diff --git a/src/pretrain.py b/src/pretrain.py index 38663e3..31dce37 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -49,7 +49,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): aiia_loader = AIIADataLoader( merged_df, column="image_bytes", - batch_size=8, + batch_size=4, pretraining=True, collate_fn=safe_collate )