From e69d0e90ec254785704837bb48bcbbe62d4230b8 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sat, 15 Feb 2025 13:08:09 +0100 Subject: [PATCH] fixed checkpoints tensors --- src/aiunn/finetune.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 6d24bb3..d3279ba 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -135,6 +135,7 @@ def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=1 with autocast(device_type="cuda"): if use_checkpoint: # Use checkpointing if requested. + low_res = batch['low_res'].to(device).requires_grad_() features = checkpoint(lambda x: model(x), low_res) else: features = model(low_res)