added the Ollama Chat instances with RAG and without

This commit is contained in:
Falko Victor Habel 2024-05-18 21:04:39 +02:00
parent fe347ab150
commit bd452e51f0
3 changed files with 113 additions and 0 deletions

30
scripts/BaseOllama.py Normal file
View File

@ -0,0 +1,30 @@
from langchain_community.chat_models import ChatOllama
class OllamaChatBot:
def __init__(self, base_url, model, headers):
self.base_url = base_url
self.model = model
self.headers = headers
self.messanges = []
if headers is None:
self.ollama = ChatOllama(
base_url=self.base_url,
model=self.model,
)
else:
self.ollama = ChatOllama(
base_url=self.base_url,
model=self.model,
headers = self.headers
)
def get_request(self, prompt):
messanges = []
self.messanges.append(prompt)
if len(self.messanges) > 5:
messanges = messanges[:5]
else:
messanges = self.messanges
return self.ollama.invoke(messanges).content

83
scripts/Rag.py Normal file
View File

@ -0,0 +1,83 @@
from langchain_community.document_loaders import UnstructuredHTMLLoader, WebBaseLoader, UnstructuredMarkdownLoader, PyPDFLoader
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.chat_models import ChatOllama
from langchain.chains import RetrievalQA
from pathlib import Path
import json
class Rag:
def __init__(self, embeddings, model,
base_url_llm,base_url_embed, base_header,
embeddings_header):
self.retriever = None
self.data = None
self.model = model
self.base_url_llm = base_url_llm
self.base_url_embed = base_url_embed
self.base_header = base_header
self.embeddings_header = embeddings_header
self.embeddings = OllamaEmbeddings(model=embeddings, headers=self.embeddings_header, base_url=self.base_url_embed)
def init_ollama(self):
self.chat_ollama = ChatOllama(
base_url=self.base_url_llm,
model=self.model,
headers = self.base_header
)
def get_file(self, file_path):
# Check if the file path starts with 'https://'
if file_path.startswith('https://'):
loader = WebBaseLoader(file_path)
data = loader.load()
if data is None:
return False
else:
file_type = file_path.split(".")[-1]
try:
match file_type:
case "csv":
loader = CSVLoader(file_path=file_path)
data = loader.load()
case "html": # Corrected the typo in the variable name
loader = UnstructuredHTMLLoader(file_path=file_path)
data = loader.load()
case "json":
data = json.loads(Path(file_path).read_text())
case "md":
loader = UnstructuredMarkdownLoader(file_path=file_path)
data = loader.load()
case "pdf":
loader = PyPDFLoader(file_path=file_path)
data = loader.load_and_split()
case _:
loader = WebBaseLoader(file_path)
data = loader.load()
except OSError:
return False
self.data = data
return True
def receive_data(self, file_path):
if self.get_file(file_path):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=0)
splitted = text_splitter.split_documents(self.data)
self.retriever = Chroma.from_documents(documents=splitted, embedding=self.embeddings).as_retriever()
return (False, "Success")
else:
return (True, f"'{file_path}' unsupported, read documentation for more information")
def get_request(self, prompt):
qachain=RetrievalQA.from_chain_type(self.chat_ollama, retriever=self.retriever)
return qachain.invoke({"query": prompt})["result"]