Altium-archimajor-3d-printe.../assistant/QueryAssistant.py

210 lines
7.4 KiB
Python

import argparse
import logging
import time
from openai import OpenAI
from openai import AssistantEventHandler
from typing_extensions import override
# Configure logging
logging.basicConfig(level=logging.INFO)
class QueryAssistant:
"""
A class to manage querying an OpenAI assistant.
Provides methods to create threads, send messages, and stream or fetch responses.
"""
def __init__(self, assistant_id: str):
"""
Initialize the QueryAssistant with the given assistant ID and OpenAI client.
:param assistant_id: The ID of the OpenAI assistant.
"""
self.client = OpenAI()
self.assistant_id = assistant_id
def create_thread(self):
"""
Create a new thread for the assistant.
:return: The created thread object.
"""
logging.info("Creating a new thread...")
thread = self.client.beta.threads.create()
logging.info(f"Thread created: {thread.id}")
return thread
def create_message(self, thread_id: str, content: str):
"""
Create a message in the specified thread with the given content.
:param thread_id: The ID of the thread.
:param content: The content of the message.
:return: The created message object.
"""
logging.info(f"Creating message in thread {thread_id}...")
message = self.client.beta.threads.messages.create(
thread_id=thread_id, role="user", content=content
)
logging.info("Message created")
return message
def stream_response(self, thread_id: str):
"""
Stream the response from the assistant for the specified thread.
:param thread_id: The ID of the thread.
"""
logging.info(f"Streaming response for thread {thread_id}...")
with self.client.beta.threads.runs.stream(
thread_id=thread_id,
assistant_id=self.assistant_id,
event_handler=self.EventHandler(),
) as stream:
stream.until_done()
logging.info("Response streaming completed")
def fetch_response(self, thread_id: str):
"""
Fetch the response from the assistant for the specified thread (non-streaming).
:param thread_id: The ID of the thread.
"""
logging.info(f"Fetching response for thread {thread_id}...")
run = self.client.beta.threads.runs.create_and_poll(
thread_id=thread_id, assistant_id=self.assistant_id
)
# Poll the run status with a delay to reduce the number of GET requests
while run.status != "completed" and run.status != "failed":
time.sleep(2) # Add a 2-second delay between checks
run = self.client.beta.threads.runs.retrieve(
thread_id=thread_id, run_id=run.id
)
logging.info(f"Run status: {run.status}")
if run.status == "completed":
messages = self.client.beta.threads.messages.list(thread_id=thread_id).data
for message in messages:
if message.role == "assistant":
for content in message.content:
if content.type == "text":
print(content.text.value)
else:
logging.error(f"Run failed with status: {run.status}")
if run.incomplete_details:
logging.error(f"Incomplete details: {run.incomplete_details}")
class EventHandler(AssistantEventHandler):
"""
A class to handle events from the assistant's response stream.
"""
@override
def on_text_created(self, text) -> None:
"""
Handle the event when text is created by the assistant.
:param text: The created text.
"""
logging.info("Text created by assistant")
print(f"\nassistant > ", end="", flush=True)
@override
def on_text_delta(self, delta, snapshot):
"""
Handle the event when there is a delta in the assistant's response.
:param delta: The response delta.
:param snapshot: The snapshot of the response.
"""
print(delta.value, end="", flush=True)
def on_tool_call_created(self, tool_call):
"""
Handle the event when a tool call is created by the assistant.
:param tool_call: The created tool call.
"""
logging.info(f"Tool call created: {tool_call.type}")
print(f"\nassistant > {tool_call.type}\n", flush=True)
def on_tool_call_delta(self, delta, snapshot):
"""
Handle the event when there is a delta in the assistant's tool call.
:param delta: The tool call delta.
:param snapshot: The snapshot of the tool call.
"""
if delta.type == "code_interpreter":
if delta.code_interpreter.input:
print(delta.code_interpreter.input, end="", flush=True)
if delta.code_interpreter.outputs:
print(f"\n\noutput >", flush=True)
for output in delta.code_interpreter.outputs:
if output.type == "logs":
print(f"\n{output.logs}", flush=True)
def main(query: str, assistant_id: str, context: str, use_streaming: bool):
"""
The main function to run the assistant query.
:param query: The query to ask the assistant.
:param assistant_id: The ID of the assistant.
:param context: The context to set before the query.
:param use_streaming: Boolean flag to determine if streaming should be used.
"""
assistant = QueryAssistant(assistant_id=assistant_id)
thread = assistant.create_thread()
# Merge the context and query into a single message
full_query = f"Context: {context}\nQuery: {query}"
# Print the full query
print("\n" + "=" * 100)
print(f"{full_query}")
print("=" * 100 + "\n")
# Send the message
assistant.create_message(thread_id=thread.id, content=full_query)
if use_streaming:
assistant.stream_response(thread_id=thread.id)
else:
assistant.fetch_response(thread_id=thread.id)
print("\n")
if __name__ == "__main__":
# Default query and context
DEFAULT_QUERY = "What are you capable of as an assistant?"
DEFAULT_CONTEXT = "Use your vector store to answer questions about the ArchiMajor Board. Take time to understand the context and introspect. If you don't know the answer simply respond with 'I don't know'. It is NEVER okay to return an empty response."
# Parse command line arguments
parser = argparse.ArgumentParser(description="Run an assistant query.")
parser.add_argument(
"--query",
type=str,
default=DEFAULT_QUERY,
help="The query to ask the assistant.",
)
parser.add_argument(
"--assistant_id",
type=str,
default="asst_QnPL19gHzraXGnLYdsVZG4dt",
help="The assistant ID to use.",
)
parser.add_argument(
"--context",
type=str,
default=DEFAULT_CONTEXT,
help="The context to set before the query.",
)
parser.add_argument(
"--use-streaming",
action="store_true",
help="Flag to determine if streaming should be used.",
)
# Run the main function with parsed arguments
args = parser.parse_args()
main(args.query, args.assistant_id, args.context, args.use_streaming)