finetune_class #1
|
@ -7,6 +7,7 @@ import io
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from aiia import AIIABase
|
from aiia import AIIABase
|
||||||
|
|
||||||
|
|
||||||
class Upscaler(nn.Module):
|
class Upscaler(nn.Module):
|
||||||
"""
|
"""
|
||||||
Transforms the base model's final feature map using a transposed convolution.
|
Transforms the base model's final feature map using a transposed convolution.
|
||||||
|
@ -30,19 +31,11 @@ class Upscaler(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
features = self.base_model(x)
|
features = self.base_model(x)
|
||||||
return self.last_transform(features)
|
return self.last_transform(features)
|
||||||
|
|
||||||
class ImageUpscaler:
|
class ImageUpscaler:
|
||||||
def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
||||||
"""
|
|
||||||
Initialize the ImageUpscaler with the trained model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path (str): Path to the trained model directory.
|
|
||||||
device (str): Device to run inference on ('cuda' or 'cpu').
|
|
||||||
"""
|
|
||||||
self.device = torch.device(device)
|
self.device = torch.device(device)
|
||||||
self.model = self.load_model(model_path)
|
self.model = self.load_model(model_path)
|
||||||
self.model.eval() # Set the model to evaluation mode
|
self.model.eval() # Set to evaluation mode
|
||||||
|
|
||||||
# Define preprocessing transformations
|
# Define preprocessing transformations
|
||||||
self.preprocess = Compose([
|
self.preprocess = Compose([
|
||||||
|
@ -53,34 +46,20 @@ class ImageUpscaler:
|
||||||
def load_model(self, model_path: str):
|
def load_model(self, model_path: str):
|
||||||
"""
|
"""
|
||||||
Load the trained model from the specified path.
|
Load the trained model from the specified path.
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path (str): Path to the saved model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
nn.Module: Loaded PyTorch model.
|
|
||||||
"""
|
"""
|
||||||
# Load the base model and wrap it with Upscaler
|
base_model = AIIABase.load(model_path) # Load base model
|
||||||
base_model = AIIABase.load(model_path)
|
model = Upscaler(base_model) # Wrap with Upscaler
|
||||||
model = Upscaler(base_model)
|
|
||||||
|
|
||||||
# Move the model to the appropriate device
|
|
||||||
return model.to(self.device)
|
return model.to(self.device)
|
||||||
|
|
||||||
def preprocess_image(self, image: Image.Image):
|
def preprocess_image(self, image: Image.Image):
|
||||||
"""
|
"""
|
||||||
Preprocess the input image for inference.
|
Preprocess input image for inference.
|
||||||
|
|
||||||
Args:
|
|
||||||
image (PIL.Image.Image): Input image in PIL format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Preprocessed image tensor.
|
|
||||||
"""
|
"""
|
||||||
# Convert PIL image to numpy array
|
if not isinstance(image, Image.Image):
|
||||||
image_array = np.array(image)
|
raise ValueError("Input must be a PIL.Image.Image object")
|
||||||
|
|
||||||
# Apply preprocessing transformations
|
# Convert to numpy array and apply preprocessing
|
||||||
|
image_array = np.array(image)
|
||||||
augmented = self.preprocess(image=image_array)
|
augmented = self.preprocess(image=image_array)
|
||||||
|
|
||||||
# Add batch dimension and move to device
|
# Add batch dimension and move to device
|
||||||
|
@ -88,48 +67,27 @@ class ImageUpscaler:
|
||||||
|
|
||||||
def postprocess_image(self, output_tensor: torch.Tensor):
|
def postprocess_image(self, output_tensor: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Postprocess the output tensor to convert it back to an image.
|
Convert output tensor back to an image.
|
||||||
|
|
||||||
Args:
|
|
||||||
output_tensor (torch.Tensor): Model output tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PIL.Image.Image: Upscaled image in PIL format.
|
|
||||||
"""
|
"""
|
||||||
# Remove batch dimension and move to CPU
|
output_tensor = output_tensor.squeeze(0).cpu() # Remove batch dimension
|
||||||
output_tensor = output_tensor.squeeze(0).cpu()
|
output_array = (output_tensor * 0.5 + 0.5).clamp(0, 1).numpy() * 255
|
||||||
|
output_array = output_array.transpose(1, 2, 0).astype(np.uint8) # CHW -> HWC
|
||||||
# Denormalize and convert to numpy array
|
|
||||||
output_array = (output_tensor * 0.5 + 0.5).clamp(0, 1).numpy()
|
|
||||||
|
|
||||||
# Convert from CHW (Channels-Height-Width) to HWC (Height-Width-Channels) format
|
|
||||||
output_array = (output_array.transpose(1, 2, 0) * 255).astype(np.uint8)
|
|
||||||
|
|
||||||
# Convert numpy array back to PIL image
|
|
||||||
return Image.fromarray(output_array)
|
return Image.fromarray(output_array)
|
||||||
|
|
||||||
def upscale_image(self, input_image_path: str):
|
def upscale_image(self, input_image_path: str):
|
||||||
"""
|
"""
|
||||||
Perform upscaling on an input image.
|
Perform upscaling on an input image.
|
||||||
|
|
||||||
Args:
|
|
||||||
input_image_path (str): Path to the input low-resolution image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PIL.Image.Image: Upscaled high-resolution image.
|
|
||||||
"""
|
"""
|
||||||
# Load and preprocess the input image
|
input_image = Image.open(input_image_path).convert('RGB') # Ensure RGB format
|
||||||
input_image = Image.open(input_image_path).convert('RGB')
|
|
||||||
preprocessed_image = self.preprocess_image(input_image)
|
preprocessed_image = self.preprocess_image(input_image)
|
||||||
|
|
||||||
# Perform inference with the model
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.amp.autocast(device_type="cuda"):
|
with torch.amp.autocast(device_type="cuda"):
|
||||||
output_tensor = self.model(preprocessed_image)
|
output_tensor = self.model(preprocessed_image)
|
||||||
|
|
||||||
# Postprocess and return the upscaled image
|
|
||||||
return self.postprocess_image(output_tensor)
|
return self.postprocess_image(output_tensor)
|
||||||
|
|
||||||
|
|
||||||
# Example usage:
|
# Example usage:
|
||||||
upscaler = ImageUpscaler(model_path="/root/vision/aiuNN/best_model")
|
upscaler = ImageUpscaler(model_path="/root/vision/aiuNN/best_model")
|
||||||
upscaled_image = upscaler.upscale_image("/root/vision/aiuNN/input.jpg")
|
upscaled_image = upscaler.upscale_image("/root/vision/aiuNN/input.jpg")
|
||||||
|
|
Loading…
Reference in New Issue