updated pretrain method

This commit is contained in:
Falko Victor Habel 2025-01-26 13:05:24 +01:00
parent de3d58f6db
commit 32daaadddd
1 changed files with 2 additions and 4 deletions

View File

@ -34,7 +34,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
batch_size=32,
val_split=0.2,
seed=42,
column="file_path",
column="image_bytes",
label_column=None
)
@ -71,8 +71,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
# Training phase
model.train()
total_train_loss = 0.0
denoise_train_loss = 0.0
rotate_train_loss = 0.0
for batch in train_dataloader:
images, targets, tasks = zip(*batch)
@ -144,6 +142,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
print("Best model saved!")
if __name__ == "__main__":
data_path1 = "/root/training_data/vision-dataset/images_dataset.parquet"
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
pretrain_model(data_path1, data_path2, num_epochs=8)