AIIA/src/pretrain.py

206 lines
7.1 KiB
Python

import torch
from torch import nn, utils
import pandas as pd
from aiia.model.config import AIIAConfig
from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader
from tqdm import tqdm
def pretrain_model(data_path1, data_path2, num_epochs=3):
# Read and merge datasets
df1 = pd.read_parquet(data_path1).head(10000)
df2 = pd.read_parquet(data_path2).head(10000)
merged_df = pd.concat([df1, df2], ignore_index=True)
# Model configuration
config = AIIAConfig(
model_name="AIIA-Base-512x20k",
)
# Initialize model and data loader
model = AIIABase(config)
def safe_collate(batch):
denoise_batch = []
rotate_batch = []
for sample in batch:
try:
noisy_img, target, task = sample
if task == 'denoise':
denoise_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
else: # rotate task
rotate_batch.append({
'image': noisy_img,
'target': target,
'task': task
})
except Exception as e:
print(f"Skipping sample due to error: {e}")
continue
if not denoise_batch and not rotate_batch:
return None
batch_data = {
'denoise': None,
'rotate': None
}
# Process denoise batch
if denoise_batch:
images = torch.stack([x['image'] for x in denoise_batch])
targets = torch.stack([x['target'] for x in denoise_batch])
batch_data['denoise'] = (images, targets)
# Process rotate batch
if rotate_batch:
images = torch.stack([x['image'] for x in rotate_batch])
targets = torch.stack([x['target'] for x in rotate_batch])
batch_data['rotate'] = (images, targets)
return batch_data
aiia_loader = AIIADataLoader(
merged_df,
column="image_bytes",
batch_size=2,
pretraining=True,
collate_fn=safe_collate
)
train_loader = aiia_loader.train_loader
val_loader = aiia_loader.val_loader
criterion_denoise = nn.MSELoss()
criterion_rotate = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
best_val_loss = float('inf')
for epoch in range(num_epochs):
print(f"\nEpoch {epoch+1}/{num_epochs}")
print("-" * 20)
# Training phase
model.train()
total_train_loss = 0.0
batch_count = 0
for batch_data in tqdm(train_loader):
if batch_data is None:
continue
optimizer.zero_grad()
batch_loss = 0
# Handle denoise task
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
# Print shapes for debugging
print(f"\nDenoising task shapes:")
print(f"Input shape: {noisy_imgs.shape}")
print(f"Target shape: {targets.shape}")
outputs = model(noisy_imgs)
print(f"Raw output shape: {outputs.shape}")
# Reshape output to match target dimensions
batch_size = targets.size(0)
outputs = outputs.view(batch_size, 3, 224, 224)
print(f"Reshaped output shape: {outputs.shape}")
loss = criterion_denoise(outputs, targets)
batch_loss += loss
# Handle rotate task
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
# Print shapes for debugging
print(f"\nRotation task shapes:")
print(f"Input shape: {imgs.shape}")
print(f"Target shape: {targets.shape}")
outputs = model(imgs)
print(f"Raw output shape: {outputs.shape}")
# Reshape output for rotation classification
outputs = outputs.view(targets.size(0), -1) # Flatten to [batch_size, features]
print(f"Reshaped output shape: {outputs.shape}")
loss = criterion_rotate(outputs, targets)
batch_loss += loss
if batch_loss > 0:
batch_loss.backward()
optimizer.step()
total_train_loss += batch_loss.item()
batch_count += 1
avg_train_loss = total_train_loss / max(batch_count, 1)
print(f"Training Loss: {avg_train_loss:.4f}")
# Validation phase
model.eval()
val_loss = 0.0
val_batch_count = 0
with torch.no_grad():
for batch_data in val_loader:
if batch_data is None:
continue
batch_loss = 0
# Handle denoise task
if batch_data['denoise'] is not None:
noisy_imgs, targets = batch_data['denoise']
noisy_imgs = noisy_imgs.to(device)
targets = targets.to(device)
outputs = model(noisy_imgs)
batch_size = targets.size(0)
outputs = outputs.view(batch_size, 3, 224, 224)
loss = criterion_denoise(outputs, targets)
batch_loss += loss
# Handle rotate task
if batch_data['rotate'] is not None:
imgs, targets = batch_data['rotate']
imgs = imgs.to(device)
targets = targets.long().to(device)
outputs = model(imgs)
outputs = outputs.view(targets.size(0), -1)
loss = criterion_rotate(outputs, targets)
batch_loss += loss
if batch_loss > 0:
val_loss += batch_loss.item()
val_batch_count += 1
avg_val_loss = val_loss / max(val_batch_count, 1)
print(f"Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save("BASEv0.1")
print("Best model saved!")
if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
pretrain_model(data_path1, data_path2, num_epochs=3)