finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 12 additions and 21 deletions
Showing only changes of commit e7e7e96001 - Show all commits

View File

@ -10,8 +10,10 @@ import io
import base64 import base64
import numpy as np import numpy as np
from torch import nn from torch import nn
from torch.utils.data import random_split from torch.utils.data import random_split, DataLoader
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
from torch.amp import autocast, GradScaler
from tqdm import tqdm
class aiuNNDataset(torch.utils.data.Dataset): class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path): def __init__(self, parquet_path):
@ -72,7 +74,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
val_size = len(combined_dataset) - train_size val_size = len(combined_dataset) - train_size
train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size]) train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])
train_loader = torch.utils.data.DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
@ -81,7 +83,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
persistent_workers=True persistent_workers=True
) )
val_loader = torch.utils.data.DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
@ -91,39 +93,33 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
) )
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Limit VRAM usage to 95% of available memory (reducing risk of overflow)
if device.type == 'cuda':
torch.cuda.set_per_process_memory_fraction(0.95, device=device)
model = model.to(device) model = model.to(device)
criterion = nn.MSELoss() criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate) optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
# Initialize GradScaler for AMP scaler = GradScaler()
scaler = torch.amp.GradScaler()
best_val_loss = float('inf') best_val_loss = float('inf')
from tqdm import tqdm
for epoch in range(epochs): for epoch in range(epochs):
model.train() model.train()
train_loss = 0.0 train_loss = 0.0
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"): for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
low_res = batch['low_res'].to(device) low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device) high_res = batch['high_res'].to(device)
optimizer.zero_grad() optimizer.zero_grad()
# Use AMP autocast for lower precision computations with autocast():
with torch.cuda.amp.autocast():
outputs = model(low_res) outputs = model(low_res)
loss = criterion(outputs, high_res) loss = criterion(outputs, high_res)
# Scale the loss for backward pass
scaler.scale(loss).backward() scaler.scale(loss).backward()
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
train_loss += loss.item() train_loss += loss.item()
avg_train_loss = train_loss / len(train_loader) avg_train_loss = train_loss / len(train_loader)
@ -131,26 +127,21 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
model.eval() model.eval()
val_loss = 0.0 val_loss = 0.0
with torch.no_grad(): with torch.no_grad():
for batch in tqdm(val_loader, desc="Validation"): for batch in tqdm(val_loader, desc="Validation"):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
low_res = batch['low_res'].to(device) low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device) high_res = batch['high_res'].to(device)
with autocast():
with torch.amp.autocast():
outputs = model(low_res) outputs = model(low_res)
loss = criterion(outputs, high_res) loss = criterion(outputs, high_res)
val_loss += loss.item() val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader) avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss: if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss best_val_loss = avg_val_loss
model.save("best_model") model.save("best_model")
return model return model
def main(): def main():