206 lines
7.1 KiB
Python
206 lines
7.1 KiB
Python
import torch
|
|
from torch import nn, utils
|
|
import pandas as pd
|
|
from aiia.model.config import AIIAConfig
|
|
from aiia.model import AIIABase
|
|
from aiia.data.DataLoader import AIIADataLoader
|
|
from tqdm import tqdm
|
|
|
|
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|
# Read and merge datasets
|
|
df1 = pd.read_parquet(data_path1).head(10000)
|
|
df2 = pd.read_parquet(data_path2).head(10000)
|
|
merged_df = pd.concat([df1, df2], ignore_index=True)
|
|
|
|
# Model configuration
|
|
config = AIIAConfig(
|
|
model_name="AIIA-Base-512x20k",
|
|
)
|
|
|
|
# Initialize model and data loader
|
|
model = AIIABase(config)
|
|
|
|
def safe_collate(batch):
|
|
denoise_batch = []
|
|
rotate_batch = []
|
|
|
|
for sample in batch:
|
|
try:
|
|
noisy_img, target, task = sample
|
|
if task == 'denoise':
|
|
denoise_batch.append({
|
|
'image': noisy_img,
|
|
'target': target,
|
|
'task': task
|
|
})
|
|
else: # rotate task
|
|
rotate_batch.append({
|
|
'image': noisy_img,
|
|
'target': target,
|
|
'task': task
|
|
})
|
|
except Exception as e:
|
|
print(f"Skipping sample due to error: {e}")
|
|
continue
|
|
|
|
if not denoise_batch and not rotate_batch:
|
|
return None
|
|
|
|
batch_data = {
|
|
'denoise': None,
|
|
'rotate': None
|
|
}
|
|
|
|
# Process denoise batch
|
|
if denoise_batch:
|
|
images = torch.stack([x['image'] for x in denoise_batch])
|
|
targets = torch.stack([x['target'] for x in denoise_batch])
|
|
batch_data['denoise'] = (images, targets)
|
|
|
|
# Process rotate batch
|
|
if rotate_batch:
|
|
images = torch.stack([x['image'] for x in rotate_batch])
|
|
targets = torch.stack([x['target'] for x in rotate_batch])
|
|
batch_data['rotate'] = (images, targets)
|
|
|
|
return batch_data
|
|
|
|
aiia_loader = AIIADataLoader(
|
|
merged_df,
|
|
column="image_bytes",
|
|
batch_size=2,
|
|
pretraining=True,
|
|
collate_fn=safe_collate
|
|
)
|
|
|
|
train_loader = aiia_loader.train_loader
|
|
val_loader = aiia_loader.val_loader
|
|
|
|
criterion_denoise = nn.MSELoss()
|
|
criterion_rotate = nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model.to(device)
|
|
|
|
best_val_loss = float('inf')
|
|
|
|
for epoch in range(num_epochs):
|
|
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
|
print("-" * 20)
|
|
|
|
# Training phase
|
|
model.train()
|
|
total_train_loss = 0.0
|
|
batch_count = 0
|
|
|
|
for batch_data in tqdm(train_loader):
|
|
if batch_data is None:
|
|
continue
|
|
|
|
optimizer.zero_grad()
|
|
batch_loss = 0
|
|
|
|
# Handle denoise task
|
|
if batch_data['denoise'] is not None:
|
|
noisy_imgs, targets = batch_data['denoise']
|
|
noisy_imgs = noisy_imgs.to(device)
|
|
targets = targets.to(device)
|
|
|
|
# Print shapes for debugging
|
|
print(f"\nDenoising task shapes:")
|
|
print(f"Input shape: {noisy_imgs.shape}")
|
|
print(f"Target shape: {targets.shape}")
|
|
|
|
outputs = model(noisy_imgs)
|
|
print(f"Raw output shape: {outputs.shape}")
|
|
|
|
# Reshape output to match target dimensions
|
|
batch_size = targets.size(0)
|
|
outputs = outputs.view(batch_size, 3, 224, 224)
|
|
print(f"Reshaped output shape: {outputs.shape}")
|
|
|
|
loss = criterion_denoise(outputs, targets)
|
|
batch_loss += loss
|
|
|
|
# Handle rotate task
|
|
if batch_data['rotate'] is not None:
|
|
imgs, targets = batch_data['rotate']
|
|
imgs = imgs.to(device)
|
|
targets = targets.long().to(device)
|
|
|
|
# Print shapes for debugging
|
|
print(f"\nRotation task shapes:")
|
|
print(f"Input shape: {imgs.shape}")
|
|
print(f"Target shape: {targets.shape}")
|
|
|
|
outputs = model(imgs)
|
|
print(f"Raw output shape: {outputs.shape}")
|
|
|
|
# Reshape output for rotation classification
|
|
outputs = outputs.view(targets.size(0), -1) # Flatten to [batch_size, features]
|
|
print(f"Reshaped output shape: {outputs.shape}")
|
|
|
|
loss = criterion_rotate(outputs, targets)
|
|
batch_loss += loss
|
|
|
|
if batch_loss > 0:
|
|
batch_loss.backward()
|
|
optimizer.step()
|
|
total_train_loss += batch_loss.item()
|
|
batch_count += 1
|
|
|
|
avg_train_loss = total_train_loss / max(batch_count, 1)
|
|
print(f"Training Loss: {avg_train_loss:.4f}")
|
|
|
|
# Validation phase
|
|
model.eval()
|
|
val_loss = 0.0
|
|
val_batch_count = 0
|
|
|
|
with torch.no_grad():
|
|
for batch_data in val_loader:
|
|
if batch_data is None:
|
|
continue
|
|
|
|
batch_loss = 0
|
|
|
|
# Handle denoise task
|
|
if batch_data['denoise'] is not None:
|
|
noisy_imgs, targets = batch_data['denoise']
|
|
noisy_imgs = noisy_imgs.to(device)
|
|
targets = targets.to(device)
|
|
|
|
outputs = model(noisy_imgs)
|
|
batch_size = targets.size(0)
|
|
outputs = outputs.view(batch_size, 3, 224, 224)
|
|
loss = criterion_denoise(outputs, targets)
|
|
batch_loss += loss
|
|
|
|
# Handle rotate task
|
|
if batch_data['rotate'] is not None:
|
|
imgs, targets = batch_data['rotate']
|
|
imgs = imgs.to(device)
|
|
targets = targets.long().to(device)
|
|
|
|
outputs = model(imgs)
|
|
outputs = outputs.view(targets.size(0), -1)
|
|
loss = criterion_rotate(outputs, targets)
|
|
batch_loss += loss
|
|
|
|
if batch_loss > 0:
|
|
val_loss += batch_loss.item()
|
|
val_batch_count += 1
|
|
|
|
avg_val_loss = val_loss / max(val_batch_count, 1)
|
|
print(f"Validation Loss: {avg_val_loss:.4f}")
|
|
|
|
if avg_val_loss < best_val_loss:
|
|
best_val_loss = avg_val_loss
|
|
model.save("BASEv0.1")
|
|
print("Best model saved!")
|
|
|
|
if __name__ == "__main__":
|
|
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
|
|
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
|
pretrain_model(data_path1, data_path2, num_epochs=3) |