develop #4

Merged
Fabel merged 103 commits from develop into main 2025-03-01 21:47:17 +00:00
1 changed files with 2 additions and 2 deletions
Showing only changes of commit b4dd550f8d - Show all commits

View File

@ -15,7 +15,7 @@ from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunke
class aiuNNDataset(torch.utils.data.Dataset): class aiuNNDataset(torch.utils.data.Dataset):
def __init__(self, parquet_path): def __init__(self, parquet_path):
self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2500) self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2000)
self.augmentation = Compose([ self.augmentation = Compose([
RandomBrightnessContrast(p=0.5), RandomBrightnessContrast(p=0.5),
@ -139,7 +139,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10):
print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss: if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss best_val_loss = avg_val_loss
torch.save(model.state_dict(), "best_model.pth") model.save("best_model")
return model return model