downszied batchsize
This commit is contained in:
parent
32526c3c30
commit
8dad1d7150
|
@ -9,13 +9,13 @@ import copy
|
|||
|
||||
def pretrain_model(data_path1, data_path2, num_epochs=3):
|
||||
# Read and merge datasets
|
||||
df1 = pd.read_parquet(data_path1).head(2500)
|
||||
df2 = pd.read_parquet(data_path2).head(2500)
|
||||
df1 = pd.read_parquet(data_path1).head(10000)
|
||||
df2 = pd.read_parquet(data_path2).head(10000)
|
||||
merged_df = pd.concat([df1, df2], ignore_index=True)
|
||||
|
||||
# Model configuration
|
||||
config = AIIAConfig(
|
||||
model_name="AIIA-Base-512x5k",
|
||||
model_name="AIIA-Base-512x20k",
|
||||
)
|
||||
|
||||
# Initialize model and data loader
|
||||
|
@ -49,7 +49,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
aiia_loader = AIIADataLoader(
|
||||
merged_df,
|
||||
column="image_bytes",
|
||||
batch_size=32,
|
||||
batch_size=8,
|
||||
pretraining=True,
|
||||
collate_fn=safe_collate
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue