Compare commits
No commits in common. "4af1d651da54cb4aaf89abb4b49d01f47d09fd52" and "5d8ec1a01feacf422b7dcb551741396b73af3a4f" have entirely different histories.
4af1d651da
...
5d8ec1a01f
|
@ -67,7 +67,7 @@ class FakeNewsModelTrainer:
|
||||||
|
|
||||||
return TensorDataset(input_ids, attention_mask, labels), sampler
|
return TensorDataset(input_ids, attention_mask, labels), sampler
|
||||||
|
|
||||||
def train(self, train_data, val_data, epochs=13, batch_size=16):
|
def train(self, train_data, val_data, epochs=5, batch_size=32):
|
||||||
train_dataset, train_sampler = train_data
|
train_dataset, train_sampler = train_data
|
||||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)
|
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)
|
||||||
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
|
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
|
||||||
|
@ -76,8 +76,6 @@ class FakeNewsModelTrainer:
|
||||||
total_steps = len(train_dataloader) * epochs
|
total_steps = len(train_dataloader) * epochs
|
||||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
|
||||||
|
|
||||||
self.model.to(self.device) # Ensure model is on the correct device
|
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
@ -130,7 +128,6 @@ class FakeNewsModelTrainer:
|
||||||
self.model.save_pretrained(path)
|
self.model.save_pretrained(path)
|
||||||
self.tokenizer.save_pretrained(path)
|
self.tokenizer.save_pretrained(path)
|
||||||
|
|
||||||
|
|
||||||
class FakeNewsInference:
|
class FakeNewsInference:
|
||||||
def __init__(self, model_path):
|
def __init__(self, model_path):
|
||||||
self.tokenizer = BertTokenizer.from_pretrained(model_path)
|
self.tokenizer = BertTokenizer.from_pretrained(model_path)
|
||||||
|
|
Reference in New Issue