finetune_class #1
|
@ -142,11 +142,13 @@ for epoch in range(num_epochs):
|
||||||
|
|
||||||
with autocast(device_type=device.type):
|
with autocast(device_type=device.type):
|
||||||
if use_checkpointing:
|
if use_checkpointing:
|
||||||
# Wrap the forward pass with checkpointing to trade compute for memory.
|
# Ensure the input tensor requires gradient so that checkpointing records the computation graph.
|
||||||
outputs = checkpoint(lambda x: model(x), low_res)
|
low_res.requires_grad_()
|
||||||
|
outputs = checkpoint(model, low_res)
|
||||||
else:
|
else:
|
||||||
outputs = model(low_res)
|
outputs = model(low_res)
|
||||||
loss = criterion(outputs, high_res)
|
loss = criterion(outputs, high_res)
|
||||||
|
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
|
|
Loading…
Reference in New Issue