develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 4 additions and 2 deletions
Showing only changes of commit fde8bdcb6f - Show all commits

View File

@ -142,11 +142,13 @@ for epoch in range(num_epochs):
with autocast(device_type=device.type):
if use_checkpointing:
# Wrap the forward pass with checkpointing to trade compute for memory.
outputs = checkpoint(lambda x: model(x), low_res)
# Ensure the input tensor requires gradient so that checkpointing records the computation graph.
low_res.requires_grad_()
outputs = checkpoint(model, low_res)
else:
outputs = model(low_res)
loss = criterion(outputs, high_res)
scaler.scale(loss).backward()
scaler.step(optimizer)