This commit is contained in:
Falko Victor Habel 2025-02-23 23:13:27 +01:00
parent a51300c77c
commit fde8bdcb6f
1 changed files with 4 additions and 2 deletions

View File

@ -142,11 +142,13 @@ for epoch in range(num_epochs):
with autocast(device_type=device.type):
if use_checkpointing:
# Wrap the forward pass with checkpointing to trade compute for memory.
outputs = checkpoint(lambda x: model(x), low_res)
# Ensure the input tensor requires gradient so that checkpointing records the computation graph.
low_res.requires_grad_()
outputs = checkpoint(model, low_res)
else:
outputs = model(low_res)
loss = criterion(outputs, high_res)
scaler.scale(loss).backward()
scaler.step(optimizer)