From 3c0e9e8ac1be6e88551ecc1304acfb1c75d4d311 Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Tue, 28 Jan 2025 17:18:21 +0100 Subject: [PATCH] fixed savings --- run.py | 27 +++++++++++++++++++++++++++ src/aiia/pretrain/pretrainer.py | 6 +----- 2 files changed, 28 insertions(+), 5 deletions(-) create mode 100644 run.py diff --git a/run.py b/run.py new file mode 100644 index 0000000..fb20e63 --- /dev/null +++ b/run.py @@ -0,0 +1,27 @@ +data_path1 = "/root/training_data/vision-dataset/images_pretrain.parquet" +data_path2 = "/root/training_data/vision-dataset/vector_img_pretrain.parquet" + +from aiia.model import AIIABase +from aiia.model.config import AIIAConfig +from aiia.pretrain import Pretrainer + +# Create your model +config = AIIAConfig(model_name="AIIA-Base-512x20k") +model = AIIABase(config) + +# Initialize pretrainer with the model +pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config) + +# List of dataset paths +dataset_paths = [ + data_path1, + data_path2 +] + +# Start training with multiple datasets +pretrainer.train( + dataset_paths=dataset_paths, + num_epochs=10, + batch_size=2, + sample_size=10000 +) \ No newline at end of file diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index fa84fcb..30ebc92 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -186,7 +186,7 @@ class Pretrainer: if val_loss < best_val_loss: best_val_loss = val_loss - self.save_model("AIIA-base-512") + self.model.save("AIIA-base-512") print("Best model saved!") self.save_losses('losses.csv') @@ -214,10 +214,6 @@ class Pretrainer: print(f"Validation Loss: {avg_val_loss:.4f}") return avg_val_loss - def save_model(self, path): - """Save the model and projection head.""" - self.model.save(path) - torch.save(self.projection_head.state_dict(), f"{path}_projection_head.pth") def save_losses(self, csv_file): """Save training and validation losses to a CSV file."""