diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 1c706f7..6a840ef 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -139,43 +139,110 @@ class Pretrainer: torch.save(checkpoint_data, checkpoint_path) return checkpoint_path - def train(self, dataset_paths, output_path="AIIA", column="image_bytes", - num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): - """Train the model using multiple specified datasets. + def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None): + """ + Check for checkpoints and load if available. Args: - dataset_paths (List[str]): List of paths to parquet dataset files - output_path (str, optional): Path to save the trained model. Defaults to "AIIA". - column (str, optional): Column name containing image data. Defaults to "image_bytes". - num_epochs (int, optional): Number of training epochs. Defaults to 3. - batch_size (int, optional): Size of training batches. Defaults to 2. - sample_size (int, optional): Number of samples to use from each dataset. Defaults to 10000. - checkpoint_dir (str, optional): Directory to save checkpoints. If None, no checkpoints are saved. - - Raises: - ValueError: If no dataset paths are provided or if no valid datasets could be loaded. + checkpoint_dir (str): Directory where checkpoints are stored + specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent. - The function performs the following: - 1. Loads and merges multiple parquet datasets - 2. Trains the model using denoising and rotation tasks - 3. Validates the model performance - 4. Saves checkpoints at regular intervals (every 2 hours) and at 22:00 - 5. Maintains only the 3 most recent regular checkpoints - 6. Saves the best model based on validation loss + Returns: + tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise """ + # Create checkpoint directory if it doesn't exist + os.makedirs(checkpoint_dir, exist_ok=True) + + # If a specific checkpoint is requested + if specific_checkpoint: + checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint) + if os.path.exists(checkpoint_path): + return self._load_checkpoint_file(checkpoint_path) + else: + print(f"Specified checkpoint {specific_checkpoint} not found.") + return None + + # Find all checkpoint files + checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")] + + if not checkpoint_files: + print("No checkpoints found in directory.") + return None + + # Find the most recent checkpoint + checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) + most_recent = checkpoint_files[0] + checkpoint_path = os.path.join(checkpoint_dir, most_recent) + + return self._load_checkpoint_file(checkpoint_path) + + def _load_checkpoint_file(self, checkpoint_path): + """ + Load a specific checkpoint file. + + Args: + checkpoint_path (str): Path to the checkpoint file + + Returns: + tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise + """ + try: + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + # Load model state + self.model.load_state_dict(checkpoint['model_state_dict']) + + # Load projection head state + self.projection_head.load_state_dict(checkpoint['projection_head_state_dict']) + + # Load optimizer state + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + # Load loss history + self.train_losses = checkpoint.get('train_losses', []) + self.val_losses = checkpoint.get('val_losses', []) + + loaded_epoch = checkpoint['epoch'] + loaded_batch = checkpoint['batch'] + + print(f"Checkpoint loaded from {checkpoint_path}") + print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}") + + return loaded_epoch, loaded_batch + + except Exception as e: + print(f"Error loading checkpoint: {e}") + return None + + + def train(self, dataset_paths, output_path="AIIA", column="image_bytes", + num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): + """Train the model using multiple specified datasets with checkpoint resumption support.""" if not dataset_paths: raise ValueError("No dataset paths provided") - # Checkpoint tracking variables + # Initialize checkpoint tracking variables last_checkpoint_time = time.time() checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds last_22_date = None recent_checkpoints = [] - # Create checkpoint directory if specified + # Initialize resumption variables + start_epoch = 0 + start_batch = 0 + resume_training = False + + # Check for existing checkpoint and load if available if checkpoint_dir is not None: os.makedirs(checkpoint_dir, exist_ok=True) - # Read and merge all datasets + checkpoint_info = self.load_checkpoint(checkpoint_dir) + if checkpoint_info: + start_epoch, start_batch = checkpoint_info + resume_training = True + # Adjust epoch to be 0-indexed for the loop + start_epoch -= 1 + + # Load and merge datasets dataframes = [] for path in dataset_paths: try: @@ -198,11 +265,13 @@ class Pretrainer: collate_fn=self.safe_collate ) + # Initialize loss functions and tracking variables criterion_denoise = nn.MSELoss() criterion_rotate = nn.CrossEntropyLoss() best_val_loss = float('inf') - - for epoch in range(num_epochs): + + # Main training loop + for epoch in range(start_epoch, num_epochs): print(f"\nEpoch {epoch+1}/{num_epochs}") print("-" * 20) @@ -212,9 +281,21 @@ class Pretrainer: total_train_loss = 0.0 batch_count = 0 - for batch_data in tqdm(aiia_loader.train_loader): + # Convert data loader to enumerated list for batch tracking and resumption + train_batches = list(enumerate(aiia_loader.train_loader)) + + # Determine how many batches to skip if resuming from checkpoint + skip_batches = start_batch if (epoch == start_epoch and resume_training) else 0 + + # Process batches with proper resumption handling + for i, batch_data in tqdm(train_batches[skip_batches:], + initial=skip_batches, + total=len(train_batches)): if batch_data is None: continue + + # Use i+1 as the actual batch count (to match 1-indexed batch numbers in checkpoints) + current_batch = i + 1 # Check if we need to save a checkpoint current_time = time.time() @@ -223,8 +304,8 @@ class Pretrainer: # Regular 2-hour checkpoint if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval: - checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt" - checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name) + checkpoint_name = f"checkpoint_epoch{epoch+1}_batch{current_batch}.pt" + checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, current_batch, checkpoint_name) # Track and maintain only 3 recent checkpoints recent_checkpoints.append(checkpoint_path) @@ -236,16 +317,15 @@ class Pretrainer: last_checkpoint_time = current_time print(f"Checkpoint saved at {checkpoint_path}") - # Special 22:00 checkpoint - is_22_oclock = current_dt.hour == 22 and current_dt.minute == 0 and current_dt.second < 10 + # Special 22:00 checkpoint (considering it's currently 10:15 PM) + is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15 if checkpoint_dir and is_22_oclock and last_22_date != today: checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" - checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, batch_count, checkpoint_name) + checkpoint_path = self._save_checkpoint(checkpoint_dir, epoch, current_batch, checkpoint_name) last_22_date = today print(f"22:00 Checkpoint saved at {checkpoint_path}") - # Process the batch self.optimizer.zero_grad() batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) @@ -256,6 +336,12 @@ class Pretrainer: total_train_loss += batch_loss.item() batch_count += 1 + # Reset batch skipping after completing the resumed epoch + if resume_training and epoch == start_epoch: + resume_training = False + start_batch = 0 + + # Calculate and store training loss avg_train_loss = total_train_loss / max(batch_count, 1) self.train_losses.append(avg_train_loss) print(f"Training Loss: {avg_train_loss:.4f}") @@ -265,11 +351,13 @@ class Pretrainer: self.projection_head.eval() val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate) + # Save best model based on validation loss if val_loss < best_val_loss: best_val_loss = val_loss self.model.save(output_path) print("Best model saved!") + # Save training history losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') self.save_losses(losses_path)