improved memory usage in training #21
|
@ -140,17 +140,18 @@ class aiuNNTrainer:
|
||||||
val_loss = 0.0
|
val_loss = 0.0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
val_loss = 0.0
|
||||||
for low_res, high_res in tqdm(self.validation_loader, desc="Validating"):
|
for low_res, high_res in tqdm(self.validation_loader, desc="Validating"):
|
||||||
low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
|
low_res = low_res.to(self.device, non_blocking=True)
|
||||||
high_res = high_res.to(self.device, non_blocking=True)
|
high_res = high_res.to(self.device, non_blocking=True)
|
||||||
|
outputs = self.model(low_res)
|
||||||
with autocast(device_type=self.device.type):
|
loss = self.criterion(outputs, high_res)
|
||||||
outputs = self.model(low_res)
|
|
||||||
loss = self.criterion(outputs, high_res)
|
|
||||||
|
|
||||||
val_loss += loss.item()
|
val_loss += loss.item()
|
||||||
|
# Explicitly delete tensors
|
||||||
del low_res, high_res, outputs, loss
|
del low_res, high_res, outputs, loss
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
return val_loss
|
return val_loss
|
||||||
|
|
Loading…
Reference in New Issue