updated finetuning script to work with Upsamler and added Early Stopping

This commit is contained in:
Falko Victor Habel 2025-02-22 17:53:06 +01:00
parent 2aba93caea
commit 736886021c
1 changed files with 71 additions and 50 deletions

View File

@ -1,40 +1,63 @@
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import io
from PIL import Image, ImageFile
from torch.utils.data import Dataset
from torchvision import transforms
from aiia import AIIABase
import csv
from tqdm import tqdm
import base64
from PIL import Image, ImageFile
from torch.amp import autocast, GradScaler
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from aiia import AIIABase
from upsampler import Upsampler
# Define a simple EarlyStopping class to monitor the epoch loss.
class EarlyStopping:
def __init__(self, patience=3, min_delta=0.001):
self.patience = patience # Number of epochs with no significant improvement before stopping.
self.min_delta = min_delta # Minimum change in loss required to count as an improvement.
self.best_loss = float('inf')
self.counter = 0
self.early_stop = False
def __call__(self, epoch_loss):
# If current loss is lower than the best loss minus min_delta, update best loss and reset counter.
if epoch_loss < self.best_loss - self.min_delta:
self.best_loss = epoch_loss
self.counter = 0
else:
# No significant improvement: increment counter.
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
# UpscaleDataset to load and preprocess your data.
class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None):
combined_df = pd.DataFrame()
for parquet_file in parquet_files:
# Load data with chunking for memory efficiency
# Load data with head() to limit rows for memory efficiency.
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(1250)
combined_df = pd.concat([combined_df, df], ignore_index=True)
# Validate data format
# Validate that each row has proper image formats.
self.df = combined_df.apply(self._validate_row, axis=1)
self.transform = transform
self.failed_indices = set()
def _validate_row(self, row):
"""Ensure both images exist and have correct dimensions"""
for col in ['image_512', 'image_1024']:
if not isinstance(row[col], (bytes, str)):
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
return row
def _decode_image(self, data):
"""Universal decoder handling both base64 strings and bytes"""
try:
if isinstance(data, str):
# Handle base64 encoded strings
return base64.b64decode(data)
elif isinstance(data, bytes):
return data
@ -46,102 +69,100 @@ class UpscaleDataset(Dataset):
return len(self.df)
def __getitem__(self, idx):
# Skip indices that have previously failed.
if idx in self.failed_indices:
return self[(idx + 1) % len(self)] # Skip failed indices
return self[(idx + 1) % len(self)]
try:
row = self.df.iloc[idx]
# Decode both images
low_res_bytes = self._decode_image(row['image_512'])
high_res_bytes = self._decode_image(row['image_1024'])
# Load images with truncation handling
ImageFile.LOAD_TRUNCATED_IMAGES = True
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
# Validate image sizes
# Validate expected sizes
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
if self.transform:
low_res = self.transform(low_res)
high_res = self.transform(high_res)
return low_res, high_res
except Exception as e:
print(f"\nError at index {idx}: {str(e)}")
self.failed_indices.add(idx)
return self[(idx + 1) % len(self)] # Return next valid sample
return self[(idx + 1) % len(self)]
# Example transform: converting PIL images to tensors
# Define any transformations you require (e.g., converting PIL images to tensors)
transform = transforms.Compose([
transforms.ToTensor(),
])
# Replace with your actual pretrained model path
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
# Load the model using the AIIA.load class method (the implementation copied in your query)
model = AIIABase.load(pretrained_model_path)
# Load the base AIIABase model and wrap it with the Upsampler.
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
base_model = AIIABase.load(pretrained_model_path)
model = Upsampler(base_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
from torch import nn, optim
from torch.utils.data import DataLoader
# Create your dataset and dataloader
dataset = UpscaleDataset(["/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet"], transform=transform)
# Create the dataset and dataloader.
dataset = UpscaleDataset([
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
], transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
# Define a loss function and optimizer
# Define loss function and optimizer.
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 10
model.train() # Set model in training mode
model.train()
# Prepare a CSV file for logging training loss.
csv_file = 'losses.csv'
# Create or open the CSV file and write the header if it doesn't exist
with open(csv_file, mode='a', newline='') as file:
writer = csv.writer(file)
# Write the header only if the file is empty
if file.tell() == 0:
writer.writerow(['Epoch', 'Train Loss'])
# Create a gradient scaler (for scaling gradients when using AMP)
# Initialize automatic mixed precision scaler and EarlyStopping.
scaler = GradScaler()
early_stopping = EarlyStopping(patience=3, min_delta=0.001)
# Training loop with early stopping.
for epoch in range(num_epochs):
epoch_loss = 0.0
data_loader_with_progress = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
for low_res, high_res in data_loader_with_progress:
progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
print(f"Epoch: {epoch}")
for low_res, high_res in progress_bar:
low_res = low_res.to(device, non_blocking=True)
high_res = high_res.to(device, non_blocking=True)
optimizer.zero_grad()
# Use automatic mixed precision context
with autocast(device_type="cuda"):
# Use automatic mixed precision to speed up training on supported hardware.
with autocast(device_type=device.type):
outputs = model(low_res)
loss = criterion(outputs, high_res)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")
# Append the training loss to the CSV file
# Record the loss in the CSV log.
with open(csv_file, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, epoch_loss])
# Optionally, save the finetuned model to a new directory
# Check early stopping criteria.
if early_stopping(epoch_loss):
print(f"Early stopping triggered at epoch {epoch + 1} with loss {epoch_loss}")
break
# Optionally, save the finetuned model using your library's save method.
finetuned_model_path = "aiuNN"
model.save(finetuned_model_path)