From de0699d6ba92e111519823d0faa1d21827a5a35d Mon Sep 17 00:00:00 2001 From: Falko Habel Date: Sat, 31 Aug 2024 08:06:28 +0200 Subject: [PATCH] inference class added --- src/model/Inference.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 src/model/Inference.py diff --git a/src/model/Inference.py b/src/model/Inference.py new file mode 100644 index 0000000..b70dca9 --- /dev/null +++ b/src/model/Inference.py @@ -0,0 +1,32 @@ +from transformers import BertTokenizer, BertForSequenceClassification, AdamW +import pyarrow.parquet as pq +import torch + + +class FakeNewsInference: + def __init__(self, model_path): + self.tokenizer = BertTokenizer.from_pretrained(model_path) + self.model = BertForSequenceClassification.from_pretrained(model_path) + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.model.to(self.device) + self.model.eval() + + def predict(self, text): + inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.model(**inputs) + probabilities = torch.softmax(outputs.logits, dim=1) + prediction = torch.argmax(probabilities, dim=1).item() + + return prediction, probabilities[0][prediction].item() + +# Usage example +if __name__ == '__main__': + # Inference example + inference = FakeNewsInference('VeriMind') + sample_text = "Breaking news: Scientists discover new planet in solar system" + prediction, confidence = inference.predict(sample_text) + print(f"Prediction: {'Real' if prediction == 1 else 'Fake'}") + print(f"Confidence: {confidence:.4f}") \ No newline at end of file