updated trainer to save checkpooints after n hours and at 22 o'clock with the mission to safe energy

This commit is contained in:
Falko Victor Habel 2025-04-20 22:28:30 +02:00
parent 45d6802cd7
commit b0d0b41944
1 changed files with 152 additions and 73 deletions

View File

@ -10,6 +10,7 @@ from torch.utils.checkpoint import checkpoint
import gc import gc
import time import time
import shutil import shutil
import datetime
class EarlyStopping: class EarlyStopping:
@ -50,10 +51,16 @@ class aiuNNTrainer:
self.optimizer = None self.optimizer = None
self.scaler = GradScaler() self.scaler = GradScaler()
self.best_loss = float('inf') self.best_loss = float('inf')
self.use_checkpointing = True self.csv_path = None
self.checkpoint_dir = None
self.data_loader = None self.data_loader = None
self.validation_loader = None self.validation_loader = None
self.log_dir = None self.last_checkpoint_time = time.time()
self.checkpoint_interval = 2 * 60 * 60 # 2 hours
self.last_22_date = None
self.recent_checkpoints = []
self.current_epoch = 0
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None): def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
""" """
@ -110,23 +117,19 @@ class aiuNNTrainer:
return self.data_loader, self.validation_loader return self.data_loader, self.validation_loader
def _setup_logging(self, output_path): def _setup_logging(self, output_path):
"""Set up directory structure for logging and model checkpoints""" """Set up basic logging and checkpoint directory"""
timestamp = time.strftime("%Y%m%d-%H%M%S")
self.log_dir = os.path.join(output_path, f"training_run_{timestamp}")
os.makedirs(self.log_dir, exist_ok=True)
# Create checkpoint directory # Create checkpoint directory
self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints") self.checkpoint_dir = os.path.join(output_path, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True) os.makedirs(self.checkpoint_dir, exist_ok=True)
# Set up CSV logging # Set up CSV logging
self.csv_path = os.path.join(self.log_dir, 'training_log.csv') self.csv_path = os.path.join(output_path, 'training_log.csv')
with open(self.csv_path, mode='w', newline='') as file: with open(self.csv_path, mode='w', newline='') as file:
writer = csv.writer(file) writer = csv.writer(file)
if self.validation_loader: if self.validation_loader:
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved']) writer.writerow(['Epoch', 'Train Loss', 'Validation Loss'])
else: else:
writer.writerow(['Epoch', 'Train Loss', 'Improved']) writer.writerow(['Epoch', 'Train Loss'])
def _evaluate(self): def _evaluate(self):
"""Evaluate the model on validation data""" """Evaluate the model on validation data"""
@ -152,63 +155,99 @@ class aiuNNTrainer:
self.model.train() self.model.train()
return val_loss return val_loss
def _save_checkpoint(self, epoch, is_best=False): def _save_checkpoint(self, epoch, batch_count, is_best=False, is_22=False):
"""Save model checkpoint""" """Save checkpoint with support for regular, best, and 22:00 saves"""
checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt") if is_22:
best_model_path = os.path.join(self.log_dir, "best_model") today = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))).date()
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
else:
checkpoint_name = f"checkpoint_epoch{epoch}_batch{batch_count}.pt"
# Save the model checkpoint checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name)
self.model.save(checkpoint_path)
# If this is the best model so far, copy it to best_model checkpoint_data = {
'epoch': epoch,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_loss': self.best_loss,
'scaler_state_dict': self.scaler.state_dict()
}
torch.save(checkpoint_data, checkpoint_path)
# Save best model separately
if is_best: if is_best:
if os.path.exists(best_model_path): best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
shutil.rmtree(best_model_path) self.model.save_pretrained(best_model_path)
self.model.save(best_model_path)
print(f"Saved new best model with loss: {self.best_loss:.6f}") return checkpoint_path
def _handle_checkpoints(self, epoch, batch_count, is_improved):
"""Handle periodic and 22:00 checkpoint saving"""
current_time = time.time()
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2)))
# Regular interval checkpoint
if (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
self._save_checkpoint(epoch, batch_count, is_improved)
self.last_checkpoint_time = current_time
# Special 22:00 checkpoint
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
if is_22_oclock and self.last_22_date != current_dt.date():
self._save_checkpoint(epoch, batch_count, is_improved, is_22=True)
self.last_22_date = current_dt.date()
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001): def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
""" """Finetune the upscaler model"""
Finetune the upscaler model
Args:
output_path (str): Directory to save models and logs
epochs (int): Maximum number of training epochs
lr (float): Learning rate
patience (int): Early stopping patience
min_delta (float): Minimum improvement for early stopping
"""
# Check if data is loaded
if self.data_loader is None: if self.data_loader is None:
raise ValueError("Data not loaded. Call load_data first.") raise ValueError("Data not loaded. Call load_data first.")
# Setup optimizer # Setup optimizer and directories
self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
self.checkpoint_dir = os.path.join(output_path, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True)
# Set up logging # Setup CSV logging
self._setup_logging(output_path) self.csv_path = os.path.join(output_path, 'training_log.csv')
with open(self.csv_path, mode='w', newline='') as file:
writer = csv.writer(file)
header = ['Epoch', 'Train Loss', 'Validation Loss'] if self.validation_loader else ['Epoch', 'Train Loss']
writer.writerow(header)
# Load existing checkpoint if available
checkpoint_info = self.load_checkpoint()
start_epoch = checkpoint_info[0] if checkpoint_info else 0
start_batch = checkpoint_info[1] if checkpoint_info else 0
# Setup early stopping # Setup early stopping
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta) early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
self.best_loss = float('inf')
# Training loop # Training loop
self.model.train() self.model.train()
for epoch in range(start_epoch, epochs):
for epoch in range(epochs): self.current_epoch = epoch
# Training phase
epoch_loss = 0.0 epoch_loss = 0.0
progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
for low_res, high_res in progress_bar: train_batches = list(enumerate(self.data_loader))
# Move data to GPU with channels_last format where possible start_idx = start_batch if epoch == start_epoch else 0
progress_bar = tqdm(train_batches[start_idx:],
initial=start_idx,
total=len(train_batches),
desc=f"Epoch {epoch + 1}/{epochs}")
for batch_idx, (low_res, high_res) in progress_bar:
# Training step
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(self.device, non_blocking=True) high_res = high_res.to(self.device, non_blocking=True)
self.optimizer.zero_grad() self.optimizer.zero_grad()
with autocast(device_type=self.device.type): with autocast(device_type=self.device.type):
if self.use_checkpointing: if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
# Ensure the input tensor requires gradient so that checkpointing records the computation graph
low_res.requires_grad_() low_res.requires_grad_()
outputs = checkpoint(self.model, low_res) outputs = checkpoint(self.model, low_res)
else: else:
@ -222,69 +261,109 @@ class aiuNNTrainer:
epoch_loss += loss.item() epoch_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()}) progress_bar.set_postfix({'loss': loss.item()})
# Optionally delete variables to free memory # Handle checkpoints
self._handle_checkpoints(epoch + 1, batch_idx + 1, loss.item() < self.best_loss)
del low_res, high_res, outputs, loss del low_res, high_res, outputs, loss
# Calculate average epoch loss # End of epoch processing
avg_train_loss = epoch_loss / len(self.data_loader) avg_train_loss = epoch_loss / len(self.data_loader)
# Validation phase (if validation loader exists) # Validation phase
if self.validation_loader: if self.validation_loader:
val_loss = self._evaluate() / len(self.validation_loader) val_loss = self._evaluate() / len(self.validation_loader)
is_improved = val_loss < self.best_loss is_improved = val_loss < self.best_loss
if is_improved: if is_improved:
self.best_loss = val_loss self.best_loss = val_loss
# Log results # Log to CSV
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
with open(self.csv_path, mode='a', newline='') as file: with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file) writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"]) writer.writerow([epoch + 1, avg_train_loss, val_loss])
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
else: else:
# If no validation, use training loss for improvement tracking
is_improved = avg_train_loss < self.best_loss is_improved = avg_train_loss < self.best_loss
if is_improved: if is_improved:
self.best_loss = avg_train_loss self.best_loss = avg_train_loss
# Log results # Log to CSV
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
with open(self.csv_path, mode='a', newline='') as file: with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file) writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"]) writer.writerow([epoch + 1, avg_train_loss])
# Save checkpoint print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
self._save_checkpoint(epoch + 1, is_best=is_improved)
# Perform garbage collection and clear GPU cache after each epoch # Save best model if improved
gc.collect() if is_improved:
torch.cuda.empty_cache() best_model_path = os.path.join(output_path, "best_model")
self.model.save_pretrained(best_model_path)
# Check early stopping # Check early stopping
early_stopping(val_loss if self.validation_loader else avg_train_loss) if early_stopping(val_loss if self.validation_loader else avg_train_loss):
if early_stopping.early_stop:
print(f"Early stopping triggered at epoch {epoch + 1}") print(f"Early stopping triggered at epoch {epoch + 1}")
break break
# Cleanup
gc.collect()
torch.cuda.empty_cache()
return self.best_loss return self.best_loss
def load_checkpoint(self, specific_checkpoint=None):
"""Enhanced checkpoint loading with specific checkpoint support"""
if specific_checkpoint:
checkpoint_path = os.path.join(self.checkpoint_dir, specific_checkpoint)
else:
checkpoint_files = [f for f in os.listdir(self.checkpoint_dir)
if f.startswith("checkpoint_") and f.endswith(".pt")]
if not checkpoint_files:
return None
checkpoint_files.sort(key=lambda x: os.path.getmtime(
os.path.join(self.checkpoint_dir, x)), reverse=True)
checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_files[0])
if not os.path.exists(checkpoint_path):
return None
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
self.best_loss = checkpoint['best_loss']
print(f"Loaded checkpoint from {checkpoint_path}")
return checkpoint['epoch'], checkpoint['batch']
def save(self, output_path=None): def save(self, output_path=None):
""" """
Save the best model to the specified path Save the best model to the specified path
Args: Args:
output_path (str, optional): Path to save the model. If None, uses the best model from training. output_path (str, optional): Path to save the model. If None, tries to use the checkpoint directory from training.
Returns:
str: Path where the model was saved
Raises:
ValueError: If no output path is specified and no checkpoint directory exists
""" """
if output_path is None and self.log_dir is not None: if output_path is None and self.checkpoint_dir is not None:
best_model_path = os.path.join(self.log_dir, "best_model") # First try to copy the best model if it exists
best_model_path = os.path.join(os.path.dirname(self.checkpoint_dir), "best_model")
if os.path.exists(best_model_path): if os.path.exists(best_model_path):
print(f"Best model already saved at {best_model_path}") output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
return best_model_path shutil.copytree(best_model_path, output_path, dirs_exist_ok=True)
print(f"Copied best model to {output_path}")
return output_path
else: else:
output_path = os.path.join(self.log_dir, "final_model") # If no best model exists, save current model state
output_path = os.path.join(os.path.dirname(self.checkpoint_dir), "final_model")
if output_path is None: if output_path is None:
raise ValueError("No output path specified and no training has been done yet.") raise ValueError("No output path specified and no checkpoint directory exists from training.")
self.model.save(output_path) self.model.save_pretrained(output_path)
print(f"Model saved to {output_path}") print(f"Model saved to {output_path}")
return output_path return output_path