aiuNN/src/aiunn/finetune.py

181 lines
6.4 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import io
import csv
import base64
from PIL import Image, ImageFile
from torch.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from torch.utils.checkpoint import checkpoint
import gc
from aiia import AIIABase
from upsampler import Upsampler
# Define a simple EarlyStopping class to monitor the epoch loss.
class EarlyStopping:
def __init__(self, patience=3, min_delta=0.001):
self.patience = patience # Number of epochs with no significant improvement before stopping.
self.min_delta = min_delta # Minimum change in loss required to count as an improvement.
self.best_loss = float('inf')
self.counter = 0
self.early_stop = False
def __call__(self, epoch_loss):
if epoch_loss < self.best_loss - self.min_delta:
self.best_loss = epoch_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
# UpscaleDataset to load and preprocess your data.
class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None):
combined_df = pd.DataFrame()
for parquet_file in parquet_files:
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(500)
combined_df = pd.concat([combined_df, df], ignore_index=True)
self.df = combined_df.apply(self._validate_row, axis=1)
self.transform = transform
self.failed_indices = set()
def _validate_row(self, row):
for col in ['image_512', 'image_1024']:
if not isinstance(row[col], (bytes, str)):
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
return row
def _decode_image(self, data):
try:
if isinstance(data, str):
return base64.b64decode(data)
elif isinstance(data, bytes):
return data
raise ValueError(f"Unsupported data type: {type(data)}")
except Exception as e:
raise RuntimeError(f"Decoding failed: {str(e)}")
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
if idx in self.failed_indices:
return self[(idx + 1) % len(self)]
try:
row = self.df.iloc[idx]
low_res_bytes = self._decode_image(row['image_512'])
high_res_bytes = self._decode_image(row['image_1024'])
ImageFile.LOAD_TRUNCATED_IMAGES = True
low_res = Image.open(io.BytesIO(low_res_bytes)).convert('RGB')
high_res = Image.open(io.BytesIO(high_res_bytes)).convert('RGB')
if low_res.size != (512, 512) or high_res.size != (1024, 1024):
raise ValueError(f"Size mismatch: LowRes={low_res.size}, HighRes={high_res.size}")
if self.transform:
low_res = self.transform(low_res)
high_res = self.transform(high_res)
return low_res, high_res
except Exception as e:
print(f"\nError at index {idx}: {str(e)}")
self.failed_indices.add(idx)
return self[(idx + 1) % len(self)]
# Define any transformations you require.
transform = transforms.Compose([
transforms.ToTensor(),
])
# Load the base AIIABase model and wrap it with the Upsampler.
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
base_model = AIIABase.load(pretrained_model_path)
model = Upsampler(base_model)
device = torch.device("cpu")#torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move model to device using channels_last memory format.
model = model.to(device, memory_format=torch.channels_last)
# Optional: flag to enable gradient checkpointing.
use_checkpointing = True
# Create the dataset and dataloader.
dataset = UpscaleDataset([
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
], transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True) # Consider adjusting num_workers if needed.
# Define loss function and optimizer.
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 10
model.train()
# Prepare a CSV file for logging training loss.
csv_file = 'losses.csv'
with open(csv_file, mode='a', newline='') as file:
writer = csv.writer(file)
if file.tell() == 0:
writer.writerow(['Epoch', 'Train Loss'])
scaler = GradScaler()
early_stopping = EarlyStopping(patience=3, min_delta=0.001)
# Training loop with early stopping.
for epoch in range(num_epochs):
epoch_loss = 0.0
progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
print(f"Epoch: {epoch + 1}")
for low_res, high_res in progress_bar:
# Move data to GPU with channels_last format where possible.
low_res = low_res.to(device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(device, non_blocking=True)
optimizer.zero_grad()
with autocast(device_type=device.type):
if use_checkpointing:
# Ensure the input tensor requires gradient so that checkpointing records the computation graph.
low_res.requires_grad_()
outputs = checkpoint(model, low_res)
else:
outputs = model(low_res)
loss = criterion(outputs, high_res)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
# Optionally delete variables to free memory.
del low_res, high_res, outputs, loss
# Perform garbage collection and clear GPU cache after each epoch.
gc.collect()
torch.cuda.empty_cache()
print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")
# Record the loss in the CSV log.
with open(csv_file, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, epoch_loss])
if early_stopping(epoch_loss):
print(f"Early stopping triggered at epoch {epoch + 1} with loss {epoch_loss}")
break
# Optionally save the fine-tuned model.
finetuned_model_path = "aiuNN"
model.save(finetuned_model_path)