290 lines
12 KiB
Python
290 lines
12 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import os
|
|
import csv
|
|
from torch.amp import autocast, GradScaler
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
from torch.utils.checkpoint import checkpoint
|
|
import gc
|
|
import time
|
|
import shutil
|
|
|
|
|
|
class EarlyStopping:
|
|
def __init__(self, patience=3, min_delta=0.001):
|
|
# Number of epochs with no significant improvement before stopping
|
|
# Minimum change in loss required to count as an improvement
|
|
self.patience = patience
|
|
self.min_delta = min_delta
|
|
self.best_loss = float('inf')
|
|
self.counter = 0
|
|
self.early_stop = False
|
|
|
|
def __call__(self, epoch_loss):
|
|
if epoch_loss < self.best_loss - self.min_delta:
|
|
self.best_loss = epoch_loss
|
|
self.counter = 0
|
|
return True # Improved
|
|
else:
|
|
self.counter += 1
|
|
if self.counter >= self.patience:
|
|
self.early_stop = True
|
|
return False # Not improved
|
|
|
|
class aiuNNTrainer:
|
|
def __init__(self, upscaler_model, dataset_class=None):
|
|
"""
|
|
Initialize the upscaler trainer
|
|
|
|
Args:
|
|
upscaler_model: The model to fine-tune
|
|
dataset_class: The dataset class to use for loading data (optional)
|
|
"""
|
|
self.model = upscaler_model
|
|
self.dataset_class = dataset_class
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.model = self.model.to(self.device, memory_format=torch.channels_last)
|
|
self.criterion = nn.MSELoss()
|
|
self.optimizer = None
|
|
self.scaler = GradScaler()
|
|
self.best_loss = float('inf')
|
|
self.use_checkpointing = True
|
|
self.data_loader = None
|
|
self.validation_loader = None
|
|
self.log_dir = None
|
|
|
|
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
|
|
"""
|
|
Load data using either a custom dataset instance or the dataset class provided at initialization
|
|
|
|
Args:
|
|
dataset_params (dict/list): Parameters to pass to the dataset class constructor
|
|
batch_size (int): Batch size for training
|
|
validation_split (float): Proportion of data to use for validation
|
|
custom_train_dataset: A pre-instantiated dataset to use for training (optional)
|
|
custom_val_dataset: A pre-instantiated dataset to use for validation (optional)
|
|
"""
|
|
# If custom datasets are provided directly, use them
|
|
if custom_train_dataset is not None:
|
|
train_dataset = custom_train_dataset
|
|
val_dataset = custom_val_dataset if custom_val_dataset is not None else None
|
|
else:
|
|
# Otherwise instantiate dataset using the class and parameters
|
|
if self.dataset_class is None:
|
|
raise ValueError("No dataset class provided. Either provide a dataset class at initialization or custom datasets.")
|
|
|
|
# Create dataset instance
|
|
dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params})
|
|
|
|
# Split into train and validation sets
|
|
dataset_size = len(dataset)
|
|
val_size = int(validation_split * dataset_size)
|
|
train_size = dataset_size - val_size
|
|
|
|
train_dataset, val_dataset = torch.utils.data.random_split(
|
|
dataset, [train_size, val_size]
|
|
)
|
|
|
|
# Create data loaders
|
|
self.data_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
pin_memory=True
|
|
)
|
|
|
|
if val_dataset is not None:
|
|
self.validation_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
pin_memory=True
|
|
)
|
|
print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples")
|
|
else:
|
|
self.validation_loader = None
|
|
print(f"Loaded {len(train_dataset)} training samples (no validation set)")
|
|
|
|
return self.data_loader, self.validation_loader
|
|
|
|
def _setup_logging(self, output_path):
|
|
"""Set up directory structure for logging and model checkpoints"""
|
|
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
|
self.log_dir = os.path.join(output_path, f"training_run_{timestamp}")
|
|
os.makedirs(self.log_dir, exist_ok=True)
|
|
|
|
# Create checkpoint directory
|
|
self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints")
|
|
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
|
|
|
# Set up CSV logging
|
|
self.csv_path = os.path.join(self.log_dir, 'training_log.csv')
|
|
with open(self.csv_path, mode='w', newline='') as file:
|
|
writer = csv.writer(file)
|
|
if self.validation_loader:
|
|
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved'])
|
|
else:
|
|
writer.writerow(['Epoch', 'Train Loss', 'Improved'])
|
|
|
|
def _evaluate(self):
|
|
"""Evaluate the model on validation data"""
|
|
if self.validation_loader is None:
|
|
return 0.0
|
|
|
|
self.model.eval()
|
|
val_loss = 0.0
|
|
|
|
with torch.no_grad():
|
|
for low_res, high_res in tqdm(self.validation_loader, desc="Validating"):
|
|
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
|
high_res = high_res.to(self.device, non_blocking=True)
|
|
|
|
with autocast(device_type=self.device.type):
|
|
outputs = self.model(low_res)
|
|
loss = self.criterion(outputs, high_res)
|
|
|
|
val_loss += loss.item()
|
|
|
|
del low_res, high_res, outputs, loss
|
|
|
|
self.model.train()
|
|
return val_loss
|
|
|
|
def _save_checkpoint(self, epoch, is_best=False):
|
|
"""Save model checkpoint"""
|
|
checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt")
|
|
best_model_path = os.path.join(self.log_dir, "best_model")
|
|
|
|
# Save the model checkpoint
|
|
self.model.save(checkpoint_path)
|
|
|
|
# If this is the best model so far, copy it to best_model
|
|
if is_best:
|
|
if os.path.exists(best_model_path):
|
|
shutil.rmtree(best_model_path)
|
|
self.model.save(best_model_path)
|
|
print(f"Saved new best model with loss: {self.best_loss:.6f}")
|
|
|
|
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
|
|
"""
|
|
Finetune the upscaler model
|
|
|
|
Args:
|
|
output_path (str): Directory to save models and logs
|
|
epochs (int): Maximum number of training epochs
|
|
lr (float): Learning rate
|
|
patience (int): Early stopping patience
|
|
min_delta (float): Minimum improvement for early stopping
|
|
"""
|
|
# Check if data is loaded
|
|
if self.data_loader is None:
|
|
raise ValueError("Data not loaded. Call load_data first.")
|
|
|
|
# Setup optimizer
|
|
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
|
|
|
# Set up logging
|
|
self._setup_logging(output_path)
|
|
|
|
# Setup early stopping
|
|
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
|
|
|
|
# Training loop
|
|
self.model.train()
|
|
|
|
for epoch in range(epochs):
|
|
# Training phase
|
|
epoch_loss = 0.0
|
|
progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
|
|
|
|
for low_res, high_res in progress_bar:
|
|
# Move data to GPU with channels_last format where possible
|
|
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
|
high_res = high_res.to(self.device, non_blocking=True)
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
with autocast(device_type=self.device.type):
|
|
if self.use_checkpointing:
|
|
# Ensure the input tensor requires gradient so that checkpointing records the computation graph
|
|
low_res.requires_grad_()
|
|
outputs = checkpoint(self.model, low_res)
|
|
else:
|
|
outputs = self.model(low_res)
|
|
loss = self.criterion(outputs, high_res)
|
|
|
|
self.scaler.scale(loss).backward()
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
|
|
epoch_loss += loss.item()
|
|
progress_bar.set_postfix({'loss': loss.item()})
|
|
|
|
# Optionally delete variables to free memory
|
|
del low_res, high_res, outputs, loss
|
|
|
|
# Calculate average epoch loss
|
|
avg_train_loss = epoch_loss / len(self.data_loader)
|
|
|
|
# Validation phase (if validation loader exists)
|
|
if self.validation_loader:
|
|
val_loss = self._evaluate() / len(self.validation_loader)
|
|
is_improved = val_loss < self.best_loss
|
|
if is_improved:
|
|
self.best_loss = val_loss
|
|
|
|
# Log results
|
|
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
|
|
with open(self.csv_path, mode='a', newline='') as file:
|
|
writer = csv.writer(file)
|
|
writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"])
|
|
else:
|
|
# If no validation, use training loss for improvement tracking
|
|
is_improved = avg_train_loss < self.best_loss
|
|
if is_improved:
|
|
self.best_loss = avg_train_loss
|
|
|
|
# Log results
|
|
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
|
|
with open(self.csv_path, mode='a', newline='') as file:
|
|
writer = csv.writer(file)
|
|
writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"])
|
|
|
|
# Save checkpoint
|
|
self._save_checkpoint(epoch + 1, is_best=is_improved)
|
|
|
|
# Perform garbage collection and clear GPU cache after each epoch
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Check early stopping
|
|
early_stopping(val_loss if self.validation_loader else avg_train_loss)
|
|
if early_stopping.early_stop:
|
|
print(f"Early stopping triggered at epoch {epoch + 1}")
|
|
break
|
|
|
|
return self.best_loss
|
|
|
|
def save(self, output_path=None):
|
|
"""
|
|
Save the best model to the specified path
|
|
|
|
Args:
|
|
output_path (str, optional): Path to save the model. If None, uses the best model from training.
|
|
"""
|
|
if output_path is None and self.log_dir is not None:
|
|
best_model_path = os.path.join(self.log_dir, "best_model")
|
|
if os.path.exists(best_model_path):
|
|
print(f"Best model already saved at {best_model_path}")
|
|
return best_model_path
|
|
else:
|
|
output_path = os.path.join(self.log_dir, "final_model")
|
|
|
|
if output_path is None:
|
|
raise ValueError("No output path specified and no training has been done yet.")
|
|
|
|
self.model.save(output_path)
|
|
print(f"Model saved to {output_path}")
|
|
return output_path |