diff --git a/src/aiunn/finetune/trainer.py b/src/aiunn/finetune/trainer.py index 40d04cd..e4cdab2 100644 --- a/src/aiunn/finetune/trainer.py +++ b/src/aiunn/finetune/trainer.py @@ -140,17 +140,18 @@ class aiuNNTrainer: val_loss = 0.0 with torch.no_grad(): + val_loss = 0.0 for low_res, high_res in tqdm(self.validation_loader, desc="Validating"): - low_res = low_res.to(self.device, non_blocking=True).to(memory_format=torch.channels_last) + low_res = low_res.to(self.device, non_blocking=True) high_res = high_res.to(self.device, non_blocking=True) - - with autocast(device_type=self.device.type): - outputs = self.model(low_res) - loss = self.criterion(outputs, high_res) - + outputs = self.model(low_res) + loss = self.criterion(outputs, high_res) val_loss += loss.item() - + # Explicitly delete tensors del low_res, high_res, outputs, loss + gc.collect() + torch.cuda.empty_cache() + self.model.train() return val_loss