From fcebc103b8a25eb1e0cc68c73ce81edf9d60e7a4 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Tue, 25 Feb 2025 15:47:21 +0100 Subject: [PATCH] fixed paths --- src/aiunn/__init__.py | 9 +- src/aiunn/finetune.py | 199 ------------ src/aiunn/finetune/__init__.py | 3 + src/aiunn/finetune/trainer.py | 289 ++++++++++++++++++ src/aiunn/inference/__init__.py | 0 src/aiunn/{ => inference}/inference.py | 2 + src/aiunn/upsampler/__init__.py | 5 + .../{upsampler.py => upsampler/aiunn.py} | 10 +- src/aiunn/{ => upsampler}/config.py | 0 9 files changed, 308 insertions(+), 209 deletions(-) delete mode 100644 src/aiunn/finetune.py create mode 100644 src/aiunn/finetune/__init__.py create mode 100644 src/aiunn/finetune/trainer.py create mode 100644 src/aiunn/inference/__init__.py rename src/aiunn/{ => inference}/inference.py (99%) create mode 100644 src/aiunn/upsampler/__init__.py rename src/aiunn/{upsampler.py => upsampler/aiunn.py} (92%) rename src/aiunn/{ => upsampler}/config.py (100%) diff --git a/src/aiunn/__init__.py b/src/aiunn/__init__.py index a8013f3..2e6f021 100644 --- a/src/aiunn/__init__.py +++ b/src/aiunn/__init__.py @@ -1,6 +1,5 @@ +from .finetune.trainer import aiuNNTrainer +from .upsampler.aiunn import aiuNN +from .upsampler.config import aiuNNConfig -from .finetune import * -from .inference import UpScaler - -__version__ = "0.1.0" - +__version__ = "0.1.1" \ No newline at end of file diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py deleted file mode 100644 index 508a4e8..0000000 --- a/src/aiunn/finetune.py +++ /dev/null @@ -1,199 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim -import pandas as pd -import io -import csv -import base64 -from PIL import Image, ImageFile -from torch.amp import autocast, GradScaler -from torch.utils.data import Dataset, DataLoader -from torchvision import transforms -from tqdm import tqdm -from torch.utils.checkpoint import checkpoint -import gc - -from aiia import AIIABase -from upsampler import Upsampler - -# Define a simple EarlyStopping class to monitor the epoch loss. -class EarlyStopping: - def __init__(self, patience=3, min_delta=0.001): - self.patience = patience # Number of epochs with no significant improvement before stopping. - self.min_delta = min_delta # Minimum change in loss required to count as an improvement. - self.best_loss = float('inf') - self.counter = 0 - self.early_stop = False - - def __call__(self, epoch_loss): - if epoch_loss < self.best_loss - self.min_delta: - self.best_loss = epoch_loss - self.counter = 0 - else: - self.counter += 1 - if self.counter >= self.patience: - self.early_stop = True - return self.early_stop - -# UpscaleDataset to load and preprocess your data. -class UpscaleDataset(Dataset): - def __init__(self, parquet_files: list, transform=None): - combined_df = pd.DataFrame() - for parquet_file in parquet_files: - # Load a subset (head(2500)) from each parquet file - df = pd.read_parquet(parquet_file, columns=['image_512', 'image_1024']).head(2500) - combined_df = pd.concat([combined_df, df], ignore_index=True) - - # Validate rows (ensuring each value is bytes or str) - self.df = combined_df.apply(self._validate_row, axis=1) - self.transform = transform - self.failed_indices = set() - - def _validate_row(self, row): - for col in ['image_512', 'image_1024']: - if not isinstance(row[col], (bytes, str)): - raise ValueError(f"Invalid data type in column {col}: {type(row[col])}") - return row - - def _decode_image(self, data): - try: - if isinstance(data, str): - return base64.b64decode(data) - elif isinstance(data, bytes): - return data - raise ValueError(f"Unsupported data type: {type(data)}") - except Exception as e: - raise RuntimeError(f"Decoding failed: {str(e)}") - - def __len__(self): - return len(self.df) - - def __getitem__(self, idx): - # If previous call failed for this index, use a different index. - if idx in self.failed_indices: - return self[(idx + 1) % len(self)] - try: - row = self.df.iloc[idx] - low_res_bytes = self._decode_image(row['image_512']) - high_res_bytes = self._decode_image(row['image_1024']) - ImageFile.LOAD_TRUNCATED_IMAGES = True - # Open image bytes with Pillow and convert to RGBA first - low_res_rgba = Image.open(io.BytesIO(low_res_bytes)).convert('RGBA') - high_res_rgba = Image.open(io.BytesIO(high_res_bytes)).convert('RGBA') - - # Create a new RGB image with black background - low_res_rgb = Image.new("RGB", low_res_rgba.size, (0, 0, 0)) - high_res_rgb = Image.new("RGB", high_res_rgba.size, (0, 0, 0)) - - # Composite the original image over the black background - low_res_rgb.paste(low_res_rgba, mask=low_res_rgba.split()[3]) - high_res_rgb.paste(high_res_rgba, mask=high_res_rgba.split()[3]) - - # Now we have true 3-channel RGB images with transparent areas converted to black - low_res = low_res_rgb - high_res = high_res_rgb - - # Resize the images to reduce VRAM usage. - low_res = low_res.resize((384, 384), Image.LANCZOS) - high_res = high_res.resize((768, 768), Image.LANCZOS) - # If a transform is provided (e.g. conversion to Tensor), apply it. - if self.transform: - low_res = self.transform(low_res) - high_res = self.transform(high_res) - return low_res, high_res - except Exception as e: - print(f"\nError at index {idx}: {str(e)}") - self.failed_indices.add(idx) - return self[(idx + 1) % len(self)] - -# Define any transformations you require. -transform = transforms.Compose([ - transforms.ToTensor(), -]) - -# Load the base AIIABase model and wrap it with the Upsampler. -pretrained_model_path = "/root/vision/AIIA/AIIA-base-512" -base_model = AIIABase.load(pretrained_model_path, precision="bf16") -model = Upsampler(base_model) - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# Move model to device using channels_last memory format. -model = model.to(device, memory_format=torch.channels_last) - -# Optional: flag to enable gradient checkpointing. -use_checkpointing = True - -# Create the dataset and dataloader. -dataset = UpscaleDataset([ - "/root/training_data/vision-dataset/image_upscaler.parquet", - "/root/training_data/vision-dataset/image_vec_upscaler.parquet" -], transform=transform) -data_loader = DataLoader(dataset, batch_size=1, shuffle=True) # Consider adjusting num_workers if needed. - -# Define loss function and optimizer. -criterion = nn.MSELoss() -optimizer = optim.Adam(model.parameters(), lr=1e-4) - -num_epochs = 10 -model.train() - -# Prepare a CSV file for logging training loss. -csv_file = 'losses.csv' -with open(csv_file, mode='a', newline='') as file: - writer = csv.writer(file) - if file.tell() == 0: - writer.writerow(['Epoch', 'Train Loss']) - -scaler = GradScaler() -early_stopping = EarlyStopping(patience=3, min_delta=0.001) - -# Training loop with early stopping. -for epoch in range(num_epochs): - epoch_loss = 0.0 - progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}") - print(f"Epoch: {epoch + 1}") - for low_res, high_res in progress_bar: - # Move data to GPU with channels_last format where possible. - low_res = low_res.to(device, non_blocking=True).to(memory_format=torch.channels_last) - high_res = high_res.to(device, non_blocking=True) - - optimizer.zero_grad() - - with autocast(device_type=device.type): - if use_checkpointing: - # Ensure the input tensor requires gradient so that checkpointing records the computation graph. - low_res.requires_grad_() - outputs = checkpoint(model, low_res) - else: - outputs = model(low_res) - loss = criterion(outputs, high_res) - - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - - epoch_loss += loss.item() - progress_bar.set_postfix({'loss': loss.item()}) - - # Optionally delete variables to free memory. - del low_res, high_res, outputs, loss - - # Perform garbage collection and clear GPU cache after each epoch. - gc.collect() - torch.cuda.empty_cache() - - print(f"Epoch {epoch + 1}, Loss: {epoch_loss}") - - # Record the loss in the CSV log. - with open(csv_file, mode='a', newline='') as file: - writer = csv.writer(file) - writer.writerow([epoch + 1, epoch_loss]) - - if early_stopping(epoch_loss): - print(f"Early stopping triggered at epoch {epoch + 1} with loss {epoch_loss}") - break - -# Optionally save the fine-tuned model. -finetuned_model_path = "aiuNN" -model.save(finetuned_model_path) diff --git a/src/aiunn/finetune/__init__.py b/src/aiunn/finetune/__init__.py new file mode 100644 index 0000000..33239b1 --- /dev/null +++ b/src/aiunn/finetune/__init__.py @@ -0,0 +1,3 @@ +from .trainer import aiuNNTrainer + +__all__ = ["aiuNNTrainer" ] \ No newline at end of file diff --git a/src/aiunn/finetune/trainer.py b/src/aiunn/finetune/trainer.py new file mode 100644 index 0000000..01047b9 --- /dev/null +++ b/src/aiunn/finetune/trainer.py @@ -0,0 +1,289 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import os +import csv +from torch.amp import autocast, GradScaler +from torch.utils.data import DataLoader +from tqdm import tqdm +from torch.utils.checkpoint import checkpoint +import gc +import time +import shutil + + +class EarlyStopping: + def __init__(self, patience=3, min_delta=0.001): + # Number of epochs with no significant improvement before stopping + # Minimum change in loss required to count as an improvement + self.patience = patience + self.min_delta = min_delta + self.best_loss = float('inf') + self.counter = 0 + self.early_stop = False + + def __call__(self, epoch_loss): + if epoch_loss < self.best_loss - self.min_delta: + self.best_loss = epoch_loss + self.counter = 0 + return True # Improved + else: + self.counter += 1 + if self.counter >= self.patience: + self.early_stop = True + return False # Not improved + +class aiuNNTrainer: + def __init__(self, upscaler_model, dataset_class=None): + """ + Initialize the upscaler trainer + + Args: + upscaler_model: The model to fine-tune + dataset_class: The dataset class to use for loading data (optional) + """ + self.model = upscaler_model + self.dataset_class = dataset_class + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = self.model.to(self.device, memory_format=torch.channels_last) + self.criterion = nn.MSELoss() + self.optimizer = None + self.scaler = GradScaler() + self.best_loss = float('inf') + self.use_checkpointing = True + self.data_loader = None + self.validation_loader = None + self.log_dir = None + + def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, custom_train_dataset=None, custom_val_dataset=None): + """ + Load data using either a custom dataset instance or the dataset class provided at initialization + + Args: + dataset_params (dict/list): Parameters to pass to the dataset class constructor + batch_size (int): Batch size for training + validation_split (float): Proportion of data to use for validation + custom_train_dataset: A pre-instantiated dataset to use for training (optional) + custom_val_dataset: A pre-instantiated dataset to use for validation (optional) + """ + # If custom datasets are provided directly, use them + if custom_train_dataset is not None: + train_dataset = custom_train_dataset + val_dataset = custom_val_dataset if custom_val_dataset is not None else None + else: + # Otherwise instantiate dataset using the class and parameters + if self.dataset_class is None: + raise ValueError("No dataset class provided. Either provide a dataset class at initialization or custom datasets.") + + # Create dataset instance + dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params}) + + # Split into train and validation sets + dataset_size = len(dataset) + val_size = int(validation_split * dataset_size) + train_size = dataset_size - val_size + + train_dataset, val_dataset = torch.utils.data.random_split( + dataset, [train_size, val_size] + ) + + # Create data loaders + self.data_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=True + ) + + if val_dataset is not None: + self.validation_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + pin_memory=True + ) + print(f"Loaded {len(train_dataset)} training samples and {len(val_dataset)} validation samples") + else: + self.validation_loader = None + print(f"Loaded {len(train_dataset)} training samples (no validation set)") + + return self.data_loader, self.validation_loader + + def _setup_logging(self, output_path): + """Set up directory structure for logging and model checkpoints""" + timestamp = time.strftime("%Y%m%d-%H%M%S") + self.log_dir = os.path.join(output_path, f"training_run_{timestamp}") + os.makedirs(self.log_dir, exist_ok=True) + + # Create checkpoint directory + self.checkpoint_dir = os.path.join(self.log_dir, "checkpoints") + os.makedirs(self.checkpoint_dir, exist_ok=True) + + # Set up CSV logging + self.csv_path = os.path.join(self.log_dir, 'training_log.csv') + with open(self.csv_path, mode='w', newline='') as file: + writer = csv.writer(file) + if self.validation_loader: + writer.writerow(['Epoch', 'Train Loss', 'Validation Loss', 'Improved']) + else: + writer.writerow(['Epoch', 'Train Loss', 'Improved']) + + def _evaluate(self): + """Evaluate the model on validation data""" + if self.validation_loader is None: + return 0.0 + + self.model.eval() + val_loss = 0.0 + + with torch.no_grad(): + for low_res, high_res in tqdm(self.validation_loader, desc="Validating"): + low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) + high_res = high_res.to(self.device, non_blocking=True) + + with autocast(device_type=self.device.type): + outputs = self.model(low_res) + loss = self.criterion(outputs, high_res) + + val_loss += loss.item() + + del low_res, high_res, outputs, loss + + self.model.train() + return val_loss + + def _save_checkpoint(self, epoch, is_best=False): + """Save model checkpoint""" + checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}.pt") + best_model_path = os.path.join(self.log_dir, "best_model") + + # Save the model checkpoint + self.model.save(checkpoint_path) + + # If this is the best model so far, copy it to best_model + if is_best: + if os.path.exists(best_model_path): + shutil.rmtree(best_model_path) + self.model.save(best_model_path) + print(f"Saved new best model with loss: {self.best_loss:.6f}") + + def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001): + """ + Finetune the upscaler model + + Args: + output_path (str): Directory to save models and logs + epochs (int): Maximum number of training epochs + lr (float): Learning rate + patience (int): Early stopping patience + min_delta (float): Minimum improvement for early stopping + """ + # Check if data is loaded + if self.data_loader is None: + raise ValueError("Data not loaded. Call load_data first.") + + # Setup optimizer + self.optimizer = optim.Adam(self.model.parameters(), lr=lr) + + # Set up logging + self._setup_logging(output_path) + + # Setup early stopping + early_stopping = EarlyStopping(patience=patience, min_delta=min_delta) + + # Training loop + self.model.train() + + for epoch in range(epochs): + # Training phase + epoch_loss = 0.0 + progress_bar = tqdm(self.data_loader, desc=f"Epoch {epoch + 1}/{epochs}") + + for low_res, high_res in progress_bar: + # Move data to GPU with channels_last format where possible + low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) + high_res = high_res.to(self.device, non_blocking=True) + + self.optimizer.zero_grad() + + with autocast(device_type=self.device.type): + if self.use_checkpointing: + # Ensure the input tensor requires gradient so that checkpointing records the computation graph + low_res.requires_grad_() + outputs = checkpoint(self.model, low_res) + else: + outputs = self.model(low_res) + loss = self.criterion(outputs, high_res) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + + epoch_loss += loss.item() + progress_bar.set_postfix({'loss': loss.item()}) + + # Optionally delete variables to free memory + del low_res, high_res, outputs, loss + + # Calculate average epoch loss + avg_train_loss = epoch_loss / len(self.data_loader) + + # Validation phase (if validation loader exists) + if self.validation_loader: + val_loss = self._evaluate() / len(self.validation_loader) + is_improved = val_loss < self.best_loss + if is_improved: + self.best_loss = val_loss + + # Log results + print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}") + with open(self.csv_path, mode='a', newline='') as file: + writer = csv.writer(file) + writer.writerow([epoch + 1, avg_train_loss, val_loss, "Yes" if is_improved else "No"]) + else: + # If no validation, use training loss for improvement tracking + is_improved = avg_train_loss < self.best_loss + if is_improved: + self.best_loss = avg_train_loss + + # Log results + print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}") + with open(self.csv_path, mode='a', newline='') as file: + writer = csv.writer(file) + writer.writerow([epoch + 1, avg_train_loss, "Yes" if is_improved else "No"]) + + # Save checkpoint + self._save_checkpoint(epoch + 1, is_best=is_improved) + + # Perform garbage collection and clear GPU cache after each epoch + gc.collect() + torch.cuda.empty_cache() + + # Check early stopping + if early_stopping(val_loss if self.validation_loader else avg_train_loss): + print(f"Early stopping triggered at epoch {epoch + 1}") + break + + return self.best_loss + + def save(self, output_path=None): + """ + Save the best model to the specified path + + Args: + output_path (str, optional): Path to save the model. If None, uses the best model from training. + """ + if output_path is None and self.log_dir is not None: + best_model_path = os.path.join(self.log_dir, "best_model") + if os.path.exists(best_model_path): + print(f"Best model already saved at {best_model_path}") + return best_model_path + else: + output_path = os.path.join(self.log_dir, "final_model") + + if output_path is None: + raise ValueError("No output path specified and no training has been done yet.") + + self.model.save(output_path) + print(f"Model saved to {output_path}") + return output_path \ No newline at end of file diff --git a/src/aiunn/inference/__init__.py b/src/aiunn/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/aiunn/inference.py b/src/aiunn/inference/inference.py similarity index 99% rename from src/aiunn/inference.py rename to src/aiunn/inference/inference.py index 991b708..6ed5b0d 100644 --- a/src/aiunn/inference.py +++ b/src/aiunn/inference/inference.py @@ -31,6 +31,8 @@ class Upscaler(nn.Module): def forward(self, x): features = self.base_model(x) return self.last_transform(features) + + class ImageUpscaler: def __init__(self, model_path: str, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): self.device = torch.device(device) diff --git a/src/aiunn/upsampler/__init__.py b/src/aiunn/upsampler/__init__.py new file mode 100644 index 0000000..e179503 --- /dev/null +++ b/src/aiunn/upsampler/__init__.py @@ -0,0 +1,5 @@ +from .aiunn import aiuNN +from .config import aiuNNConfig + + +__all__ = ["aiuNN", "aiuNNConfig"] \ No newline at end of file diff --git a/src/aiunn/upsampler.py b/src/aiunn/upsampler/aiunn.py similarity index 92% rename from src/aiunn/upsampler.py rename to src/aiunn/upsampler/aiunn.py index 657e6b3..ca7ba2b 100644 --- a/src/aiunn/upsampler.py +++ b/src/aiunn/upsampler/aiunn.py @@ -3,17 +3,17 @@ import torch import torch.nn as nn import warnings from aiia import AIIA, AIIAConfig, AIIABase -from config import UpsamplerConfig +from .config import aiuNNConfig import warnings -class Upsampler(AIIA): +class aiuNN(AIIA): def __init__(self, base_model: AIIABase): super().__init__(base_model.config) self.base_model = base_model # Pass the unified base configuration using the new parameter. - self.config = UpsamplerConfig(base_config=base_model.config) + self.config = aiuNNConfig(base_config=base_model.config) self.upsample = nn.Upsample( scale_factor=self.config.upsample_scale, @@ -72,11 +72,11 @@ if __name__ == "__main__": config = AIIAConfig() base_model = AIIABase(config) # Instantiate Upsampler from the base model (works correctly). - upsampler = Upsampler(base_model) + upsampler = aiuNN(base_model) # Save the model (both configuration and weights). upsampler.save("hehe") # Now load using the overridden load method; this will load the complete model. - upsampler_loaded = Upsampler.load("hehe", precision="bf16") + upsampler_loaded = aiuNN.load("hehe", precision="bf16") print("Updated configuration:", upsampler_loaded.config.__dict__) diff --git a/src/aiunn/config.py b/src/aiunn/upsampler/config.py similarity index 100% rename from src/aiunn/config.py rename to src/aiunn/upsampler/config.py