99 lines
3.9 KiB
Python
99 lines
3.9 KiB
Python
import re
|
|
from scripts.BaseOllama import OllamaChatBot
|
|
from scripts.Rag import Rag
|
|
from termcolor import colored
|
|
|
|
|
|
CODE_PREFIX = "[1337]"
|
|
CONFIG_FILE = "config/config.json"
|
|
|
|
|
|
class TerminalBot:
|
|
|
|
def __init__(self, inital_prompt, context, base_url='http://localhost:11434', embeddings_url= 'http://localhost:11434', base_model='mistral',
|
|
embeddings_model='mxbai-embed-large',base_header= None, embeddings_header= None):
|
|
self.init_prompt = inital_prompt
|
|
self.context = context
|
|
self.rag = Rag(embeddings=embeddings_model, model=base_model,
|
|
base_url_llm=base_url, base_url_embed=embeddings_url,
|
|
base_header=base_header, embeddings_header=embeddings_header)
|
|
self.bot = OllamaChatBot(base_url=base_url, model=base_model, headers=base_header)
|
|
self.rag.init_ollama()
|
|
|
|
|
|
def start(self):
|
|
if self.context is not None:
|
|
checks = self.rag.receive_data(file_path=self.context)
|
|
if checks[0]:
|
|
self.show_reply(checks[1])
|
|
else:
|
|
self.show_reply(self.rag.get_request(prompt=self.init_prompt))
|
|
self.conversation_with_context()
|
|
else:
|
|
self.show_reply(self.bot.get_request(prompt=self.init_prompt))
|
|
self.conversation_without_context()
|
|
|
|
def conversation_with_context(self):
|
|
prompt = self.print_for_input().strip()
|
|
if prompt == "":
|
|
return "Finished Conversation"
|
|
self.show_reply(self.rag.get_request(prompt=prompt))
|
|
self.conversation_with_context()
|
|
|
|
def conversation_without_context(self):
|
|
prompt = self.print_for_input().strip()
|
|
if prompt == "":
|
|
return "Finished Conversation"
|
|
self.show_reply(self.bot.get_request(prompt=prompt))
|
|
self.conversation_without_context()
|
|
|
|
def print_for_input(self) -> str:
|
|
message_lines = []
|
|
print("Type in your prompt: (Finish with ctrl + d or ctrl + z)")
|
|
while True:
|
|
try:
|
|
line = input(":")
|
|
message_lines.append(line)
|
|
except EOFError:
|
|
break
|
|
return ''.join(message_lines)
|
|
|
|
def show_reply(self, message) -> None:
|
|
message = self.extract_code(message)
|
|
if isinstance(message, (list)):
|
|
for part in message:
|
|
if part.startswith(CODE_PREFIX):
|
|
part = part[len(CODE_PREFIX):]
|
|
print(colored(part, "light_red"))
|
|
else:
|
|
print(colored(part, "white"))
|
|
else:
|
|
print(colored(message + "\n", "white"))
|
|
|
|
def extract_code(self, input_string, replacement=CODE_PREFIX) -> list:
|
|
# Split the input string on the ``` delimiter
|
|
split_parts = re.split(r'(```)', input_string) # Include the delimiter in the results
|
|
|
|
# Initialize an empty list to store the output array
|
|
output_array = []
|
|
|
|
# Track whether the previous part was a ``` delimiter
|
|
previously_delimiter = False
|
|
|
|
|
|
for part in split_parts:
|
|
# Check if the current part is a ``` delimiter
|
|
if part == "```":
|
|
previously_delimiter = True # Set flag if a delimiter is found
|
|
continue # Skip adding the delimiter to the output
|
|
|
|
# If the previous part was a delimiter, replace the first word with the specified string
|
|
if previously_delimiter:
|
|
part = re.sub(r'^\b\w+\b', replacement, part, count=1) # Replace the programming Language
|
|
previously_delimiter = False # Reset the flag
|
|
|
|
# Only add non-empty parts to the output array
|
|
if part.strip():
|
|
output_array.append(part)
|
|
|
|
return output_array |