fixed paths

This commit is contained in:
Falko Victor Habel 2025-02-25 15:47:21 +01:00
parent 09f196294c
commit fcebc103b8
9 changed files with 308 additions and 209 deletions

View File

@ -1,6 +1,5 @@
from .finetune.trainer import aiuNNTrainer
from .upsampler.aiunn import aiuNN
from .upsampler.config import aiuNNConfig
from .finetune import *
from .inference import UpScaler
__version__ = "0.1.0"
__version__ = "0.1.1"

View File

@ -1,199 +0,0 @@
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import io
import csv
import base64
from PIL import Image, ImageFile
from torch.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from torch.utils.checkpoint import checkpoint
import gc
from aiia import AIIABase
from upsampler import Upsampler
# Define a simple EarlyStopping class to monitor the epoch loss.
class EarlyStopping:
def __init__(self, patience=3, min_delta=0.001):
self.patience = patience # Number of epochs with no significant improvement before stopping.
self.min_delta = min_delta # Minimum change in loss required to count as an improvement.
self.best_loss = float('inf')
self.counter = 0
self.early_stop = False
def __call__(self, epoch_loss):
if epoch_loss < self.best_loss - self.min_delta:
self.best_loss = epoch_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
# UpscaleDataset to load and preprocess your data.
class UpscaleDataset(Dataset):
def __init__(self, parquet_files: list, transform=None):
combined_df = pd.DataFrame()
for parquet_file in parquet_files:
# Load a subset (head(2500)) from each parquet file
df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(2500)
combined_df = pd.concat([combined_df, df], ignore_index=True)
# Validate rows (ensuring each value is bytes or str)
self.df = combined_df.apply(self._validate_row, axis=1)
self.transform = transform
self.failed_indices = set()
def _validate_row(self, row):
for col in ['image_512', 'image_1024']:
if not isinstance(row[col], (bytes, str)):
raise ValueError(f"Invalid data type in column {col}: {type(row[col])}")
return row
def _decode_image(self, data):
try:
if isinstance(data, str):
return base64.b64decode(data)
elif isinstance(data, bytes):
return data
raise ValueError(f"Unsupported data type: {type(data)}")
except Exception as e:
raise RuntimeError(f"Decoding failed: {str(e)}")
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# If previous call failed for this index, use a different index.
if idx in self.failed_indices:
return self[(idx + 1) % len(self)]
try:
row = self.df.iloc[idx]
low_res_bytes = self._decode_image(row['image_512'])
high_res_bytes = self._decode_image(row['image_1024'])
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Open image bytes with Pillow and convert to RGBA first
low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA')
high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA')
# Create a new RGB image with black background
low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0))
high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0))
# Composite the original image over the black background
low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3])
high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3])
# Now we have true 3-channel RGB images with transparent areas converted to black
low_res = low_res_rgb
high_res = high_res_rgb
# Resize the images to reduce VRAM usage.
low_res = low_res.resize((384, 384), Image.LANCZOS)
high_res = high_res.resize((768, 768), Image.LANCZOS)
# If a transform is provided (e.g. conversion to Tensor), apply it.
if self.transform:
low_res = self.transform(low_res)
high_res = self.transform(high_res)
return low_res, high_res
except Exception as e:
print(f"\nError at index {idx}: {str(e)}")
self.failed_indices.add(idx)
return self[(idx + 1) % len(self)]
# Define any transformations you require.
transform = transforms.Compose([
transforms.ToTensor(),
])
# Load the base AIIABase model and wrap it with the Upsampler.
pretrained_model_path = "/root/vision/AIIA/AIIA-base-512"
base_model = AIIABase.load(pretrained_model_path, precision="bf16")
model = Upsampler(base_model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move model to device using channels_last memory format.
model = model.to(device, memory_format=torch.channels_last)
# Optional: flag to enable gradient checkpointing.
use_checkpointing = True
# Create the dataset and dataloader.
dataset = UpscaleDataset([
"/root/training_data/vision-dataset/image_upscaler.parquet",
"/root/training_data/vision-dataset/image_vec_upscaler.parquet"
], transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True) # Consider adjusting num_workers if needed.
# Define loss function and optimizer.
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 10
model.train()
# Prepare a CSV file for logging training loss.
csv_file = 'losses.csv'
with open(csv_file, mode='a', newline='') as file:
writer = csv.writer(file)
if file.tell() == 0:
writer.writerow(['Epoch', 'Train Loss'])
scaler = GradScaler()
early_stopping = EarlyStopping(patience=3, min_delta=0.001)
# Training loop with early stopping.
for epoch in range(num_epochs):
epoch_loss = 0.0
progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}")
print(f"Epoch: {epoch + 1}")
for low_res, high_res in progress_bar:
# Move data to GPU with channels_last format where possible.
low_res = low_res.to(device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(device, non_blocking=True)
optimizer.zero_grad()
with autocast(device_type=device.type):
if use_checkpointing:
# Ensure the input tensor requires gradient so that checkpointing records the computation graph.
low_res.requires_grad_()
outputs = checkpoint(model, low_res)
else:
outputs = model(low_res)
loss = criterion(outputs, high_res)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
epoch_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
# Optionally delete variables to free memory.
del low_res, high_res, outputs, loss
# Perform garbage collection and clear GPU cache after each epoch.
gc.collect()
torch.cuda.empty_cache()
print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")
# Record the loss in the CSV log.
with open(csv_file, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, epoch_loss])
if early_stopping(epoch_loss):
print(f"Early stopping triggered at epoch {epoch + 1} with loss {epoch_loss}")
break
# Optionally save the fine-tuned model.
finetuned_model_path = "aiuNN"
model.save(finetuned_model_path)

View File

@ -0,0 +1,3 @@
from .trainer import aiuNNTrainer
__all__ = ["aiuNNTrainer" ]

View File

@ -0,0 +1,289 @@
import torch
import torch.nn as nn
import torch.optim as optim
import os
import csv
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.checkpoint import checkpoint
import gc
import time
import shutil
class EarlyStopping:
def __init__(self, patience=3, min_delta=0.001):
# Number of epochs with no significant improvement before stopping
# Minimum change in loss required to count as an improvement
self.patience = patience
self.min_delta = min_delta
self.best_loss = float('inf')
self.counter = 0
self.early_stop = False
def __call__(self, epoch_loss):
if epoch_loss < self.best_loss - self.min_delta:
self.best_loss = epoch_loss
self.counter = 0
return True # Improved
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return False # Not improved
class aiuNNTrainer:
def __init__(self, upscaler_model, dataset_class=None):
"""
Initialize the upscaler trainer
Args:
upscaler_model: The model to fine-tune
dataset_class: The dataset class to use for loading data (optional)
"""
self.model = upscaler_model
self.dataset_class = dataset_class
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device, memory_format=torch.channels_last)
self.criterion = nn.MSELoss()
self.optimizer = None
self.scaler = GradScaler()
self.best_loss = float('inf')
self.use_checkpointing = True
self.data_loader = None
self.validation_loader = None
self.log_dir = None
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None):
"""
Load data using either a custom dataset instance or the dataset class provided at initialization
Args:
dataset_params (dict/list): Parameters to pass to the dataset class constructor
batch_size (int): Batch size for training
validation_split (float): Proportion of data to use for validation
custom_train_dataset: A pre-instantiated dataset to use for training (optional)
custom_val_dataset: A pre-instantiated dataset to use for validation (optional)
"""
# If custom datasets are provided directly, use them
if custom_train_dataset is not None:
train_dataset = custom_train_dataset
val_dataset = custom_val_dataset if custom_val_dataset is not None else None
else:
# Otherwise instantiate dataset using the class and parameters
if self.dataset_class is None:
raise ValueError("No dataset class provided. Either provide a dataset class at initialization or custom datasets.")
# Create dataset instance
dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params})
# Split into train and validation sets
dataset_size = len(dataset)
val_size = int(validation_split * dataset_size)
train_size = dataset_size - val_size
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size]
)
# Create data loaders
self.data_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True
)
if val_dataset is not None:
self.validation_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True
)
print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples")
else:
self.validation_loader = None
print(f"Loaded {len(train_dataset)} training samples (no validation set)")
return self.data_loader, self.validation_loader
def _setup_logging(self, output_path):
"""Set up directory structure for logging and model checkpoints"""
timestamp = time.strftime("%Y%m%d-%H%M%S")
self.log_dir = os.path.join(output_path, f"training_run_{timestamp}")
os.makedirs(self.log_dir, exist_ok=True)
# Create checkpoint directory
self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True)
# Set up CSV logging
self.csv_path = os.path.join(self.log_dir, 'training_log.csv')
with open(self.csv_path, mode='w', newline='') as file:
writer = csv.writer(file)
if self.validation_loader:
writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved'])
else:
writer.writerow(['Epoch', 'Train Loss', 'Improved'])
def _evaluate(self):
"""Evaluate the model on validation data"""
if self.validation_loader is None:
return 0.0
self.model.eval()
val_loss = 0.0
with torch.no_grad():
for low_res, high_res in tqdm(self.validation_loader, desc="Validating"):
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(self.device, non_blocking=True)
with autocast(device_type=self.device.type):
outputs = self.model(low_res)
loss = self.criterion(outputs, high_res)
val_loss += loss.item()
del low_res, high_res, outputs, loss
self.model.train()
return val_loss
def _save_checkpoint(self, epoch, is_best=False):
"""Save model checkpoint"""
checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt")
best_model_path = os.path.join(self.log_dir, "best_model")
# Save the model checkpoint
self.model.save(checkpoint_path)
# If this is the best model so far, copy it to best_model
if is_best:
if os.path.exists(best_model_path):
shutil.rmtree(best_model_path)
self.model.save(best_model_path)
print(f"Saved new best model with loss: {self.best_loss:.6f}")
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
"""
Finetune the upscaler model
Args:
output_path (str): Directory to save models and logs
epochs (int): Maximum number of training epochs
lr (float): Learning rate
patience (int): Early stopping patience
min_delta (float): Minimum improvement for early stopping
"""
# Check if data is loaded
if self.data_loader is None:
raise ValueError("Data not loaded. Call load_data first.")
# Setup optimizer
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
# Set up logging
self._setup_logging(output_path)
# Setup early stopping
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
# Training loop
self.model.train()
for epoch in range(epochs):
# Training phase
epoch_loss = 0.0
progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}")
for low_res, high_res in progress_bar:
# Move data to GPU with channels_last format where possible
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(self.device, non_blocking=True)
self.optimizer.zero_grad()
with autocast(device_type=self.device.type):
if self.use_checkpointing:
# Ensure the input tensor requires gradient so that checkpointing records the computation graph
low_res.requires_grad_()
outputs = checkpoint(self.model, low_res)
else:
outputs = self.model(low_res)
loss = self.criterion(outputs, high_res)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
epoch_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
# Optionally delete variables to free memory
del low_res, high_res, outputs, loss
# Calculate average epoch loss
avg_train_loss = epoch_loss / len(self.data_loader)
# Validation phase (if validation loader exists)
if self.validation_loader:
val_loss = self._evaluate() / len(self.validation_loader)
is_improved = val_loss < self.best_loss
if is_improved:
self.best_loss = val_loss
# Log results
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"])
else:
# If no validation, use training loss for improvement tracking
is_improved = avg_train_loss < self.best_loss
if is_improved:
self.best_loss = avg_train_loss
# Log results
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"])
# Save checkpoint
self._save_checkpoint(epoch + 1, is_best=is_improved)
# Perform garbage collection and clear GPU cache after each epoch
gc.collect()
torch.cuda.empty_cache()
# Check early stopping
if early_stopping(val_loss if self.validation_loader else avg_train_loss):
print(f"Early stopping triggered at epoch {epoch + 1}")
break
return self.best_loss
def save(self, output_path=None):
"""
Save the best model to the specified path
Args:
output_path (str, optional): Path to save the model. If None, uses the best model from training.
"""
if output_path is None and self.log_dir is not None:
best_model_path = os.path.join(self.log_dir, "best_model")
if os.path.exists(best_model_path):
print(f"Best model already saved at {best_model_path}")
return best_model_path
else:
output_path = os.path.join(self.log_dir, "final_model")
if output_path is None:
raise ValueError("No output path specified and no training has been done yet.")
self.model.save(output_path)
print(f"Model saved to {output_path}")
return output_path

View File

View File

@ -31,6 +31,8 @@ class Upscaler(nn.Module):
def forward(self, x):
features = self.base_model(x)
return self.last_transform(features)
class ImageUpscaler:
def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
self.device = torch.device(device)

View File

@ -0,0 +1,5 @@
from .aiunn import aiuNN
from .config import aiuNNConfig
__all__ = ["aiuNN", "aiuNNConfig"]

View File

@ -3,17 +3,17 @@ import torch
import torch.nn as nn
import warnings
from aiia import AIIA, AIIAConfig, AIIABase
from config import UpsamplerConfig
from .config import aiuNNConfig
import warnings
class Upsampler(AIIA):
class aiuNN(AIIA):
def __init__(self, base_model: AIIABase):
super().__init__(base_model.config)
self.base_model = base_model
# Pass the unified base configuration using the new parameter.
self.config = UpsamplerConfig(base_config=base_model.config)
self.config = aiuNNConfig(base_config=base_model.config)
self.upsample = nn.Upsample(
scale_factor=self.config.upsample_scale,
@ -72,11 +72,11 @@ if __name__ == "__main__":
config = AIIAConfig()
base_model = AIIABase(config)
# Instantiate Upsampler from the base model (works correctly).
upsampler = Upsampler(base_model)
upsampler = aiuNN(base_model)
# Save the model (both configuration and weights).
upsampler.save("hehe")
# Now load using the overridden load method; this will load the complete model.
upsampler_loaded = Upsampler.load("hehe", precision="bf16")
upsampler_loaded = aiuNN.load("hehe", precision="bf16")
print("Updated configuration:", upsampler_loaded.config.__dict__)