inference class added
This commit is contained in:
parent
cbfcad6088
commit
de0699d6ba
|
@ -0,0 +1,32 @@
|
||||||
|
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class FakeNewsInference:
|
||||||
|
def __init__(self, model_path):
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained(model_path)
|
||||||
|
self.model = BertForSequenceClassification.from_pretrained(model_path)
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def predict(self, text):
|
||||||
|
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
probabilities = torch.softmax(outputs.logits, dim=1)
|
||||||
|
prediction = torch.argmax(probabilities, dim=1).item()
|
||||||
|
|
||||||
|
return prediction, probabilities[0][prediction].item()
|
||||||
|
|
||||||
|
# Usage example
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Inference example
|
||||||
|
inference = FakeNewsInference('VeriMind')
|
||||||
|
sample_text = "Breaking news: Scientists discover new planet in solar system"
|
||||||
|
prediction, confidence = inference.predict(sample_text)
|
||||||
|
print(f"Prediction: {'Real' if prediction == 1 else 'Fake'}")
|
||||||
|
print(f"Confidence: {confidence:.4f}")
|
Reference in New Issue