import torch from albumentations import Compose, Normalize from albumentations.pytorch import ToTensorV2 from PIL import Image import numpy as np import io from torch import nn from aiia import AIIABase class Upscaler(nn.Module): """ Transforms the base model's final feature map using a transposed convolution. The base model produces a feature map of size 512x512. This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features to the output channels using a single ConvTranspose2d layer. """ def __init__(self, base_model: AIIABase): super(Upscaler, self).__init__() self.base_model = base_model # Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer. self.last_transform = nn.ConvTranspose2d( in_channels=base_model.config.hidden_size, out_channels=base_model.config.num_channels, kernel_size=base_model.config.kernel_size, stride=2, padding=1, output_padding=1 ) def forward(self, x): features = self.base_model(x) return self.last_transform(features) class ImageUpscaler: def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): self.device = torch.device(device) self.model = self.load_model(model_path) self.model.eval() # Set to evaluation mode # Define preprocessing transformations self.preprocess = Compose([ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2() ]) def load_model(self, model_path: str): """ Load the trained model from the specified path. """ base_model = AIIABase.load(model_path) # Load base model model = Upscaler(base_model) # Wrap with Upscaler return model.to(self.device) def preprocess_image(self, image: Image.Image): """ Preprocess input image for inference. """ if not isinstance(image, Image.Image): raise ValueError("Input must be a PIL.Image.Image object") # Convert to numpy array and apply preprocessing image_array = np.array(image) augmented = self.preprocess(image=image_array) # Add batch dimension and move to device return augmented['image'].unsqueeze(0).to(self.device) def postprocess_image(self, output_tensor: torch.Tensor): """ Convert output tensor back to an image. """ output_tensor = output_tensor.squeeze(0).cpu() # Remove batch dimension output_array = (output_tensor * 0.5 + 0.5).clamp(0, 1).numpy() * 255 output_array = output_array.transpose(1, 2, 0).astype(np.uint8) # CHW -> HWC return Image.fromarray(output_array) def upscale_image(self, input_image_path: str): """ Perform upscaling on an input image. """ input_image = Image.open(input_image_path).convert('RGB') # Ensure RGB format preprocessed_image = self.preprocess_image(input_image) with torch.no_grad(): with torch.amp.autocast(device_type="cuda"): output_tensor = self.model(preprocessed_image) return self.postprocess_image(output_tensor) # Example usage: upscaler = ImageUpscaler(model_path="/root/vision/aiuNN/best_model") upscaled_image = upscaler.upscale_image("/root/vision/aiuNN/input.jpg") upscaled_image.save("upscaled_image.jpg")