aiuNN/src/aiunn/inference.py

95 lines
3.5 KiB
Python

import torch
from albumentations import Compose, Normalize
from albumentations.pytorch import ToTensorV2
from PIL import Image
import numpy as np
import io
from torch import nn
from aiia import AIIABase
class Upscaler(nn.Module):
"""
Transforms the base model's final feature map using a transposed convolution.
The base model produces a feature map of size 512x512.
This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features
to the output channels using a single ConvTranspose2d layer.
"""
def __init__(self, base_model: AIIABase):
super(Upscaler, self).__init__()
self.base_model = base_model
# Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer.
self.last_transform = nn.ConvTranspose2d(
in_channels=base_model.config.hidden_size,
out_channels=base_model.config.num_channels,
kernel_size=base_model.config.kernel_size,
stride=2,
padding=1,
output_padding=1
)
def forward(self, x):
features = self.base_model(x)
return self.last_transform(features)
class ImageUpscaler:
def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
self.device = torch.device(device)
self.model = self.load_model(model_path)
self.model.eval() # Set to evaluation mode
# Define preprocessing transformations
self.preprocess = Compose([
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
ToTensorV2()
])
def load_model(self, model_path: str):
"""
Load the trained model from the specified path.
"""
base_model = AIIABase.load(model_path) # Load base model
model = Upscaler(base_model) # Wrap with Upscaler
return model.to(self.device)
def preprocess_image(self, image: Image.Image):
"""
Preprocess input image for inference.
"""
if not isinstance(image, Image.Image):
raise ValueError("Input must be a PIL.Image.Image object")
# Convert to numpy array and apply preprocessing
image_array = np.array(image)
augmented = self.preprocess(image=image_array)
# Add batch dimension and move to device
return augmented['image'].unsqueeze(0).to(self.device)
def postprocess_image(self, output_tensor: torch.Tensor):
"""
Convert output tensor back to an image.
"""
output_tensor = output_tensor.squeeze(0).cpu() # Remove batch dimension
output_array = (output_tensor * 0.5 + 0.5).clamp(0, 1).numpy() * 255
output_array = output_array.transpose(1, 2, 0).astype(np.uint8) # CHW -> HWC
return Image.fromarray(output_array)
def upscale_image(self, input_image_path: str):
"""
Perform upscaling on an input image.
"""
input_image = Image.open(input_image_path).convert('RGB') # Ensure RGB format
preprocessed_image = self.preprocess_image(input_image)
with torch.no_grad():
with torch.amp.autocast(device_type="cuda"):
output_tensor = self.model(preprocessed_image)
return self.postprocess_image(output_tensor)
# Example usage:
upscaler = ImageUpscaler(model_path="/root/vision/aiuNN/best_model")
upscaled_image = upscaler.upscale_image("/root/vision/aiuNN/input.jpg")
upscaled_image.save("upscaled_image.jpg")