import torch import torch.nn as nn import torch.optim as optim import os import csv from torch.amp import autocast, GradScaler from torch.utils.data import DataLoader from tqdm import tqdm from torch.utils.checkpoint import checkpoint import gc import time import shutil class EarlyStopping: def __init__(self, patience=3, min_delta=0.001): # Number of epochs with no significant improvement before stopping # Minimum change in loss required to count as an improvement self.patience = patience self.min_delta = min_delta self.best_loss = float('inf') self.counter = 0 self.early_stop = False def __call__(self, epoch_loss): if epoch_loss < self.best_loss - self.min_delta: self.best_loss = epoch_loss self.counter = 0 return True # Improved else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True return False # Not improved class aiuNNTrainer: def __init__(self, upscaler_model, dataset_class=None): """ Initialize the upscaler trainer Args: upscaler_model: The model to fine-tune dataset_class: The dataset class to use for loading data (optional) """ self.model = upscaler_model self.dataset_class = dataset_class self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.model.to(self.device, memory_format=torch.channels_last) self.criterion = nn.MSELoss() self.optimizer = None self.scaler = GradScaler() self.best_loss = float('inf') self.use_checkpointing = True self.data_loader = None self.validation_loader = None self.log_dir = None def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None): """ Load data using either a custom dataset instance or the dataset class provided at initialization Args: dataset_params (dict/list): Parameters to pass to the dataset class constructor batch_size (int): Batch size for training validation_split (float): Proportion of data to use for validation custom_train_dataset: A pre-instantiated dataset to use for training (optional) custom_val_dataset: A pre-instantiated dataset to use for validation (optional) """ # If custom datasets are provided directly, use them if custom_train_dataset is not None: train_dataset = custom_train_dataset val_dataset = custom_val_dataset if custom_val_dataset is not None else None else: # Otherwise instantiate dataset using the class and parameters if self.dataset_class is None: raise ValueError("No dataset class provided. Either provide a dataset class at initialization or custom datasets.") # Create dataset instance dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params}) # Split into train and validation sets dataset_size = len(dataset) val_size = int(validation_split * dataset_size) train_size = dataset_size - val_size train_dataset, val_dataset = torch.utils.data.random_split( dataset, [train_size, val_size] ) # Create data loaders self.data_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True ) if val_dataset is not None: self.validation_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True ) print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples") else: self.validation_loader = None print(f"Loaded {len(train_dataset)} training samples (no validation set)") return self.data_loader, self.validation_loader def _setup_logging(self, output_path): """Set up directory structure for logging and model checkpoints""" timestamp = time.strftime("%Y%m%d-%H%M%S") self.log_dir = os.path.join(output_path, f"training_run_{timestamp}") os.makedirs(self.log_dir, exist_ok=True) # Create checkpoint directory self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints") os.makedirs(self.checkpoint_dir, exist_ok=True) # Set up CSV logging self.csv_path = os.path.join(self.log_dir, 'training_log.csv') with open(self.csv_path, mode='w', newline='') as file: writer = csv.writer(file) if self.validation_loader: writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved']) else: writer.writerow(['Epoch', 'Train Loss', 'Improved']) def _evaluate(self): """Evaluate the model on validation data""" if self.validation_loader is None: return 0.0 self.model.eval() val_loss = 0.0 with torch.no_grad(): for low_res, high_res in tqdm(self.validation_loader, desc="Validating"): low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) high_res = high_res.to(self.device, non_blocking=True) with autocast(device_type=self.device.type): outputs = self.model(low_res) loss = self.criterion(outputs, high_res) val_loss += loss.item() del low_res, high_res, outputs, loss self.model.train() return val_loss def _save_checkpoint(self, epoch, is_best=False): """Save model checkpoint""" checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt") best_model_path = os.path.join(self.log_dir, "best_model") # Save the model checkpoint self.model.save(checkpoint_path) # If this is the best model so far, copy it to best_model if is_best: if os.path.exists(best_model_path): shutil.rmtree(best_model_path) self.model.save(best_model_path) print(f"Saved new best model with loss: {self.best_loss:.6f}") def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001): """ Finetune the upscaler model Args: output_path (str): Directory to save models and logs epochs (int): Maximum number of training epochs lr (float): Learning rate patience (int): Early stopping patience min_delta (float): Minimum improvement for early stopping """ # Check if data is loaded if self.data_loader is None: raise ValueError("Data not loaded. Call load_data first.") # Setup optimizer self.optimizer = optim.Adam(self.model.parameters(), lr=lr) # Set up logging self._setup_logging(output_path) # Setup early stopping early_stopping = EarlyStopping(patience=patience, min_delta=min_delta) # Training loop self.model.train() for epoch in range(epochs): # Training phase epoch_loss = 0.0 progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}") for low_res, high_res in progress_bar: # Move data to GPU with channels_last format where possible low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) high_res = high_res.to(self.device, non_blocking=True) self.optimizer.zero_grad() with autocast(device_type=self.device.type): if self.use_checkpointing: # Ensure the input tensor requires gradient so that checkpointing records the computation graph low_res.requires_grad_() outputs = checkpoint(self.model, low_res) else: outputs = self.model(low_res) loss = self.criterion(outputs, high_res) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() epoch_loss += loss.item() progress_bar.set_postfix({'loss': loss.item()}) # Optionally delete variables to free memory del low_res, high_res, outputs, loss # Calculate average epoch loss avg_train_loss = epoch_loss / len(self.data_loader) # Validation phase (if validation loader exists) if self.validation_loader: val_loss = self._evaluate() / len(self.validation_loader) is_improved = val_loss < self.best_loss if is_improved: self.best_loss = val_loss # Log results print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}") with open(self.csv_path, mode='a', newline='') as file: writer = csv.writer(file) writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"]) else: # If no validation, use training loss for improvement tracking is_improved = avg_train_loss < self.best_loss if is_improved: self.best_loss = avg_train_loss # Log results print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}") with open(self.csv_path, mode='a', newline='') as file: writer = csv.writer(file) writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"]) # Save checkpoint self._save_checkpoint(epoch + 1, is_best=is_improved) # Perform garbage collection and clear GPU cache after each epoch gc.collect() torch.cuda.empty_cache() # Check early stopping early_stopping(val_loss if self.validation_loader else avg_train_loss) if early_stopping.early_stop: print(f"Early stopping triggered at epoch {epoch + 1}") break return self.best_loss def save(self, output_path=None): """ Save the best model to the specified path Args: output_path (str, optional): Path to save the model. If None, uses the best model from training. """ if output_path is None and self.log_dir is not None: best_model_path = os.path.join(self.log_dir, "best_model") if os.path.exists(best_model_path): print(f"Best model already saved at {best_model_path}") return best_model_path else: output_path = os.path.join(self.log_dir, "final_model") if output_path is None: raise ValueError("No output path specified and no training has been done yet.") self.model.save(output_path) print(f"Model saved to {output_path}") return output_path