diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 5b2f8fb..d994f75 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -63,6 +63,7 @@ class aiuNNDataset(torch.utils.data.Dataset): 'low_res': augmented_low['image'], 'high_res': augmented_high['image'] } + def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): loaded_datasets = [aiuNNDataset(d) for d in datasets] combined_dataset = torch.utils.data.ConcatDataset(loaded_datasets) @@ -95,6 +96,9 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) + # Initialize GradScaler for AMP + scaler = torch.amp.GradScaler() + best_val_loss = float('inf') from tqdm import tqdm @@ -110,11 +114,16 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): high_res = batch['high_res'].to(device) optimizer.zero_grad() - outputs = model(low_res) - loss = criterion(outputs, high_res) - - loss.backward() - optimizer.step() + # Use AMP autocast for lower precision computations + with torch.cuda.amp.autocast(): + outputs = model(low_res) + loss = criterion(outputs, high_res) + + # Scale the loss for backward pass + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + train_loss += loss.item() avg_train_loss = train_loss / len(train_loader) @@ -131,8 +140,9 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) - outputs = model(low_res) - loss = criterion(outputs, high_res) + with torch.amp.autocast(): + outputs = model(low_res) + loss = criterion(outputs, high_res) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) @@ -160,4 +170,4 @@ def main(): ) if __name__ == '__main__': - main() \ No newline at end of file + main()