fixed savings

This commit is contained in:
Falko Victor Habel 2025-01-28 17:18:21 +01:00
parent 3631df7f0a
commit 3c0e9e8ac1
2 changed files with 28 additions and 5 deletions

27
run.py Normal file
View File

@ -0,0 +1,27 @@
data_path1 = "/root/training_data/vision-dataset/images_pretrain.parquet"
data_path2 = "/root/training_data/vision-dataset/vector_img_pretrain.parquet"
from aiia.model import AIIABase
from aiia.model.config import AIIAConfig
from aiia.pretrain import Pretrainer
# Create your model
config = AIIAConfig(model_name="AIIA-Base-512x20k")
model = AIIABase(config)
# Initialize pretrainer with the model
pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
# List of dataset paths
dataset_paths = [
data_path1,
data_path2
]
# Start training with multiple datasets
pretrainer.train(
dataset_paths=dataset_paths,
num_epochs=10,
batch_size=2,
sample_size=10000
)

View File

@ -186,7 +186,7 @@ class Pretrainer:
if val_loss < best_val_loss:
best_val_loss = val_loss
self.save_model("AIIA-base-512")
self.model.save("AIIA-base-512")
print("Best model saved!")
self.save_losses('losses.csv')
@ -214,10 +214,6 @@ class Pretrainer:
print(f"Validation Loss: {avg_val_loss:.4f}")
return avg_val_loss
def save_model(self, path):
"""Save the model and projection head."""
self.model.save(path)
torch.save(self.projection_head.state_dict(), f"{path}_projection_head.pth")
def save_losses(self, csv_file):
"""Save training and validation losses to a CSV file."""