aiuNN/src/aiunn/finetune/memory_trainer.py

377 lines
16 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import os
import csv
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.checkpoint import checkpoint
import gc
import time
from .trainer import aiuNNTrainer, EarlyStopping
import psutil
import threading
class MemoryOptimizedTrainer(aiuNNTrainer):
def __init__(self, upscaler_model, dataset_class=None, use_gradient_accumulation=True,
accumulation_steps=4, use_memory_profiling=True, use_model_compilation=True):
"""
Enhanced trainer with advanced memory optimizations
Args:
use_gradient_accumulation: Enable gradient accumulation for larger effective batch sizes
accumulation_steps: Number of steps to accumulate gradients before updating
use_memory_profiling: Enable automatic memory monitoring
use_model_compilation: Use torch.compile for better efficiency
"""
super().__init__(upscaler_model, dataset_class)
# Memory optimization settings
self.use_gradient_accumulation = use_gradient_accumulation
self.accumulation_steps = accumulation_steps
self.use_memory_profiling = use_memory_profiling
self.use_model_compilation = use_model_compilation
# Memory monitoring
self.memory_stats = []
self.peak_memory = 0
self.memory_monitor_thread = None
self.stop_monitoring = False
# Compile model for better efficiency if requested
if self.use_model_compilation and hasattr(torch, 'compile'):
try:
self.model = torch.compile(self.model, mode='reduce-overhead')
print("Model compiled successfully for better memory efficiency")
except Exception as e:
print(f"Model compilation failed: {e}, continuing without compilation")
# Use memory-efficient optimizer settings
self.optimizer_kwargs = {
'eps': 1e-8,
'weight_decay': 0.01,
'amsgrad': False # Disable amsgrad to save memory
}
def _get_optimal_num_workers(self):
"""Calculate optimal number of workers for DataLoader"""
cpu_count = os.cpu_count()
# Use fewer workers to reduce memory overhead
return min(4, max(1, cpu_count // 2))
def _memory_monitor(self):
"""Background thread to monitor memory usage"""
while not self.stop_monitoring:
try:
# GPU memory if available
if torch.cuda.is_available():
gpu_memory = torch.cuda.memory_allocated() / (1024**3) # GB
gpu_cached = torch.cuda.memory_reserved() / (1024**3) # GB
self.peak_memory = max(self.peak_memory, gpu_memory)
else:
gpu_memory = gpu_cached = 0
# System RAM
ram_usage = psutil.virtual_memory().percent
ram_available = psutil.virtual_memory().available / (1024**3) # GB
self.memory_stats.append({
'timestamp': time.time(),
'gpu_allocated': gpu_memory,
'gpu_cached': gpu_cached,
'ram_percent': ram_usage,
'ram_available': ram_available
})
# Keep only last 100 measurements to prevent memory buildup
if len(self.memory_stats) > 100:
self.memory_stats = self.memory_stats[-50:]
except Exception:
pass
time.sleep(5) # Check every 5 seconds
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2,
custom_train_dataset=None, custom_val_dataset=None, val_batch_size=None):
"""Enhanced data loading with memory optimizations"""
# Use smaller validation batch size to reduce memory spikes
if val_batch_size is None:
val_batch_size = max(1, batch_size // 2)
# Calculate optimal number of workers
optimal_workers = self._get_optimal_num_workers()
# Load datasets (keeping existing logic)
if custom_train_dataset is not None:
train_dataset = custom_train_dataset
val_dataset = custom_val_dataset if custom_val_dataset is not None else None
else:
if self.dataset_class is None:
raise ValueError("No dataset class provided.")
dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params})
dataset_size = len(dataset)
val_size = int(validation_split * dataset_size)
train_size = dataset_size - val_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
# Create optimized data loaders
self.data_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=optimal_workers,
prefetch_factor=2, # Reduce prefetching to save memory
persistent_workers=True if optimal_workers > 0 else False,
drop_last=True # Drop incomplete batches to maintain consistent memory usage
)
if val_dataset is not None:
self.validation_loader = DataLoader(
val_dataset,
batch_size=val_batch_size, # Use smaller validation batch size
shuffle=False,
pin_memory=True,
num_workers=min(2, optimal_workers), # Fewer workers for validation
prefetch_factor=1,
persistent_workers=True if optimal_workers > 0 else False
)
print(f"Loaded {len(train_dataset)} training samples (batch_size={batch_size}) and {len(val_dataset)} validation samples (batch_size={val_batch_size})")
else:
self.validation_loader = None
print(f"Loaded {len(train_dataset)} training samples (batch_size={batch_size})")
return self.data_loader, self.validation_loader
def _aggressive_memory_cleanup(self):
"""More aggressive memory cleanup"""
# Clear Python cache
gc.collect()
# Clear PyTorch cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Clear any lingering autograd graphs
if hasattr(torch.autograd, 'set_grad_enabled'):
with torch.no_grad():
pass
def _evaluate_memory_efficient(self):
"""Memory-efficient validation with smaller chunks"""
if self.validation_loader is None:
return 0.0
self.model.eval()
val_loss = 0.0
num_batches = 0
with torch.no_grad():
for low_res, high_res in self.validation_loader:
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(self.device, non_blocking=True)
with autocast(device_type=self.device.type):
outputs = self.model(low_res)
loss = self.criterion(outputs, high_res)
val_loss += loss.item()
num_batches += 1
# Immediate cleanup
del low_res, high_res, outputs, loss
# Aggressive cleanup every few batches
if num_batches % 10 == 0:
self._aggressive_memory_cleanup()
self.model.train()
# Final cleanup after validation
self._aggressive_memory_cleanup()
return val_loss / max(num_batches, 1)
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
"""Enhanced training with memory optimizations"""
if self.data_loader is None:
raise ValueError("Data not loaded. Call load_data first.")
# Start memory monitoring
if self.use_memory_profiling:
self.stop_monitoring = False
self.memory_monitor_thread = threading.Thread(target=self._memory_monitor, daemon=True)
self.memory_monitor_thread.start()
print("Memory monitoring started")
# Setup logging and optimizer
self._setup_logging(output_path=output_path)
self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, **self.optimizer_kwargs)
# Load checkpoint if available
checkpoint_info = self.load_checkpoint()
start_epoch = checkpoint_info[0] if checkpoint_info else 0
start_batch = checkpoint_info[1] if checkpoint_info else 0
# Setup early stopping
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
self.best_loss = float('inf')
print(f"Training configuration:")
print(f"- Gradient accumulation: {self.use_gradient_accumulation} (steps: {self.accumulation_steps})")
print(f"- Memory profiling: {self.use_memory_profiling}")
print(f"- Model compilation: {self.use_model_compilation}")
print(f"- Effective batch size: {self.data_loader.batch_size * (self.accumulation_steps if self.use_gradient_accumulation else 1)}")
# Training loop
self.model.train()
for epoch in range(start_epoch, epochs):
self.current_epoch = epoch
epoch_loss = 0.0
accumulation_loss = 0.0
train_batches = list(enumerate(self.data_loader))
start_idx = start_batch if epoch == start_epoch else 0
progress_bar = tqdm(train_batches[start_idx:],
initial=start_idx,
total=len(train_batches),
desc=f"Epoch {epoch + 1}/{epochs}")
for batch_idx, (low_res, high_res) in progress_bar:
# Move data to device
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
high_res = high_res.to(self.device, non_blocking=True)
# Forward pass
with autocast(device_type=self.device.type):
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
low_res.requires_grad_()
outputs = checkpoint(self.model, low_res)
outputs = outputs.clone() # <-- Clone added here
else:
outputs = self.model(low_res)
outputs = outputs.clone() # <-- Clone added here
loss = self.criterion(outputs, high_res)
# Scale loss for gradient accumulation
if self.use_gradient_accumulation:
loss = loss / self.accumulation_steps
# Backward pass
self.scaler.scale(loss).backward()
accumulation_loss += loss.item()
# Update weights every accumulation_steps or at the end of epoch
should_step = (not self.use_gradient_accumulation or
(batch_idx + 1) % self.accumulation_steps == 0 or
batch_idx == len(train_batches) - 1)
if should_step:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
# Add accumulated loss to epoch loss
if self.use_gradient_accumulation:
epoch_loss += accumulation_loss
accumulation_loss = 0.0
else:
epoch_loss += loss.item()
# Update progress bar
current_loss = accumulation_loss if self.use_gradient_accumulation else loss.item()
progress_bar.set_postfix({
'loss': current_loss,
'peak_mem': f"{self.peak_memory:.1f}GB" if self.use_memory_profiling else "N/A"
})
# Immediate cleanup
del low_res, high_res, outputs, loss
# Handle checkpoints
self._handle_checkpoints(epoch + 1, batch_idx + 1, current_loss < self.best_loss)
# Periodic aggressive cleanup
if batch_idx % 20 == 0:
self._aggressive_memory_cleanup()
# End of epoch processing
avg_train_loss = epoch_loss / len(self.data_loader)
# Memory-efficient validation
if self.validation_loader:
val_loss = self._evaluate_memory_efficient()
is_improved = val_loss < self.best_loss
if is_improved:
self.best_loss = val_loss
with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss, val_loss])
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
if self.use_memory_profiling:
print(f"Peak GPU Memory: {self.peak_memory:.2f}GB")
else:
is_improved = avg_train_loss < self.best_loss
if is_improved:
self.best_loss = avg_train_loss
with open(self.csv_path, mode='a', newline='') as file:
writer = csv.writer(file)
writer.writerow([epoch + 1, avg_train_loss])
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
# Save best model
if is_improved:
best_model_path = os.path.join(output_path, "best_model")
self.model.save_pretrained(best_model_path)
# Check early stopping
if early_stopping(val_loss if self.validation_loader else avg_train_loss):
print(f"Early stopping triggered at epoch {epoch + 1}")
break
# End of epoch cleanup
self._aggressive_memory_cleanup()
# Stop memory monitoring
if self.use_memory_profiling:
self.stop_monitoring = True
if self.memory_monitor_thread:
self.memory_monitor_thread.join(timeout=1)
print(f"Training completed. Peak GPU memory usage: {self.peak_memory:.2f}GB")
return self.best_loss
# Stop memory monitoring
if self.use_memory_profiling:
self.stop_monitoring = True
if self.memory_monitor_thread:
self.memory_monitor_thread.join(timeout=1)
print(f"Training completed. Peak GPU memory usage: {self.peak_memory:.2f}GB")
return self.best_loss
def get_memory_summary(self):
"""Get a summary of memory usage during training"""
if not self.memory_stats:
return "No memory statistics available"
gpu_peak = max([stat['gpu_allocated'] for stat in self.memory_stats])
ram_peak = max([stat['ram_percent'] for stat in self.memory_stats])
return {
'peak_gpu_memory_gb': gpu_peak,
'peak_ram_percent': ram_peak,
'total_measurements': len(self.memory_stats)
}