stream responses and all tests passing
This commit is contained in:
+69
-8
@@ -10,7 +10,21 @@ from agent.tools import TOOL_SCHEMAS, dispatch_tool
|
||||
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||||
|
||||
|
||||
async def run_turn(user_message: str, history: list[dict] = None, sandbox=None) -> str:
|
||||
async def run_turn(
|
||||
user_message: str, history: list[dict] = None, sandbox=None, stream_callback=None
|
||||
) -> str:
|
||||
"""
|
||||
Run one turn of agent loop with streaming support.
|
||||
|
||||
Args:
|
||||
user_message: User's input
|
||||
history: Conversation history
|
||||
sandbox: Sandbox session
|
||||
stream_callback: Optional callback(text) for streaming output
|
||||
|
||||
Returns:
|
||||
Final text response
|
||||
"""
|
||||
|
||||
if history is None:
|
||||
history = []
|
||||
@@ -18,44 +32,76 @@ async def run_turn(user_message: str, history: list[dict] = None, sandbox=None)
|
||||
# add the new user message to history
|
||||
messages = history + [{"role": "user", "content": user_message}]
|
||||
|
||||
response = await client.messages.create(
|
||||
full_response = "" # accumulate full response from stream
|
||||
|
||||
async with client.messages.stream(
|
||||
model=settings.model,
|
||||
max_tokens=settings.max_tokens,
|
||||
tools=TOOL_SCHEMAS,
|
||||
messages=messages,
|
||||
)
|
||||
) as stream:
|
||||
# stream text as it arrives
|
||||
async for text in stream.text_stream:
|
||||
full_response += text
|
||||
if stream_callback:
|
||||
await stream_callback(text)
|
||||
|
||||
# get the final message (includes tool calls)
|
||||
response = await stream.get_final_message()
|
||||
|
||||
while response.stop_reason == "tool_use":
|
||||
tool_results = []
|
||||
for block in response.content:
|
||||
if block.type == "tool_use":
|
||||
# show what tool is being called
|
||||
if stream_callback:
|
||||
await stream_callback(
|
||||
f"\n🔧 Running: {block.name}({block.input})\n"
|
||||
)
|
||||
|
||||
# Execute tool
|
||||
result = await dispatch_tool(
|
||||
tool_name=block.name, tool_input=block.input, sandbox=sandbox
|
||||
)
|
||||
|
||||
# show result preview
|
||||
if stream_callback:
|
||||
preview = result[:100] + "..." if len(result) > 100 else result
|
||||
await stream_callback(f"✓ Result: {preview}\n\n")
|
||||
|
||||
tool_results.append(
|
||||
{"type": "tool_result", "tool_use_id": block.id, "content": result}
|
||||
)
|
||||
|
||||
# update messages with tool results
|
||||
messages = messages + [
|
||||
{"role": "assistant", "content": response.content},
|
||||
{"role": "user", "content": tool_results},
|
||||
]
|
||||
|
||||
response = await client.messages.create(
|
||||
# get next response
|
||||
async with client.messages.stream(
|
||||
model=settings.model,
|
||||
max_tokens=settings.max_tokens,
|
||||
tools=TOOL_SCHEMAS,
|
||||
messages=messages,
|
||||
)
|
||||
) as stream:
|
||||
async for text in stream.text_stream:
|
||||
full_response += text
|
||||
if stream_callback:
|
||||
await stream_callback(text)
|
||||
|
||||
return next(block.text for block in response.content if hasattr(block, "text"))
|
||||
response = await stream.get_final_message()
|
||||
|
||||
return full_response
|
||||
|
||||
|
||||
async def run_session(sandbox=None):
|
||||
"""simple CLI session - temporary until TUI is built"""
|
||||
history = ConversationHistory()
|
||||
|
||||
print("Codeing agent ready. Type /quit to quit.")
|
||||
print(f"Codeing agent ready. Session: {history.session_id}")
|
||||
print("Type /quit to quit.\n")
|
||||
|
||||
while True:
|
||||
user_input = input("You: ").strip()
|
||||
@@ -71,4 +117,19 @@ async def run_session(sandbox=None):
|
||||
response = await run_turn(user_input, history.get_all(), sandbox)
|
||||
history.add_message("assistant", response)
|
||||
|
||||
print(f"\nAssistant: {response}")
|
||||
# Print "Agent: " then stream response
|
||||
print(f"\nAgent: ", end="", flush=True)
|
||||
|
||||
# callback that prints text as it arrives
|
||||
async def print_stream(text: str):
|
||||
print(text, end="", flush=True)
|
||||
|
||||
# run turn with streaming
|
||||
response = await run_turn(
|
||||
user_input, history.get_all(), sandbox, stream_callback=print_stream
|
||||
)
|
||||
|
||||
print() # new line after response
|
||||
|
||||
# add to history
|
||||
history.add_message("assistant", response)
|
||||
|
||||
Reference in New Issue
Block a user