83 lines
3.4 KiB
Python
83 lines
3.4 KiB
Python
|
from langchain_community.document_loaders import UnstructuredHTMLLoader, WebBaseLoader, UnstructuredMarkdownLoader, PyPDFLoader
|
||
|
from langchain_community.document_loaders.csv_loader import CSVLoader
|
||
|
from langchain_community.embeddings import OllamaEmbeddings
|
||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
|
from langchain_community.embeddings import OllamaEmbeddings
|
||
|
from langchain_community.vectorstores import Chroma
|
||
|
from langchain_community.chat_models import ChatOllama
|
||
|
from langchain.chains import RetrievalQA
|
||
|
from pathlib import Path
|
||
|
import json
|
||
|
|
||
|
|
||
|
|
||
|
class Rag:
|
||
|
def __init__(self, embeddings, model,
|
||
|
base_url_llm,base_url_embed, base_header,
|
||
|
embeddings_header):
|
||
|
self.retriever = None
|
||
|
self.data = None
|
||
|
self.model = model
|
||
|
|
||
|
self.base_url_llm = base_url_llm
|
||
|
self.base_url_embed = base_url_embed
|
||
|
self.base_header = base_header
|
||
|
self.embeddings_header = embeddings_header
|
||
|
self.embeddings = OllamaEmbeddings(model=embeddings, headers=self.embeddings_header, base_url=self.base_url_embed)
|
||
|
|
||
|
def init_ollama(self):
|
||
|
self.chat_ollama = ChatOllama(
|
||
|
base_url=self.base_url_llm,
|
||
|
model=self.model,
|
||
|
headers = self.base_header
|
||
|
)
|
||
|
|
||
|
def get_file(self, file_path):
|
||
|
# Check if the file path starts with 'https://'
|
||
|
if file_path.startswith('https://'):
|
||
|
loader = WebBaseLoader(file_path)
|
||
|
data = loader.load()
|
||
|
if data is None:
|
||
|
return False
|
||
|
else:
|
||
|
file_type = file_path.split(".")[-1]
|
||
|
try:
|
||
|
match file_type:
|
||
|
case "csv":
|
||
|
loader = CSVLoader(file_path=file_path)
|
||
|
data = loader.load()
|
||
|
case "html": # Corrected the typo in the variable name
|
||
|
loader = UnstructuredHTMLLoader(file_path=file_path)
|
||
|
data = loader.load()
|
||
|
case "json":
|
||
|
data = json.loads(Path(file_path).read_text())
|
||
|
case "md":
|
||
|
loader = UnstructuredMarkdownLoader(file_path=file_path)
|
||
|
data = loader.load()
|
||
|
case "pdf":
|
||
|
loader = PyPDFLoader(file_path=file_path)
|
||
|
data = loader.load_and_split()
|
||
|
case _:
|
||
|
loader = WebBaseLoader(file_path)
|
||
|
data = loader.load()
|
||
|
except OSError:
|
||
|
return False
|
||
|
|
||
|
self.data = data
|
||
|
return True
|
||
|
|
||
|
|
||
|
|
||
|
def receive_data(self, file_path):
|
||
|
if self.get_file(file_path):
|
||
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=0)
|
||
|
splitted = text_splitter.split_documents(self.data)
|
||
|
self.retriever = Chroma.from_documents(documents=splitted, embedding=self.embeddings).as_retriever()
|
||
|
return (False, "Success")
|
||
|
else:
|
||
|
return (True, f"'{file_path}' unsupported, read documentation for more information")
|
||
|
|
||
|
def get_request(self, prompt):
|
||
|
qachain=RetrievalQA.from_chain_type(self.chat_ollama, retriever=self.retriever)
|
||
|
return qachain.invoke({"query": prompt})["result"]
|
||
|
|