From c702834cee777205010976e68167cda8a0e3012f Mon Sep 17 00:00:00 2001 From: Falko Victor Habel Date: Thu, 17 Apr 2025 10:51:29 +0000 Subject: [PATCH] manual mergerequest --- src/aiia/pretrain/pretrainer.py | 288 ++++++++++++++++++++++++++------ 1 file changed, 233 insertions(+), 55 deletions(-) diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index 93df9dd..f94af2c 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -1,9 +1,11 @@ import torch from torch import nn import csv +import datetime +import time import pandas as pd from tqdm import tqdm -from transformers import PreTrainedModel +from ..model.Model import AIIA from ..model.config import AIIAConfig from ..data.DataLoader import AIIADataLoader import os @@ -21,12 +23,12 @@ class ProjectionHead(nn.Module): return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task class Pretrainer: - def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None): + def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None): """ Initialize the pretrainer with a model. Args: - model (PreTrainedModel): The model instance to pretrain + model (AIIA): The model instance to pretrain learning_rate (float): Learning rate for optimization config (dict): Model configuration containing hidden_size """ @@ -112,20 +114,169 @@ class Pretrainer: return batch_loss - def train(self, dataset_paths,output_path:str="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000): - """ - Train the model using multiple specified datasets. + def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name): + """Save a model checkpoint. Args: - dataset_paths (list): List of paths to parquet datasets - num_epochs (int): Number of training epochs - batch_size (int): Batch size for training - sample_size (int): Number of samples to use from each dataset + checkpoint_dir (str): Directory to save the checkpoint + epoch (int): Current epoch number + batch_count (int): Current batch count + checkpoint_name (str): Name for the checkpoint file + + Returns: + str: Path to the saved checkpoint """ + checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + checkpoint_data = { + 'epoch': epoch + 1, + 'batch': batch_count, + 'model_state_dict': self.model.state_dict(), + 'projection_head_state_dict': self.projection_head.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'train_losses': self.train_losses, + 'val_losses': self.val_losses, + } + torch.save(checkpoint_data, checkpoint_path) + return checkpoint_path + + def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None): + """ + Check for checkpoints and load if available. + + Args: + checkpoint_dir (str): Directory where checkpoints are stored + specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent. + + Returns: + tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise + """ + # Create checkpoint directory if it doesn't exist + os.makedirs(checkpoint_dir, exist_ok=True) + + # If a specific checkpoint is requested + if specific_checkpoint: + checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint) + if os.path.exists(checkpoint_path): + return self._load_checkpoint_file(checkpoint_path) + else: + print(f"Specified checkpoint {specific_checkpoint} not found.") + return None + + # Find all checkpoint files + checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")] + + if not checkpoint_files: + print("No checkpoints found in directory.") + return None + + # Find the most recent checkpoint + checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) + most_recent = checkpoint_files[0] + checkpoint_path = os.path.join(checkpoint_dir, most_recent) + + return self._load_checkpoint_file(checkpoint_path) + + def _load_checkpoint_file(self, checkpoint_path): + """ + Load a specific checkpoint file. + + Args: + checkpoint_path (str): Path to the checkpoint file + + Returns: + tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise + """ + try: + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + # Load model state + self.model.load_state_dict(checkpoint['model_state_dict']) + + # Load projection head state + self.projection_head.load_state_dict(checkpoint['projection_head_state_dict']) + + # Load optimizer state + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + # Load loss history + self.train_losses = checkpoint.get('train_losses', []) + self.val_losses = checkpoint.get('val_losses', []) + + loaded_epoch = checkpoint['epoch'] + loaded_batch = checkpoint['batch'] + + print(f"Checkpoint loaded from {checkpoint_path}") + print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}") + + return loaded_epoch, loaded_batch + + except Exception as e: + print(f"Error loading checkpoint: {e}") + return None + + + def train(self, dataset_paths, output_path="AIIA", column="image_bytes", + num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): + """Train the model using multiple specified datasets with checkpoint resumption support.""" if not dataset_paths: raise ValueError("No dataset paths provided") - # Read and merge all datasets + self._initialize_checkpoint_variables() + start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir) + + dataframes = self._load_and_merge_datasets(dataset_paths, sample_size) + aiia_loader = self._initialize_data_loader(dataframes, column, batch_size) + + criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions() + + for epoch in range(start_epoch, num_epochs): + print(f"\nEpoch {epoch+1}/{num_epochs}") + print("-" * 20) + total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader, + start_batch if (epoch == start_epoch and resume_training) else 0, + criterion_denoise, + criterion_rotate) + + avg_train_loss = total_train_loss / max(batch_count, 1) + self.train_losses.append(avg_train_loss) + print(f"Training Loss: {avg_train_loss:.4f}") + + val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate) + + if val_loss < best_val_loss: + best_val_loss = val_loss + self.model.save(output_path) + print("Best model saved!") + + losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') + self.save_losses(losses_path) + + def _initialize_checkpoint_variables(self): + """Initialize checkpoint tracking variables.""" + self.last_checkpoint_time = time.time() + self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds + self.last_22_date = None + self.recent_checkpoints = [] + + def _load_checkpoints(self, checkpoint_dir): + """Load checkpoints and return start epoch, batch, and resumption flag.""" + start_epoch = 0 + start_batch = 0 + resume_training = False + + if checkpoint_dir is not None: + os.makedirs(checkpoint_dir, exist_ok=True) + checkpoint_info = self.load_checkpoint(checkpoint_dir) + if checkpoint_info: + start_epoch, start_batch = checkpoint_info + resume_training = True + # Adjust epoch to be 0-indexed for the loop + start_epoch -= 1 + + return start_epoch, start_batch, resume_training + + def _load_and_merge_datasets(self, dataset_paths, sample_size): + """Load and merge datasets.""" dataframes = [] for path in dataset_paths: try: @@ -133,14 +284,15 @@ class Pretrainer: dataframes.append(df) except Exception as e: print(f"Error loading dataset {path}: {e}") - + if not dataframes: raise ValueError("No valid datasets could be loaded") - - merged_df = pd.concat(dataframes, ignore_index=True) - # Initialize data loader - aiia_loader = AIIADataLoader( + return pd.concat(dataframes, ignore_index=True) + + def _initialize_data_loader(self, merged_df, column, batch_size): + """Initialize the data loader.""" + return AIIADataLoader( merged_df, column=column, batch_size=batch_size, @@ -148,49 +300,75 @@ class Pretrainer: collate_fn=self.safe_collate ) + def _initialize_loss_functions(self): + """Initialize loss functions and tracking variables.""" criterion_denoise = nn.MSELoss() criterion_rotate = nn.CrossEntropyLoss() best_val_loss = float('inf') + return criterion_denoise, criterion_rotate, best_val_loss - for epoch in range(num_epochs): - print(f"\nEpoch {epoch+1}/{num_epochs}") - print("-" * 20) - - # Training phase - self.model.train() - self.projection_head.train() - total_train_loss = 0.0 - batch_count = 0 - - for batch_data in tqdm(aiia_loader.train_loader): - if batch_data is None: - continue - - self.optimizer.zero_grad() - batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) - - if batch_loss > 0: - batch_loss.backward() - self.optimizer.step() - total_train_loss += batch_loss.item() - batch_count += 1 - - avg_train_loss = total_train_loss / max(batch_count, 1) - self.train_losses.append(avg_train_loss) - print(f"Training Loss: {avg_train_loss:.4f}") + def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate): + """Handle the training phase.""" + self.model.train() + self.projection_head.train() + total_train_loss = 0.0 + batch_count = 0 - # Validation phase - self.model.eval() - self.projection_head.eval() - val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate) - - if val_loss < best_val_loss: - best_val_loss = val_loss - self.model.save_pretrained(output_path) - print("Best model save_pretrainedd!") + train_batches = list(enumerate(train_loader)) + for i, batch_data in tqdm(train_batches[skip_batches:], + initial=skip_batches, + total=len(train_batches)): + if batch_data is None: + continue - losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') - self.save_pretrained_losses(losses_path) + current_batch = i + 1 + self._handle_checkpoints(current_batch) + + self.optimizer.zero_grad() + batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) + + if batch_loss > 0: + batch_loss.backward() + self.optimizer.step() + total_train_loss += batch_loss.item() + batch_count += 1 + + return total_train_loss, batch_count + + def _handle_checkpoints(self, current_batch): + """Handle checkpoint saving logic.""" + current_time = time.time() + current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time + today = current_dt.date() + + if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval: + checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt" + checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name) + + # Track and maintain only 3 recent checkpoints + self.recent_checkpoints.append(checkpoint_path) + if len(self.recent_checkpoints) > 3: + oldest = self.recent_checkpoints.pop(0) + if os.path.exists(oldest): + os.remove(oldest) + + self.last_checkpoint_time = current_time + print(f"Checkpoint saved at {checkpoint_path}") + + # Special 22:00 checkpoint (considering it's currently 10:15 PM) + is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15 + + if self.checkpoint_dir and is_22_oclock and self.last_22_date != today: + checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" + checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name) + self.last_22_date = today + print(f"22:00 Checkpoint saved at {checkpoint_path}") + + def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate): + """Handle the validation phase.""" + self.model.eval() + self.projection_head.eval() + return self._validate(val_loader, criterion_denoise, criterion_rotate) def _validate(self, val_loader, criterion_denoise, criterion_rotate): """Perform validation and return average validation loss.""" @@ -216,8 +394,8 @@ class Pretrainer: return avg_val_loss - def save_pretrained_losses(self, csv_file): - """save_pretrained training and validation losses to a CSV file.""" + def save_losses(self, csv_file): + """Save training and validation losses to a CSV file.""" data = list(zip( range(1, len(self.train_losses) + 1), self.train_losses,