diff --git a/src/pretrain.py b/src/pretrain.py index 6fc9922..09e9856 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -12,8 +12,8 @@ from aiia.data.DataLoader import AIIADataLoader def pretrain_model(data_path1, data_path2, num_epochs=3): # Merge the two parquet files - df1 = pd.read_parquet(data_path1) - df2 = pd.read_parquet(data_path2) + df1 = pd.read_parquet(data_path1).head(10000) + df2 = pd.read_parquet(data_path2).head(10000) merged_df = pd.concat([df1, df2], ignore_index=True) # Create a new AIIAConfig instance