fixed savings
This commit is contained in:
parent
3631df7f0a
commit
3c0e9e8ac1
|
@ -0,0 +1,27 @@
|
|||
data_path1 = "/root/training_data/vision-dataset/images_pretrain.parquet"
|
||||
data_path2 = "/root/training_data/vision-dataset/vector_img_pretrain.parquet"
|
||||
|
||||
from aiia.model import AIIABase
|
||||
from aiia.model.config import AIIAConfig
|
||||
from aiia.pretrain import Pretrainer
|
||||
|
||||
# Create your model
|
||||
config = AIIAConfig(model_name="AIIA-Base-512x20k")
|
||||
model = AIIABase(config)
|
||||
|
||||
# Initialize pretrainer with the model
|
||||
pretrainer = Pretrainer(model, learning_rate=config.learning_rate, config=config)
|
||||
|
||||
# List of dataset paths
|
||||
dataset_paths = [
|
||||
data_path1,
|
||||
data_path2
|
||||
]
|
||||
|
||||
# Start training with multiple datasets
|
||||
pretrainer.train(
|
||||
dataset_paths=dataset_paths,
|
||||
num_epochs=10,
|
||||
batch_size=2,
|
||||
sample_size=10000
|
||||
)
|
|
@ -186,7 +186,7 @@ class Pretrainer:
|
|||
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
self.save_model("AIIA-base-512")
|
||||
self.model.save("AIIA-base-512")
|
||||
print("Best model saved!")
|
||||
|
||||
self.save_losses('losses.csv')
|
||||
|
@ -214,10 +214,6 @@ class Pretrainer:
|
|||
print(f"Validation Loss: {avg_val_loss:.4f}")
|
||||
return avg_val_loss
|
||||
|
||||
def save_model(self, path):
|
||||
"""Save the model and projection head."""
|
||||
self.model.save(path)
|
||||
torch.save(self.projection_head.state_dict(), f"{path}_projection_head.pth")
|
||||
|
||||
def save_losses(self, csv_file):
|
||||
"""Save training and validation losses to a CSV file."""
|
||||
|
|
Loading…
Reference in New Issue