Compare commits

...

2 Commits

Author SHA1 Message Date
Falko Victor Habel 4af1d651da icnreased epoch 2024-09-01 16:08:50 +02:00
Falko Victor Habel da78ad357a removed model error when it comes to storage 2024-09-01 16:08:33 +02:00
1 changed files with 4 additions and 1 deletions

View File

@ -67,7 +67,7 @@ class FakeNewsModelTrainer:
return TensorDataset(input_ids, attention_mask, labels), sampler
def train(self, train_data, val_data, epochs=5, batch_size=32):
def train(self, train_data, val_data, epochs=13, batch_size=16):
train_dataset, train_sampler = train_data
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
@ -76,6 +76,8 @@ class FakeNewsModelTrainer:
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
self.model.to(self.device) # Ensure model is on the correct device
for epoch in range(epochs):
self.model.train()
total_loss = 0
@ -128,6 +130,7 @@ class FakeNewsModelTrainer:
self.model.save_pretrained(path)
self.tokenizer.save_pretrained(path)
class FakeNewsInference:
def __init__(self, model_path):
self.tokenizer = BertTokenizer.from_pretrained(model_path)