develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 18 additions and 8 deletions
Showing only changes of commit ca44dd8a77 - Show all commits

View File

@ -63,6 +63,7 @@ class aiuNNDataset(torch.utils.data.Dataset):
'low_res': augmented_low['image'],
'high_res': augmented_high['image']
}
def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
loaded_datasets = [aiuNNDataset(d) for d in datasets]
combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets)
@ -95,6 +96,9 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
# Initialize GradScaler for AMP
scaler = torch.amp.GradScaler()
best_val_loss = float('inf')
from tqdm import tqdm
@ -110,11 +114,16 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
high_res = batch['high_res'].to(device)
optimizer.zero_grad()
outputs = model(low_res)
loss = criterion(outputs, high_res)
# Use AMP autocast for lower precision computations
with torch.cuda.amp.autocast():
outputs = model(low_res)
loss = criterion(outputs, high_res)
# Scale the loss for backward pass
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
loss.backward()
optimizer.step()
train_loss += loss.item()
avg_train_loss = train_loss / len(train_loader)
@ -131,8 +140,9 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
outputs = model(low_res)
loss = criterion(outputs, high_res)
with torch.amp.autocast():
outputs = model(low_res)
loss = criterion(outputs, high_res)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)