fixed checkpoints tensors

This commit is contained in:
Falko Victor Habel 2025-02-15 13:08:09 +01:00
parent 75be3291d3
commit e69d0e90ec
1 changed files with 1 additions and 0 deletions

View File

@ -135,6 +135,7 @@ def finetune_model(model: nn.Module, datasets: list[str], batch_size=1, epochs=1
with autocast(device_type="cuda"): with autocast(device_type="cuda"):
if use_checkpoint: if use_checkpoint:
# Use checkpointing if requested. # Use checkpointing if requested.
low_res = batch['low_res'].to(device).requires_grad_()
features = checkpoint(lambda x: model(x), low_res) features = checkpoint(lambda x: model(x), low_res)
else: else:
features = model(low_res) features = model(low_res)