From 2e1556d30651f5fdb1e96f9b69ecaa1b97497b55 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Tue, 25 Feb 2025 19:36:58 +0100 Subject: [PATCH] added aiunn inference --- src/aiunn/__init__.py | 1 + src/aiunn/inference/__init__.py | 3 + src/aiunn/inference/inference.py | 274 +++++++++++++++++++++++-------- 3 files changed, 206 insertions(+), 72 deletions(-) diff --git a/src/aiunn/__init__.py b/src/aiunn/__init__.py index 2e6f021..a097c72 100644 --- a/src/aiunn/__init__.py +++ b/src/aiunn/__init__.py @@ -1,5 +1,6 @@ from .finetune.trainer import aiuNNTrainer from .upsampler.aiunn import aiuNN from .upsampler.config import aiuNNConfig +from .inference.inference import aiuNNInference __version__ = "0.1.1" \ No newline at end of file diff --git a/src/aiunn/inference/__init__.py b/src/aiunn/inference/__init__.py index e69de29..798de24 100644 --- a/src/aiunn/inference/__init__.py +++ b/src/aiunn/inference/__init__.py @@ -0,0 +1,3 @@ +from .inference import aiuNNInference + +__all__ = ["aiuNNInference"] \ No newline at end of file diff --git a/src/aiunn/inference/inference.py b/src/aiunn/inference/inference.py index 6ed5b0d..d288931 100644 --- a/src/aiunn/inference/inference.py +++ b/src/aiunn/inference/inference.py @@ -1,96 +1,226 @@ +import os import torch -from albumentations import Compose, Normalize -from albumentations.pytorch import ToTensorV2 -from PIL import Image import numpy as np +from PIL import Image import io -from torch import nn -from aiia import AIIABase +from typing import Union, Optional, Tuple, List +from ..upsampler.aiunn import aiuNN -class Upscaler(nn.Module): +class aiuNNInference: """ - Transforms the base model's final feature map using a transposed convolution. - The base model produces a feature map of size 512x512. - This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features - to the output channels using a single ConvTranspose2d layer. + Inference class for aiuNN upsampling model. + Handles model loading, image upscaling, and output processing. """ - def __init__(self, base_model: AIIABase): - super(Upscaler, self).__init__() - self.base_model = base_model - # Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer. - self.last_transform = nn.ConvTranspose2d( - in_channels=base_model.config.hidden_size, - out_channels=base_model.config.num_channels, - kernel_size=base_model.config.kernel_size, - stride=2, - padding=1, - output_padding=1 - ) + def __init__(self, model_path: str, precision: Optional[str] = None, device: Optional[str] = None): + """ + Initialize the inference class by loading the aiuNN model. - def forward(self, x): - features = self.base_model(x) - return self.last_transform(features) + Args: + model_path: Path to the saved model directory + precision: Optional precision setting ('fp16', 'bf16', or None for default) + device: Optional device specification ('cuda', 'cpu', or None for auto-detection) + """ + + + # Set device + if device is None: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + self.device = device + + # Load the model with specified precision + self.model = aiuNN.load(model_path, precision=precision) + self.model.to(self.device) + self.model.eval() + + # Store configuration for reference + self.config = self.model.config + + def preprocess_image(self, image: Union[str, Image.Image, np.ndarray, torch.Tensor]) -> torch.Tensor: + """ + Preprocess the input image to match model requirements. + + Args: + image: Input image as file path, PIL Image, numpy array, or torch tensor + + Returns: + Preprocessed tensor ready for model input + """ + # Handle different input types + if isinstance(image, str): + # Load from file path + image = Image.open(image).convert('RGB') + + if isinstance(image, Image.Image): + # Convert PIL Image to tensor + image = np.array(image) + image = image.transpose(2, 0, 1) # HWC to CHW + image = torch.from_numpy(image).float() + + if isinstance(image, np.ndarray): + # Convert numpy array to tensor + if image.shape[0] == 3: + # Already in CHW format + pass + elif image.shape[-1] == 3: + # HWC to CHW format + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).float() + + # Normalize to [0, 1] range if needed + if image.max() > 1.0: + image = image / 255.0 + + # Add batch dimension if not present + if len(image.shape) == 3: + image = image.unsqueeze(0) + + # Move to device + image = image.to(self.device) + + return image - -class ImageUpscaler: - def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): - self.device = torch.device(device) - self.model = self.load_model(model_path) - self.model.eval() # Set to evaluation mode + def postprocess_tensor(self, tensor: torch.Tensor) -> Image.Image: + """ + Convert output tensor to PIL Image. - # Define preprocessing transformations - self.preprocess = Compose([ - Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ToTensorV2() - ]) + Args: + tensor: Output tensor from model + + Returns: + Processed PIL Image + """ + # Move to CPU and convert to numpy + output = tensor.detach().cpu().squeeze(0).numpy() + + # Ensure proper range [0, 255] + output = np.clip(output * 255, 0, 255).astype(np.uint8) + + # Convert from CHW to HWC for PIL + output = output.transpose(1, 2, 0) + + # Create PIL Image + return Image.fromarray(output) - def load_model(self, model_path: str): + @torch.no_grad() + def upscale(self, image: Union[str, Image.Image, np.ndarray, torch.Tensor]) -> Image.Image: """ - Load the trained model from the specified path. + Upscale an image using the aiuNN model. + + Args: + image: Input image to upscale + + Returns: + Upscaled image as PIL Image """ - base_model = AIIABase.load(model_path) # Load base model - model = Upscaler(base_model) # Wrap with Upscaler - return model.to(self.device) + # Preprocess input + input_tensor = self.preprocess_image(image) + + # Run inference + output_tensor = self.model(input_tensor) + + # Postprocess output + upscaled_image = self.postprocess_tensor(output_tensor) + + return upscaled_image - def preprocess_image(self, image: Image.Image): + def save(self, image: Image.Image, output_path: str, format: Optional[str] = None) -> None: """ - Preprocess input image for inference. + Save the upscaled image to a file. + + Args: + image: PIL Image to save + output_path: Path where the image should be saved + format: Optional format override (e.g., 'PNG', 'JPEG') """ - if not isinstance(image, Image.Image): - raise ValueError("Input must be a PIL.Image.Image object") + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) - # Convert to numpy array and apply preprocessing - image_array = np.array(image) - augmented = self.preprocess(image=image_array) + # Save the image + image.save(output_path, format=format) - # Add batch dimension and move to device - return augmented['image'].unsqueeze(0).to(self.device) + def convert_to_binary(self, image: Image.Image, format: str = 'PNG') -> bytes: + """ + Convert the image to binary data. + + Args: + image: PIL Image to convert + format: Image format to use for binary conversion + + Returns: + Binary representation of the image + """ + # Use BytesIO to convert to binary + binary_output = io.BytesIO() + image.save(binary_output, format=format) + + # Get the binary data + binary_data = binary_output.getvalue() + + return binary_data - def postprocess_image(self, output_tensor: torch.Tensor): + def process_batch(self, + images: List[Union[str, Image.Image]], + output_dir: Optional[str] = None, + save_format: str = 'PNG', + return_binary: bool = False) -> Union[List[Image.Image], List[bytes], None]: """ - Convert output tensor back to an image. - """ - output_tensor = output_tensor.squeeze(0).cpu() # Remove batch dimension - output_array = (output_tensor * 0.5 + 0.5).clamp(0, 1).numpy() * 255 - output_array = output_array.transpose(1, 2, 0).astype(np.uint8) # CHW -> HWC - return Image.fromarray(output_array) - - def upscale_image(self, input_image_path: str): - """ - Perform upscaling on an input image. - """ - input_image = Image.open(input_image_path).convert('RGB') # Ensure RGB format - preprocessed_image = self.preprocess_image(input_image) + Process multiple images in batch. - with torch.no_grad(): - with torch.amp.autocast(device_type="cuda"): - output_tensor = self.model(preprocessed_image) + Args: + images: List of input images (paths or PIL Images) + output_dir: Optional directory to save results + save_format: Format to use when saving images + return_binary: Whether to return binary data instead of PIL Images + + Returns: + List of processed images or binary data, or None if only saving + """ + results = [] - return self.postprocess_image(output_tensor) + for i, img in enumerate(images): + # Upscale the image + upscaled = self.upscale(img) + + # Save if output directory is provided + if output_dir: + # Extract filename if input is a path + if isinstance(img, str): + filename = os.path.basename(img) + base, _ = os.path.splitext(filename) + else: + base = f"upscaled_{i}" + + output_path = os.path.join(output_dir, f"{base}.{save_format.lower()}") + self.save(upscaled, output_path, format=save_format) + + # Add to results based on return type + if return_binary: + results.append(self.convert_to_binary(upscaled, format=save_format)) + else: + results.append(upscaled) + + return results if (not output_dir or return_binary or not save_format) else None -# Example usage: -upscaler = ImageUpscaler(model_path="/root/vision/aiuNN/best_model") -upscaled_image = upscaler.upscale_image("/root/vision/aiuNN/input.jpg") -upscaled_image.save("upscaled_image.jpg") +# Example usage (can be removed) +if __name__ == "__main__": + # Initialize inference with a model path + inferencer = aiuNNInference("path/to/model", precision="bf16") + + # Upscale a single image + upscaled_image = inferencer.upscale("input_image.jpg") + + # Save the result + inferencer.save(upscaled_image, "output_image.png") + + # Convert to binary + binary_data = inferencer.convert_to_binary(upscaled_image) + + # Process a batch of images + inferencer.process_batch( + ["image1.jpg", "image2.jpg"], + output_dir="output_folder", + save_format="PNG" + ) \ No newline at end of file