added the Ollama Chat instances with RAG and without
This commit is contained in:
parent
fe347ab150
commit
bd452e51f0
|
@ -0,0 +1,30 @@
|
||||||
|
from langchain_community.chat_models import ChatOllama
|
||||||
|
|
||||||
|
class OllamaChatBot:
|
||||||
|
def __init__(self, base_url, model, headers):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.model = model
|
||||||
|
self.headers = headers
|
||||||
|
self.messanges = []
|
||||||
|
|
||||||
|
if headers is None:
|
||||||
|
self.ollama = ChatOllama(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.ollama = ChatOllama(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
headers = self.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_request(self, prompt):
|
||||||
|
messanges = []
|
||||||
|
self.messanges.append(prompt)
|
||||||
|
if len(self.messanges) > 5:
|
||||||
|
messanges = messanges[:5]
|
||||||
|
else:
|
||||||
|
messanges = self.messanges
|
||||||
|
return self.ollama.invoke(messanges).content
|
|
@ -0,0 +1,83 @@
|
||||||
|
from langchain_community.document_loaders import UnstructuredHTMLLoader, WebBaseLoader, UnstructuredMarkdownLoader, PyPDFLoader
|
||||||
|
from langchain_community.document_loaders.csv_loader import CSVLoader
|
||||||
|
from langchain_community.embeddings import OllamaEmbeddings
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_community.embeddings import OllamaEmbeddings
|
||||||
|
from langchain_community.vectorstores import Chroma
|
||||||
|
from langchain_community.chat_models import ChatOllama
|
||||||
|
from langchain.chains import RetrievalQA
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Rag:
|
||||||
|
def __init__(self, embeddings, model,
|
||||||
|
base_url_llm,base_url_embed, base_header,
|
||||||
|
embeddings_header):
|
||||||
|
self.retriever = None
|
||||||
|
self.data = None
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
self.base_url_llm = base_url_llm
|
||||||
|
self.base_url_embed = base_url_embed
|
||||||
|
self.base_header = base_header
|
||||||
|
self.embeddings_header = embeddings_header
|
||||||
|
self.embeddings = OllamaEmbeddings(model=embeddings, headers=self.embeddings_header, base_url=self.base_url_embed)
|
||||||
|
|
||||||
|
def init_ollama(self):
|
||||||
|
self.chat_ollama = ChatOllama(
|
||||||
|
base_url=self.base_url_llm,
|
||||||
|
model=self.model,
|
||||||
|
headers = self.base_header
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_file(self, file_path):
|
||||||
|
# Check if the file path starts with 'https://'
|
||||||
|
if file_path.startswith('https://'):
|
||||||
|
loader = WebBaseLoader(file_path)
|
||||||
|
data = loader.load()
|
||||||
|
if data is None:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
file_type = file_path.split(".")[-1]
|
||||||
|
try:
|
||||||
|
match file_type:
|
||||||
|
case "csv":
|
||||||
|
loader = CSVLoader(file_path=file_path)
|
||||||
|
data = loader.load()
|
||||||
|
case "html": # Corrected the typo in the variable name
|
||||||
|
loader = UnstructuredHTMLLoader(file_path=file_path)
|
||||||
|
data = loader.load()
|
||||||
|
case "json":
|
||||||
|
data = json.loads(Path(file_path).read_text())
|
||||||
|
case "md":
|
||||||
|
loader = UnstructuredMarkdownLoader(file_path=file_path)
|
||||||
|
data = loader.load()
|
||||||
|
case "pdf":
|
||||||
|
loader = PyPDFLoader(file_path=file_path)
|
||||||
|
data = loader.load_and_split()
|
||||||
|
case _:
|
||||||
|
loader = WebBaseLoader(file_path)
|
||||||
|
data = loader.load()
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.data = data
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def receive_data(self, file_path):
|
||||||
|
if self.get_file(file_path):
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=0)
|
||||||
|
splitted = text_splitter.split_documents(self.data)
|
||||||
|
self.retriever = Chroma.from_documents(documents=splitted, embedding=self.embeddings).as_retriever()
|
||||||
|
return (False, "Success")
|
||||||
|
else:
|
||||||
|
return (True, f"'{file_path}' unsupported, read documentation for more information")
|
||||||
|
|
||||||
|
def get_request(self, prompt):
|
||||||
|
qachain=RetrievalQA.from_chain_type(self.chat_ollama, retriever=self.retriever)
|
||||||
|
return qachain.invoke({"query": prompt})["result"]
|
||||||
|
|
Loading…
Reference in New Issue