develop #4
|
@ -63,6 +63,7 @@ class aiuNNDataset(torch.utils.data.Dataset):
|
|||
'low_res': augmented_low['image'],
|
||||
'high_res': augmented_high['image']
|
||||
}
|
||||
|
||||
def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
||||
loaded_datasets = [aiuNNDataset(d) for d in datasets]
|
||||
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
|
||||
|
@ -95,6 +96,9 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
|||
criterion = nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
|
||||
|
||||
# Initialize GradScaler for AMP
|
||||
scaler = torch.amp.GradScaler()
|
||||
|
||||
best_val_loss = float('inf')
|
||||
|
||||
from tqdm import tqdm
|
||||
|
@ -110,11 +114,16 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
|||
high_res = batch['high_res'].to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(low_res)
|
||||
loss = criterion(outputs, high_res)
|
||||
# Use AMP autocast for lower precision computations
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs = model(low_res)
|
||||
loss = criterion(outputs, high_res)
|
||||
|
||||
# Scale the loss for backward pass
|
||||
scaler.scale(loss).backward()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
train_loss += loss.item()
|
||||
|
||||
avg_train_loss = train_loss / len(train_loader)
|
||||
|
@ -131,8 +140,9 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
|||
low_res = batch['low_res'].to(device)
|
||||
high_res = batch['high_res'].to(device)
|
||||
|
||||
outputs = model(low_res)
|
||||
loss = criterion(outputs, high_res)
|
||||
with torch.amp.autocast():
|
||||
outputs = model(low_res)
|
||||
loss = criterion(outputs, high_res)
|
||||
val_loss += loss.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
|
|
Loading…
Reference in New Issue