test_model #13

Merged
Fabel merged 21 commits from test_model into develop 2024-09-03 08:53:54 +00:00
1 changed files with 3 additions and 0 deletions
Showing only changes of commit da78ad357a - Show all commits

View File

@ -76,6 +76,8 @@ class FakeNewsModelTrainer:
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
self.model.to(self.device) # Ensure model is on the correct device
for epoch in range(epochs):
self.model.train()
total_loss = 0
@ -128,6 +130,7 @@ class FakeNewsModelTrainer:
self.model.save_pretrained(path)
self.tokenizer.save_pretrained(path)
class FakeNewsInference:
def __init__(self, model_path):
self.tokenizer = BertTokenizer.from_pretrained(model_path)