import torch from PIL import Image import torchvision.transforms as T from torch.nn import functional as F from aiia.model import AIIABase class UpScaler: def __init__(self, model_path="AIIA-base-512-upscaler", device="cuda"): self.device = torch.device(device) self.model = AIIABase.load(model_path).to(self.device) self.model.eval() # Preprocessing transforms self.preprocess = T.Compose([ T.Lambda(lambda img: self._pad_to_square(img)), T.Resize(512), T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def _pad_to_square(self, pil_img): """Pad image to square while maintaining aspect ratio""" w, h = pil_img.size max_side = max(w, h) hp = (max_side - w) // 2 vp = (max_side - h) // 2 padding = (hp, vp, max_side - w - hp, max_side - h - vp) return T.functional.pad(pil_img, padding, 0, 'constant') def _remove_padding(self, tensor, original_size): """Remove padding added during preprocessing""" _, _, h, w = tensor.shape orig_w, orig_h = original_size # Calculate scale factor scale = 512 / max(orig_w, orig_h) new_w = int(orig_w * scale) new_h = int(orig_h * scale) # Calculate padding offsets pad_w = (512 - new_w) // 2 pad_h = (512 - new_h) // 2 # Remove padding unpad = tensor[:, :, pad_h:pad_h+new_h, pad_w:pad_w+new_w] # Resize to target 2x resolution return F.interpolate(unpad, size=(orig_h*2, orig_w*2), mode='bilinear', align_corners=False) def upscale(self, input_image): # Preprocess original_size = input_image.size input_tensor = self.preprocess(input_image).unsqueeze(0).to(self.device) # Inference with torch.no_grad(): features = self.model.cnn(input_tensor) output = self.model.upsample(features) # Postprocess output = self._remove_padding(output, original_size) # Convert to PIL Image output = output.squeeze(0).cpu().detach() output = (output * 0.5 + 0.5).clamp(0, 1) return T.functional.to_pil_image(output) # Usage example if __name__ == "__main__": upscaler = UpScaler() input_image = Image.open("input.jpg") output_image = upscaler.upscale(input_image) output_image.save("output_2x.jpg")