new pretraining script
This commit is contained in:
parent
2b55f02b50
commit
b501ae8317
|
@ -1,29 +1,65 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn, utils
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from aiia.model.config import AIIAConfig
|
from aiia.model.config import AIIAConfig
|
||||||
from aiia.model import AIIABase
|
from aiia.model import AIIABase
|
||||||
from aiia.data.DataLoader import AIIADataLoader
|
from aiia.data.DataLoader import AIIADataLoader
|
||||||
|
import os
|
||||||
|
import copy
|
||||||
|
|
||||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
|
# Read and merge datasets
|
||||||
df1 = pd.read_parquet(data_path1).head(10000)
|
df1 = pd.read_parquet(data_path1).head(10000)
|
||||||
df2 = pd.read_parquet(data_path2).head(10000)
|
df2 = pd.read_parquet(data_path2).head(10000)
|
||||||
merged_df = pd.concat([df1, df2], ignore_index=True)
|
merged_df = pd.concat([df1, df2], ignore_index=True)
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
config = AIIAConfig(
|
config = AIIAConfig(
|
||||||
model_name="AIIA-Base-512x20k",
|
model_name="AIIA-Base-512x20k",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize model and data loader
|
||||||
model = AIIABase(config)
|
model = AIIABase(config)
|
||||||
|
|
||||||
|
# Define a custom collate function to handle preprocessing and skip bad samples
|
||||||
|
def safe_collate(batch):
|
||||||
|
processed_batch = []
|
||||||
|
for sample in batch:
|
||||||
|
try:
|
||||||
|
# Process each sample here (e.g., decode image, preprocess, etc.)
|
||||||
|
# Replace with actual preprocessing steps
|
||||||
|
processed_sample = {
|
||||||
|
'image': torch.randn(3, 224, 224), # Example tensor
|
||||||
|
'target': torch.randint(0, 10, (1,)), # Example target
|
||||||
|
'task': 'denoise' # Example task
|
||||||
|
}
|
||||||
|
processed_batch.append(processed_sample)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Skipping sample due to error: {e}")
|
||||||
|
if not processed_batch:
|
||||||
|
return None # Skip batch if all samples are invalid
|
||||||
|
|
||||||
aiia_loader = AIIADataLoader(merged_df, column="image_bytes", batch_size=32, pretraining=True)
|
# Stack tensors for the batch
|
||||||
|
images = torch.stack([x['image'] for x in processed_batch])
|
||||||
|
targets = torch.stack([x['target'] for x in processed_batch])
|
||||||
|
tasks = [x['task'] for x in processed_batch]
|
||||||
|
|
||||||
|
return (images, targets, tasks)
|
||||||
|
|
||||||
|
aiia_loader = AIIADataLoader(
|
||||||
|
merged_df,
|
||||||
|
column="image_bytes",
|
||||||
|
batch_size=32,
|
||||||
|
pretraining=True,
|
||||||
|
collate_fn=safe_collate
|
||||||
|
)
|
||||||
|
|
||||||
train_loader = aiia_loader.train_loader
|
train_loader = aiia_loader.train_loader
|
||||||
val_loader = aiia_loader.val_loader
|
val_loader = aiia_loader.val_loader
|
||||||
|
|
||||||
|
# Define loss functions and optimizer
|
||||||
criterion_denoise = nn.MSELoss()
|
criterion_denoise = nn.MSELoss()
|
||||||
criterion_rotate = nn.CrossEntropyLoss()
|
criterion_rotate = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
@ -35,49 +71,55 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
||||||
print("-" * 20)
|
print("-" * 20)
|
||||||
|
|
||||||
|
# Training phase
|
||||||
model.train()
|
model.train()
|
||||||
total_train_loss = 0.0
|
total_train_loss = 0.0
|
||||||
denoise_losses = []
|
|
||||||
rotate_losses = []
|
|
||||||
|
|
||||||
for batch in train_loader:
|
for batch in train_loader:
|
||||||
|
if batch is None:
|
||||||
|
continue # Skip empty batches
|
||||||
|
|
||||||
noisy_imgs, targets, tasks = batch
|
noisy_imgs, targets, tasks = batch
|
||||||
|
|
||||||
noisy_imgs = noisy_imgs.to(device)
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
outputs = model(noisy_imgs)
|
outputs = model(noisy_imgs)
|
||||||
task_losses = []
|
task_losses = []
|
||||||
|
|
||||||
for i, task in enumerate(tasks):
|
for i, task in enumerate(tasks):
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
loss = criterion_denoise(outputs[i], targets[i])
|
loss = criterion_denoise(outputs[i], targets[i])
|
||||||
denoise_losses.append(loss.item())
|
|
||||||
else:
|
else:
|
||||||
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
|
loss = criterion_rotate(outputs[i].unsqueeze(0), targets[i].unsqueeze(0))
|
||||||
rotate_losses.append(loss.item())
|
|
||||||
task_losses.append(loss)
|
task_losses.append(loss)
|
||||||
|
|
||||||
batch_loss = sum(task_losses) / len(task_losses)
|
batch_loss = sum(task_losses) / len(task_losses)
|
||||||
batch_loss.backward()
|
batch_loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
total_train_loss += batch_loss.item()
|
total_train_loss += batch_loss.item()
|
||||||
|
|
||||||
avg_total_train_loss = total_train_loss / len(train_loader)
|
avg_total_train_loss = total_train_loss / len(train_loader)
|
||||||
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
print(f"Training Loss: {avg_total_train_loss:.4f}")
|
||||||
|
|
||||||
|
# Validation phase
|
||||||
model.eval()
|
model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
val_losses = []
|
|
||||||
for batch in val_loader:
|
for batch in val_loader:
|
||||||
|
if batch is None:
|
||||||
|
continue
|
||||||
|
|
||||||
noisy_imgs, targets, tasks = batch
|
noisy_imgs, targets, tasks = batch
|
||||||
|
|
||||||
noisy_imgs = noisy_imgs.to(device)
|
noisy_imgs = noisy_imgs.to(device)
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
|
|
||||||
outputs = model(noisy_imgs)
|
outputs = model(noisy_imgs)
|
||||||
|
|
||||||
task_losses = []
|
task_losses = []
|
||||||
|
|
||||||
for i, task in enumerate(tasks):
|
for i, task in enumerate(tasks):
|
||||||
if task == 'denoise':
|
if task == 'denoise':
|
||||||
loss = criterion_denoise(outputs[i], targets[i])
|
loss = criterion_denoise(outputs[i], targets[i])
|
||||||
|
@ -86,10 +128,11 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
task_losses.append(loss)
|
task_losses.append(loss)
|
||||||
|
|
||||||
batch_loss = sum(task_losses) / len(task_losses)
|
batch_loss = sum(task_losses) / len(task_losses)
|
||||||
val_losses.append(batch_loss.item())
|
val_loss += batch_loss.item()
|
||||||
avg_val_loss = sum(val_losses) / len(val_loader)
|
|
||||||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
avg_val_loss = val_loss / len(val_loader)
|
||||||
|
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||||
|
|
||||||
if avg_val_loss < best_val_loss:
|
if avg_val_loss < best_val_loss:
|
||||||
best_val_loss = avg_val_loss
|
best_val_loss = avg_val_loss
|
||||||
model.save("BASEv0.1")
|
model.save("BASEv0.1")
|
||||||
|
|
Loading…
Reference in New Issue