377 lines
16 KiB
Python
377 lines
16 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import os
|
|
import csv
|
|
from torch.amp import autocast, GradScaler
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
from torch.utils.checkpoint import checkpoint
|
|
import gc
|
|
import time
|
|
from .trainer import aiuNNTrainer, EarlyStopping
|
|
import psutil
|
|
import threading
|
|
|
|
class MemoryOptimizedTrainer(aiuNNTrainer):
|
|
def __init__(self, upscaler_model, dataset_class=None, use_gradient_accumulation=True,
|
|
accumulation_steps=4, use_memory_profiling=True, use_model_compilation=True):
|
|
"""
|
|
Enhanced trainer with advanced memory optimizations
|
|
|
|
Args:
|
|
use_gradient_accumulation: Enable gradient accumulation for larger effective batch sizes
|
|
accumulation_steps: Number of steps to accumulate gradients before updating
|
|
use_memory_profiling: Enable automatic memory monitoring
|
|
use_model_compilation: Use torch.compile for better efficiency
|
|
"""
|
|
super().__init__(upscaler_model, dataset_class)
|
|
|
|
# Memory optimization settings
|
|
self.use_gradient_accumulation = use_gradient_accumulation
|
|
self.accumulation_steps = accumulation_steps
|
|
self.use_memory_profiling = use_memory_profiling
|
|
self.use_model_compilation = use_model_compilation
|
|
|
|
# Memory monitoring
|
|
self.memory_stats = []
|
|
self.peak_memory = 0
|
|
self.memory_monitor_thread = None
|
|
self.stop_monitoring = False
|
|
|
|
# Compile model for better efficiency if requested
|
|
if self.use_model_compilation and hasattr(torch, 'compile'):
|
|
try:
|
|
self.model = torch.compile(self.model, mode='reduce-overhead')
|
|
print("Model compiled successfully for better memory efficiency")
|
|
except Exception as e:
|
|
print(f"Model compilation failed: {e}, continuing without compilation")
|
|
|
|
# Use memory-efficient optimizer settings
|
|
self.optimizer_kwargs = {
|
|
'eps': 1e-8,
|
|
'weight_decay': 0.01,
|
|
'amsgrad': False # Disable amsgrad to save memory
|
|
}
|
|
|
|
def _get_optimal_num_workers(self):
|
|
"""Calculate optimal number of workers for DataLoader"""
|
|
cpu_count = os.cpu_count()
|
|
# Use fewer workers to reduce memory overhead
|
|
return min(4, max(1, cpu_count // 2))
|
|
|
|
def _memory_monitor(self):
|
|
"""Background thread to monitor memory usage"""
|
|
while not self.stop_monitoring:
|
|
try:
|
|
# GPU memory if available
|
|
if torch.cuda.is_available():
|
|
gpu_memory = torch.cuda.memory_allocated() / (1024**3) # GB
|
|
gpu_cached = torch.cuda.memory_reserved() / (1024**3) # GB
|
|
self.peak_memory = max(self.peak_memory, gpu_memory)
|
|
else:
|
|
gpu_memory = gpu_cached = 0
|
|
|
|
# System RAM
|
|
ram_usage = psutil.virtual_memory().percent
|
|
ram_available = psutil.virtual_memory().available / (1024**3) # GB
|
|
|
|
self.memory_stats.append({
|
|
'timestamp': time.time(),
|
|
'gpu_allocated': gpu_memory,
|
|
'gpu_cached': gpu_cached,
|
|
'ram_percent': ram_usage,
|
|
'ram_available': ram_available
|
|
})
|
|
|
|
# Keep only last 100 measurements to prevent memory buildup
|
|
if len(self.memory_stats) > 100:
|
|
self.memory_stats = self.memory_stats[-50:]
|
|
|
|
except Exception:
|
|
pass
|
|
time.sleep(5) # Check every 5 seconds
|
|
|
|
def load_data(self, dataset_params=None, batch_size=1, validation_split=0.2,
|
|
custom_train_dataset=None, custom_val_dataset=None, val_batch_size=None):
|
|
"""Enhanced data loading with memory optimizations"""
|
|
|
|
# Use smaller validation batch size to reduce memory spikes
|
|
if val_batch_size is None:
|
|
val_batch_size = max(1, batch_size // 2)
|
|
|
|
# Calculate optimal number of workers
|
|
optimal_workers = self._get_optimal_num_workers()
|
|
|
|
# Load datasets (keeping existing logic)
|
|
if custom_train_dataset is not None:
|
|
train_dataset = custom_train_dataset
|
|
val_dataset = custom_val_dataset if custom_val_dataset is not None else None
|
|
else:
|
|
if self.dataset_class is None:
|
|
raise ValueError("No dataset class provided.")
|
|
|
|
dataset = self.dataset_class(**dataset_params if isinstance(dataset_params, dict) else {'parquet_files': dataset_params})
|
|
dataset_size = len(dataset)
|
|
val_size = int(validation_split * dataset_size)
|
|
train_size = dataset_size - val_size
|
|
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
|
|
|
# Create optimized data loaders
|
|
self.data_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
pin_memory=True,
|
|
num_workers=optimal_workers,
|
|
prefetch_factor=2, # Reduce prefetching to save memory
|
|
persistent_workers=True if optimal_workers > 0 else False,
|
|
drop_last=True # Drop incomplete batches to maintain consistent memory usage
|
|
)
|
|
|
|
if val_dataset is not None:
|
|
self.validation_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=val_batch_size, # Use smaller validation batch size
|
|
shuffle=False,
|
|
pin_memory=True,
|
|
num_workers=min(2, optimal_workers), # Fewer workers for validation
|
|
prefetch_factor=1,
|
|
persistent_workers=True if optimal_workers > 0 else False
|
|
)
|
|
print(f"Loaded {len(train_dataset)} training samples (batch_size={batch_size}) and {len(val_dataset)} validation samples (batch_size={val_batch_size})")
|
|
else:
|
|
self.validation_loader = None
|
|
print(f"Loaded {len(train_dataset)} training samples (batch_size={batch_size})")
|
|
|
|
return self.data_loader, self.validation_loader
|
|
|
|
def _aggressive_memory_cleanup(self):
|
|
"""More aggressive memory cleanup"""
|
|
# Clear Python cache
|
|
gc.collect()
|
|
|
|
# Clear PyTorch cache
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
# Clear any lingering autograd graphs
|
|
if hasattr(torch.autograd, 'set_grad_enabled'):
|
|
with torch.no_grad():
|
|
pass
|
|
|
|
def _evaluate_memory_efficient(self):
|
|
"""Memory-efficient validation with smaller chunks"""
|
|
if self.validation_loader is None:
|
|
return 0.0
|
|
|
|
self.model.eval()
|
|
val_loss = 0.0
|
|
num_batches = 0
|
|
|
|
with torch.no_grad():
|
|
for low_res, high_res in self.validation_loader:
|
|
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
|
high_res = high_res.to(self.device, non_blocking=True)
|
|
|
|
with autocast(device_type=self.device.type):
|
|
outputs = self.model(low_res)
|
|
loss = self.criterion(outputs, high_res)
|
|
|
|
val_loss += loss.item()
|
|
num_batches += 1
|
|
|
|
# Immediate cleanup
|
|
del low_res, high_res, outputs, loss
|
|
|
|
# Aggressive cleanup every few batches
|
|
if num_batches % 10 == 0:
|
|
self._aggressive_memory_cleanup()
|
|
|
|
self.model.train()
|
|
# Final cleanup after validation
|
|
self._aggressive_memory_cleanup()
|
|
|
|
return val_loss / max(num_batches, 1)
|
|
|
|
def finetune(self, output_path, epochs=10, lr=1e-4, patience=3, min_delta=0.001):
|
|
"""Enhanced training with memory optimizations"""
|
|
if self.data_loader is None:
|
|
raise ValueError("Data not loaded. Call load_data first.")
|
|
|
|
# Start memory monitoring
|
|
if self.use_memory_profiling:
|
|
self.stop_monitoring = False
|
|
self.memory_monitor_thread = threading.Thread(target=self._memory_monitor, daemon=True)
|
|
self.memory_monitor_thread.start()
|
|
print("Memory monitoring started")
|
|
|
|
# Setup logging and optimizer
|
|
self._setup_logging(output_path=output_path)
|
|
self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, **self.optimizer_kwargs)
|
|
|
|
# Load checkpoint if available
|
|
checkpoint_info = self.load_checkpoint()
|
|
start_epoch = checkpoint_info[0] if checkpoint_info else 0
|
|
start_batch = checkpoint_info[1] if checkpoint_info else 0
|
|
|
|
# Setup early stopping
|
|
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
|
|
self.best_loss = float('inf')
|
|
|
|
print(f"Training configuration:")
|
|
print(f"- Gradient accumulation: {self.use_gradient_accumulation} (steps: {self.accumulation_steps})")
|
|
print(f"- Memory profiling: {self.use_memory_profiling}")
|
|
print(f"- Model compilation: {self.use_model_compilation}")
|
|
print(f"- Effective batch size: {self.data_loader.batch_size * (self.accumulation_steps if self.use_gradient_accumulation else 1)}")
|
|
|
|
# Training loop
|
|
self.model.train()
|
|
for epoch in range(start_epoch, epochs):
|
|
self.current_epoch = epoch
|
|
epoch_loss = 0.0
|
|
accumulation_loss = 0.0
|
|
|
|
train_batches = list(enumerate(self.data_loader))
|
|
start_idx = start_batch if epoch == start_epoch else 0
|
|
|
|
progress_bar = tqdm(train_batches[start_idx:],
|
|
initial=start_idx,
|
|
total=len(train_batches),
|
|
desc=f"Epoch {epoch + 1}/{epochs}")
|
|
|
|
|
|
for batch_idx, (low_res, high_res) in progress_bar:
|
|
# Move data to device
|
|
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
|
high_res = high_res.to(self.device, non_blocking=True)
|
|
|
|
# Forward pass
|
|
with autocast(device_type=self.device.type):
|
|
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
|
|
low_res.requires_grad_()
|
|
outputs = checkpoint(self.model, low_res)
|
|
outputs = outputs.clone() # <-- Clone added here
|
|
else:
|
|
outputs = self.model(low_res)
|
|
outputs = outputs.clone() # <-- Clone added here
|
|
loss = self.criterion(outputs, high_res)
|
|
|
|
# Scale loss for gradient accumulation
|
|
if self.use_gradient_accumulation:
|
|
loss = loss / self.accumulation_steps
|
|
|
|
# Backward pass
|
|
self.scaler.scale(loss).backward()
|
|
|
|
accumulation_loss += loss.item()
|
|
|
|
# Update weights every accumulation_steps or at the end of epoch
|
|
should_step = (not self.use_gradient_accumulation or
|
|
(batch_idx + 1) % self.accumulation_steps == 0 or
|
|
batch_idx == len(train_batches) - 1)
|
|
|
|
if should_step:
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
self.optimizer.zero_grad()
|
|
|
|
# Add accumulated loss to epoch loss
|
|
if self.use_gradient_accumulation:
|
|
epoch_loss += accumulation_loss
|
|
accumulation_loss = 0.0
|
|
else:
|
|
epoch_loss += loss.item()
|
|
|
|
# Update progress bar
|
|
current_loss = accumulation_loss if self.use_gradient_accumulation else loss.item()
|
|
progress_bar.set_postfix({
|
|
'loss': current_loss,
|
|
'peak_mem': f"{self.peak_memory:.1f}GB" if self.use_memory_profiling else "N/A"
|
|
})
|
|
|
|
# Immediate cleanup
|
|
del low_res, high_res, outputs, loss
|
|
|
|
# Handle checkpoints
|
|
self._handle_checkpoints(epoch + 1, batch_idx + 1, current_loss < self.best_loss)
|
|
|
|
# Periodic aggressive cleanup
|
|
if batch_idx % 20 == 0:
|
|
self._aggressive_memory_cleanup()
|
|
|
|
# End of epoch processing
|
|
avg_train_loss = epoch_loss / len(self.data_loader)
|
|
|
|
# Memory-efficient validation
|
|
if self.validation_loader:
|
|
val_loss = self._evaluate_memory_efficient()
|
|
is_improved = val_loss < self.best_loss
|
|
if is_improved:
|
|
self.best_loss = val_loss
|
|
|
|
with open(self.csv_path, mode='a', newline='') as file:
|
|
writer = csv.writer(file)
|
|
writer.writerow([epoch + 1, avg_train_loss, val_loss])
|
|
|
|
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {val_loss:.6f}")
|
|
if self.use_memory_profiling:
|
|
print(f"Peak GPU Memory: {self.peak_memory:.2f}GB")
|
|
else:
|
|
is_improved = avg_train_loss < self.best_loss
|
|
if is_improved:
|
|
self.best_loss = avg_train_loss
|
|
|
|
with open(self.csv_path, mode='a', newline='') as file:
|
|
writer = csv.writer(file)
|
|
writer.writerow([epoch + 1, avg_train_loss])
|
|
|
|
print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {avg_train_loss:.6f}")
|
|
|
|
# Save best model
|
|
if is_improved:
|
|
best_model_path = os.path.join(output_path, "best_model")
|
|
self.model.save_pretrained(best_model_path)
|
|
|
|
# Check early stopping
|
|
if early_stopping(val_loss if self.validation_loader else avg_train_loss):
|
|
print(f"Early stopping triggered at epoch {epoch + 1}")
|
|
break
|
|
|
|
# End of epoch cleanup
|
|
self._aggressive_memory_cleanup()
|
|
|
|
# Stop memory monitoring
|
|
if self.use_memory_profiling:
|
|
self.stop_monitoring = True
|
|
if self.memory_monitor_thread:
|
|
self.memory_monitor_thread.join(timeout=1)
|
|
print(f"Training completed. Peak GPU memory usage: {self.peak_memory:.2f}GB")
|
|
|
|
return self.best_loss
|
|
|
|
|
|
# Stop memory monitoring
|
|
if self.use_memory_profiling:
|
|
self.stop_monitoring = True
|
|
if self.memory_monitor_thread:
|
|
self.memory_monitor_thread.join(timeout=1)
|
|
print(f"Training completed. Peak GPU memory usage: {self.peak_memory:.2f}GB")
|
|
|
|
return self.best_loss
|
|
|
|
def get_memory_summary(self):
|
|
"""Get a summary of memory usage during training"""
|
|
if not self.memory_stats:
|
|
return "No memory statistics available"
|
|
|
|
gpu_peak = max([stat['gpu_allocated'] for stat in self.memory_stats])
|
|
ram_peak = max([stat['ram_percent'] for stat in self.memory_stats])
|
|
|
|
return {
|
|
'peak_gpu_memory_gb': gpu_peak,
|
|
'peak_ram_percent': ram_peak,
|
|
'total_measurements': len(self.memory_stats)
|
|
}
|