From 2121316e3b452b97a89360e2918f5c1928698bb1 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Thu, 30 Jan 2025 10:36:15 +0100 Subject: [PATCH] finetune improvement --- src/aiunn/finetune.py | 145 ++++++++++++++++++++++-------------------- 1 file changed, 75 insertions(+), 70 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 336f38c..5174d87 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -5,10 +5,9 @@ import io from torch import nn from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms -from aiia.model import AIIABase +from aiia.model import AIIABase, AIIA from sklearn.model_selection import train_test_split -from typing import Dict, List, Union - +from typing import Dict, List, Union, Optional class ImageDataset(Dataset): def __init__(self, dataframe, transform=None): @@ -36,24 +35,21 @@ class ImageDataset(Dataset): low_res_image = self.transform(low_res_image) high_res_image = self.transform(high_res_image) - return {'low_res': low_res_image, 'high_res': high_res_image} - - - + return {'low_ress': low_res_image, 'high_ress': high_res_image} class ModelTrainer: def __init__(self, - model_name: str = "AIIA-Base-512", - dataset_paths: List[str] = None, + model: AIIA, + dataset_paths: List[str], batch_size: int = 32, learning_rate: float = 0.001, num_workers: int = 4, train_ratio: float = 0.8): """ Specialized trainer for image super resolution tasks - + Args: - model_name (str): Name of the model to initialize + model (nn.Module): Model instance to finetune dataset_paths (List[str]): Paths to datasets batch_size (int): Batch size for training learning_rate (float): Learning rate for optimizer @@ -64,120 +60,126 @@ class ModelTrainer: self.batch_size = batch_size self.num_workers = num_workers self.dataset_paths = dataset_paths - self.model_name = model_name self.learning_rate = learning_rate self.train_ratio = train_ratio - + self.model = model + # Initialize datasets and loaders self._initialize_datasets() - - # Initialize model and training parameters - self._initialize_model() - + + # Initialize training parameters + self._initialize_training() + def _initialize_datasets(self): """ Helper method to initialize datasets """ - # Read training data based on input format if isinstance(self.dataset_paths, list): df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True) else: - raise ValueError("Invalid dataset_paths format. Must be a list or dictionary.") - + raise ValueError("Invalid dataset_paths format. Must be a list.") + df_train, df_val = train_test_split( df_train, test_size=1 - self.train_ratio, random_state=42 ) - + # Define preprocessing transforms self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) - + # Create datasets and dataloaders self.train_dataset = ImageDataset(df_train, transform=self.transform) self.val_dataset = ImageDataset(df_val, transform=self.transform) - + self.train_loader = DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers ) - + self.val_loader = DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers ) if df_val is not None else None - - def _initialize_model(self): + + def _initialize_training(self): """ - Helper method to initialize model architecture and training parameters + Helper method to initialize training parameters """ - # Load base model - self.model = AIIABase.load(self.model_name) - - # Freeze CNN layers - for param in self.model.cnn.parameters(): - param.requires_grad = False - - # Add upscaling layer - hidden_size = self.model.config.hidden_size - kernel_size = self.model.config.kernel_size - self.model.upsample = nn.Sequential( - nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), - nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1) - ) - + # Freeze CNN layers (if applicable) + try: + for param in self.model.cnn.parameters(): + param.requires_grad = False + except AttributeError: + pass # If model doesn't have a 'cnn' attribute, just continue + + # Add upscaling layer if not already present + if not hasattr(self.model, 'upsample'): + hidden_size = self.model.config.hidden_size + kernel_size = self.model.config.kernel_size + self.model.upsample = nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1) + ) + # Initialize optimizer and loss function self.criterion = nn.MSELoss() + + # Get parameters of the upsample layer for training + params = [p for p in self.model.upsample.parameters() if p.requires_grad] + if not params: + raise ValueError("No parameters found in upsample layer to optimize") + self.optimizer = torch.optim.Adam( - [param for param in self.model.parameters() if 'upsample' in str(param)], + params, lr=self.learning_rate ) - + self.best_val_loss = float('inf') - + def train(self, num_epochs: int = 10): """ Train the model for specified number of epochs """ self.model.to(self.device) - + for epoch in range(num_epochs): print(f"Epoch {epoch+1}/{num_epochs}") - + # Train phase self._train_epoch() - + # Validation phase if self.val_loader is not None: self._validate_epoch() - + # Save best model based on validation loss if self.val_loader is not None and self.current_val_loss < self.best_val_loss: - self.model.save("aiuNN-base") - + self.model.save("aiuNN-finetuned") + def _train_epoch(self): """ Train model for one epoch """ self.model.train() running_loss = 0.0 - + for batch in self.train_loader: - low_res = batch['low_ress'].to(self.device) - high_res = batch['high_ress'].to(self.device) + low_ress = batch['low_ress'].to(self.device) + high_ress = batch['high_ress'].to(self.device) # Forward pass - features = self.model.cnn(low_res) + features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress) outputs = self.model.upsample(features) - loss = self.criterion(outputs, high_res) + loss = self.criterion(outputs, high_ress) # Backward pass and optimize self.optimizer.zero_grad() @@ -185,41 +187,44 @@ class ModelTrainer: self.optimizer.step() running_loss += loss.item() - + epoch_loss = running_loss / len(self.train_loader) print(f"Train Loss: {epoch_loss:.4f}") - + def _validate_epoch(self): """ Validate model performance """ self.model.eval() - val_oss = 0.0 + val_loss = 0.0 with torch.no_grad(): for batch in self.val_loader: - low_res = batch['low_ress'].to(self.device) - high_res = batch['high_ress'].to(self.device) + low_ress = batch['low_ress'].to(self.device) + high_ress = batch['high_ress'].to(self.device) - features = self.model.cnn(low_res) + features = self.model.cnn(low_ress) if hasattr(self.model, 'cnn') else self.model.extract_features(low_ress) outputs = self.model.upsample(features) - loss = self.criterion(outputs, high_res) + loss = self.criterion(outputs, high_ress) val_loss += loss.item() avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0 print(f"Validation Loss: {avg_val_loss:.4f}") - + # Update best model if avg_val_loss < self.best_val_loss: self.best_val_loss = avg_val_loss - + def __repr__(self): - return f"Model ({self.model_name}, batch_size={self.batch_size})" - + return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})" + if __name__ == "__main__": + # Load your model first + model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/") + trainer = ModelTrainer( - model_name="/root/vision/AIIA/AIIA-base-512/", + model=model, dataset_paths=[ "/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet" @@ -227,5 +232,5 @@ if __name__ == "__main__": batch_size=2, learning_rate=0.001 ) - - trainer.train(num__epochs=3) \ No newline at end of file + + trainer.train(num_epochs=3) \ No newline at end of file