aiuNN/src/aiunn/finetune/trainer.py

290 lines
12 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import os
import csv
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.checkpoint import checkpoint
import gc
import time
import shutil
class EarlyStopping:
def __init__(self, patience=3, min_delta=0.001):
# Number of epochs with no significant improvement before stopping
# Minimum change in loss required to count as an improvement
self.patience = patience
self.min_delta = min_delta
self.best_loss = float('inf')
self.counter = 0
self.early_stop = False
def __call__(self, epoch_loss):
if epoch_loss < self.best_loss - self.min_delta:
self.best_loss = epoch_loss
self.counter = 0
return True # Improved
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return False # Not improved
class aiuNNTrainer:
def __init__(self, upscaler_model, dataset_class=None):
"""
Initialize the upscaler trainer
Args:
upscaler_model: The model to fine-tune
dataset_class: The dataset class to use for loading data (optional)
"""
self.model = upscaler_model
self.dataset_class = dataset_class
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device, memory_format=torch.channels_last)
self.criterion = nn.MSELoss()
self.optimizer = None
self.scaler = GradScaler()
self.best_loss = float('inf')
self.use_checkpointing = True
self.data_loader = None
self.validation_loader = None
self.log_dir = None
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
"""
Load data using either a custom dataset instance or the dataset class provided at initialization
Args:
dataset_params (dict/list): Parameters to pass to the dataset class constructor
batch_size (int): Batch size for training
validation_split (float): Proportion of data to use for validation
custom_train_dataset: A pre-instantiated dataset to use for training (optional)
custom_val_dataset: A pre-instantiated dataset to use for validation (optional)
"""
# If custom datasets are provided directly, use them
if custom_train_dataset is not None:
train_dataset = custom_train_dataset
val_dataset = custom_val_dataset if custom_val_dataset is not None else None
else:
# Otherwise instantiate dataset using the class and parameters
if self.dataset_class is None:
raise ValueError("No dataset class provided. Either provide a dataset class at initialization or custom datasets.")
# Create dataset instance
dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params})
# Split into train and validation sets
dataset_size = len(dataset)
val_size = int(validation_split * dataset_size)
train_size = dataset_size - val_size
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size]
)
# Create data loaders
self.data_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True
)
if val_dataset is not None:
self.validation_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True
)
print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples")
else:
self.validation_loader = None
print(f"Loaded {len(train_dataset)} training samples (no validation set)")
return self.data_loader, self.validation_loader
def _setup_logging(self, output_path):
"""Set up directory structure for logging and model checkpoints"""
timestamp = time.strftime("%Y%m%d-%H%M%S")
self.log_dir = os.path.join(output_path, f"training_run_{timestamp}")
os.makedirs(self.log_dir, exist_ok=True)
# Create checkpoint directory
self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True)
# Set up CSV logging
self.csv_path = os.path.join(self.log_dir, 'training_log.csv')
with open(self.csv_path, mode='w', newline='') as file:
writer = csv.writer(file)
if self.validation_loader:
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved'])
else:
writer.writerow(['Epoch', 'Train Loss', 'Improved'])
def _evaluate(self):
"""Evaluate the model on validation data"""
if self.validation_loader is None:
return 0.0
self.model.eval()
val_loss = 0.0
with torch.no_grad():
for low_res, high_res in tqdm(self.validation_loader, desc="Validating"):
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(self.device, non_blocking=True)
with autocast(device_type=self.device.type):
outputs = self.model(low_res)
loss = self.criterion(outputs, high_res)
val_loss += loss.item()
del low_res, high_res, outputs, loss
self.model.train()
return val_loss
def _save_checkpoint(self, epoch, is_best=False):
"""Save model checkpoint"""
checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt")
best_model_path = os.path.join(self.log_dir, "best_model")
# Save the model checkpoint
self.model.save(checkpoint_path)
# If this is the best model so far, copy it to best_model
if is_best:
if os.path.exists(best_model_path):
shutil.rmtree(best_model_path)
self.model.save(best_model_path)
print(f"Saved new best model with loss: {self.best_loss:.6f}")
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
"""
Finetune the upscaler model
Args:
output_path (str): Directory to save models and logs
epochs (int): Maximum number of training epochs
lr (float): Learning rate
patience (int): Early stopping patience
min_delta (float): Minimum improvement for early stopping
"""
# Check if data is loaded
if self.data_loader is None:
raise ValueError("Data not loaded. Call load_data first.")
# Setup optimizer
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
# Set up logging
self._setup_logging(output_path)
# Setup early stopping
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
# Training loop
self.model.train()
for epoch in range(epochs):
# Training phase
epoch_loss = 0.0
progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
for low_res, high_res in progress_bar:
# Move data to GPU with channels_last format where possible
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(self.device, non_blocking=True)
self.optimizer.zero_grad()
with autocast(device_type=self.device.type):
if self.use_checkpointing:
# Ensure the input tensor requires gradient so that checkpointing records the computation graph
low_res.requires_grad_()
outputs = checkpoint(self.model, low_res)
else:
outputs = self.model(low_res)
loss = self.criterion(outputs, high_res)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
epoch_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
# Optionally delete variables to free memory
del low_res, high_res, outputs, loss
# Calculate average epoch loss
avg_train_loss = epoch_loss / len(self.data_loader)
# Validation phase (if validation loader exists)
if self.validation_loader:
val_loss = self._evaluate() / len(self.validation_loader)
is_improved = val_loss < self.best_loss
if is_improved:
self.best_loss = val_loss
# Log results
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"])
else:
# If no validation, use training loss for improvement tracking
is_improved = avg_train_loss < self.best_loss
if is_improved:
self.best_loss = avg_train_loss
# Log results
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"])
# Save checkpoint
self._save_checkpoint(epoch + 1, is_best=is_improved)
# Perform garbage collection and clear GPU cache after each epoch
gc.collect()
torch.cuda.empty_cache()
# Check early stopping
early_stopping(val_loss if self.validation_loader else avg_train_loss)
if early_stopping.early_stop:
print(f"Early stopping triggered at epoch {epoch + 1}")
break
return self.best_loss
def save(self, output_path=None):
"""
Save the best model to the specified path
Args:
output_path (str, optional): Path to save the model. If None, uses the best model from training.
"""
if output_path is None and self.log_dir is not None:
best_model_path = os.path.join(self.log_dir, "best_model")
if os.path.exists(best_model_path):
print(f"Best model already saved at {best_model_path}")
return best_model_path
else:
output_path = os.path.join(self.log_dir, "final_model")
if output_path is None:
raise ValueError("No output path specified and no training has been done yet.")
self.model.save(output_path)
print(f"Model saved to {output_path}")
return output_path