new pretraining script

This commit is contained in:
Falko Victor Habel 2025-01-26 23:20:15 +01:00
parent 2b55f02b50
commit b501ae8317
1 changed files with 61 additions and 18 deletions

View File

@ -1,29 +1,65 @@
import torch
from torch import nn
from torch import nn, utils
import pandas as pd
from aiia.model.config import AIIAConfig
from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader
import os
import copy
def pretrain_model(data_path1, data_path2, num_epochs=3):
# Read and merge datasets
df1 = pd.read_parquet(data_path1).head(10000)
df2 = pd.read_parquet(data_path2).head(10000)
merged_df = pd.concat([df1, df2], ignore_index=True)
# Model configuration
config = AIIAConfig(
model_name="AIIA-Base-512x20k",
)
# Initialize model and data loader
model = AIIABase(config)
# Define a custom collate function to handle preprocessing and skip bad samples
def safe_collate(batch):
processed_batch = []
for sample in batch:
try:
# Process each sample here (e.g., decode image, preprocess, etc.)
# Replace with actual preprocessing steps
processed_sample = {
'image': torch.randn(3, 224, 224), # Example tensor
'target': torch.randint(0, 10, (1,)), # Example target
'task': 'denoise' # Example task
}
processed_batch.append(processed_sample)
except Exception as e:
print(f"Skipping sample due to error: {e}")
if not processed_batch:
return None # Skip batch if all samples are invalid
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32, pretraining=True)
# Stack tensors for the batch
images = torch.stack([x['image'] for x in processed_batch])
targets = torch.stack([x['target'] for x in processed_batch])
tasks = [x['task'] for x in processed_batch]
return (images, targets, tasks)
aiia_loader = AIIADataLoader(
merged_df,
column="image_bytes",
batch_size=32,
pretraining=True,
collate_fn=safe_collate
)
train_loader = aiia_loader.train_loader
val_loader = aiia_loader.val_loader
# Define loss functions and optimizer
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -35,49 +71,55 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
model.train()
total_train_loss = 0.0
denoise_losses = []
rotate_losses = []
for batch in train_loader:
if batch is None:
continue # Skip empty batches
noisy_imgs, targets, tasks = batch
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
outputs = model(noisy_imgs)
task_losses = []
for i, task in enumerate(tasks):
if task == 'denoise':
loss = criterion_denoise(outputs[i], targets[i])
denoise_losses.append(loss.item())
else:
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
rotate_losses.append(loss.item())
task_losses.append(loss)
batch_loss = sum(task_losses) / len(task_losses)
batch_loss.backward()
optimizer.step()
total_train_loss += batch_loss.item()
avg_total_train_loss = total_train_loss / len(train_loader)
print(f"Training Loss: {avg_total_train_loss:.4f}")
# Validation phase
model.eval()
val_loss = 0.0
with torch.no_grad():
val_losses = []
for batch in val_loader:
if batch is None:
continue
noisy_imgs, targets, tasks = batch
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
outputs = model(noisy_imgs)
task_losses = []
for i, task in enumerate(tasks):
if task == 'denoise':
loss = criterion_denoise(outputs[i], targets[i])
@ -86,10 +128,11 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
task_losses.append(loss)
batch_loss = sum(task_losses) / len(task_losses)
val_losses.append(batch_loss.item())
avg_val_loss = sum(val_losses) / len(val_loader)
print(f"Validation Loss: {avg_val_loss:.4f}")
val_loss += batch_loss.item()
avg_val_loss = val_loss / len(val_loader)
print(f"Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save("BASEv0.1")