aiuNN/src/aiunn/finetune.py

259 lines
9.0 KiB
Python

import torch
import pandas as pd
from PIL import Image
import io
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from aiia.model import AIIABase
from sklearn.model_selection import train_test_split
from typing import Dict, List, Union
class ImageDataset(Dataset):
def __init__(self, dataframe, transform=None):
self.dataframe = dataframe
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
# Decode image_512 from bytes
img_bytes = row['image_512']
img_stream = io.BytesIO(img_bytes)
low_res_image = Image.open(img_stream).convert('RGB')
# Decode image_1024 from bytes
high_res_bytes = row['image_1024']
high_stream = io.BytesIO(high_res_bytes)
high_res_image = Image.open(high_stream).convert('RGB')
# Apply transformations if specified
if self.transform:
low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image)
return {'low_res': low_res_image, 'high_res': high_res_image}
class TrainingBase:
def __init__(self,
model_name: str,
dataset_paths: Union[List[str], Dict[str, str]],
batch_size: int = 32,
learning_rate: float = 0.001,
num_workers: int = 4,
train_ratio: float = 0.8):
"""
Base class for training models with multiple dataset support
Args:
model_name (str): Name of the model to initialize
dataset_paths (Union[List[str], Dict[str, str]]): Paths to datasets (train and optional validation)
batch_size (int): Batch size for training
learning_rate (float): Learning rate for optimizer
num_workers (int): Number of workers for data loading
train_ratio (float): Ratio of data to use for training (rest goes to validation)
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = batch_size
self.num_workers = num_workers
# Initialize datasets and loaders
self.dataset_paths = dataset_paths
self._initialize_datasets()
# Initialize model and training parameters
self.model_name = model_name
self.learning_rate = learning_rate
self._initialize_model()
def _initialize_datasets(self):
"""Helper method to initialize datasets"""
raise NotImplementedError("This method should be implemented in child classes")
def _initialize_model(self):
"""Helper method to initialize model architecture"""
raise NotImplementedError("This method should be implemented in child classes")
def train(self, num_epochs: int = 10):
"""Train the model for specified number of epochs"""
self.model.to(self.device)
for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
# Train phase
self._train_epoch()
# Validation phase
self._validate_epoch()
# Save best model based on validation loss
if self.current_val_loss < self.best_val_loss:
self.save_model()
def _train_epoch(self):
"""Train model for one epoch"""
raise NotImplementedError("This method should be implemented in child classes")
def _validate_epoch(self):
"""Validate model performance"""
raise NotImplementedError("This method should be implemented in child classes")
def save_model(self):
"""Save current best model"""
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_val_loss': self.best_val_loss
}, f"{self.model_name}_best.pth")
class Finetuner(TrainingBase):
def __init__(self,
model_name: str = "AIIA-Base-512",
dataset_paths: Union[List[str], Dict[str, str]] = None,
batch_size: int = 32,
learning_rate: float = 0.001,
num_workers: int = 4,
train_ratio: float = 0.8):
"""
Specialized trainer for image super resolution tasks
Args:
Same as TrainingBase
"""
super().__init__(model_name, dataset_paths, batch_size, learning_rate, num_workers, train_ratio)
def _initialize_datasets(self):
"""Initialize image datasets"""
# Load dataframes from parquet files
if isinstance(self.dataset_paths, dict):
df_train = pd.read_parquet(self.dataset_paths['train'])
df_val = pd.read_parquet(self.dataset_paths['val']) if 'val' in self.dataset_paths else None
elif isinstance(self.dataset_paths, list):
df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True)
df_val = None
else:
raise ValueError("Invalid dataset_paths format")
# Split into train and validation sets if needed
if df_val is None:
df_train, df_val = train_test_split(df_train, test_size=1 - self.train_ratio, random_state=42)
# Define preprocessing transforms
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Create datasets and dataloaders
self.train_dataset = ImageDataset(df_train, transform=self.transform)
self.val_dataset = ImageDataset(df_val, transform=self.transform)
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers
)
def _initialize_model(self):
"""Initialize and modify the super resolution model"""
# Load base model
self.model = AIIABase.load(self.model_name)
# Freeze CNN layers
for param in self.model.cnn.parameters():
param.requires_grad = False
# Add upscaling layer
hidden_size = self.model.config.hidden_size
self.model.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(hidden_size, 3, kernel_size=3, padding=1)
)
# Initialize optimizer and loss function
self.criterion = nn.MSELoss()
self.optimizer = torch.optim.Adam(
[param for param in self.model.parameters() if 'upsample' in str(param)],
lr=self.learning_rate
)
self.best_val_loss = float('inf')
def _train_epoch(self):
"""Train model for one epoch"""
self.model.train()
running_loss = 0.0
for batch in self.train_loader:
low_res = batch['low_res'].to(self.device)
high_res = batch['high_res'].to(self.device)
# Forward pass
features = self.model.cnn(low_res)
outputs = self.model.upsample(features)
loss = self.criterion(outputs, high_res)
# Backward pass and optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss / len(self.train_loader)
print(f"Train Loss: {epoch_loss:.4f}")
def _validate_epoch(self):
"""Validate model performance"""
self.model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in self.val_loader:
low_res = batch['low_res'].to(self.device)
high_res = batch['high_res'].to(self.device)
features = self.model.cnn(low_res)
outputs = self.model.upsample(features)
loss = self.criterion(outputs, high_res)
val_loss += loss.item()
avg_val_loss = val_loss / len(self.val_loader)
print(f"Validation Loss: {avg_val_loss:.4f}")
# Update best model
if avg_val_loss < self.best_val_loss:
self.best_val_loss = avg_val_loss
def __repr__(self):
return f"Model ({self.model_name}, batch_size={self.batch_size})"
# Example usage:
if __name__ == "__main__":
finetuner = Finetuner(
train_parquet_path="/root/training_data/vision-dataset/image_upscaler.parquet",
val_parquet_path="/root/training_data/vision-dataset/image_vec_upscaler.parquet",
batch_size=2,
learning_rate=0.001
)
finetuner.train_model(num_epochs=10)