diff --git a/src/pretrain.py b/src/pretrain.py index 8436d51..4d98538 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -1,29 +1,65 @@ import torch -from torch import nn +from torch import nn, utils import pandas as pd from aiia.model.config import AIIAConfig from aiia.model import AIIABase from aiia.data.DataLoader import AIIADataLoader +import os +import copy def pretrain_model(data_path1, data_path2, num_epochs=3): + # Read and merge datasets df1 = pd.read_parquet(data_path1).head(10000) df2 = pd.read_parquet(data_path2).head(10000) merged_df = pd.concat([df1, df2], ignore_index=True) + # Model configuration config = AIIAConfig( model_name="AIIA-Base-512x20k", ) + # Initialize model and data loader model = AIIABase(config) + + # Define a custom collate function to handle preprocessing and skip bad samples + def safe_collate(batch): + processed_batch = [] + for sample in batch: + try: + # Process each sample here (e.g., decode image, preprocess, etc.) + # Replace with actual preprocessing steps + processed_sample = { + 'image': torch.randn(3, 224, 224), # Example tensor + 'target': torch.randint(0, 10, (1,)), # Example target + 'task': 'denoise' # Example task + } + processed_batch.append(processed_sample) + except Exception as e: + print(f"Skipping sample due to error: {e}") + if not processed_batch: + return None # Skip batch if all samples are invalid - aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32, pretraining=True) + # Stack tensors for the batch + images = torch.stack([x['image'] for x in processed_batch]) + targets = torch.stack([x['target'] for x in processed_batch]) + tasks = [x['task'] for x in processed_batch] + + return (images, targets, tasks) + + aiia_loader = AIIADataLoader( + merged_df, + column="image_bytes", + batch_size=32, + pretraining=True, + collate_fn=safe_collate + ) train_loader = aiia_loader.train_loader val_loader = aiia_loader.val_loader + # Define loss functions and optimizer criterion_denoise = nn.MSELoss() criterion_rotate = nn.CrossEntropyLoss() - optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) device = "cuda" if torch.cuda.is_available() else "cpu" @@ -35,49 +71,55 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): print(f"\nEpoch {epoch+1}/{num_epochs}") print("-" * 20) + # Training phase model.train() total_train_loss = 0.0 - denoise_losses = [] - rotate_losses = [] - + for batch in train_loader: + if batch is None: + continue # Skip empty batches + noisy_imgs, targets, tasks = batch - noisy_imgs = noisy_imgs.to(device) targets = targets.to(device) + optimizer.zero_grad() - + outputs = model(noisy_imgs) task_losses = [] + for i, task in enumerate(tasks): if task == 'denoise': loss = criterion_denoise(outputs[i], targets[i]) - denoise_losses.append(loss.item()) else: loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0)) - rotate_losses.append(loss.item()) task_losses.append(loss) batch_loss = sum(task_losses) / len(task_losses) batch_loss.backward() optimizer.step() - + total_train_loss += batch_loss.item() + avg_total_train_loss = total_train_loss / len(train_loader) print(f"Training Loss: {avg_total_train_loss:.4f}") + # Validation phase model.eval() + val_loss = 0.0 + with torch.no_grad(): - val_losses = [] for batch in val_loader: + if batch is None: + continue + noisy_imgs, targets, tasks = batch - noisy_imgs = noisy_imgs.to(device) targets = targets.to(device) outputs = model(noisy_imgs) - task_losses = [] + for i, task in enumerate(tasks): if task == 'denoise': loss = criterion_denoise(outputs[i], targets[i]) @@ -86,10 +128,11 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): task_losses.append(loss) batch_loss = sum(task_losses) / len(task_losses) - val_losses.append(batch_loss.item()) - avg_val_loss = sum(val_losses) / len(val_loader) - print(f"Validation Loss: {avg_val_loss:.4f}") - + val_loss += batch_loss.item() + + avg_val_loss = val_loss / len(val_loader) + print(f"Validation Loss: {avg_val_loss:.4f}") + if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss model.save("BASEv0.1")