diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py new file mode 100644 index 0000000..a911f7f --- /dev/null +++ b/src/aiunn/finetune.py @@ -0,0 +1,144 @@ +import torch +import pandas as pd +from PIL import Image +import io +from torch import nn +from torch.utils.data import Dataset, DataLoader +import torchvision.transforms as transforms +from aiia.model import AIIABase +from sklearn.model_selection import train_test_split + + +# Step 1: Define Custom Dataset Class +class ImageDataset(Dataset): + def __init__(self, dataframe, transform=None): + self.dataframe = dataframe + self.transform = transform + + def __len__(self): + return len(self.dataframe) + + def __getitem__(self, idx): + row = self.dataframe.iloc[idx] + + # Decode image_512 from bytes + img_bytes = row['image_512'] + img_stream = io.BytesIO(img_bytes) + low_res_image = Image.open(img_stream).convert('RGB') + + # Decode image_1024 from bytes + high_res_bytes = row['image_1024'] + high_stream = io.BytesIO(high_res_bytes) + high_res_image = Image.open(high_stream).convert('RGB') + + # Apply transformations if specified + if self.transform: + low_res_image = self.transform(low_res_image) + high_res_image = self.transform(high_res_image) + + return {'low_res': low_res_image, 'high_res': high_res_image} + + + +# Step 2: Load and Preprocess Data +# Read the dataset (assuming it's a DataFrame with columns 'image_512' and 'image_1024') +df1 = pd.read_parquet('/root/training_data/vision-dataset/image_upscaler.parquet') +df2 = pd.read_parquet('/root/training_data/vision-dataset/image_vec_upscaler.parquet') + +# Combine the two datasets into one DataFrame +df = pd.concat([df1, df2], ignore_index=True) + +# Split into training and validation sets +train_df, val_df = train_test_split(df, test_size=0.2, random_state=42) + +# Define preprocessing transforms +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) +]) + +train_dataset = ImageDataset(train_df, transform=transform) +val_dataset = ImageDataset(val_df, transform=transform) + +# Create DataLoaders +batch_size = 2 +train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) +val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4) + +# Step 3: Load Pre-trained Model and Modify for Upscaling +model = AIIABase.load("AIIA-Base-512") + +# Freeze original CNN layers to prevent catastrophic forgetting +for param in model.cnn.parameters(): + param.requires_grad = False + +# Add upsample module +hidden_size = model.config.hidden_size # Assuming this is defined in your model's config +model.upsample = torch.nn.Sequential( + nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), + nn.Conv2d(hidden_size, 3, kernel_size=3, padding=1) +) + +# Step 4: Define Loss Function and Optimizer +criterion = torch.nn.MSELoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # Adjust learning rate as needed + +# Alternatively, if you want to train only the new layers: +params_to_update = [] +for name, param in model.named_parameters(): + if 'upsample' in name: + params_to_update.append(param) +optimizer = torch.optim.Adam(params_to_update, lr=0.001) + +# Step 5: Training Loop +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model.to(device) + +best_val_loss = float('inf') +num_epochs = 10 # Adjust as needed + +for epoch in range(num_epochs): + model.train() + running_loss = 0.0 + + for batch in train_loader: + low_res = batch['low_res'].to(device) + high_res = batch['high_res'].to(device) + + # Forward pass + features = model.cnn(low_res) + outputs = model.upsample(features) + + loss = criterion(outputs, high_res) + + # Backward pass and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + running_loss += loss.item() + + epoch_loss = running_loss / len(train_loader) + print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}') + + # Validation Step + model.eval() + val_loss = 0.0 + + with torch.no_grad(): + for batch in val_loader: + low_res = batch['low_res'].to(device) + high_res = batch['high_res'].to(device) + + features = model.cnn(low_res) + outputs = model.upsample(features) + + loss = criterion(outputs, high_res) + val_loss += loss.item() + + print(f"Validation Loss: {val_loss:.4f}") + + if val_loss < best_val_loss: + best_val_loss = val_loss + model.save("AIIA-base-512-upscaler") + print("Best model saved!") \ No newline at end of file