From c33e9c97406b705c8dfa172bba7d608b0977aa00 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Wed, 23 Apr 2025 15:23:01 +0200 Subject: [PATCH 1/3] hotfix for pretrainer --- src/aiia/pretrain/pretrainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 6815a90..62c01c0 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -42,6 +42,8 @@ class Pretrainer: ) self.train_losses = [] self.val_losses = [] + self.checkpoint_dir = None # Initialize checkpoint_dir + self.current_epoch = 0 # Add current_epoch tracking @staticmethod def safe_collate(batch): @@ -140,8 +142,7 @@ class Pretrainer: return checkpoint_path def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None): - """ - Check for checkpoints and load if available. + """Check for checkpoints and load if available. Args: checkpoint_dir (str): Directory where checkpoints are stored @@ -177,8 +178,7 @@ class Pretrainer: return self._load_checkpoint_file(checkpoint_path) def _load_checkpoint_file(self, checkpoint_path): - """ - Load a specific checkpoint file. + """Load a specific checkpoint file. Args: checkpoint_path (str): Path to the checkpoint file @@ -214,13 +214,13 @@ class Pretrainer: print(f"Error loading checkpoint: {e}") return None - def train(self, dataset_paths, output_path="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): """Train the model using multiple specified datasets with checkpoint resumption support.""" if not dataset_paths: raise ValueError("No dataset paths provided") + self.checkpoint_dir = checkpoint_dir # Set checkpoint_dir class variable self._initialize_checkpoint_variables() start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir) @@ -230,6 +230,7 @@ class Pretrainer: criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions() for epoch in range(start_epoch, num_epochs): + self.current_epoch = epoch # Update current_epoch print(f"\nEpoch {epoch+1}/{num_epochs}") print("-" * 20) total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader, @@ -393,7 +394,6 @@ class Pretrainer: print(f"Validation Loss: {avg_val_loss:.4f}") return avg_val_loss - def save_losses(self, csv_file): """Save training and validation losses to a CSV file.""" data = list(zip( From 23c6891984d7837b7e62eed1df1dea70dd7b2a38 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Wed, 23 Apr 2025 15:23:20 +0200 Subject: [PATCH 2/3] updated version number --- pyproject.toml | 2 +- setup.cfg | 2 +- src/aiia/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ff68c3d..8d6469c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ include = '\.pyi?$' [project] name = "aiia" -version = "0.3.1" +version = "0.3.11" description = "AIIA Deep Learning Model Implementation" readme = "README.md" authors = [ diff --git a/setup.cfg b/setup.cfg index 598b087..13b41cf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = aiia -version = 0.3.1 +version = 0.3.11 author = Falko Habel author_email = falko.habel@gmx.de description = AIIA deep learning model implementation diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py index c1bba6c..07037e6 100644 --- a/src/aiia/__init__.py +++ b/src/aiia/__init__.py @@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader from .pretrain.pretrainer import Pretrainer, ProjectionHead -__version__ = "0.3.1" +__version__ = "0.3.11" From 3ae9c7d09e04c59fea7586ed6ce25a9832ca0182 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Wed, 23 Apr 2025 15:34:20 +0200 Subject: [PATCH 3/3] fix version number --- pyproject.toml | 2 +- setup.cfg | 2 +- src/aiia/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8d6469c..8975d57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ include = '\.pyi?$' [project] name = "aiia" -version = "0.3.11" +version = "0.3.2" description = "AIIA Deep Learning Model Implementation" readme = "README.md" authors = [ diff --git a/setup.cfg b/setup.cfg index 13b41cf..65869f3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = aiia -version = 0.3.11 +version = 0.3.2 author = Falko Habel author_email = falko.habel@gmx.de description = AIIA deep learning model implementation diff --git a/src/aiia/__init__.py b/src/aiia/__init__.py index 07037e6..3a8a200 100644 --- a/src/aiia/__init__.py +++ b/src/aiia/__init__.py @@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader from .pretrain.pretrainer import Pretrainer, ProjectionHead -__version__ = "0.3.11" +__version__ = "0.3.2"