diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 92fa954..774da0c 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -106,7 +106,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac torch.cuda.empty_cache() low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) - with autocast(): + with autocast(device_type="cuda"): if use_checkpoint: outputs = checkpoint(lambda x: model(x), low_res) else: @@ -134,7 +134,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac torch.cuda.empty_cache() low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) - with autocast(): + with autocast(device_type="cuda"): outputs = model(low_res) loss = criterion(outputs, high_res) val_loss += loss.item()