finetune_class #1
|
@ -10,8 +10,10 @@ import io
|
||||||
import base64
|
import base64
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import random_split
|
from torch.utils.data import random_split, DataLoader
|
||||||
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
|
from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunked, AIIArecursive
|
||||||
|
from torch.amp import autocast, GradScaler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
class aiuNNDataset(torch.utils.data.Dataset):
|
class aiuNNDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, parquet_path):
|
def __init__(self, parquet_path):
|
||||||
|
@ -72,7 +74,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
||||||
val_size = len(combined_dataset) - train_size
|
val_size = len(combined_dataset) - train_size
|
||||||
train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])
|
train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])
|
||||||
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
|
@ -81,7 +83,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
||||||
persistent_workers=True
|
persistent_workers=True
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loader = torch.utils.data.DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
@ -91,39 +93,33 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
||||||
)
|
)
|
||||||
|
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
# Limit VRAM usage to 95% of available memory (reducing risk of overflow)
|
||||||
|
if device.type == 'cuda':
|
||||||
|
torch.cuda.set_per_process_memory_fraction(0.95, device=device)
|
||||||
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
|
optimizer = torch.optim.Adam(model.parameters(), lr=model.config.learning_rate)
|
||||||
|
|
||||||
# Initialize GradScaler for AMP
|
scaler = GradScaler()
|
||||||
scaler = torch.amp.GradScaler()
|
|
||||||
|
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
model.train()
|
model.train()
|
||||||
train_loss = 0.0
|
train_loss = 0.0
|
||||||
|
|
||||||
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"):
|
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/Training"):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
low_res = batch['low_res'].to(device)
|
low_res = batch['low_res'].to(device)
|
||||||
high_res = batch['high_res'].to(device)
|
high_res = batch['high_res'].to(device)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
# Use AMP autocast for lower precision computations
|
with autocast():
|
||||||
with torch.cuda.amp.autocast():
|
|
||||||
outputs = model(low_res)
|
outputs = model(low_res)
|
||||||
loss = criterion(outputs, high_res)
|
loss = criterion(outputs, high_res)
|
||||||
|
|
||||||
# Scale the loss for backward pass
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
|
||||||
train_loss += loss.item()
|
train_loss += loss.item()
|
||||||
|
|
||||||
avg_train_loss = train_loss / len(train_loader)
|
avg_train_loss = train_loss / len(train_loader)
|
||||||
|
@ -131,26 +127,21 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss = 0.0
|
val_loss = 0.0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in tqdm(val_loader, desc="Validation"):
|
for batch in tqdm(val_loader, desc="Validation"):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
low_res = batch['low_res'].to(device)
|
low_res = batch['low_res'].to(device)
|
||||||
high_res = batch['high_res'].to(device)
|
high_res = batch['high_res'].to(device)
|
||||||
|
with autocast():
|
||||||
with torch.amp.autocast():
|
|
||||||
outputs = model(low_res)
|
outputs = model(low_res)
|
||||||
loss = criterion(outputs, high_res)
|
loss = criterion(outputs, high_res)
|
||||||
val_loss += loss.item()
|
val_loss += loss.item()
|
||||||
|
|
||||||
avg_val_loss = val_loss / len(val_loader)
|
avg_val_loss = val_loss / len(val_loader)
|
||||||
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
|
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
|
||||||
if avg_val_loss < best_val_loss:
|
if avg_val_loss < best_val_loss:
|
||||||
best_val_loss = avg_val_loss
|
best_val_loss = avg_val_loss
|
||||||
model.save("best_model")
|
model.save("best_model")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
Loading…
Reference in New Issue