develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 12 additions and 21 deletions
Showing only changes of commit e7e7e96001 - Show all commits

View File

@ -10,8 +10,10 @@ import io
import base64
import numpy as np
from torch import nn
from torch.utils.data import random_split
from torch.utils.data import random_split, DataLoader
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
from torch.amp import autocast, GradScaler
from tqdm import tqdm
class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path):
@ -72,7 +74,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
val_size = len(combined_dataset) - train_size
train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])
train_loader = torch.utils.data.DataLoader(
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
@ -81,7 +83,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
persistent_workers=True
)
val_loader = torch.utils.data.DataLoader(
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
@ -91,39 +93,33 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Limit VRAM usage to 95% of available memory (reducing risk of overflow)
if device.type == 'cuda':
torch.cuda.set_per_process_memory_fraction(0.95, device=device)
model = model.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
# Initialize GradScaler for AMP
scaler = torch.amp.GradScaler()
scaler = GradScaler()
best_val_loss = float('inf')
from tqdm import tqdm
for epoch in range(epochs):
model.train()
train_loss = 0.0
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"):
if torch.cuda.is_available():
torch.cuda.empty_cache()
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
optimizer.zero_grad()
# Use AMP autocast for lower precision computations
with torch.cuda.amp.autocast():
with autocast():
outputs = model(low_res)
loss = criterion(outputs, high_res)
# Scale the loss for backward pass
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item()
avg_train_loss = train_loss / len(train_loader)
@ -131,26 +127,21 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in tqdm(val_loader, desc="Validation"):
if torch.cuda.is_available():
torch.cuda.empty_cache()
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
with torch.amp.autocast():
with autocast():
outputs = model(low_res)
loss = criterion(outputs, high_res)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
model.save("best_model")
return model
def main():