finetune_class #1

Merged
Fabel merged 96 commits from finetune_class into develop 2025-02-26 12:13:09 +00:00
1 changed files with 8 additions and 5 deletions
Showing only changes of commit 20dd9f68ed - Show all commits

View File

@ -15,7 +15,7 @@ from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunke
class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path):
# Read the Parquet file
self.df = pd.read_parquet(parquet_path).head(2500)
self.df = pd.read_parquet(parquet_path).head(1250)
# Data augmentation pipeline without Resize as it's redundant
self.augmentation = Compose([
@ -124,12 +124,15 @@ def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10):
best_val_loss = float('inf')
from tqdm import tqdm
for epoch in range(epochs):
model.train()
train_loss = 0.0
for batch_idx, batch in enumerate(train_loader):
for batch_idx, batch in enumerate(tqdm(train_loader)):
# Your training code here
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)
@ -155,7 +158,7 @@ def finetune_model(model: AIIA, datasets:list[str], batch_size=2, epochs=10):
val_loss = 0.0
with torch.no_grad():
for batch in val_loader:
for batch in tqdm(val_loader, desc="Validation"):
low_res = batch['low_res'].to(device)
high_res = batch['high_res'].to(device)