finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 2 additions and 2 deletions
Showing only changes of commit 1234fc5beb - Show all commits

View File

@ -106,7 +106,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
torch.cuda.empty_cache() torch.cuda.empty_cache()
low_res = batch['low_res'].to(device) low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device) high_res = batch['high_res'].to(device)
with autocast(): with autocast(device_type="cuda"):
if use_checkpoint: if use_checkpoint:
outputs = checkpoint(lambda x: model(x), low_res) outputs = checkpoint(lambda x: model(x), low_res)
else: else:
@ -134,7 +134,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
torch.cuda.empty_cache() torch.cuda.empty_cache()
low_res = batch['low_res'].to(device) low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device) high_res = batch['high_res'].to(device)
with autocast(): with autocast(device_type="cuda"):
outputs = model(low_res) outputs = model(low_res)
loss = criterion(outputs, high_res) loss = criterion(outputs, high_res)
val_loss += loss.item() val_loss += loss.item()