diff --git a/README.md b/README.md index 802ba14..e3a3c9f 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,6 @@ # aiuNN -Adaptive Image Upscaler using Neural Networks - -## Overview - -`aiuNN` is an adaptive image upscaling model built on top of the Adaptive Image Intelligence Architecture (AIIA). This project provides fine-tuned versions of AIIA models specifically designed for high-quality image upscaling. By leveraging neural networks, `aiuNN` can significantly enhance the resolution and detail of images. +Adaptive Image Upscaler using Neural Networks `aiuNN` is an adaptive image upscaling model built on top of the Adaptive Image Intelligence Architecture (AIIA). This project provides fine-tuned versions of AIIA models specifically designed for high-quality image upscaling. By leveraging neural networks, `aiuNN` can significantly enhance the resolution and detail of images. ## Features @@ -23,6 +19,8 @@ pip install git+https://gitea.fabelous.app/Machine-Learning/aiuNN.git ## Usage +### Basic Example + Here's a basic example of how to use `aiuNN` for image upscaling: ```python src/main.py @@ -61,12 +59,10 @@ trainer.load_data(dataset_params=dataset_params, batch_size=1) trainer.finetune(output_path="trained_model") ``` -## Dataset +### Dataset Class The `UpscaleDataset` class is designed to handle Parquet files containing image data. It loads a subset of images from each file and validates the data types to ensure consistency. -This is an example dataset that you can use with the AIIUN model: - ```python src/example.py class UpscaleDataset(Dataset): def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000): @@ -134,4 +130,81 @@ class UpscaleDataset(Dataset): print(f"\nError at index {idx}: {str(e)}") self.failed_indices.add(idx) return self[(idx + 1) % len(self)] -``` \ No newline at end of file +``` + +### Trainers + +`aiuNN` provides two types of trainers: the standard `aiuNNTrainer` and the optimized `MemoryOptimizedTrainer`. Choose the one that best fits your needs based on your hardware capabilities and memory constraints. + +#### Standard Trainer (`aiuNNTrainer`) + +Use the standard trainer when you have sufficient memory resources and do not need advanced optimizations. This trainer is straightforward and easy to use for most basic training tasks. + +```python src/main.py +from aiunn import aiuNNTrainer + +# Create trainer with your dataset class +trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset) + +# Load data using parameters for your dataset +dataset_params = { + 'parquet_files': [ + "path/to/dataset1", + "path/to/dataset2" + ], + 'transform': transforms.Compose([transforms.ToTensor()]), + 'samples_per_file': 5000 # Your training samples you want to load per file +} +trainer.load_data(dataset_params=dataset_params, batch_size=4) + +# Fine-tune the model +trainer.finetune(output_path="trained_model") +``` + +#### Memory-Efficient Trainer (`MemoryOptimizedTrainer`) + +Use the `MemoryOptimizedTrainer` when you have limited memory resources and need to optimize training for better efficiency. This trainer includes features like gradient accumulation, memory profiling, and aggressive memory cleanup. + +```python src/main.py +from aiunn import MemoryOptimizedTrainer + +# Replace your existing trainer with the optimized version +trainer = MemoryOptimizedTrainer( + upscaler_model=your_model, + dataset_class=your_dataset_class, + use_gradient_accumulation=True, + accumulation_steps=4, # Effective batch size = batch_size * 4 + use_memory_profiling=True, + use_model_compilation=True +) + +# Load data with memory optimizations +trainer.load_data( + dataset_params=your_params, + batch_size=2, # Use smaller batch size with gradient accumulation + val_batch_size=1 # Even smaller validation batches +) + +# Train with optimizations +trainer.finetune( + output_path="./output", + epochs=10, + lr=1e-4 +) + +# Get memory usage summary +memory_summary = trainer.get_memory_summary() +print(memory_summary) +``` + +#### When to Use Which + +- **Use `aiuNNTrainer` when**: + - You have sufficient GPU and RAM resources. + - You prefer a straightforward training process without additional optimizations. + - Your dataset is not excessively large, and memory constraints are not a concern. + +- **Use `MemoryOptimizedTrainer` when**: + - You have limited GPU or RAM resources. + - You need to train on larger datasets that may exceed your hardware capabilities. + - You want to monitor memory usage and optimize training for better efficiency. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 873f10f..47fb3a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=45", "wheel"] build-backend = "setuptools.build_meta" [project] name = "aiunn" -version = "0.2.3" +version = "0.2.4" description = "Finetuner for image upscaling using AIIA" readme = "README.md" requires-python = ">=3.10" diff --git a/setup.py b/setup.py index 6e22059..fc567d9 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name="aiunn", - version="0.2.3", + version="0.2.4", packages=find_packages(where="src"), package_dir={"": "src"}, install_requires=[ diff --git a/src/aiunn/__init__.py b/src/aiunn/__init__.py index 05c37ab..f8b8019 100644 --- a/src/aiunn/__init__.py +++ b/src/aiunn/__init__.py @@ -1,6 +1,7 @@ from .finetune.trainer import aiuNNTrainer +from .finetune.memory_trainer import MemoryOptimizedTrainer from .upsampler.aiunn import aiuNN from .upsampler.config import aiuNNConfig from .inference.inference import aiuNNInference -__version__ = "0.2.3" \ No newline at end of file +__version__ = "0.2.4" \ No newline at end of file diff --git a/src/aiunn/finetune/__init__.py b/src/aiunn/finetune/__init__.py index 33239b1..dd4be32 100644 --- a/src/aiunn/finetune/__init__.py +++ b/src/aiunn/finetune/__init__.py @@ -1,3 +1,4 @@ from .trainer import aiuNNTrainer +from .memory_trainer import MemoryOptimizedTrainer -__all__ = ["aiuNNTrainer" ] \ No newline at end of file +__all__ = ["aiuNNTrainer", "MemoryOptimizedTrainer" ] \ No newline at end of file diff --git a/src/aiunn/finetune/memory_trainer.py b/src/aiunn/finetune/memory_trainer.py new file mode 100644 index 0000000..db6be54 --- /dev/null +++ b/src/aiunn/finetune/memory_trainer.py @@ -0,0 +1,363 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import os +import csv +from torch.amp import autocast, GradScaler +from torch.utils.data import DataLoader +from tqdm import tqdm +from torch.utils.checkpoint import checkpoint +import gc +import time +from .trainer import aiuNNTrainer, EarlyStopping +import psutil +import threading + +class MemoryOptimizedTrainer(aiuNNTrainer): + def __init__(self, upscaler_model, dataset_class=None, use_gradient_accumulation=True, + accumulation_steps=4, use_memory_profiling=True, use_model_compilation=True): + """ + Enhanced trainer with advanced memory optimizations + + Args: + use_gradient_accumulation: Enable gradient accumulation for larger effective batch sizes + accumulation_steps: Number of steps to accumulate gradients before updating + use_memory_profiling: Enable automatic memory monitoring + use_model_compilation: Use torch.compile for better efficiency + """ + super().__init__(upscaler_model, dataset_class) + + # Memory optimization settings + self.use_gradient_accumulation = use_gradient_accumulation + self.accumulation_steps = accumulation_steps + self.use_memory_profiling = use_memory_profiling + self.use_model_compilation = use_model_compilation + + # Memory monitoring + self.memory_stats = [] + self.peak_memory = 0 + self.memory_monitor_thread = None + self.stop_monitoring = False + + # Compile model for better efficiency if requested + if self.use_model_compilation and hasattr(torch, 'compile'): + try: + self.model = torch.compile(self.model, mode='reduce-overhead') + print("Model compiled successfully for better memory efficiency") + except Exception as e: + print(f"Model compilation failed: {e}, continuing without compilation") + + # Use memory-efficient optimizer settings + self.optimizer_kwargs = { + 'eps': 1e-8, + 'weight_decay': 0.01, + 'amsgrad': False # Disable amsgrad to save memory + } + + def _get_optimal_num_workers(self): + """Calculate optimal number of workers for DataLoader""" + cpu_count = os.cpu_count() + # Use fewer workers to reduce memory overhead + return min(4, max(1, cpu_count // 2)) + + def _memory_monitor(self): + """Background thread to monitor memory usage""" + while not self.stop_monitoring: + try: + # GPU memory if available + if torch.cuda.is_available(): + gpu_memory = torch.cuda.memory_allocated() / (1024**3) # GB + gpu_cached = torch.cuda.memory_reserved() / (1024**3) # GB + self.peak_memory = max(self.peak_memory, gpu_memory) + else: + gpu_memory = gpu_cached = 0 + + # System RAM + ram_usage = psutil.virtual_memory().percent + ram_available = psutil.virtual_memory().available / (1024**3) # GB + + self.memory_stats.append({ + 'timestamp': time.time(), + 'gpu_allocated': gpu_memory, + 'gpu_cached': gpu_cached, + 'ram_percent': ram_usage, + 'ram_available': ram_available + }) + + # Keep only last 100 measurements to prevent memory buildup + if len(self.memory_stats) > 100: + self.memory_stats = self.memory_stats[-50:] + + except Exception: + pass + time.sleep(5) # Check every 5 seconds + + def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2, + custom_train_dataset=None, custom_val_dataset=None, val_batch_size=None): + """Enhanced data loading with memory optimizations""" + + # Use smaller validation batch size to reduce memory spikes + if val_batch_size is None: + val_batch_size = max(1, batch_size // 2) + + # Calculate optimal number of workers + optimal_workers = self._get_optimal_num_workers() + + # Load datasets (keeping existing logic) + if custom_train_dataset is not None: + train_dataset = custom_train_dataset + val_dataset = custom_val_dataset if custom_val_dataset is not None else None + else: + if self.dataset_class is None: + raise ValueError("No dataset class provided.") + + dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params}) + dataset_size = len(dataset) + val_size = int(validation_split * dataset_size) + train_size = dataset_size - val_size + train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) + + # Create optimized data loaders + self.data_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + pin_memory=True, + num_workers=optimal_workers, + prefetch_factor=2, # Reduce prefetching to save memory + persistent_workers=True if optimal_workers > 0 else False, + drop_last=True # Drop incomplete batches to maintain consistent memory usage + ) + + if val_dataset is not None: + self.validation_loader = DataLoader( + val_dataset, + batch_size=val_batch_size, # Use smaller validation batch size + shuffle=False, + pin_memory=True, + num_workers=min(2, optimal_workers), # Fewer workers for validation + prefetch_factor=1, + persistent_workers=True if optimal_workers > 0 else False + ) + print(f"Loaded {len(train_dataset)} training samples (batch_size={batch_size}) and {len(val_dataset)} validation samples (batch_size={val_batch_size})") + else: + self.validation_loader = None + print(f"Loaded {len(train_dataset)} training samples (batch_size={batch_size})") + + return self.data_loader, self.validation_loader + + def _aggressive_memory_cleanup(self): + """More aggressive memory cleanup""" + # Clear Python cache + gc.collect() + + # Clear PyTorch cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Clear any lingering autograd graphs + if hasattr(torch.autograd, 'set_grad_enabled'): + with torch.no_grad(): + pass + + def _evaluate_memory_efficient(self): + """Memory-efficient validation with smaller chunks""" + if self.validation_loader is None: + return 0.0 + + self.model.eval() + val_loss = 0.0 + num_batches = 0 + + with torch.no_grad(): + for low_res, high_res in self.validation_loader: + low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) + high_res = high_res.to(self.device, non_blocking=True) + + with autocast(device_type=self.device.type): + outputs = self.model(low_res) + loss = self.criterion(outputs, high_res) + + val_loss += loss.item() + num_batches += 1 + + # Immediate cleanup + del low_res, high_res, outputs, loss + + # Aggressive cleanup every few batches + if num_batches % 10 == 0: + self._aggressive_memory_cleanup() + + self.model.train() + # Final cleanup after validation + self._aggressive_memory_cleanup() + + return val_loss / max(num_batches, 1) + + def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001): + """Enhanced training with memory optimizations""" + if self.data_loader is None: + raise ValueError("Data not loaded. Call load_data first.") + + # Start memory monitoring + if self.use_memory_profiling: + self.stop_monitoring = False + self.memory_monitor_thread = threading.Thread(target=self._memory_monitor, daemon=True) + self.memory_monitor_thread.start() + print("Memory monitoring started") + + # Setup logging and optimizer + self._setup_logging(output_path=output_path) + self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, **self.optimizer_kwargs) + + # Load checkpoint if available + checkpoint_info = self.load_checkpoint() + start_epoch = checkpoint_info[0] if checkpoint_info else 0 + start_batch = checkpoint_info[1] if checkpoint_info else 0 + + # Setup early stopping + early_stopping = EarlyStopping(patience=patience, min_delta=min_delta) + self.best_loss = float('inf') + + print(f"Training configuration:") + print(f"- Gradient accumulation: {self.use_gradient_accumulation} (steps: {self.accumulation_steps})") + print(f"- Memory profiling: {self.use_memory_profiling}") + print(f"- Model compilation: {self.use_model_compilation}") + print(f"- Effective batch size: {self.data_loader.batch_size * (self.accumulation_steps if self.use_gradient_accumulation else 1)}") + + # Training loop + self.model.train() + for epoch in range(start_epoch, epochs): + self.current_epoch = epoch + epoch_loss = 0.0 + accumulation_loss = 0.0 + + train_batches = list(enumerate(self.data_loader)) + start_idx = start_batch if epoch == start_epoch else 0 + + progress_bar = tqdm(train_batches[start_idx:], + initial=start_idx, + total=len(train_batches), + desc=f"Epoch {epoch + 1}/{epochs}") + + for batch_idx, (low_res, high_res) in progress_bar: + # Move data to device + low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) + high_res = high_res.to(self.device, non_blocking=True) + + # Forward pass + with autocast(device_type=self.device.type): + if hasattr(self, 'use_checkpointing') and self.use_checkpointing: + low_res.requires_grad_() + outputs = checkpoint(self.model, low_res) + else: + outputs = self.model(low_res) + loss = self.criterion(outputs, high_res) + + # Scale loss for gradient accumulation + if self.use_gradient_accumulation: + loss = loss / self.accumulation_steps + + # Backward pass + self.scaler.scale(loss).backward() + + accumulation_loss += loss.item() + + # Update weights every accumulation_steps or at the end of epoch + should_step = (not self.use_gradient_accumulation or + (batch_idx + 1) % self.accumulation_steps == 0 or + batch_idx == len(train_batches) - 1) + + if should_step: + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() + + # Add accumulated loss to epoch loss + if self.use_gradient_accumulation: + epoch_loss += accumulation_loss + accumulation_loss = 0.0 + else: + epoch_loss += loss.item() + + # Update progress bar + current_loss = accumulation_loss if self.use_gradient_accumulation else loss.item() + progress_bar.set_postfix({ + 'loss': current_loss, + 'peak_mem': f"{self.peak_memory:.1f}GB" if self.use_memory_profiling else "N/A" + }) + + # Immediate cleanup + del low_res, high_res, outputs, loss + + # Handle checkpoints + self._handle_checkpoints(epoch + 1, batch_idx + 1, current_loss < self.best_loss) + + # Periodic aggressive cleanup + if batch_idx % 20 == 0: + self._aggressive_memory_cleanup() + + # End of epoch processing + avg_train_loss = epoch_loss / len(self.data_loader) + + # Memory-efficient validation + if self.validation_loader: + val_loss = self._evaluate_memory_efficient() + is_improved = val_loss < self.best_loss + if is_improved: + self.best_loss = val_loss + + with open(self.csv_path, mode='a', newline='') as file: + writer = csv.writer(file) + writer.writerow([epoch + 1, avg_train_loss, val_loss]) + + print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}") + if self.use_memory_profiling: + print(f"Peak GPU Memory: {self.peak_memory:.2f}GB") + else: + is_improved = avg_train_loss < self.best_loss + if is_improved: + self.best_loss = avg_train_loss + + with open(self.csv_path, mode='a', newline='') as file: + writer = csv.writer(file) + writer.writerow([epoch + 1, avg_train_loss]) + + print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}") + + # Save best model + if is_improved: + best_model_path = os.path.join(output_path, "best_model") + self.model.save_pretrained(best_model_path) + + # Check early stopping + if early_stopping(val_loss if self.validation_loader else avg_train_loss): + print(f"Early stopping triggered at epoch {epoch + 1}") + break + + # End of epoch cleanup + self._aggressive_memory_cleanup() + + # Stop memory monitoring + if self.use_memory_profiling: + self.stop_monitoring = True + if self.memory_monitor_thread: + self.memory_monitor_thread.join(timeout=1) + print(f"Training completed. Peak GPU memory usage: {self.peak_memory:.2f}GB") + + return self.best_loss + + def get_memory_summary(self): + """Get a summary of memory usage during training""" + if not self.memory_stats: + return "No memory statistics available" + + gpu_peak = max([stat['gpu_allocated'] for stat in self.memory_stats]) + ram_peak = max([stat['ram_percent'] for stat in self.memory_stats]) + + return { + 'peak_gpu_memory_gb': gpu_peak, + 'peak_ram_percent': ram_peak, + 'total_measurements': len(self.memory_stats) + }