updated memory fix
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 9m43s
Details
Gitea Actions For AIIA / Explore-Gitea-Actions (push) Successful in 9m43s
Details
This commit is contained in:
parent
8010605f44
commit
159ada872b
|
@ -241,6 +241,7 @@ class MemoryOptimizedTrainer(aiuNNTrainer):
|
|||
total=len(train_batches),
|
||||
desc=f"Epoch {epoch + 1}/{epochs}")
|
||||
|
||||
|
||||
for batch_idx, (low_res, high_res) in progress_bar:
|
||||
# Move data to device
|
||||
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
||||
|
@ -251,8 +252,10 @@ class MemoryOptimizedTrainer(aiuNNTrainer):
|
|||
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
|
||||
low_res.requires_grad_()
|
||||
outputs = checkpoint(self.model, low_res)
|
||||
outputs = outputs.clone() # <-- Clone added here
|
||||
else:
|
||||
outputs = self.model(low_res)
|
||||
outputs = outputs.clone() # <-- Clone added here
|
||||
loss = self.criterion(outputs, high_res)
|
||||
|
||||
# Scale loss for gradient accumulation
|
||||
|
@ -348,6 +351,16 @@ class MemoryOptimizedTrainer(aiuNNTrainer):
|
|||
|
||||
return self.best_loss
|
||||
|
||||
|
||||
# Stop memory monitoring
|
||||
if self.use_memory_profiling:
|
||||
self.stop_monitoring = True
|
||||
if self.memory_monitor_thread:
|
||||
self.memory_monitor_thread.join(timeout=1)
|
||||
print(f"Training completed. Peak GPU memory usage: {self.peak_memory:.2f}GB")
|
||||
|
||||
return self.best_loss
|
||||
|
||||
def get_memory_summary(self):
|
||||
"""Get a summary of memory usage during training"""
|
||||
if not self.memory_stats:
|
||||
|
|
Loading…
Reference in New Issue