aiuNN/src/aiunn/finetune.py

148 lines
5.2 KiB
Python

import pandas as pd
import io
from PIL import Image, ImageFile
from torch.utils.data import Dataset
from torchvision import transforms
from aiia import AIIABase
import csv
from tqdm import tqdm
import base64
from torch.amp import autocast, GradScaler
import torch
class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None):
combined_df = pd.DataFrame()
for parquet_file in parquet_files:
# Load data with chunking for memory efficiency
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(5000)
combined_df = pd.concat([combined_df, df], ignore_index=True)
# Validate data format
self.df = combined_df.apply(self._validate_row, axis=1)
self.transform = transform
self.failed_indices = set()
def _validate_row(self, row):
"""Ensure both images exist and have correct dimensions"""
for col in ['image_512', 'image_1024']:
if not isinstance(row[col], (bytes, str)):
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
return row
def _decode_image(self, data):
"""Universal decoder handling both base64 strings and bytes"""
try:
if isinstance(data, str):
# Handle base64 encoded strings
return base64.b64decode(data)
elif isinstance(data, bytes):
return data
raise ValueError(f"Unsupported data type: {type(data)}")
except Exception as e:
raise RuntimeError(f"Decoding failed: {str(e)}")
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
if idx in self.failed_indices:
return self[(idx + 1) % len(self)] # Skip failed indices
try:
row = self.df.iloc[idx]
# Decode both images
low_res_bytes = self._decode_image(row['image_512'])
high_res_bytes = self._decode_image(row['image_1024'])
# Load images with truncation handling
ImageFile.LOAD_TRUNCATED_IMAGES = True
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
# Validate image sizes
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
if self.transform:
low_res = self.transform(low_res)
high_res = self.transform(high_res)
return low_res, high_res
except Exception as e:
print(f"\nError at index {idx}: {str(e)}")
self.failed_indices.add(idx)
return self[(idx + 1) % len(self)] # Return next valid sample
# Example transform: converting PIL images to tensors
transform = transforms.Compose([
transforms.ToTensor(),
])
# Replace with your actual pretrained model path
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
# Load the model using the AIIA.load class method (the implementation copied in your query)
model = AIIABase.load(pretrained_model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
from torch import nn, optim
from torch.utils.data import DataLoader
# Create your dataset and dataloader
dataset = UpscaleDataset(["/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet"], transform=transform)
data_loader = DataLoader(dataset, batch_size=2, shuffle=True)
# Define a loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 10
model.train() # Set model in training mode
csv_file = 'losses.csv'
# Create or open the CSV file and write the header if it doesn't exist
with open(csv_file, mode='a', newline='') as file:
writer = csv.writer(file)
# Write the header only if the file is empty
if file.tell() == 0:
writer.writerow(['Epoch', 'Train Loss'])
# Create a gradient scaler (for scaling gradients when using AMP)
scaler = GradScaler()
for epoch in range(num_epochs):
epoch_loss = 0.0
data_loader_with_progress = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
for low_res, high_res in data_loader_with_progress:
low_res = low_res.to(device, non_blocking=True)
high_res = high_res.to(device, non_blocking=True)
optimizer.zero_grad()
# Use automatic mixed precision context
with autocast(device_type=device):
outputs = model(low_res)
loss = criterion(outputs, high_res)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item()
print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")
# Append the training loss to the CSV file
with open(csv_file, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, epoch_loss])
# Optionally, save the finetuned model to a new directory
finetuned_model_path = "aiuNN"
model.save(finetuned_model_path)