From de6a67cb4e0d544c709d17988d8f315fea490641 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Wed, 5 Feb 2025 16:16:31 +0100 Subject: [PATCH] updated finetuning script --- src/aiunn/finetune.py | 536 +++++++++++------------------------------- 1 file changed, 141 insertions(+), 395 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index be54fe2..4fde964 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -1,409 +1,155 @@ import torch import pandas as pd -from PIL import Image, ImageFile -import io -from torch import nn -from torch.utils.data import Dataset, DataLoader -import torchvision.transforms as transforms -from aiia.model import AIIABase, AIIA, AIIAConfig -from sklearn.model_selection import train_test_split -from typing import Dict, List, Union, Optional -import base64 -from tqdm import tqdm +import numpy as np +import cv2 import os +from albumentations import ( + Compose, Resize, Normalize, RandomBrightnessContrast, + HorizontalFlip, VerticalFlip, Rotate, GaussianBlur +) +from albumentations.pytorch import ToTensorV2 +from torch import nn +# Import the model and config from your existing code +from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive - -class ImageDataset(Dataset): - def __init__(self, dataframe, transform=None): - """ - Initialize the dataset with a dataframe containing image data +class aiuNNDataset(torch.utils.data.Dataset): + def __init__(self, parquet_path, config=None): + # Read the Parquet file + self.df = pd.read_parquet(parquet_path) - Args: - dataframe (pd.DataFrame): DataFrame containing 'image_512' and 'image_1024' columns - transform (callable, optional): Optional transform to be applied to both images - """ - self.dataframe = dataframe - self.transform = transform - - def __len__(self): - return len(self.dataframe) - - def __getitem__(self, idx): - """ - Get a pair of low and high resolution images - - Args: - idx (int): Index of the data point - - Returns: - dict: Contains 'low_ress' and 'high_ress' PIL images or transformed tensors - """ - row = self.dataframe.iloc[idx] - - try: - # Handle both bytes and base64 encoded strings - low_res_data = row['image_512'] - high_res_data = row['image_1024'] - - if isinstance(low_res_data, str): - # Decode base64 string to bytes - low_res_data = base64.b64decode(low_res_data) - high_res_data = base64.b64decode(high_res_data) - - # Verify data is valid before creating BytesIO - if not isinstance(low_res_data, bytes) or not isinstance(high_res_data, bytes): - raise ValueError(f"Invalid image data format at index {idx}") - - # Create image streams - low_res_stream = io.BytesIO(low_res_data) - high_res_stream = io.BytesIO(high_res_data) - - # Enable loading of truncated images - ImageFile.LOAD_TRUNCATED_IMAGES = True - - # Load and convert images to RGB - low_res_image = Image.open(low_res_stream).convert('RGB') - high_res_image = Image.open(high_res_stream).convert('RGB') - - # Create fresh copies for verify() since it modifies the image object - low_res_verify = low_res_image.copy() - high_res_verify = high_res_image.copy() - - # Verify images are valid - try: - low_res_verify.verify() - high_res_verify.verify() - except Exception as e: - raise ValueError(f"Image verification failed at index {idx}: {str(e)}") - finally: - low_res_verify.close() - high_res_verify.close() - - # Apply transforms if specified - if self.transform: - low_res_image = self.transform(low_res_image) - high_res_image = self.transform(high_res_image) - - return { - 'low_ress': low_res_image, # Note: Using 'low_ress' to match ModelTrainer - 'high_ress': high_res_image # Note: Using 'high_ress' to match ModelTrainer - } - - except Exception as e: - raise RuntimeError(f"Error loading images at index {idx}: {str(e)}") - - finally: - # Ensure streams are closed - if 'low_res_stream' in locals(): - low_res_stream.close() - if 'high_res_stream' in locals(): - high_res_stream.close() - -class SuperResolutionModel(AIIA): - def __init__(self, base_model: AIIA, config: AIIAConfig): - super(SuperResolutionModel, self).__init__(config=config) - # Use base model as encoder - self.encoder = base_model - for param in self.encoder.parameters(): - param.requires_grad = False # Freeze encoder layers - - # Add decoder layers to reconstruct image - self.decoder = nn.Sequential( - nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1), - nn.ReLU(), - nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1), - nn.ReLU(), - nn.ConvTranspose2d(128, 3, kernel_size=3, padding=1) - ) - - def forward(self, x): - features = self.encoder(x) - output = self.decoder(features) - return output - -class FineTuner: - def __init__(self, - model: AIIA, - dataset_paths: List[str], - batch_size: int = 32, - learning_rate: float = 0.001, - num_workers: int = 4, - train_ratio: float = 0.8, - output_dir: str = "./training_logs"): - """ - Specialized trainer for image super resolution tasks - - Args: - model (nn.Module): Model instance to finetune - dataset_paths (List[str]): Paths to datasets - batch_size (int): Batch size for training - learning_rate (float): Learning rate for optimizer - num_workers (int): Number of workers for data loading - train_ratio (float): Ratio of data to use for training (rest goes to validation) - output_dir (str): Directory to save training logs and model checkpoints - """ - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.batch_size = batch_size - self.num_workers = num_workers - self.dataset_paths = dataset_paths - self.learning_rate = learning_rate - self.train_ratio = train_ratio - self.model = model - self.output_dir = output_dir - - # Create output directory if it doesn't exist - os.makedirs(output_dir, exist_ok=True) - - # Initialize history tracking - self.train_losses = [] - self.val_losses = [] - self.best_val_loss = float('inf') - self.current_val_loss = float('inf') - - # Initialize CSV logging - self.log_file = os.path.join(output_dir, 'training_log.csv') - if not os.path.exists(self.log_file): - with open(self.log_file, 'w') as f: - f.write('epoch,train_loss,val_loss,best_val_loss\n') - - # Initialize datasets and loaders - self._initialize_datasets() - - # Initialize training parameters - self._initialize_training() - - def _freeze_layers(self): - """ - Freeze all layers except those that are part of the decoder or upsampling - We'll assume the last few layers are responsible for upsampling/reconstruction - """ - try: - # Try to identify encoder layers and freeze them - for name, param in self.model.named_parameters(): - if 'encoder' in name: - param.requires_grad = False - - # Unfreeze certain layers (example: last 3 decoder layers) - # Modify this based on your actual model architecture - for name, param in self.model.named_parameters(): - if 'decoder' in name and 'block4' in name or 'block5' in name: - param.requires_grad = True - - except Exception as e: - print(f"Warning: Couldn't freeze layers - {str(e)}") - pass - - def _initialize_datasets(self): - """ - Helper method to initialize datasets - """ - if isinstance(self.dataset_paths, list): - df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True) - else: - raise ValueError("Invalid dataset_paths format. Must be a list.") - - # Split into train and validation sets - df_train, df_val = train_test_split( - df_train, - test_size=1 - self.train_ratio, - random_state=42 - ) - - # Define preprocessing transforms with augmentation - self.transform = transforms.Compose([ - transforms.ToTensor(), - transforms.RandomResizedCrop(256), - transforms.RandomHorizontalFlip(), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + # Data augmentation pipeline + self.augmentation = Compose([ + Resize(height=512, width=512), + RandomBrightnessContrast(), + HorizontalFlip(p=0.5), + VerticalFlip(p=0.5), + Rotate(limit=45), + GaussianBlur(p=0.3), + Normalize(mean=[0.5], std=[0.5]), + ToTensorV2() ]) - - # Create datasets and dataloaders - self.train_dataset = ImageDataset(df_train, transform=self.transform) - self.val_dataset = ImageDataset(df_val, transform=self.transform) - - self.train_loader = DataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_workers - ) - - self.val_loader = DataLoader( - self.val_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers - ) if df_val is not None else None - - def _log_metrics(self, epoch: int, train_loss: float, val_loss: float): - """ - Log training metrics to CSV file - Args: - epoch (int): Current epoch number - train_loss (float): Training loss for the epoch - val_loss (float): Validation loss for the epoch - """ - with open(self.log_file, 'a') as f: - f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n') + def __len__(self): + return len(self.df) - def _initialize_training(self): - """ - Helper method to initialize training parameters - """ - # Freeze layers except those we want to finetune - self._freeze_layers() - - # Initialize optimizer and scheduler - params_to_optimize = [p for p in self.model.parameters() if p.requires_grad] + def __getitem__(self, idx): + # Get the byte strings + low_res_bytes = self.df.iloc[idx]['low_res'] + high_res_bytes = self.df.iloc[idx]['high_res'] - if not params_to_optimize: - raise ValueError("No parameters found to optimize") - - # Use Adam with weight decay for better regularization - self.optimizer = torch.optim.Adam( - params_to_optimize, - lr=self.learning_rate, - weight_decay=1e-4 # Add L2 regularization - ) - - # Reduce learning rate when validation loss plateaus - self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer, - factor=0.1, # Multiply LR by this factor on plateau - patience=3, # Number of epochs to wait before reducing LR - verbose=True - ) - - # Use a combination of L1 and L2 losses for better performance - self.criterion = nn.L1Loss() - self.mse_criterion = nn.MSELoss() - - def _train_epoch(self): - """ - Train model for one epoch - Returns: - float: Average training loss for the epoch - """ - self.model.train() - running_loss = 0.0 - - for batch in tqdm(self.train_loader, desc="Training"): - low_ress = batch['low_ress'].to(self.device) - high_ress = batch['high_ress'].to(self.device) - - # Forward pass (we'll use the model's existing architecture without adding layers) - try: - features = self.model(low_ress) - print("Features shape:", features.shape) # Check output dimensions - print("High-res shape:", high_ress.shape) # Check target dimensions - except Exception as e: - raise RuntimeError(f"Error during forward pass: {str(e)}") - - # Calculate loss with different scaling for L1 and MSE components - l1_loss = self.criterion(features, high_ress) * 0.5 - mse_loss = self.mse_criterion(features, high_ress) * 0.5 - total_loss = l1_loss + mse_loss - - # Backward pass and optimize - self.optimizer.zero_grad() - total_loss.backward() - self.optimizer.step() - - running_loss += total_loss.item() - - epoch_loss = running_loss / len(self.train_loader) - self.train_losses.append(epoch_loss) - print(f"Train Loss: {epoch_loss:.4f}") - return epoch_loss - - def _train_epoch(self): - """Train model for one epoch""" - self.model.train() - running_loss = 0.0 - - for batch in tqdm(self.train_loader, desc="Training"): - low_ress = batch['low_ress'].to(self.device) - high_ress = batch['high_ress'].to(self.device) - - # Forward pass - try: - outputs = self.model(low_ress) # Now outputs are images - print("Output shape:", outputs.shape) - print("High-res shape:", high_ress.shape) - except Exception as e: - raise RuntimeError(f"Error during forward pass: {str(e)}") - - # Calculate loss - l1_loss = self.criterion(outputs, high_ress) * 0.5 - mse_loss = self.mse_criterion(outputs, high_ress) * 0.5 - total_loss = l1_loss + mse_loss - - # Backward pass and optimize - self.optimizer.zero_grad() - total_loss.backward() - self.optimizer.step() - - running_loss += total_loss.item() - - epoch_loss = running_loss / len(self.train_loader) - self.train_lossess.append(epoch_loss) - print(f"Train Loss: {epoch_loss:.4f}") - return epoch_loss - - def train(self, num_epochs: int = 10): - """ - Train the model for specified number of epochs - Args: - num_epochs (int): Number of epochs to train for - """ - self.model.to(self.device) + # Convert bytes to numpy arrays + low_res = cv2.imdecode(np.frombuffer(low_res_bytes, np.uint8), -1) + high_res = cv2.imdecode(np.frombuffer(high_res_bytes, np.uint8), -1) - print(f"Training metrics will be logged to: {self.log_file}") - - for epoch in range(num_epochs): - print(f"\nEpoch {epoch+1}/{num_epochs}") - - # Train phase - train_loss = self._train_epoch() - - # Validation phase - if self.val_loader is not None: - val_loss = self._validate_epoch() - - # Update learning rate scheduler based on validation loss - self.scheduler.step(val_loss) - - # Log metrics - self._log_metrics(epoch + 1, train_loss, val_loss) - - # Save best model based on validation loss - if self.current_val_loss < self.best_val_loss: - print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}") - self.best_val_loss = self.current_val_loss - model_save_path = os.path.join(self.output_dir, "aiuNN-optimized") - self.model.save(model_save_path) - print(f"Model saved to: {model_save_path}") - - # After training, save the final model - final_model_path = os.path.join(self.output_dir, "aiuNN-final") - self.model.save(final_model_path) - print(f"\nFinal model saved to: {final_model_path}") + # Apply augmentation and normalization + augmented = self.augmentation(image=low_res, mask=high_res) + low_res = augmented['image'] + high_res = augmented['mask'] - -if __name__ == "__main__": - # Load your model first - config = AIIAConfig.load("/root/vision/AIIA/AIIA-base-512") - model = SuperResolutionModel(base_model=AIIABase.load("/root/vision/AIIA/AIIA-base-512"), config=config) + return { + 'low_res': low_res, + 'high_res': high_res + } - trainer = FineTuner( - model=model, - dataset_paths=[ - "/root/training_data/vision-dataset/image_upscaler.parquet", - "/root/training_data/vision-dataset/image_vec_upscaler.parquet" - ], - batch_size=2, # Increased batch size - learning_rate=1e-4 # Reduced initial LR + +def finetune_model(model: AIIA, train_parquet_path, val_parquet_path, batch_size=8, epochs = 10): + # Initialize dataset and dataloader + train_dataset = aiuNNDataset(train_parquet_path) + val_dataset = aiuNNDataset(val_parquet_path) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=4 ) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=4 + ) + + # Set device + device = 'cuda' if torch.cuda.is_available() else 'cpu' + model.to(device) + + # Define loss function and optimizer + criterion = nn.MSELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) + + best_val_loss = float('inf') + + for epoch in range(epochs): + model.train() + + train_loss = 0.0 + + for batch_idx, batch in enumerate(train_loader): + low_res = batch['low_res'].to(device) + high_res = batch['high_res'].to(device) + + # Forward pass + outputs = model(low_res) + + # Calculate loss + loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) # Adjust for channel dimensions + + # Backward pass and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_loss += loss.item() + + avg_train_loss = train_loss / len(train_loader) + + print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss:.4f}") + + # Validation + model.eval() + val_loss = 0.0 + + with torch.no_grad(): + for batch in val_loader: + low_res = batch['low_res'].to(device) + high_res = batch['high_res'].to(device) + + outputs = model(low_res) + loss = criterion(outputs, high_res.permute(0, 3, 1, 2)) + + val_loss += loss.item() + + avg_val_loss = val_loss / len(val_loader) + + print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") + + # Save best model + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + model.save("best_model") + + return model - trainer.train(num_epochs=10) # Extended training time \ No newline at end of file +def main(): + # Paths to your data + train_parquet_path = "/root/training_data/vision-dataset/image_upscaler.parquet" + val_parquet_path = "/root/training_data/vision-dataset/image_vec_upscaler.parquet" + + # Load pretrained model + model = AIIA.load("/root/vision/AIIA/AIIA-base-512") + + # Add final upsampling layer if needed (depending on your specific architecture) + if hasattr(model, 'chunked_'): + model.add_module('final_upsample', nn.Upsample(scale_factor=2, mode='bilinear')) + + # Fine-tune + finetune_model( + model, + train_parquet_path, + val_parquet_path + ) + +if __name__ == '__main__': + main() \ No newline at end of file