From 6c146f2767774ed0571475e1b47e6fdbf5ad81dc Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Mon, 27 Jan 2025 09:13:22 +0100 Subject: [PATCH] added progressbar for batches --- src/pretrain.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pretrain.py b/src/pretrain.py index 31dce37..45aac3f 100644 --- a/src/pretrain.py +++ b/src/pretrain.py @@ -4,8 +4,8 @@ import pandas as pd from aiia.model.config import AIIAConfig from aiia.model import AIIABase from aiia.data.DataLoader import AIIADataLoader -import os -import copy +from tqdm import tqdm + def pretrain_model(data_path1, data_path2, num_epochs=3): # Read and merge datasets @@ -49,7 +49,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): aiia_loader = AIIADataLoader( merged_df, column="image_bytes", - batch_size=4, + batch_size=2, pretraining=True, collate_fn=safe_collate ) @@ -75,7 +75,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3): model.train() total_train_loss = 0.0 - for batch in train_loader: + for batch in tqdm(train_loader): if batch is None: continue # Skip empty batches