import torch from torch import nn import csv import pandas as pd from tqdm import tqdm from transformers import PreTrainedModel from ..model.config import AIIAConfig from ..data.DataLoader import AIIADataLoader import os class ProjectionHead(nn.Module): def __init__(self, hidden_size): super().__init__() self.conv_denoise = nn.Conv2d(hidden_size, 3, kernel_size=1) self.conv_rotate = nn.Conv2d(hidden_size, 4, kernel_size=1) # 4 classes for 0, 90, 180, 270 degrees def forward(self, x, task='denoise'): if task == 'denoise': return self.conv_denoise(x) else: return self.conv_rotate(x).mean(dim=(2, 3)) # Global average pooling for rotation task class Pretrainer: def __init__(self, model: PreTrainedModel, learning_rate=1e-4, config: AIIAConfig=None): """ Initialize the pretrainer with a model. Args: model (PreTrainedModel): The model instance to pretrain learning_rate (float): Learning rate for optimization config (dict): Model configuration containing hidden_size """ self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = model.to(self.device) hidden_size = config.hidden_size self.projection_head = ProjectionHead(hidden_size).to(self.device) self.optimizer = torch.optim.AdamW( list(self.model.parameters()) + list(self.projection_head.parameters()), lr=learning_rate ) self.train_losses = [] self.val_losses = [] @staticmethod def safe_collate(batch): """Safely collate batch data handling both denoise and rotate tasks.""" denoise_batch = [] rotate_batch = [] for sample in batch: try: noisy_img, target, task = sample if task == 'denoise': denoise_batch.append({ 'image': noisy_img, 'target': target, 'task': task }) else: # rotate task rotate_batch.append({ 'image': noisy_img, 'target': target, 'task': task }) except Exception as e: print(f"Skipping sample due to error: {e}") continue if not denoise_batch and not rotate_batch: return None batch_data = { 'denoise': None, 'rotate': None } if denoise_batch: images = torch.stack([x['image'] for x in denoise_batch]) targets = torch.stack([x['target'] for x in denoise_batch]) batch_data['denoise'] = (images, targets) if rotate_batch: images = torch.stack([x['image'] for x in rotate_batch]) targets = torch.stack([x['target'] for x in rotate_batch]) batch_data['rotate'] = (images, targets) return batch_data def _process_batch(self, batch_data, criterion_denoise, criterion_rotate, training=True): """Process a single batch of data.""" batch_loss = 0 if batch_data['denoise'] is not None: noisy_imgs, targets = batch_data['denoise'] noisy_imgs = noisy_imgs.to(self.device) targets = targets.to(self.device) features = self.model(noisy_imgs) outputs = self.projection_head(features, task='denoise') loss = criterion_denoise(outputs, targets) batch_loss += loss if batch_data['rotate'] is not None: imgs, targets = batch_data['rotate'] imgs = imgs.to(self.device) targets = targets.long().to(self.device) features = self.model(imgs) outputs = self.projection_head(features, task='rotate') loss = criterion_rotate(outputs, targets) batch_loss += loss return batch_loss def train(self, dataset_paths,output_path:str="AIIA", column="image_bytes", num_epochs=3, batch_size=2, sample_size=10000): """ Train the model using multiple specified datasets. Args: dataset_paths (list): List of paths to parquet datasets num_epochs (int): Number of training epochs batch_size (int): Batch size for training sample_size (int): Number of samples to use from each dataset """ if not dataset_paths: raise ValueError("No dataset paths provided") # Read and merge all datasets dataframes = [] for path in dataset_paths: try: df = pd.read_parquet(path).head(sample_size) dataframes.append(df) except Exception as e: print(f"Error loading dataset {path}: {e}") if not dataframes: raise ValueError("No valid datasets could be loaded") merged_df = pd.concat(dataframes, ignore_index=True) # Initialize data loader aiia_loader = AIIADataLoader( merged_df, column=column, batch_size=batch_size, pretraining=True, collate_fn=self.safe_collate ) criterion_denoise = nn.MSELoss() criterion_rotate = nn.CrossEntropyLoss() best_val_loss = float('inf') for epoch in range(num_epochs): print(f"\nEpoch {epoch+1}/{num_epochs}") print("-" * 20) # Training phase self.model.train() self.projection_head.train() total_train_loss = 0.0 batch_count = 0 for batch_data in tqdm(aiia_loader.train_loader): if batch_data is None: continue self.optimizer.zero_grad() batch_loss = self._process_batch(batch_data, criterion_denoise, criterion_rotate) if batch_loss > 0: batch_loss.backward() self.optimizer.step() total_train_loss += batch_loss.item() batch_count += 1 avg_train_loss = total_train_loss / max(batch_count, 1) self.train_losses.append(avg_train_loss) print(f"Training Loss: {avg_train_loss:.4f}") # Validation phase self.model.eval() self.projection_head.eval() val_loss = self._validate(aiia_loader.val_loader, criterion_denoise, criterion_rotate) if val_loss < best_val_loss: best_val_loss = val_loss self.model.save_pretrained(output_path) print("Best model save_pretrainedd!") losses_path = os.path.join(os.path.dirname(output_path), 'losses.csv') self.save_pretrained_losses(losses_path) def _validate(self, val_loader, criterion_denoise, criterion_rotate): """Perform validation and return average validation loss.""" val_loss = 0.0 val_batch_count = 0 with torch.no_grad(): for batch_data in val_loader: if batch_data is None: continue batch_loss = self._process_batch( batch_data, criterion_denoise, criterion_rotate, training=False ) if batch_loss > 0: val_loss += batch_loss.item() val_batch_count += 1 avg_val_loss = val_loss / max(val_batch_count, 1) self.val_losses.append(avg_val_loss) print(f"Validation Loss: {avg_val_loss:.4f}") return avg_val_loss def save_pretrained_losses(self, csv_file): """save_pretrained training and validation losses to a CSV file.""" data = list(zip( range(1, len(self.train_losses) + 1), self.train_losses, self.val_losses )) with open(csv_file, mode='w', newline='') as file: writer = csv.writer(file) writer.writerow(['Epoch', 'Train Loss', 'Validation Loss']) writer.writerows(data) print(f"Loss data has been written to {csv_file}")