diff --git a/input.jpg b/input.jpg new file mode 100644 index 0000000..0426a63 Binary files /dev/null and b/input.jpg differ diff --git a/src/aiunn/inference.py b/src/aiunn/inference.py index c86905b..09b1098 100644 --- a/src/aiunn/inference.py +++ b/src/aiunn/inference.py @@ -1,73 +1,136 @@ import torch +from albumentations import Compose, Normalize +from albumentations.pytorch import ToTensorV2 from PIL import Image -import torchvision.transforms as T -from torch.nn import functional as F -from aiia.model import AIIABase +import numpy as np +import io +from torch import nn +from aiia import AIIABase -class UpScaler: - def __init__(self, model_path="aiuNN-finetuned", device="cuda"): +class Upscaler(nn.Module): + """ + Transforms the base model's final feature map using a transposed convolution. + The base model produces a feature map of size 512x512. + This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features + to the output channels using a single ConvTranspose2d layer. + """ + def __init__(self, base_model: AIIABase): + super(Upscaler, self).__init__() + self.base_model = base_model + # Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer. + self.last_transform = nn.ConvTranspose2d( + in_channels=base_model.config.hidden_size, + out_channels=base_model.config.num_channels, + kernel_size=base_model.config.kernel_size, + stride=2, + padding=1, + output_padding=1 + ) + + def forward(self, x): + features = self.base_model(x) + return self.last_transform(features) + +class ImageUpscaler: + def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): + """ + Initialize the ImageUpscaler with the trained model. + + Args: + model_path (str): Path to the trained model directory. + device (str): Device to run inference on ('cuda' or 'cpu'). + """ self.device = torch.device(device) - self.model = AIIABase.load(model_path).to(self.device) - self.model.eval() + self.model = self.load_model(model_path) + self.model.eval() # Set the model to evaluation mode - # Preprocessing transforms - self.preprocess = T.Compose([ - T.Lambda(lambda img: self._pad_to_square(img)), - T.Resize(512), - T.ToTensor(), - T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + # Define preprocessing transformations + self.preprocess = Compose([ + Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ToTensorV2() ]) - - def _pad_to_square(self, pil_img): - """Pad image to square while maintaining aspect ratio""" - w, h = pil_img.size - max_side = max(w, h) - hp = (max_side - w) // 2 - vp = (max_side - h) // 2 - padding = (hp, vp, max_side - w - hp, max_side - h - vp) - return T.functional.pad(pil_img, padding, 0, 'constant') - def _remove_padding(self, tensor, original_size): - """Remove padding added during preprocessing""" - _, _, h, w = tensor.shape - orig_w, orig_h = original_size + def load_model(self, model_path: str): + """ + Load the trained model from the specified path. - # Calculate scale factor - scale = 512 / max(orig_w, orig_h) - new_w = int(orig_w * scale) - new_h = int(orig_h * scale) + Args: + model_path (str): Path to the saved model. - # Calculate padding offsets - pad_w = (512 - new_w) // 2 - pad_h = (512 - new_h) // 2 + Returns: + nn.Module: Loaded PyTorch model. + """ + # Load the base model and wrap it with Upscaler + base_model = AIIABase.load(model_path) + model = Upscaler(base_model) - # Remove padding - unpad = tensor[:, :, pad_h:pad_h+new_h, pad_w:pad_w+new_w] + # Move the model to the appropriate device + return model.to(self.device) + + def preprocess_image(self, image: Image.Image): + """ + Preprocess the input image for inference. - # Resize to target 2x resolution - return F.interpolate(unpad, size=(orig_h*2, orig_w*2), mode='bilinear', align_corners=False) - - def upscale(self, input_image): - # Preprocess - original_size = input_image.size - input_tensor = self.preprocess(input_image).unsqueeze(0).to(self.device) + Args: + image (PIL.Image.Image): Input image in PIL format. - # Inference + Returns: + torch.Tensor: Preprocessed image tensor. + """ + # Convert PIL image to numpy array + image_array = np.array(image) + + # Apply preprocessing transformations + augmented = self.preprocess(image=image_array) + + # Add batch dimension and move to device + return augmented['image'].unsqueeze(0).to(self.device) + + def postprocess_image(self, output_tensor: torch.Tensor): + """ + Postprocess the output tensor to convert it back to an image. + + Args: + output_tensor (torch.Tensor): Model output tensor. + + Returns: + PIL.Image.Image: Upscaled image in PIL format. + """ + # Remove batch dimension and move to CPU + output_tensor = output_tensor.squeeze(0).cpu() + + # Denormalize and convert to numpy array + output_array = (output_tensor * 0.5 + 0.5).clamp(0, 1).numpy() + + # Convert from CHW (Channels-Height-Width) to HWC (Height-Width-Channels) format + output_array = (output_array.transpose(1, 2, 0) * 255).astype(np.uint8) + + # Convert numpy array back to PIL image + return Image.fromarray(output_array) + + def upscale_image(self, input_image_path: str): + """ + Perform upscaling on an input image. + + Args: + input_image_path (str): Path to the input low-resolution image. + + Returns: + PIL.Image.Image: Upscaled high-resolution image. + """ + # Load and preprocess the input image + input_image = Image.open(input_image_path).convert('RGB') + preprocessed_image = self.preprocess_image(input_image) + + # Perform inference with the model with torch.no_grad(): - features = self.model.cnn(input_tensor) - output = self.model.upsample(features) + with torch.cuda.amp.autocast(device_type="cuda"): + output_tensor = self.model(preprocessed_image) - # Postprocess - output = self._remove_padding(output, original_size) - - # Convert to PIL Image - output = output.squeeze(0).cpu().detach() - output = (output * 0.5 + 0.5).clamp(0, 1) - return T.functional.to_pil_image(output) + # Postprocess and return the upscaled image + return self.postprocess_image(output_tensor) -# Usage example -if __name__ == "__main__": - upscaler = UpScaler() - input_image = Image.open("input.jpg") - output_image = upscaler.upscale(input_image) - output_image.save("output_2x.jpg") +# Example usage: +# upscaler = ImageUpscaler(model_path="/path/to/best_model") +# upscaled_image = upscaler.upscale_image("/path/to/low_res_image.jpg") +# upscaled_image.save("/path/to/upscaled_image.jpg")