added progressbar for batches

This commit is contained in:
Falko Victor Habel 2025-01-27 09:13:22 +01:00
parent b546f4ee27
commit 6c146f2767
1 changed files with 4 additions and 4 deletions

View File

@ -4,8 +4,8 @@ import pandas as pd
from aiia.model.config import AIIAConfig from aiia.model.config import AIIAConfig
from aiia.model import AIIABase from aiia.model import AIIABase
from aiia.data.DataLoader import AIIADataLoader from aiia.data.DataLoader import AIIADataLoader
import os from tqdm import tqdm
import copy
def pretrain_model(data_path1, data_path2, num_epochs=3): def pretrain_model(data_path1, data_path2, num_epochs=3):
# Read and merge datasets # Read and merge datasets
@ -49,7 +49,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
aiia_loader = AIIADataLoader( aiia_loader = AIIADataLoader(
merged_df, merged_df,
column="image_bytes", column="image_bytes",
batch_size=4, batch_size=2,
pretraining=True, pretraining=True,
collate_fn=safe_collate collate_fn=safe_collate
) )
@ -75,7 +75,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
model.train() model.train()
total_train_loss = 0.0 total_train_loss = 0.0
for batch in train_loader: for batch in tqdm(train_loader):
if batch is None: if batch is None:
continue # Skip empty batches continue # Skip empty batches