import torch import pandas as pd from PIL import Image import io from torch import nn from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms from aiia.model import AIIABase from sklearn.model_selection import train_test_split from typing import Dict, List, Union class ImageDataset(Dataset): def __init__(self, dataframe, transform=None): self.dataframe = dataframe self.transform = transform def __len__(self): return len(self.dataframe) def __getitem__(self, idx): row = self.dataframe.iloc[idx] # Decode image_512 from bytes img_bytes = row['image_512'] img_stream = io.BytesIO(img_bytes) low_res_image = Image.open(img_stream).convert('RGB') # Decode image_1024 from bytes high_res_bytes = row['image_1024'] high_stream = io.BytesIO(high_res_bytes) high_res_image = Image.open(high_stream).convert('RGB') # Apply transformations if specified if self.transform: low_res_image = self.transform(low_res_image) high_res_image = self.transform(high_res_image) return {'low_res': low_res_image, 'high_res': high_res_image} class TrainingBase: def __init__(self, model_name: str, dataset_paths: Union[List[str], Dict[str, str]], batch_size: int = 32, learning_rate: float = 0.001, num_workers: int = 4, train_ratio: float = 0.8): """ Base class for training models with multiple dataset support Args: model_name (str): Name of the model to initialize dataset_paths (Union[List[str], Dict[str, str]]): Paths to datasets (train and optional validation) batch_size (int): Batch size for training learning_rate (float): Learning rate for optimizer num_workers (int): Number of workers for data loading train_ratio (float): Ratio of data to use for training (rest goes to validation) """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.batch_size = batch_size self.num_workers = num_workers # Initialize datasets and loaders self.dataset_paths = dataset_paths self._initialize_datasets() # Initialize model and training parameters self.model_name = model_name self.learning_rate = learning_rate self._initialize_model() def _initialize_datasets(self): """Helper method to initialize datasets""" raise NotImplementedError("This method should be implemented in child classes") def _initialize_model(self): """Helper method to initialize model architecture""" raise NotImplementedError("This method should be implemented in child classes") def train(self, num_epochs: int = 10): """Train the model for specified number of epochs""" self.model.to(self.device) for epoch in range(num_epochs): print(f"Epoch {epoch+1}/{num_epochs}") # Train phase self._train_epoch() # Validation phase self._validate_epoch() # Save best model based on validation loss if self.current_val_loss < self.best_val_loss: self.save_model() def _train_epoch(self): """Train model for one epoch""" raise NotImplementedError("This method should be implemented in child classes") def _validate_epoch(self): """Validate model performance""" raise NotImplementedError("This method should be implemented in child classes") def save_model(self): """Save current best model""" torch.save({ 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'best_val_loss': self.best_val_loss }, f"{self.model_name}_best.pth") class Finetuner(TrainingBase): def __init__(self, model_name: str = "AIIA-Base-512", dataset_paths: Union[List[str], Dict[str, str]] = None, batch_size: int = 32, learning_rate: float = 0.001, num_workers: int = 4, train_ratio: float = 0.8): """ Specialized trainer for image super resolution tasks Args: Same as TrainingBase """ super().__init__(model_name, dataset_paths, batch_size, learning_rate, num_workers, train_ratio) def _initialize_datasets(self): """Initialize image datasets""" # Load dataframes from parquet files if isinstance(self.dataset_paths, dict): df_train = pd.read_parquet(self.dataset_paths['train']) df_val = pd.read_parquet(self.dataset_paths['val']) if 'val' in self.dataset_paths else None elif isinstance(self.dataset_paths, list): df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True) df_val = None else: raise ValueError("Invalid dataset_paths format") # Split into train and validation sets if needed if df_val is None: df_train, df_val = train_test_split(df_train, test_size=1 - self.train_ratio, random_state=42) # Define preprocessing transforms self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Create datasets and dataloaders self.train_dataset = ImageDataset(df_train, transform=self.transform) self.val_dataset = ImageDataset(df_val, transform=self.transform) self.train_loader = DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers ) self.val_loader = DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers ) def _initialize_model(self): """Initialize and modify the super resolution model""" # Load base model self.model = AIIABase.load(self.model_name) # Freeze CNN layers for param in self.model.cnn.parameters(): param.requires_grad = False # Add upscaling layer hidden_size = self.model.config.hidden_size self.model.upsample = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.Conv2d(hidden_size, 3, kernel_size=3, padding=1) ) # Initialize optimizer and loss function self.criterion = nn.MSELoss() self.optimizer = torch.optim.Adam( [param for param in self.model.parameters() if 'upsample' in str(param)], lr=self.learning_rate ) self.best_val_loss = float('inf') def _train_epoch(self): """Train model for one epoch""" self.model.train() running_loss = 0.0 for batch in self.train_loader: low_res = batch['low_res'].to(self.device) high_res = batch['high_res'].to(self.device) # Forward pass features = self.model.cnn(low_res) outputs = self.model.upsample(features) loss = self.criterion(outputs, high_res) # Backward pass and optimize self.optimizer.zero_grad() loss.backward() self.optimizer.step() running_loss += loss.item() epoch_loss = running_loss / len(self.train_loader) print(f"Train Loss: {epoch_loss:.4f}") def _validate_epoch(self): """Validate model performance""" self.model.eval() val_loss = 0.0 with torch.no_grad(): for batch in self.val_loader: low_res = batch['low_res'].to(self.device) high_res = batch['high_res'].to(self.device) features = self.model.cnn(low_res) outputs = self.model.upsample(features) loss = self.criterion(outputs, high_res) val_loss += loss.item() avg_val_loss = val_loss / len(self.val_loader) print(f"Validation Loss: {avg_val_loss:.4f}") # Update best model if avg_val_loss < self.best_val_loss: self.best_val_loss = avg_val_loss def __repr__(self): return f"Model ({self.model_name}, batch_size={self.batch_size})" # Example usage: if __name__ == "__main__": finetuner = Finetuner( train_parquet_path="/root/training_data/vision-dataset/image_upscaler.parquet", val_parquet_path="/root/training_data/vision-dataset/image_vec_upscaler.parquet", batch_size=2, learning_rate=0.001 ) finetuner.train_model(num_epochs=10)