Merge pull request 'develop' (#4) from develop into main
Reviewed-on: #4
This commit is contained in:
commit
cdf1e19280
|
@ -0,0 +1,106 @@
|
|||
from aiia import AIIABase
|
||||
from aiunn import aiuNN
|
||||
from aiunn import aiuNNTrainer
|
||||
import pandas as pd
|
||||
import io
|
||||
import base64
|
||||
from PIL import Image, ImageFile
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
|
||||
class UpscaleDataset(Dataset):
|
||||
def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000):
|
||||
combined_df = pd.DataFrame()
|
||||
for parquet_file in parquet_files:
|
||||
# Load a subset from each parquet file
|
||||
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(samples_per_file)
|
||||
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
||||
|
||||
# Validate rows (ensuring each value is bytes or str)
|
||||
self.df = combined_df.apply(self._validate_row, axis=1)
|
||||
self.transform = transform
|
||||
self.failed_indices = set()
|
||||
|
||||
def _validate_row(self, row):
|
||||
for col in ['image_512', 'image_1024']:
|
||||
if not isinstance(row[col], (bytes, str)):
|
||||
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
|
||||
return row
|
||||
|
||||
def _decode_image(self, data):
|
||||
try:
|
||||
if isinstance(data, str):
|
||||
return base64.b64decode(data)
|
||||
elif isinstance(data, bytes):
|
||||
return data
|
||||
raise ValueError(f"Unsupported data type: {type(data)}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Decoding failed: {str(e)}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# If previous call failed for this index, use a different index
|
||||
if idx in self.failed_indices:
|
||||
return self[(idx + 1) % len(self)]
|
||||
try:
|
||||
row = self.df.iloc[idx]
|
||||
low_res_bytes = self._decode_image(row['image_512'])
|
||||
high_res_bytes = self._decode_image(row['image_1024'])
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
# Open image bytes with Pillow and convert to RGBA first
|
||||
low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
|
||||
high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
|
||||
|
||||
# Create a new RGB image with black background
|
||||
low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0))
|
||||
high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0))
|
||||
|
||||
# Composite the original image over the black background
|
||||
low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3])
|
||||
high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3])
|
||||
|
||||
# Now we have true 3-channel RGB images with transparent areas converted to black
|
||||
low_res = low_res_rgb
|
||||
high_res = high_res_rgb
|
||||
|
||||
# Resize the images to reduce VRAM usage
|
||||
low_res = low_res.resize((410, 410), Image.LANCZOS)
|
||||
high_res = high_res.resize((820, 820), Image.LANCZOS)
|
||||
|
||||
# If a transform is provided (e.g. conversion to Tensor), apply it
|
||||
if self.transform:
|
||||
low_res = self.transform(low_res)
|
||||
high_res = self.transform(high_res)
|
||||
return low_res, high_res
|
||||
except Exception as e:
|
||||
print(f"\nError at index {idx}: {str(e)}")
|
||||
self.failed_indices.add(idx)
|
||||
return self[(idx + 1) % len(self)]
|
||||
|
||||
|
||||
if __name__ =="__main__":
|
||||
# Load your base model and upscaler
|
||||
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
|
||||
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
|
||||
upscaler = aiuNN(base_model)
|
||||
|
||||
# Create trainer with your dataset class
|
||||
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
||||
|
||||
# Load data using parameters for your dataset
|
||||
dataset_params = {
|
||||
'parquet_files': [
|
||||
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
||||
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
||||
],
|
||||
'transform': transforms.Compose([transforms.ToTensor()]),
|
||||
'samples_per_file': 5000
|
||||
}
|
||||
trainer.load_data(dataset_params=dataset_params, batch_size=1)
|
||||
|
||||
# Fine-tune the model
|
||||
trainer.finetune(output_path="trained_models")
|
|
@ -0,0 +1,17 @@
|
|||
[build-system]
|
||||
requires = ["setuptools>=45", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "aiunn"
|
||||
version = "0.1.1"
|
||||
description = "Finetuner for image upscaling using AIIA"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = {file = "LICENSE"}
|
||||
authors = [
|
||||
{name = "Falko Habel", email = "falko.habel@gmx.de"},
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://gitea.fabelous.app/Machine-Learning/aiuNN"
|
|
@ -0,0 +1,5 @@
|
|||
torch
|
||||
aiia
|
||||
pillow
|
||||
torchvision
|
||||
sklearn
|
|
@ -0,0 +1,14 @@
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="aiunn",
|
||||
version="0.1.1",
|
||||
packages=find_packages(where="src"),
|
||||
package_dir={"": "src"},
|
||||
install_requires=[
|
||||
line.strip()
|
||||
for line in open("requirements.txt")
|
||||
if line.strip() and not line.startswith("#")
|
||||
],
|
||||
python_requires=">=3.10",
|
||||
)
|
|
@ -0,0 +1,6 @@
|
|||
from .finetune.trainer import aiuNNTrainer
|
||||
from .upsampler.aiunn import aiuNN
|
||||
from .upsampler.config import aiuNNConfig
|
||||
from .inference.inference import aiuNNInference
|
||||
|
||||
__version__ = "0.1.1"
|
|
@ -0,0 +1,3 @@
|
|||
from .trainer import aiuNNTrainer
|
||||
|
||||
__all__ = ["aiuNNTrainer" ]
|
|
@ -0,0 +1,290 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import os
|
||||
import csv
|
||||
from torch.amp import autocast, GradScaler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import gc
|
||||
import time
|
||||
import shutil
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
def __init__(self, patience=3, min_delta=0.001):
|
||||
# Number of epochs with no significant improvement before stopping
|
||||
# Minimum change in loss required to count as an improvement
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.best_loss = float('inf')
|
||||
self.counter = 0
|
||||
self.early_stop = False
|
||||
|
||||
def __call__(self, epoch_loss):
|
||||
if epoch_loss < self.best_loss - self.min_delta:
|
||||
self.best_loss = epoch_loss
|
||||
self.counter = 0
|
||||
return True # Improved
|
||||
else:
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
return False # Not improved
|
||||
|
||||
class aiuNNTrainer:
|
||||
def __init__(self, upscaler_model, dataset_class=None):
|
||||
"""
|
||||
Initialize the upscaler trainer
|
||||
|
||||
Args:
|
||||
upscaler_model: The model to fine-tune
|
||||
dataset_class: The dataset class to use for loading data (optional)
|
||||
"""
|
||||
self.model = upscaler_model
|
||||
self.dataset_class = dataset_class
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.model = self.model.to(self.device, memory_format=torch.channels_last)
|
||||
self.criterion = nn.MSELoss()
|
||||
self.optimizer = None
|
||||
self.scaler = GradScaler()
|
||||
self.best_loss = float('inf')
|
||||
self.use_checkpointing = True
|
||||
self.data_loader = None
|
||||
self.validation_loader = None
|
||||
self.log_dir = None
|
||||
|
||||
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
|
||||
"""
|
||||
Load data using either a custom dataset instance or the dataset class provided at initialization
|
||||
|
||||
Args:
|
||||
dataset_params (dict/list): Parameters to pass to the dataset class constructor
|
||||
batch_size (int): Batch size for training
|
||||
validation_split (float): Proportion of data to use for validation
|
||||
custom_train_dataset: A pre-instantiated dataset to use for training (optional)
|
||||
custom_val_dataset: A pre-instantiated dataset to use for validation (optional)
|
||||
"""
|
||||
# If custom datasets are provided directly, use them
|
||||
if custom_train_dataset is not None:
|
||||
train_dataset = custom_train_dataset
|
||||
val_dataset = custom_val_dataset if custom_val_dataset is not None else None
|
||||
else:
|
||||
# Otherwise instantiate dataset using the class and parameters
|
||||
if self.dataset_class is None:
|
||||
raise ValueError("No dataset class provided. Either provide a dataset class at initialization or custom datasets.")
|
||||
|
||||
# Create dataset instance
|
||||
dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params})
|
||||
|
||||
# Split into train and validation sets
|
||||
dataset_size = len(dataset)
|
||||
val_size = int(validation_split * dataset_size)
|
||||
train_size = dataset_size - val_size
|
||||
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(
|
||||
dataset, [train_size, val_size]
|
||||
)
|
||||
|
||||
# Create data loaders
|
||||
self.data_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
if val_dataset is not None:
|
||||
self.validation_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=True
|
||||
)
|
||||
print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples")
|
||||
else:
|
||||
self.validation_loader = None
|
||||
print(f"Loaded {len(train_dataset)} training samples (no validation set)")
|
||||
|
||||
return self.data_loader, self.validation_loader
|
||||
|
||||
def _setup_logging(self, output_path):
|
||||
"""Set up directory structure for logging and model checkpoints"""
|
||||
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||
self.log_dir = os.path.join(output_path, f"training_run_{timestamp}")
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
|
||||
# Create checkpoint directory
|
||||
self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints")
|
||||
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Set up CSV logging
|
||||
self.csv_path = os.path.join(self.log_dir, 'training_log.csv')
|
||||
with open(self.csv_path, mode='w', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
if self.validation_loader:
|
||||
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved'])
|
||||
else:
|
||||
writer.writerow(['Epoch', 'Train Loss', 'Improved'])
|
||||
|
||||
def _evaluate(self):
|
||||
"""Evaluate the model on validation data"""
|
||||
if self.validation_loader is None:
|
||||
return 0.0
|
||||
|
||||
self.model.eval()
|
||||
val_loss = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for low_res, high_res in tqdm(self.validation_loader, desc="Validating"):
|
||||
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
||||
high_res = high_res.to(self.device, non_blocking=True)
|
||||
|
||||
with autocast(device_type=self.device.type):
|
||||
outputs = self.model(low_res)
|
||||
loss = self.criterion(outputs, high_res)
|
||||
|
||||
val_loss += loss.item()
|
||||
|
||||
del low_res, high_res, outputs, loss
|
||||
|
||||
self.model.train()
|
||||
return val_loss
|
||||
|
||||
def _save_checkpoint(self, epoch, is_best=False):
|
||||
"""Save model checkpoint"""
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt")
|
||||
best_model_path = os.path.join(self.log_dir, "best_model")
|
||||
|
||||
# Save the model checkpoint
|
||||
self.model.save(checkpoint_path)
|
||||
|
||||
# If this is the best model so far, copy it to best_model
|
||||
if is_best:
|
||||
if os.path.exists(best_model_path):
|
||||
shutil.rmtree(best_model_path)
|
||||
self.model.save(best_model_path)
|
||||
print(f"Saved new best model with loss: {self.best_loss:.6f}")
|
||||
|
||||
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
|
||||
"""
|
||||
Finetune the upscaler model
|
||||
|
||||
Args:
|
||||
output_path (str): Directory to save models and logs
|
||||
epochs (int): Maximum number of training epochs
|
||||
lr (float): Learning rate
|
||||
patience (int): Early stopping patience
|
||||
min_delta (float): Minimum improvement for early stopping
|
||||
"""
|
||||
# Check if data is loaded
|
||||
if self.data_loader is None:
|
||||
raise ValueError("Data not loaded. Call load_data first.")
|
||||
|
||||
# Setup optimizer
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
||||
|
||||
# Set up logging
|
||||
self._setup_logging(output_path)
|
||||
|
||||
# Setup early stopping
|
||||
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
|
||||
|
||||
# Training loop
|
||||
self.model.train()
|
||||
|
||||
for epoch in range(epochs):
|
||||
# Training phase
|
||||
epoch_loss = 0.0
|
||||
progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
|
||||
|
||||
for low_res, high_res in progress_bar:
|
||||
# Move data to GPU with channels_last format where possible
|
||||
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
||||
high_res = high_res.to(self.device, non_blocking=True)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
with autocast(device_type=self.device.type):
|
||||
if self.use_checkpointing:
|
||||
# Ensure the input tensor requires gradient so that checkpointing records the computation graph
|
||||
low_res.requires_grad_()
|
||||
outputs = checkpoint(self.model, low_res)
|
||||
else:
|
||||
outputs = self.model(low_res)
|
||||
loss = self.criterion(outputs, high_res)
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
progress_bar.set_postfix({'loss': loss.item()})
|
||||
|
||||
# Optionally delete variables to free memory
|
||||
del low_res, high_res, outputs, loss
|
||||
|
||||
# Calculate average epoch loss
|
||||
avg_train_loss = epoch_loss / len(self.data_loader)
|
||||
|
||||
# Validation phase (if validation loader exists)
|
||||
if self.validation_loader:
|
||||
val_loss = self._evaluate() / len(self.validation_loader)
|
||||
is_improved = val_loss < self.best_loss
|
||||
if is_improved:
|
||||
self.best_loss = val_loss
|
||||
|
||||
# Log results
|
||||
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
|
||||
with open(self.csv_path, mode='a', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"])
|
||||
else:
|
||||
# If no validation, use training loss for improvement tracking
|
||||
is_improved = avg_train_loss < self.best_loss
|
||||
if is_improved:
|
||||
self.best_loss = avg_train_loss
|
||||
|
||||
# Log results
|
||||
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
|
||||
with open(self.csv_path, mode='a', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"])
|
||||
|
||||
# Save checkpoint
|
||||
self._save_checkpoint(epoch + 1, is_best=is_improved)
|
||||
|
||||
# Perform garbage collection and clear GPU cache after each epoch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Check early stopping
|
||||
early_stopping(val_loss if self.validation_loader else avg_train_loss)
|
||||
if early_stopping.early_stop:
|
||||
print(f"Early stopping triggered at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
return self.best_loss
|
||||
|
||||
def save(self, output_path=None):
|
||||
"""
|
||||
Save the best model to the specified path
|
||||
|
||||
Args:
|
||||
output_path (str, optional): Path to save the model. If None, uses the best model from training.
|
||||
"""
|
||||
if output_path is None and self.log_dir is not None:
|
||||
best_model_path = os.path.join(self.log_dir, "best_model")
|
||||
if os.path.exists(best_model_path):
|
||||
print(f"Best model already saved at {best_model_path}")
|
||||
return best_model_path
|
||||
else:
|
||||
output_path = os.path.join(self.log_dir, "final_model")
|
||||
|
||||
if output_path is None:
|
||||
raise ValueError("No output path specified and no training has been done yet.")
|
||||
|
||||
self.model.save(output_path)
|
||||
print(f"Model saved to {output_path}")
|
||||
return output_path
|
|
@ -0,0 +1,3 @@
|
|||
from .inference import aiuNNInference
|
||||
|
||||
__all__ = ["aiuNNInference"]
|
|
@ -0,0 +1,226 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import io
|
||||
from typing import Union, Optional, Tuple, List
|
||||
from ..upsampler.aiunn import aiuNN
|
||||
|
||||
|
||||
class aiuNNInference:
|
||||
"""
|
||||
Inference class for aiuNN upsampling model.
|
||||
Handles model loading, image upscaling, and output processing.
|
||||
"""
|
||||
def __init__(self, model_path: str, precision: Optional[str] = None, device: Optional[str] = None):
|
||||
"""
|
||||
Initialize the inference class by loading the aiuNN model.
|
||||
|
||||
Args:
|
||||
model_path: Path to the saved model directory
|
||||
precision: Optional precision setting ('fp16', 'bf16', or None for default)
|
||||
device: Optional device specification ('cuda', 'cpu', or None for auto-detection)
|
||||
"""
|
||||
|
||||
|
||||
# Set device
|
||||
if device is None:
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
# Load the model with specified precision
|
||||
self.model = aiuNN.load(model_path, precision=precision)
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# Store configuration for reference
|
||||
self.config = self.model.config
|
||||
|
||||
def preprocess_image(self, image: Union[str, Image.Image, np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess the input image to match model requirements.
|
||||
|
||||
Args:
|
||||
image: Input image as file path, PIL Image, numpy array, or torch tensor
|
||||
|
||||
Returns:
|
||||
Preprocessed tensor ready for model input
|
||||
"""
|
||||
# Handle different input types
|
||||
if isinstance(image, str):
|
||||
# Load from file path
|
||||
image = Image.open(image).convert('RGB')
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
# Convert PIL Image to tensor
|
||||
image = np.array(image)
|
||||
image = image.transpose(2, 0, 1) # HWC to CHW
|
||||
image = torch.from_numpy(image).float()
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
# Convert numpy array to tensor
|
||||
if image.shape[0] == 3:
|
||||
# Already in CHW format
|
||||
pass
|
||||
elif image.shape[-1] == 3:
|
||||
# HWC to CHW format
|
||||
image = image.transpose(2, 0, 1)
|
||||
image = torch.from_numpy(image).float()
|
||||
|
||||
# Normalize to [0, 1] range if needed
|
||||
if image.max() > 1.0:
|
||||
image = image / 255.0
|
||||
|
||||
# Add batch dimension if not present
|
||||
if len(image.shape) == 3:
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
# Move to device
|
||||
image = image.to(self.device)
|
||||
|
||||
return image
|
||||
|
||||
def postprocess_tensor(self, tensor: torch.Tensor) -> Image.Image:
|
||||
"""
|
||||
Convert output tensor to PIL Image.
|
||||
|
||||
Args:
|
||||
tensor: Output tensor from model
|
||||
|
||||
Returns:
|
||||
Processed PIL Image
|
||||
"""
|
||||
# Move to CPU and convert to numpy
|
||||
output = tensor.detach().cpu().squeeze(0).numpy()
|
||||
|
||||
# Ensure proper range [0, 255]
|
||||
output = np.clip(output * 255, 0, 255).astype(np.uint8)
|
||||
|
||||
# Convert from CHW to HWC for PIL
|
||||
output = output.transpose(1, 2, 0)
|
||||
|
||||
# Create PIL Image
|
||||
return Image.fromarray(output)
|
||||
|
||||
@torch.no_grad()
|
||||
def upscale(self, image: Union[str, Image.Image, np.ndarray, torch.Tensor]) -> Image.Image:
|
||||
"""
|
||||
Upscale an image using the aiuNN model.
|
||||
|
||||
Args:
|
||||
image: Input image to upscale
|
||||
|
||||
Returns:
|
||||
Upscaled image as PIL Image
|
||||
"""
|
||||
# Preprocess input
|
||||
input_tensor = self.preprocess_image(image)
|
||||
|
||||
# Run inference
|
||||
output_tensor = self.model(input_tensor)
|
||||
|
||||
# Postprocess output
|
||||
upscaled_image = self.postprocess_tensor(output_tensor)
|
||||
|
||||
return upscaled_image
|
||||
|
||||
def save(self, image: Image.Image, output_path: str, format: Optional[str] = None) -> None:
|
||||
"""
|
||||
Save the upscaled image to a file.
|
||||
|
||||
Args:
|
||||
image: PIL Image to save
|
||||
output_path: Path where the image should be saved
|
||||
format: Optional format override (e.g., 'PNG', 'JPEG')
|
||||
"""
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
||||
|
||||
# Save the image
|
||||
image.save(output_path, format=format)
|
||||
|
||||
def convert_to_binary(self, image: Image.Image, format: str = 'PNG') -> bytes:
|
||||
"""
|
||||
Convert the image to binary data.
|
||||
|
||||
Args:
|
||||
image: PIL Image to convert
|
||||
format: Image format to use for binary conversion
|
||||
|
||||
Returns:
|
||||
Binary representation of the image
|
||||
"""
|
||||
# Use BytesIO to convert to binary
|
||||
binary_output = io.BytesIO()
|
||||
image.save(binary_output, format=format)
|
||||
|
||||
# Get the binary data
|
||||
binary_data = binary_output.getvalue()
|
||||
|
||||
return binary_data
|
||||
|
||||
def process_batch(self,
|
||||
images: List[Union[str, Image.Image]],
|
||||
output_dir: Optional[str] = None,
|
||||
save_format: str = 'PNG',
|
||||
return_binary: bool = False) -> Union[List[Image.Image], List[bytes], None]:
|
||||
"""
|
||||
Process multiple images in batch.
|
||||
|
||||
Args:
|
||||
images: List of input images (paths or PIL Images)
|
||||
output_dir: Optional directory to save results
|
||||
save_format: Format to use when saving images
|
||||
return_binary: Whether to return binary data instead of PIL Images
|
||||
|
||||
Returns:
|
||||
List of processed images or binary data, or None if only saving
|
||||
"""
|
||||
results = []
|
||||
|
||||
for i, img in enumerate(images):
|
||||
# Upscale the image
|
||||
upscaled = self.upscale(img)
|
||||
|
||||
# Save if output directory is provided
|
||||
if output_dir:
|
||||
# Extract filename if input is a path
|
||||
if isinstance(img, str):
|
||||
filename = os.path.basename(img)
|
||||
base, _ = os.path.splitext(filename)
|
||||
else:
|
||||
base = f"upscaled_{i}"
|
||||
|
||||
output_path = os.path.join(output_dir, f"{base}.{save_format.lower()}")
|
||||
self.save(upscaled, output_path, format=save_format)
|
||||
|
||||
# Add to results based on return type
|
||||
if return_binary:
|
||||
results.append(self.convert_to_binary(upscaled, format=save_format))
|
||||
else:
|
||||
results.append(upscaled)
|
||||
|
||||
return results if (not output_dir or return_binary or not save_format) else None
|
||||
|
||||
|
||||
# Example usage (can be removed)
|
||||
if __name__ == "__main__":
|
||||
# Initialize inference with a model path
|
||||
inferencer = aiuNNInference("path/to/model", precision="bf16")
|
||||
|
||||
# Upscale a single image
|
||||
upscaled_image = inferencer.upscale("input_image.jpg")
|
||||
|
||||
# Save the result
|
||||
inferencer.save(upscaled_image, "output_image.png")
|
||||
|
||||
# Convert to binary
|
||||
binary_data = inferencer.convert_to_binary(upscaled_image)
|
||||
|
||||
# Process a batch of images
|
||||
inferencer.process_batch(
|
||||
["image1.jpg", "image2.jpg"],
|
||||
output_dir="output_folder",
|
||||
save_format="PNG"
|
||||
)
|
|
@ -0,0 +1,5 @@
|
|||
from .aiunn import aiuNN
|
||||
from .config import aiuNNConfig
|
||||
|
||||
|
||||
__all__ = ["aiuNN", "aiuNNConfig"]
|
|
@ -0,0 +1,82 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import warnings
|
||||
from aiia import AIIA, AIIAConfig, AIIABase
|
||||
from .config import aiuNNConfig
|
||||
import warnings
|
||||
|
||||
|
||||
class aiuNN(AIIA):
|
||||
def __init__(self, base_model: AIIABase):
|
||||
super().__init__(base_model.config)
|
||||
self.base_model = base_model
|
||||
|
||||
# Pass the unified base configuration using the new parameter.
|
||||
self.config = aiuNNConfig(base_config=base_model.config)
|
||||
|
||||
self.upsample = nn.Upsample(
|
||||
scale_factor=self.config.upsample_scale,
|
||||
mode=self.config.upsample_mode,
|
||||
align_corners=self.config.upsample_align_corners
|
||||
)
|
||||
# Conversion layer: change from hidden size channels to 3 channels.
|
||||
self.to_rgb = nn.Conv2d(
|
||||
in_channels=self.base_model.config.hidden_size,
|
||||
out_channels=3,
|
||||
kernel_size=1
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base_model(x)
|
||||
x = self.upsample(x)
|
||||
x = self.to_rgb(x) # Ensures output has 3 channels.
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def load(cls, path, precision: str = None):
|
||||
# Load the configuration from disk.
|
||||
config = AIIAConfig.load(path)
|
||||
# Reconstruct the base model from the loaded configuration.
|
||||
base_model = AIIABase(config)
|
||||
# Instantiate the Upsampler using the proper base model.
|
||||
upsampler = cls(base_model)
|
||||
|
||||
# Load state dict and handle precision conversion if needed.
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
state_dict = torch.load(f"{path}/model.pth", map_location=device)
|
||||
if precision is not None:
|
||||
if precision.lower() == 'fp16':
|
||||
dtype = torch.float16
|
||||
elif precision.lower() == 'bf16':
|
||||
if device == 'cuda' and not torch.cuda.is_bf16_supported():
|
||||
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
||||
|
||||
for key, param in state_dict.items():
|
||||
if torch.is_tensor(param):
|
||||
state_dict[key] = param.to(dtype)
|
||||
upsampler.load_state_dict(state_dict)
|
||||
return upsampler
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from aiia import AIIABase, AIIAConfig
|
||||
# Create a configuration and build a base model.
|
||||
config = AIIAConfig()
|
||||
base_model = AIIABase(config)
|
||||
# Instantiate Upsampler from the base model (works correctly).
|
||||
upsampler = aiuNN(base_model)
|
||||
|
||||
# Save the model (both configuration and weights).
|
||||
upsampler.save("hehe")
|
||||
|
||||
# Now load using the overridden load method; this will load the complete model.
|
||||
upsampler_loaded = aiuNN.load("hehe", precision="bf16")
|
||||
print("Updated configuration:", upsampler_loaded.config.__dict__)
|
|
@ -0,0 +1,50 @@
|
|||
from aiia import AIIAConfig
|
||||
|
||||
|
||||
class aiuNNConfig(AIIAConfig):
|
||||
def __init__(
|
||||
self,
|
||||
base_config=None,
|
||||
upsample_scale: int = 2,
|
||||
upsample_mode: str = 'bilinear',
|
||||
upsample_align_corners: bool = False,
|
||||
layers=None,
|
||||
**kwargs
|
||||
):
|
||||
# Start with a single configuration dictionary.
|
||||
config_data = {}
|
||||
if base_config is not None:
|
||||
# If base_config is an object with a to_dict method, use it.
|
||||
if hasattr(base_config, "to_dict"):
|
||||
config_data.update(base_config.to_dict())
|
||||
elif isinstance(base_config, dict):
|
||||
config_data.update(base_config)
|
||||
|
||||
# Update with any additional keyword arguments (if needed).
|
||||
config_data.update(kwargs)
|
||||
|
||||
# Initialize base AIIAConfig with a single merged configuration.
|
||||
super().__init__(**config_data)
|
||||
|
||||
# Upsampler-specific parameters.
|
||||
self.upsample_scale = upsample_scale
|
||||
self.upsample_mode = upsample_mode
|
||||
self.upsample_align_corners = upsample_align_corners
|
||||
|
||||
# Use layers from the argument or initialize an empty list.
|
||||
self.layers = layers if layers is not None else []
|
||||
|
||||
# Add the upsample layer details only once.
|
||||
self.add_upsample_layer()
|
||||
|
||||
def add_upsample_layer(self):
|
||||
upsample_layer = {
|
||||
'name': 'Upsample',
|
||||
'type': 'nn.Upsample',
|
||||
'scale_factor': self.upsample_scale,
|
||||
'mode': self.upsample_mode,
|
||||
'align_corners': self.upsample_align_corners
|
||||
}
|
||||
# Append the layer only if it isn’t already present.
|
||||
if not any(layer.get('name') == 'Upsample' for layer in self.layers):
|
||||
self.layers.append(upsample_layer)
|
Loading…
Reference in New Issue