develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 144 additions and 0 deletions
Showing only changes of commit d85faadcc1 - Show all commits

144
src/aiunn/finetune.py Normal file
View File

@ -0,0 +1,144 @@
import torch
import pandas as pd
from PIL import Image
import io
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from aiia.model import AIIABase
from sklearn.model_selection import train_test_split
# Step 1: Define Custom Dataset Class
class ImageDataset(Dataset):
def __init__(self, dataframe, transform=None):
self.dataframe = dataframe
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
row = self.dataframe.iloc[idx]
# Decode image_512 from bytes
img_bytes = row['image_512']
img_stream = io.BytesIO(img_bytes)
low_res_image = Image.open(img_stream).convert('RGB')
# Decode image_1024 from bytes
high_res_bytes = row['image_1024']
high_stream = io.BytesIO(high_res_bytes)
high_res_image = Image.open(high_stream).convert('RGB')
# Apply transformations if specified
if self.transform:
low_res_image = self.transform(low_res_image)
high_res_image = self.transform(high_res_image)
return {'low_res': low_res_image, 'high_res': high_res_image}
# Step 2: Load and Preprocess Data
# Read the dataset (assuming it's a DataFrame with columns 'image_512' and 'image_1024')
df1 = pd.read_parquet('/root/training_data/vision-dataset/image_upscaler.parquet')
df2 = pd.read_parquet('/root/training_data/vision-dataset/image_vec_upscaler.parquet')
# Combine the two datasets into one DataFrame
df = pd.concat([df1, df2], ignore_index=True)
# Split into training and validation sets
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
# Define preprocessing transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = ImageDataset(train_df, transform=transform)
val_dataset = ImageDataset(val_df, transform=transform)
# Create DataLoaders
batch_size = 2
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
# Step 3: Load Pre-trained Model and Modify for Upscaling
model = AIIABase.load("AIIA-Base-512")
# Freeze original CNN layers to prevent catastrophic forgetting
for param in model.cnn.parameters():
param.requires_grad = False
# Add upsample module
hidden_size = model.config.hidden_size # Assuming this is defined in your model's config
model.upsample = torch.nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(hidden_size, 3, kernel_size=3, padding=1)
)
# Step 4: Define Loss Function and Optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # Adjust learning rate as needed
# Alternatively, if you want to train only the new layers:
params_to_update = []
for name, param in model.named_parameters():
if 'upsample' in name:
params_to_update.append(param)
optimizer = torch.optim.Adam(params_to_update, lr=0.001)
# Step 5: Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
best_val_loss = float('inf')
num_epochs = 10 # Adjust as needed
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for batch in train_loader:
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
# Forward pass
features = model.cnn(low_res)
outputs = model.upsample(features)
loss = criterion(outputs, high_res)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
epoch_loss = running_loss / len(train_loader)
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')
# Validation Step
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in val_loader:
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
features = model.cnn(low_res)
outputs = model.upsample(features)
loss = criterion(outputs, high_res)
val_loss += loss.item()
print(f"Validation Loss: {val_loss:.4f}")
if val_loss < best_val_loss:
best_val_loss = val_loss
model.save("AIIA-base-512-upscaler")
print("Best model saved!")