diff --git a/src/pretrain.py b/src/pretrain.py index c6e7705..38663e3 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -9,13 +9,13 @@ import copy def pretrain_model(data_path1, data_path2, num_epochs=3): # Read and merge datasets - df1 = pd.read_parquet(data_path1).head(2500) - df2 = pd.read_parquet(data_path2).head(2500) + df1 = pd.read_parquet(data_path1).head(10000) + df2 = pd.read_parquet(data_path2).head(10000) merged_df = pd.concat([df1, df2], ignore_index=True) # Model configuration config = AIIAConfig( - model_name="AIIA-Base-512x5k", + model_name="AIIA-Base-512x20k", ) # Initialize model and data loader @@ -49,7 +49,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): aiia_loader = AIIADataLoader( merged_df, column="image_bytes", - batch_size=32, + batch_size=8, pretraining=True, collate_fn=safe_collate )