updated finetuning script to work with Upsamler and added Early Stopping
This commit is contained in:
parent
2aba93caea
commit
736886021c
|
@ -1,40 +1,63 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import io
|
import io
|
||||||
from PIL import Image, ImageFile
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from torchvision import transforms
|
|
||||||
from aiia import AIIABase
|
|
||||||
import csv
|
import csv
|
||||||
from tqdm import tqdm
|
|
||||||
import base64
|
import base64
|
||||||
|
from PIL import Image, ImageFile
|
||||||
from torch.amp import autocast, GradScaler
|
from torch.amp import autocast, GradScaler
|
||||||
import torch
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from aiia import AIIABase
|
||||||
|
from upsampler import Upsampler
|
||||||
|
|
||||||
|
# Define a simple EarlyStopping class to monitor the epoch loss.
|
||||||
|
class EarlyStopping:
|
||||||
|
def __init__(self, patience=3, min_delta=0.001):
|
||||||
|
self.patience = patience # Number of epochs with no significant improvement before stopping.
|
||||||
|
self.min_delta = min_delta # Minimum change in loss required to count as an improvement.
|
||||||
|
self.best_loss = float('inf')
|
||||||
|
self.counter = 0
|
||||||
|
self.early_stop = False
|
||||||
|
|
||||||
|
def __call__(self, epoch_loss):
|
||||||
|
# If current loss is lower than the best loss minus min_delta, update best loss and reset counter.
|
||||||
|
if epoch_loss < self.best_loss - self.min_delta:
|
||||||
|
self.best_loss = epoch_loss
|
||||||
|
self.counter = 0
|
||||||
|
else:
|
||||||
|
# No significant improvement: increment counter.
|
||||||
|
self.counter += 1
|
||||||
|
if self.counter >= self.patience:
|
||||||
|
self.early_stop = True
|
||||||
|
return self.early_stop
|
||||||
|
|
||||||
|
# UpscaleDataset to load and preprocess your data.
|
||||||
class UpscaleDataset(Dataset):
|
class UpscaleDataset(Dataset):
|
||||||
def __init__(self, parquet_files: list, transform=None):
|
def __init__(self, parquet_files: list, transform=None):
|
||||||
combined_df = pd.DataFrame()
|
combined_df = pd.DataFrame()
|
||||||
for parquet_file in parquet_files:
|
for parquet_file in parquet_files:
|
||||||
# Load data with chunking for memory efficiency
|
# Load data with head() to limit rows for memory efficiency.
|
||||||
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(1250)
|
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(1250)
|
||||||
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
combined_df = pd.concat([combined_df, df], ignore_index=True)
|
||||||
|
|
||||||
# Validate data format
|
# Validate that each row has proper image formats.
|
||||||
self.df = combined_df.apply(self._validate_row, axis=1)
|
self.df = combined_df.apply(self._validate_row, axis=1)
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.failed_indices = set()
|
self.failed_indices = set()
|
||||||
|
|
||||||
def _validate_row(self, row):
|
def _validate_row(self, row):
|
||||||
"""Ensure both images exist and have correct dimensions"""
|
|
||||||
for col in ['image_512', 'image_1024']:
|
for col in ['image_512', 'image_1024']:
|
||||||
if not isinstance(row[col], (bytes, str)):
|
if not isinstance(row[col], (bytes, str)):
|
||||||
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
|
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
|
||||||
return row
|
return row
|
||||||
|
|
||||||
def _decode_image(self, data):
|
def _decode_image(self, data):
|
||||||
"""Universal decoder handling both base64 strings and bytes"""
|
|
||||||
try:
|
try:
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
# Handle base64 encoded strings
|
|
||||||
return base64.b64decode(data)
|
return base64.b64decode(data)
|
||||||
elif isinstance(data, bytes):
|
elif isinstance(data, bytes):
|
||||||
return data
|
return data
|
||||||
|
@ -46,87 +69,78 @@ class UpscaleDataset(Dataset):
|
||||||
return len(self.df)
|
return len(self.df)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
# Skip indices that have previously failed.
|
||||||
if idx in self.failed_indices:
|
if idx in self.failed_indices:
|
||||||
return self[(idx + 1) % len(self)] # Skip failed indices
|
return self[(idx + 1) % len(self)]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
row = self.df.iloc[idx]
|
row = self.df.iloc[idx]
|
||||||
|
|
||||||
# Decode both images
|
|
||||||
low_res_bytes = self._decode_image(row['image_512'])
|
low_res_bytes = self._decode_image(row['image_512'])
|
||||||
high_res_bytes = self._decode_image(row['image_1024'])
|
high_res_bytes = self._decode_image(row['image_1024'])
|
||||||
|
|
||||||
# Load images with truncation handling
|
|
||||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
|
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
|
||||||
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
|
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
|
||||||
|
# Validate expected sizes
|
||||||
# Validate image sizes
|
|
||||||
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
|
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
|
||||||
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
|
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
|
||||||
|
|
||||||
if self.transform:
|
if self.transform:
|
||||||
low_res = self.transform(low_res)
|
low_res = self.transform(low_res)
|
||||||
high_res = self.transform(high_res)
|
high_res = self.transform(high_res)
|
||||||
|
|
||||||
return low_res, high_res
|
return low_res, high_res
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nError at index {idx}: {str(e)}")
|
print(f"\nError at index {idx}: {str(e)}")
|
||||||
self.failed_indices.add(idx)
|
self.failed_indices.add(idx)
|
||||||
return self[(idx + 1) % len(self)] # Return next valid sample
|
return self[(idx + 1) % len(self)]
|
||||||
|
|
||||||
# Example transform: converting PIL images to tensors
|
# Define any transformations you require (e.g., converting PIL images to tensors)
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
# Load the base AIIABase model and wrap it with the Upsampler.
|
||||||
# Replace with your actual pretrained model path
|
|
||||||
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
|
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
|
||||||
|
base_model = AIIABase.load(pretrained_model_path)
|
||||||
# Load the model using the AIIA.load class method (the implementation copied in your query)
|
model = Upsampler(base_model)
|
||||||
model = AIIABase.load(pretrained_model_path)
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
from torch import nn, optim
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
# Create your dataset and dataloader
|
# Create the dataset and dataloader.
|
||||||
dataset = UpscaleDataset(["/root/training_data/vision-dataset/image_upscaler.parquet", "/root/training_data/vision-dataset/image_vec_upscaler.parquet"], transform=transform)
|
dataset = UpscaleDataset([
|
||||||
|
"/root/training_data/vision-dataset/image_upscaler.parquet",
|
||||||
|
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
|
||||||
|
], transform=transform)
|
||||||
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
|
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
|
||||||
|
|
||||||
# Define a loss function and optimizer
|
# Define loss function and optimizer.
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
||||||
|
|
||||||
num_epochs = 10
|
num_epochs = 10
|
||||||
model.train() # Set model in training mode
|
model.train()
|
||||||
|
|
||||||
|
|
||||||
|
# Prepare a CSV file for logging training loss.
|
||||||
csv_file = 'losses.csv'
|
csv_file = 'losses.csv'
|
||||||
|
|
||||||
# Create or open the CSV file and write the header if it doesn't exist
|
|
||||||
with open(csv_file, mode='a', newline='') as file:
|
with open(csv_file, mode='a', newline='') as file:
|
||||||
writer = csv.writer(file)
|
writer = csv.writer(file)
|
||||||
# Write the header only if the file is empty
|
|
||||||
if file.tell() == 0:
|
if file.tell() == 0:
|
||||||
writer.writerow(['Epoch', 'Train Loss'])
|
writer.writerow(['Epoch', 'Train Loss'])
|
||||||
|
|
||||||
# Create a gradient scaler (for scaling gradients when using AMP)
|
# Initialize automatic mixed precision scaler and EarlyStopping.
|
||||||
scaler = GradScaler()
|
scaler = GradScaler()
|
||||||
|
early_stopping = EarlyStopping(patience=3, min_delta=0.001)
|
||||||
|
|
||||||
|
# Training loop with early stopping.
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
epoch_loss = 0.0
|
epoch_loss = 0.0
|
||||||
data_loader_with_progress = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
|
progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
|
||||||
for low_res, high_res in data_loader_with_progress:
|
print(f"Epoch: {epoch}")
|
||||||
|
for low_res, high_res in progress_bar:
|
||||||
low_res = low_res.to(device, non_blocking=True)
|
low_res = low_res.to(device, non_blocking=True)
|
||||||
high_res = high_res.to(device, non_blocking=True)
|
high_res = high_res.to(device, non_blocking=True)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Use automatic mixed precision context
|
# Use automatic mixed precision to speed up training on supported hardware.
|
||||||
with autocast(device_type="cuda"):
|
with autocast(device_type=device.type):
|
||||||
outputs = model(low_res)
|
outputs = model(low_res)
|
||||||
loss = criterion(outputs, high_res)
|
loss = criterion(outputs, high_res)
|
||||||
|
|
||||||
|
@ -135,13 +149,20 @@ for epoch in range(num_epochs):
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
|
||||||
epoch_loss += loss.item()
|
epoch_loss += loss.item()
|
||||||
|
progress_bar.set_postfix({'loss': loss.item()})
|
||||||
|
|
||||||
print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")
|
print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")
|
||||||
|
|
||||||
# Append the training loss to the CSV file
|
# Record the loss in the CSV log.
|
||||||
with open(csv_file, mode='a', newline='') as file:
|
with open(csv_file, mode='a', newline='') as file:
|
||||||
writer = csv.writer(file)
|
writer = csv.writer(file)
|
||||||
writer.writerow([epoch + 1, epoch_loss])
|
writer.writerow([epoch + 1, epoch_loss])
|
||||||
|
|
||||||
# Optionally, save the finetuned model to a new directory
|
# Check early stopping criteria.
|
||||||
|
if early_stopping(epoch_loss):
|
||||||
|
print(f"Early stopping triggered at epoch {epoch + 1} with loss {epoch_loss}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Optionally, save the finetuned model using your library's save method.
|
||||||
finetuned_model_path = "aiuNN"
|
finetuned_model_path = "aiuNN"
|
||||||
model.save(finetuned_model_path)
|
model.save(finetuned_model_path)
|
||||||
|
|
Loading…
Reference in New Issue