210 lines
7.4 KiB
Python
210 lines
7.4 KiB
Python
import argparse
|
|
import logging
|
|
import time
|
|
from openai import OpenAI
|
|
from openai import AssistantEventHandler
|
|
from typing_extensions import override
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
class QueryAssistant:
|
|
"""
|
|
A class to manage querying an OpenAI assistant.
|
|
Provides methods to create threads, send messages, and stream or fetch responses.
|
|
"""
|
|
|
|
def __init__(self, assistant_id: str):
|
|
"""
|
|
Initialize the QueryAssistant with the given assistant ID and OpenAI client.
|
|
|
|
:param assistant_id: The ID of the OpenAI assistant.
|
|
"""
|
|
self.client = OpenAI()
|
|
self.assistant_id = assistant_id
|
|
|
|
def create_thread(self):
|
|
"""
|
|
Create a new thread for the assistant.
|
|
|
|
:return: The created thread object.
|
|
"""
|
|
logging.info("Creating a new thread...")
|
|
thread = self.client.beta.threads.create()
|
|
logging.info(f"Thread created: {thread.id}")
|
|
return thread
|
|
|
|
def create_message(self, thread_id: str, content: str):
|
|
"""
|
|
Create a message in the specified thread with the given content.
|
|
|
|
:param thread_id: The ID of the thread.
|
|
:param content: The content of the message.
|
|
:return: The created message object.
|
|
"""
|
|
logging.info(f"Creating message in thread {thread_id}...")
|
|
message = self.client.beta.threads.messages.create(
|
|
thread_id=thread_id, role="user", content=content
|
|
)
|
|
logging.info("Message created")
|
|
return message
|
|
|
|
def stream_response(self, thread_id: str):
|
|
"""
|
|
Stream the response from the assistant for the specified thread.
|
|
|
|
:param thread_id: The ID of the thread.
|
|
"""
|
|
logging.info(f"Streaming response for thread {thread_id}...")
|
|
with self.client.beta.threads.runs.stream(
|
|
thread_id=thread_id,
|
|
assistant_id=self.assistant_id,
|
|
event_handler=self.EventHandler(),
|
|
) as stream:
|
|
stream.until_done()
|
|
logging.info("Response streaming completed")
|
|
|
|
def fetch_response(self, thread_id: str):
|
|
"""
|
|
Fetch the response from the assistant for the specified thread (non-streaming).
|
|
|
|
:param thread_id: The ID of the thread.
|
|
"""
|
|
logging.info(f"Fetching response for thread {thread_id}...")
|
|
run = self.client.beta.threads.runs.create_and_poll(
|
|
thread_id=thread_id, assistant_id=self.assistant_id
|
|
)
|
|
|
|
# Poll the run status with a delay to reduce the number of GET requests
|
|
while run.status != "completed" and run.status != "failed":
|
|
time.sleep(2) # Add a 2-second delay between checks
|
|
run = self.client.beta.threads.runs.retrieve(
|
|
thread_id=thread_id, run_id=run.id
|
|
)
|
|
logging.info(f"Run status: {run.status}")
|
|
|
|
if run.status == "completed":
|
|
messages = self.client.beta.threads.messages.list(thread_id=thread_id).data
|
|
for message in messages:
|
|
if message.role == "assistant":
|
|
for content in message.content:
|
|
if content.type == "text":
|
|
print(content.text.value)
|
|
else:
|
|
logging.error(f"Run failed with status: {run.status}")
|
|
if run.incomplete_details:
|
|
logging.error(f"Incomplete details: {run.incomplete_details}")
|
|
|
|
class EventHandler(AssistantEventHandler):
|
|
"""
|
|
A class to handle events from the assistant's response stream.
|
|
"""
|
|
|
|
@override
|
|
def on_text_created(self, text) -> None:
|
|
"""
|
|
Handle the event when text is created by the assistant.
|
|
|
|
:param text: The created text.
|
|
"""
|
|
logging.info("Text created by assistant")
|
|
print(f"\nassistant > ", end="", flush=True)
|
|
|
|
@override
|
|
def on_text_delta(self, delta, snapshot):
|
|
"""
|
|
Handle the event when there is a delta in the assistant's response.
|
|
|
|
:param delta: The response delta.
|
|
:param snapshot: The snapshot of the response.
|
|
"""
|
|
print(delta.value, end="", flush=True)
|
|
|
|
def on_tool_call_created(self, tool_call):
|
|
"""
|
|
Handle the event when a tool call is created by the assistant.
|
|
|
|
:param tool_call: The created tool call.
|
|
"""
|
|
logging.info(f"Tool call created: {tool_call.type}")
|
|
print(f"\nassistant > {tool_call.type}\n", flush=True)
|
|
|
|
def on_tool_call_delta(self, delta, snapshot):
|
|
"""
|
|
Handle the event when there is a delta in the assistant's tool call.
|
|
|
|
:param delta: The tool call delta.
|
|
:param snapshot: The snapshot of the tool call.
|
|
"""
|
|
if delta.type == "code_interpreter":
|
|
if delta.code_interpreter.input:
|
|
print(delta.code_interpreter.input, end="", flush=True)
|
|
if delta.code_interpreter.outputs:
|
|
print(f"\n\noutput >", flush=True)
|
|
for output in delta.code_interpreter.outputs:
|
|
if output.type == "logs":
|
|
print(f"\n{output.logs}", flush=True)
|
|
|
|
|
|
def main(query: str, assistant_id: str, context: str, use_streaming: bool):
|
|
"""
|
|
The main function to run the assistant query.
|
|
|
|
:param query: The query to ask the assistant.
|
|
:param assistant_id: The ID of the assistant.
|
|
:param context: The context to set before the query.
|
|
:param use_streaming: Boolean flag to determine if streaming should be used.
|
|
"""
|
|
assistant = QueryAssistant(assistant_id=assistant_id)
|
|
thread = assistant.create_thread()
|
|
# Merge the context and query into a single message
|
|
full_query = f"Context: {context}\nQuery: {query}"
|
|
# Print the full query
|
|
print("\n" + "=" * 100)
|
|
print(f"{full_query}")
|
|
print("=" * 100 + "\n")
|
|
# Send the message
|
|
assistant.create_message(thread_id=thread.id, content=full_query)
|
|
if use_streaming:
|
|
assistant.stream_response(thread_id=thread.id)
|
|
else:
|
|
assistant.fetch_response(thread_id=thread.id)
|
|
print("\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Default query and context
|
|
DEFAULT_QUERY = "What are you capable of as an assistant?"
|
|
DEFAULT_CONTEXT = "Use your vector store to answer questions about the ArchiMajor Board. Take time to understand the context and introspect. If you don't know the answer simply respond with 'I don't know'. It is NEVER okay to return an empty response."
|
|
|
|
# Parse command line arguments
|
|
parser = argparse.ArgumentParser(description="Run an assistant query.")
|
|
parser.add_argument(
|
|
"--query",
|
|
type=str,
|
|
default=DEFAULT_QUERY,
|
|
help="The query to ask the assistant.",
|
|
)
|
|
parser.add_argument(
|
|
"--assistant_id",
|
|
type=str,
|
|
default="asst_QnPL19gHzraXGnLYdsVZG4dt",
|
|
help="The assistant ID to use.",
|
|
)
|
|
parser.add_argument(
|
|
"--context",
|
|
type=str,
|
|
default=DEFAULT_CONTEXT,
|
|
help="The context to set before the query.",
|
|
)
|
|
parser.add_argument(
|
|
"--use-streaming",
|
|
action="store_true",
|
|
help="Flag to determine if streaming should be used.",
|
|
)
|
|
|
|
# Run the main function with parsed arguments
|
|
args = parser.parse_args()
|
|
main(args.query, args.assistant_id, args.context, args.use_streaming)
|