new pretraining script
This commit is contained in:
parent
2b55f02b50
commit
b501ae8317
|
@ -1,29 +1,65 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch import nn, utils
|
||||
import pandas as pd
|
||||
from aiia.model.config import AIIAConfig
|
||||
from aiia.model import AIIABase
|
||||
from aiia.data.DataLoader import AIIADataLoader
|
||||
import os
|
||||
import copy
|
||||
|
||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||
# Read and merge datasets
|
||||
df1 = pd.read_parquet(data_path1).head(10000)
|
||||
df2 = pd.read_parquet(data_path2).head(10000)
|
||||
merged_df = pd.concat([df1, df2], ignore_index=True)
|
||||
|
||||
# Model configuration
|
||||
config = AIIAConfig(
|
||||
model_name="AIIA-Base-512x20k",
|
||||
)
|
||||
|
||||
# Initialize model and data loader
|
||||
model = AIIABase(config)
|
||||
|
||||
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32, pretraining=True)
|
||||
# Define a custom collate function to handle preprocessing and skip bad samples
|
||||
def safe_collate(batch):
|
||||
processed_batch = []
|
||||
for sample in batch:
|
||||
try:
|
||||
# Process each sample here (e.g., decode image, preprocess, etc.)
|
||||
# Replace with actual preprocessing steps
|
||||
processed_sample = {
|
||||
'image': torch.randn(3, 224, 224), # Example tensor
|
||||
'target': torch.randint(0, 10, (1,)), # Example target
|
||||
'task': 'denoise' # Example task
|
||||
}
|
||||
processed_batch.append(processed_sample)
|
||||
except Exception as e:
|
||||
print(f"Skipping sample due to error: {e}")
|
||||
if not processed_batch:
|
||||
return None # Skip batch if all samples are invalid
|
||||
|
||||
# Stack tensors for the batch
|
||||
images = torch.stack([x['image'] for x in processed_batch])
|
||||
targets = torch.stack([x['target'] for x in processed_batch])
|
||||
tasks = [x['task'] for x in processed_batch]
|
||||
|
||||
return (images, targets, tasks)
|
||||
|
||||
aiia_loader = AIIADataLoader(
|
||||
merged_df,
|
||||
column="image_bytes",
|
||||
batch_size=32,
|
||||
pretraining=True,
|
||||
collate_fn=safe_collate
|
||||
)
|
||||
|
||||
train_loader = aiia_loader.train_loader
|
||||
val_loader = aiia_loader.val_loader
|
||||
|
||||
# Define loss functions and optimizer
|
||||
criterion_denoise = nn.MSELoss()
|
||||
criterion_rotate = nn.CrossEntropyLoss()
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
@ -35,27 +71,28 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||
print("-" * 20)
|
||||
|
||||
# Training phase
|
||||
model.train()
|
||||
total_train_loss = 0.0
|
||||
denoise_losses = []
|
||||
rotate_losses = []
|
||||
|
||||
for batch in train_loader:
|
||||
noisy_imgs, targets, tasks = batch
|
||||
if batch is None:
|
||||
continue # Skip empty batches
|
||||
|
||||
noisy_imgs, targets, tasks = batch
|
||||
noisy_imgs = noisy_imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(noisy_imgs)
|
||||
task_losses = []
|
||||
|
||||
for i, task in enumerate(tasks):
|
||||
if task == 'denoise':
|
||||
loss = criterion_denoise(outputs[i], targets[i])
|
||||
denoise_losses.append(loss.item())
|
||||
else:
|
||||
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
|
||||
rotate_losses.append(loss.item())
|
||||
task_losses.append(loss)
|
||||
|
||||
batch_loss = sum(task_losses) / len(task_losses)
|
||||
|
@ -63,21 +100,26 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
optimizer.step()
|
||||
|
||||
total_train_loss += batch_loss.item()
|
||||
|
||||
avg_total_train_loss = total_train_loss / len(train_loader)
|
||||
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
||||
|
||||
# Validation phase
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
val_losses = []
|
||||
for batch in val_loader:
|
||||
noisy_imgs, targets, tasks = batch
|
||||
val_loss = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
if batch is None:
|
||||
continue
|
||||
|
||||
noisy_imgs, targets, tasks = batch
|
||||
noisy_imgs = noisy_imgs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
outputs = model(noisy_imgs)
|
||||
|
||||
task_losses = []
|
||||
|
||||
for i, task in enumerate(tasks):
|
||||
if task == 'denoise':
|
||||
loss = criterion_denoise(outputs[i], targets[i])
|
||||
|
@ -86,8 +128,9 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
task_losses.append(loss)
|
||||
|
||||
batch_loss = sum(task_losses) / len(task_losses)
|
||||
val_losses.append(batch_loss.item())
|
||||
avg_val_loss = sum(val_losses) / len(val_loader)
|
||||
val_loss += batch_loss.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||
|
||||
if avg_val_loss < best_val_loss:
|
||||
|
|
Loading…
Reference in New Issue