updated pretrain method
This commit is contained in:
parent
de3d58f6db
commit
32daaadddd
|
@ -34,7 +34,7 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
batch_size=32,
|
||||
val_split=0.2,
|
||||
seed=42,
|
||||
column="file_path",
|
||||
column="image_bytes",
|
||||
label_column=None
|
||||
)
|
||||
|
||||
|
@ -71,8 +71,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
# Training phase
|
||||
model.train()
|
||||
total_train_loss = 0.0
|
||||
denoise_train_loss = 0.0
|
||||
rotate_train_loss = 0.0
|
||||
|
||||
for batch in train_dataloader:
|
||||
images, targets, tasks = zip(*batch)
|
||||
|
@ -144,6 +142,6 @@ def pretrain_model(data_path1, data_path2, num_epochs=3):
|
|||
print("Best model saved!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path1 = "/root/training_data/vision-dataset/images_dataset.parquet"
|
||||
data_path1 = "/root/training_data/vision-dataset/images_checkpoint.parquet"
|
||||
data_path2 = "/root/training_data/vision-dataset/vec_images_dataset.parquet"
|
||||
pretrain_model(data_path1, data_path2, num_epochs=8)
|
Loading…
Reference in New Issue