diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 42ba4b8..94d412d 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -1,6 +1,8 @@ import torch from torch import nn import csv +import datetime +import time import pandas as pd from tqdm import tqdm from ..model.Model import AIIA @@ -112,18 +114,21 @@ class Pretrainer: return batch_loss - def train(self, dataset_paths,output_path:str="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000): - """ - Train the model using multiple specified datasets. - - Args: - dataset_paths (list): List of paths to parquet datasets - num_epochs (int): Number of training epochs - batch_size (int): Batch size for training - sample_size (int): Number of samples to use from each dataset - """ + def train(self, dataset_paths, output_path="AIIA", column="image_bytes", + num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): + """Train the model using multiple specified datasets.""" if not dataset_paths: raise ValueError("No dataset paths provided") + + # Checkpoint tracking variables + last_checkpoint_time = time.time() + checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds + last_22_date = None + recent_checkpoints = [] + + # Create checkpoint directory if specified + if checkpoint_dir is not None: + os.makedirs(checkpoint_dir, exist_ok=True) # Read and merge all datasets dataframes = [] @@ -166,6 +171,59 @@ class Pretrainer: if batch_data is None: continue + # Check if we need to save a checkpoint + current_time = time.time() + current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time + today = current_dt.date() + + # Regular 2-hour checkpoint + if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval: + checkpoint_path = os.path.join( + checkpoint_dir, + f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt" + ) + torch.save({ + 'epoch': epoch + 1, + 'batch': batch_count, + 'model_state_dict': self.model.state_dict(), + 'projection_head_state_dict': self.projection_head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'train_losses': self.train_losses, + 'val_losses': self.val_losses, + }, checkpoint_path) + + # Track and maintain only 3 recent checkpoints + recent_checkpoints.append(checkpoint_path) + if len(recent_checkpoints) > 3: + oldest = recent_checkpoints.pop(0) + if os.path.exists(oldest): + os.remove(oldest) + + last_checkpoint_time = current_time + print(f"Checkpoint saved at {checkpoint_path}") + + # Special 22:00 checkpoint + is_22_oclock = current_dt.hour == 22 and current_dt.minute == 0 and current_dt.second < 10 + + if checkpoint_dir and is_22_oclock and last_22_date != today: + checkpoint_path = os.path.join( + checkpoint_dir, + f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" + ) + torch.save({ + 'epoch': epoch + 1, + 'batch': batch_count, + 'model_state_dict': self.model.state_dict(), + 'projection_head_state_dict': self.projection_head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'train_losses': self.train_losses, + 'val_losses': self.val_losses, + }, checkpoint_path) + + last_22_date = today + print(f"22:00 Checkpoint saved at {checkpoint_path}") + + # Process the batch self.optimizer.zero_grad() batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) @@ -192,6 +250,7 @@ class Pretrainer: losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') self.save_losses(losses_path) + def _validate(self, val_loader, criterion_denoise, criterion_rotate): """Perform validation and return average validation loss.""" val_loss = 0.0