develop #4
|
@ -135,6 +135,7 @@ def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=1
|
|||
with autocast(device_type="cuda"):
|
||||
if use_checkpoint:
|
||||
# Use checkpointing if requested.
|
||||
low_res = batch['low_res'].to(device).requires_grad_()
|
||||
features = checkpoint(lambda x: model(x), low_res)
|
||||
else:
|
||||
features = model(low_res)
|
||||
|
|
Loading…
Reference in New Issue