added docs
This commit is contained in:
parent
24b6a2a8a7
commit
395f2775c8
Binary file not shown.
After Width: | Height: | Size: 128 KiB |
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 389 KiB |
|
@ -0,0 +1,173 @@
|
|||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import accuracy_score, classification_report
|
||||
|
||||
class NewsDataset(Dataset):
|
||||
def __init__(self, texts, labels, tokenizer, max_len=512):
|
||||
self.texts = texts
|
||||
self.labels = labels
|
||||
self.tokenizer = tokenizer
|
||||
self.max_len = max_len
|
||||
|
||||
def __len__(self):
|
||||
return len(self.texts)
|
||||
|
||||
def __getitem__(self, item):
|
||||
text = str(self.texts[item])
|
||||
label = self.labels[item]
|
||||
|
||||
encoding = self.tokenizer.encode_plus(
|
||||
text,
|
||||
add_special_tokens=True,
|
||||
max_length=self.max_len,
|
||||
return_token_type_ids=False,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors='pt',
|
||||
)
|
||||
|
||||
return {
|
||||
'text': text,
|
||||
'input_ids': encoding['input_ids'].flatten(),
|
||||
'attention_mask': encoding['attention_mask'].flatten(),
|
||||
'labels': torch.tensor(label, dtype=torch.long)
|
||||
}
|
||||
|
||||
class FakeNewsTrainer:
|
||||
def __init__(self, model, tokenizer, device):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.model.to(self.device)
|
||||
|
||||
def train(self, train_texts, train_labels, val_texts, val_labels,
|
||||
batch_size=16, num_epochs=5, learning_rate=2e-5):
|
||||
train_dataset = NewsDataset(train_texts, train_labels, self.tokenizer)
|
||||
val_dataset = NewsDataset(val_texts, val_labels, self.tokenizer)
|
||||
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size)
|
||||
|
||||
optimizer = AdamW(self.model.parameters(), lr=learning_rate)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f'Epoch {epoch + 1}/{num_epochs}')
|
||||
self._train_epoch(train_loader, optimizer)
|
||||
accuracy, report = self._evaluate(val_loader)
|
||||
print(f'Validation Accuracy: {accuracy}')
|
||||
print(f'Classification Report:\n{report}')
|
||||
|
||||
def _train_epoch(self, data_loader, optimizer):
|
||||
self.model.train()
|
||||
for batch in data_loader:
|
||||
optimizer.zero_grad()
|
||||
input_ids = batch['input_ids'].to(self.device)
|
||||
attention_mask = batch['attention_mask'].to(self.device)
|
||||
labels = batch['labels'].to(self.device)
|
||||
outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
def _evaluate(self, data_loader):
|
||||
self.model.eval()
|
||||
predictions = []
|
||||
actual_labels = []
|
||||
with torch.no_grad():
|
||||
for batch in data_loader:
|
||||
input_ids = batch['input_ids'].to(self.device)
|
||||
attention_mask = batch['attention_mask'].to(self.device)
|
||||
labels = batch['labels'].to(self.device)
|
||||
outputs = self.model(input_ids, attention_mask=attention_mask)
|
||||
_, preds = torch.max(outputs.logits, dim=1)
|
||||
predictions.extend(preds.cpu().tolist())
|
||||
actual_labels.extend(labels.cpu().tolist())
|
||||
return accuracy_score(actual_labels, predictions), classification_report(actual_labels, predictions)
|
||||
|
||||
def save_model(self, path):
|
||||
self.model.save_pretrained(path)
|
||||
self.tokenizer.save_pretrained(path)
|
||||
|
||||
class FakeNewsInference:
|
||||
def __init__(self, model_path, device):
|
||||
self.model = BertForSequenceClassification.from_pretrained(model_path)
|
||||
self.tokenizer = BertTokenizer.from_pretrained(model_path)
|
||||
self.model.to(device)
|
||||
self.model.eval()
|
||||
self.device = device
|
||||
|
||||
def predict(self, text):
|
||||
encoding = self.tokenizer.encode_plus(
|
||||
text,
|
||||
add_special_tokens=True,
|
||||
max_length=512,
|
||||
return_token_type_ids=False,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors='pt',
|
||||
)
|
||||
input_ids = encoding['input_ids'].to(self.device)
|
||||
attention_mask = encoding['attention_mask'].to(self.device)
|
||||
with torch.no_grad():
|
||||
outputs = self.model(input_ids, attention_mask=attention_mask)
|
||||
_, preds = torch.max(outputs.logits, dim=1)
|
||||
return 'Real' if preds.item() == 1 else 'Fake'
|
||||
|
||||
class FakeNewsModel:
|
||||
def __init__(self, model_path=None):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
if model_path:
|
||||
self.inference = FakeNewsInference(model_path, self.device)
|
||||
self.tokenizer = self.inference.tokenizer
|
||||
else:
|
||||
self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
|
||||
self.inference = None
|
||||
|
||||
def train(self, csv_path, model_save_path, test_size=0.2, **kwargs):
|
||||
df = pd.read_csv(csv_path)
|
||||
df['combined'] = df['Title'] + ' ' + df['Text']
|
||||
|
||||
train_texts, val_texts, train_labels, val_labels = train_test_split(
|
||||
df['combined'].tolist(), df['Label'].tolist(), test_size=test_size, random_state=42
|
||||
)
|
||||
|
||||
model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=2)
|
||||
trainer = FakeNewsTrainer(model, self.tokenizer, self.device)
|
||||
trainer.train(train_texts, train_labels, val_texts, val_labels, **kwargs)
|
||||
trainer.save_model(model_save_path)
|
||||
|
||||
self.inference = FakeNewsInference(model_save_path, self.device)
|
||||
|
||||
def predict(self, text):
|
||||
if self.inference is None:
|
||||
raise ValueError("Model not trained or loaded. Call train() or load a pre-trained model.")
|
||||
return self.inference.predict(text)
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Initialize the model
|
||||
fake_news_model = FakeNewsModel()
|
||||
|
||||
# Train the model
|
||||
fake_news_model.train(
|
||||
csv_path='/root/schule/WELFake_Dataset.csv',
|
||||
model_save_path='VeriMind',
|
||||
batch_size=32,
|
||||
num_epochs=13,
|
||||
learning_rate=2e-5
|
||||
)
|
||||
|
||||
# Make a prediction
|
||||
sample_text = "Your sample news article text here"
|
||||
prediction = fake_news_model.predict(sample_text)
|
||||
print(f"The article is predicted to be: {prediction}")
|
||||
|
||||
# Load a pre-trained model
|
||||
pretrained_model = FakeNewsModel('VeriMind')
|
||||
prediction = pretrained_model.predict(sample_text)
|
||||
print(f"Prediction from pre-trained model: {prediction}")
|
Reference in New Issue