finetune_class #1
|
@ -106,7 +106,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
low_res = batch['low_res'].to(device)
|
low_res = batch['low_res'].to(device)
|
||||||
high_res = batch['high_res'].to(device)
|
high_res = batch['high_res'].to(device)
|
||||||
with autocast():
|
with autocast(device_type="cuda"):
|
||||||
if use_checkpoint:
|
if use_checkpoint:
|
||||||
outputs = checkpoint(lambda x: model(x), low_res)
|
outputs = checkpoint(lambda x: model(x), low_res)
|
||||||
else:
|
else:
|
||||||
|
@ -134,7 +134,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=1, epochs=10, ac
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
low_res = batch['low_res'].to(device)
|
low_res = batch['low_res'].to(device)
|
||||||
high_res = batch['high_res'].to(device)
|
high_res = batch['high_res'].to(device)
|
||||||
with autocast():
|
with autocast(device_type="cuda"):
|
||||||
outputs = model(low_res)
|
outputs = model(low_res)
|
||||||
loss = criterion(outputs, high_res)
|
loss = criterion(outputs, high_res)
|
||||||
val_loss += loss.item()
|
val_loss += loss.item()
|
||||||
|
|
Loading…
Reference in New Issue