diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 1644662..336f38c 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -41,83 +41,10 @@ class ImageDataset(Dataset): -class TrainingBase: - def __init__(self, - model_name: str, - dataset_paths: Union[List[str], Dict[str, str]], - batch_size: int = 32, - learning_rate: float = 0.001, - num_workers: int = 4, - train_ratio: float = 0.8): - """ - Base class for training models with multiple dataset support - - Args: - model_name (str): Name of the model to initialize - dataset_paths (Union[List[str], Dict[str, str]]): Paths to datasets (train and optional validation) - batch_size (int): Batch size for training - learning_rate (float): Learning rate for optimizer - num_workers (int): Number of workers for data loading - train_ratio (float): Ratio of data to use for training (rest goes to validation) - """ - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.batch_size = batch_size - self.num_workers = num_workers - - # Initialize datasets and loaders - self.dataset_paths = dataset_paths - self._initialize_datasets() - - # Initialize model and training parameters - self.model_name = model_name - self.learning_rate = learning_rate - self._initialize_model() - - def _initialize_datasets(self): - """Helper method to initialize datasets""" - raise NotImplementedError("This method should be implemented in child classes") - - def _initialize_model(self): - """Helper method to initialize model architecture""" - raise NotImplementedError("This method should be implemented in child classes") - - def train(self, num_epochs: int = 10): - """Train the model for specified number of epochs""" - self.model.to(self.device) - - for epoch in range(num_epochs): - print(f"Epoch {epoch+1}/{num_epochs}") - - # Train phase - self._train_epoch() - - # Validation phase - self._validate_epoch() - - # Save best model based on validation loss - if self.current_val_loss < self.best_val_loss: - self.save_model() - - def _train_epoch(self): - """Train model for one epoch""" - raise NotImplementedError("This method should be implemented in child classes") - - def _validate_epoch(self): - """Validate model performance""" - raise NotImplementedError("This method should be implemented in child classes") - - def save_model(self): - """Save current best model""" - torch.save({ - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'best_val_loss': self.best_val_loss - }, f"{self.model_name}_best.pth") - -class Finetuner(TrainingBase): +class ModelTrainer: def __init__(self, model_name: str = "AIIA-Base-512", - dataset_paths: Union[List[str], Dict[str, str]] = None, + dataset_paths: List[str] = None, batch_size: int = 32, learning_rate: float = 0.001, num_workers: int = 4, @@ -126,25 +53,42 @@ class Finetuner(TrainingBase): Specialized trainer for image super resolution tasks Args: - Same as TrainingBase + model_name (str): Name of the model to initialize + dataset_paths (List[str]): Paths to datasets + batch_size (int): Batch size for training + learning_rate (float): Learning rate for optimizer + num_workers (int): Number of workers for data loading + train_ratio (float): Ratio of data to use for training (rest goes to validation) """ - super().__init__(model_name, dataset_paths, batch_size, learning_rate, num_workers, train_ratio) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.batch_size = batch_size + self.num_workers = num_workers + self.dataset_paths = dataset_paths + self.model_name = model_name + self.learning_rate = learning_rate + self.train_ratio = train_ratio + # Initialize datasets and loaders + self._initialize_datasets() + + # Initialize model and training parameters + self._initialize_model() + def _initialize_datasets(self): - """Initialize image datasets""" - # Load dataframes from parquet files - if isinstance(self.dataset_paths, dict): - df_train = pd.read_parquet(self.dataset_paths['train']) - df_val = pd.read_parquet(self.dataset_paths['val']) if 'val' in self.dataset_paths else None - elif isinstance(self.dataset_paths, list): + """ + Helper method to initialize datasets + """ + # Read training data based on input format + if isinstance(self.dataset_paths, list): df_train = pd.concat([pd.read_parquet(path) for path in self.dataset_paths], ignore_index=True) - df_val = None else: - raise ValueError("Invalid dataset_paths format") + raise ValueError("Invalid dataset_paths format. Must be a list or dictionary.") - # Split into train and validation sets if needed - if df_val is None: - df_train, df_val = train_test_split(df_train, test_size=1 - self.train_ratio, random_state=42) + df_train, df_val = train_test_split( + df_train, + test_size=1 - self.train_ratio, + random_state=42 + ) # Define preprocessing transforms self.transform = transforms.Compose([ @@ -168,10 +112,12 @@ class Finetuner(TrainingBase): batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers - ) + ) if df_val is not None else None def _initialize_model(self): - """Initialize and modify the super resolution model""" + """ + Helper method to initialize model architecture and training parameters + """ # Load base model self.model = AIIABase.load(self.model_name) @@ -181,9 +127,10 @@ class Finetuner(TrainingBase): # Add upscaling layer hidden_size = self.model.config.hidden_size + kernel_size = self.model.config.kernel_size self.model.upsample = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), - nn.Conv2d(hidden_size, 3, kernel_size=3, padding=1) + nn.Conv2d(hidden_size, 3, kernel_size=kernel_size, padding=1) ) # Initialize optimizer and loss function @@ -195,14 +142,36 @@ class Finetuner(TrainingBase): self.best_val_loss = float('inf') + def train(self, num_epochs: int = 10): + """ + Train the model for specified number of epochs + """ + self.model.to(self.device) + + for epoch in range(num_epochs): + print(f"Epoch {epoch+1}/{num_epochs}") + + # Train phase + self._train_epoch() + + # Validation phase + if self.val_loader is not None: + self._validate_epoch() + + # Save best model based on validation loss + if self.val_loader is not None and self.current_val_loss < self.best_val_loss: + self.model.save("aiuNN-base") + def _train_epoch(self): - """Train model for one epoch""" + """ + Train model for one epoch + """ self.model.train() running_loss = 0.0 for batch in self.train_loader: - low_res = batch['low_res'].to(self.device) - high_res = batch['high_res'].to(self.device) + low_res = batch['low_ress'].to(self.device) + high_res = batch['high_ress'].to(self.device) # Forward pass features = self.model.cnn(low_res) @@ -221,14 +190,16 @@ class Finetuner(TrainingBase): print(f"Train Loss: {epoch_loss:.4f}") def _validate_epoch(self): - """Validate model performance""" + """ + Validate model performance + """ self.model.eval() - val_loss = 0.0 + val_oss = 0.0 with torch.no_grad(): for batch in self.val_loader: - low_res = batch['low_res'].to(self.device) - high_res = batch['high_res'].to(self.device) + low_res = batch['low_ress'].to(self.device) + high_res = batch['high_ress'].to(self.device) features = self.model.cnn(low_res) outputs = self.model.upsample(features) @@ -236,24 +207,25 @@ class Finetuner(TrainingBase): loss = self.criterion(outputs, high_res) val_loss += loss.item() - avg_val_loss = val_loss / len(self.val_loader) + avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0 print(f"Validation Loss: {avg_val_loss:.4f}") # Update best model if avg_val_loss < self.best_val_loss: self.best_val_loss = avg_val_loss - + def __repr__(self): return f"Model ({self.model_name}, batch_size={self.batch_size})" - - -# Example usage: + if __name__ == "__main__": - finetuner = Finetuner( - train_parquet_path="/root/training_data/vision-dataset/image_upscaler.parquet", - val_parquet_path="/root/training_data/vision-dataset/image_vec_upscaler.parquet", + trainer = ModelTrainer( + model_name="/root/vision/AIIA/AIIA-base-512/", + dataset_paths=[ + "/root/training_data/vision-dataset/image_upscaler.parquet", + "/root/training_data/vision-dataset/image_vec_upscaler.parquet" + ], batch_size=2, learning_rate=0.001 ) - finetuner.train_model(num_epochs=10) \ No newline at end of file + trainer.train(num__epochs=3) \ No newline at end of file