removed json loade, because it was buggy
This commit is contained in:
parent
4dc355b2cb
commit
8d24965c67
|
@ -4,7 +4,10 @@ class OllamaChatBot:
|
|||
def __init__(self, base_url, model, headers):
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
self.headers = headers
|
||||
if self.is_empty(headers):
|
||||
self.headers = ""
|
||||
else:
|
||||
self.headers = headers
|
||||
self.messanges = []
|
||||
|
||||
if headers is None:
|
||||
|
@ -18,6 +21,10 @@ class OllamaChatBot:
|
|||
model=self.model,
|
||||
headers = self.headers
|
||||
)
|
||||
|
||||
def is_empty(self, dictionary):
|
||||
return len(dictionary) == 1 and list(dictionary.keys())[0] == '' and list(dictionary.values())[0] == ''
|
||||
|
||||
|
||||
|
||||
def get_request(self, prompt):
|
||||
|
|
|
@ -25,19 +25,22 @@ class ChatGUI(CTk.CTk):
|
|||
self.start_message_processing_thread()
|
||||
|
||||
def get_response_from_ollama(self, prompt, context):
|
||||
if context != "":
|
||||
if self.context != context:
|
||||
checks = self.rag.receive_data(file_path=context)
|
||||
if checks[0]:
|
||||
return checks[1]
|
||||
else:
|
||||
self.context = context
|
||||
self.rag.init_ollama()
|
||||
|
||||
return self.rag.get_request(prompt=prompt)
|
||||
try:
|
||||
if context != "":
|
||||
if self.context != context:
|
||||
checks = self.rag.receive_data(file_path=context)
|
||||
if checks[0]:
|
||||
return checks[1]
|
||||
else:
|
||||
self.context = context
|
||||
self.rag.init_ollama()
|
||||
|
||||
return self.rag.get_request(prompt=prompt)
|
||||
|
||||
else:
|
||||
return self.bot.get_request(prompt=prompt)
|
||||
else:
|
||||
return self.bot.get_request(prompt=prompt)
|
||||
except ValueError:
|
||||
return "An unexpected Error occuried"
|
||||
|
||||
def on_send(self, event=None):
|
||||
message = self.entry_bar.get().strip()
|
||||
|
@ -65,7 +68,8 @@ class ChatGUI(CTk.CTk):
|
|||
|
||||
def select_file(self):
|
||||
file_path = filedialog.askopenfilename()
|
||||
self.file_entry.insert(1, file_path)
|
||||
self.file_entry.delete(0, "end")
|
||||
self.file_entry.insert(0, file_path)
|
||||
|
||||
def create_widgets(self):
|
||||
self.geometry("900x600")
|
||||
|
@ -109,6 +113,8 @@ class ChatGUI(CTk.CTk):
|
|||
for message in self.history:
|
||||
message.pack_forget()
|
||||
self.history = []
|
||||
self.bot.messanges = []
|
||||
self.rag.init_ollama()
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -6,8 +6,6 @@ from langchain_community.embeddings import OllamaEmbeddings
|
|||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from langchain.chains import RetrievalQA
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
|
||||
|
||||
|
@ -21,8 +19,15 @@ class Rag:
|
|||
|
||||
self.base_url_llm = base_url_llm
|
||||
self.base_url_embed = base_url_embed
|
||||
self.base_header = base_header
|
||||
self.embeddings_header = embeddings_header
|
||||
|
||||
if self.is_empty(base_header):
|
||||
self.base_header = ""
|
||||
else:
|
||||
self.base_header = base_header
|
||||
if self.is_empty(embeddings_header):
|
||||
self.embeddings_header = ""
|
||||
else:
|
||||
self.embeddings_header = embeddings_header
|
||||
self.embeddings = OllamaEmbeddings(model=embeddings, headers=self.embeddings_header, base_url=self.base_url_embed)
|
||||
|
||||
def init_ollama(self):
|
||||
|
@ -49,8 +54,6 @@ class Rag:
|
|||
case "html": # Corrected the typo in the variable name
|
||||
loader = UnstructuredHTMLLoader(file_path=file_path)
|
||||
data = loader.load()
|
||||
case "json":
|
||||
data = json.loads(Path(file_path).read_text())
|
||||
case "md":
|
||||
loader = UnstructuredMarkdownLoader(file_path=file_path)
|
||||
data = loader.load()
|
||||
|
@ -67,17 +70,26 @@ class Rag:
|
|||
return True
|
||||
|
||||
|
||||
def is_empty(self, dictionary):
|
||||
return len(dictionary) == 1 and list(dictionary.keys())[0] == '' and list(dictionary.values())[0] == ''
|
||||
|
||||
|
||||
|
||||
def receive_data(self, file_path):
|
||||
if self.get_file(file_path):
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=0)
|
||||
splitted = text_splitter.split_documents(self.data)
|
||||
self.retriever = Chroma.from_documents(documents=splitted, embedding=self.embeddings).as_retriever()
|
||||
return (False, "Success")
|
||||
else:
|
||||
return (True, f"'{file_path}' unsupported, read documentation for more information")
|
||||
|
||||
try:
|
||||
if self.get_file(file_path):
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=0)
|
||||
splitted = text_splitter.split_documents(self.data)
|
||||
self.retriever = Chroma.from_documents(documents=splitted, embedding=self.embeddings).as_retriever()
|
||||
return (False, "Success")
|
||||
else:
|
||||
return (True, f"'{file_path}' unsupported, read documentation for more information")
|
||||
except (ValueError, AttributeError):
|
||||
return (True, "An unexpected Error occuried")
|
||||
def get_request(self, prompt):
|
||||
qachain=RetrievalQA.from_chain_type(self.chat_ollama, retriever=self.retriever)
|
||||
return qachain.invoke({"query": prompt})["result"]
|
||||
try:
|
||||
return qachain.invoke({"query": prompt})["result"]
|
||||
except ValueError:
|
||||
return (True, "An unexpected Error occuried")
|
||||
|
Loading…
Reference in New Issue