diff --git a/src/model/train.py b/src/model/train.py index 7f0671c..3b39b89 100644 --- a/src/model/train.py +++ b/src/model/train.py @@ -67,7 +67,7 @@ class FakeNewsModelTrainer: return TensorDataset(input_ids, attention_mask, labels), sampler - def train(self, train_data, val_data, epochs=5, batch_size=32): + def train(self, train_data, val_data, epochs=13, batch_size=16): train_dataset, train_sampler = train_data train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size) val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)