diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 545de8c..4cf3d2e 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split from typing import Dict, List, Union, Optional import base64 from tqdm import tqdm - +import os class ImageDataset(Dataset): @@ -99,14 +99,15 @@ class ImageDataset(Dataset): if 'high_res_stream' in locals(): high_res_stream.close() -class ModelTrainer: +class FineTuner: def __init__(self, model: AIIA, dataset_paths: List[str], batch_size: int = 32, learning_rate: float = 0.001, num_workers: int = 4, - train_ratio: float = 0.8): + train_ratio: float = 0.8, + output_dir: str = "./training_logs"): """ Specialized trainer for image super resolution tasks @@ -117,6 +118,7 @@ class ModelTrainer: learning_rate (float): Learning rate for optimizer num_workers (int): Number of workers for data loading train_ratio (float): Ratio of data to use for training (rest goes to validation) + output_dir (str): Directory to save training logs and model checkpoints """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.batch_size = batch_size @@ -125,6 +127,22 @@ class ModelTrainer: self.learning_rate = learning_rate self.train_ratio = train_ratio self.model = model + self.output_dir = output_dir + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Initialize history tracking + self.train_losses = [] + self.val_losses = [] + self.best_val_loss = float('inf') + self.current_val_loss = float('inf') + + # Initialize CSV logging + self.log_file = os.path.join(output_dir, 'training_log.csv') + if not os.path.exists(self.log_file): + with open(self.log_file, 'w') as f: + f.write('epoch,train_loss,val_loss,best_val_loss\n') # Initialize datasets and loaders self._initialize_datasets() @@ -171,6 +189,19 @@ class ModelTrainer: num_workers=self.num_workers ) if df_val is not None else None + def _log_metrics(self, epoch: int, train_loss: float, val_loss: float): + """ + Log training metrics to CSV file + + Args: + epoch (int): Current epoch number + train_loss (float): Training loss for the epoch + val_loss (float): Validation loss for the epoch + """ + with open(self.log_file, 'a') as f: + f.write(f'{epoch},{train_loss:.6f},{val_loss:.6f},{self.best_val_loss:.6f}\n') + + def _initialize_training(self): """ Helper method to initialize training parameters @@ -211,29 +242,11 @@ class ModelTrainer: self.best_val_loss = float('inf') - def train(self, num_epochs: int = 10): - """ - Train the model for specified number of epochs - """ - self.model.to(self.device) - - for epoch in tqdm(range(num_epochs), desc="Training"): - print(f"Epoch {epoch+1}/{num_epochs}") - - # Train phase - self._train_epoch() - - # Validation phase - if self.val_loader is not None: - self._validate_epoch() - - # Save best model based on validation loss - if self.val_loader is not None and self.current_val_loss < self.best_val_loss: - self.model.save("aiuNN-finetuned") - def _train_epoch(self): """ Train model for one epoch + Returns: + float: Average training loss for the epoch """ self.model.train() running_loss = 0.0 @@ -256,11 +269,15 @@ class ModelTrainer: running_loss += loss.item() epoch_loss = running_loss / len(self.train_loader) + self.train_losses.append(epoch_loss) print(f"Train Loss: {epoch_loss:.4f}") + return epoch_loss def _validate_epoch(self): """ Validate model performance + Returns: + float: Average validation loss for the epoch """ self.model.eval() val_loss = 0.0 @@ -276,12 +293,41 @@ class ModelTrainer: loss = self.criterion(outputs, high_ress) val_loss += loss.item() - avg_val_loss = val_loss / len(self.val_loader) if self.val_loader else 0 - print(f"Validation Loss: {avg_val_loss:.4f}") + self.current_val_loss = val_loss / len(self.val_loader) + self.val_losses.append(self.current_val_loss) + print(f"Validation Loss: {self.current_val_loss:.4f}") + return self.current_val_loss - # Update best model - if avg_val_loss < self.best_val_loss: - self.best_val_loss = avg_val_loss + def train(self, num_epochs: int = 10): + """ + Train the model for specified number of epochs + Args: + num_epochs (int): Number of epochs to train for + """ + self.model.to(self.device) + + print(f"Training metrics will be logged to: {self.log_file}") + + for epoch in range(num_epochs): + print(f"\nEpoch {epoch+1}/{num_epochs}") + + # Train phase + train_loss = self._train_epoch() + + # Validation phase + if self.val_loader is not None: + val_loss = self._validate_epoch() + + # Log metrics + self._log_metrics(epoch + 1, train_loss, val_loss) + + # Save best model based on validation loss + if self.current_val_loss < self.best_val_loss: + print(f"Validation loss improved from {self.best_val_loss:.4f} to {self.current_val_loss:.4f}") + self.best_val_loss = self.current_val_loss + model_save_path = os.path.join(self.output_dir, "aiuNN-finetuned") + self.model.save(model_save_path) + print(f"Model saved to: {model_save_path}") def __repr__(self): return f"ModelTrainer (model={type(self.model).__name__}, batch_size={self.batch_size})" @@ -290,7 +336,7 @@ if __name__ == "__main__": # Load your model first model = AIIABase.load("/root/vision/AIIA/AIIA-base-512/") - trainer = ModelTrainer( + trainer = FineTuner( model=model, dataset_paths=[ "/root/training_data/vision-dataset/image_upscaler.parquet",