diff --git a/src/model/train.py b/src/model/train.py index fc9b014..7f0671c 100644 --- a/src/model/train.py +++ b/src/model/train.py @@ -76,6 +76,8 @@ class FakeNewsModelTrainer: total_steps = len(train_dataloader) * epochs scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps) + self.model.to(self.device) # Ensure model is on the correct device + for epoch in range(epochs): self.model.train() total_loss = 0 @@ -128,6 +130,7 @@ class FakeNewsModelTrainer: self.model.save_pretrained(path) self.tokenizer.save_pretrained(path) + class FakeNewsInference: def __init__(self, model_path): self.tokenizer = BertTokenizer.from_pretrained(model_path)