aiuNN/src/aiunn/finetune.py

263 lines
8.8 KiB
Python

import torch
import pandas as pd
from PIL import Image, ImageFile
import io
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from aiia.model import AIIABase, AIIA
from sklearn.model_selection import train_test_split
from typing import Dict, List, Union, Optional
import base64
class ImageDataset(Dataset):
def __init__(self, dataframe, transform=None):
self.dataframe = dataframe
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
try:
# Verify data is valid before creating BytesIO
if not isinstance(row['image_512'], bytes) or not isinstance(row['image_1024'], bytes):
raise ValueError("Image data must be in bytes format")
low_res_stream = io.BytesIO(row['image_512'])
high_res_stream = io.BytesIO(row['image_1024'])
# Reset stream position
low_res_stream.seek(0)
high_res_stream.seek(0)
# Enable loading of truncated images if necessary
ImageFile.LOAD_TRUNCATED_IMAGES = True
low_res_image = Image.open(low_res_stream).convert('RGB')
high_res_image = Image.open(high_res_stream).convert('RGB')
# Verify images are valid
low_res_image.verify()
high_res_image.verify()
except Exception as e:
raise ValueError(f"Image loading failed: {str(e)}")
finally:
low_res_stream.close()
high_res_stream.close()
if self.transform:
low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image)
return {'low_ress': low_res_image, 'high_ress': high_res_image}
class ModelTrainer:
def __init__(self,
model: AIIA,
dataset_paths: List[str],
batch_size: int = 32,
learning_rate: float = 0.001,
num_workers: int = 4,
train_ratio: float = 0.8):
"""
Specialized trainer for image super resolution tasks
Args:
model (nn.Module): Model instance to finetune
dataset_paths (List[str]): Paths to datasets
batch_size (int): Batch size for training
learning_rate (float): Learning rate for optimizer
num_workers (int): Number of workers for data loading
train_ratio (float): Ratio of data to use for training (rest goes to validation)
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_paths = dataset_paths
self.learning_rate = learning_rate
self.train_ratio = train_ratio
self.model = model
# Initialize datasets and loaders
self._initialize_datasets()
# Initialize training parameters
self._initialize_training()
def _initialize_datasets(self):
"""
Helper method to initialize datasets
"""
if isinstance(self.dataset_paths, list):
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
else:
raise ValueError("Invalid dataset_paths format. Must be a list.")
df_train, df_val = train_test_split(
df_train,
test_size=1 - self.train_ratio,
random_state=42
)
# Define preprocessing transforms
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Create datasets and dataloaders
self.train_dataset = ImageDataset(df_train, transform=self.transform)
self.val_dataset = ImageDataset(df_val, transform=self.transform)
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers
) if df_val is not None else None
def _initialize_training(self):
"""
Helper method to initialize training parameters
"""
# Freeze CNN layers (if applicable)
try:
for param in self.model.cnn.parameters():
param.requires_grad = False
except AttributeError:
pass # If model doesn't have a 'cnn' attribute, just continue
# Add upscaling layer if not already present
if not hasattr(self.model, 'upsample'):
# Get existing configuration values or set defaults if necessary
hidden_size = self.model.config.hidden_size
kernel_size = self.model.config.kernel_size
self.model.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1)
)
# Update the model's configuration with new parameters
self.model.config.upsample_hidden_size = hidden_size
self.model.config.upsample_kernel_size = kernel_size
# Initialize optimizer and loss function
self.criterion = nn.MSELoss()
# Get parameters of the upsample layer for training
params = [p for p in self.model.upsample.parameters() if p.requires_grad]
if not params:
raise ValueError("No parameters found in upsample layer to optimize")
self.optimizer = torch.optim.Adam(
params,
lr=self.learning_rate
)
self.best_val_loss = float('inf')
def train(self, num_epochs: int = 10):
"""
Train the model for specified number of epochs
"""
self.model.to(self.device)
for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
# Train phase
self._train_epoch()
# Validation phase
if self.val_loader is not None:
self._validate_epoch()
# Save best model based on validation loss
if self.val_loader is not None and self.current_val_loss < self.best_val_loss:
self.model.save("aiuNN-finetuned")
def _train_epoch(self):
"""
Train model for one epoch
"""
self.model.train()
running_loss = 0.0
for batch in self.train_loader:
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
# Forward pass
features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress)
outputs = self.model.upsample(features)
loss = self.criterion(outputs, high_ress)
# Backward pass and optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss / len(self.train_loader)
print(f"Train Loss: {epoch_loss:.4f}")
def _validate_epoch(self):
"""
Validate model performance
"""
self.model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in self.val_loader:
low_ress = batch['low_ress'].to(self.device)
high_ress = batch['high_ress'].to(self.device)
features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress)
outputs = self.model.upsample(features)
loss = self.criterion(outputs, high_ress)
val_loss += loss.item()
avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0
print(f"Validation Loss: {avg_val_loss:.4f}")
# Update best model
if avg_val_loss < self.best_val_loss:
self.best_val_loss = avg_val_loss
def __repr__(self):
return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})"
if __name__ == "__main__":
# Load your model first
model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/")
trainer = ModelTrainer(
model=model,
dataset_paths=[
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
],
batch_size=2,
learning_rate=0.001
)
trainer.train(num_epochs=3)