diff --git a/src/aiunn/finetune/memory_trainer.py b/src/aiunn/finetune/memory_trainer.py index db6be54..9a7bd8f 100644 --- a/src/aiunn/finetune/memory_trainer.py +++ b/src/aiunn/finetune/memory_trainer.py @@ -237,9 +237,10 @@ class MemoryOptimizedTrainer(aiuNNTrainer): start_idx = start_batch if epoch == start_epoch else 0 progress_bar = tqdm(train_batches[start_idx:], - initial=start_idx, - total=len(train_batches), - desc=f"Epoch {epoch + 1}/{epochs}") + initial=start_idx, + total=len(train_batches), + desc=f"Epoch {epoch + 1}/{epochs}") + for batch_idx, (low_res, high_res) in progress_bar: # Move data to device @@ -251,8 +252,10 @@ class MemoryOptimizedTrainer(aiuNNTrainer): if hasattr(self, 'use_checkpointing') and self.use_checkpointing: low_res.requires_grad_() outputs = checkpoint(self.model, low_res) + outputs = outputs.clone() # <-- Clone added here else: outputs = self.model(low_res) + outputs = outputs.clone() # <-- Clone added here loss = self.criterion(outputs, high_res) # Scale loss for gradient accumulation @@ -266,8 +269,8 @@ class MemoryOptimizedTrainer(aiuNNTrainer): # Update weights every accumulation_steps or at the end of epoch should_step = (not self.use_gradient_accumulation or - (batch_idx + 1) % self.accumulation_steps == 0 or - batch_idx == len(train_batches) - 1) + (batch_idx + 1) % self.accumulation_steps == 0 or + batch_idx == len(train_batches) - 1) if should_step: self.scaler.step(self.optimizer) @@ -348,6 +351,16 @@ class MemoryOptimizedTrainer(aiuNNTrainer): return self.best_loss + + # Stop memory monitoring + if self.use_memory_profiling: + self.stop_monitoring = True + if self.memory_monitor_thread: + self.memory_monitor_thread.join(timeout=1) + print(f"Training completed. Peak GPU memory usage: {self.peak_memory:.2f}GB") + + return self.best_loss + def get_memory_summary(self): """Get a summary of memory usage during training""" if not self.memory_stats: