inference class added
This commit is contained in:
parent
cbfcad6088
commit
de0699d6ba
|
@ -0,0 +1,32 @@
|
|||
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
|
||||
class FakeNewsInference:
|
||||
def __init__(self, model_path):
|
||||
self.tokenizer = BertTokenizer.from_pretrained(model_path)
|
||||
self.model = BertForSequenceClassification.from_pretrained(model_path)
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def predict(self, text):
|
||||
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
probabilities = torch.softmax(outputs.logits, dim=1)
|
||||
prediction = torch.argmax(probabilities, dim=1).item()
|
||||
|
||||
return prediction, probabilities[0][prediction].item()
|
||||
|
||||
# Usage example
|
||||
if __name__ == '__main__':
|
||||
# Inference example
|
||||
inference = FakeNewsInference('VeriMind')
|
||||
sample_text = "Breaking news: Scientists discover new planet in solar system"
|
||||
prediction, confidence = inference.predict(sample_text)
|
||||
print(f"Prediction: {'Real' if prediction == 1 else 'Fake'}")
|
||||
print(f"Confidence: {confidence:.4f}")
|
Reference in New Issue