from langchain_community.document_loaders import UnstructuredHTMLLoader, WebBaseLoader, UnstructuredMarkdownLoader, PyPDFLoader from langchain_community.document_loaders.csv_loader import CSVLoader from langchain_community.embeddings import OllamaEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain_community.chat_models import ChatOllama from langchain.chains import RetrievalQA import requests class Rag: def __init__(self, embeddings, model, base_url_llm,base_url_embed, base_header, embeddings_header): self.retriever = None self.data = None self.model = model self.base_url_llm = base_url_llm self.base_url_embed = base_url_embed # check if header exists if self.is_empty(base_header): self.base_header = "" else: self.base_header = base_header if self.is_empty(embeddings_header): self.embeddings_header = "" else: self.embeddings_header = embeddings_header self.embeddings = OllamaEmbeddings(model=embeddings, headers=self.embeddings_header, base_url=self.base_url_embed) def init_ollama(self): self.chat_ollama = ChatOllama( base_url=self.base_url_llm, model=self.model, headers = self.base_header ) def get_file(self, file_path: str): # Check if the file path starts with 'https://' if file_path.startswith('https://'): try: loader = WebBaseLoader(file_path) data = loader.load() if data is None: return False except requests.exceptions.SSLError: return False else: file_type = file_path.split(".")[-1] try: match file_type: case "csv": loader = CSVLoader(file_path=file_path) data = loader.load() case "html": # Corrected the typo in the variable name loader = UnstructuredHTMLLoader(file_path=file_path) data = loader.load() case "md": loader = UnstructuredMarkdownLoader(file_path=file_path) data = loader.load() case "pdf": loader = PyPDFLoader(file_path=file_path) data = loader.load_and_split() case _: return False except OSError: return False self.data = data return True def is_empty(self, dictionary: dict) -> bool: return len(dictionary) == 1 and list(dictionary.keys())[0] == '' and list(dictionary.values())[0] == '' def receive_data(self, file_path: str): try: if self.get_file(file_path): text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=0) splitted = text_splitter.split_documents(self.data) self.retriever = Chroma.from_documents(documents=splitted, embedding=self.embeddings).as_retriever() return (False, "Success") else: return (True, f"'{file_path}' unsupported, read documentation for more information") except (ValueError, AttributeError) as e: return (True, f"An unexpected Error occuried: {e}") def get_request(self, prompt: str) -> str: qachain=RetrievalQA.from_chain_type(self.chat_ollama, retriever=self.retriever) try: return qachain.invoke({"query": prompt})["result"] except ValueError as e: return f"An unexpected Error occuried: {e}"