diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 94d412d..1c706f7 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -114,9 +114,55 @@ class Pretrainer: return batch_loss + def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name): + """Save a model checkpoint. + + Args: + checkpoint_dir (str): Directory to save the checkpoint + epoch (int): Current epoch number + batch_count (int): Current batch count + checkpoint_name (str): Name for the checkpoint file + + Returns: + str: Path to the saved checkpoint + """ + checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + checkpoint_data = { + 'epoch': epoch + 1, + 'batch': batch_count, + 'model_state_dict': self.model.state_dict(), + 'projection_head_state_dict': self.projection_head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'train_losses': self.train_losses, + 'val_losses': self.val_losses, + } + torch.save(checkpoint_data, checkpoint_path) + return checkpoint_path + def train(self, dataset_paths, output_path="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): - """Train the model using multiple specified datasets.""" + """Train the model using multiple specified datasets. + + Args: + dataset_paths (List[str]): List of paths to parquet dataset files + output_path (str, optional): Path to save the trained model. Defaults to "AIIA". + column (str, optional): Column name containing image data. Defaults to "image_bytes". + num_epochs (int, optional): Number of training epochs. Defaults to 3. + batch_size (int, optional): Size of training batches. Defaults to 2. + sample_size (int, optional): Number of samples to use from each dataset. Defaults to 10000. + checkpoint_dir (str, optional): Directory to save checkpoints. If None, no checkpoints are saved. + + Raises: + ValueError: If no dataset paths are provided or if no valid datasets could be loaded. + + The function performs the following: + 1. Loads and merges multiple parquet datasets + 2. Trains the model using denoising and rotation tasks + 3. Validates the model performance + 4. Saves checkpoints at regular intervals (every 2 hours) and at 22:00 + 5. Maintains only the 3 most recent regular checkpoints + 6. Saves the best model based on validation loss + """ if not dataset_paths: raise ValueError("No dataset paths provided") @@ -129,7 +175,6 @@ class Pretrainer: # Create checkpoint directory if specified if checkpoint_dir is not None: os.makedirs(checkpoint_dir, exist_ok=True) - # Read and merge all datasets dataframes = [] for path in dataset_paths: @@ -175,22 +220,11 @@ class Pretrainer: current_time = time.time() current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time today = current_dt.date() - + # Regular 2-hour checkpoint if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval: - checkpoint_path = os.path.join( - checkpoint_dir, - f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt" - ) - torch.save({ - 'epoch': epoch + 1, - 'batch': batch_count, - 'model_state_dict': self.model.state_dict(), - 'projection_head_state_dict': self.projection_head.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'train_losses': self.train_losses, - 'val_losses': self.val_losses, - }, checkpoint_path) + checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt" + checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name) # Track and maintain only 3 recent checkpoints recent_checkpoints.append(checkpoint_path) @@ -206,23 +240,12 @@ class Pretrainer: is_22_oclock = current_dt.hour == 22 and current_dt.minute == 0 and current_dt.second < 10 if checkpoint_dir and is_22_oclock and last_22_date != today: - checkpoint_path = os.path.join( - checkpoint_dir, - f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" - ) - torch.save({ - 'epoch': epoch + 1, - 'batch': batch_count, - 'model_state_dict': self.model.state_dict(), - 'projection_head_state_dict': self.projection_head.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'train_losses': self.train_losses, - 'val_losses': self.val_losses, - }, checkpoint_path) - + checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" + checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name) last_22_date = today print(f"22:00 Checkpoint saved at {checkpoint_path}") + # Process the batch self.optimizer.zero_grad() batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)