updated inference class and test image
This commit is contained in:
parent
62825e9731
commit
704ad6106d
|
@ -1,73 +1,136 @@
|
||||||
import torch
|
import torch
|
||||||
|
from albumentations import Compose, Normalize
|
||||||
|
from albumentations.pytorch import ToTensorV2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torchvision.transforms as T
|
import numpy as np
|
||||||
from torch.nn import functional as F
|
import io
|
||||||
from aiia.model import AIIABase
|
from torch import nn
|
||||||
|
from aiia import AIIABase
|
||||||
|
|
||||||
class UpScaler:
|
class Upscaler(nn.Module):
|
||||||
def __init__(self, model_path="aiuNN-finetuned", device="cuda"):
|
"""
|
||||||
|
Transforms the base model's final feature map using a transposed convolution.
|
||||||
|
The base model produces a feature map of size 512x512.
|
||||||
|
This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features
|
||||||
|
to the output channels using a single ConvTranspose2d layer.
|
||||||
|
"""
|
||||||
|
def __init__(self, base_model: AIIABase):
|
||||||
|
super(Upscaler, self).__init__()
|
||||||
|
self.base_model = base_model
|
||||||
|
# Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer.
|
||||||
|
self.last_transform = nn.ConvTranspose2d(
|
||||||
|
in_channels=base_model.config.hidden_size,
|
||||||
|
out_channels=base_model.config.num_channels,
|
||||||
|
kernel_size=base_model.config.kernel_size,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
output_padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
features = self.base_model(x)
|
||||||
|
return self.last_transform(features)
|
||||||
|
|
||||||
|
class ImageUpscaler:
|
||||||
|
def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
||||||
|
"""
|
||||||
|
Initialize the ImageUpscaler with the trained model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): Path to the trained model directory.
|
||||||
|
device (str): Device to run inference on ('cuda' or 'cpu').
|
||||||
|
"""
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
self.model = AIIABase.load(model_path).to(self.device)
|
self.model = self.load_model(model_path)
|
||||||
self.model.eval()
|
self.model.eval() # Set the model to evaluation mode
|
||||||
|
|
||||||
# Preprocessing transforms
|
# Define preprocessing transformations
|
||||||
self.preprocess = T.Compose([
|
self.preprocess = Compose([
|
||||||
T.Lambda(lambda img: self._pad_to_square(img)),
|
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
T.Resize(512),
|
ToTensorV2()
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
||||||
])
|
])
|
||||||
|
|
||||||
def _pad_to_square(self, pil_img):
|
|
||||||
"""Pad image to square while maintaining aspect ratio"""
|
|
||||||
w, h = pil_img.size
|
|
||||||
max_side = max(w, h)
|
|
||||||
hp = (max_side - w) // 2
|
|
||||||
vp = (max_side - h) // 2
|
|
||||||
padding = (hp, vp, max_side - w - hp, max_side - h - vp)
|
|
||||||
return T.functional.pad(pil_img, padding, 0, 'constant')
|
|
||||||
|
|
||||||
def _remove_padding(self, tensor, original_size):
|
def load_model(self, model_path: str):
|
||||||
"""Remove padding added during preprocessing"""
|
"""
|
||||||
_, _, h, w = tensor.shape
|
Load the trained model from the specified path.
|
||||||
orig_w, orig_h = original_size
|
|
||||||
|
|
||||||
# Calculate scale factor
|
Args:
|
||||||
scale = 512 / max(orig_w, orig_h)
|
model_path (str): Path to the saved model.
|
||||||
new_w = int(orig_w * scale)
|
|
||||||
new_h = int(orig_h * scale)
|
|
||||||
|
|
||||||
# Calculate padding offsets
|
Returns:
|
||||||
pad_w = (512 - new_w) // 2
|
nn.Module: Loaded PyTorch model.
|
||||||
pad_h = (512 - new_h) // 2
|
"""
|
||||||
|
# Load the base model and wrap it with Upscaler
|
||||||
|
base_model = AIIABase.load(model_path)
|
||||||
|
model = Upscaler(base_model)
|
||||||
|
|
||||||
# Remove padding
|
# Move the model to the appropriate device
|
||||||
unpad = tensor[:, :, pad_h:pad_h+new_h, pad_w:pad_w+new_w]
|
return model.to(self.device)
|
||||||
|
|
||||||
|
def preprocess_image(self, image: Image.Image):
|
||||||
|
"""
|
||||||
|
Preprocess the input image for inference.
|
||||||
|
|
||||||
# Resize to target 2x resolution
|
Args:
|
||||||
return F.interpolate(unpad, size=(orig_h*2, orig_w*2), mode='bilinear', align_corners=False)
|
image (PIL.Image.Image): Input image in PIL format.
|
||||||
|
|
||||||
def upscale(self, input_image):
|
|
||||||
# Preprocess
|
|
||||||
original_size = input_image.size
|
|
||||||
input_tensor = self.preprocess(input_image).unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
# Inference
|
Returns:
|
||||||
|
torch.Tensor: Preprocessed image tensor.
|
||||||
|
"""
|
||||||
|
# Convert PIL image to numpy array
|
||||||
|
image_array = np.array(image)
|
||||||
|
|
||||||
|
# Apply preprocessing transformations
|
||||||
|
augmented = self.preprocess(image=image_array)
|
||||||
|
|
||||||
|
# Add batch dimension and move to device
|
||||||
|
return augmented['image'].unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
def postprocess_image(self, output_tensor: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Postprocess the output tensor to convert it back to an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_tensor (torch.Tensor): Model output tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL.Image.Image: Upscaled image in PIL format.
|
||||||
|
"""
|
||||||
|
# Remove batch dimension and move to CPU
|
||||||
|
output_tensor = output_tensor.squeeze(0).cpu()
|
||||||
|
|
||||||
|
# Denormalize and convert to numpy array
|
||||||
|
output_array = (output_tensor * 0.5 + 0.5).clamp(0, 1).numpy()
|
||||||
|
|
||||||
|
# Convert from CHW (Channels-Height-Width) to HWC (Height-Width-Channels) format
|
||||||
|
output_array = (output_array.transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# Convert numpy array back to PIL image
|
||||||
|
return Image.fromarray(output_array)
|
||||||
|
|
||||||
|
def upscale_image(self, input_image_path: str):
|
||||||
|
"""
|
||||||
|
Perform upscaling on an input image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_image_path (str): Path to the input low-resolution image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL.Image.Image: Upscaled high-resolution image.
|
||||||
|
"""
|
||||||
|
# Load and preprocess the input image
|
||||||
|
input_image = Image.open(input_image_path).convert('RGB')
|
||||||
|
preprocessed_image = self.preprocess_image(input_image)
|
||||||
|
|
||||||
|
# Perform inference with the model
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
features = self.model.cnn(input_tensor)
|
with torch.cuda.amp.autocast(device_type="cuda"):
|
||||||
output = self.model.upsample(features)
|
output_tensor = self.model(preprocessed_image)
|
||||||
|
|
||||||
# Postprocess
|
# Postprocess and return the upscaled image
|
||||||
output = self._remove_padding(output, original_size)
|
return self.postprocess_image(output_tensor)
|
||||||
|
|
||||||
# Convert to PIL Image
|
|
||||||
output = output.squeeze(0).cpu().detach()
|
|
||||||
output = (output * 0.5 + 0.5).clamp(0, 1)
|
|
||||||
return T.functional.to_pil_image(output)
|
|
||||||
|
|
||||||
# Usage example
|
# Example usage:
|
||||||
if __name__ == "__main__":
|
# upscaler = ImageUpscaler(model_path="/path/to/best_model")
|
||||||
upscaler = UpScaler()
|
# upscaled_image = upscaler.upscale_image("/path/to/low_res_image.jpg")
|
||||||
input_image = Image.open("input.jpg")
|
# upscaled_image.save("/path/to/upscaled_image.jpg")
|
||||||
output_image = upscaler.upscale(input_image)
|
|
||||||
output_image.save("output_2x.jpg")
|
|
||||||
|
|
Loading…
Reference in New Issue