diff --git a/scripts/BaseOllama.py b/scripts/BaseOllama.py new file mode 100644 index 0000000..4402669 --- /dev/null +++ b/scripts/BaseOllama.py @@ -0,0 +1,30 @@ +from langchain_community.chat_models import ChatOllama + +class OllamaChatBot: + def __init__(self, base_url, model, headers): + self.base_url = base_url + self.model = model + self.headers = headers + self.messanges = [] + + if headers is None: + self.ollama = ChatOllama( + base_url=self.base_url, + model=self.model, + ) + else: + self.ollama = ChatOllama( + base_url=self.base_url, + model=self.model, + headers = self.headers + ) + + + def get_request(self, prompt): + messanges = [] + self.messanges.append(prompt) + if len(self.messanges) > 5: + messanges = messanges[:5] + else: + messanges = self.messanges + return self.ollama.invoke(messanges).content diff --git a/scripts/Rag.py b/scripts/Rag.py new file mode 100644 index 0000000..9d01c4d --- /dev/null +++ b/scripts/Rag.py @@ -0,0 +1,83 @@ +from langchain_community.document_loaders import UnstructuredHTMLLoader, WebBaseLoader, UnstructuredMarkdownLoader, PyPDFLoader +from langchain_community.document_loaders.csv_loader import CSVLoader +from langchain_community.embeddings import OllamaEmbeddings +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain_community.embeddings import OllamaEmbeddings +from langchain_community.vectorstores import Chroma +from langchain_community.chat_models import ChatOllama +from langchain.chains import RetrievalQA +from pathlib import Path +import json + + + +class Rag: + def __init__(self, embeddings, model, + base_url_llm,base_url_embed, base_header, + embeddings_header): + self.retriever = None + self.data = None + self.model = model + + self.base_url_llm = base_url_llm + self.base_url_embed = base_url_embed + self.base_header = base_header + self.embeddings_header = embeddings_header + self.embeddings = OllamaEmbeddings(model=embeddings, headers=self.embeddings_header, base_url=self.base_url_embed) + + def init_ollama(self): + self.chat_ollama = ChatOllama( + base_url=self.base_url_llm, + model=self.model, + headers = self.base_header + ) + + def get_file(self, file_path): + # Check if the file path starts with 'https://' + if file_path.startswith('https://'): + loader = WebBaseLoader(file_path) + data = loader.load() + if data is None: + return False + else: + file_type = file_path.split(".")[-1] + try: + match file_type: + case "csv": + loader = CSVLoader(file_path=file_path) + data = loader.load() + case "html": # Corrected the typo in the variable name + loader = UnstructuredHTMLLoader(file_path=file_path) + data = loader.load() + case "json": + data = json.loads(Path(file_path).read_text()) + case "md": + loader = UnstructuredMarkdownLoader(file_path=file_path) + data = loader.load() + case "pdf": + loader = PyPDFLoader(file_path=file_path) + data = loader.load_and_split() + case _: + loader = WebBaseLoader(file_path) + data = loader.load() + except OSError: + return False + + self.data = data + return True + + + + def receive_data(self, file_path): + if self.get_file(file_path): + text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=0) + splitted = text_splitter.split_documents(self.data) + self.retriever = Chroma.from_documents(documents=splitted, embedding=self.embeddings).as_retriever() + return (False, "Success") + else: + return (True, f"'{file_path}' unsupported, read documentation for more information") + + def get_request(self, prompt): + qachain=RetrievalQA.from_chain_type(self.chat_ollama, retriever=self.retriever) + return qachain.invoke({"query": prompt})["result"] + \ No newline at end of file diff --git a/scripts/TerminalBot.py b/scripts/TerminalChat.py similarity index 100% rename from scripts/TerminalBot.py rename to scripts/TerminalChat.py