develop #4
|
@ -135,6 +135,7 @@ def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=1
|
||||||
with autocast(device_type="cuda"):
|
with autocast(device_type="cuda"):
|
||||||
if use_checkpoint:
|
if use_checkpoint:
|
||||||
# Use checkpointing if requested.
|
# Use checkpointing if requested.
|
||||||
|
low_res = batch['low_res'].to(device).requires_grad_()
|
||||||
features = checkpoint(lambda x: model(x), low_res)
|
features = checkpoint(lambda x: model(x), low_res)
|
||||||
else:
|
else:
|
||||||
features = model(low_res)
|
features = model(low_res)
|
||||||
|
|
Loading…
Reference in New Issue