feat(bus): add correlation store for request-response
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -16,6 +16,9 @@ class MessageBus:
|
||||
def __init__(self):
|
||||
self.inbound: asyncio.Queue[InboundMessage] = asyncio.Queue()
|
||||
self.outbound: asyncio.Queue[OutboundMessage] = asyncio.Queue()
|
||||
self._outbound_subscribers: dict[str, list[Callable[[OutboundMessage], Awaitable[None]]]] = {}
|
||||
self._correlation_store: dict[str, asyncio.Future] = {}
|
||||
self._running = False
|
||||
|
||||
async def publish_inbound(self, msg: InboundMessage) -> None:
|
||||
"""Publish a message from a channel to the agent."""
|
||||
@@ -33,6 +36,60 @@ class MessageBus:
|
||||
"""Consume the next outbound message (blocks until available)."""
|
||||
return await self.outbound.get()
|
||||
|
||||
def register_correlation(self, correlation_id: str) -> asyncio.Future:
|
||||
"""Register a Future to be resolved when a matching outbound message appears."""
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
self._correlation_store[correlation_id] = future
|
||||
return future
|
||||
|
||||
def resolve_correlation(self, msg: OutboundMessage) -> None:
|
||||
"""Check if an outbound message has a correlation_id and resolve the matching Future."""
|
||||
cid = msg.metadata.get("correlation_id") if msg.metadata else None
|
||||
if cid and cid in self._correlation_store:
|
||||
future = self._correlation_store.pop(cid)
|
||||
if not future.done():
|
||||
future.set_result(msg.content)
|
||||
|
||||
def cancel_correlation(self, correlation_id: str) -> None:
|
||||
"""Cancel and remove a pending correlation."""
|
||||
future = self._correlation_store.pop(correlation_id, None)
|
||||
if future and not future.done():
|
||||
future.cancel()
|
||||
|
||||
def subscribe_outbound(
|
||||
self,
|
||||
channel: str,
|
||||
callback: Callable[[OutboundMessage], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Subscribe to outbound messages for a specific channel."""
|
||||
if channel not in self._outbound_subscribers:
|
||||
self._outbound_subscribers[channel] = []
|
||||
self._outbound_subscribers[channel].append(callback)
|
||||
|
||||
async def dispatch_outbound(self) -> None:
|
||||
"""
|
||||
Dispatch outbound messages to subscribed channels.
|
||||
Run this as a background task.
|
||||
"""
|
||||
self._running = True
|
||||
while self._running:
|
||||
try:
|
||||
msg = await asyncio.wait_for(self.outbound.get(), timeout=1.0)
|
||||
subscribers = self._outbound_subscribers.get(msg.channel, [])
|
||||
for callback in subscribers:
|
||||
try:
|
||||
await callback(msg)
|
||||
except Exception as e:
|
||||
logger.error(f"Error dispatching to {msg.channel}: {e}")
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher loop."""
|
||||
self._running = False
|
||||
|
||||
>>>>>>> e5bad4e (feat(bus): add correlation store for request-response)
|
||||
@property
|
||||
def inbound_size(self) -> int:
|
||||
"""Number of pending inbound messages."""
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Tests for bus-level correlation (request-response via Futures)."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bus():
|
||||
return MessageBus()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_correlation_returns_future(bus):
|
||||
future = bus.register_correlation("test-id-1")
|
||||
assert isinstance(future, asyncio.Future)
|
||||
assert not future.done()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_sets_future_result(bus):
|
||||
future = bus.register_correlation("test-id-1")
|
||||
msg = OutboundMessage(channel="hook", chat_id="test", content="hello", metadata={"correlation_id": "test-id-1"})
|
||||
bus.resolve_correlation(msg)
|
||||
assert future.done()
|
||||
assert future.result() == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_no_match_is_noop(bus):
|
||||
future = bus.register_correlation("test-id-1")
|
||||
msg = OutboundMessage(channel="hook", chat_id="test", content="hello", metadata={"correlation_id": "other-id"})
|
||||
bus.resolve_correlation(msg)
|
||||
assert not future.done()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_no_metadata_is_noop(bus):
|
||||
future = bus.register_correlation("test-id-1")
|
||||
msg = OutboundMessage(channel="hook", chat_id="test", content="hello")
|
||||
bus.resolve_correlation(msg)
|
||||
assert not future.done()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_correlation_cleans_up_store(bus):
|
||||
future = bus.register_correlation("test-id-1")
|
||||
msg = OutboundMessage(channel="hook", chat_id="test", content="hello", metadata={"correlation_id": "test-id-1"})
|
||||
bus.resolve_correlation(msg)
|
||||
assert "test-id-1" not in bus._correlation_store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_correlation(bus):
|
||||
future = bus.register_correlation("test-id-1")
|
||||
bus.cancel_correlation("test-id-1")
|
||||
assert "test-id-1" not in bus._correlation_store
|
||||
assert future.cancelled()
|
||||
Reference in New Issue
Block a user