From c33e9c97406b705c8dfa172bba7d608b0977aa00 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Wed, 23 Apr 2025 15:23:01 +0200 Subject: [PATCH] hotfix for pretrainer --- src/aiia/pretrain/pretrainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 6815a90..62c01c0 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -42,6 +42,8 @@ class Pretrainer: ) self.train_losses = [] self.val_losses = [] + self.checkpoint_dir = None # Initialize checkpoint_dir + self.current_epoch = 0 # Add current_epoch tracking @staticmethod def safe_collate(batch): @@ -140,8 +142,7 @@ class Pretrainer: return checkpoint_path def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None): - """ - Check for checkpoints and load if available. + """Check for checkpoints and load if available. Args: checkpoint_dir (str): Directory where checkpoints are stored @@ -177,8 +178,7 @@ class Pretrainer: return self._load_checkpoint_file(checkpoint_path) def _load_checkpoint_file(self, checkpoint_path): - """ - Load a specific checkpoint file. + """Load a specific checkpoint file. Args: checkpoint_path (str): Path to the checkpoint file @@ -214,13 +214,13 @@ class Pretrainer: print(f"Error loading checkpoint: {e}") return None - def train(self, dataset_paths, output_path="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): """Train the model using multiple specified datasets with checkpoint resumption support.""" if not dataset_paths: raise ValueError("No dataset paths provided") + self.checkpoint_dir = checkpoint_dir # Set checkpoint_dir class variable self._initialize_checkpoint_variables() start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir) @@ -230,6 +230,7 @@ class Pretrainer: criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions() for epoch in range(start_epoch, num_epochs): + self.current_epoch = epoch # Update current_epoch print(f"\nEpoch {epoch+1}/{num_epochs}") print("-" * 20) total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader, @@ -393,7 +394,6 @@ class Pretrainer: print(f"Validation Loss: {avg_val_loss:.4f}") return avg_val_loss - def save_losses(self, csv_file): """Save training and validation losses to a CSV file.""" data = list(zip(