From e7e7e960010a2d432d32a7ed199368b3ddff0414 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Fri, 14 Feb 2025 22:03:04 +0100 Subject: [PATCH] max gpu usage added --- src/aiunn/finetune.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index d994f75..84f6b69 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -10,8 +10,10 @@ import io import base64 import numpy as np from torch import nn -from torch.utils.data import random_split +from torch.utils.data import random_split, DataLoader from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive +from torch.amp import autocast, GradScaler +from tqdm import tqdm class aiuNNDataset(torch.utils.data.Dataset): def __init__(self, parquet_path): @@ -72,7 +74,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): val_size = len(combined_dataset) - train_size train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size]) - train_loader = torch.utils.data.DataLoader( + train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, @@ -81,7 +83,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): persistent_workers=True ) - val_loader = torch.utils.data.DataLoader( + val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, @@ -91,39 +93,33 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Limit VRAM usage to 95% of available memory (reducing risk of overflow) + if device.type == 'cuda': + torch.cuda.set_per_process_memory_fraction(0.95, device=device) + model = model.to(device) criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) - # Initialize GradScaler for AMP - scaler = torch.amp.GradScaler() - + scaler = GradScaler() best_val_loss = float('inf') - from tqdm import tqdm - for epoch in range(epochs): model.train() train_loss = 0.0 - for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"): if torch.cuda.is_available(): torch.cuda.empty_cache() low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) - optimizer.zero_grad() - # Use AMP autocast for lower precision computations - with torch.cuda.amp.autocast(): + with autocast(): outputs = model(low_res) loss = criterion(outputs, high_res) - - # Scale the loss for backward pass scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() - train_loss += loss.item() avg_train_loss = train_loss / len(train_loader) @@ -131,26 +127,21 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): model.eval() val_loss = 0.0 - with torch.no_grad(): for batch in tqdm(val_loader, desc="Validation"): if torch.cuda.is_available(): torch.cuda.empty_cache() - low_res = batch['low_res'].to(device) high_res = batch['high_res'].to(device) - - with torch.amp.autocast(): + with autocast(): outputs = model(low_res) loss = criterion(outputs, high_res) val_loss += loss.item() - avg_val_loss = val_loss / len(val_loader) print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss model.save("best_model") - return model def main():