finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 4 additions and 2 deletions
Showing only changes of commit fde8bdcb6f - Show all commits

View File

@ -142,11 +142,13 @@ for epoch in range(num_epochs):
with autocast(device_type=device.type): with autocast(device_type=device.type):
if use_checkpointing: if use_checkpointing:
# Wrap the forward pass with checkpointing to trade compute for memory. # Ensure the input tensor requires gradient so that checkpointing records the computation graph.
outputs = checkpoint(lambda x: model(x), low_res) low_res.requires_grad_()
outputs = checkpoint(model, low_res)
else: else:
outputs = model(low_res) outputs = model(low_res)
loss = criterion(outputs, high_res) loss = criterion(outputs, high_res)
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.step(optimizer) scaler.step(optimizer)