Merge pull request 'added aiunn inference' (#2) from add_use_class into develop
Reviewed-on: #2
This commit is contained in:
commit
5321eee803
|
@ -1,5 +1,6 @@
|
|||
from .finetune.trainer import aiuNNTrainer
|
||||
from .upsampler.aiunn import aiuNN
|
||||
from .upsampler.config import aiuNNConfig
|
||||
from .inference.inference import aiuNNInference
|
||||
|
||||
__version__ = "0.1.1"
|
|
@ -0,0 +1,3 @@
|
|||
from .inference import aiuNNInference
|
||||
|
||||
__all__ = ["aiuNNInference"]
|
|
@ -1,96 +1,226 @@
|
|||
import os
|
||||
import torch
|
||||
from albumentations import Compose, Normalize
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import io
|
||||
from torch import nn
|
||||
from aiia import AIIABase
|
||||
from typing import Union, Optional, Tuple, List
|
||||
from ..upsampler.aiunn import aiuNN
|
||||
|
||||
|
||||
class Upscaler(nn.Module):
|
||||
class aiuNNInference:
|
||||
"""
|
||||
Transforms the base model's final feature map using a transposed convolution.
|
||||
The base model produces a feature map of size 512x512.
|
||||
This layer upsamples by a factor of 2 (yielding 1024x1024) and maps the hidden features
|
||||
to the output channels using a single ConvTranspose2d layer.
|
||||
Inference class for aiuNN upsampling model.
|
||||
Handles model loading, image upscaling, and output processing.
|
||||
"""
|
||||
def __init__(self, base_model: AIIABase):
|
||||
super(Upscaler, self).__init__()
|
||||
self.base_model = base_model
|
||||
# Instead of adding separate upsampling and convolutional layers, we use a ConvTranspose2d layer.
|
||||
self.last_transform = nn.ConvTranspose2d(
|
||||
in_channels=base_model.config.hidden_size,
|
||||
out_channels=base_model.config.num_channels,
|
||||
kernel_size=base_model.config.kernel_size,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1
|
||||
)
|
||||
def __init__(self, model_path: str, precision: Optional[str] = None, device: Optional[str] = None):
|
||||
"""
|
||||
Initialize the inference class by loading the aiuNN model.
|
||||
|
||||
def forward(self, x):
|
||||
features = self.base_model(x)
|
||||
return self.last_transform(features)
|
||||
Args:
|
||||
model_path: Path to the saved model directory
|
||||
precision: Optional precision setting ('fp16', 'bf16', or None for default)
|
||||
device: Optional device specification ('cuda', 'cpu', or None for auto-detection)
|
||||
"""
|
||||
|
||||
|
||||
# Set device
|
||||
if device is None:
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
# Load the model with specified precision
|
||||
self.model = aiuNN.load(model_path, precision=precision)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# Store configuration for reference
|
||||
self.config = self.model.config
|
||||
|
||||
def preprocess_image(self, image: Union[str, Image.Image, np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess the input image to match model requirements.
|
||||
|
||||
Args:
|
||||
image: Input image as file path, PIL Image, numpy array, or torch tensor
|
||||
|
||||
Returns:
|
||||
Preprocessed tensor ready for model input
|
||||
"""
|
||||
# Handle different input types
|
||||
if isinstance(image, str):
|
||||
# Load from file path
|
||||
image = Image.open(image).convert('RGB')
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
# Convert PIL Image to tensor
|
||||
image = np.array(image)
|
||||
image = image.transpose(2, 0, 1) # HWC to CHW
|
||||
image = torch.from_numpy(image).float()
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
# Convert numpy array to tensor
|
||||
if image.shape[0] == 3:
|
||||
# Already in CHW format
|
||||
pass
|
||||
elif image.shape[-1] == 3:
|
||||
# HWC to CHW format
|
||||
image = image.transpose(2, 0, 1)
|
||||
image = torch.from_numpy(image).float()
|
||||
|
||||
# Normalize to [0, 1] range if needed
|
||||
if image.max() > 1.0:
|
||||
image = image / 255.0
|
||||
|
||||
# Add batch dimension if not present
|
||||
if len(image.shape) == 3:
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
# Move to device
|
||||
image = image.to(self.device)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class ImageUpscaler:
|
||||
def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
self.device = torch.device(device)
|
||||
self.model = self.load_model(model_path)
|
||||
self.model.eval() # Set to evaluation mode
|
||||
def postprocess_tensor(self, tensor: torch.Tensor) -> Image.Image:
|
||||
"""
|
||||
Convert output tensor to PIL Image.
|
||||
|
||||
# Define preprocessing transformations
|
||||
self.preprocess = Compose([
|
||||
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
ToTensorV2()
|
||||
])
|
||||
Args:
|
||||
tensor: Output tensor from model
|
||||
|
||||
Returns:
|
||||
Processed PIL Image
|
||||
"""
|
||||
# Move to CPU and convert to numpy
|
||||
output = tensor.detach().cpu().squeeze(0).numpy()
|
||||
|
||||
# Ensure proper range [0, 255]
|
||||
output = np.clip(output * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
# Convert from CHW to HWC for PIL
|
||||
output = output.transpose(1, 2, 0)
|
||||
|
||||
# Create PIL Image
|
||||
return Image.fromarray(output)
|
||||
|
||||
def load_model(self, model_path: str):
|
||||
@torch.no_grad()
|
||||
def upscale(self, image: Union[str, Image.Image, np.ndarray, torch.Tensor]) -> Image.Image:
|
||||
"""
|
||||
Load the trained model from the specified path.
|
||||
Upscale an image using the aiuNN model.
|
||||
|
||||
Args:
|
||||
image: Input image to upscale
|
||||
|
||||
Returns:
|
||||
Upscaled image as PIL Image
|
||||
"""
|
||||
base_model = AIIABase.load(model_path) # Load base model
|
||||
model = Upscaler(base_model) # Wrap with Upscaler
|
||||
return model.to(self.device)
|
||||
# Preprocess input
|
||||
input_tensor = self.preprocess_image(image)
|
||||
|
||||
# Run inference
|
||||
output_tensor = self.model(input_tensor)
|
||||
|
||||
# Postprocess output
|
||||
upscaled_image = self.postprocess_tensor(output_tensor)
|
||||
|
||||
return upscaled_image
|
||||
|
||||
def preprocess_image(self, image: Image.Image):
|
||||
def save(self, image: Image.Image, output_path: str, format: Optional[str] = None) -> None:
|
||||
"""
|
||||
Preprocess input image for inference.
|
||||
Save the upscaled image to a file.
|
||||
|
||||
Args:
|
||||
image: PIL Image to save
|
||||
output_path: Path where the image should be saved
|
||||
format: Optional format override (e.g., 'PNG', 'JPEG')
|
||||
"""
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError("Input must be a PIL.Image.Image object")
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
||||
|
||||
# Convert to numpy array and apply preprocessing
|
||||
image_array = np.array(image)
|
||||
augmented = self.preprocess(image=image_array)
|
||||
# Save the image
|
||||
image.save(output_path, format=format)
|
||||
|
||||
# Add batch dimension and move to device
|
||||
return augmented['image'].unsqueeze(0).to(self.device)
|
||||
def convert_to_binary(self, image: Image.Image, format: str = 'PNG') -> bytes:
|
||||
"""
|
||||
Convert the image to binary data.
|
||||
|
||||
Args:
|
||||
image: PIL Image to convert
|
||||
format: Image format to use for binary conversion
|
||||
|
||||
Returns:
|
||||
Binary representation of the image
|
||||
"""
|
||||
# Use BytesIO to convert to binary
|
||||
binary_output = io.BytesIO()
|
||||
image.save(binary_output, format=format)
|
||||
|
||||
# Get the binary data
|
||||
binary_data = binary_output.getvalue()
|
||||
|
||||
return binary_data
|
||||
|
||||
def postprocess_image(self, output_tensor: torch.Tensor):
|
||||
def process_batch(self,
|
||||
images: List[Union[str, Image.Image]],
|
||||
output_dir: Optional[str] = None,
|
||||
save_format: str = 'PNG',
|
||||
return_binary: bool = False) -> Union[List[Image.Image], List[bytes], None]:
|
||||
"""
|
||||
Convert output tensor back to an image.
|
||||
"""
|
||||
output_tensor = output_tensor.squeeze(0).cpu() # Remove batch dimension
|
||||
output_array = (output_tensor * 0.5 + 0.5).clamp(0, 1).numpy() * 255
|
||||
output_array = output_array.transpose(1, 2, 0).astype(np.uint8) # CHW -> HWC
|
||||
return Image.fromarray(output_array)
|
||||
|
||||
def upscale_image(self, input_image_path: str):
|
||||
"""
|
||||
Perform upscaling on an input image.
|
||||
"""
|
||||
input_image = Image.open(input_image_path).convert('RGB') # Ensure RGB format
|
||||
preprocessed_image = self.preprocess_image(input_image)
|
||||
Process multiple images in batch.
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.amp.autocast(device_type="cuda"):
|
||||
output_tensor = self.model(preprocessed_image)
|
||||
Args:
|
||||
images: List of input images (paths or PIL Images)
|
||||
output_dir: Optional directory to save results
|
||||
save_format: Format to use when saving images
|
||||
return_binary: Whether to return binary data instead of PIL Images
|
||||
|
||||
Returns:
|
||||
List of processed images or binary data, or None if only saving
|
||||
"""
|
||||
results = []
|
||||
|
||||
return self.postprocess_image(output_tensor)
|
||||
for i, img in enumerate(images):
|
||||
# Upscale the image
|
||||
upscaled = self.upscale(img)
|
||||
|
||||
# Save if output directory is provided
|
||||
if output_dir:
|
||||
# Extract filename if input is a path
|
||||
if isinstance(img, str):
|
||||
filename = os.path.basename(img)
|
||||
base, _ = os.path.splitext(filename)
|
||||
else:
|
||||
base = f"upscaled_{i}"
|
||||
|
||||
output_path = os.path.join(output_dir, f"{base}.{save_format.lower()}")
|
||||
self.save(upscaled, output_path, format=save_format)
|
||||
|
||||
# Add to results based on return type
|
||||
if return_binary:
|
||||
results.append(self.convert_to_binary(upscaled, format=save_format))
|
||||
else:
|
||||
results.append(upscaled)
|
||||
|
||||
return results if (not output_dir or return_binary or not save_format) else None
|
||||
|
||||
|
||||
# Example usage:
|
||||
upscaler = ImageUpscaler(model_path="/root/vision/aiuNN/best_model")
|
||||
upscaled_image = upscaler.upscale_image("/root/vision/aiuNN/input.jpg")
|
||||
upscaled_image.save("upscaled_image.jpg")
|
||||
# Example usage (can be removed)
|
||||
if __name__ == "__main__":
|
||||
# Initialize inference with a model path
|
||||
inferencer = aiuNNInference("path/to/model", precision="bf16")
|
||||
|
||||
# Upscale a single image
|
||||
upscaled_image = inferencer.upscale("input_image.jpg")
|
||||
|
||||
# Save the result
|
||||
inferencer.save(upscaled_image, "output_image.png")
|
||||
|
||||
# Convert to binary
|
||||
binary_data = inferencer.convert_to_binary(upscaled_image)
|
||||
|
||||
# Process a batch of images
|
||||
inferencer.process_batch(
|
||||
["image1.jpg", "image2.jpg"],
|
||||
output_dir="output_folder",
|
||||
save_format="PNG"
|
||||
)
|
Loading…
Reference in New Issue