diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 6d24bb3..d3279ba 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -135,6 +135,7 @@ def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=1 with autocast(device_type="cuda"): if use_checkpoint: # Use checkpointing if requested. + low_res = batch['low_res'].to(device).requires_grad_() features = checkpoint(lambda x: model(x), low_res) else: features = model(low_res)