From 736886021c7858ed5d1cd7d013321a2640eba03b Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sat, 22 Feb 2025 17:53:06 +0100 Subject: [PATCH] updated finetuning script to work with Upsamler and added Early Stopping --- src/aiunn/finetune.py | 121 +++++++++++++++++++++++++----------------- 1 file changed, 71 insertions(+), 50 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 8ef0fa1..d785030 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -1,40 +1,63 @@ +import torch +import torch.nn as nn +import torch.optim as optim import pandas as pd import io -from PIL import Image, ImageFile -from torch.utils.data import Dataset -from torchvision import transforms -from aiia import AIIABase import csv -from tqdm import tqdm import base64 +from PIL import Image, ImageFile from torch.amp import autocast, GradScaler -import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from tqdm import tqdm +from aiia import AIIABase +from upsampler import Upsampler + +# Define a simple EarlyStopping class to monitor the epoch loss. +class EarlyStopping: + def __init__(self, patience=3, min_delta=0.001): + self.patience = patience # Number of epochs with no significant improvement before stopping. + self.min_delta = min_delta # Minimum change in loss required to count as an improvement. + self.best_loss = float('inf') + self.counter = 0 + self.early_stop = False + + def __call__(self, epoch_loss): + # If current loss is lower than the best loss minus min_delta, update best loss and reset counter. + if epoch_loss < self.best_loss - self.min_delta: + self.best_loss = epoch_loss + self.counter = 0 + else: + # No significant improvement: increment counter. + self.counter += 1 + if self.counter >= self.patience: + self.early_stop = True + return self.early_stop + +# UpscaleDataset to load and preprocess your data. class UpscaleDataset(Dataset): def __init__(self, parquet_files: list, transform=None): combined_df = pd.DataFrame() for parquet_file in parquet_files: - # Load data with chunking for memory efficiency + # Load data with head() to limit rows for memory efficiency. df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(1250) combined_df = pd.concat([combined_df, df], ignore_index=True) - # Validate data format + # Validate that each row has proper image formats. self.df = combined_df.apply(self._validate_row, axis=1) self.transform = transform self.failed_indices = set() def _validate_row(self, row): - """Ensure both images exist and have correct dimensions""" for col in ['image_512', 'image_1024']: if not isinstance(row[col], (bytes, str)): raise ValueError(f"Invalid data type in column {col}: {type(row[col])}") return row def _decode_image(self, data): - """Universal decoder handling both base64 strings and bytes""" try: if isinstance(data, str): - # Handle base64 encoded strings return base64.b64decode(data) elif isinstance(data, bytes): return data @@ -46,102 +69,100 @@ class UpscaleDataset(Dataset): return len(self.df) def __getitem__(self, idx): + # Skip indices that have previously failed. if idx in self.failed_indices: - return self[(idx + 1) % len(self)] # Skip failed indices - + return self[(idx + 1) % len(self)] try: row = self.df.iloc[idx] - - # Decode both images low_res_bytes = self._decode_image(row['image_512']) high_res_bytes = self._decode_image(row['image_1024']) - - # Load images with truncation handling ImageFile.LOAD_TRUNCATED_IMAGES = True low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB') high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB') - - # Validate image sizes + # Validate expected sizes if low_res.size != (512, 512) or high_res.size != (1024, 1024): raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}") - if self.transform: low_res = self.transform(low_res) high_res = self.transform(high_res) - return low_res, high_res - except Exception as e: print(f"\nError at index {idx}: {str(e)}") self.failed_indices.add(idx) - return self[(idx + 1) % len(self)] # Return next valid sample + return self[(idx + 1) % len(self)] -# Example transform: converting PIL images to tensors +# Define any transformations you require (e.g., converting PIL images to tensors) transform = transforms.Compose([ transforms.ToTensor(), ]) - -# Replace with your actual pretrained model path -pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" - -# Load the model using the AIIA.load class method (the implementation copied in your query) -model = AIIABase.load(pretrained_model_path) +# Load the base AIIABase model and wrap it with the Upsampler. +pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" +base_model = AIIABase.load(pretrained_model_path) +model = Upsampler(base_model) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) -from torch import nn, optim -from torch.utils.data import DataLoader -# Create your dataset and dataloader -dataset = UpscaleDataset(["/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet"], transform=transform) +# Create the dataset and dataloader. +dataset = UpscaleDataset([ + "/root/training_data/vision-dataset/image_upscaler.parquet", + "/root/training_data/vision-dataset/image_vec_upscaler.parquet" +], transform=transform) data_loader = DataLoader(dataset, batch_size=1, shuffle=True) -# Define a loss function and optimizer +# Define loss function and optimizer. criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-4) num_epochs = 10 -model.train() # Set model in training mode - +model.train() +# Prepare a CSV file for logging training loss. csv_file = 'losses.csv' - -# Create or open the CSV file and write the header if it doesn't exist with open(csv_file, mode='a', newline='') as file: writer = csv.writer(file) - # Write the header only if the file is empty if file.tell() == 0: writer.writerow(['Epoch', 'Train Loss']) -# Create a gradient scaler (for scaling gradients when using AMP) +# Initialize automatic mixed precision scaler and EarlyStopping. scaler = GradScaler() +early_stopping = EarlyStopping(patience=3, min_delta=0.001) +# Training loop with early stopping. for epoch in range(num_epochs): epoch_loss = 0.0 - data_loader_with_progress = tqdm(data_loader, desc=f"Epoch {epoch + 1}") - for low_res, high_res in data_loader_with_progress: + progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}") + print(f"Epoch: {epoch}") + for low_res, high_res in progress_bar: low_res = low_res.to(device, non_blocking=True) high_res = high_res.to(device, non_blocking=True) - + optimizer.zero_grad() - # Use automatic mixed precision context - with autocast(device_type="cuda"): + # Use automatic mixed precision to speed up training on supported hardware. + with autocast(device_type=device.type): outputs = model(low_res) loss = criterion(outputs, high_res) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() - + epoch_loss += loss.item() + progress_bar.set_postfix({'loss': loss.item()}) + print(f"Epoch {epoch + 1}, Loss: {epoch_loss}") - - # Append the training loss to the CSV file + + # Record the loss in the CSV log. with open(csv_file, mode='a', newline='') as file: writer = csv.writer(file) writer.writerow([epoch + 1, epoch_loss]) -# Optionally, save the finetuned model to a new directory + # Check early stopping criteria. + if early_stopping(epoch_loss): + print(f"Early stopping triggered at epoch {epoch + 1} with loss {epoch_loss}") + break + +# Optionally, save the finetuned model using your library's save method. finetuned_model_path = "aiuNN" model.save(finetuned_model_path)