This commit is contained in:
Björn Ruthotto 2024-10-07 12:35:30 +02:00
parent 57cc69ea43
commit 08b990e658
5 changed files with 69 additions and 4 deletions

8
.gitignore vendored
View File

@ -153,6 +153,14 @@ dmypy.json
# Cython debug symbols
cython_debug/
#ML
VeraMind-Mini/
# OS generated files #
######################
.DS_Store
.DS_Store?
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore

0
src/Ai/.gitignore vendored
View File

38
src/Ai/interence.py Normal file
View File

@ -0,0 +1,38 @@
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class VeraMindInference:
def __init__(self, model_path, max_len=512):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
self.max_len = max_len
def predict(self, text):
encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
input_ids = encoding['input_ids'].to(self.device)
attention_mask = encoding['attention_mask'].to(self.device)
with torch.no_grad():
outputs = self.model(input_ids, attention_mask=attention_mask).logits
prediction = torch.sigmoid(outputs).cpu().numpy()[0][0]
is_fake = prediction >= 0.5
confidence = prediction if is_fake else 1 - prediction
return {
"result": is_fake,
"confidence": float(confidence)
}

View File

@ -1,5 +1,6 @@
from views.mainScreen import MainFrame
from models.data import TextData
from Ai.interence import VeraMindInference
class MainFrameController:
def __init__(self,frame:MainFrame) -> None:
@ -17,5 +18,17 @@ class MainFrameController:
def press_check_button(self):
text_data = self.get_textdata()
print(f"text:{text_data.text}")
self.prediction(text_data)
self.frame.output_textbox.configure(state="normal")
self.frame.output_textbox.delete("0.0", "end")
self.frame.output_textbox.insert("0.0",f"{text_data.get_output()}")
self.frame.output_textbox.configure(state="disabled")
def prediction(self, text_data:TextData) -> TextData:
inference = VeraMindInference('VeraMind-Mini')
result = inference.predict(text_data.text)
text_data.confidence = result["confidence"]
text_data.isfake_news = result["result"]
print(f"Prediction: {'Real' if text_data.isfake_news == 1 else 'Fake'}")
print(f"Confidence: {text_data.confidence}")
return text_data

View File

@ -4,6 +4,8 @@ class TextData:
def __init__(self, url: str = "") -> None:
self.url = url
self.text = ""
self.isfake_news = False
self.confidence = None
self._extractor = None
def set_url(self, url: str) -> None:
@ -15,7 +17,7 @@ class TextData:
if not self.url:
print("No url")
return True
if not self.text:
print("Extrahiere Text von URL...")
self._extractor = WebTextExtractor(self.url)
@ -23,6 +25,10 @@ class TextData:
self._extractor.extract_text()
self.text = self._extractor.get_text()
return False
def get_output(self):
if self.confidence != None:
output = f"Prediction: {'Real' if self.isfake_news else 'Fake'}" + f" Confidence: {self.confidence:.4f}"
print(output)
return output