in between safe

This commit is contained in:
Falko Victor Habel 2025-04-14 22:00:50 +02:00
parent 18249a852a
commit 5457bca963
1 changed files with 69 additions and 10 deletions

View File

@ -1,6 +1,8 @@
import torch import torch
from torch import nn from torch import nn
import csv import csv
import datetime
import time
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from ..model.Model import AIIA from ..model.Model import AIIA
@ -112,19 +114,22 @@ class Pretrainer:
return batch_loss return batch_loss
def train(self, dataset_paths,output_path:str="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000): def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
""" num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
Train the model using multiple specified datasets. """Train the model using multiple specified datasets."""
Args:
dataset_paths (list): List of paths to parquet datasets
num_epochs (int): Number of training epochs
batch_size (int): Batch size for training
sample_size (int): Number of samples to use from each dataset
"""
if not dataset_paths: if not dataset_paths:
raise ValueError("No dataset paths provided") raise ValueError("No dataset paths provided")
# Checkpoint tracking variables
last_checkpoint_time = time.time()
checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
last_22_date = None
recent_checkpoints = []
# Create checkpoint directory if specified
if checkpoint_dir is not None:
os.makedirs(checkpoint_dir, exist_ok=True)
# Read and merge all datasets # Read and merge all datasets
dataframes = [] dataframes = []
for path in dataset_paths: for path in dataset_paths:
@ -166,6 +171,59 @@ class Pretrainer:
if batch_data is None: if batch_data is None:
continue continue
# Check if we need to save a checkpoint
current_time = time.time()
current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
today = current_dt.date()
# Regular 2-hour checkpoint
if checkpoint_dir and (current_time - last_checkpoint_time) >= checkpoint_interval:
checkpoint_path = os.path.join(
checkpoint_dir,
f"checkpoint_epoch{epoch+1}_batch{batch_count}.pt"
)
torch.save({
'epoch': epoch + 1,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'projection_head_state_dict': self.projection_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': self.train_losses,
'val_losses': self.val_losses,
}, checkpoint_path)
# Track and maintain only 3 recent checkpoints
recent_checkpoints.append(checkpoint_path)
if len(recent_checkpoints) > 3:
oldest = recent_checkpoints.pop(0)
if os.path.exists(oldest):
os.remove(oldest)
last_checkpoint_time = current_time
print(f"Checkpoint saved at {checkpoint_path}")
# Special 22:00 checkpoint
is_22_oclock = current_dt.hour == 22 and current_dt.minute == 0 and current_dt.second < 10
if checkpoint_dir and is_22_oclock and last_22_date != today:
checkpoint_path = os.path.join(
checkpoint_dir,
f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
)
torch.save({
'epoch': epoch + 1,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'projection_head_state_dict': self.projection_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': self.train_losses,
'val_losses': self.val_losses,
}, checkpoint_path)
last_22_date = today
print(f"22:00 Checkpoint saved at {checkpoint_path}")
# Process the batch
self.optimizer.zero_grad() self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
@ -192,6 +250,7 @@ class Pretrainer:
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path) self.save_losses(losses_path)
def _validate(self, val_loader, criterion_denoise, criterion_rotate): def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss.""" """Perform validation and return average validation loss."""
val_loss = 0.0 val_loss = 0.0