From fde8bdcb6f4aa72258736b1073a9ad1777b01cac Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sun, 23 Feb 2025 23:13:27 +0100 Subject: [PATCH] test --- src/aiunn/finetune.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index f0bb3e8..6a2f9cc 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -142,11 +142,13 @@ for epoch in range(num_epochs): with autocast(device_type=device.type): if use_checkpointing: - # Wrap the forward pass with checkpointing to trade compute for memory. - outputs = checkpoint(lambda x: model(x), low_res) + # Ensure the input tensor requires gradient so that checkpointing records the computation graph. + low_res.requires_grad_() + outputs = checkpoint(model, low_res) else: outputs = model(low_res) loss = criterion(outputs, high_res) + scaler.scale(loss).backward() scaler.step(optimizer)