stream responses and all tests passing
This commit is contained in:
+30
-12
@@ -24,24 +24,42 @@ def settings():
|
||||
)
|
||||
|
||||
|
||||
# tests/conftest.py - replace mock_anthropic_client fixture
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_anthropic_client():
|
||||
"""Mock anthropic client that returns a fake response."""
|
||||
"""Mock Anthropic client with streaming support."""
|
||||
mock_client = AsyncMock()
|
||||
|
||||
# create a realistic fake response
|
||||
fake_message = Message(
|
||||
id="msg_test123",
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[TextBlock(type="text", text="42")],
|
||||
model="claude-test-model",
|
||||
stop_reason="end_turn",
|
||||
usage=Usage(input_tokens=10, output_tokens=5),
|
||||
# Create mock stream
|
||||
mock_stream = AsyncMock()
|
||||
|
||||
# Mock text_stream - simple async generator
|
||||
async def fake_text():
|
||||
yield "42"
|
||||
|
||||
mock_stream.text_stream = fake_text()
|
||||
|
||||
# Mock get_final_message
|
||||
mock_stream.get_final_message = AsyncMock(
|
||||
return_value=Message(
|
||||
id="msg_test123",
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[TextBlock(type="text", text="42")],
|
||||
model="claude-test-model",
|
||||
stop_reason="end_turn",
|
||||
usage=Usage(input_tokens=10, output_tokens=1),
|
||||
)
|
||||
)
|
||||
|
||||
# make messages.create() return this fake message
|
||||
mock_client.messages.create = AsyncMock(return_value=fake_message)
|
||||
# Mock context manager
|
||||
mock_stream.__aenter__ = AsyncMock(return_value=mock_stream)
|
||||
mock_stream.__aexit__ = AsyncMock()
|
||||
|
||||
# Wire up
|
||||
mock_client.messages.stream = MagicMock(return_value=mock_stream)
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
+15
-13
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -16,13 +16,13 @@ async def test_run_turn_basic(mock_anthropic_client):
|
||||
result = await run_turn("What is 2+2?")
|
||||
|
||||
# verify client was called
|
||||
mock_anthropic_client.messages.create.assert_called_once()
|
||||
mock_anthropic_client.messages.stream.assert_called_once()
|
||||
|
||||
# verify message returned
|
||||
assert result == "42"
|
||||
|
||||
# verify call has correct parameters
|
||||
call_args = mock_anthropic_client.messages.create.call_args
|
||||
call_args = mock_anthropic_client.messages.stream.call_args
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@@ -34,7 +34,7 @@ async def test_run_turn_with_history(mock_anthropic_client, sample_history):
|
||||
with patch("agent.loop.client", mock_anthropic_client):
|
||||
result = await run_turn("What is 2+2?", history=sample_history)
|
||||
|
||||
call_args = mock_anthropic_client.messages.create.call_args
|
||||
call_args = mock_anthropic_client.messages.stream.call_args
|
||||
messages = call_args.kwargs["messages"]
|
||||
|
||||
# verify all history was included plus new message
|
||||
@@ -53,7 +53,7 @@ async def test_run_turn_uses_settings(mock_anthropic_client, settings):
|
||||
await run_turn("test message")
|
||||
|
||||
# Verify settings were used
|
||||
call_args = mock_anthropic_client.messages.create.call_args
|
||||
call_args = mock_anthropic_client.messages.stream.call_args
|
||||
assert call_args.kwargs["model"] == settings.model
|
||||
assert call_args.kwargs["max_tokens"] == settings.max_tokens
|
||||
|
||||
@@ -63,16 +63,18 @@ async def test_run_turn_uses_settings(mock_anthropic_client, settings):
|
||||
async def test_run_session_calls_run_turn_with_user_input():
|
||||
"""Test that user input is passed to run_turn"""
|
||||
|
||||
# Mock the input()
|
||||
with patch("builtins.input", side_effect=["hello", KeyboardInterrupt]):
|
||||
# Mock run_turn to avoid actually calling the API
|
||||
with patch("agent.loop.run_turn", new_callable=AsyncMock) as mock_run_turn:
|
||||
mock_run_turn.return_value = AsyncMock(content=[AsyncMock(text="response")])
|
||||
mock_history = MagicMock()
|
||||
mock_history.get_all.return_value = []
|
||||
mock_history.session_id = "test-session"
|
||||
|
||||
try:
|
||||
with patch("agent.loop.ConversationHistory", return_value=mock_history):
|
||||
with patch("agent.loop.run_turn", new_callable=AsyncMock) as mock_run_turn:
|
||||
mock_run_turn.return_value = "response text"
|
||||
|
||||
# Provide inputs: command, then quit
|
||||
inputs = iter(["hello", "/quit"])
|
||||
with patch("builtins.input", side_effect=lambda _: next(inputs)):
|
||||
await run_session()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
mock_run_turn.assert_called()
|
||||
assert mock_run_turn.call_args.args[0] == "hello"
|
||||
|
||||
@@ -67,11 +67,13 @@ async def test_sandbox_run_executes_command():
|
||||
sb = PodmanSandbox()
|
||||
await sb.__aenter__()
|
||||
|
||||
result = await sb.run("echo 'hello from sandbox'")
|
||||
result = sb.run("echo 'hello from sandbox'")
|
||||
|
||||
# Verify exec_run was called with shell wrapper
|
||||
mock_container.exec_run.assert_called_once_with(
|
||||
["/bin/sh", "-c", "echo 'hello from sandbox'"], workdir="/workspace"
|
||||
["/bin/sh", "-c", "echo 'hello from sandbox'"],
|
||||
workdir="/workspace",
|
||||
demux=False,
|
||||
)
|
||||
assert result == "hello from sandbox\n"
|
||||
|
||||
@@ -83,7 +85,7 @@ async def test_tool_call_fails_if_sandbox_crashes():
|
||||
|
||||
# Simulate crashed sandbox (container is None)
|
||||
mock_sandbox = MagicMock()
|
||||
mock_sandbox.run = AsyncMock(side_effect=RuntimeError("Container crashed"))
|
||||
mock_sandbox.run = MagicMock(side_effect=RuntimeError("Container crashed"))
|
||||
|
||||
result = await bash("ls -la", mock_sandbox)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user