import torch import torch.nn as nn import torch.optim as optim import os import csv from torch.amp import autocast, GradScaler from torch.utils.data import DataLoader from tqdm import tqdm from torch.utils.checkpoint import checkpoint import gc import time import shutil import datetime class EarlyStopping: def __init__(self, patience=3, min_delta=0.001): # Number of epochs with no significant improvement before stopping # Minimum change in loss required to count as an improvement self.patience = patience self.min_delta = min_delta self.best_loss = float('inf') self.counter = 0 self.early_stop = False def __call__(self, epoch_loss): if epoch_loss < self.best_loss - self.min_delta: self.best_loss = epoch_loss self.counter = 0 return True # Improved else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True return False # Not improved class aiuNNTrainer: def __init__(self, upscaler_model, dataset_class=None): """ Initialize the upscaler trainer Args: upscaler_model: The model to fine-tune dataset_class: The dataset class to use for loading data (optional) """ self.model = upscaler_model self.dataset_class = dataset_class self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.model.to(self.device, memory_format=torch.channels_last) self.criterion = nn.MSELoss() self.optimizer = None self.scaler = GradScaler() self.best_loss = float('inf') self.csv_path = None self.checkpoint_dir = None self.data_loader = None self.validation_loader = None self.last_checkpoint_time = time.time() self.checkpoint_interval = 2 * 60 * 60 # 2 hours self.last_22_date = None self.recent_checkpoints = [] self.current_epoch = 0 def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None): """ Load data using either a custom dataset instance or the dataset class provided at initialization Args: dataset_params (dict/list): Parameters to pass to the dataset class constructor batch_size (int): Batch size for training validation_split (float): Proportion of data to use for validation custom_train_dataset: A pre-instantiated dataset to use for training (optional) custom_val_dataset: A pre-instantiated dataset to use for validation (optional) """ # If custom datasets are provided directly, use them if custom_train_dataset is not None: train_dataset = custom_train_dataset val_dataset = custom_val_dataset if custom_val_dataset is not None else None else: # Otherwise instantiate dataset using the class and parameters if self.dataset_class is None: raise ValueError("No dataset class provided. Either provide a dataset class at initialization or custom datasets.") # Create dataset instance dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params}) # Split into train and validation sets dataset_size = len(dataset) val_size = int(validation_split * dataset_size) train_size = dataset_size - val_size train_dataset, val_dataset = torch.utils.data.random_split( dataset, [train_size, val_size] ) # Create data loaders self.data_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True ) if val_dataset is not None: self.validation_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True ) print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples") else: self.validation_loader = None print(f"Loaded {len(train_dataset)} training samples (no validation set)") return self.data_loader, self.validation_loader def _setup_logging(self, output_path): """Set up basic logging and checkpoint directory""" # Create checkpoint directory self.checkpoint_dir = os.path.join(output_path, "checkpoints") os.makedirs(self.checkpoint_dir, exist_ok=True) # Set up CSV logging self.csv_path = os.path.join(output_path, 'training_log.csv') with open(self.csv_path, mode='w', newline='') as file: writer = csv.writer(file) if self.validation_loader: writer.writerow(['Epoch', 'Train Loss', 'Validation Loss']) else: writer.writerow(['Epoch', 'Train Loss']) def _evaluate(self): """Evaluate the model on validation data""" if self.validation_loader is None: return 0.0 self.model.eval() val_loss = 0.0 with torch.no_grad(): for low_res, high_res in tqdm(self.validation_loader, desc="Validating"): low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) high_res = high_res.to(self.device, non_blocking=True) with autocast(device_type=self.device.type): outputs = self.model(low_res) loss = self.criterion(outputs, high_res) val_loss += loss.item() del low_res, high_res, outputs, loss self.model.train() return val_loss def _save_checkpoint(self, epoch, batch_count, is_best=False, is_22=False): """Save checkpoint with support for regular, best, and 22:00 saves""" if is_22: today = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))).date() checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" else: checkpoint_name = f"checkpoint_epoch{epoch}_batch{batch_count}.pt" checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name) checkpoint_data = { 'epoch': epoch, 'batch': batch_count, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'best_loss': self.best_loss, 'scaler_state_dict': self.scaler.state_dict() } torch.save(checkpoint_data, checkpoint_path) # Save best model separately if is_best: best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model") self.model.save_pretrained(best_model_path) return checkpoint_path def _handle_checkpoints(self, epoch, batch_count, is_improved): """Handle periodic and 22:00 checkpoint saving""" current_time = time.time() current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # Regular interval checkpoint if (current_time - self.last_checkpoint_time) >= self.checkpoint_interval: self._save_checkpoint(epoch, batch_count, is_improved) self.last_checkpoint_time = current_time # Special 22:00 checkpoint is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15 if is_22_oclock and self.last_22_date != current_dt.date(): self._save_checkpoint(epoch, batch_count, is_improved, is_22=True) self.last_22_date = current_dt.date() def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001): """Finetune the upscaler model""" if self.data_loader is None: raise ValueError("Data not loaded. Call load_data first.") # setup logging self._setup_logging(output_path=output_path) # Setup optimizer and directories self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.checkpoint_dir = os.path.join(output_path, "checkpoints") os.makedirs(self.checkpoint_dir, exist_ok=True) # Setup CSV logging self.csv_path = os.path.join(output_path, 'training_log.csv') with open(self.csv_path, mode='w', newline='') as file: writer = csv.writer(file) header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss'] writer.writerow(header) # Load existing checkpoint if available checkpoint_info = self.load_checkpoint() start_epoch = checkpoint_info[0] if checkpoint_info else 0 start_batch = checkpoint_info[1] if checkpoint_info else 0 # Setup early stopping early_stopping = EarlyStopping(patience=patience, min_delta=min_delta) self.best_loss = float('inf') # Training loop self.model.train() for epoch in range(start_epoch, epochs): self.current_epoch = epoch epoch_loss = 0.0 train_batches = list(enumerate(self.data_loader)) start_idx = start_batch if epoch == start_epoch else 0 progress_bar = tqdm(train_batches[start_idx:], initial=start_idx, total=len(train_batches), desc=f"Epoch {epoch + 1}/{epochs}") for batch_idx, (low_res, high_res) in progress_bar: # Training step low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) high_res = high_res.to(self.device, non_blocking=True) self.optimizer.zero_grad() with autocast(device_type=self.device.type): if hasattr(self, 'use_checkpointing') and self.use_checkpointing: low_res.requires_grad_() outputs = checkpoint(self.model, low_res) else: outputs = self.model(low_res) loss = self.criterion(outputs, high_res) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() epoch_loss += loss.item() progress_bar.set_postfix({'loss': loss.item()}) # Handle checkpoints self._handle_checkpoints(epoch + 1, batch_idx + 1, loss.item() < self.best_loss) del low_res, high_res, outputs, loss # End of epoch processing avg_train_loss = epoch_loss / len(self.data_loader) # Validation phase if self.validation_loader: val_loss = self._evaluate() / len(self.validation_loader) is_improved = val_loss < self.best_loss if is_improved: self.best_loss = val_loss # Log to CSV with open(self.csv_path, mode='a', newline='') as file: writer = csv.writer(file) writer.writerow([epoch + 1, avg_train_loss, val_loss]) print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}") else: is_improved = avg_train_loss < self.best_loss if is_improved: self.best_loss = avg_train_loss # Log to CSV with open(self.csv_path, mode='a', newline='') as file: writer = csv.writer(file) writer.writerow([epoch + 1, avg_train_loss]) print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}") # Save best model if improved if is_improved: best_model_path = os.path.join(output_path, "best_model") self.model.save_pretrained(best_model_path) # Check early stopping if early_stopping(val_loss if self.validation_loader else avg_train_loss): print(f"Early stopping triggered at epoch {epoch + 1}") break # Cleanup gc.collect() torch.cuda.empty_cache() return self.best_loss def load_checkpoint(self, specific_checkpoint=None): """Enhanced checkpoint loading with specific checkpoint support""" if specific_checkpoint: checkpoint_path = os.path.join(self.checkpoint_dir, specific_checkpoint) else: checkpoint_files = [f for f in os.listdir(self.checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")] if not checkpoint_files: return None checkpoint_files.sort(key=lambda x: os.path.getmtime( os.path.join(self.checkpoint_dir, x)), reverse=True) checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_files[0]) if not os.path.exists(checkpoint_path): return None checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scaler.load_state_dict(checkpoint['scaler_state_dict']) self.best_loss = checkpoint['best_loss'] print(f"Loaded checkpoint from {checkpoint_path}") return checkpoint['epoch'], checkpoint['batch'] def save(self, output_path=None): """ Save the best model to the specified path Args: output_path (str, optional): Path to save the model. If None, tries to use the checkpoint directory from training. Returns: str: Path where the model was saved Raises: ValueError: If no output path is specified and no checkpoint directory exists """ if output_path is None and self.checkpoint_dir is not None: # First try to copy the best model if it exists best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model") if os.path.exists(best_model_path): output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model") shutil.copytree(best_model_path, output_path, dirs_exist_ok=True) print(f"Copied best model to {output_path}") return output_path else: # If no best model exists, save current model state output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model") if output_path is None: raise ValueError("No output path specified and no checkpoint directory exists from training.") self.model.save_pretrained(output_path) print(f"Model saved to {output_path}") return output_path