develop #41

Merged
Fabel merged 27 commits from develop into main 2025-04-17 17:08:57 +00:00
1 changed files with 233 additions and 55 deletions
Showing only changes of commit c702834cee - Show all commits

View File

@ -1,9 +1,11 @@
import torch import torch
from torch import nn from torch import nn
import csv import csv
import datetime
import time
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedModel from ..model.Model import AIIA
from ..model.config import AIIAConfig from ..model.config import AIIAConfig
from ..data.DataLoader import AIIADataLoader from ..data.DataLoader import AIIADataLoader
import os import os
@ -21,12 +23,12 @@ class ProjectionHead(nn.Module):
return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task
class Pretrainer: class Pretrainer:
def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None): def __init__(self, model: AIIA, learning_rate=1e-4, config: AIIAConfig=None):
""" """
Initialize the pretrainer with a model. Initialize the pretrainer with a model.
Args: Args:
model (PreTrainedModel): The model instance to pretrain model (AIIA): The model instance to pretrain
learning_rate (float): Learning rate for optimization learning_rate (float): Learning rate for optimization
config (dict): Model configuration containing hidden_size config (dict): Model configuration containing hidden_size
""" """
@ -112,20 +114,169 @@ class Pretrainer:
return batch_loss return batch_loss
def train(self, dataset_paths,output_path:str="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000): def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name):
""" """Save a model checkpoint.
Train the model using multiple specified datasets.
Args: Args:
dataset_paths (list): List of paths to parquet datasets checkpoint_dir (str): Directory to save the checkpoint
num_epochs (int): Number of training epochs epoch (int): Current epoch number
batch_size (int): Batch size for training batch_count (int): Current batch count
sample_size (int): Number of samples to use from each dataset checkpoint_name (str): Name for the checkpoint file
Returns:
str: Path to the saved checkpoint
""" """
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
checkpoint_data = {
'epoch': epoch + 1,
'batch': batch_count,
'model_state_dict': self.model.state_dict(),
'projection_head_state_dict': self.projection_head.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': self.train_losses,
'val_losses': self.val_losses,
}
torch.save(checkpoint_data, checkpoint_path)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None):
"""
Check for checkpoints and load if available.
Args:
checkpoint_dir (str): Directory where checkpoints are stored
specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent.
Returns:
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
"""
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# If a specific checkpoint is requested
if specific_checkpoint:
checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint)
if os.path.exists(checkpoint_path):
return self._load_checkpoint_file(checkpoint_path)
else:
print(f"Specified checkpoint {specific_checkpoint} not found.")
return None
# Find all checkpoint files
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")]
if not checkpoint_files:
print("No checkpoints found in directory.")
return None
# Find the most recent checkpoint
checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
most_recent = checkpoint_files[0]
checkpoint_path = os.path.join(checkpoint_dir, most_recent)
return self._load_checkpoint_file(checkpoint_path)
def _load_checkpoint_file(self, checkpoint_path):
"""
Load a specific checkpoint file.
Args:
checkpoint_path (str): Path to the checkpoint file
Returns:
tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise
"""
try:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Load model state
self.model.load_state_dict(checkpoint['model_state_dict'])
# Load projection head state
self.projection_head.load_state_dict(checkpoint['projection_head_state_dict'])
# Load optimizer state
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Load loss history
self.train_losses = checkpoint.get('train_losses', [])
self.val_losses = checkpoint.get('val_losses', [])
loaded_epoch = checkpoint['epoch']
loaded_batch = checkpoint['batch']
print(f"Checkpoint loaded from {checkpoint_path}")
print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}")
return loaded_epoch, loaded_batch
except Exception as e:
print(f"Error loading checkpoint: {e}")
return None
def train(self, dataset_paths, output_path="AIIA", column="image_bytes",
num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None):
"""Train the model using multiple specified datasets with checkpoint resumption support."""
if not dataset_paths: if not dataset_paths:
raise ValueError("No dataset paths provided") raise ValueError("No dataset paths provided")
# Read and merge all datasets self._initialize_checkpoint_variables()
start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir)
dataframes = self._load_and_merge_datasets(dataset_paths, sample_size)
aiia_loader = self._initialize_data_loader(dataframes, column, batch_size)
criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions()
for epoch in range(start_epoch, num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader,
start_batch if (epoch == start_epoch and resume_training) else 0,
criterion_denoise,
criterion_rotate)
avg_train_loss = total_train_loss / max(batch_count, 1)
self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate)
if val_loss < best_val_loss:
best_val_loss = val_loss
self.model.save(output_path)
print("Best model saved!")
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv')
self.save_losses(losses_path)
def _initialize_checkpoint_variables(self):
"""Initialize checkpoint tracking variables."""
self.last_checkpoint_time = time.time()
self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds
self.last_22_date = None
self.recent_checkpoints = []
def _load_checkpoints(self, checkpoint_dir):
"""Load checkpoints and return start epoch, batch, and resumption flag."""
start_epoch = 0
start_batch = 0
resume_training = False
if checkpoint_dir is not None:
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_info = self.load_checkpoint(checkpoint_dir)
if checkpoint_info:
start_epoch, start_batch = checkpoint_info
resume_training = True
# Adjust epoch to be 0-indexed for the loop
start_epoch -= 1
return start_epoch, start_batch, resume_training
def _load_and_merge_datasets(self, dataset_paths, sample_size):
"""Load and merge datasets."""
dataframes = [] dataframes = []
for path in dataset_paths: for path in dataset_paths:
try: try:
@ -137,10 +288,11 @@ class Pretrainer:
if not dataframes: if not dataframes:
raise ValueError("No valid datasets could be loaded") raise ValueError("No valid datasets could be loaded")
merged_df = pd.concat(dataframes, ignore_index=True) return pd.concat(dataframes, ignore_index=True)
# Initialize data loader def _initialize_data_loader(self, merged_df, column, batch_size):
aiia_loader = AIIADataLoader( """Initialize the data loader."""
return AIIADataLoader(
merged_df, merged_df,
column=column, column=column,
batch_size=batch_size, batch_size=batch_size,
@ -148,49 +300,75 @@ class Pretrainer:
collate_fn=self.safe_collate collate_fn=self.safe_collate
) )
def _initialize_loss_functions(self):
"""Initialize loss functions and tracking variables."""
criterion_denoise = nn.MSELoss() criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss() criterion_rotate = nn.CrossEntropyLoss()
best_val_loss = float('inf') best_val_loss = float('inf')
return criterion_denoise, criterion_rotate, best_val_loss
for epoch in range(num_epochs): def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate):
print(f"\nEpoch {epoch+1}/{num_epochs}") """Handle the training phase."""
print("-" * 20) self.model.train()
self.projection_head.train()
total_train_loss = 0.0
batch_count = 0
# Training phase train_batches = list(enumerate(train_loader))
self.model.train() for i, batch_data in tqdm(train_batches[skip_batches:],
self.projection_head.train() initial=skip_batches,
total_train_loss = 0.0 total=len(train_batches)):
batch_count = 0 if batch_data is None:
continue
for batch_data in tqdm(aiia_loader.train_loader): current_batch = i + 1
if batch_data is None: self._handle_checkpoints(current_batch)
continue
self.optimizer.zero_grad() self.optimizer.zero_grad()
batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate)
if batch_loss > 0: if batch_loss > 0:
batch_loss.backward() batch_loss.backward()
self.optimizer.step() self.optimizer.step()
total_train_loss += batch_loss.item() total_train_loss += batch_loss.item()
batch_count += 1 batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1) return total_train_loss, batch_count
self.train_losses.append(avg_train_loss)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase def _handle_checkpoints(self, current_batch):
self.model.eval() """Handle checkpoint saving logic."""
self.projection_head.eval() current_time = time.time()
val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate) current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time
today = current_dt.date()
if val_loss < best_val_loss: if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval:
best_val_loss = val_loss checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt"
self.model.save_pretrained(output_path) checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
print("Best model save_pretrainedd!")
losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') # Track and maintain only 3 recent checkpoints
self.save_pretrained_losses(losses_path) self.recent_checkpoints.append(checkpoint_path)
if len(self.recent_checkpoints) > 3:
oldest = self.recent_checkpoints.pop(0)
if os.path.exists(oldest):
os.remove(oldest)
self.last_checkpoint_time = current_time
print(f"Checkpoint saved at {checkpoint_path}")
# Special 22:00 checkpoint (considering it's currently 10:15 PM)
is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15
if self.checkpoint_dir and is_22_oclock and self.last_22_date != today:
checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt"
checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name)
self.last_22_date = today
print(f"22:00 Checkpoint saved at {checkpoint_path}")
def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate):
"""Handle the validation phase."""
self.model.eval()
self.projection_head.eval()
return self._validate(val_loader, criterion_denoise, criterion_rotate)
def _validate(self, val_loader, criterion_denoise, criterion_rotate): def _validate(self, val_loader, criterion_denoise, criterion_rotate):
"""Perform validation and return average validation loss.""" """Perform validation and return average validation loss."""
@ -216,8 +394,8 @@ class Pretrainer:
return avg_val_loss return avg_val_loss
def save_pretrained_losses(self, csv_file): def save_losses(self, csv_file):
"""save_pretrained training and validation losses to a CSV file.""" """Save training and validation losses to a CSV file."""
data = list(zip( data = list(zip(
range(1, len(self.train_losses) + 1), range(1, len(self.train_losses) + 1),
self.train_losses, self.train_losses,