develop #4
|
@ -1,5 +1,6 @@
|
||||||
from .finetune.trainer import aiuNNTrainer
|
from .finetune.trainer import aiuNNTrainer
|
||||||
from .upsampler.aiunn import aiuNN
|
from .upsampler.aiunn import aiuNN
|
||||||
from .upsampler.config import aiuNNConfig
|
from .upsampler.config import aiuNNConfig
|
||||||
|
from .inference.inference import aiuNNInference
|
||||||
|
|
||||||
__version__ = "0.1.1"
|
__version__ = "0.1.1"
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .inference import aiuNNInference
|
||||||
|
|
||||||
|
__all__ = ["aiuNNInference"]
|
|
@ -1,96 +1,226 @@
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from albumentations import Compose, Normalize
|
|
||||||
from albumentations.pytorch import ToTensorV2
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
from torch import nn
|
from typing import Union, Optional, Tuple, List
|
||||||
from aiia import AIIABase
|
from ..upsampler.aiunn import aiuNN
|
||||||
|
|
||||||
|
|
||||||
class Upscaler(nn.Module):
|
class aiuNNInference:
|
||||||
"""
|
"""
|
||||||
Transforms the base model's final feature map using a transposed convolution.
|
Inference class for aiuNN upsampling model.
|
||||||
The base model produces a feature map of size 512x512.
|
Handles model loading, image upscaling, and output processing.
|
||||||
This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features
|
|
||||||
to the output channels using a single ConvTranspose2d layer.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, base_model: AIIABase):
|
def __init__(self, model_path: str, precision: Optional[str] = None, device: Optional[str] = None):
|
||||||
super(Upscaler, self).__init__()
|
"""
|
||||||
self.base_model = base_model
|
Initialize the inference class by loading the aiuNN model.
|
||||||
# Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer.
|
|
||||||
self.last_transform = nn.ConvTranspose2d(
|
Args:
|
||||||
in_channels=base_model.config.hidden_size,
|
model_path: Path to the saved model directory
|
||||||
out_channels=base_model.config.num_channels,
|
precision: Optional precision setting ('fp16', 'bf16', or None for default)
|
||||||
kernel_size=base_model.config.kernel_size,
|
device: Optional device specification ('cuda', 'cpu', or None for auto-detection)
|
||||||
stride=2,
|
"""
|
||||||
padding=1,
|
|
||||||
output_padding=1
|
|
||||||
|
# Set device
|
||||||
|
if device is None:
|
||||||
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
else:
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Load the model with specified precision
|
||||||
|
self.model = aiuNN.load(model_path, precision=precision)
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# Store configuration for reference
|
||||||
|
self.config = self.model.config
|
||||||
|
|
||||||
|
def preprocess_image(self, image: Union[str, Image.Image, np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Preprocess the input image to match model requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image as file path, PIL Image, numpy array, or torch tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preprocessed tensor ready for model input
|
||||||
|
"""
|
||||||
|
# Handle different input types
|
||||||
|
if isinstance(image, str):
|
||||||
|
# Load from file path
|
||||||
|
image = Image.open(image).convert('RGB')
|
||||||
|
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
# Convert PIL Image to tensor
|
||||||
|
image = np.array(image)
|
||||||
|
image = image.transpose(2, 0, 1) # HWC to CHW
|
||||||
|
image = torch.from_numpy(image).float()
|
||||||
|
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
# Convert numpy array to tensor
|
||||||
|
if image.shape[0] == 3:
|
||||||
|
# Already in CHW format
|
||||||
|
pass
|
||||||
|
elif image.shape[-1] == 3:
|
||||||
|
# HWC to CHW format
|
||||||
|
image = image.transpose(2, 0, 1)
|
||||||
|
image = torch.from_numpy(image).float()
|
||||||
|
|
||||||
|
# Normalize to [0, 1] range if needed
|
||||||
|
if image.max() > 1.0:
|
||||||
|
image = image / 255.0
|
||||||
|
|
||||||
|
# Add batch dimension if not present
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
|
||||||
|
# Move to device
|
||||||
|
image = image.to(self.device)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def postprocess_tensor(self, tensor: torch.Tensor) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Convert output tensor to PIL Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Output tensor from model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed PIL Image
|
||||||
|
"""
|
||||||
|
# Move to CPU and convert to numpy
|
||||||
|
output = tensor.detach().cpu().squeeze(0).numpy()
|
||||||
|
|
||||||
|
# Ensure proper range [0, 255]
|
||||||
|
output = np.clip(output * 255, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# Convert from CHW to HWC for PIL
|
||||||
|
output = output.transpose(1, 2, 0)
|
||||||
|
|
||||||
|
# Create PIL Image
|
||||||
|
return Image.fromarray(output)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def upscale(self, image: Union[str, Image.Image, np.ndarray, torch.Tensor]) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Upscale an image using the aiuNN model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image to upscale
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Upscaled image as PIL Image
|
||||||
|
"""
|
||||||
|
# Preprocess input
|
||||||
|
input_tensor = self.preprocess_image(image)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
output_tensor = self.model(input_tensor)
|
||||||
|
|
||||||
|
# Postprocess output
|
||||||
|
upscaled_image = self.postprocess_tensor(output_tensor)
|
||||||
|
|
||||||
|
return upscaled_image
|
||||||
|
|
||||||
|
def save(self, image: Image.Image, output_path: str, format: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Save the upscaled image to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image to save
|
||||||
|
output_path: Path where the image should be saved
|
||||||
|
format: Optional format override (e.g., 'PNG', 'JPEG')
|
||||||
|
"""
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
||||||
|
|
||||||
|
# Save the image
|
||||||
|
image.save(output_path, format=format)
|
||||||
|
|
||||||
|
def convert_to_binary(self, image: Image.Image, format: str = 'PNG') -> bytes:
|
||||||
|
"""
|
||||||
|
Convert the image to binary data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image to convert
|
||||||
|
format: Image format to use for binary conversion
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Binary representation of the image
|
||||||
|
"""
|
||||||
|
# Use BytesIO to convert to binary
|
||||||
|
binary_output = io.BytesIO()
|
||||||
|
image.save(binary_output, format=format)
|
||||||
|
|
||||||
|
# Get the binary data
|
||||||
|
binary_data = binary_output.getvalue()
|
||||||
|
|
||||||
|
return binary_data
|
||||||
|
|
||||||
|
def process_batch(self,
|
||||||
|
images: List[Union[str, Image.Image]],
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
save_format: str = 'PNG',
|
||||||
|
return_binary: bool = False) -> Union[List[Image.Image], List[bytes], None]:
|
||||||
|
"""
|
||||||
|
Process multiple images in batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: List of input images (paths or PIL Images)
|
||||||
|
output_dir: Optional directory to save results
|
||||||
|
save_format: Format to use when saving images
|
||||||
|
return_binary: Whether to return binary data instead of PIL Images
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of processed images or binary data, or None if only saving
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
# Upscale the image
|
||||||
|
upscaled = self.upscale(img)
|
||||||
|
|
||||||
|
# Save if output directory is provided
|
||||||
|
if output_dir:
|
||||||
|
# Extract filename if input is a path
|
||||||
|
if isinstance(img, str):
|
||||||
|
filename = os.path.basename(img)
|
||||||
|
base, _ = os.path.splitext(filename)
|
||||||
|
else:
|
||||||
|
base = f"upscaled_{i}"
|
||||||
|
|
||||||
|
output_path = os.path.join(output_dir, f"{base}.{save_format.lower()}")
|
||||||
|
self.save(upscaled, output_path, format=save_format)
|
||||||
|
|
||||||
|
# Add to results based on return type
|
||||||
|
if return_binary:
|
||||||
|
results.append(self.convert_to_binary(upscaled, format=save_format))
|
||||||
|
else:
|
||||||
|
results.append(upscaled)
|
||||||
|
|
||||||
|
return results if (not output_dir or return_binary or not save_format) else None
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage (can be removed)
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Initialize inference with a model path
|
||||||
|
inferencer = aiuNNInference("path/to/model", precision="bf16")
|
||||||
|
|
||||||
|
# Upscale a single image
|
||||||
|
upscaled_image = inferencer.upscale("input_image.jpg")
|
||||||
|
|
||||||
|
# Save the result
|
||||||
|
inferencer.save(upscaled_image, "output_image.png")
|
||||||
|
|
||||||
|
# Convert to binary
|
||||||
|
binary_data = inferencer.convert_to_binary(upscaled_image)
|
||||||
|
|
||||||
|
# Process a batch of images
|
||||||
|
inferencer.process_batch(
|
||||||
|
["image1.jpg", "image2.jpg"],
|
||||||
|
output_dir="output_folder",
|
||||||
|
save_format="PNG"
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
features = self.base_model(x)
|
|
||||||
return self.last_transform(features)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageUpscaler:
|
|
||||||
def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
|
||||||
self.device = torch.device(device)
|
|
||||||
self.model = self.load_model(model_path)
|
|
||||||
self.model.eval() # Set to evaluation mode
|
|
||||||
|
|
||||||
# Define preprocessing transformations
|
|
||||||
self.preprocess = Compose([
|
|
||||||
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
||||||
ToTensorV2()
|
|
||||||
])
|
|
||||||
|
|
||||||
def load_model(self, model_path: str):
|
|
||||||
"""
|
|
||||||
Load the trained model from the specified path.
|
|
||||||
"""
|
|
||||||
base_model = AIIABase.load(model_path) # Load base model
|
|
||||||
model = Upscaler(base_model) # Wrap with Upscaler
|
|
||||||
return model.to(self.device)
|
|
||||||
|
|
||||||
def preprocess_image(self, image: Image.Image):
|
|
||||||
"""
|
|
||||||
Preprocess input image for inference.
|
|
||||||
"""
|
|
||||||
if not isinstance(image, Image.Image):
|
|
||||||
raise ValueError("Input must be a PIL.Image.Image object")
|
|
||||||
|
|
||||||
# Convert to numpy array and apply preprocessing
|
|
||||||
image_array = np.array(image)
|
|
||||||
augmented = self.preprocess(image=image_array)
|
|
||||||
|
|
||||||
# Add batch dimension and move to device
|
|
||||||
return augmented['image'].unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
def postprocess_image(self, output_tensor: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Convert output tensor back to an image.
|
|
||||||
"""
|
|
||||||
output_tensor = output_tensor.squeeze(0).cpu() # Remove batch dimension
|
|
||||||
output_array = (output_tensor * 0.5 + 0.5).clamp(0, 1).numpy() * 255
|
|
||||||
output_array = output_array.transpose(1, 2, 0).astype(np.uint8) # CHW -> HWC
|
|
||||||
return Image.fromarray(output_array)
|
|
||||||
|
|
||||||
def upscale_image(self, input_image_path: str):
|
|
||||||
"""
|
|
||||||
Perform upscaling on an input image.
|
|
||||||
"""
|
|
||||||
input_image = Image.open(input_image_path).convert('RGB') # Ensure RGB format
|
|
||||||
preprocessed_image = self.preprocess_image(input_image)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
with torch.amp.autocast(device_type="cuda"):
|
|
||||||
output_tensor = self.model(preprocessed_image)
|
|
||||||
|
|
||||||
return self.postprocess_image(output_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
# Example usage:
|
|
||||||
upscaler = ImageUpscaler(model_path="/root/vision/aiuNN/best_model")
|
|
||||||
upscaled_image = upscaler.upscale_image("/root/vision/aiuNN/input.jpg")
|
|
||||||
upscaled_image.save("upscaled_image.jpg")
|
|
||||||
|
|
Loading…
Reference in New Issue