import torch import pandas as pd from PIL import Image, ImageFile import io from torch import nn from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms from aiia.model import AIIABase, AIIA from sklearn.model_selection import train_test_split from typing import Dict, List, Union, Optional import base64 class ImageDataset(Dataset): def __init__(self, dataframe, transform=None): self.dataframe = dataframe self.transform = transform def __len__(self): return len(self.dataframe) def __getitem__(self, idx): row = self.dataframe.iloc[idx] try: # Verify data is valid before creating BytesIO if not isinstance(row['image_512'], bytes) or not isinstance(row['image_1024'], bytes): raise ValueError("Image data must be in bytes format") low_res_stream = io.BytesIO(row['image_512']) high_res_stream = io.BytesIO(row['image_1024']) # Reset stream position low_res_stream.seek(0) high_res_stream.seek(0) # Enable loading of truncated images if necessary ImageFile.LOAD_TRUNCATED_IMAGES = True low_res_image = Image.open(low_res_stream).convert('RGB') high_res_image = Image.open(high_res_stream).convert('RGB') # Verify images are valid low_res_image.verify() high_res_image.verify() except Exception as e: raise ValueError(f"Image loading failed: {str(e)}") finally: low_res_stream.close() high_res_stream.close() if self.transform: low_res_image = self.transform(low_res_image) high_res_image = self.transform(high_res_image) return {'low_ress': low_res_image, 'high_ress': high_res_image} class ModelTrainer: def __init__(self, model: AIIA, dataset_paths: List[str], batch_size: int = 32, learning_rate: float = 0.001, num_workers: int = 4, train_ratio: float = 0.8): """ Specialized trainer for image super resolution tasks Args: model (nn.Module): Model instance to finetune dataset_paths (List[str]): Paths to datasets batch_size (int): Batch size for training learning_rate (float): Learning rate for optimizer num_workers (int): Number of workers for data loading train_ratio (float): Ratio of data to use for training (rest goes to validation) """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.batch_size = batch_size self.num_workers = num_workers self.dataset_paths = dataset_paths self.learning_rate = learning_rate self.train_ratio = train_ratio self.model = model # Initialize datasets and loaders self._initialize_datasets() # Initialize training parameters self._initialize_training() def _initialize_datasets(self): """ Helper method to initialize datasets """ if isinstance(self.dataset_paths, list): df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True) else: raise ValueError("Invalid dataset_paths format. Must be a list.") df_train, df_val = train_test_split( df_train, test_size=1 - self.train_ratio, random_state=42 ) # Define preprocessing transforms self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Create datasets and dataloaders self.train_dataset = ImageDataset(df_train, transform=self.transform) self.val_dataset = ImageDataset(df_val, transform=self.transform) self.train_loader = DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers ) self.val_loader = DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers ) if df_val is not None else None def _initialize_training(self): """ Helper method to initialize training parameters """ # Freeze CNN layers (if applicable) try: for param in self.model.cnn.parameters(): param.requires_grad = False except AttributeError: pass # If model doesn't have a 'cnn' attribute, just continue # Add upscaling layer if not already present if not hasattr(self.model, 'upsample'): # Get existing configuration values or set defaults if necessary hidden_size = self.model.config.hidden_size kernel_size = self.model.config.kernel_size self.model.upsample = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1) ) # Update the model's configuration with new parameters self.model.config.upsample_hidden_size = hidden_size self.model.config.upsample_kernel_size = kernel_size # Initialize optimizer and loss function self.criterion = nn.MSELoss() # Get parameters of the upsample layer for training params = [p for p in self.model.upsample.parameters() if p.requires_grad] if not params: raise ValueError("No parameters found in upsample layer to optimize") self.optimizer = torch.optim.Adam( params, lr=self.learning_rate ) self.best_val_loss = float('inf') def train(self, num_epochs: int = 10): """ Train the model for specified number of epochs """ self.model.to(self.device) for epoch in range(num_epochs): print(f"Epoch {epoch+1}/{num_epochs}") # Train phase self._train_epoch() # Validation phase if self.val_loader is not None: self._validate_epoch() # Save best model based on validation loss if self.val_loader is not None and self.current_val_loss < self.best_val_loss: self.model.save("aiuNN-finetuned") def _train_epoch(self): """ Train model for one epoch """ self.model.train() running_loss = 0.0 for batch in self.train_loader: low_ress = batch['low_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device) # Forward pass features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress) outputs = self.model.upsample(features) loss = self.criterion(outputs, high_ress) # Backward pass and optimize self.optimizer.zero_grad() loss.backward() self.optimizer.step() running_loss += loss.item() epoch_loss = running_loss / len(self.train_loader) print(f"Train Loss: {epoch_loss:.4f}") def _validate_epoch(self): """ Validate model performance """ self.model.eval() val_loss = 0.0 with torch.no_grad(): for batch in self.val_loader: low_ress = batch['low_ress'].to(self.device) high_ress = batch['high_ress'].to(self.device) features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress) outputs = self.model.upsample(features) loss = self.criterion(outputs, high_ress) val_loss += loss.item() avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0 print(f"Validation Loss: {avg_val_loss:.4f}") # Update best model if avg_val_loss < self.best_val_loss: self.best_val_loss = avg_val_loss def __repr__(self): return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})" if __name__ == "__main__": # Load your model first model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/") trainer = ModelTrainer( model=model, dataset_paths=[ "/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet" ], batch_size=2, learning_rate=0.001 ) trainer.train(num_epochs=3)