95 lines
3.9 KiB
Python
95 lines
3.9 KiB
Python
from langchain_community.document_loaders import UnstructuredHTMLLoader, WebBaseLoader, UnstructuredMarkdownLoader, PyPDFLoader
|
|
from langchain_community.document_loaders.csv_loader import CSVLoader
|
|
from langchain_community.embeddings import OllamaEmbeddings
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from langchain_community.embeddings import OllamaEmbeddings
|
|
from langchain_community.vectorstores import Chroma
|
|
from langchain_community.chat_models import ChatOllama
|
|
from langchain.chains import RetrievalQA
|
|
|
|
|
|
|
|
class Rag:
|
|
def __init__(self, embeddings, model,
|
|
base_url_llm,base_url_embed, base_header,
|
|
embeddings_header):
|
|
self.retriever = None
|
|
self.data = None
|
|
self.model = model
|
|
|
|
self.base_url_llm = base_url_llm
|
|
self.base_url_embed = base_url_embed
|
|
|
|
if self.is_empty(base_header):
|
|
self.base_header = ""
|
|
else:
|
|
self.base_header = base_header
|
|
if self.is_empty(embeddings_header):
|
|
self.embeddings_header = ""
|
|
else:
|
|
self.embeddings_header = embeddings_header
|
|
self.embeddings = OllamaEmbeddings(model=embeddings, headers=self.embeddings_header, base_url=self.base_url_embed)
|
|
|
|
def init_ollama(self):
|
|
self.chat_ollama = ChatOllama(
|
|
base_url=self.base_url_llm,
|
|
model=self.model,
|
|
headers = self.base_header
|
|
)
|
|
|
|
def get_file(self, file_path):
|
|
# Check if the file path starts with 'https://'
|
|
if file_path.startswith('https://'):
|
|
loader = WebBaseLoader(file_path)
|
|
data = loader.load()
|
|
if data is None:
|
|
return False
|
|
else:
|
|
file_type = file_path.split(".")[-1]
|
|
try:
|
|
match file_type:
|
|
case "csv":
|
|
loader = CSVLoader(file_path=file_path)
|
|
data = loader.load()
|
|
case "html": # Corrected the typo in the variable name
|
|
loader = UnstructuredHTMLLoader(file_path=file_path)
|
|
data = loader.load()
|
|
case "md":
|
|
loader = UnstructuredMarkdownLoader(file_path=file_path)
|
|
data = loader.load()
|
|
case "pdf":
|
|
loader = PyPDFLoader(file_path=file_path)
|
|
data = loader.load_and_split()
|
|
case _:
|
|
loader = WebBaseLoader(file_path)
|
|
data = loader.load()
|
|
except OSError:
|
|
return False
|
|
|
|
self.data = data
|
|
return True
|
|
|
|
|
|
def is_empty(self, dictionary):
|
|
return len(dictionary) == 1 and list(dictionary.keys())[0] == '' and list(dictionary.values())[0] == ''
|
|
|
|
|
|
|
|
def receive_data(self, file_path):
|
|
try:
|
|
if self.get_file(file_path):
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=0)
|
|
splitted = text_splitter.split_documents(self.data)
|
|
self.retriever = Chroma.from_documents(documents=splitted, embedding=self.embeddings).as_retriever()
|
|
return (False, "Success")
|
|
else:
|
|
return (True, f"'{file_path}' unsupported, read documentation for more information")
|
|
except (ValueError, AttributeError):
|
|
return (True, "An unexpected Error occuried")
|
|
def get_request(self, prompt):
|
|
qachain=RetrievalQA.from_chain_type(self.chat_ollama, retriever=self.retriever)
|
|
try:
|
|
return qachain.invoke({"query": prompt})["result"]
|
|
except ValueError:
|
|
return (True, "An unexpected Error occuried")
|
|
|