370 lines
16 KiB
Python
370 lines
16 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import os
|
|
import csv
|
|
from torch.amp import autocast, GradScaler
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
from torch.utils.checkpoint import checkpoint
|
|
import gc
|
|
import time
|
|
import shutil
|
|
import datetime
|
|
|
|
|
|
class EarlyStopping:
|
|
def __init__(self, patience=3, min_delta=0.001):
|
|
# Number of epochs with no significant improvement before stopping
|
|
# Minimum change in loss required to count as an improvement
|
|
self.patience = patience
|
|
self.min_delta = min_delta
|
|
self.best_loss = float('inf')
|
|
self.counter = 0
|
|
self.early_stop = False
|
|
|
|
def __call__(self, epoch_loss):
|
|
if epoch_loss < self.best_loss - self.min_delta:
|
|
self.best_loss = epoch_loss
|
|
self.counter = 0
|
|
return True # Improved
|
|
else:
|
|
self.counter += 1
|
|
if self.counter >= self.patience:
|
|
self.early_stop = True
|
|
return False # Not improved
|
|
|
|
class aiuNNTrainer:
|
|
def __init__(self, upscaler_model, dataset_class=None):
|
|
"""
|
|
Initialize the upscaler trainer
|
|
|
|
Args:
|
|
upscaler_model: The model to fine-tune
|
|
dataset_class: The dataset class to use for loading data (optional)
|
|
"""
|
|
self.model = upscaler_model
|
|
self.dataset_class = dataset_class
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.model = self.model.to(self.device, memory_format=torch.channels_last)
|
|
self.criterion = nn.MSELoss()
|
|
self.optimizer = None
|
|
self.scaler = GradScaler()
|
|
self.best_loss = float('inf')
|
|
self.csv_path = None
|
|
self.checkpoint_dir = None
|
|
self.data_loader = None
|
|
self.validation_loader = None
|
|
self.last_checkpoint_time = time.time()
|
|
self.checkpoint_interval = 2 * 60 * 60 # 2 hours
|
|
self.last_22_date = None
|
|
self.recent_checkpoints = []
|
|
self.current_epoch = 0
|
|
|
|
|
|
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
|
|
"""
|
|
Load data using either a custom dataset instance or the dataset class provided at initialization
|
|
|
|
Args:
|
|
dataset_params (dict/list): Parameters to pass to the dataset class constructor
|
|
batch_size (int): Batch size for training
|
|
validation_split (float): Proportion of data to use for validation
|
|
custom_train_dataset: A pre-instantiated dataset to use for training (optional)
|
|
custom_val_dataset: A pre-instantiated dataset to use for validation (optional)
|
|
"""
|
|
# If custom datasets are provided directly, use them
|
|
if custom_train_dataset is not None:
|
|
train_dataset = custom_train_dataset
|
|
val_dataset = custom_val_dataset if custom_val_dataset is not None else None
|
|
else:
|
|
# Otherwise instantiate dataset using the class and parameters
|
|
if self.dataset_class is None:
|
|
raise ValueError("No dataset class provided. Either provide a dataset class at initialization or custom datasets.")
|
|
|
|
# Create dataset instance
|
|
dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params})
|
|
|
|
# Split into train and validation sets
|
|
dataset_size = len(dataset)
|
|
val_size = int(validation_split * dataset_size)
|
|
train_size = dataset_size - val_size
|
|
|
|
train_dataset, val_dataset = torch.utils.data.random_split(
|
|
dataset, [train_size, val_size]
|
|
)
|
|
|
|
# Create data loaders
|
|
self.data_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
pin_memory=True
|
|
)
|
|
|
|
if val_dataset is not None:
|
|
self.validation_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
pin_memory=True
|
|
)
|
|
print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples")
|
|
else:
|
|
self.validation_loader = None
|
|
print(f"Loaded {len(train_dataset)} training samples (no validation set)")
|
|
|
|
return self.data_loader, self.validation_loader
|
|
|
|
def _setup_logging(self, output_path):
|
|
"""Set up basic logging and checkpoint directory"""
|
|
# Create checkpoint directory
|
|
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
|
|
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
|
|
|
# Set up CSV logging
|
|
self.csv_path = os.path.join(output_path, 'training_log.csv')
|
|
with open(self.csv_path, mode='w', newline='') as file:
|
|
writer = csv.writer(file)
|
|
if self.validation_loader:
|
|
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
|
|
else:
|
|
writer.writerow(['Epoch', 'Train Loss'])
|
|
|
|
def _evaluate(self):
|
|
"""Evaluate the model on validation data"""
|
|
if self.validation_loader is None:
|
|
return 0.0
|
|
|
|
self.model.eval()
|
|
val_loss = 0.0
|
|
|
|
with torch.no_grad():
|
|
for low_res, high_res in tqdm(self.validation_loader, desc="Validating"):
|
|
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
|
high_res = high_res.to(self.device, non_blocking=True)
|
|
|
|
with autocast(device_type=self.device.type):
|
|
outputs = self.model(low_res)
|
|
loss = self.criterion(outputs, high_res)
|
|
|
|
val_loss += loss.item()
|
|
|
|
del low_res, high_res, outputs, loss
|
|
|
|
self.model.train()
|
|
return val_loss
|
|
|
|
def _save_checkpoint(self, epoch, batch_count, is_best=False, is_22=False):
|
|
"""Save checkpoint with support for regular, best, and 22:00 saves"""
|
|
if is_22:
|
|
today = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))).date()
|
|
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
|
|
else:
|
|
checkpoint_name = f"checkpoint_epoch{epoch}_batch{batch_count}.pt"
|
|
|
|
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
|
|
|
|
checkpoint_data = {
|
|
'epoch': epoch,
|
|
'batch': batch_count,
|
|
'model_state_dict': self.model.state_dict(),
|
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
'best_loss': self.best_loss,
|
|
'scaler_state_dict': self.scaler.state_dict()
|
|
}
|
|
|
|
torch.save(checkpoint_data, checkpoint_path)
|
|
|
|
# Save best model separately
|
|
if is_best:
|
|
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
|
|
self.model.save_pretrained(best_model_path)
|
|
|
|
return checkpoint_path
|
|
|
|
def _handle_checkpoints(self, epoch, batch_count, is_improved):
|
|
"""Handle periodic and 22:00 checkpoint saving"""
|
|
current_time = time.time()
|
|
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2)))
|
|
|
|
# Regular interval checkpoint
|
|
if (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
|
|
self._save_checkpoint(epoch, batch_count, is_improved)
|
|
self.last_checkpoint_time = current_time
|
|
|
|
# Special 22:00 checkpoint
|
|
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
|
|
if is_22_oclock and self.last_22_date != current_dt.date():
|
|
self._save_checkpoint(epoch, batch_count, is_improved, is_22=True)
|
|
self.last_22_date = current_dt.date()
|
|
|
|
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
|
|
"""Finetune the upscaler model"""
|
|
if self.data_loader is None:
|
|
raise ValueError("Data not loaded. Call load_data first.")
|
|
# setup logging
|
|
self._setup_logging(output_path=output_path)
|
|
# Setup optimizer and directories
|
|
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
|
|
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
|
|
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
|
|
|
# Setup CSV logging
|
|
self.csv_path = os.path.join(output_path, 'training_log.csv')
|
|
with open(self.csv_path, mode='w', newline='') as file:
|
|
writer = csv.writer(file)
|
|
header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss']
|
|
writer.writerow(header)
|
|
|
|
# Load existing checkpoint if available
|
|
checkpoint_info = self.load_checkpoint()
|
|
start_epoch = checkpoint_info[0] if checkpoint_info else 0
|
|
start_batch = checkpoint_info[1] if checkpoint_info else 0
|
|
|
|
# Setup early stopping
|
|
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
|
|
self.best_loss = float('inf')
|
|
|
|
# Training loop
|
|
self.model.train()
|
|
for epoch in range(start_epoch, epochs):
|
|
self.current_epoch = epoch
|
|
epoch_loss = 0.0
|
|
|
|
train_batches = list(enumerate(self.data_loader))
|
|
start_idx = start_batch if epoch == start_epoch else 0
|
|
|
|
progress_bar = tqdm(train_batches[start_idx:],
|
|
initial=start_idx,
|
|
total=len(train_batches),
|
|
desc=f"Epoch {epoch + 1}/{epochs}")
|
|
|
|
for batch_idx, (low_res, high_res) in progress_bar:
|
|
# Training step
|
|
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
|
high_res = high_res.to(self.device, non_blocking=True)
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
with autocast(device_type=self.device.type):
|
|
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
|
|
low_res.requires_grad_()
|
|
outputs = checkpoint(self.model, low_res)
|
|
else:
|
|
outputs = self.model(low_res)
|
|
loss = self.criterion(outputs, high_res)
|
|
|
|
self.scaler.scale(loss).backward()
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
|
|
epoch_loss += loss.item()
|
|
progress_bar.set_postfix({'loss': loss.item()})
|
|
|
|
# Handle checkpoints
|
|
self._handle_checkpoints(epoch + 1, batch_idx + 1, loss.item() < self.best_loss)
|
|
|
|
del low_res, high_res, outputs, loss
|
|
|
|
# End of epoch processing
|
|
avg_train_loss = epoch_loss / len(self.data_loader)
|
|
|
|
# Validation phase
|
|
if self.validation_loader:
|
|
val_loss = self._evaluate() / len(self.validation_loader)
|
|
is_improved = val_loss < self.best_loss
|
|
if is_improved:
|
|
self.best_loss = val_loss
|
|
|
|
# Log to CSV
|
|
with open(self.csv_path, mode='a', newline='') as file:
|
|
writer = csv.writer(file)
|
|
writer.writerow([epoch + 1, avg_train_loss, val_loss])
|
|
|
|
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
|
|
else:
|
|
is_improved = avg_train_loss < self.best_loss
|
|
if is_improved:
|
|
self.best_loss = avg_train_loss
|
|
|
|
# Log to CSV
|
|
with open(self.csv_path, mode='a', newline='') as file:
|
|
writer = csv.writer(file)
|
|
writer.writerow([epoch + 1, avg_train_loss])
|
|
|
|
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
|
|
|
|
# Save best model if improved
|
|
if is_improved:
|
|
best_model_path = os.path.join(output_path, "best_model")
|
|
self.model.save_pretrained(best_model_path)
|
|
|
|
# Check early stopping
|
|
if early_stopping(val_loss if self.validation_loader else avg_train_loss):
|
|
print(f"Early stopping triggered at epoch {epoch + 1}")
|
|
break
|
|
|
|
# Cleanup
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
return self.best_loss
|
|
|
|
def load_checkpoint(self, specific_checkpoint=None):
|
|
"""Enhanced checkpoint loading with specific checkpoint support"""
|
|
if specific_checkpoint:
|
|
checkpoint_path = os.path.join(self.checkpoint_dir, specific_checkpoint)
|
|
else:
|
|
checkpoint_files = [f for f in os.listdir(self.checkpoint_dir)
|
|
if f.startswith("checkpoint_") and f.endswith(".pt")]
|
|
if not checkpoint_files:
|
|
return None
|
|
|
|
checkpoint_files.sort(key=lambda x: os.path.getmtime(
|
|
os.path.join(self.checkpoint_dir, x)), reverse=True)
|
|
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_files[0])
|
|
|
|
if not os.path.exists(checkpoint_path):
|
|
return None
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
self.best_loss = checkpoint['best_loss']
|
|
|
|
print(f"Loaded checkpoint from {checkpoint_path}")
|
|
return checkpoint['epoch'], checkpoint['batch']
|
|
|
|
def save(self, output_path=None):
|
|
"""
|
|
Save the best model to the specified path
|
|
|
|
Args:
|
|
output_path (str, optional): Path to save the model. If None, tries to use the checkpoint directory from training.
|
|
|
|
Returns:
|
|
str: Path where the model was saved
|
|
|
|
Raises:
|
|
ValueError: If no output path is specified and no checkpoint directory exists
|
|
"""
|
|
if output_path is None and self.checkpoint_dir is not None:
|
|
# First try to copy the best model if it exists
|
|
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
|
|
if os.path.exists(best_model_path):
|
|
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
|
|
shutil.copytree(best_model_path, output_path, dirs_exist_ok=True)
|
|
print(f"Copied best model to {output_path}")
|
|
return output_path
|
|
else:
|
|
# If no best model exists, save current model state
|
|
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
|
|
|
|
if output_path is None:
|
|
raise ValueError("No output path specified and no checkpoint directory exists from training.")
|
|
|
|
self.model.save_pretrained(output_path)
|
|
print(f"Model saved to {output_path}")
|
|
return output_path |