aiuNN/src/aiunn/inference.py

74 lines
2.5 KiB
Python

import torch
from PIL import Image
import torchvision.transforms as T
from torch.nn import functional as F
from aiia.model import AIIABase
class UpScaler:
def __init__(self, model_path="AIIA-base-512-upscaler", device="cuda"):
self.device = torch.device(device)
self.model = AIIABase.load(model_path).to(self.device)
self.model.eval()
# Preprocessing transforms
self.preprocess = T.Compose([
T.Lambda(lambda img: self._pad_to_square(img)),
T.Resize(512),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def _pad_to_square(self, pil_img):
"""Pad image to square while maintaining aspect ratio"""
w, h = pil_img.size
max_side = max(w, h)
hp = (max_side - w) // 2
vp = (max_side - h) // 2
padding = (hp, vp, max_side - w - hp, max_side - h - vp)
return T.functional.pad(pil_img, padding, 0, 'constant')
def _remove_padding(self, tensor, original_size):
"""Remove padding added during preprocessing"""
_, _, h, w = tensor.shape
orig_w, orig_h = original_size
# Calculate scale factor
scale = 512 / max(orig_w, orig_h)
new_w = int(orig_w * scale)
new_h = int(orig_h * scale)
# Calculate padding offsets
pad_w = (512 - new_w) // 2
pad_h = (512 - new_h) // 2
# Remove padding
unpad = tensor[:, :, pad_h:pad_h+new_h, pad_w:pad_w+new_w]
# Resize to target 2x resolution
return F.interpolate(unpad, size=(orig_h*2, orig_w*2), mode='bilinear', align_corners=False)
def upscale(self, input_image):
# Preprocess
original_size = input_image.size
input_tensor = self.preprocess(input_image).unsqueeze(0).to(self.device)
# Inference
with torch.no_grad():
features = self.model.cnn(input_tensor)
output = self.model.upsample(features)
# Postprocess
output = self._remove_padding(output, original_size)
# Convert to PIL Image
output = output.squeeze(0).cpu().detach()
output = (output * 0.5 + 0.5).clamp(0, 1)
return T.functional.to_pil_image(output)
# Usage example
if __name__ == "__main__":
upscaler = UpScaler()
input_image = Image.open("input.jpg")
output_image = upscaler.upscale(input_image)
output_image.save("output_2x.jpg")