inference class added

This commit is contained in:
Falko Victor Habel 2024-08-31 08:06:28 +02:00
parent cbfcad6088
commit de0699d6ba
1 changed files with 32 additions and 0 deletions

32
src/model/Inference.py Normal file
View File

@ -0,0 +1,32 @@
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
import pyarrow.parquet as pq
import torch
class FakeNewsInference:
def __init__(self, model_path):
self.tokenizer = BertTokenizer.from_pretrained(model_path)
self.model = BertForSequenceClassification.from_pretrained(model_path)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
def predict(self, text):
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1)
prediction = torch.argmax(probabilities, dim=1).item()
return prediction, probabilities[0][prediction].item()
# Usage example
if __name__ == '__main__':
# Inference example
inference = FakeNewsInference('VeriMind')
sample_text = "Breaking news: Scientists discover new planet in solar system"
prediction, confidence = inference.predict(sample_text)
print(f"Prediction: {'Real' if prediction == 1 else 'Fake'}")
print(f"Confidence: {confidence:.4f}")