diff --git a/run.py b/run.py new file mode 100644 index 0000000..fb20e63 --- /dev/null +++ b/run.py @@ -0,0 +1,27 @@ +data_path1 = "/root/training_data/vision-dataset/images_pretrain.parquet" +data_path2 = "/root/training_data/vision-dataset/vector_img_pretrain.parquet" + +from aiia.model import AIIABase +from aiia.model.config import AIIAConfig +from aiia.pretrain import Pretrainer + +# Create your model +config = AIIAConfig(model_name="AIIA-Base-512x20k") +model = AIIABase(config) + +# Initialize pretrainer with the model +pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config) + +# List of dataset paths +dataset_paths = [ + data_path1, + data_path2 +] + +# Start training with multiple datasets +pretrainer.train( + dataset_paths=dataset_paths, + num_epochs=10, + batch_size=2, + sample_size=10000 +) \ No newline at end of file diff --git a/src/aiia/pretrain/pretrainer.py b/src/aiia/pretrain/pretrainer.py index fa84fcb..30ebc92 100644 --- a/src/aiia/pretrain/pretrainer.py +++ b/src/aiia/pretrain/pretrainer.py @@ -186,7 +186,7 @@ class Pretrainer: if val_loss < best_val_loss: best_val_loss = val_loss - self.save_model("AIIA-base-512") + self.model.save("AIIA-base-512") print("Best model saved!") self.save_losses('losses.csv') @@ -214,10 +214,6 @@ class Pretrainer: print(f"Validation Loss: {avg_val_loss:.4f}") return avg_val_loss - def save_model(self, path): - """Save the model and projection head.""" - self.model.save(path) - torch.save(self.projection_head.state_dict(), f"{path}_projection_head.pth") def save_losses(self, csv_file): """Save training and validation losses to a CSV file."""