stream responses and all tests passing
This commit is contained in:
+69
-8
@@ -10,7 +10,21 @@ from agent.tools import TOOL_SCHEMAS, dispatch_tool
|
|||||||
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||||||
|
|
||||||
|
|
||||||
async def run_turn(user_message: str, history: list[dict] = None, sandbox=None) -> str:
|
async def run_turn(
|
||||||
|
user_message: str, history: list[dict] = None, sandbox=None, stream_callback=None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Run one turn of agent loop with streaming support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_message: User's input
|
||||||
|
history: Conversation history
|
||||||
|
sandbox: Sandbox session
|
||||||
|
stream_callback: Optional callback(text) for streaming output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final text response
|
||||||
|
"""
|
||||||
|
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
@@ -18,44 +32,76 @@ async def run_turn(user_message: str, history: list[dict] = None, sandbox=None)
|
|||||||
# add the new user message to history
|
# add the new user message to history
|
||||||
messages = history + [{"role": "user", "content": user_message}]
|
messages = history + [{"role": "user", "content": user_message}]
|
||||||
|
|
||||||
response = await client.messages.create(
|
full_response = "" # accumulate full response from stream
|
||||||
|
|
||||||
|
async with client.messages.stream(
|
||||||
model=settings.model,
|
model=settings.model,
|
||||||
max_tokens=settings.max_tokens,
|
max_tokens=settings.max_tokens,
|
||||||
tools=TOOL_SCHEMAS,
|
tools=TOOL_SCHEMAS,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
) as stream:
|
||||||
|
# stream text as it arrives
|
||||||
|
async for text in stream.text_stream:
|
||||||
|
full_response += text
|
||||||
|
if stream_callback:
|
||||||
|
await stream_callback(text)
|
||||||
|
|
||||||
|
# get the final message (includes tool calls)
|
||||||
|
response = await stream.get_final_message()
|
||||||
|
|
||||||
while response.stop_reason == "tool_use":
|
while response.stop_reason == "tool_use":
|
||||||
tool_results = []
|
tool_results = []
|
||||||
for block in response.content:
|
for block in response.content:
|
||||||
if block.type == "tool_use":
|
if block.type == "tool_use":
|
||||||
|
# show what tool is being called
|
||||||
|
if stream_callback:
|
||||||
|
await stream_callback(
|
||||||
|
f"\n🔧 Running: {block.name}({block.input})\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute tool
|
||||||
result = await dispatch_tool(
|
result = await dispatch_tool(
|
||||||
tool_name=block.name, tool_input=block.input, sandbox=sandbox
|
tool_name=block.name, tool_input=block.input, sandbox=sandbox
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# show result preview
|
||||||
|
if stream_callback:
|
||||||
|
preview = result[:100] + "..." if len(result) > 100 else result
|
||||||
|
await stream_callback(f"✓ Result: {preview}\n\n")
|
||||||
|
|
||||||
tool_results.append(
|
tool_results.append(
|
||||||
{"type": "tool_result", "tool_use_id": block.id, "content": result}
|
{"type": "tool_result", "tool_use_id": block.id, "content": result}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# update messages with tool results
|
||||||
messages = messages + [
|
messages = messages + [
|
||||||
{"role": "assistant", "content": response.content},
|
{"role": "assistant", "content": response.content},
|
||||||
{"role": "user", "content": tool_results},
|
{"role": "user", "content": tool_results},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = await client.messages.create(
|
# get next response
|
||||||
|
async with client.messages.stream(
|
||||||
model=settings.model,
|
model=settings.model,
|
||||||
max_tokens=settings.max_tokens,
|
max_tokens=settings.max_tokens,
|
||||||
tools=TOOL_SCHEMAS,
|
tools=TOOL_SCHEMAS,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
) as stream:
|
||||||
|
async for text in stream.text_stream:
|
||||||
|
full_response += text
|
||||||
|
if stream_callback:
|
||||||
|
await stream_callback(text)
|
||||||
|
|
||||||
return next(block.text for block in response.content if hasattr(block, "text"))
|
response = await stream.get_final_message()
|
||||||
|
|
||||||
|
return full_response
|
||||||
|
|
||||||
|
|
||||||
async def run_session(sandbox=None):
|
async def run_session(sandbox=None):
|
||||||
"""simple CLI session - temporary until TUI is built"""
|
"""simple CLI session - temporary until TUI is built"""
|
||||||
history = ConversationHistory()
|
history = ConversationHistory()
|
||||||
|
|
||||||
print("Codeing agent ready. Type /quit to quit.")
|
print(f"Codeing agent ready. Session: {history.session_id}")
|
||||||
|
print("Type /quit to quit.\n")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
user_input = input("You: ").strip()
|
user_input = input("You: ").strip()
|
||||||
@@ -71,4 +117,19 @@ async def run_session(sandbox=None):
|
|||||||
response = await run_turn(user_input, history.get_all(), sandbox)
|
response = await run_turn(user_input, history.get_all(), sandbox)
|
||||||
history.add_message("assistant", response)
|
history.add_message("assistant", response)
|
||||||
|
|
||||||
print(f"\nAssistant: {response}")
|
# Print "Agent: " then stream response
|
||||||
|
print(f"\nAgent: ", end="", flush=True)
|
||||||
|
|
||||||
|
# callback that prints text as it arrives
|
||||||
|
async def print_stream(text: str):
|
||||||
|
print(text, end="", flush=True)
|
||||||
|
|
||||||
|
# run turn with streaming
|
||||||
|
response = await run_turn(
|
||||||
|
user_input, history.get_all(), sandbox, stream_callback=print_stream
|
||||||
|
)
|
||||||
|
|
||||||
|
print() # new line after response
|
||||||
|
|
||||||
|
# add to history
|
||||||
|
history.add_message("assistant", response)
|
||||||
|
|||||||
+3
-3
@@ -30,12 +30,12 @@ class PodmanSandbox:
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def run(self, command: str) -> str:
|
def run(self, command: str) -> str:
|
||||||
"""Execute command in microVM/"""
|
"""Execute command in microVM/"""
|
||||||
exit_code, output = self.container.exec_run(
|
exit_code, output = self.container.exec_run(
|
||||||
["/bin/sh", "-c", command], workdir="/workspace"
|
["/bin/sh", "-c", command], workdir="/workspace", demux=False
|
||||||
)
|
)
|
||||||
return output.decode()
|
return output.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
async def __aexit__(self, *args):
|
async def __aexit__(self, *args):
|
||||||
if self.container:
|
if self.container:
|
||||||
|
|||||||
+24
-6
@@ -24,24 +24,42 @@ def settings():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# tests/conftest.py - replace mock_anthropic_client fixture
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_anthropic_client():
|
def mock_anthropic_client():
|
||||||
"""Mock anthropic client that returns a fake response."""
|
"""Mock Anthropic client with streaming support."""
|
||||||
mock_client = AsyncMock()
|
mock_client = AsyncMock()
|
||||||
|
|
||||||
# create a realistic fake response
|
# Create mock stream
|
||||||
fake_message = Message(
|
mock_stream = AsyncMock()
|
||||||
|
|
||||||
|
# Mock text_stream - simple async generator
|
||||||
|
async def fake_text():
|
||||||
|
yield "42"
|
||||||
|
|
||||||
|
mock_stream.text_stream = fake_text()
|
||||||
|
|
||||||
|
# Mock get_final_message
|
||||||
|
mock_stream.get_final_message = AsyncMock(
|
||||||
|
return_value=Message(
|
||||||
id="msg_test123",
|
id="msg_test123",
|
||||||
type="message",
|
type="message",
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=[TextBlock(type="text", text="42")],
|
content=[TextBlock(type="text", text="42")],
|
||||||
model="claude-test-model",
|
model="claude-test-model",
|
||||||
stop_reason="end_turn",
|
stop_reason="end_turn",
|
||||||
usage=Usage(input_tokens=10, output_tokens=5),
|
usage=Usage(input_tokens=10, output_tokens=1),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# make messages.create() return this fake message
|
# Mock context manager
|
||||||
mock_client.messages.create = AsyncMock(return_value=fake_message)
|
mock_stream.__aenter__ = AsyncMock(return_value=mock_stream)
|
||||||
|
mock_stream.__aexit__ = AsyncMock()
|
||||||
|
|
||||||
|
# Wire up
|
||||||
|
mock_client.messages.stream = MagicMock(return_value=mock_stream)
|
||||||
|
|
||||||
return mock_client
|
return mock_client
|
||||||
|
|
||||||
|
|||||||
+15
-13
@@ -1,4 +1,4 @@
|
|||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -16,13 +16,13 @@ async def test_run_turn_basic(mock_anthropic_client):
|
|||||||
result = await run_turn("What is 2+2?")
|
result = await run_turn("What is 2+2?")
|
||||||
|
|
||||||
# verify client was called
|
# verify client was called
|
||||||
mock_anthropic_client.messages.create.assert_called_once()
|
mock_anthropic_client.messages.stream.assert_called_once()
|
||||||
|
|
||||||
# verify message returned
|
# verify message returned
|
||||||
assert result == "42"
|
assert result == "42"
|
||||||
|
|
||||||
# verify call has correct parameters
|
# verify call has correct parameters
|
||||||
call_args = mock_anthropic_client.messages.create.call_args
|
call_args = mock_anthropic_client.messages.stream.call_args
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@@ -34,7 +34,7 @@ async def test_run_turn_with_history(mock_anthropic_client, sample_history):
|
|||||||
with patch("agent.loop.client", mock_anthropic_client):
|
with patch("agent.loop.client", mock_anthropic_client):
|
||||||
result = await run_turn("What is 2+2?", history=sample_history)
|
result = await run_turn("What is 2+2?", history=sample_history)
|
||||||
|
|
||||||
call_args = mock_anthropic_client.messages.create.call_args
|
call_args = mock_anthropic_client.messages.stream.call_args
|
||||||
messages = call_args.kwargs["messages"]
|
messages = call_args.kwargs["messages"]
|
||||||
|
|
||||||
# verify all history was included plus new message
|
# verify all history was included plus new message
|
||||||
@@ -53,7 +53,7 @@ async def test_run_turn_uses_settings(mock_anthropic_client, settings):
|
|||||||
await run_turn("test message")
|
await run_turn("test message")
|
||||||
|
|
||||||
# Verify settings were used
|
# Verify settings were used
|
||||||
call_args = mock_anthropic_client.messages.create.call_args
|
call_args = mock_anthropic_client.messages.stream.call_args
|
||||||
assert call_args.kwargs["model"] == settings.model
|
assert call_args.kwargs["model"] == settings.model
|
||||||
assert call_args.kwargs["max_tokens"] == settings.max_tokens
|
assert call_args.kwargs["max_tokens"] == settings.max_tokens
|
||||||
|
|
||||||
@@ -63,16 +63,18 @@ async def test_run_turn_uses_settings(mock_anthropic_client, settings):
|
|||||||
async def test_run_session_calls_run_turn_with_user_input():
|
async def test_run_session_calls_run_turn_with_user_input():
|
||||||
"""Test that user input is passed to run_turn"""
|
"""Test that user input is passed to run_turn"""
|
||||||
|
|
||||||
# Mock the input()
|
mock_history = MagicMock()
|
||||||
with patch("builtins.input", side_effect=["hello", KeyboardInterrupt]):
|
mock_history.get_all.return_value = []
|
||||||
# Mock run_turn to avoid actually calling the API
|
mock_history.session_id = "test-session"
|
||||||
with patch("agent.loop.run_turn", new_callable=AsyncMock) as mock_run_turn:
|
|
||||||
mock_run_turn.return_value = AsyncMock(content=[AsyncMock(text="response")])
|
|
||||||
|
|
||||||
try:
|
with patch("agent.loop.ConversationHistory", return_value=mock_history):
|
||||||
|
with patch("agent.loop.run_turn", new_callable=AsyncMock) as mock_run_turn:
|
||||||
|
mock_run_turn.return_value = "response text"
|
||||||
|
|
||||||
|
# Provide inputs: command, then quit
|
||||||
|
inputs = iter(["hello", "/quit"])
|
||||||
|
with patch("builtins.input", side_effect=lambda _: next(inputs)):
|
||||||
await run_session()
|
await run_session()
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
|
|
||||||
mock_run_turn.assert_called()
|
mock_run_turn.assert_called()
|
||||||
assert mock_run_turn.call_args.args[0] == "hello"
|
assert mock_run_turn.call_args.args[0] == "hello"
|
||||||
|
|||||||
@@ -67,11 +67,13 @@ async def test_sandbox_run_executes_command():
|
|||||||
sb = PodmanSandbox()
|
sb = PodmanSandbox()
|
||||||
await sb.__aenter__()
|
await sb.__aenter__()
|
||||||
|
|
||||||
result = await sb.run("echo 'hello from sandbox'")
|
result = sb.run("echo 'hello from sandbox'")
|
||||||
|
|
||||||
# Verify exec_run was called with shell wrapper
|
# Verify exec_run was called with shell wrapper
|
||||||
mock_container.exec_run.assert_called_once_with(
|
mock_container.exec_run.assert_called_once_with(
|
||||||
["/bin/sh", "-c", "echo 'hello from sandbox'"], workdir="/workspace"
|
["/bin/sh", "-c", "echo 'hello from sandbox'"],
|
||||||
|
workdir="/workspace",
|
||||||
|
demux=False,
|
||||||
)
|
)
|
||||||
assert result == "hello from sandbox\n"
|
assert result == "hello from sandbox\n"
|
||||||
|
|
||||||
@@ -83,7 +85,7 @@ async def test_tool_call_fails_if_sandbox_crashes():
|
|||||||
|
|
||||||
# Simulate crashed sandbox (container is None)
|
# Simulate crashed sandbox (container is None)
|
||||||
mock_sandbox = MagicMock()
|
mock_sandbox = MagicMock()
|
||||||
mock_sandbox.run = AsyncMock(side_effect=RuntimeError("Container crashed"))
|
mock_sandbox.run = MagicMock(side_effect=RuntimeError("Container crashed"))
|
||||||
|
|
||||||
result = await bash("ls -la", mock_sandbox)
|
result = await bash("ls -la", mock_sandbox)
|
||||||
|
|
||||||
|
|||||||
+4
-1
@@ -1,3 +1,6 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
async def bash(command: str, sandbox=None) -> str:
|
async def bash(command: str, sandbox=None) -> str:
|
||||||
"""
|
"""
|
||||||
Execute a bash command in the sandbox.
|
Execute a bash command in the sandbox.
|
||||||
@@ -14,7 +17,7 @@ async def bash(command: str, sandbox=None) -> str:
|
|||||||
return "Error: Sandbox not available"
|
return "Error: Sandbox not available"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await sandbox.run(command)
|
result = await asyncio.to_thread(sandbox.run, command)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user