aiuNN/src/aiunn/finetune/trainer.py

370 lines
16 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import os
import csv
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.checkpoint import checkpoint
import gc
import time
import shutil
import datetime
class EarlyStopping:
def __init__(self, patience=3, min_delta=0.001):
# Number of epochs with no significant improvement before stopping
# Minimum change in loss required to count as an improvement
self.patience = patience
self.min_delta = min_delta
self.best_loss = float('inf')
self.counter = 0
self.early_stop = False
def __call__(self, epoch_loss):
if epoch_loss < self.best_loss - self.min_delta:
self.best_loss = epoch_loss
self.counter = 0
return True # Improved
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return False # Not improved
class aiuNNTrainer:
def __init__(self, upscaler_model, dataset_class=None):
"""
Initialize the upscaler trainer
Args:
upscaler_model: The model to fine-tune
dataset_class: The dataset class to use for loading data (optional)
"""
self.model = upscaler_model
self.dataset_class = dataset_class
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device, memory_format=torch.channels_last)
self.criterion = nn.MSELoss()
self.optimizer = None
self.scaler = GradScaler()
self.best_loss = float('inf')
self.csv_path = None
self.checkpoint_dir = None
self.data_loader = None
self.validation_loader = None
self.last_checkpoint_time = time.time()
self.checkpoint_interval = 2 * 60 * 60 # 2 hours
self.last_22_date = None
self.recent_checkpoints = []
self.current_epoch = 0
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
"""
Load data using either a custom dataset instance or the dataset class provided at initialization
Args:
dataset_params (dict/list): Parameters to pass to the dataset class constructor
batch_size (int): Batch size for training
validation_split (float): Proportion of data to use for validation
custom_train_dataset: A pre-instantiated dataset to use for training (optional)
custom_val_dataset: A pre-instantiated dataset to use for validation (optional)
"""
# If custom datasets are provided directly, use them
if custom_train_dataset is not None:
train_dataset = custom_train_dataset
val_dataset = custom_val_dataset if custom_val_dataset is not None else None
else:
# Otherwise instantiate dataset using the class and parameters
if self.dataset_class is None:
raise ValueError("No dataset class provided. Either provide a dataset class at initialization or custom datasets.")
# Create dataset instance
dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params})
# Split into train and validation sets
dataset_size = len(dataset)
val_size = int(validation_split * dataset_size)
train_size = dataset_size - val_size
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size]
)
# Create data loaders
self.data_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True
)
if val_dataset is not None:
self.validation_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True
)
print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples")
else:
self.validation_loader = None
print(f"Loaded {len(train_dataset)} training samples (no validation set)")
return self.data_loader, self.validation_loader
def _setup_logging(self, output_path):
"""Set up basic logging and checkpoint directory"""
# Create checkpoint directory
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True)
# Set up CSV logging
self.csv_path = os.path.join(output_path, 'training_log.csv')
with open(self.csv_path, mode='w', newline='') as file:
writer = csv.writer(file)
if self.validation_loader:
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
else:
writer.writerow(['Epoch', 'Train Loss'])
def _evaluate(self):
"""Evaluate the model on validation data"""
if self.validation_loader is None:
return 0.0
self.model.eval()
val_loss = 0.0
with torch.no_grad():
for low_res, high_res in tqdm(self.validation_loader, desc="Validating"):
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(self.device, non_blocking=True)
with autocast(device_type=self.device.type):
outputs = self.model(low_res)
loss = self.criterion(outputs, high_res)
val_loss += loss.item()
del low_res, high_res, outputs, loss
self.model.train()
return val_loss
def _save_checkpoint(self, epoch, batch_count, is_best=False, is_22=False):
"""Save checkpoint with support for regular, best, and 22:00 saves"""
if is_22:
today = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))).date()
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
else:
checkpoint_name = f"checkpoint_epoch{epoch}_batch{batch_count}.pt"
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
checkpoint_data = {
'epoch': epoch,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_loss': self.best_loss,
'scaler_state_dict': self.scaler.state_dict()
}
torch.save(checkpoint_data, checkpoint_path)
# Save best model separately
if is_best:
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
self.model.save_pretrained(best_model_path)
return checkpoint_path
def _handle_checkpoints(self, epoch, batch_count, is_improved):
"""Handle periodic and 22:00 checkpoint saving"""
current_time = time.time()
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2)))
# Regular interval checkpoint
if (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
self._save_checkpoint(epoch, batch_count, is_improved)
self.last_checkpoint_time = current_time
# Special 22:00 checkpoint
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
if is_22_oclock and self.last_22_date != current_dt.date():
self._save_checkpoint(epoch, batch_count, is_improved, is_22=True)
self.last_22_date = current_dt.date()
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
"""Finetune the upscaler model"""
if self.data_loader is None:
raise ValueError("Data not loaded. Call load_data first.")
# setup logging
self._setup_logging(output_path=output_path)
# Setup optimizer and directories
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True)
# Setup CSV logging
self.csv_path = os.path.join(output_path, 'training_log.csv')
with open(self.csv_path, mode='w', newline='') as file:
writer = csv.writer(file)
header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss']
writer.writerow(header)
# Load existing checkpoint if available
checkpoint_info = self.load_checkpoint()
start_epoch = checkpoint_info[0] if checkpoint_info else 0
start_batch = checkpoint_info[1] if checkpoint_info else 0
# Setup early stopping
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
self.best_loss = float('inf')
# Training loop
self.model.train()
for epoch in range(start_epoch, epochs):
self.current_epoch = epoch
epoch_loss = 0.0
train_batches = list(enumerate(self.data_loader))
start_idx = start_batch if epoch == start_epoch else 0
progress_bar = tqdm(train_batches[start_idx:],
initial=start_idx,
total=len(train_batches),
desc=f"Epoch {epoch + 1}/{epochs}")
for batch_idx, (low_res, high_res) in progress_bar:
# Training step
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(self.device, non_blocking=True)
self.optimizer.zero_grad()
with autocast(device_type=self.device.type):
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
low_res.requires_grad_()
outputs = checkpoint(self.model, low_res)
else:
outputs = self.model(low_res)
loss = self.criterion(outputs, high_res)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
epoch_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
# Handle checkpoints
self._handle_checkpoints(epoch + 1, batch_idx + 1, loss.item() < self.best_loss)
del low_res, high_res, outputs, loss
# End of epoch processing
avg_train_loss = epoch_loss / len(self.data_loader)
# Validation phase
if self.validation_loader:
val_loss = self._evaluate() / len(self.validation_loader)
is_improved = val_loss < self.best_loss
if is_improved:
self.best_loss = val_loss
# Log to CSV
with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, val_loss])
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
else:
is_improved = avg_train_loss < self.best_loss
if is_improved:
self.best_loss = avg_train_loss
# Log to CSV
with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss])
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
# Save best model if improved
if is_improved:
best_model_path = os.path.join(output_path, "best_model")
self.model.save_pretrained(best_model_path)
# Check early stopping
if early_stopping(val_loss if self.validation_loader else avg_train_loss):
print(f"Early stopping triggered at epoch {epoch + 1}")
break
# Cleanup
gc.collect()
torch.cuda.empty_cache()
return self.best_loss
def load_checkpoint(self, specific_checkpoint=None):
"""Enhanced checkpoint loading with specific checkpoint support"""
if specific_checkpoint:
checkpoint_path = os.path.join(self.checkpoint_dir, specific_checkpoint)
else:
checkpoint_files = [f for f in os.listdir(self.checkpoint_dir)
if f.startswith("checkpoint_") and f.endswith(".pt")]
if not checkpoint_files:
return None
checkpoint_files.sort(key=lambda x: os.path.getmtime(
os.path.join(self.checkpoint_dir, x)), reverse=True)
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_files[0])
if not os.path.exists(checkpoint_path):
return None
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
self.best_loss = checkpoint['best_loss']
print(f"Loaded checkpoint from {checkpoint_path}")
return checkpoint['epoch'], checkpoint['batch']
def save(self, output_path=None):
"""
Save the best model to the specified path
Args:
output_path (str, optional): Path to save the model. If None, tries to use the checkpoint directory from training.
Returns:
str: Path where the model was saved
Raises:
ValueError: If no output path is specified and no checkpoint directory exists
"""
if output_path is None and self.checkpoint_dir is not None:
# First try to copy the best model if it exists
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
if os.path.exists(best_model_path):
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
shutil.copytree(best_model_path, output_path, dirs_exist_ok=True)
print(f"Copied best model to {output_path}")
return output_path
else:
# If no best model exists, save current model state
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
if output_path is None:
raise ValueError("No output path specified and no checkpoint directory exists from training.")
self.model.save_pretrained(output_path)
print(f"Model saved to {output_path}")
return output_path