Merge pull request 'feat/bugfix' (#42) from feat/bugfix into main
Run VectorLoader Script / Explore-Gitea-Actions (push) Successful in 20s Details
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 38s Details

Reviewed-on: #42
This commit is contained in:
Falko Victor Habel 2025-04-23 13:39:01 +00:00
commit ff6f279728
4 changed files with 9 additions and 9 deletions

View File

@ -10,7 +10,7 @@ include = '\.pyi?$'
[project] [project]
name = "aiia" name = "aiia"
version = "0.3.1" version = "0.3.2"
description = "AIIA Deep Learning Model Implementation" description = "AIIA Deep Learning Model Implementation"
readme = "README.md" readme = "README.md"
authors = [ authors = [

View File

@ -1,6 +1,6 @@
[metadata] [metadata]
name = aiia name = aiia
version = 0.3.1 version = 0.3.2
author = Falko Habel author = Falko Habel
author_email = falko.habel@gmx.de author_email = falko.habel@gmx.de
description = AIIA deep learning model implementation description = AIIA deep learning model implementation

View File

@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
from .pretrain.pretrainer import Pretrainer, ProjectionHead from .pretrain.pretrainer import Pretrainer, ProjectionHead
__version__ = "0.3.1" __version__ = "0.3.2"

View File

@ -42,6 +42,8 @@ class Pretrainer:
) )
self.train_losses = [] self.train_losses = []
self.val_losses = [] self.val_losses = []
self.checkpoint_dir = None # Initialize checkpoint_dir
self.current_epoch = 0 # Add current_epoch tracking
@staticmethod @staticmethod
def safe_collate(batch): def safe_collate(batch):
@ -140,8 +142,7 @@ class Pretrainer:
return checkpoint_path return checkpoint_path
def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None): def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None):
""" """Check for checkpoints and load if available.
Check for checkpoints and load if available.
Args: Args:
checkpoint_dir (str): Directory where checkpoints are stored checkpoint_dir (str): Directory where checkpoints are stored
@ -177,8 +178,7 @@ class Pretrainer:
return self._load_checkpoint_file(checkpoint_path) return self._load_checkpoint_file(checkpoint_path)
def _load_checkpoint_file(self, checkpoint_path): def _load_checkpoint_file(self, checkpoint_path):
""" """Load a specific checkpoint file.
Load a specific checkpoint file.
Args: Args:
checkpoint_path (str): Path to the checkpoint file checkpoint_path (str): Path to the checkpoint file
@ -214,13 +214,13 @@ class Pretrainer:
print(f"Error loading checkpoint: {e}") print(f"Error loading checkpoint: {e}")
return None return None
def train(self, dataset_paths, output_path="AIIA", column="image_bytes", def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
"""Train the model using multiple specified datasets with checkpoint resumption support.""" """Train the model using multiple specified datasets with checkpoint resumption support."""
if not dataset_paths: if not dataset_paths:
raise ValueError("No dataset paths provided") raise ValueError("No dataset paths provided")
self.checkpoint_dir = checkpoint_dir # Set checkpoint_dir class variable
self._initialize_checkpoint_variables() self._initialize_checkpoint_variables()
start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir) start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir)
@ -230,6 +230,7 @@ class Pretrainer:
criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions() criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions()
for epoch in range(start_epoch, num_epochs): for epoch in range(start_epoch, num_epochs):
self.current_epoch = epoch # Update current_epoch
print(f"\nEpoch {epoch+1}/{num_epochs}") print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20) print("-" * 20)
total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader, total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader,
@ -393,7 +394,6 @@ class Pretrainer:
print(f"Validation Loss: {avg_val_loss:.4f}") print(f"Validation Loss: {avg_val_loss:.4f}")
return avg_val_loss return avg_val_loss
def save_losses(self, csv_file): def save_losses(self, csv_file):
"""Save training and validation losses to a CSV file.""" """Save training and validation losses to a CSV file."""
data = list(zip( data = list(zip(