81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from agent.loop import run_session, run_turn
|
|
from tests.conftest import sample_history
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_run_turn_basic(mock_anthropic_client):
|
|
"""test that run_turn calls the API and returns a message"""
|
|
|
|
# patch the client with our mock
|
|
with patch("agent.loop.client", mock_anthropic_client):
|
|
result = await run_turn("What is 2+2?")
|
|
|
|
# verify client was called
|
|
mock_anthropic_client.messages.stream.assert_called_once()
|
|
|
|
# verify message returned
|
|
assert result == "42"
|
|
|
|
# verify call has correct parameters
|
|
call_args = mock_anthropic_client.messages.stream.call_args
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_run_turn_with_history(mock_anthropic_client, sample_history):
|
|
"""test that run_turn includes conversation history in the API call"""
|
|
|
|
# patch the client with our mock
|
|
with patch("agent.loop.client", mock_anthropic_client):
|
|
result = await run_turn("What is 2+2?", history=sample_history)
|
|
|
|
call_args = mock_anthropic_client.messages.stream.call_args
|
|
messages = call_args.kwargs["messages"]
|
|
|
|
# verify all history was included plus new message
|
|
assert len(messages) == 4
|
|
assert messages[0]["content"] == "Hello"
|
|
assert messages[-1]["content"] == "What is 2+2?"
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_run_turn_uses_settings(mock_anthropic_client, settings):
|
|
"""Test that run_turn uses settings correctly."""
|
|
|
|
with patch("agent.loop.client", mock_anthropic_client):
|
|
with patch("agent.loop.settings", settings):
|
|
await run_turn("test message")
|
|
|
|
# Verify settings were used
|
|
call_args = mock_anthropic_client.messages.stream.call_args
|
|
assert call_args.kwargs["model"] == settings.model
|
|
assert call_args.kwargs["max_tokens"] == settings.max_tokens
|
|
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_run_session_calls_run_turn_with_user_input():
|
|
"""Test that user input is passed to run_turn"""
|
|
|
|
mock_history = MagicMock()
|
|
mock_history.get_all.return_value = []
|
|
mock_history.session_id = "test-session"
|
|
|
|
with patch("agent.loop.ConversationHistory", return_value=mock_history):
|
|
with patch("agent.loop.run_turn", new_callable=AsyncMock) as mock_run_turn:
|
|
mock_run_turn.return_value = "response text"
|
|
|
|
# Provide inputs: command, then quit
|
|
inputs = iter(["hello", "/quit"])
|
|
with patch("builtins.input", side_effect=lambda _: next(inputs)):
|
|
await run_session()
|
|
|
|
mock_run_turn.assert_called()
|
|
assert mock_run_turn.call_args.args[0] == "hello"
|