added torch.amp support
This commit is contained in:
parent
b4dd550f8d
commit
ca44dd8a77
|
@ -63,6 +63,7 @@ class aiuNNDataset(torch.utils.data.Dataset):
|
||||||
'low_res': augmented_low['image'],
|
'low_res': augmented_low['image'],
|
||||||
'high_res': augmented_high['image']
|
'high_res': augmented_high['image']
|
||||||
}
|
}
|
||||||
|
|
||||||
def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
||||||
loaded_datasets = [aiuNNDataset(d) for d in datasets]
|
loaded_datasets = [aiuNNDataset(d) for d in datasets]
|
||||||
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
|
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
|
||||||
|
@ -95,6 +96,9 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
|
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
|
||||||
|
|
||||||
|
# Initialize GradScaler for AMP
|
||||||
|
scaler = torch.amp.GradScaler()
|
||||||
|
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
@ -110,11 +114,16 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
||||||
high_res = batch['high_res'].to(device)
|
high_res = batch['high_res'].to(device)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
outputs = model(low_res)
|
# Use AMP autocast for lower precision computations
|
||||||
loss = criterion(outputs, high_res)
|
with torch.cuda.amp.autocast():
|
||||||
|
outputs = model(low_res)
|
||||||
|
loss = criterion(outputs, high_res)
|
||||||
|
|
||||||
|
# Scale the loss for backward pass
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
train_loss += loss.item()
|
train_loss += loss.item()
|
||||||
|
|
||||||
avg_train_loss = train_loss / len(train_loader)
|
avg_train_loss = train_loss / len(train_loader)
|
||||||
|
@ -131,8 +140,9 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
||||||
low_res = batch['low_res'].to(device)
|
low_res = batch['low_res'].to(device)
|
||||||
high_res = batch['high_res'].to(device)
|
high_res = batch['high_res'].to(device)
|
||||||
|
|
||||||
outputs = model(low_res)
|
with torch.amp.autocast():
|
||||||
loss = criterion(outputs, high_res)
|
outputs = model(low_res)
|
||||||
|
loss = criterion(outputs, high_res)
|
||||||
val_loss += loss.item()
|
val_loss += loss.item()
|
||||||
|
|
||||||
avg_val_loss = val_loss / len(val_loader)
|
avg_val_loss = val_loss / len(val_loader)
|
||||||
|
|
Loading…
Reference in New Issue