improved memory usage in training #21

Merged
Fabel merged 1 commits from feat/model_fix into main 2025-06-04 19:23:15 +00:00
1 changed files with 8 additions and 7 deletions

View File

@ -140,17 +140,18 @@ class aiuNNTrainer:
val_loss = 0.0 val_loss = 0.0
with torch.no_grad(): with torch.no_grad():
val_loss = 0.0
for low_res, high_res in tqdm(self.validation_loader, desc="Validating"): for low_res, high_res in tqdm(self.validation_loader, desc="Validating"):
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) low_res = low_res.to(self.device, non_blocking=True)
high_res = high_res.to(self.device, non_blocking=True) high_res = high_res.to(self.device, non_blocking=True)
with autocast(device_type=self.device.type):
outputs = self.model(low_res) outputs = self.model(low_res)
loss = self.criterion(outputs, high_res) loss = self.criterion(outputs, high_res)
val_loss += loss.item() val_loss += loss.item()
# Explicitly delete tensors
del low_res, high_res, outputs, loss del low_res, high_res, outputs, loss
gc.collect()
torch.cuda.empty_cache()
self.model.train() self.model.train()
return val_loss return val_loss