import torch from torch import nn import csv import datetime import time import pandas as pd from tqdm import tqdm from transformers import PreTrainedModel from ..model.config import AIIAConfig from ..data.DataLoader import AIIADataLoader import os class ProjectionHead(nn.Module): def __init__(self, hidden_size): super().__init__() self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1) self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees def forward(self, x, task='denoise'): if task == 'denoise': return self.conv_denoise(x) else: return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task class Pretrainer: def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None): """ Initialize the pretrainer with a model. Args: model (AIIA): The model instance to pretrain learning_rate (float): Learning rate for optimization config (dict): Model configuration containing hidden_size """ self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = model.to(self.device) hidden_size = config.hidden_size self.projection_head = ProjectionHead(hidden_size).to(self.device) self.optimizer = torch.optim.AdamW( list(self.model.parameters()) + list(self.projection_head.parameters()), lr=learning_rate ) self.train_losses = [] self.val_losses = [] self.checkpoint_dir = None # Initialize checkpoint_dir self.current_epoch = 0 # Add current_epoch tracking @staticmethod def safe_collate(batch): """Safely collate batch data handling both denoise and rotate tasks.""" denoise_batch = [] rotate_batch = [] for sample in batch: try: noisy_img, target, task = sample if task == 'denoise': denoise_batch.append({ 'image': noisy_img, 'target': target, 'task': task }) else: # rotate task rotate_batch.append({ 'image': noisy_img, 'target': target, 'task': task }) except Exception as e: print(f"Skipping sample due to error: {e}") continue if not denoise_batch and not rotate_batch: return None batch_data = { 'denoise': None, 'rotate': None } if denoise_batch: images = torch.stack([x['image'] for x in denoise_batch]) targets = torch.stack([x['target'] for x in denoise_batch]) batch_data['denoise'] = (images, targets) if rotate_batch: images = torch.stack([x['image'] for x in rotate_batch]) targets = torch.stack([x['target'] for x in rotate_batch]) batch_data['rotate'] = (images, targets) return batch_data def _process_batch(self, batch_data, criterion_denoise, criterion_rotate, training=True): """Process a single batch of data.""" batch_loss = 0 if batch_data['denoise'] is not None: noisy_imgs, targets = batch_data['denoise'] noisy_imgs = noisy_imgs.to(self.device) targets = targets.to(self.device) features = self.model(noisy_imgs) outputs = self.projection_head(features, task='denoise') loss = criterion_denoise(outputs, targets) batch_loss += loss if batch_data['rotate'] is not None: imgs, targets = batch_data['rotate'] imgs = imgs.to(self.device) targets = targets.long().to(self.device) features = self.model(imgs) outputs = self.projection_head(features, task='rotate') loss = criterion_rotate(outputs, targets) batch_loss += loss return batch_loss def _save_checkpoint(self, checkpoint_dir, epoch, batch_count, checkpoint_name): """Save a model checkpoint. Args: checkpoint_dir (str): Directory to save the checkpoint epoch (int): Current epoch number batch_count (int): Current batch count checkpoint_name (str): Name for the checkpoint file Returns: str: Path to the saved checkpoint """ checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) checkpoint_data = { 'epoch': epoch + 1, 'batch': batch_count, 'model_state_dict': self.model.state_dict(), 'projection_head_state_dict': self.projection_head.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'train_losses': self.train_losses, 'val_losses': self.val_losses, } torch.save(checkpoint_data, checkpoint_path) return checkpoint_path def load_checkpoint(self, checkpoint_dir, specific_checkpoint=None): """Check for checkpoints and load if available. Args: checkpoint_dir (str): Directory where checkpoints are stored specific_checkpoint (str, optional): Specific checkpoint file to load. If None, loads the most recent. Returns: tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise """ # Create checkpoint directory if it doesn't exist os.makedirs(checkpoint_dir, exist_ok=True) # If a specific checkpoint is requested if specific_checkpoint: checkpoint_path = os.path.join(checkpoint_dir, specific_checkpoint) if os.path.exists(checkpoint_path): return self._load_checkpoint_file(checkpoint_path) else: print(f"Specified checkpoint {specific_checkpoint} not found.") return None # Find all checkpoint files checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith("checkpoint_") and f.endswith(".pt")] if not checkpoint_files: print("No checkpoints found in directory.") return None # Find the most recent checkpoint checkpoint_files.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) most_recent = checkpoint_files[0] checkpoint_path = os.path.join(checkpoint_dir, most_recent) return self._load_checkpoint_file(checkpoint_path) def _load_checkpoint_file(self, checkpoint_path): """Load a specific checkpoint file. Args: checkpoint_path (str): Path to the checkpoint file Returns: tuple: (loaded_epoch, loaded_batch) if checkpoint was loaded, None otherwise """ try: checkpoint = torch.load(checkpoint_path, map_location=self.device) # Load model state self.model.load_state_dict(checkpoint['model_state_dict']) # Load projection head state self.projection_head.load_state_dict(checkpoint['projection_head_state_dict']) # Load optimizer state self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Load loss history self.train_losses = checkpoint.get('train_losses', []) self.val_losses = checkpoint.get('val_losses', []) loaded_epoch = checkpoint['epoch'] loaded_batch = checkpoint['batch'] print(f"Checkpoint loaded from {checkpoint_path}") print(f"Resuming from epoch {loaded_epoch}, batch {loaded_batch}") return loaded_epoch, loaded_batch except Exception as e: print(f"Error loading checkpoint: {e}") return None def train(self, dataset_paths, output_path="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000, checkpoint_dir=None): """Train the model using multiple specified datasets with checkpoint resumption support.""" if not dataset_paths: raise ValueError("No dataset paths provided") self.checkpoint_dir = checkpoint_dir # Set checkpoint_dir class variable self._initialize_checkpoint_variables() start_epoch, start_batch, resume_training = self._load_checkpoints(checkpoint_dir) dataframes = self._load_and_merge_datasets(dataset_paths, sample_size) aiia_loader = self._initialize_data_loader(dataframes, column, batch_size) criterion_denoise, criterion_rotate, best_val_loss = self._initialize_loss_functions() for epoch in range(start_epoch, num_epochs): self.current_epoch = epoch # Update current_epoch print(f"\nEpoch {epoch+1}/{num_epochs}") print("-" * 20) total_train_loss, batch_count = self._training_phase(aiia_loader.train_loader, start_batch if (epoch == start_epoch and resume_training) else 0, criterion_denoise, criterion_rotate) avg_train_loss = total_train_loss / max(batch_count, 1) self.train_losses.append(avg_train_loss) print(f"Training Loss: {avg_train_loss:.4f}") val_loss = self._validation_phase(aiia_loader.val_loader, criterion_denoise, criterion_rotate) if val_loss < best_val_loss: best_val_loss = val_loss self.model.save_pretrained(output_path) print("Best model saved!") losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') self.save_losses(losses_path) def _initialize_checkpoint_variables(self): """Initialize checkpoint tracking variables.""" self.last_checkpoint_time = time.time() self.checkpoint_interval = 2 * 60 * 60 # 2 hours in seconds self.last_22_date = None self.recent_checkpoints = [] def _load_checkpoints(self, checkpoint_dir): """Load checkpoints and return start epoch, batch, and resumption flag.""" start_epoch = 0 start_batch = 0 resume_training = False if checkpoint_dir is not None: os.makedirs(checkpoint_dir, exist_ok=True) checkpoint_info = self.load_checkpoint(checkpoint_dir) if checkpoint_info: start_epoch, start_batch = checkpoint_info resume_training = True # Adjust epoch to be 0-indexed for the loop start_epoch -= 1 return start_epoch, start_batch, resume_training def _load_and_merge_datasets(self, dataset_paths, sample_size): """Load and merge datasets.""" dataframes = [] for path in dataset_paths: try: df = pd.read_parquet(path).head(sample_size) dataframes.append(df) except Exception as e: print(f"Error loading dataset {path}: {e}") if not dataframes: raise ValueError("No valid datasets could be loaded") return pd.concat(dataframes, ignore_index=True) def _initialize_data_loader(self, merged_df, column, batch_size): """Initialize the data loader.""" return AIIADataLoader( merged_df, column=column, batch_size=batch_size, pretraining=True, collate_fn=self.safe_collate ) def _initialize_loss_functions(self): """Initialize loss functions and tracking variables.""" criterion_denoise = nn.MSELoss() criterion_rotate = nn.CrossEntropyLoss() best_val_loss = float('inf') return criterion_denoise, criterion_rotate, best_val_loss def _training_phase(self, train_loader, skip_batches, criterion_denoise, criterion_rotate): """Handle the training phase.""" self.model.train() self.projection_head.train() total_train_loss = 0.0 batch_count = 0 train_batches = list(enumerate(train_loader)) for i, batch_data in tqdm(train_batches[skip_batches:], initial=skip_batches, total=len(train_batches)): if batch_data is None: continue current_batch = i + 1 self._handle_checkpoints(current_batch) self.optimizer.zero_grad() batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) if batch_loss > 0: batch_loss.backward() self.optimizer.step() total_train_loss += batch_loss.item() batch_count += 1 return total_train_loss, batch_count def _handle_checkpoints(self, current_batch): """Handle checkpoint saving logic.""" current_time = time.time() current_dt = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=2))) # German time today = current_dt.date() if self.checkpoint_dir and (current_time - self.last_checkpoint_time) >= self.checkpoint_interval: checkpoint_name = f"checkpoint_epoch{self.current_epoch+1}_batch{current_batch}.pt" checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name) # Track and maintain only 3 recent checkpoints self.recent_checkpoints.append(checkpoint_path) if len(self.recent_checkpoints) > 3: oldest = self.recent_checkpoints.pop(0) if os.path.exists(oldest): os.remove(oldest) self.last_checkpoint_time = current_time print(f"Checkpoint saved at {checkpoint_path}") # Special 22:00 checkpoint (considering it's currently 10:15 PM) is_22_oclock = current_dt.hour == 22 and current_dt.minute < 15 if self.checkpoint_dir and is_22_oclock and self.last_22_date != today: checkpoint_name = f"checkpoint_22h_{today.strftime('%Y%m%d')}.pt" checkpoint_path = self._save_checkpoint(self.checkpoint_dir, self.current_epoch, current_batch, checkpoint_name) self.last_22_date = today print(f"22:00 Checkpoint saved at {checkpoint_path}") def _validation_phase(self, val_loader, criterion_denoise, criterion_rotate): """Handle the validation phase.""" self.model.eval() self.projection_head.eval() return self._validate(val_loader, criterion_denoise, criterion_rotate) def _validate(self, val_loader, criterion_denoise, criterion_rotate): """Perform validation and return average validation loss.""" val_loss = 0.0 val_batch_count = 0 with torch.no_grad(): for batch_data in val_loader: if batch_data is None: continue batch_loss = self._process_batch( batch_data, criterion_denoise, criterion_rotate, training=False ) if batch_loss > 0: val_loss += batch_loss.item() val_batch_count += 1 avg_val_loss = val_loss / max(val_batch_count, 1) self.val_losses.append(avg_val_loss) print(f"Validation Loss: {avg_val_loss:.4f}") return avg_val_loss def save_losses(self, csv_file): """Save training and validation losses to a CSV file.""" data = list(zip( range(1, len(self.train_losses) + 1), self.train_losses, self.val_losses )) with open(csv_file, mode='w', newline='') as file: writer = csv.writer(file) writer.writerow(['Epoch', 'Train Loss', 'Validation Loss']) writer.writerows(data) print(f"Loss data has been written to {csv_file}")