diff --git a/src/pretrain.py b/src/pretrain.py index 31dce37..45aac3f 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -4,8 +4,8 @@ import pandas as pd from aiia.model.config import AIIAConfig from aiia.model import AIIABase from aiia.data.DataLoader import AIIADataLoader -import os -import copy +from tqdm import tqdm + def pretrain_model(data_path1, data_path2, num_epochs=3): # Read and merge datasets @@ -49,7 +49,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): aiia_loader = AIIADataLoader( merged_df, column="image_bytes", - batch_size=4, + batch_size=2, pretraining=True, collate_fn=safe_collate ) @@ -75,7 +75,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): model.train() total_train_loss = 0.0 - for batch in train_loader: + for batch in tqdm(train_loader): if batch is None: continue # Skip empty batches