Merge pull request 'develop' (#4) from develop into main
Reviewed-on: #4
This commit is contained in:
commit
cdf1e19280
|
@ -0,0 +1,106 @@
|
||||||
|
from aiia import AIIABase
|
||||||
|
from aiunn import aiuNN
|
||||||
|
from aiunn import aiuNNTrainer
|
||||||
|
import pandas as pd
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleDataset(Dataset):
|
||||||
|
def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000):
|
||||||
|
combined_df = pd.DataFrame()
|
||||||
|
for parquet_file in parquet_files:
|
||||||
|
# Load a subset from each parquet file
|
||||||
|
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(samples_per_file)
|
||||||
|
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
||||||
|
|
||||||
|
# Validate rows (ensuring each value is bytes or str)
|
||||||
|
self.df = combined_df.apply(self._validate_row, axis=1)
|
||||||
|
self.transform = transform
|
||||||
|
self.failed_indices = set()
|
||||||
|
|
||||||
|
def _validate_row(self, row):
|
||||||
|
for col in ['image_512', 'image_1024']:
|
||||||
|
if not isinstance(row[col], (bytes, str)):
|
||||||
|
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
|
||||||
|
return row
|
||||||
|
|
||||||
|
def _decode_image(self, data):
|
||||||
|
try:
|
||||||
|
if isinstance(data, str):
|
||||||
|
return base64.b64decode(data)
|
||||||
|
elif isinstance(data, bytes):
|
||||||
|
return data
|
||||||
|
raise ValueError(f"Unsupported data type: {type(data)}")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Decoding failed: {str(e)}")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.df)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# If previous call failed for this index, use a different index
|
||||||
|
if idx in self.failed_indices:
|
||||||
|
return self[(idx + 1) % len(self)]
|
||||||
|
try:
|
||||||
|
row = self.df.iloc[idx]
|
||||||
|
low_res_bytes = self._decode_image(row['image_512'])
|
||||||
|
high_res_bytes = self._decode_image(row['image_1024'])
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
# Open image bytes with Pillow and convert to RGBA first
|
||||||
|
low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
|
||||||
|
high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
|
||||||
|
|
||||||
|
# Create a new RGB image with black background
|
||||||
|
low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0))
|
||||||
|
high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0))
|
||||||
|
|
||||||
|
# Composite the original image over the black background
|
||||||
|
low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3])
|
||||||
|
high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3])
|
||||||
|
|
||||||
|
# Now we have true 3-channel RGB images with transparent areas converted to black
|
||||||
|
low_res = low_res_rgb
|
||||||
|
high_res = high_res_rgb
|
||||||
|
|
||||||
|
# Resize the images to reduce VRAM usage
|
||||||
|
low_res = low_res.resize((410, 410), Image.LANCZOS)
|
||||||
|
high_res = high_res.resize((820, 820), Image.LANCZOS)
|
||||||
|
|
||||||
|
# If a transform is provided (e.g. conversion to Tensor), apply it
|
||||||
|
if self.transform:
|
||||||
|
low_res = self.transform(low_res)
|
||||||
|
high_res = self.transform(high_res)
|
||||||
|
return low_res, high_res
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError at index {idx}: {str(e)}")
|
||||||
|
self.failed_indices.add(idx)
|
||||||
|
return self[(idx + 1) % len(self)]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ =="__main__":
|
||||||
|
# Load your base model and upscaler
|
||||||
|
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
|
||||||
|
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
|
||||||
|
upscaler = aiuNN(base_model)
|
||||||
|
|
||||||
|
# Create trainer with your dataset class
|
||||||
|
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
||||||
|
|
||||||
|
# Load data using parameters for your dataset
|
||||||
|
dataset_params = {
|
||||||
|
'parquet_files': [
|
||||||
|
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
||||||
|
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
||||||
|
],
|
||||||
|
'transform': transforms.Compose([transforms.ToTensor()]),
|
||||||
|
'samples_per_file': 5000
|
||||||
|
}
|
||||||
|
trainer.load_data(dataset_params=dataset_params, batch_size=1)
|
||||||
|
|
||||||
|
# Fine-tune the model
|
||||||
|
trainer.finetune(output_path="trained_models")
|
|
@ -0,0 +1,17 @@
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=45", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "aiunn"
|
||||||
|
version = "0.1.1"
|
||||||
|
description = "Finetuner for image upscaling using AIIA"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
license = {file = "LICENSE"}
|
||||||
|
authors = [
|
||||||
|
{name = "Falko Habel", email = "falko.habel@gmx.de"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
"Homepage" = "https://gitea.fabelous.app/Machine-Learning/aiuNN"
|
|
@ -0,0 +1,5 @@
|
||||||
|
torch
|
||||||
|
aiia
|
||||||
|
pillow
|
||||||
|
torchvision
|
||||||
|
sklearn
|
|
@ -0,0 +1,14 @@
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="aiunn",
|
||||||
|
version="0.1.1",
|
||||||
|
packages=find_packages(where="src"),
|
||||||
|
package_dir={"": "src"},
|
||||||
|
install_requires=[
|
||||||
|
line.strip()
|
||||||
|
for line in open("requirements.txt")
|
||||||
|
if line.strip() and not line.startswith("#")
|
||||||
|
],
|
||||||
|
python_requires=">=3.10",
|
||||||
|
)
|
|
@ -0,0 +1,6 @@
|
||||||
|
from .finetune.trainer import aiuNNTrainer
|
||||||
|
from .upsampler.aiunn import aiuNN
|
||||||
|
from .upsampler.config import aiuNNConfig
|
||||||
|
from .inference.inference import aiuNNInference
|
||||||
|
|
||||||
|
__version__ = "0.1.1"
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .trainer import aiuNNTrainer
|
||||||
|
|
||||||
|
__all__ = ["aiuNNTrainer" ]
|
|
@ -0,0 +1,290 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
from torch.amp import autocast, GradScaler
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
import gc
|
||||||
|
import time
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStopping:
|
||||||
|
def __init__(self, patience=3, min_delta=0.001):
|
||||||
|
# Number of epochs with no significant improvement before stopping
|
||||||
|
# Minimum change in loss required to count as an improvement
|
||||||
|
self.patience = patience
|
||||||
|
self.min_delta = min_delta
|
||||||
|
self.best_loss = float('inf')
|
||||||
|
self.counter = 0
|
||||||
|
self.early_stop = False
|
||||||
|
|
||||||
|
def __call__(self, epoch_loss):
|
||||||
|
if epoch_loss < self.best_loss - self.min_delta:
|
||||||
|
self.best_loss = epoch_loss
|
||||||
|
self.counter = 0
|
||||||
|
return True # Improved
|
||||||
|
else:
|
||||||
|
self.counter += 1
|
||||||
|
if self.counter >= self.patience:
|
||||||
|
self.early_stop = True
|
||||||
|
return False # Not improved
|
||||||
|
|
||||||
|
class aiuNNTrainer:
|
||||||
|
def __init__(self, upscaler_model, dataset_class=None):
|
||||||
|
"""
|
||||||
|
Initialize the upscaler trainer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
upscaler_model: The model to fine-tune
|
||||||
|
dataset_class: The dataset class to use for loading data (optional)
|
||||||
|
"""
|
||||||
|
self.model = upscaler_model
|
||||||
|
self.dataset_class = dataset_class
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.model = self.model.to(self.device, memory_format=torch.channels_last)
|
||||||
|
self.criterion = nn.MSELoss()
|
||||||
|
self.optimizer = None
|
||||||
|
self.scaler = GradScaler()
|
||||||
|
self.best_loss = float('inf')
|
||||||
|
self.use_checkpointing = True
|
||||||
|
self.data_loader = None
|
||||||
|
self.validation_loader = None
|
||||||
|
self.log_dir = None
|
||||||
|
|
||||||
|
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
|
||||||
|
"""
|
||||||
|
Load data using either a custom dataset instance or the dataset class provided at initialization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_params (dict/list): Parameters to pass to the dataset class constructor
|
||||||
|
batch_size (int): Batch size for training
|
||||||
|
validation_split (float): Proportion of data to use for validation
|
||||||
|
custom_train_dataset: A pre-instantiated dataset to use for training (optional)
|
||||||
|
custom_val_dataset: A pre-instantiated dataset to use for validation (optional)
|
||||||
|
"""
|
||||||
|
# If custom datasets are provided directly, use them
|
||||||
|
if custom_train_dataset is not None:
|
||||||
|
train_dataset = custom_train_dataset
|
||||||
|
val_dataset = custom_val_dataset if custom_val_dataset is not None else None
|
||||||
|
else:
|
||||||
|
# Otherwise instantiate dataset using the class and parameters
|
||||||
|
if self.dataset_class is None:
|
||||||
|
raise ValueError("No dataset class provided. Either provide a dataset class at initialization or custom datasets.")
|
||||||
|
|
||||||
|
# Create dataset instance
|
||||||
|
dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params})
|
||||||
|
|
||||||
|
# Split into train and validation sets
|
||||||
|
dataset_size = len(dataset)
|
||||||
|
val_size = int(validation_split * dataset_size)
|
||||||
|
train_size = dataset_size - val_size
|
||||||
|
|
||||||
|
train_dataset, val_dataset = torch.utils.data.random_split(
|
||||||
|
dataset, [train_size, val_size]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create data loaders
|
||||||
|
self.data_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if val_dataset is not None:
|
||||||
|
self.validation_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples")
|
||||||
|
else:
|
||||||
|
self.validation_loader = None
|
||||||
|
print(f"Loaded {len(train_dataset)} training samples (no validation set)")
|
||||||
|
|
||||||
|
return self.data_loader, self.validation_loader
|
||||||
|
|
||||||
|
def _setup_logging(self, output_path):
|
||||||
|
"""Set up directory structure for logging and model checkpoints"""
|
||||||
|
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
self.log_dir = os.path.join(output_path, f"training_run_{timestamp}")
|
||||||
|
os.makedirs(self.log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Create checkpoint directory
|
||||||
|
self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints")
|
||||||
|
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Set up CSV logging
|
||||||
|
self.csv_path = os.path.join(self.log_dir, 'training_log.csv')
|
||||||
|
with open(self.csv_path, mode='w', newline='') as file:
|
||||||
|
writer = csv.writer(file)
|
||||||
|
if self.validation_loader:
|
||||||
|
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved'])
|
||||||
|
else:
|
||||||
|
writer.writerow(['Epoch', 'Train Loss', 'Improved'])
|
||||||
|
|
||||||
|
def _evaluate(self):
|
||||||
|
"""Evaluate the model on validation data"""
|
||||||
|
if self.validation_loader is None:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for low_res, high_res in tqdm(self.validation_loader, desc="Validating"):
|
||||||
|
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
||||||
|
high_res = high_res.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
|
with autocast(device_type=self.device.type):
|
||||||
|
outputs = self.model(low_res)
|
||||||
|
loss = self.criterion(outputs, high_res)
|
||||||
|
|
||||||
|
val_loss += loss.item()
|
||||||
|
|
||||||
|
del low_res, high_res, outputs, loss
|
||||||
|
|
||||||
|
self.model.train()
|
||||||
|
return val_loss
|
||||||
|
|
||||||
|
def _save_checkpoint(self, epoch, is_best=False):
|
||||||
|
"""Save model checkpoint"""
|
||||||
|
checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt")
|
||||||
|
best_model_path = os.path.join(self.log_dir, "best_model")
|
||||||
|
|
||||||
|
# Save the model checkpoint
|
||||||
|
self.model.save(checkpoint_path)
|
||||||
|
|
||||||
|
# If this is the best model so far, copy it to best_model
|
||||||
|
if is_best:
|
||||||
|
if os.path.exists(best_model_path):
|
||||||
|
shutil.rmtree(best_model_path)
|
||||||
|
self.model.save(best_model_path)
|
||||||
|
print(f"Saved new best model with loss: {self.best_loss:.6f}")
|
||||||
|
|
||||||
|
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
|
||||||
|
"""
|
||||||
|
Finetune the upscaler model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path (str): Directory to save models and logs
|
||||||
|
epochs (int): Maximum number of training epochs
|
||||||
|
lr (float): Learning rate
|
||||||
|
patience (int): Early stopping patience
|
||||||
|
min_delta (float): Minimum improvement for early stopping
|
||||||
|
"""
|
||||||
|
# Check if data is loaded
|
||||||
|
if self.data_loader is None:
|
||||||
|
raise ValueError("Data not loaded. Call load_data first.")
|
||||||
|
|
||||||
|
# Setup optimizer
|
||||||
|
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
self._setup_logging(output_path)
|
||||||
|
|
||||||
|
# Setup early stopping
|
||||||
|
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
# Training phase
|
||||||
|
epoch_loss = 0.0
|
||||||
|
progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
|
||||||
|
|
||||||
|
for low_res, high_res in progress_bar:
|
||||||
|
# Move data to GPU with channels_last format where possible
|
||||||
|
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
||||||
|
high_res = high_res.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
with autocast(device_type=self.device.type):
|
||||||
|
if self.use_checkpointing:
|
||||||
|
# Ensure the input tensor requires gradient so that checkpointing records the computation graph
|
||||||
|
low_res.requires_grad_()
|
||||||
|
outputs = checkpoint(self.model, low_res)
|
||||||
|
else:
|
||||||
|
outputs = self.model(low_res)
|
||||||
|
loss = self.criterion(outputs, high_res)
|
||||||
|
|
||||||
|
self.scaler.scale(loss).backward()
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
|
||||||
|
epoch_loss += loss.item()
|
||||||
|
progress_bar.set_postfix({'loss': loss.item()})
|
||||||
|
|
||||||
|
# Optionally delete variables to free memory
|
||||||
|
del low_res, high_res, outputs, loss
|
||||||
|
|
||||||
|
# Calculate average epoch loss
|
||||||
|
avg_train_loss = epoch_loss / len(self.data_loader)
|
||||||
|
|
||||||
|
# Validation phase (if validation loader exists)
|
||||||
|
if self.validation_loader:
|
||||||
|
val_loss = self._evaluate() / len(self.validation_loader)
|
||||||
|
is_improved = val_loss < self.best_loss
|
||||||
|
if is_improved:
|
||||||
|
self.best_loss = val_loss
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
|
||||||
|
with open(self.csv_path, mode='a', newline='') as file:
|
||||||
|
writer = csv.writer(file)
|
||||||
|
writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"])
|
||||||
|
else:
|
||||||
|
# If no validation, use training loss for improvement tracking
|
||||||
|
is_improved = avg_train_loss < self.best_loss
|
||||||
|
if is_improved:
|
||||||
|
self.best_loss = avg_train_loss
|
||||||
|
|
||||||
|
# Log results
|
||||||
|
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
|
||||||
|
with open(self.csv_path, mode='a', newline='') as file:
|
||||||
|
writer = csv.writer(file)
|
||||||
|
writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"])
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
self._save_checkpoint(epoch + 1, is_best=is_improved)
|
||||||
|
|
||||||
|
# Perform garbage collection and clear GPU cache after each epoch
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Check early stopping
|
||||||
|
early_stopping(val_loss if self.validation_loader else avg_train_loss)
|
||||||
|
if early_stopping.early_stop:
|
||||||
|
print(f"Early stopping triggered at epoch {epoch + 1}")
|
||||||
|
break
|
||||||
|
|
||||||
|
return self.best_loss
|
||||||
|
|
||||||
|
def save(self, output_path=None):
|
||||||
|
"""
|
||||||
|
Save the best model to the specified path
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path (str, optional): Path to save the model. If None, uses the best model from training.
|
||||||
|
"""
|
||||||
|
if output_path is None and self.log_dir is not None:
|
||||||
|
best_model_path = os.path.join(self.log_dir, "best_model")
|
||||||
|
if os.path.exists(best_model_path):
|
||||||
|
print(f"Best model already saved at {best_model_path}")
|
||||||
|
return best_model_path
|
||||||
|
else:
|
||||||
|
output_path = os.path.join(self.log_dir, "final_model")
|
||||||
|
|
||||||
|
if output_path is None:
|
||||||
|
raise ValueError("No output path specified and no training has been done yet.")
|
||||||
|
|
||||||
|
self.model.save(output_path)
|
||||||
|
print(f"Model saved to {output_path}")
|
||||||
|
return output_path
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .inference import aiuNNInference
|
||||||
|
|
||||||
|
__all__ = ["aiuNNInference"]
|
|
@ -0,0 +1,226 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
from typing import Union, Optional, Tuple, List
|
||||||
|
from ..upsampler.aiunn import aiuNN
|
||||||
|
|
||||||
|
|
||||||
|
class aiuNNInference:
|
||||||
|
"""
|
||||||
|
Inference class for aiuNN upsampling model.
|
||||||
|
Handles model loading, image upscaling, and output processing.
|
||||||
|
"""
|
||||||
|
def __init__(self, model_path: str, precision: Optional[str] = None, device: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the inference class by loading the aiuNN model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the saved model directory
|
||||||
|
precision: Optional precision setting ('fp16', 'bf16', or None for default)
|
||||||
|
device: Optional device specification ('cuda', 'cpu', or None for auto-detection)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
if device is None:
|
||||||
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
else:
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Load the model with specified precision
|
||||||
|
self.model = aiuNN.load(model_path, precision=precision)
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# Store configuration for reference
|
||||||
|
self.config = self.model.config
|
||||||
|
|
||||||
|
def preprocess_image(self, image: Union[str, Image.Image, np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Preprocess the input image to match model requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image as file path, PIL Image, numpy array, or torch tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preprocessed tensor ready for model input
|
||||||
|
"""
|
||||||
|
# Handle different input types
|
||||||
|
if isinstance(image, str):
|
||||||
|
# Load from file path
|
||||||
|
image = Image.open(image).convert('RGB')
|
||||||
|
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
# Convert PIL Image to tensor
|
||||||
|
image = np.array(image)
|
||||||
|
image = image.transpose(2, 0, 1) # HWC to CHW
|
||||||
|
image = torch.from_numpy(image).float()
|
||||||
|
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
# Convert numpy array to tensor
|
||||||
|
if image.shape[0] == 3:
|
||||||
|
# Already in CHW format
|
||||||
|
pass
|
||||||
|
elif image.shape[-1] == 3:
|
||||||
|
# HWC to CHW format
|
||||||
|
image = image.transpose(2, 0, 1)
|
||||||
|
image = torch.from_numpy(image).float()
|
||||||
|
|
||||||
|
# Normalize to [0, 1] range if needed
|
||||||
|
if image.max() > 1.0:
|
||||||
|
image = image / 255.0
|
||||||
|
|
||||||
|
# Add batch dimension if not present
|
||||||
|
if len(image.shape) == 3:
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
|
||||||
|
# Move to device
|
||||||
|
image = image.to(self.device)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def postprocess_tensor(self, tensor: torch.Tensor) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Convert output tensor to PIL Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Output tensor from model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed PIL Image
|
||||||
|
"""
|
||||||
|
# Move to CPU and convert to numpy
|
||||||
|
output = tensor.detach().cpu().squeeze(0).numpy()
|
||||||
|
|
||||||
|
# Ensure proper range [0, 255]
|
||||||
|
output = np.clip(output * 255, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# Convert from CHW to HWC for PIL
|
||||||
|
output = output.transpose(1, 2, 0)
|
||||||
|
|
||||||
|
# Create PIL Image
|
||||||
|
return Image.fromarray(output)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def upscale(self, image: Union[str, Image.Image, np.ndarray, torch.Tensor]) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Upscale an image using the aiuNN model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input image to upscale
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Upscaled image as PIL Image
|
||||||
|
"""
|
||||||
|
# Preprocess input
|
||||||
|
input_tensor = self.preprocess_image(image)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
output_tensor = self.model(input_tensor)
|
||||||
|
|
||||||
|
# Postprocess output
|
||||||
|
upscaled_image = self.postprocess_tensor(output_tensor)
|
||||||
|
|
||||||
|
return upscaled_image
|
||||||
|
|
||||||
|
def save(self, image: Image.Image, output_path: str, format: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Save the upscaled image to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image to save
|
||||||
|
output_path: Path where the image should be saved
|
||||||
|
format: Optional format override (e.g., 'PNG', 'JPEG')
|
||||||
|
"""
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
||||||
|
|
||||||
|
# Save the image
|
||||||
|
image.save(output_path, format=format)
|
||||||
|
|
||||||
|
def convert_to_binary(self, image: Image.Image, format: str = 'PNG') -> bytes:
|
||||||
|
"""
|
||||||
|
Convert the image to binary data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: PIL Image to convert
|
||||||
|
format: Image format to use for binary conversion
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Binary representation of the image
|
||||||
|
"""
|
||||||
|
# Use BytesIO to convert to binary
|
||||||
|
binary_output = io.BytesIO()
|
||||||
|
image.save(binary_output, format=format)
|
||||||
|
|
||||||
|
# Get the binary data
|
||||||
|
binary_data = binary_output.getvalue()
|
||||||
|
|
||||||
|
return binary_data
|
||||||
|
|
||||||
|
def process_batch(self,
|
||||||
|
images: List[Union[str, Image.Image]],
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
save_format: str = 'PNG',
|
||||||
|
return_binary: bool = False) -> Union[List[Image.Image], List[bytes], None]:
|
||||||
|
"""
|
||||||
|
Process multiple images in batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: List of input images (paths or PIL Images)
|
||||||
|
output_dir: Optional directory to save results
|
||||||
|
save_format: Format to use when saving images
|
||||||
|
return_binary: Whether to return binary data instead of PIL Images
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of processed images or binary data, or None if only saving
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
# Upscale the image
|
||||||
|
upscaled = self.upscale(img)
|
||||||
|
|
||||||
|
# Save if output directory is provided
|
||||||
|
if output_dir:
|
||||||
|
# Extract filename if input is a path
|
||||||
|
if isinstance(img, str):
|
||||||
|
filename = os.path.basename(img)
|
||||||
|
base, _ = os.path.splitext(filename)
|
||||||
|
else:
|
||||||
|
base = f"upscaled_{i}"
|
||||||
|
|
||||||
|
output_path = os.path.join(output_dir, f"{base}.{save_format.lower()}")
|
||||||
|
self.save(upscaled, output_path, format=save_format)
|
||||||
|
|
||||||
|
# Add to results based on return type
|
||||||
|
if return_binary:
|
||||||
|
results.append(self.convert_to_binary(upscaled, format=save_format))
|
||||||
|
else:
|
||||||
|
results.append(upscaled)
|
||||||
|
|
||||||
|
return results if (not output_dir or return_binary or not save_format) else None
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage (can be removed)
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Initialize inference with a model path
|
||||||
|
inferencer = aiuNNInference("path/to/model", precision="bf16")
|
||||||
|
|
||||||
|
# Upscale a single image
|
||||||
|
upscaled_image = inferencer.upscale("input_image.jpg")
|
||||||
|
|
||||||
|
# Save the result
|
||||||
|
inferencer.save(upscaled_image, "output_image.png")
|
||||||
|
|
||||||
|
# Convert to binary
|
||||||
|
binary_data = inferencer.convert_to_binary(upscaled_image)
|
||||||
|
|
||||||
|
# Process a batch of images
|
||||||
|
inferencer.process_batch(
|
||||||
|
["image1.jpg", "image2.jpg"],
|
||||||
|
output_dir="output_folder",
|
||||||
|
save_format="PNG"
|
||||||
|
)
|
|
@ -0,0 +1,5 @@
|
||||||
|
from .aiunn import aiuNN
|
||||||
|
from .config import aiuNNConfig
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["aiuNN", "aiuNNConfig"]
|
|
@ -0,0 +1,82 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import warnings
|
||||||
|
from aiia import AIIA, AIIAConfig, AIIABase
|
||||||
|
from .config import aiuNNConfig
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
class aiuNN(AIIA):
|
||||||
|
def __init__(self, base_model: AIIABase):
|
||||||
|
super().__init__(base_model.config)
|
||||||
|
self.base_model = base_model
|
||||||
|
|
||||||
|
# Pass the unified base configuration using the new parameter.
|
||||||
|
self.config = aiuNNConfig(base_config=base_model.config)
|
||||||
|
|
||||||
|
self.upsample = nn.Upsample(
|
||||||
|
scale_factor=self.config.upsample_scale,
|
||||||
|
mode=self.config.upsample_mode,
|
||||||
|
align_corners=self.config.upsample_align_corners
|
||||||
|
)
|
||||||
|
# Conversion layer: change from hidden size channels to 3 channels.
|
||||||
|
self.to_rgb = nn.Conv2d(
|
||||||
|
in_channels=self.base_model.config.hidden_size,
|
||||||
|
out_channels=3,
|
||||||
|
kernel_size=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.base_model(x)
|
||||||
|
x = self.upsample(x)
|
||||||
|
x = self.to_rgb(x) # Ensures output has 3 channels.
|
||||||
|
return x
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path, precision: str = None):
|
||||||
|
# Load the configuration from disk.
|
||||||
|
config = AIIAConfig.load(path)
|
||||||
|
# Reconstruct the base model from the loaded configuration.
|
||||||
|
base_model = AIIABase(config)
|
||||||
|
# Instantiate the Upsampler using the proper base model.
|
||||||
|
upsampler = cls(base_model)
|
||||||
|
|
||||||
|
# Load state dict and handle precision conversion if needed.
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
state_dict = torch.load(f"{path}/model.pth", map_location=device)
|
||||||
|
if precision is not None:
|
||||||
|
if precision.lower() == 'fp16':
|
||||||
|
dtype = torch.float16
|
||||||
|
elif precision.lower() == 'bf16':
|
||||||
|
if device == 'cuda' and not torch.cuda.is_bf16_supported():
|
||||||
|
warnings.warn("BF16 is not supported on this GPU. Falling back to FP16.")
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported precision. Use 'fp16', 'bf16', or leave as None.")
|
||||||
|
|
||||||
|
for key, param in state_dict.items():
|
||||||
|
if torch.is_tensor(param):
|
||||||
|
state_dict[key] = param.to(dtype)
|
||||||
|
upsampler.load_state_dict(state_dict)
|
||||||
|
return upsampler
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from aiia import AIIABase, AIIAConfig
|
||||||
|
# Create a configuration and build a base model.
|
||||||
|
config = AIIAConfig()
|
||||||
|
base_model = AIIABase(config)
|
||||||
|
# Instantiate Upsampler from the base model (works correctly).
|
||||||
|
upsampler = aiuNN(base_model)
|
||||||
|
|
||||||
|
# Save the model (both configuration and weights).
|
||||||
|
upsampler.save("hehe")
|
||||||
|
|
||||||
|
# Now load using the overridden load method; this will load the complete model.
|
||||||
|
upsampler_loaded = aiuNN.load("hehe", precision="bf16")
|
||||||
|
print("Updated configuration:", upsampler_loaded.config.__dict__)
|
|
@ -0,0 +1,50 @@
|
||||||
|
from aiia import AIIAConfig
|
||||||
|
|
||||||
|
|
||||||
|
class aiuNNConfig(AIIAConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_config=None,
|
||||||
|
upsample_scale: int = 2,
|
||||||
|
upsample_mode: str = 'bilinear',
|
||||||
|
upsample_align_corners: bool = False,
|
||||||
|
layers=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# Start with a single configuration dictionary.
|
||||||
|
config_data = {}
|
||||||
|
if base_config is not None:
|
||||||
|
# If base_config is an object with a to_dict method, use it.
|
||||||
|
if hasattr(base_config, "to_dict"):
|
||||||
|
config_data.update(base_config.to_dict())
|
||||||
|
elif isinstance(base_config, dict):
|
||||||
|
config_data.update(base_config)
|
||||||
|
|
||||||
|
# Update with any additional keyword arguments (if needed).
|
||||||
|
config_data.update(kwargs)
|
||||||
|
|
||||||
|
# Initialize base AIIAConfig with a single merged configuration.
|
||||||
|
super().__init__(**config_data)
|
||||||
|
|
||||||
|
# Upsampler-specific parameters.
|
||||||
|
self.upsample_scale = upsample_scale
|
||||||
|
self.upsample_mode = upsample_mode
|
||||||
|
self.upsample_align_corners = upsample_align_corners
|
||||||
|
|
||||||
|
# Use layers from the argument or initialize an empty list.
|
||||||
|
self.layers = layers if layers is not None else []
|
||||||
|
|
||||||
|
# Add the upsample layer details only once.
|
||||||
|
self.add_upsample_layer()
|
||||||
|
|
||||||
|
def add_upsample_layer(self):
|
||||||
|
upsample_layer = {
|
||||||
|
'name': 'Upsample',
|
||||||
|
'type': 'nn.Upsample',
|
||||||
|
'scale_factor': self.upsample_scale,
|
||||||
|
'mode': self.upsample_mode,
|
||||||
|
'align_corners': self.upsample_align_corners
|
||||||
|
}
|
||||||
|
# Append the layer only if it isn’t already present.
|
||||||
|
if not any(layer.get('name') == 'Upsample' for layer in self.layers):
|
||||||
|
self.layers.append(upsample_layer)
|
Loading…
Reference in New Issue