Compare commits
2 Commits
c92fa68e92
...
e7b9da37d6
Author | SHA1 | Date |
---|---|---|
|
e7b9da37d6 | |
|
159ada872b |
|
@ -237,9 +237,10 @@ class MemoryOptimizedTrainer(aiuNNTrainer):
|
||||||
start_idx = start_batch if epoch == start_epoch else 0
|
start_idx = start_batch if epoch == start_epoch else 0
|
||||||
|
|
||||||
progress_bar = tqdm(train_batches[start_idx:],
|
progress_bar = tqdm(train_batches[start_idx:],
|
||||||
initial=start_idx,
|
initial=start_idx,
|
||||||
total=len(train_batches),
|
total=len(train_batches),
|
||||||
desc=f"Epoch {epoch + 1}/{epochs}")
|
desc=f"Epoch {epoch + 1}/{epochs}")
|
||||||
|
|
||||||
|
|
||||||
for batch_idx, (low_res, high_res) in progress_bar:
|
for batch_idx, (low_res, high_res) in progress_bar:
|
||||||
# Move data to device
|
# Move data to device
|
||||||
|
@ -251,8 +252,10 @@ class MemoryOptimizedTrainer(aiuNNTrainer):
|
||||||
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
|
if hasattr(self, 'use_checkpointing') and self.use_checkpointing:
|
||||||
low_res.requires_grad_()
|
low_res.requires_grad_()
|
||||||
outputs = checkpoint(self.model, low_res)
|
outputs = checkpoint(self.model, low_res)
|
||||||
|
outputs = outputs.clone() # <-- Clone added here
|
||||||
else:
|
else:
|
||||||
outputs = self.model(low_res)
|
outputs = self.model(low_res)
|
||||||
|
outputs = outputs.clone() # <-- Clone added here
|
||||||
loss = self.criterion(outputs, high_res)
|
loss = self.criterion(outputs, high_res)
|
||||||
|
|
||||||
# Scale loss for gradient accumulation
|
# Scale loss for gradient accumulation
|
||||||
|
@ -266,8 +269,8 @@ class MemoryOptimizedTrainer(aiuNNTrainer):
|
||||||
|
|
||||||
# Update weights every accumulation_steps or at the end of epoch
|
# Update weights every accumulation_steps or at the end of epoch
|
||||||
should_step = (not self.use_gradient_accumulation or
|
should_step = (not self.use_gradient_accumulation or
|
||||||
(batch_idx + 1) % self.accumulation_steps == 0 or
|
(batch_idx + 1) % self.accumulation_steps == 0 or
|
||||||
batch_idx == len(train_batches) - 1)
|
batch_idx == len(train_batches) - 1)
|
||||||
|
|
||||||
if should_step:
|
if should_step:
|
||||||
self.scaler.step(self.optimizer)
|
self.scaler.step(self.optimizer)
|
||||||
|
@ -348,6 +351,16 @@ class MemoryOptimizedTrainer(aiuNNTrainer):
|
||||||
|
|
||||||
return self.best_loss
|
return self.best_loss
|
||||||
|
|
||||||
|
|
||||||
|
# Stop memory monitoring
|
||||||
|
if self.use_memory_profiling:
|
||||||
|
self.stop_monitoring = True
|
||||||
|
if self.memory_monitor_thread:
|
||||||
|
self.memory_monitor_thread.join(timeout=1)
|
||||||
|
print(f"Training completed. Peak GPU memory usage: {self.peak_memory:.2f}GB")
|
||||||
|
|
||||||
|
return self.best_loss
|
||||||
|
|
||||||
def get_memory_summary(self):
|
def get_memory_summary(self):
|
||||||
"""Get a summary of memory usage during training"""
|
"""Get a summary of memory usage during training"""
|
||||||
if not self.memory_stats:
|
if not self.memory_stats:
|
||||||
|
|
Loading…
Reference in New Issue