aiuNN/src/aiunn/finetune.py

169 lines
6.2 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import io
import csv
import base64
from PIL import Image, ImageFile
from torch.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from aiia import AIIABase
from aiunn.upsampler import Upsampler
# Define a simple EarlyStopping class to monitor the epoch loss.
class EarlyStopping:
def __init__(self, patience=3, min_delta=0.001):
self.patience = patience # Number of epochs with no significant improvement before stopping.
self.min_delta = min_delta # Minimum change in loss required to count as an improvement.
self.best_loss = float('inf')
self.counter = 0
self.early_stop = False
def __call__(self, epoch_loss):
# If current loss is lower than the best loss minus min_delta, update best loss and reset counter.
if epoch_loss < self.best_loss - self.min_delta:
self.best_loss = epoch_loss
self.counter = 0
else:
# No significant improvement: increment counter.
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
# UpscaleDataset to load and preprocess your data.
class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None):
combined_df = pd.DataFrame()
for parquet_file in parquet_files:
# Load data with head() to limit rows for memory efficiency.
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(1250)
combined_df = pd.concat([combined_df, df], ignore_index=True)
# Validate that each row has proper image formats.
self.df = combined_df.apply(self._validate_row, axis=1)
self.transform = transform
self.failed_indices = set()
def _validate_row(self, row):
for col in ['image_512', 'image_1024']:
if not isinstance(row[col], (bytes, str)):
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
return row
def _decode_image(self, data):
try:
if isinstance(data, str):
return base64.b64decode(data)
elif isinstance(data, bytes):
return data
raise ValueError(f"Unsupported data type: {type(data)}")
except Exception as e:
raise RuntimeError(f"Decoding failed: {str(e)}")
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# Skip indices that have previously failed.
if idx in self.failed_indices:
return self[(idx + 1) % len(self)]
try:
row = self.df.iloc[idx]
low_res_bytes = self._decode_image(row['image_512'])
high_res_bytes = self._decode_image(row['image_1024'])
ImageFile.LOAD_TRUNCATED_IMAGES = True
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
# Validate expected sizes
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
if self.transform:
low_res = self.transform(low_res)
high_res = self.transform(high_res)
return low_res, high_res
except Exception as e:
print(f"\nError at index {idx}: {str(e)}")
self.failed_indices.add(idx)
return self[(idx + 1) % len(self)]
# Define any transformations you require (e.g., converting PIL images to tensors)
transform = transforms.Compose([
transforms.ToTensor(),
])
# Load the base AIIABase model and wrap it with the Upsampler.
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
base_model = AIIABase.load(pretrained_model_path)
model = Upsampler(base_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Create the dataset and dataloader.
dataset = UpscaleDataset([
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
], transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
# Define loss function and optimizer.
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 10
model.train()
# Prepare a CSV file for logging training loss.
csv_file = 'losses.csv'
with open(csv_file, mode='a', newline='') as file:
writer = csv.writer(file)
if file.tell() == 0:
writer.writerow(['Epoch', 'Train Loss'])
# Initialize automatic mixed precision scaler and EarlyStopping.
scaler = GradScaler()
early_stopping = EarlyStopping(patience=3, min_delta=0.001)
# Training loop with early stopping.
for epoch in range(num_epochs):
epoch_loss = 0.0
progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
print(f"Epoch: {epoch}")
for low_res, high_res in progress_bar:
low_res = low_res.to(device, non_blocking=True)
high_res = high_res.to(device, non_blocking=True)
optimizer.zero_grad()
# Use automatic mixed precision to speed up training on supported hardware.
with autocast(device_type=device.type):
outputs = model(low_res)
loss = criterion(outputs, high_res)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")
# Record the loss in the CSV log.
with open(csv_file, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, epoch_loss])
# Check early stopping criteria.
if early_stopping(epoch_loss):
print(f"Early stopping triggered at epoch {epoch + 1} with loss {epoch_loss}")
break
# Optionally, save the finetuned model using your library's save method.
finetuned_model_path = "aiuNN"
model.save(finetuned_model_path)