diff --git a/src/aiunn/finetune.py b/src/aiunn/finetune.py index 284c606..5b2f8fb 100644 --- a/src/aiunn/finetune.py +++ b/src/aiunn/finetune.py @@ -15,7 +15,7 @@ from aiia import AIIA, AIIAConfig, AIIABase, AIIABaseShared, AIIAmoe, AIIAchunke class aiuNNDataset(torch.utils.data.Dataset): def __init__(self, parquet_path): - self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2500) + self.df = pd.read_parquet(parquet_path, columns=['image_512', 'image_1024']).head(2000) self.augmentation = Compose([ RandomBrightnessContrast(p=0.5), @@ -139,7 +139,7 @@ def finetune_model(model: AIIA, datasets: list[str], batch_size=2, epochs=10): print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}") if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss - torch.save(model.state_dict(), "best_model.pth") + model.save("best_model") return model