95 lines
3.5 KiB
Python
95 lines
3.5 KiB
Python
import torch
|
|
from albumentations import Compose, Normalize
|
|
from albumentations.pytorch import ToTensorV2
|
|
from PIL import Image
|
|
import numpy as np
|
|
import io
|
|
from torch import nn
|
|
from aiia import AIIABase
|
|
|
|
|
|
class Upscaler(nn.Module):
|
|
"""
|
|
Transforms the base model's final feature map using a transposed convolution.
|
|
The base model produces a feature map of size 512x512.
|
|
This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features
|
|
to the output channels using a single ConvTranspose2d layer.
|
|
"""
|
|
def __init__(self, base_model: AIIABase):
|
|
super(Upscaler, self).__init__()
|
|
self.base_model = base_model
|
|
# Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer.
|
|
self.last_transform = nn.ConvTranspose2d(
|
|
in_channels=base_model.config.hidden_size,
|
|
out_channels=base_model.config.num_channels,
|
|
kernel_size=base_model.config.kernel_size,
|
|
stride=2,
|
|
padding=1,
|
|
output_padding=1
|
|
)
|
|
|
|
def forward(self, x):
|
|
features = self.base_model(x)
|
|
return self.last_transform(features)
|
|
class ImageUpscaler:
|
|
def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
|
self.device = torch.device(device)
|
|
self.model = self.load_model(model_path)
|
|
self.model.eval() # Set to evaluation mode
|
|
|
|
# Define preprocessing transformations
|
|
self.preprocess = Compose([
|
|
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
|
ToTensorV2()
|
|
])
|
|
|
|
def load_model(self, model_path: str):
|
|
"""
|
|
Load the trained model from the specified path.
|
|
"""
|
|
base_model = AIIABase.load(model_path) # Load base model
|
|
model = Upscaler(base_model) # Wrap with Upscaler
|
|
return model.to(self.device)
|
|
|
|
def preprocess_image(self, image: Image.Image):
|
|
"""
|
|
Preprocess input image for inference.
|
|
"""
|
|
if not isinstance(image, Image.Image):
|
|
raise ValueError("Input must be a PIL.Image.Image object")
|
|
|
|
# Convert to numpy array and apply preprocessing
|
|
image_array = np.array(image)
|
|
augmented = self.preprocess(image=image_array)
|
|
|
|
# Add batch dimension and move to device
|
|
return augmented['image'].unsqueeze(0).to(self.device)
|
|
|
|
def postprocess_image(self, output_tensor: torch.Tensor):
|
|
"""
|
|
Convert output tensor back to an image.
|
|
"""
|
|
output_tensor = output_tensor.squeeze(0).cpu() # Remove batch dimension
|
|
output_array = (output_tensor * 0.5 + 0.5).clamp(0, 1).numpy() * 255
|
|
output_array = output_array.transpose(1, 2, 0).astype(np.uint8) # CHW -> HWC
|
|
return Image.fromarray(output_array)
|
|
|
|
def upscale_image(self, input_image_path: str):
|
|
"""
|
|
Perform upscaling on an input image.
|
|
"""
|
|
input_image = Image.open(input_image_path).convert('RGB') # Ensure RGB format
|
|
preprocessed_image = self.preprocess_image(input_image)
|
|
|
|
with torch.no_grad():
|
|
with torch.amp.autocast(device_type="cuda"):
|
|
output_tensor = self.model(preprocessed_image)
|
|
|
|
return self.postprocess_image(output_tensor)
|
|
|
|
|
|
# Example usage:
|
|
upscaler = ImageUpscaler(model_path="/root/vision/aiuNN/best_model")
|
|
upscaled_image = upscaler.upscale_image("/root/vision/aiuNN/input.jpg")
|
|
upscaled_image.save("upscaled_image.jpg")
|