diff --git a/pyproject.toml b/pyproject.toml index ff68c3d..8975d57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ include = '\.pyi?$' [project] name = "aiia" -version = "0.3.1" +version = "0.3.2" description = "AIIA Deep Learning Model Implementation" readme = "README.md" authors = [ diff --git a/setup.cfg b/setup.cfg index 598b087..65869f3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = aiia -version = 0.3.1 +version = 0.3.2 author = Falko Habel author_email = falko.habel@gmx.de description = AIIA deep learning model implementation diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py index c1bba6c..3a8a200 100644 --- a/src/aiia/__init__.py +++ b/src/aiia/__init__.py @@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader from .pretrain.pretrainer import Pretrainer, ProjectionHead -__version__ = "0.3.1" +__version__ = "0.3.2" diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 6815a90..62c01c0 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -42,6 +42,8 @@ class Pretrainer: ) self.train_losses = [] self.val_losses = [] + self.checkpoint_dir = None # Initialize checkpoint_dir + self.current_epoch = 0 # Add current_epoch tracking @staticmethod def safe_collate(batch): @@ -140,8 +142,7 @@ class Pretrainer: return checkpoint_path def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None): - """ - Check for checkpoints and load if available. + """Check for checkpoints and load if available. Args: checkpoint_dir (str): Directory where checkpoints are stored @@ -177,8 +178,7 @@ class Pretrainer: return self._load_checkpoint_file(checkpoint_path) def _load_checkpoint_file(self, checkpoint_path): - """ - Load a specific checkpoint file. + """Load a specific checkpoint file. Args: checkpoint_path (str): Path to the checkpoint file @@ -214,13 +214,13 @@ class Pretrainer: print(f"Error loading checkpoint: {e}") return None - def train(self, dataset_paths, output_path="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): """Train the model using multiple specified datasets with checkpoint resumption support.""" if not dataset_paths: raise ValueError("No dataset paths provided") + self.checkpoint_dir = checkpoint_dir # Set checkpoint_dir class variable self._initialize_checkpoint_variables() start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir) @@ -230,6 +230,7 @@ class Pretrainer: criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions() for epoch in range(start_epoch, num_epochs): + self.current_epoch = epoch # Update current_epoch print(f"\nEpoch {epoch+1}/{num_epochs}") print("-" * 20) total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader, @@ -393,7 +394,6 @@ class Pretrainer: print(f"Validation Loss: {avg_val_loss:.4f}") return avg_val_loss - def save_losses(self, csv_file): """Save training and validation losses to a CSV file.""" data = list(zip(