added progressbar for batches
This commit is contained in:
parent
b546f4ee27
commit
6c146f2767
|
@ -4,8 +4,8 @@ import pandas as pd
|
||||||
from aiia.model.config import AIIAConfig
|
from aiia.model.config import AIIAConfig
|
||||||
from aiia.model import AIIABase
|
from aiia.model import AIIABase
|
||||||
from aiia.data.DataLoader import AIIADataLoader
|
from aiia.data.DataLoader import AIIADataLoader
|
||||||
import os
|
from tqdm import tqdm
|
||||||
import copy
|
|
||||||
|
|
||||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
# Read and merge datasets
|
# Read and merge datasets
|
||||||
|
@ -49,7 +49,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
aiia_loader = AIIADataLoader(
|
aiia_loader = AIIADataLoader(
|
||||||
merged_df,
|
merged_df,
|
||||||
column="image_bytes",
|
column="image_bytes",
|
||||||
batch_size=4,
|
batch_size=2,
|
||||||
pretraining=True,
|
pretraining=True,
|
||||||
collate_fn=safe_collate
|
collate_fn=safe_collate
|
||||||
)
|
)
|
||||||
|
@ -75,7 +75,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||||
model.train()
|
model.train()
|
||||||
total_train_loss = 0.0
|
total_train_loss = 0.0
|
||||||
|
|
||||||
for batch in train_loader:
|
for batch in tqdm(train_loader):
|
||||||
if batch is None:
|
if batch is None:
|
||||||
continue # Skip empty batches
|
continue # Skip empty batches
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue