test
This commit is contained in:
parent
a51300c77c
commit
fde8bdcb6f
|
@ -142,11 +142,13 @@ for epoch in range(num_epochs):
|
|||
|
||||
with autocast(device_type=device.type):
|
||||
if use_checkpointing:
|
||||
# Wrap the forward pass with checkpointing to trade compute for memory.
|
||||
outputs = checkpoint(lambda x: model(x), low_res)
|
||||
# Ensure the input tensor requires gradient so that checkpointing records the computation graph.
|
||||
low_res.requires_grad_()
|
||||
outputs = checkpoint(model, low_res)
|
||||
else:
|
||||
outputs = model(low_res)
|
||||
loss = criterion(outputs, high_res)
|
||||
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
|
|
Loading…
Reference in New Issue