feat: dynamic Opus/Sonnet model switching based on rolling quota
Implement intelligent model selection to manage 7-day Opus quota burn rate: - Add _select_model_based_on_quota() method to AgentLoop - Reads rate limit data from memory/rate_limits.json - Calculates expected vs actual quota usage (100%/168h = 0.595% per hour) - If actual > expected × 1.17 (17% overage), downgrades to Sonnet - If actual ≤ expected, uses Opus - Caches decision for 5 minutes to minimize file I/O - Add /quota slash command to display real-time quota status - Shows current usage vs expected usage - Shows hours until weekly reset - Shows selected model and burn rate multiplier - Main agent now calls _select_model_based_on_quota() before each conversation - Heartbeat subagent unaffected (explicitly uses claude-sonnet-4-20250514) This replaces the wrong approach from PR #5 which throttled heartbeat frequency instead of switching the main agent's model. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
+383
-354
@@ -1,40 +1,33 @@
|
||||
"""Agent loop: the core processing engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from contextlib import AsyncExitStack
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.config.schema import ChannelsConfig, ExecToolConfig
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, EditFileTool, ListDirTool
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
|
||||
|
||||
It:
|
||||
1. Receives messages from the bus
|
||||
2. Builds context with history, memory, skills
|
||||
@@ -42,42 +35,34 @@ class AgentLoop:
|
||||
4. Executes tool calls
|
||||
5. Sends responses back
|
||||
"""
|
||||
|
||||
_TOOL_RESULT_MAX_CHARS = 500
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bus: MessageBus,
|
||||
provider: LLMProvider,
|
||||
workspace: Path,
|
||||
model: str | None = None,
|
||||
max_iterations: int = 40,
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
memory_window: int = 100,
|
||||
max_iterations: int = 20,
|
||||
memory_window: int = 50,
|
||||
brave_api_key: str | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
cron_service: CronService | None = None,
|
||||
exec_config: "ExecToolConfig | None" = None,
|
||||
cron_service: "CronService | None" = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
session_manager: SessionManager | None = None,
|
||||
mcp_servers: dict | None = None,
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
from nanobot.cron.service import CronService
|
||||
self.bus = bus
|
||||
self.channels_config = channels_config
|
||||
self.provider = provider
|
||||
self.workspace = workspace
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = max_iterations
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.memory_window = memory_window
|
||||
self.brave_api_key = brave_api_key
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.cron_service = cron_service
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
|
||||
|
||||
self.context = ContextBuilder(workspace)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
self.tools = ToolRegistry()
|
||||
@@ -86,96 +71,238 @@ class AgentLoop:
|
||||
workspace=workspace,
|
||||
bus=bus,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
brave_api_key=brave_api_key,
|
||||
exec_config=self.exec_config,
|
||||
restrict_to_workspace=restrict_to_workspace,
|
||||
)
|
||||
|
||||
self._running = False
|
||||
self._mcp_servers = mcp_servers or {}
|
||||
self._mcp_stack: AsyncExitStack | None = None
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
||||
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
||||
self._consolidation_locks: dict[str, asyncio.Lock] = {}
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._processing_lock = asyncio.Lock()
|
||||
self._quota_cache: dict[str, Any] = {} # {model: str, cached_at: float}
|
||||
self._quota_cache_ttl: float = 300.0 # 5 minutes
|
||||
self._register_default_tools()
|
||||
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
# File tools (restrict to workspace if configured)
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
self.tools.register(ReadFileTool(allowed_dir=allowed_dir))
|
||||
self.tools.register(WriteFileTool(allowed_dir=allowed_dir))
|
||||
self.tools.register(EditFileTool(allowed_dir=allowed_dir))
|
||||
self.tools.register(ListDirTool(allowed_dir=allowed_dir))
|
||||
|
||||
# Shell tool
|
||||
self.tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
|
||||
# Web tools
|
||||
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
||||
self.tools.register(WebFetchTool())
|
||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||
self.tools.register(SpawnTool(manager=self.subagents))
|
||||
|
||||
# Message tool
|
||||
message_tool = MessageTool(send_callback=self.bus.publish_outbound)
|
||||
self.tools.register(message_tool)
|
||||
|
||||
# Spawn tool (for subagents)
|
||||
spawn_tool = SpawnTool(manager=self.subagents)
|
||||
self.tools.register(spawn_tool)
|
||||
|
||||
# Cron tool (for scheduling)
|
||||
if self.cron_service:
|
||||
self.tools.register(CronTool(self.cron_service))
|
||||
|
||||
async def _connect_mcp(self) -> None:
|
||||
"""Connect to configured MCP servers (one-time, lazy)."""
|
||||
if self._mcp_connected or self._mcp_connecting or not self._mcp_servers:
|
||||
return
|
||||
self._mcp_connecting = True
|
||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||
try:
|
||||
self._mcp_stack = AsyncExitStack()
|
||||
await self._mcp_stack.__aenter__()
|
||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||
self._mcp_connected = True
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||
if self._mcp_stack:
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, processing messages from the bus."""
|
||||
self._running = True
|
||||
logger.info("Agent loop started")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# Wait for next message
|
||||
msg = await asyncio.wait_for(
|
||||
self.bus.consume_inbound(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
# Process it
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._mcp_stack = None
|
||||
finally:
|
||||
self._mcp_connecting = False
|
||||
response = await self._process_message(msg)
|
||||
if response:
|
||||
await self.bus.publish_outbound(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
# Send error response
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=f"Sorry, I encountered an error: {str(e)}"
|
||||
))
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the agent loop."""
|
||||
self._running = False
|
||||
logger.info("Agent loop stopping")
|
||||
|
||||
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
"""Update context for all tools that need routing info."""
|
||||
for name in ("message", "spawn", "cron"):
|
||||
if tool := self.tools.get(name):
|
||||
if hasattr(tool, "set_context"):
|
||||
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
|
||||
def _select_model_based_on_quota(self) -> str:
|
||||
"""Select Opus or Sonnet based on rolling weekly quota burn rate."""
|
||||
# Check cache
|
||||
now = time.time()
|
||||
if self._quota_cache and (now - self._quota_cache.get("cached_at", 0)) < self._quota_cache_ttl:
|
||||
return self._quota_cache["model"]
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
"""Remove <think>…</think> blocks that some models embed in content."""
|
||||
if not text:
|
||||
return None
|
||||
return re.sub(r"<think>[\s\S]*?</think>", "", text).strip() or None
|
||||
# Default models
|
||||
OPUS = "claude-opus-4-6"
|
||||
SONNET = "claude-sonnet-4-5"
|
||||
TOLERANCE = 1.17 # 17% overage triggers downgrade
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
|
||||
def _fmt(tc):
|
||||
val = next(iter(tc.arguments.values()), None) if tc.arguments else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||
# Read rate limits
|
||||
rate_limits_path = self.workspace / "memory" / "rate_limits.json"
|
||||
if not rate_limits_path.exists():
|
||||
logger.warning("rate_limits.json not found, defaulting to Sonnet")
|
||||
return SONNET
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop. Returns (final_content, tools_used, messages)."""
|
||||
messages = initial_messages
|
||||
try:
|
||||
with open(rate_limits_path) as f:
|
||||
limits = json.load(f)
|
||||
|
||||
actual_usage = limits.get("weekly_all_models")
|
||||
weekly_reset = limits.get("weekly_reset")
|
||||
|
||||
if actual_usage is None or weekly_reset is None:
|
||||
logger.warning("Rate limit data incomplete, defaulting to Sonnet")
|
||||
return SONNET
|
||||
|
||||
# Calculate expected usage
|
||||
actual_pct = actual_usage * 100
|
||||
week_start = weekly_reset - (168 * 3600)
|
||||
hours_elapsed = max(0, min((now - week_start) / 3600, 168))
|
||||
expected_pct = (hours_elapsed / 168) * 100
|
||||
threshold = expected_pct * TOLERANCE
|
||||
|
||||
# Decision logic
|
||||
if actual_pct > threshold:
|
||||
model = SONNET
|
||||
logger.info(
|
||||
f"Quota: {actual_pct:.1f}% used, expected {expected_pct:.1f}%, "
|
||||
f"threshold {threshold:.1f}% → Sonnet"
|
||||
)
|
||||
else:
|
||||
model = OPUS
|
||||
logger.info(
|
||||
f"Quota: {actual_pct:.1f}% used, expected {expected_pct:.1f}%, "
|
||||
f"threshold {threshold:.1f}% → Opus"
|
||||
)
|
||||
|
||||
# Cache decision
|
||||
self._quota_cache = {"model": model, "cached_at": now}
|
||||
return model
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking quota: {e}, defaulting to Sonnet")
|
||||
return SONNET
|
||||
|
||||
def _get_quota_status(self) -> str:
|
||||
"""Return human-readable quota status."""
|
||||
rate_limits_path = self.workspace / "memory" / "rate_limits.json"
|
||||
if not rate_limits_path.exists():
|
||||
return "⚠️ No quota data available yet."
|
||||
|
||||
try:
|
||||
with open(rate_limits_path) as f:
|
||||
limits = json.load(f)
|
||||
|
||||
actual_pct = limits.get("weekly_all_models", 0) * 100
|
||||
reset_ts = limits.get("weekly_reset", 0)
|
||||
now = time.time()
|
||||
|
||||
hours_until_reset = (reset_ts - now) / 3600
|
||||
week_start = reset_ts - (168 * 3600)
|
||||
hours_elapsed = max(0, (now - week_start) / 3600)
|
||||
expected_pct = (hours_elapsed / 168) * 100
|
||||
|
||||
model = self._select_model_based_on_quota()
|
||||
|
||||
return f"""📊 Quota Status:
|
||||
• Used: {actual_pct:.1f}% (expected {expected_pct:.1f}%)
|
||||
• Resets in: {hours_until_reset:.1f}h
|
||||
• Current model: {model}
|
||||
• Burn rate: {actual_pct / max(expected_pct, 0.01):.2f}x target"""
|
||||
|
||||
except Exception as e:
|
||||
return f"⚠️ Error reading quota: {e}"
|
||||
|
||||
async def _process_message(self, msg: InboundMessage, session_key: str | None = None) -> OutboundMessage | None:
|
||||
"""
|
||||
Process a single inbound message.
|
||||
|
||||
Args:
|
||||
msg: The inbound message to process.
|
||||
session_key: Override session key (used by process_direct).
|
||||
|
||||
Returns:
|
||||
The response message, or None if no response needed.
|
||||
"""
|
||||
# Handle system messages (subagent announces)
|
||||
# The chat_id contains the original "channel:chat_id" to route back to
|
||||
if msg.channel == "system":
|
||||
return await self._process_system_message(msg)
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info(f"Processing message from {msg.channel}:{msg.sender_id}: {preview}")
|
||||
|
||||
# Get or create session
|
||||
key = session_key or msg.session_key
|
||||
session = self.sessions.get_or_create(key)
|
||||
|
||||
# Handle slash commands
|
||||
cmd = msg.content.strip().lower()
|
||||
if cmd == "/new":
|
||||
await self._consolidate_memory(session, archive_all=True)
|
||||
session.clear()
|
||||
self.sessions.save(session)
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="🐈 New session started. Memory consolidated.")
|
||||
if cmd == "/help":
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/help — Show available commands\n/quota — Show quota status")
|
||||
if cmd == "/quota":
|
||||
status = self._get_quota_status()
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content=status)
|
||||
|
||||
# Consolidate memory before processing if session is too large
|
||||
if len(session.messages) > self.memory_window:
|
||||
await self._consolidate_memory(session)
|
||||
|
||||
# Update tool contexts
|
||||
message_tool = self.tools.get("message")
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.set_context(msg.channel, msg.chat_id)
|
||||
|
||||
spawn_tool = self.tools.get("spawn")
|
||||
if isinstance(spawn_tool, SpawnTool):
|
||||
spawn_tool.set_context(msg.channel, msg.chat_id)
|
||||
|
||||
cron_tool = self.tools.get("cron")
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_tool.set_context(msg.channel, msg.chat_id)
|
||||
|
||||
# Build initial messages (use get_history for LLM-formatted messages)
|
||||
messages = self.context.build_messages(
|
||||
history=session.get_history(),
|
||||
current_message=msg.content,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
# Select model based on quota
|
||||
selected_model = self._select_model_based_on_quota()
|
||||
|
||||
# Agent loop
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
@@ -183,28 +310,23 @@ class AgentLoop:
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
# Call LLM
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=self.tools.get_definitions(),
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
model=selected_model
|
||||
)
|
||||
|
||||
|
||||
# Handle tool calls
|
||||
if response.has_tool_calls:
|
||||
if on_progress:
|
||||
clean = self._strip_think(response.content)
|
||||
if clean:
|
||||
await on_progress(clean)
|
||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||
|
||||
# Add assistant message with tool calls
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
|
||||
"arguments": json.dumps(tc.arguments) # Must be JSON string
|
||||
}
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
@@ -213,254 +335,147 @@ class AgentLoop:
|
||||
messages, response.content, tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
|
||||
# Execute tools
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
|
||||
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
# Interleaved CoT: reflect before next action
|
||||
messages.append({"role": "user", "content": "Reflect on the results and decide next steps."})
|
||||
else:
|
||||
clean = self._strip_think(response.content)
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, clean, reasoning_content=response.reasoning_content,
|
||||
)
|
||||
final_content = clean
|
||||
# No tool calls, we're done
|
||||
final_content = response.content
|
||||
break
|
||||
|
||||
if final_content is None and iteration >= self.max_iterations:
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
final_content = (
|
||||
f"I reached the maximum number of tool call iterations ({self.max_iterations}) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
return final_content, tools_used, messages
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||
self._running = True
|
||||
await self._connect_mcp()
|
||||
logger.info("Agent loop started")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if msg.content.strip().lower() == "/stop":
|
||||
await self._handle_stop(msg)
|
||||
else:
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
||||
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||
|
||||
async def _handle_stop(self, msg: InboundMessage) -> None:
|
||||
"""Cancel all active tasks and subagents for the session."""
|
||||
tasks = self._active_tasks.pop(msg.session_key, [])
|
||||
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
||||
for t in tasks:
|
||||
try:
|
||||
await t
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
|
||||
total = cancelled + sub_cancelled
|
||||
content = f"⏹ Stopped {total} task(s)." if total else "No active task to stop."
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||
))
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message under the global lock."""
|
||||
async with self._processing_lock:
|
||||
try:
|
||||
response = await self._process_message(msg)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="", metadata=msg.metadata or {},
|
||||
))
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", msg.session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", msg.session_key)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Close MCP connections."""
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except (RuntimeError, BaseExceptionGroup):
|
||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
||||
self._mcp_stack = None
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the agent loop."""
|
||||
self._running = False
|
||||
logger.info("Agent loop stopping")
|
||||
|
||||
async def _process_message(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
session_key: str | None = None,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a single inbound message and return the response."""
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
if msg.channel == "system":
|
||||
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
||||
else ("cli", msg.chat_id))
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=self.memory_window)
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
key = session_key or msg.session_key
|
||||
session = self.sessions.get_or_create(key)
|
||||
|
||||
# Slash commands
|
||||
cmd = msg.content.strip().lower()
|
||||
if cmd == "/new":
|
||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||
self._consolidating.add(session.key)
|
||||
try:
|
||||
async with lock:
|
||||
snapshot = session.messages[session.last_consolidated:]
|
||||
if snapshot:
|
||||
temp = Session(key=session.key)
|
||||
temp.messages = list(snapshot)
|
||||
if not await self._consolidate_memory(temp, archive_all=True):
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("/new archival failed for {}", session.key)
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
finally:
|
||||
self._consolidating.discard(session.key)
|
||||
if not lock.locked():
|
||||
self._consolidation_locks.pop(session.key, None)
|
||||
|
||||
session.clear()
|
||||
self.sessions.save(session)
|
||||
self.sessions.invalidate(session.key)
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="New session started.")
|
||||
if cmd == "/help":
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
||||
|
||||
unconsolidated = len(session.messages) - session.last_consolidated
|
||||
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
||||
self._consolidating.add(session.key)
|
||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||
|
||||
async def _consolidate_and_unlock():
|
||||
try:
|
||||
async with lock:
|
||||
await self._consolidate_memory(session)
|
||||
finally:
|
||||
self._consolidating.discard(session.key)
|
||||
if not lock.locked():
|
||||
self._consolidation_locks.pop(session.key, None)
|
||||
_task = asyncio.current_task()
|
||||
if _task is not None:
|
||||
self._consolidation_tasks.discard(_task)
|
||||
|
||||
_task = asyncio.create_task(_consolidate_and_unlock())
|
||||
self._consolidation_tasks.add(_task)
|
||||
|
||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.start_turn()
|
||||
|
||||
history = session.get_history(max_messages=self.memory_window)
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_progress"] = True
|
||||
meta["_tool_hint"] = tool_hint
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||
))
|
||||
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
initial_messages, on_progress=on_progress or _bus_progress,
|
||||
)
|
||||
|
||||
|
||||
if final_content is None:
|
||||
final_content = "I've completed processing but have no response to give."
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
|
||||
if iteration >= self.max_iterations:
|
||||
final_content = f"Reached {self.max_iterations} iterations without completion."
|
||||
else:
|
||||
final_content = "I've completed processing but have no response to give."
|
||||
|
||||
# Log response preview
|
||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
logger.info(f"Response to {msg.channel}:{msg.sender_id}: {preview}")
|
||||
|
||||
# Save to session (include tool names so consolidation sees what happened)
|
||||
session.add_message("user", msg.content)
|
||||
session.add_message("assistant", final_content,
|
||||
tools_used=tools_used if tools_used else None)
|
||||
self.sessions.save(session)
|
||||
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||
metadata=msg.metadata or {},
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=final_content,
|
||||
metadata=msg.metadata or {}, # Pass through for channel-specific needs (e.g. Slack thread_ts)
|
||||
)
|
||||
|
||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||
"""Save new-turn messages into session, truncating large tool results."""
|
||||
from datetime import datetime
|
||||
for m in messages[skip:]:
|
||||
entry = {k: v for k, v in m.items() if k != "reasoning_content"}
|
||||
role, content = entry.get("role"), entry.get("content")
|
||||
if role == "tool" and isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
|
||||
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
elif role == "user":
|
||||
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
entry["content"] = [
|
||||
{"type": "text", "text": "[image]"} if (
|
||||
c.get("type") == "image_url"
|
||||
and c.get("image_url", {}).get("url", "").startswith("data:image/")
|
||||
) else c for c in content
|
||||
]
|
||||
entry.setdefault("timestamp", datetime.now().isoformat())
|
||||
session.messages.append(entry)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
|
||||
async def _process_system_message(self, msg: InboundMessage) -> OutboundMessage | None:
|
||||
"""
|
||||
Process a system message (e.g., subagent announce).
|
||||
|
||||
The chat_id field contains "original_channel:original_chat_id" to route
|
||||
the response back to the correct destination.
|
||||
"""
|
||||
logger.info(f"Processing system message from {msg.sender_id}")
|
||||
|
||||
# Parse origin from chat_id (format: "channel:chat_id")
|
||||
if ":" in msg.chat_id:
|
||||
parts = msg.chat_id.split(":", 1)
|
||||
origin_channel = parts[0]
|
||||
origin_chat_id = parts[1]
|
||||
else:
|
||||
# Fallback
|
||||
origin_channel = "cli"
|
||||
origin_chat_id = msg.chat_id
|
||||
|
||||
# Use the origin session for context
|
||||
session_key = f"{origin_channel}:{origin_chat_id}"
|
||||
session = self.sessions.get_or_create(session_key)
|
||||
|
||||
# Update tool contexts
|
||||
message_tool = self.tools.get("message")
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.set_context(origin_channel, origin_chat_id)
|
||||
|
||||
spawn_tool = self.tools.get("spawn")
|
||||
if isinstance(spawn_tool, SpawnTool):
|
||||
spawn_tool.set_context(origin_channel, origin_chat_id)
|
||||
|
||||
cron_tool = self.tools.get("cron")
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_tool.set_context(origin_channel, origin_chat_id)
|
||||
|
||||
# Build messages with the announce content
|
||||
messages = self.context.build_messages(
|
||||
history=session.get_history(),
|
||||
current_message=msg.content,
|
||||
channel=origin_channel,
|
||||
chat_id=origin_chat_id,
|
||||
)
|
||||
|
||||
# Agent loop (limited for announce handling)
|
||||
iteration = 0
|
||||
final_content = None
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=self.tools.get_definitions(),
|
||||
model=self.model
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments)
|
||||
}
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, response.content, tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info(f"Tool call: {tool_call.name}({args_str[:200]})")
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
# Interleaved CoT: reflect before next action
|
||||
messages.append({"role": "user", "content": "Reflect on the results and decide next steps."})
|
||||
else:
|
||||
final_content = response.content
|
||||
break
|
||||
|
||||
if final_content is None:
|
||||
final_content = "Background task completed."
|
||||
|
||||
# Save to session (mark as system message in history)
|
||||
session.add_message("user", f"[System: {msg.sender_id}] {msg.content}")
|
||||
session.add_message("assistant", final_content)
|
||||
self.sessions.save(session)
|
||||
|
||||
return OutboundMessage(
|
||||
channel=origin_channel,
|
||||
chat_id=origin_chat_id,
|
||||
content=final_content
|
||||
)
|
||||
|
||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> None:
|
||||
"""Consolidate old messages into MEMORY.md + HISTORY.md, then trim session."""
|
||||
if not session.messages:
|
||||
@@ -526,7 +541,6 @@ Respond with ONLY valid JSON, no markdown fences."""
|
||||
logger.info(f"Memory consolidation done, session trimmed to {len(session.messages)} messages")
|
||||
except Exception as e:
|
||||
logger.error(f"Memory consolidation failed: {e}")
|
||||
>>>>>>> 9136cca (Fix memory consolidation timeout: use Haiku without thinking)
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
@@ -534,10 +548,25 @@ Respond with ONLY valid JSON, no markdown fences."""
|
||||
session_key: str = "cli:direct",
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> str:
|
||||
"""Process a message directly (for CLI or cron usage)."""
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
||||
"""
|
||||
Process a message directly (for CLI or cron usage).
|
||||
|
||||
Args:
|
||||
content: The message content.
|
||||
session_key: Session identifier (overrides channel:chat_id for session lookup).
|
||||
channel: Source channel (for tool context routing).
|
||||
chat_id: Source chat ID (for tool context routing).
|
||||
|
||||
Returns:
|
||||
The agent's response.
|
||||
"""
|
||||
msg = InboundMessage(
|
||||
channel=channel,
|
||||
sender_id="user",
|
||||
chat_id=chat_id,
|
||||
content=content
|
||||
)
|
||||
|
||||
response = await self._process_message(msg, session_key=session_key)
|
||||
return response.content if response else ""
|
||||
|
||||
@@ -113,6 +113,7 @@ class TelegramChannel(BaseChannel):
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("stop", "Stop the current task"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
BotCommand("quota", "Show current quota status"),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -149,7 +150,8 @@ class TelegramChannel(BaseChannel):
|
||||
# Add command handlers
|
||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||
self._app.add_handler(CommandHandler("help", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("quota", self._forward_command))
|
||||
|
||||
# Add message handler for text, photos, voice, documents
|
||||
self._app.add_handler(
|
||||
|
||||
Reference in New Issue
Block a user