74 lines
2.5 KiB
Python
74 lines
2.5 KiB
Python
import torch
|
|
from PIL import Image
|
|
import torchvision.transforms as T
|
|
from torch.nn import functional as F
|
|
from aiia.model import AIIABase
|
|
|
|
class UpScaler:
|
|
def __init__(self, model_path="AIIA-base-512-upscaler", device="cuda"):
|
|
self.device = torch.device(device)
|
|
self.model = AIIABase.load(model_path).to(self.device)
|
|
self.model.eval()
|
|
|
|
# Preprocessing transforms
|
|
self.preprocess = T.Compose([
|
|
T.Lambda(lambda img: self._pad_to_square(img)),
|
|
T.Resize(512),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
])
|
|
|
|
def _pad_to_square(self, pil_img):
|
|
"""Pad image to square while maintaining aspect ratio"""
|
|
w, h = pil_img.size
|
|
max_side = max(w, h)
|
|
hp = (max_side - w) // 2
|
|
vp = (max_side - h) // 2
|
|
padding = (hp, vp, max_side - w - hp, max_side - h - vp)
|
|
return T.functional.pad(pil_img, padding, 0, 'constant')
|
|
|
|
def _remove_padding(self, tensor, original_size):
|
|
"""Remove padding added during preprocessing"""
|
|
_, _, h, w = tensor.shape
|
|
orig_w, orig_h = original_size
|
|
|
|
# Calculate scale factor
|
|
scale = 512 / max(orig_w, orig_h)
|
|
new_w = int(orig_w * scale)
|
|
new_h = int(orig_h * scale)
|
|
|
|
# Calculate padding offsets
|
|
pad_w = (512 - new_w) // 2
|
|
pad_h = (512 - new_h) // 2
|
|
|
|
# Remove padding
|
|
unpad = tensor[:, :, pad_h:pad_h+new_h, pad_w:pad_w+new_w]
|
|
|
|
# Resize to target 2x resolution
|
|
return F.interpolate(unpad, size=(orig_h*2, orig_w*2), mode='bilinear', align_corners=False)
|
|
|
|
def upscale(self, input_image):
|
|
# Preprocess
|
|
original_size = input_image.size
|
|
input_tensor = self.preprocess(input_image).unsqueeze(0).to(self.device)
|
|
|
|
# Inference
|
|
with torch.no_grad():
|
|
features = self.model.cnn(input_tensor)
|
|
output = self.model.upsample(features)
|
|
|
|
# Postprocess
|
|
output = self._remove_padding(output, original_size)
|
|
|
|
# Convert to PIL Image
|
|
output = output.squeeze(0).cpu().detach()
|
|
output = (output * 0.5 + 0.5).clamp(0, 1)
|
|
return T.functional.to_pil_image(output)
|
|
|
|
# Usage example
|
|
if __name__ == "__main__":
|
|
upscaler = UpScaler()
|
|
input_image = Image.open("input.jpg")
|
|
output_image = upscaler.upscale(input_image)
|
|
output_image.save("output_2x.jpg")
|