Compare commits
4 Commits
c09212de6f
...
ff6f279728
Author | SHA1 | Date |
---|---|---|
|
ff6f279728 | |
|
3ae9c7d09e | |
|
23c6891984 | |
|
c33e9c9740 |
|
@ -10,7 +10,7 @@ include = '\.pyi?$'
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "aiia"
|
name = "aiia"
|
||||||
version = "0.3.1"
|
version = "0.3.2"
|
||||||
description = "AIIA Deep Learning Model Implementation"
|
description = "AIIA Deep Learning Model Implementation"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[metadata]
|
[metadata]
|
||||||
name = aiia
|
name = aiia
|
||||||
version = 0.3.1
|
version = 0.3.2
|
||||||
author = Falko Habel
|
author = Falko Habel
|
||||||
author_email = falko.habel@gmx.de
|
author_email = falko.habel@gmx.de
|
||||||
description = AIIA deep learning model implementation
|
description = AIIA deep learning model implementation
|
||||||
|
|
|
@ -4,4 +4,4 @@ from .data.DataLoader import DataLoader
|
||||||
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
from .pretrain.pretrainer import Pretrainer, ProjectionHead
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.3.1"
|
__version__ = "0.3.2"
|
||||||
|
|
|
@ -42,6 +42,8 @@ class Pretrainer:
|
||||||
)
|
)
|
||||||
self.train_losses = []
|
self.train_losses = []
|
||||||
self.val_losses = []
|
self.val_losses = []
|
||||||
|
self.checkpoint_dir = None # Initialize checkpoint_dir
|
||||||
|
self.current_epoch = 0 # Add current_epoch tracking
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def safe_collate(batch):
|
def safe_collate(batch):
|
||||||
|
@ -140,8 +142,7 @@ class Pretrainer:
|
||||||
return checkpoint_path
|
return checkpoint_path
|
||||||
|
|
||||||
def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None):
|
def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None):
|
||||||
"""
|
"""Check for checkpoints and load if available.
|
||||||
Check for checkpoints and load if available.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint_dir (str): Directory where checkpoints are stored
|
checkpoint_dir (str): Directory where checkpoints are stored
|
||||||
|
@ -177,8 +178,7 @@ class Pretrainer:
|
||||||
return self._load_checkpoint_file(checkpoint_path)
|
return self._load_checkpoint_file(checkpoint_path)
|
||||||
|
|
||||||
def _load_checkpoint_file(self, checkpoint_path):
|
def _load_checkpoint_file(self, checkpoint_path):
|
||||||
"""
|
"""Load a specific checkpoint file.
|
||||||
Load a specific checkpoint file.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint_path (str): Path to the checkpoint file
|
checkpoint_path (str): Path to the checkpoint file
|
||||||
|
@ -214,13 +214,13 @@ class Pretrainer:
|
||||||
print(f"Error loading checkpoint: {e}")
|
print(f"Error loading checkpoint: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
|
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
|
||||||
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
|
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
|
||||||
"""Train the model using multiple specified datasets with checkpoint resumption support."""
|
"""Train the model using multiple specified datasets with checkpoint resumption support."""
|
||||||
if not dataset_paths:
|
if not dataset_paths:
|
||||||
raise ValueError("No dataset paths provided")
|
raise ValueError("No dataset paths provided")
|
||||||
|
|
||||||
|
self.checkpoint_dir = checkpoint_dir # Set checkpoint_dir class variable
|
||||||
self._initialize_checkpoint_variables()
|
self._initialize_checkpoint_variables()
|
||||||
start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir)
|
start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir)
|
||||||
|
|
||||||
|
@ -230,6 +230,7 @@ class Pretrainer:
|
||||||
criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions()
|
criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions()
|
||||||
|
|
||||||
for epoch in range(start_epoch, num_epochs):
|
for epoch in range(start_epoch, num_epochs):
|
||||||
|
self.current_epoch = epoch # Update current_epoch
|
||||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||||
print("-" * 20)
|
print("-" * 20)
|
||||||
total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader,
|
total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader,
|
||||||
|
@ -393,7 +394,6 @@ class Pretrainer:
|
||||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||||
return avg_val_loss
|
return avg_val_loss
|
||||||
|
|
||||||
|
|
||||||
def save_losses(self, csv_file):
|
def save_losses(self, csv_file):
|
||||||
"""Save training and validation losses to a CSV file."""
|
"""Save training and validation losses to a CSV file."""
|
||||||
data = list(zip(
|
data = list(zip(
|
||||||
|
|
Loading…
Reference in New Issue