added a new trainer called MemoryOptimizedTrainer
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 29s
Details
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 29s
Details
This commit is contained in:
parent
aa06b4cf57
commit
1cf8b9e09b
91
README.md
91
README.md
|
@ -1,10 +1,6 @@
|
|||
# aiuNN
|
||||
|
||||
Adaptive Image Upscaler using Neural Networks
|
||||
|
||||
## Overview
|
||||
|
||||
`aiuNN` is an adaptive image upscaling model built on top of the Adaptive Image Intelligence Architecture (AIIA). This project provides fine-tuned versions of AIIA models specifically designed for high-quality image upscaling. By leveraging neural networks, `aiuNN` can significantly enhance the resolution and detail of images.
|
||||
Adaptive Image Upscaler using Neural Networks `aiuNN` is an adaptive image upscaling model built on top of the Adaptive Image Intelligence Architecture (AIIA). This project provides fine-tuned versions of AIIA models specifically designed for high-quality image upscaling. By leveraging neural networks, `aiuNN` can significantly enhance the resolution and detail of images.
|
||||
|
||||
## Features
|
||||
|
||||
|
@ -23,6 +19,8 @@ pip install git+https://gitea.fabelous.app/Machine-Learning/aiuNN.git
|
|||
|
||||
## Usage
|
||||
|
||||
### Basic Example
|
||||
|
||||
Here's a basic example of how to use `aiuNN` for image upscaling:
|
||||
|
||||
```python src/main.py
|
||||
|
@ -61,12 +59,10 @@ trainer.load_data(dataset_params=dataset_params, batch_size=1)
|
|||
trainer.finetune(output_path="trained_model")
|
||||
```
|
||||
|
||||
## Dataset
|
||||
### Dataset Class
|
||||
|
||||
The `UpscaleDataset` class is designed to handle Parquet files containing image data. It loads a subset of images from each file and validates the data types to ensure consistency.
|
||||
|
||||
This is an example dataset that you can use with the AIIUN model:
|
||||
|
||||
```python src/example.py
|
||||
class UpscaleDataset(Dataset):
|
||||
def __init__(self, parquet_files: list, transform=None, samples_per_file=10_000):
|
||||
|
@ -134,4 +130,81 @@ class UpscaleDataset(Dataset):
|
|||
print(f"\nError at index {idx}: {str(e)}")
|
||||
self.failed_indices.add(idx)
|
||||
return self[(idx + 1) % len(self)]
|
||||
```
|
||||
```
|
||||
|
||||
### Trainers
|
||||
|
||||
`aiuNN` provides two types of trainers: the standard `aiuNNTrainer` and the optimized `MemoryOptimizedTrainer`. Choose the one that best fits your needs based on your hardware capabilities and memory constraints.
|
||||
|
||||
#### Standard Trainer (`aiuNNTrainer`)
|
||||
|
||||
Use the standard trainer when you have sufficient memory resources and do not need advanced optimizations. This trainer is straightforward and easy to use for most basic training tasks.
|
||||
|
||||
```python src/main.py
|
||||
from aiunn import aiuNNTrainer
|
||||
|
||||
# Create trainer with your dataset class
|
||||
trainer = aiuNNTrainer(upscaler, dataset_class=UpscaleDataset)
|
||||
|
||||
# Load data using parameters for your dataset
|
||||
dataset_params = {
|
||||
'parquet_files': [
|
||||
"path/to/dataset1",
|
||||
"path/to/dataset2"
|
||||
],
|
||||
'transform': transforms.Compose([transforms.ToTensor()]),
|
||||
'samples_per_file': 5000 # Your training samples you want to load per file
|
||||
}
|
||||
trainer.load_data(dataset_params=dataset_params, batch_size=4)
|
||||
|
||||
# Fine-tune the model
|
||||
trainer.finetune(output_path="trained_model")
|
||||
```
|
||||
|
||||
#### Memory-Efficient Trainer (`MemoryOptimizedTrainer`)
|
||||
|
||||
Use the `MemoryOptimizedTrainer` when you have limited memory resources and need to optimize training for better efficiency. This trainer includes features like gradient accumulation, memory profiling, and aggressive memory cleanup.
|
||||
|
||||
```python src/main.py
|
||||
from aiunn import MemoryOptimizedTrainer
|
||||
|
||||
# Replace your existing trainer with the optimized version
|
||||
trainer = MemoryOptimizedTrainer(
|
||||
upscaler_model=your_model,
|
||||
dataset_class=your_dataset_class,
|
||||
use_gradient_accumulation=True,
|
||||
accumulation_steps=4, # Effective batch size = batch_size * 4
|
||||
use_memory_profiling=True,
|
||||
use_model_compilation=True
|
||||
)
|
||||
|
||||
# Load data with memory optimizations
|
||||
trainer.load_data(
|
||||
dataset_params=your_params,
|
||||
batch_size=2, # Use smaller batch size with gradient accumulation
|
||||
val_batch_size=1 # Even smaller validation batches
|
||||
)
|
||||
|
||||
# Train with optimizations
|
||||
trainer.finetune(
|
||||
output_path="./output",
|
||||
epochs=10,
|
||||
lr=1e-4
|
||||
)
|
||||
|
||||
# Get memory usage summary
|
||||
memory_summary = trainer.get_memory_summary()
|
||||
print(memory_summary)
|
||||
```
|
||||
|
||||
#### When to Use Which
|
||||
|
||||
- **Use `aiuNNTrainer` when**:
|
||||
- You have sufficient GPU and RAM resources.
|
||||
- You prefer a straightforward training process without additional optimizations.
|
||||
- Your dataset is not excessively large, and memory constraints are not a concern.
|
||||
|
||||
- **Use `MemoryOptimizedTrainer` when**:
|
||||
- You have limited GPU or RAM resources.
|
||||
- You need to train on larger datasets that may exceed your hardware capabilities.
|
||||
- You want to monitor memory usage and optimize training for better efficiency.
|
|
@ -3,7 +3,7 @@ requires = ["setuptools>=45", "wheel"]
|
|||
build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "aiunn"
|
||||
version = "0.2.3"
|
||||
version = "0.2.4"
|
||||
description = "Finetuner for image upscaling using AIIA"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
|
2
setup.py
2
setup.py
|
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|||
|
||||
setup(
|
||||
name="aiunn",
|
||||
version="0.2.3",
|
||||
version="0.2.4",
|
||||
packages=find_packages(where="src"),
|
||||
package_dir={"": "src"},
|
||||
install_requires=[
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .finetune.trainer import aiuNNTrainer
|
||||
from .finetune.memory_trainer import MemoryOptimizedTrainer
|
||||
from .upsampler.aiunn import aiuNN
|
||||
from .upsampler.config import aiuNNConfig
|
||||
from .inference.inference import aiuNNInference
|
||||
|
||||
__version__ = "0.2.3"
|
||||
__version__ = "0.2.4"
|
|
@ -1,3 +1,4 @@
|
|||
from .trainer import aiuNNTrainer
|
||||
from .memory_trainer import MemoryOptimizedTrainer
|
||||
|
||||
__all__ = ["aiuNNTrainer" ]
|
||||
__all__ = ["aiuNNTrainer", "MemoryOptimizedTrainer" ]
|
|
@ -0,0 +1,363 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import os
|
||||
import csv
|
||||
from torch.amp import autocast, GradScaler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import gc
|
||||
import time
|
||||
from .trainer import aiuNNTrainer, EarlyStopping
|
||||
import psutil
|
||||
import threading
|
||||
|
||||
class MemoryOptimizedTrainer(aiuNNTrainer):
|
||||
def __init__(self, upscaler_model, dataset_class=None, use_gradient_accumulation=True,
|
||||
accumulation_steps=4, use_memory_profiling=True, use_model_compilation=True):
|
||||
"""
|
||||
Enhanced trainer with advanced memory optimizations
|
||||
|
||||
Args:
|
||||
use_gradient_accumulation: Enable gradient accumulation for larger effective batch sizes
|
||||
accumulation_steps: Number of steps to accumulate gradients before updating
|
||||
use_memory_profiling: Enable automatic memory monitoring
|
||||
use_model_compilation: Use torch.compile for better efficiency
|
||||
"""
|
||||
super().__init__(upscaler_model, dataset_class)
|
||||
|
||||
# Memory optimization settings
|
||||
self.use_gradient_accumulation = use_gradient_accumulation
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.use_memory_profiling = use_memory_profiling
|
||||
self.use_model_compilation = use_model_compilation
|
||||
|
||||
# Memory monitoring
|
||||
self.memory_stats = []
|
||||
self.peak_memory = 0
|
||||
self.memory_monitor_thread = None
|
||||
self.stop_monitoring = False
|
||||
|
||||
# Compile model for better efficiency if requested
|
||||
if self.use_model_compilation and hasattr(torch, 'compile'):
|
||||
try:
|
||||
self.model = torch.compile(self.model, mode='reduce-overhead')
|
||||
print("Model compiled successfully for better memory efficiency")
|
||||
except Exception as e:
|
||||
print(f"Model compilation failed: {e}, continuing without compilation")
|
||||
|
||||
# Use memory-efficient optimizer settings
|
||||
self.optimizer_kwargs = {
|
||||
'eps': 1e-8,
|
||||
'weight_decay': 0.01,
|
||||
'amsgrad': False # Disable amsgrad to save memory
|
||||
}
|
||||
|
||||
def _get_optimal_num_workers(self):
|
||||
"""Calculate optimal number of workers for DataLoader"""
|
||||
cpu_count = os.cpu_count()
|
||||
# Use fewer workers to reduce memory overhead
|
||||
return min(4, max(1, cpu_count // 2))
|
||||
|
||||
def _memory_monitor(self):
|
||||
"""Background thread to monitor memory usage"""
|
||||
while not self.stop_monitoring:
|
||||
try:
|
||||
# GPU memory if available
|
||||
if torch.cuda.is_available():
|
||||
gpu_memory = torch.cuda.memory_allocated() / (1024**3) # GB
|
||||
gpu_cached = torch.cuda.memory_reserved() / (1024**3) # GB
|
||||
self.peak_memory = max(self.peak_memory, gpu_memory)
|
||||
else:
|
||||
gpu_memory = gpu_cached = 0
|
||||
|
||||
# System RAM
|
||||
ram_usage = psutil.virtual_memory().percent
|
||||
ram_available = psutil.virtual_memory().available / (1024**3) # GB
|
||||
|
||||
self.memory_stats.append({
|
||||
'timestamp': time.time(),
|
||||
'gpu_allocated': gpu_memory,
|
||||
'gpu_cached': gpu_cached,
|
||||
'ram_percent': ram_usage,
|
||||
'ram_available': ram_available
|
||||
})
|
||||
|
||||
# Keep only last 100 measurements to prevent memory buildup
|
||||
if len(self.memory_stats) > 100:
|
||||
self.memory_stats = self.memory_stats[-50:]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(5) # Check every 5 seconds
|
||||
|
||||
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2,
|
||||
custom_train_dataset=None, custom_val_dataset=None, val_batch_size=None):
|
||||
"""Enhanced data loading with memory optimizations"""
|
||||
|
||||
# Use smaller validation batch size to reduce memory spikes
|
||||
if val_batch_size is None:
|
||||
val_batch_size = max(1, batch_size // 2)
|
||||
|
||||
# Calculate optimal number of workers
|
||||
optimal_workers = self._get_optimal_num_workers()
|
||||
|
||||
# Load datasets (keeping existing logic)
|
||||
if custom_train_dataset is not None:
|
||||
train_dataset = custom_train_dataset
|
||||
val_dataset = custom_val_dataset if custom_val_dataset is not None else None
|
||||
else:
|
||||
if self.dataset_class is None:
|
||||
raise ValueError("No dataset class provided.")
|
||||
|
||||
dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params})
|
||||
dataset_size = len(dataset)
|
||||
val_size = int(validation_split * dataset_size)
|
||||
train_size = dataset_size - val_size
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||
|
||||
# Create optimized data loaders
|
||||
self.data_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
num_workers=optimal_workers,
|
||||
prefetch_factor=2, # Reduce prefetching to save memory
|
||||
persistent_workers=True if optimal_workers > 0 else False,
|
||||
drop_last=True # Drop incomplete batches to maintain consistent memory usage
|
||||
)
|
||||
|
||||
if val_dataset is not None:
|
||||
self.validation_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=val_batch_size, # Use smaller validation batch size
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
num_workers=min(2, optimal_workers), # Fewer workers for validation
|
||||
prefetch_factor=1,
|
||||
persistent_workers=True if optimal_workers > 0 else False
|
||||
)
|
||||
print(f"Loaded {len(train_dataset)} training samples (batch_size={batch_size}) and {len(val_dataset)} validation samples (batch_size={val_batch_size})")
|
||||
else:
|
||||
self.validation_loader = None
|
||||
print(f"Loaded {len(train_dataset)} training samples (batch_size={batch_size})")
|
||||
|
||||
return self.data_loader, self.validation_loader
|
||||
|
||||
def _aggressive_memory_cleanup(self):
|
||||
"""More aggressive memory cleanup"""
|
||||
# Clear Python cache
|
||||
gc.collect()
|
||||
|
||||
# Clear PyTorch cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Clear any lingering autograd graphs
|
||||
if hasattr(torch.autograd, 'set_grad_enabled'):
|
||||
with torch.no_grad():
|
||||
pass
|
||||
|
||||
def _evaluate_memory_efficient(self):
|
||||
"""Memory-efficient validation with smaller chunks"""
|
||||
if self.validation_loader is None:
|
||||
return 0.0
|
||||
|
||||
self.model.eval()
|
||||
val_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for low_res, high_res in self.validation_loader:
|
||||
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
||||
high_res = high_res.to(self.device, non_blocking=True)
|
||||
|
||||
with autocast(device_type=self.device.type):
|
||||
outputs = self.model(low_res)
|
||||
loss = self.criterion(outputs, high_res)
|
||||
|
||||
val_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
# Immediate cleanup
|
||||
del low_res, high_res, outputs, loss
|
||||
|
||||
# Aggressive cleanup every few batches
|
||||
if num_batches % 10 == 0:
|
||||
self._aggressive_memory_cleanup()
|
||||
|
||||
self.model.train()
|
||||
# Final cleanup after validation
|
||||
self._aggressive_memory_cleanup()
|
||||
|
||||
return val_loss / max(num_batches, 1)
|
||||
|
||||
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
|
||||
"""Enhanced training with memory optimizations"""
|
||||
if self.data_loader is None:
|
||||
raise ValueError("Data not loaded. Call load_data first.")
|
||||
|
||||
# Start memory monitoring
|
||||
if self.use_memory_profiling:
|
||||
self.stop_monitoring = False
|
||||
self.memory_monitor_thread = threading.Thread(target=self._memory_monitor, daemon=True)
|
||||
self.memory_monitor_thread.start()
|
||||
print("Memory monitoring started")
|
||||
|
||||
# Setup logging and optimizer
|
||||
self._setup_logging(output_path=output_path)
|
||||
self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, **self.optimizer_kwargs)
|
||||
|
||||
# Load checkpoint if available
|
||||
checkpoint_info = self.load_checkpoint()
|
||||
start_epoch = checkpoint_info[0] if checkpoint_info else 0
|
||||
start_batch = checkpoint_info[1] if checkpoint_info else 0
|
||||
|
||||
# Setup early stopping
|
||||
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
|
||||
self.best_loss = float('inf')
|
||||
|
||||
print(f"Training configuration:")
|
||||
print(f"- Gradient accumulation: {self.use_gradient_accumulation} (steps: {self.accumulation_steps})")
|
||||
print(f"- Memory profiling: {self.use_memory_profiling}")
|
||||
print(f"- Model compilation: {self.use_model_compilation}")
|
||||
print(f"- Effective batch size: {self.data_loader.batch_size * (self.accumulation_steps if self.use_gradient_accumulation else 1)}")
|
||||
|
||||
# Training loop
|
||||
self.model.train()
|
||||
for epoch in range(start_epoch, epochs):
|
||||
self.current_epoch = epoch
|
||||
epoch_loss = 0.0
|
||||
accumulation_loss = 0.0
|
||||
|
||||
train_batches = list(enumerate(self.data_loader))
|
||||
start_idx = start_batch if epoch == start_epoch else 0
|
||||
|
||||
progress_bar = tqdm(train_batches[start_idx:],
|
||||
initial=start_idx,
|
||||
total=len(train_batches),
|
||||
desc=f"Epoch {epoch + 1}/{epochs}")
|
||||
|
||||
for batch_idx, (low_res, high_res) in progress_bar:
|
||||
# Move data to device
|
||||
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
||||
high_res = high_res.to(self.device, non_blocking=True)
|
||||
|
||||
# Forward pass
|
||||
with autocast(device_type=self.device.type):
|
||||
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
|
||||
low_res.requires_grad_()
|
||||
outputs = checkpoint(self.model, low_res)
|
||||
else:
|
||||
outputs = self.model(low_res)
|
||||
loss = self.criterion(outputs, high_res)
|
||||
|
||||
# Scale loss for gradient accumulation
|
||||
if self.use_gradient_accumulation:
|
||||
loss = loss / self.accumulation_steps
|
||||
|
||||
# Backward pass
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
accumulation_loss += loss.item()
|
||||
|
||||
# Update weights every accumulation_steps or at the end of epoch
|
||||
should_step = (not self.use_gradient_accumulation or
|
||||
(batch_idx + 1) % self.accumulation_steps == 0 or
|
||||
batch_idx == len(train_batches) - 1)
|
||||
|
||||
if should_step:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Add accumulated loss to epoch loss
|
||||
if self.use_gradient_accumulation:
|
||||
epoch_loss += accumulation_loss
|
||||
accumulation_loss = 0.0
|
||||
else:
|
||||
epoch_loss += loss.item()
|
||||
|
||||
# Update progress bar
|
||||
current_loss = accumulation_loss if self.use_gradient_accumulation else loss.item()
|
||||
progress_bar.set_postfix({
|
||||
'loss': current_loss,
|
||||
'peak_mem': f"{self.peak_memory:.1f}GB" if self.use_memory_profiling else "N/A"
|
||||
})
|
||||
|
||||
# Immediate cleanup
|
||||
del low_res, high_res, outputs, loss
|
||||
|
||||
# Handle checkpoints
|
||||
self._handle_checkpoints(epoch + 1, batch_idx + 1, current_loss < self.best_loss)
|
||||
|
||||
# Periodic aggressive cleanup
|
||||
if batch_idx % 20 == 0:
|
||||
self._aggressive_memory_cleanup()
|
||||
|
||||
# End of epoch processing
|
||||
avg_train_loss = epoch_loss / len(self.data_loader)
|
||||
|
||||
# Memory-efficient validation
|
||||
if self.validation_loader:
|
||||
val_loss = self._evaluate_memory_efficient()
|
||||
is_improved = val_loss < self.best_loss
|
||||
if is_improved:
|
||||
self.best_loss = val_loss
|
||||
|
||||
with open(self.csv_path, mode='a', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow([epoch + 1, avg_train_loss, val_loss])
|
||||
|
||||
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
|
||||
if self.use_memory_profiling:
|
||||
print(f"Peak GPU Memory: {self.peak_memory:.2f}GB")
|
||||
else:
|
||||
is_improved = avg_train_loss < self.best_loss
|
||||
if is_improved:
|
||||
self.best_loss = avg_train_loss
|
||||
|
||||
with open(self.csv_path, mode='a', newline='') as file:
|
||||
writer = csv.writer(file)
|
||||
writer.writerow([epoch + 1, avg_train_loss])
|
||||
|
||||
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
|
||||
|
||||
# Save best model
|
||||
if is_improved:
|
||||
best_model_path = os.path.join(output_path, "best_model")
|
||||
self.model.save_pretrained(best_model_path)
|
||||
|
||||
# Check early stopping
|
||||
if early_stopping(val_loss if self.validation_loader else avg_train_loss):
|
||||
print(f"Early stopping triggered at epoch {epoch + 1}")
|
||||
break
|
||||
|
||||
# End of epoch cleanup
|
||||
self._aggressive_memory_cleanup()
|
||||
|
||||
# Stop memory monitoring
|
||||
if self.use_memory_profiling:
|
||||
self.stop_monitoring = True
|
||||
if self.memory_monitor_thread:
|
||||
self.memory_monitor_thread.join(timeout=1)
|
||||
print(f"Training completed. Peak GPU memory usage: {self.peak_memory:.2f}GB")
|
||||
|
||||
return self.best_loss
|
||||
|
||||
def get_memory_summary(self):
|
||||
"""Get a summary of memory usage during training"""
|
||||
if not self.memory_stats:
|
||||
return "No memory statistics available"
|
||||
|
||||
gpu_peak = max([stat['gpu_allocated'] for stat in self.memory_stats])
|
||||
ram_peak = max([stat['ram_percent'] for stat in self.memory_stats])
|
||||
|
||||
return {
|
||||
'peak_gpu_memory_gb': gpu_peak,
|
||||
'peak_ram_percent': ram_peak,
|
||||
'total_measurements': len(self.memory_stats)
|
||||
}
|
Loading…
Reference in New Issue