looping chatbot with ephemeral history
This commit is contained in:
+23
-7
@@ -3,19 +3,35 @@ import asyncio
|
|||||||
from anthropic import AsyncAnthropic
|
from anthropic import AsyncAnthropic
|
||||||
|
|
||||||
from agent.config import settings
|
from agent.config import settings
|
||||||
|
from agent.history import ConversationHistory
|
||||||
|
|
||||||
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
client = AsyncAnthropic(api_key=settings.anthropic_api_key)
|
||||||
|
history = ConversationHistory()
|
||||||
|
|
||||||
|
|
||||||
async def run_turn(user_message: str) -> str:
|
async def run_turn(user_message: str, history: list[dict] = None) -> str:
|
||||||
|
|
||||||
|
if history is None:
|
||||||
|
history = []
|
||||||
|
|
||||||
|
# add the new user message to history
|
||||||
|
messages = history + [{"role": "user", "content": user_message}]
|
||||||
|
|
||||||
message = await client.messages.create(
|
message = await client.messages.create(
|
||||||
model=settings.model,
|
model=settings.model,
|
||||||
max_tokens=settings.max_tokens,
|
max_tokens=settings.max_tokens,
|
||||||
messages=[
|
messages=messages,
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": user_message,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
async def run_session():
|
||||||
|
while True:
|
||||||
|
user_input = input("You: ")
|
||||||
|
history.add_message("user", user_input)
|
||||||
|
|
||||||
|
response = await run_turn(user_input, history.get_all())
|
||||||
|
history.add_message("assistant", response.content[0].text)
|
||||||
|
|
||||||
|
print(f"Assistant: {response.content[0].text}")
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from agent.loop import run_turn
|
from agent.loop import run_session
|
||||||
|
|
||||||
|
|
||||||
async def run_tui():
|
async def run_tui():
|
||||||
user_message = "what is the answer to life the universe and everythings?"
|
await run_session()
|
||||||
message = await run_turn(user_message)
|
|
||||||
print(message.content)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from pydantic import ValidationError
|
|||||||
from agent.config import Settings
|
from agent.config import Settings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
def test_settings_with_all_values():
|
def test_settings_with_all_values():
|
||||||
"""Test Settings loads correctly with all values provided."""
|
"""Test Settings loads correctly with all values provided."""
|
||||||
settings = Settings(
|
settings = Settings(
|
||||||
@@ -20,6 +21,7 @@ def test_settings_with_all_values():
|
|||||||
assert settings.max_tokens == 500
|
assert settings.max_tokens == 500
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
def test_settings_defaults():
|
def test_settings_defaults():
|
||||||
"""Test Settings uses defaults for optional values."""
|
"""Test Settings uses defaults for optional values."""
|
||||||
settings = Settings(
|
settings = Settings(
|
||||||
@@ -31,6 +33,7 @@ def test_settings_defaults():
|
|||||||
assert settings.max_tokens == 500
|
assert settings.max_tokens == 500
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
def test_settings_missing_required_field():
|
def test_settings_missing_required_field():
|
||||||
"""Test Settings raises error when required field is missing."""
|
"""Test Settings raises error when required field is missing."""
|
||||||
with pytest.raises(ValidationError) as exc_info:
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
@@ -40,6 +43,7 @@ def test_settings_missing_required_field():
|
|||||||
assert "anthropic_api_key" in str(exc_info.value)
|
assert "anthropic_api_key" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
def test_settings_type_validation():
|
def test_settings_type_validation():
|
||||||
"""Test Settings validates types correctly."""
|
"""Test Settings validates types correctly."""
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
|
|||||||
+42
-2
@@ -2,9 +2,11 @@ from unittest.mock import AsyncMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from agent.loop import run_turn
|
from agent.loop import run_session, run_turn
|
||||||
|
from tests.conftest import sample_history
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_turn_basic(mock_anthropic_client):
|
async def test_run_turn_basic(mock_anthropic_client):
|
||||||
"""test that run_turn calls the API and returns a message"""
|
"""test that run_turn calls the API and returns a message"""
|
||||||
@@ -21,9 +23,27 @@ async def test_run_turn_basic(mock_anthropic_client):
|
|||||||
|
|
||||||
# verify call has correct parameters
|
# verify call has correct parameters
|
||||||
call_args = mock_anthropic_client.messages.create.call_args
|
call_args = mock_anthropic_client.messages.create.call_args
|
||||||
assert call_args.kwargs["messages"][0]["content"] == "What is 2+2?"
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_turn_with_history(mock_anthropic_client, sample_history):
|
||||||
|
"""test that run_turn includes conversation history in the API call"""
|
||||||
|
|
||||||
|
# patch the client with our mock
|
||||||
|
with patch("agent.loop.client", mock_anthropic_client):
|
||||||
|
result = await run_turn("What is 2+2?", history=sample_history)
|
||||||
|
|
||||||
|
call_args = mock_anthropic_client.messages.create.call_args
|
||||||
|
messages = call_args.kwargs["messages"]
|
||||||
|
|
||||||
|
# verify all history was included plus new message
|
||||||
|
assert len(messages) == 4
|
||||||
|
assert messages[0]["content"] == "Hello"
|
||||||
|
assert messages[-1]["content"] == "What is 2+2?"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_turn_uses_settings(mock_anthropic_client, settings):
|
async def test_run_turn_uses_settings(mock_anthropic_client, settings):
|
||||||
"""Test that run_turn uses settings correctly."""
|
"""Test that run_turn uses settings correctly."""
|
||||||
@@ -36,3 +56,23 @@ async def test_run_turn_uses_settings(mock_anthropic_client, settings):
|
|||||||
call_args = mock_anthropic_client.messages.create.call_args
|
call_args = mock_anthropic_client.messages.create.call_args
|
||||||
assert call_args.kwargs["model"] == settings.model
|
assert call_args.kwargs["model"] == settings.model
|
||||||
assert call_args.kwargs["max_tokens"] == settings.max_tokens
|
assert call_args.kwargs["max_tokens"] == settings.max_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_session_calls_run_turn_with_user_input():
|
||||||
|
"""Test that user input is passed to run_turn"""
|
||||||
|
|
||||||
|
# Mock the input()
|
||||||
|
with patch("builtins.input", side_effect=["hello", KeyboardInterrupt]):
|
||||||
|
# Mock run_turn to avoid actually calling the API
|
||||||
|
with patch("agent.loop.run_turn", new_callable=AsyncMock) as mock_run_turn:
|
||||||
|
mock_run_turn.return_value = AsyncMock(content=[AsyncMock(text="response")])
|
||||||
|
|
||||||
|
try:
|
||||||
|
await run_session()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_run_turn.assert_called()
|
||||||
|
assert mock_run_turn.call_args.args[0] == "hello"
|
||||||
|
|||||||
Reference in New Issue
Block a user