Fabelous-Ai-Chat/scripts/Rag.py

96 lines
3.9 KiB
Python

from langchain_community.document_loaders import UnstructuredHTMLLoader, WebBaseLoader, UnstructuredMarkdownLoader, PyPDFLoader
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.chat_models import ChatOllama
from langchain.chains import RetrievalQA
import requests
class Rag:
def __init__(self, embeddings, model,
base_url_llm,base_url_embed, base_header,
embeddings_header):
self.retriever = None
self.data = None
self.model = model
self.base_url_llm = base_url_llm
self.base_url_embed = base_url_embed
# check if header exists
if self.is_empty(base_header):
self.base_header = ""
else:
self.base_header = base_header
if self.is_empty(embeddings_header):
self.embeddings_header = ""
else:
self.embeddings_header = embeddings_header
self.embeddings = OllamaEmbeddings(model=embeddings, headers=self.embeddings_header, base_url=self.base_url_embed)
def init_ollama(self):
self.chat_ollama = ChatOllama(
base_url=self.base_url_llm,
model=self.model,
headers = self.base_header
)
def get_file(self, file_path: str):
# Check if the file path starts with 'https://'
if file_path.startswith('https://'):
try:
loader = WebBaseLoader(file_path)
data = loader.load()
if data is None:
return False
except requests.exceptions.SSLError:
return False
else:
file_type = file_path.split(".")[-1]
try:
match file_type:
case "csv":
loader = CSVLoader(file_path=file_path)
data = loader.load()
case "html": # Corrected the typo in the variable name
loader = UnstructuredHTMLLoader(file_path=file_path)
data = loader.load()
case "md":
loader = UnstructuredMarkdownLoader(file_path=file_path)
data = loader.load()
case "pdf":
loader = PyPDFLoader(file_path=file_path)
data = loader.load_and_split()
case _:
return False
except OSError:
return False
self.data = data
return True
def is_empty(self, dictionary: dict) -> bool:
return len(dictionary) == 1 and list(dictionary.keys())[0] == '' and list(dictionary.values())[0] == ''
def receive_data(self, file_path: str):
try:
if self.get_file(file_path):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=0)
splitted = text_splitter.split_documents(self.data)
self.retriever = Chroma.from_documents(documents=splitted, embedding=self.embeddings).as_retriever()
return (False, "Success")
else:
return (True, f"'{file_path}' unsupported, read documentation for more information")
except (ValueError, AttributeError) as e:
return (True, f"An unexpected Error occuried: {e}")
def get_request(self, prompt: str) -> str:
qachain=RetrievalQA.from_chain_type(self.chat_ollama, retriever=self.retriever)
try:
return qachain.invoke({"query": prompt})["result"]
except ValueError as e:
return f"An unexpected Error occuried: {e}"