Port teamwork parity slice and add miniapp foundation
This commit is contained in:
@@ -2,6 +2,7 @@ __pycache__/
|
||||
archive/
|
||||
.omx/
|
||||
.clawd-agents/
|
||||
.clawd-state/
|
||||
# Claude Code local artifacts
|
||||
.claude/settings.local.json
|
||||
.claude/sessions/
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
target/
|
||||
.omx/
|
||||
.clawd-agents/
|
||||
.clawd-state/
|
||||
|
||||
Generated
+3
@@ -282,6 +282,7 @@ dependencies = [
|
||||
"channel-gateway-core",
|
||||
"futures-core",
|
||||
"reqwest",
|
||||
"runtime",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"subtle",
|
||||
@@ -293,6 +294,7 @@ name = "claw-telegram"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"api",
|
||||
"axum",
|
||||
"base64",
|
||||
"bollard",
|
||||
"channel-gateway-core",
|
||||
@@ -305,6 +307,7 @@ dependencies = [
|
||||
"sha2",
|
||||
"tokio",
|
||||
"tools",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -10,8 +10,11 @@ pub use manifest::{
|
||||
ProfileRecord, WorkerDefaults, WorkerSpec,
|
||||
};
|
||||
pub use protocol::{
|
||||
GeneratedFileDescriptor, InboundAttachment, TurnSource, WorkerApprovalDecision,
|
||||
WorkerStatusResponse, WorkerTurnAccepted, WorkerTurnEvent, WorkerTurnRequest,
|
||||
GeneratedFileDescriptor, InboundAttachment, TurnSource, WorkerAgentListResponse,
|
||||
WorkerApprovalDecision, WorkerBackgroundApprovalListResponse, WorkerMailboxMessageEvent,
|
||||
WorkerMailboxPendingResponse, WorkerMailboxSummaryResponse, WorkerStatusResponse,
|
||||
WorkerTaskListResponse, WorkerTaskSnapshotResponse, WorkerTeamCreatedEvent,
|
||||
WorkerTeamSnapshotResponse, WorkerTurnAccepted, WorkerTurnEvent, WorkerTurnRequest,
|
||||
};
|
||||
pub use runtime_host::{
|
||||
ApprovalDecision, ApprovalRequestPayload, ApprovalResponder, AttachmentKind, AttachmentRef,
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use runtime::{
|
||||
BackgroundApprovalRecord, MailboxSummary, RuntimeTaskRecord, TaskListRecord, TeamRecord,
|
||||
};
|
||||
|
||||
use crate::runtime_host::{ApprovalRequestPayload, AttachmentKind};
|
||||
|
||||
@@ -45,6 +48,24 @@ pub struct GeneratedFileDescriptor {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct WorkerTeamCreatedEvent {
|
||||
pub team: TeamRecord,
|
||||
pub task_list_id: String,
|
||||
pub team_file_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct WorkerMailboxMessageEvent {
|
||||
pub team_name: String,
|
||||
pub sender: String,
|
||||
pub count: usize,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub recipients: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub summary: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum WorkerTurnEvent {
|
||||
AssistantTextDelta {
|
||||
@@ -67,6 +88,29 @@ pub enum WorkerTurnEvent {
|
||||
AutoCompaction {
|
||||
removed_message_count: usize,
|
||||
},
|
||||
TaskCreated {
|
||||
task_list_id: String,
|
||||
task: TaskListRecord,
|
||||
},
|
||||
TaskUpdated {
|
||||
task_list_id: String,
|
||||
task: TaskListRecord,
|
||||
},
|
||||
TaskStopped {
|
||||
task: RuntimeTaskRecord,
|
||||
},
|
||||
AgentSpawned {
|
||||
agent: RuntimeTaskRecord,
|
||||
},
|
||||
TeamCreated {
|
||||
team: WorkerTeamCreatedEvent,
|
||||
},
|
||||
TeamDeleted {
|
||||
team_name: String,
|
||||
},
|
||||
MailboxMessage {
|
||||
message: WorkerMailboxMessageEvent,
|
||||
},
|
||||
Completed {
|
||||
final_text: String,
|
||||
iterations: usize,
|
||||
@@ -97,4 +141,49 @@ pub struct WorkerStatusResponse {
|
||||
pub permission_mode: String,
|
||||
pub default_cwd: String,
|
||||
pub busy: bool,
|
||||
#[serde(default)]
|
||||
pub task_list_id: String,
|
||||
#[serde(default)]
|
||||
pub active_team: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct WorkerTaskListResponse {
|
||||
pub task_list_id: String,
|
||||
pub tasks: Vec<TaskListRecord>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct WorkerTaskSnapshotResponse {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub task: Option<TaskListRecord>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub runtime_task: Option<RuntimeTaskRecord>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct WorkerTeamSnapshotResponse {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub team: Option<TeamRecord>,
|
||||
pub task_list_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct WorkerAgentListResponse {
|
||||
pub agents: Vec<RuntimeTaskRecord>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct WorkerMailboxSummaryResponse {
|
||||
pub mailbox: MailboxSummary,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct WorkerMailboxPendingResponse {
|
||||
pub messages: Vec<runtime::MailboxMessage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct WorkerBackgroundApprovalListResponse {
|
||||
pub approvals: Vec<BackgroundApprovalRecord>,
|
||||
}
|
||||
|
||||
@@ -25,7 +25,10 @@ use runtime::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
use tools::{GlobalToolRegistry, RuntimeToolDefinition};
|
||||
use tools::{
|
||||
with_subagent_approval_handler, GlobalToolRegistry, RuntimeToolDefinition,
|
||||
SubagentApprovalHandler,
|
||||
};
|
||||
|
||||
const DEFAULT_PROMPT_DATE: &str = "2026-03-31";
|
||||
const APPROVAL_POLL_INTERVAL: Duration = Duration::from_millis(250);
|
||||
@@ -227,9 +230,14 @@ impl RuntimeHost {
|
||||
cancel_flag.clone(),
|
||||
);
|
||||
let prompt = compose_user_input(prompt, &attachments);
|
||||
let summary = runtime
|
||||
.run_turn_with_observer(prompt, Some(&mut prompter), &mut observer)
|
||||
.map_err(HostError::Runtime)?;
|
||||
let subagent_handler: Arc<dyn SubagentApprovalHandler> =
|
||||
Arc::new(SubagentApprovalBridge {
|
||||
relay: prompter.relay.clone(),
|
||||
});
|
||||
let summary = with_subagent_approval_handler(subagent_handler, || {
|
||||
runtime.run_turn_with_observer(prompt, Some(&mut prompter), &mut observer)
|
||||
})
|
||||
.map_err(HostError::Runtime)?;
|
||||
let final_text = final_assistant_text(&summary);
|
||||
let generated_files = collect_generated_files(&summary);
|
||||
let _ = event_tx.send(RuntimeEvent::Completed {
|
||||
@@ -1270,6 +1278,14 @@ impl BridgeToolExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_tool_input_json(input: &str) -> Result<Value, ToolError> {
|
||||
if input.trim().is_empty() {
|
||||
return Ok(json!({}));
|
||||
}
|
||||
serde_json::from_str(input)
|
||||
.map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))
|
||||
}
|
||||
|
||||
impl ToolExecutor for BridgeToolExecutor {
|
||||
fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
|
||||
if self
|
||||
@@ -1281,8 +1297,7 @@ impl ToolExecutor for BridgeToolExecutor {
|
||||
"tool `{tool_name}` is not enabled by the current allowed-tools setting"
|
||||
)));
|
||||
}
|
||||
let value = serde_json::from_str(input)
|
||||
.map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?;
|
||||
let value = parse_tool_input_json(input)?;
|
||||
if tool_name == "ToolSearch" {
|
||||
self.execute_search_tool(value)
|
||||
} else if self.tool_registry.has_runtime_tool(tool_name) {
|
||||
@@ -1354,15 +1369,55 @@ impl TurnEventObserver for BridgeObserver {
|
||||
}
|
||||
|
||||
struct BridgePermissionPrompter {
|
||||
relay: Arc<ApprovalRelay>,
|
||||
}
|
||||
|
||||
impl BridgePermissionPrompter {
|
||||
fn new(
|
||||
tx: UnboundedSender<RuntimeEvent>,
|
||||
current_mode: PermissionMode,
|
||||
timeout: Duration,
|
||||
approval_state: Arc<Mutex<SessionApprovalState>>,
|
||||
cancel_flag: Arc<AtomicBool>,
|
||||
) -> Self {
|
||||
Self {
|
||||
relay: Arc::new(ApprovalRelay::new(
|
||||
tx,
|
||||
current_mode,
|
||||
timeout,
|
||||
approval_state,
|
||||
cancel_flag,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PermissionPrompter for BridgePermissionPrompter {
|
||||
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
|
||||
self.relay.decide(request)
|
||||
}
|
||||
}
|
||||
|
||||
struct SubagentApprovalBridge {
|
||||
relay: Arc<ApprovalRelay>,
|
||||
}
|
||||
|
||||
impl SubagentApprovalHandler for SubagentApprovalBridge {
|
||||
fn decide(&self, request: &PermissionRequest) -> PermissionPromptDecision {
|
||||
self.relay.decide(request)
|
||||
}
|
||||
}
|
||||
|
||||
struct ApprovalRelay {
|
||||
tx: UnboundedSender<RuntimeEvent>,
|
||||
current_mode: PermissionMode,
|
||||
timeout: Duration,
|
||||
approval_state: Arc<Mutex<SessionApprovalState>>,
|
||||
cancel_flag: Arc<AtomicBool>,
|
||||
next_approval_id: u64,
|
||||
next_approval_id: std::sync::atomic::AtomicU64,
|
||||
}
|
||||
|
||||
impl BridgePermissionPrompter {
|
||||
impl ApprovalRelay {
|
||||
fn new(
|
||||
tx: UnboundedSender<RuntimeEvent>,
|
||||
current_mode: PermissionMode,
|
||||
@@ -1376,13 +1431,11 @@ impl BridgePermissionPrompter {
|
||||
timeout,
|
||||
approval_state,
|
||||
cancel_flag,
|
||||
next_approval_id: 0,
|
||||
next_approval_id: std::sync::atomic::AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PermissionPrompter for BridgePermissionPrompter {
|
||||
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
|
||||
fn decide(&self, request: &PermissionRequest) -> PermissionPromptDecision {
|
||||
if self
|
||||
.approval_state
|
||||
.lock()
|
||||
@@ -1398,8 +1451,10 @@ impl PermissionPrompter for BridgePermissionPrompter {
|
||||
};
|
||||
}
|
||||
|
||||
self.next_approval_id += 1;
|
||||
let approval_id = format!("approval-{}", self.next_approval_id);
|
||||
let approval_id = format!(
|
||||
"approval-{}",
|
||||
self.next_approval_id.fetch_add(1, Ordering::SeqCst) + 1
|
||||
);
|
||||
let (decision_tx, decision_rx) = mpsc::channel();
|
||||
let request_payload = ApprovalRequestPayload {
|
||||
approval_id,
|
||||
@@ -1431,7 +1486,17 @@ impl PermissionPrompter for BridgePermissionPrompter {
|
||||
}
|
||||
let wait_for = remaining.min(APPROVAL_POLL_INTERVAL);
|
||||
match decision_rx.recv_timeout(wait_for) {
|
||||
Ok(ApprovalDecision::ApproveOnce) => {
|
||||
Ok(ApprovalDecision::ApproveOnce) => return PermissionPromptDecision::Allow,
|
||||
Ok(ApprovalDecision::ApproveToolForSession) => {
|
||||
if let Ok(mut state) = self.approval_state.lock() {
|
||||
state.allow_tool(request.tool_name.clone());
|
||||
}
|
||||
return PermissionPromptDecision::Allow;
|
||||
}
|
||||
Ok(ApprovalDecision::ApproveAllForSession) => {
|
||||
if let Ok(mut state) = self.approval_state.lock() {
|
||||
state.allow_all();
|
||||
}
|
||||
return PermissionPromptDecision::Allow;
|
||||
}
|
||||
Ok(ApprovalDecision::ApproveToolForSession) => {
|
||||
@@ -1608,8 +1673,12 @@ mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use runtime::{ContentBlock, ConversationMessage};
|
||||
use serde_json::json;
|
||||
|
||||
use super::{collect_generated_files, compose_user_input, AttachmentKind, AttachmentRef};
|
||||
use super::{
|
||||
collect_generated_files, compose_user_input, parse_tool_input_json, AttachmentKind,
|
||||
AttachmentRef,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn compose_user_input_includes_attachment_paths() {
|
||||
@@ -1666,4 +1735,13 @@ mod tests {
|
||||
};
|
||||
assert!(collect_generated_files(&summary).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_input_json_treats_blank_input_as_empty_object() {
|
||||
let value = parse_tool_input_json("").expect("blank input should parse");
|
||||
assert_eq!(value, json!({}));
|
||||
|
||||
let whitespace = parse_tool_input_json(" \n\t").expect("whitespace should parse");
|
||||
assert_eq!(whitespace, json!({}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ axum = { version = "0.7", features = ["multipart"] }
|
||||
base64 = "0.22"
|
||||
channel-gateway-core = { path = "../channel-gateway-core" }
|
||||
futures-core = "0.3"
|
||||
runtime = { path = "../runtime" }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json.workspace = true
|
||||
subtle = "2"
|
||||
|
||||
@@ -16,9 +16,17 @@ use base64::Engine as _;
|
||||
use channel_gateway_core::{
|
||||
ApprovalDecision, ApprovalResponder, AttachmentRef, GeneratedFileDescriptor, HostError,
|
||||
RuntimeEvent, RuntimeHost, RuntimeHostConfig, SessionApprovalState, WorkerApprovalDecision,
|
||||
WorkerStatusResponse, WorkerTurnAccepted, WorkerTurnEvent, WorkerTurnRequest,
|
||||
WorkerAgentListResponse, WorkerBackgroundApprovalListResponse, WorkerMailboxMessageEvent,
|
||||
WorkerMailboxPendingResponse, WorkerMailboxSummaryResponse, WorkerStatusResponse,
|
||||
WorkerTaskListResponse, WorkerTaskSnapshotResponse, WorkerTeamCreatedEvent,
|
||||
WorkerTeamSnapshotResponse, WorkerTurnAccepted, WorkerTurnEvent, WorkerTurnRequest,
|
||||
};
|
||||
use runtime::{
|
||||
current_task_list_id, BackgroundApprovalDecision, BackgroundApprovalStore, RuntimeTaskKind,
|
||||
RuntimeTaskRecord, RuntimeTaskStore, TaskListStore, TeamStore,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use subtle::ConstantTimeEq;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::{broadcast, mpsc, Mutex as AsyncMutex};
|
||||
@@ -41,6 +49,28 @@ fn app_router(config: WorkerConfig, runtime: Arc<dyn WorkerRuntime>) -> Router {
|
||||
.route("/healthz", get(health))
|
||||
.route("/v1/status", get(status))
|
||||
.route("/v1/session/reset", post(reset_session))
|
||||
.route("/v1/tasks", get(list_tasks))
|
||||
.route("/v1/tasks/:task_id", get(get_task))
|
||||
.route("/v1/tasks/:task_id/stop", post(stop_task))
|
||||
.route("/v1/team", get(get_team))
|
||||
.route("/v1/agents", get(list_agents))
|
||||
.route("/v1/agents/:agent_id", get(get_agent))
|
||||
.route("/v1/agents/:agent_id/notified", post(mark_agent_notified))
|
||||
.route("/v1/background-approvals", get(list_background_approvals))
|
||||
.route(
|
||||
"/v1/background-approvals/:approval_id",
|
||||
post(post_background_approval),
|
||||
)
|
||||
.route(
|
||||
"/v1/background-approvals/:approval_id/notified",
|
||||
post(mark_background_approval_notified),
|
||||
)
|
||||
.route("/v1/mailbox", get(get_mailbox))
|
||||
.route("/v1/mailbox/pending/:recipient", get(get_pending_mailbox_messages))
|
||||
.route(
|
||||
"/v1/mailbox/:recipient/notified",
|
||||
post(mark_mailbox_messages_notified),
|
||||
)
|
||||
.route("/v1/turns", post(post_turn))
|
||||
.route("/v1/turns/:turn_id/events", get(stream_events))
|
||||
.route("/v1/turns/:turn_id/approval", post(post_approval))
|
||||
@@ -207,6 +237,11 @@ struct ApprovalRequest {
|
||||
decision: WorkerApprovalDecision,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MailboxNotifiedRequest {
|
||||
ids: Vec<String>,
|
||||
}
|
||||
|
||||
async fn health() -> impl IntoResponse {
|
||||
(StatusCode::OK, "ok")
|
||||
}
|
||||
@@ -229,9 +264,225 @@ async fn status(
|
||||
permission_mode: state.config.permission_mode.as_str().to_string(),
|
||||
default_cwd: state.config.default_cwd.display().to_string(),
|
||||
busy,
|
||||
task_list_id: current_task_list_id().unwrap_or_else(|_| state.config.profile_id.clone()),
|
||||
active_team: TeamStore::new()
|
||||
.current_team()
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(|team| team.team_name),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn list_tasks(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<WorkerTaskListResponse>, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
let store = TaskListStore::current().map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let tasks = store.list(false).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(Json(WorkerTaskListResponse {
|
||||
task_list_id: store.task_list_id().to_string(),
|
||||
tasks,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn get_task(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
AxumPath(task_id): AxumPath<String>,
|
||||
) -> Result<Json<WorkerTaskSnapshotResponse>, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
let task = TaskListStore::current()
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.get(&task_id)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let runtime_task = RuntimeTaskStore::new()
|
||||
.get(&task_id)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
if task.is_none() && runtime_task.is_none() {
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
Ok(Json(WorkerTaskSnapshotResponse { task, runtime_task }))
|
||||
}
|
||||
|
||||
async fn stop_task(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
AxumPath(task_id): AxumPath<String>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
match RuntimeTaskStore::new().stop(&task_id) {
|
||||
Ok(Some(_)) => Ok(StatusCode::ACCEPTED),
|
||||
Ok(None) => Err(StatusCode::NOT_FOUND),
|
||||
Err(_) => Err(StatusCode::BAD_REQUEST),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_team(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<WorkerTeamSnapshotResponse>, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
let task_list_id = current_task_list_id().unwrap_or_else(|_| state.config.profile_id.clone());
|
||||
let team = TeamStore::new()
|
||||
.current_team()
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(Json(WorkerTeamSnapshotResponse { team, task_list_id }))
|
||||
}
|
||||
|
||||
async fn list_agents(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<WorkerAgentListResponse>, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
let agents = RuntimeTaskStore::new()
|
||||
.list()
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.into_iter()
|
||||
.filter(|task| task.kind == RuntimeTaskKind::Agent)
|
||||
.collect();
|
||||
Ok(Json(WorkerAgentListResponse { agents }))
|
||||
}
|
||||
|
||||
async fn get_agent(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
AxumPath(agent_id): AxumPath<String>,
|
||||
) -> Result<Json<RuntimeTaskRecord>, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
let task = RuntimeTaskStore::new()
|
||||
.get(&agent_id)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.filter(|task| task.kind == RuntimeTaskKind::Agent)
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
Ok(Json(task))
|
||||
}
|
||||
|
||||
async fn mark_agent_notified(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
AxumPath(agent_id): AxumPath<String>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
RuntimeTaskStore::new()
|
||||
.mark_notified(&agent_id)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.filter(|task| task.kind == RuntimeTaskKind::Agent)
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
Ok(StatusCode::ACCEPTED)
|
||||
}
|
||||
|
||||
async fn list_background_approvals(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<WorkerBackgroundApprovalListResponse>, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
let approvals = BackgroundApprovalStore::new()
|
||||
.list_pending()
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(Json(WorkerBackgroundApprovalListResponse { approvals }))
|
||||
}
|
||||
|
||||
async fn post_background_approval(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
AxumPath(approval_id): AxumPath<String>,
|
||||
Json(request): Json<ApprovalRequest>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
if request.approval_id != approval_id {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
let store = BackgroundApprovalStore::new();
|
||||
let record = store
|
||||
.get(&approval_id)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.filter(|record| record.decision.is_none())
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
apply_session_scope_if_needed(
|
||||
&state.config,
|
||||
&state.approval_state,
|
||||
&record.tool_name,
|
||||
&request.decision,
|
||||
);
|
||||
store
|
||||
.resolve(&approval_id, map_background_approval_decision(request.decision))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
Ok(StatusCode::ACCEPTED)
|
||||
}
|
||||
|
||||
async fn mark_background_approval_notified(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
AxumPath(approval_id): AxumPath<String>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
BackgroundApprovalStore::new()
|
||||
.mark_notified(&approval_id)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.filter(|record| record.decision.is_none())
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
Ok(StatusCode::ACCEPTED)
|
||||
}
|
||||
|
||||
async fn get_mailbox(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<WorkerMailboxSummaryResponse>, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
let mailbox = match TeamStore::new()
|
||||
.current_team()
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
{
|
||||
Some(team) => TeamStore::new()
|
||||
.mailbox_summary(&team.team_name, 20)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
|
||||
None => runtime::MailboxSummary {
|
||||
team_name: None,
|
||||
recent_messages: Vec::new(),
|
||||
},
|
||||
};
|
||||
Ok(Json(WorkerMailboxSummaryResponse { mailbox }))
|
||||
}
|
||||
|
||||
async fn get_pending_mailbox_messages(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
AxumPath(recipient): AxumPath<String>,
|
||||
) -> Result<Json<WorkerMailboxPendingResponse>, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
let messages = match TeamStore::new()
|
||||
.current_team()
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
{
|
||||
Some(team) => TeamStore::new()
|
||||
.pending_messages(&team.team_name, &recipient, 20)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
|
||||
None => Vec::new(),
|
||||
};
|
||||
Ok(Json(WorkerMailboxPendingResponse { messages }))
|
||||
}
|
||||
|
||||
async fn mark_mailbox_messages_notified(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
AxumPath(recipient): AxumPath<String>,
|
||||
Json(request): Json<MailboxNotifiedRequest>,
|
||||
) -> Result<StatusCode, StatusCode> {
|
||||
authorize(&headers, &state.config.auth_token)?;
|
||||
let Some(team) = TeamStore::new()
|
||||
.current_team()
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
else {
|
||||
return Ok(StatusCode::ACCEPTED);
|
||||
};
|
||||
TeamStore::new()
|
||||
.mark_messages_notified(&team.team_name, &recipient, &request.ids)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
Ok(StatusCode::ACCEPTED)
|
||||
}
|
||||
|
||||
async fn reset_session(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
@@ -326,6 +577,7 @@ async fn post_turn(
|
||||
output,
|
||||
is_error,
|
||||
} => {
|
||||
let derived_events = derive_tool_events(&tool_name, &output, is_error);
|
||||
turn_state_for_events
|
||||
.push_event(WorkerTurnEvent::ToolResult {
|
||||
tool_use_id,
|
||||
@@ -334,6 +586,9 @@ async fn post_turn(
|
||||
is_error,
|
||||
})
|
||||
.await;
|
||||
for derived in derived_events {
|
||||
turn_state_for_events.push_event(derived).await;
|
||||
}
|
||||
}
|
||||
RuntimeEvent::ApprovalRequested { request, responder } => {
|
||||
turn_state_for_events
|
||||
@@ -614,6 +869,165 @@ fn map_approval_decision(decision: WorkerApprovalDecision) -> ApprovalDecision {
|
||||
}
|
||||
}
|
||||
|
||||
fn map_background_approval_decision(
|
||||
decision: WorkerApprovalDecision,
|
||||
) -> BackgroundApprovalDecision {
|
||||
match decision {
|
||||
WorkerApprovalDecision::ApproveOnce => BackgroundApprovalDecision::ApproveOnce,
|
||||
WorkerApprovalDecision::ApproveToolForSession => {
|
||||
BackgroundApprovalDecision::ApproveToolForSession
|
||||
}
|
||||
WorkerApprovalDecision::ApproveAllForSession => {
|
||||
BackgroundApprovalDecision::ApproveAllForSession
|
||||
}
|
||||
WorkerApprovalDecision::Deny { reason } => BackgroundApprovalDecision::Deny { reason },
|
||||
WorkerApprovalDecision::CancelTurn => BackgroundApprovalDecision::CancelTurn,
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_session_scope_if_needed(
|
||||
config: &WorkerConfig,
|
||||
approval_state: &Arc<StdMutex<SessionApprovalState>>,
|
||||
tool_name: &str,
|
||||
decision: &WorkerApprovalDecision,
|
||||
) {
|
||||
let updated = if let Ok(mut state) = approval_state.lock() {
|
||||
match decision {
|
||||
WorkerApprovalDecision::ApproveToolForSession => {
|
||||
state.allow_tool(tool_name.to_string());
|
||||
true
|
||||
}
|
||||
WorkerApprovalDecision::ApproveAllForSession => {
|
||||
state.allow_all();
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
if updated {
|
||||
let _ = persist_session_approval_state(config, approval_state);
|
||||
}
|
||||
}
|
||||
|
||||
fn derive_tool_events(tool_name: &str, output: &str, is_error: bool) -> Vec<WorkerTurnEvent> {
|
||||
if is_error {
|
||||
return Vec::new();
|
||||
}
|
||||
let Ok(value) = serde_json::from_str::<Value>(output) else {
|
||||
return Vec::new();
|
||||
};
|
||||
match tool_name {
|
||||
"TaskCreate" => derive_task_created_event(&value).into_iter().collect(),
|
||||
"TaskUpdate" => derive_task_updated_event(&value).into_iter().collect(),
|
||||
"TaskStop" => derive_task_stopped_event(&value).into_iter().collect(),
|
||||
"Agent" => derive_agent_spawned_event(&value).into_iter().collect(),
|
||||
"TeamCreate" => derive_team_created_event(&value).into_iter().collect(),
|
||||
"TeamDelete" => derive_team_deleted_event(&value).into_iter().collect(),
|
||||
"SendMessage" => derive_mailbox_message_event(&value).into_iter().collect(),
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn derive_task_created_event(value: &Value) -> Option<WorkerTurnEvent> {
|
||||
let task_id = value.get("task")?.get("id")?.as_str()?;
|
||||
let store = TaskListStore::current().ok()?;
|
||||
let task = store.get(task_id).ok()??;
|
||||
Some(WorkerTurnEvent::TaskCreated {
|
||||
task_list_id: store.task_list_id().to_string(),
|
||||
task,
|
||||
})
|
||||
}
|
||||
|
||||
fn derive_task_updated_event(value: &Value) -> Option<WorkerTurnEvent> {
|
||||
let task_id = value.get("id")?.as_str()?;
|
||||
let store = TaskListStore::current().ok()?;
|
||||
let task = store.get(task_id).ok()??;
|
||||
Some(WorkerTurnEvent::TaskUpdated {
|
||||
task_list_id: store.task_list_id().to_string(),
|
||||
task,
|
||||
})
|
||||
}
|
||||
|
||||
fn derive_task_stopped_event(value: &Value) -> Option<WorkerTurnEvent> {
|
||||
let task_id = value.get("task_id")?.as_str()?;
|
||||
let task = RuntimeTaskStore::new().get(task_id).ok()??;
|
||||
Some(WorkerTurnEvent::TaskStopped { task })
|
||||
}
|
||||
|
||||
fn derive_agent_spawned_event(value: &Value) -> Option<WorkerTurnEvent> {
|
||||
let task_id = value
|
||||
.get("taskId")
|
||||
.and_then(Value::as_str)
|
||||
.or_else(|| value.get("agentId").and_then(Value::as_str))?;
|
||||
let task = RuntimeTaskStore::new().get(task_id).ok()??;
|
||||
Some(WorkerTurnEvent::AgentSpawned { agent: task })
|
||||
}
|
||||
|
||||
fn derive_team_created_event(value: &Value) -> Option<WorkerTurnEvent> {
|
||||
let team_name = value.get("team_name")?.as_str()?;
|
||||
let team = TeamStore::new().get_team(team_name).ok()??;
|
||||
Some(WorkerTurnEvent::TeamCreated {
|
||||
team: WorkerTeamCreatedEvent {
|
||||
task_list_id: current_task_list_id().ok()?,
|
||||
team_file_path: value
|
||||
.get("team_file_path")
|
||||
.and_then(Value::as_str)
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
team,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn derive_team_deleted_event(value: &Value) -> Option<WorkerTurnEvent> {
|
||||
Some(WorkerTurnEvent::TeamDeleted {
|
||||
team_name: value.get("team_name")?.as_str()?.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn derive_mailbox_message_event(value: &Value) -> Option<WorkerTurnEvent> {
|
||||
let recipients = value
|
||||
.get("delivered")
|
||||
.and_then(Value::as_array)
|
||||
.map(|entries| {
|
||||
entries
|
||||
.iter()
|
||||
.filter_map(|entry| {
|
||||
entry
|
||||
.get("recipient")
|
||||
.and_then(Value::as_str)
|
||||
.map(ToString::to_string)
|
||||
.or_else(|| {
|
||||
entry.get("envelope")
|
||||
.and_then(|env| env.get("to"))
|
||||
.and_then(Value::as_str)
|
||||
.map(ToString::to_string)
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let summary = value
|
||||
.get("delivered")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|entries| entries.first())
|
||||
.and_then(|entry| entry.get("envelope"))
|
||||
.and_then(|env| env.get("summary"))
|
||||
.and_then(Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
Some(WorkerTurnEvent::MailboxMessage {
|
||||
message: WorkerMailboxMessageEvent {
|
||||
team_name: value.get("team_name")?.as_str()?.to_string(),
|
||||
sender: value.get("sender")?.as_str()?.to_string(),
|
||||
count: value.get("count")?.as_u64()? as usize,
|
||||
recipients,
|
||||
summary,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn approval_state_path(config: &WorkerConfig) -> PathBuf {
|
||||
config.state_root.join("approval-session.json")
|
||||
}
|
||||
@@ -744,8 +1158,18 @@ impl From<HostError> for ServerError {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::{Mutex, MutexGuard, OnceLock};
|
||||
|
||||
use super::*;
|
||||
use channel_gateway_core::{RuntimeEvent, TurnSource};
|
||||
use serde_json::json;
|
||||
|
||||
fn env_lock() -> MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock should not be poisoned")
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MockRuntime {
|
||||
@@ -909,4 +1333,74 @@ mod tests {
|
||||
ApprovalDecision::ApproveAllForSession
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derive_tool_events_emits_task_created_notice() {
|
||||
let _lock = env_lock();
|
||||
let root = std::env::temp_dir().join(format!("claw-worker-derived-task-{}", next_turn_id()));
|
||||
let _ = std::fs::remove_dir_all(&root);
|
||||
std::env::set_var("CLAW_WORKER_STATE_ROOT", &root);
|
||||
std::env::set_var("CLAW_WORKER_PROFILE_ID", "makar");
|
||||
|
||||
let task = TaskListStore::current()
|
||||
.expect("task store")
|
||||
.create(
|
||||
"Ship parity slice".to_string(),
|
||||
"Verify derived events".to_string(),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.expect("task should persist");
|
||||
|
||||
let events = derive_tool_events(
|
||||
"TaskCreate",
|
||||
&json!({
|
||||
"task": { "id": task.id },
|
||||
"task_list_id": "makar"
|
||||
})
|
||||
.to_string(),
|
||||
false,
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
events.first(),
|
||||
Some(WorkerTurnEvent::TaskCreated { task, .. }) if task.subject == "Ship parity slice"
|
||||
));
|
||||
|
||||
std::env::remove_var("CLAW_WORKER_STATE_ROOT");
|
||||
std::env::remove_var("CLAW_WORKER_PROFILE_ID");
|
||||
let _ = std::fs::remove_dir_all(&root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derive_tool_events_emits_team_created_notice() {
|
||||
let _lock = env_lock();
|
||||
let root = std::env::temp_dir().join(format!("claw-worker-derived-team-{}", next_turn_id()));
|
||||
let _ = std::fs::remove_dir_all(&root);
|
||||
std::env::set_var("CLAW_WORKER_STATE_ROOT", &root);
|
||||
std::env::set_var("CLAW_WORKER_PROFILE_ID", "makar");
|
||||
|
||||
let team = TeamStore::new()
|
||||
.create_team("alpha", Some("test team".to_string()), None)
|
||||
.expect("team should persist");
|
||||
|
||||
let events = derive_tool_events(
|
||||
"TeamCreate",
|
||||
&json!({
|
||||
"team_name": team.team_name,
|
||||
"team_file_path": root.join("teams").join("alpha").join("config.json").display().to_string(),
|
||||
})
|
||||
.to_string(),
|
||||
false,
|
||||
);
|
||||
|
||||
assert!(matches!(
|
||||
events.first(),
|
||||
Some(WorkerTurnEvent::TeamCreated { team }) if team.team.team_name == "alpha"
|
||||
));
|
||||
|
||||
std::env::remove_var("CLAW_WORKER_STATE_ROOT");
|
||||
std::env::remove_var("CLAW_WORKER_PROFILE_ID");
|
||||
let _ = std::fs::remove_dir_all(&root);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ license.workspace = true
|
||||
publish.workspace = true
|
||||
|
||||
[dependencies]
|
||||
axum = "0.7"
|
||||
bollard = "0.17"
|
||||
channel-gateway-core = { path = "../channel-gateway-core" }
|
||||
futures-util = "0.3"
|
||||
@@ -17,8 +18,9 @@ reqwest = { version = "0.12", default-features = false, features = ["json", "mul
|
||||
runtime = { path = "../runtime" }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json.workspace = true
|
||||
tokio = { version = "1", features = ["fs", "macros", "rt-multi-thread", "signal", "sync", "time"] }
|
||||
tokio = { version = "1", features = ["fs", "macros", "net", "rt-multi-thread", "signal", "sync", "time"] }
|
||||
tools = { path = "../tools" }
|
||||
url = "2"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -976,15 +976,21 @@ fn approval_keyboard(turn_id: &str, approval_id: &str) -> InlineKeyboardMarkup {
|
||||
inline_keyboard: vec![vec![
|
||||
InlineKeyboardButton {
|
||||
text: "Approve once".to_string(),
|
||||
callback_data: format!("cta:{turn_id}:{approval_id}:allow"),
|
||||
callback_data: Some(format!("cta:{turn_id}:{approval_id}:allow")),
|
||||
url: None,
|
||||
web_app: None,
|
||||
},
|
||||
InlineKeyboardButton {
|
||||
text: "Deny".to_string(),
|
||||
callback_data: format!("cta:{turn_id}:{approval_id}:deny"),
|
||||
callback_data: Some(format!("cta:{turn_id}:{approval_id}:deny")),
|
||||
url: None,
|
||||
web_app: None,
|
||||
},
|
||||
InlineKeyboardButton {
|
||||
text: "Cancel turn".to_string(),
|
||||
callback_data: format!("cta:{turn_id}:{approval_id}:cancel"),
|
||||
callback_data: Some(format!("cta:{turn_id}:{approval_id}:cancel")),
|
||||
url: None,
|
||||
web_app: None,
|
||||
},
|
||||
]],
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ const DEFAULT_GATEWAY_WORKER_DEFAULT_CWD: &str = "/workspace";
|
||||
const DEFAULT_GATEWAY_WORKER_PERMISSION_MODE: &str = "workspace-write";
|
||||
const DEFAULT_GATEWAY_WORKER_HOST_STATE_ROOT: &str = "/mnt/user/appdata/claw-workers";
|
||||
const DEFAULT_GATEWAY_WORKER_HOST_WORKSPACE_ROOT: &str = "/mnt/user/appdata/claw-workers";
|
||||
const DEFAULT_GATEWAY_MINIAPP_SESSION_TTL_SECS: u64 = 3600;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct TelegramBotConfig {
|
||||
@@ -149,6 +150,9 @@ pub struct GatewayConfig {
|
||||
pub worker_host_state_root: PathBuf,
|
||||
pub worker_host_workspace_root: PathBuf,
|
||||
pub inherited_env: Vec<String>,
|
||||
pub miniapp_bind_addr: Option<String>,
|
||||
pub miniapp_public_base_url: Option<String>,
|
||||
pub miniapp_session_ttl_secs: u64,
|
||||
bot_token: String,
|
||||
}
|
||||
|
||||
@@ -180,6 +184,9 @@ impl std::fmt::Debug for GatewayConfig {
|
||||
&self.worker_host_workspace_root,
|
||||
)
|
||||
.field("inherited_env", &self.inherited_env)
|
||||
.field("miniapp_bind_addr", &self.miniapp_bind_addr)
|
||||
.field("miniapp_public_base_url", &self.miniapp_public_base_url)
|
||||
.field("miniapp_session_ttl_secs", &self.miniapp_session_ttl_secs)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@@ -270,6 +277,16 @@ impl GatewayConfig {
|
||||
inherited_env: parse_csv_list(
|
||||
vars.get("CLAW_GATEWAY_INHERITED_ENV").map(String::as_str),
|
||||
),
|
||||
miniapp_bind_addr: optional_string(&vars, "CLAW_GATEWAY_MINIAPP_BIND_ADDR"),
|
||||
miniapp_public_base_url: optional_string(
|
||||
&vars,
|
||||
"CLAW_GATEWAY_MINIAPP_PUBLIC_BASE_URL",
|
||||
),
|
||||
miniapp_session_ttl_secs: parse_u64(
|
||||
vars.get("CLAW_GATEWAY_MINIAPP_SESSION_TTL_SECS"),
|
||||
DEFAULT_GATEWAY_MINIAPP_SESSION_TTL_SECS,
|
||||
"CLAW_GATEWAY_MINIAPP_SESSION_TTL_SECS",
|
||||
)?,
|
||||
bot_token,
|
||||
})
|
||||
}
|
||||
@@ -334,6 +351,17 @@ fn optional_path(vars: &std::collections::BTreeMap<String, String>, key: &str) -
|
||||
.map(PathBuf::from)
|
||||
}
|
||||
|
||||
fn optional_string(
|
||||
vars: &std::collections::BTreeMap<String, String>,
|
||||
key: &str,
|
||||
) -> Option<String> {
|
||||
vars.get(key)
|
||||
.map(String::as_str)
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(ToOwned::to_owned)
|
||||
}
|
||||
|
||||
fn default_state_root() -> PathBuf {
|
||||
default_cwd().join(".claw-telegram")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@ mod bot;
|
||||
mod config;
|
||||
mod docker_worker_manager;
|
||||
mod gateway;
|
||||
mod miniapp;
|
||||
mod registry;
|
||||
mod runtime_host;
|
||||
mod telegram_api;
|
||||
|
||||
@@ -0,0 +1,614 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use axum::extract::State;
|
||||
use axum::http::{header, HeaderMap, StatusCode};
|
||||
use axum::response::Html;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use channel_gateway_core::ProfileRecord;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
use url::form_urlencoded;
|
||||
|
||||
use crate::config::GatewayConfig;
|
||||
use crate::docker_worker_manager::{DockerWorkerManager, WorkerManagerError};
|
||||
use crate::gateway::{load_manifest, GatewayError};
|
||||
use crate::worker_client::{WorkerClient, WorkerClientError};
|
||||
|
||||
const INDEX_HTML: &str = r#"<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, viewport-fit=cover">
|
||||
<title>Claw Mini App</title>
|
||||
<style>
|
||||
:root {
|
||||
--bg: #f6f1e8;
|
||||
--panel: #fffdf7;
|
||||
--ink: #182022;
|
||||
--muted: #6c736f;
|
||||
--line: #d7cfbf;
|
||||
--accent: #0e7a66;
|
||||
--accent-2: #d96b2b;
|
||||
--shadow: 0 14px 40px rgba(24,32,34,.08);
|
||||
}
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
margin: 0;
|
||||
font-family: "Iowan Old Style", "Palatino Linotype", Georgia, serif;
|
||||
background:
|
||||
radial-gradient(circle at top right, rgba(217,107,43,.10), transparent 30%),
|
||||
linear-gradient(180deg, #fbf7ef 0%, var(--bg) 100%);
|
||||
color: var(--ink);
|
||||
}
|
||||
.shell {
|
||||
max-width: 1100px;
|
||||
margin: 0 auto;
|
||||
padding: 20px 16px 48px;
|
||||
}
|
||||
.hero {
|
||||
margin-bottom: 20px;
|
||||
padding: 20px;
|
||||
border: 1px solid var(--line);
|
||||
border-radius: 18px;
|
||||
background: rgba(255,253,247,.88);
|
||||
box-shadow: var(--shadow);
|
||||
backdrop-filter: blur(12px);
|
||||
}
|
||||
.eyebrow {
|
||||
font-size: 12px;
|
||||
letter-spacing: .18em;
|
||||
text-transform: uppercase;
|
||||
color: var(--muted);
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 8px;
|
||||
font-size: 34px;
|
||||
line-height: 1.05;
|
||||
}
|
||||
.sub {
|
||||
margin: 0;
|
||||
color: var(--muted);
|
||||
font-size: 16px;
|
||||
}
|
||||
.toolbar {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
flex-wrap: wrap;
|
||||
margin-top: 18px;
|
||||
}
|
||||
button {
|
||||
border: 0;
|
||||
border-radius: 999px;
|
||||
padding: 10px 14px;
|
||||
cursor: pointer;
|
||||
background: var(--accent);
|
||||
color: white;
|
||||
font: inherit;
|
||||
}
|
||||
button.secondary {
|
||||
background: #ebe3d6;
|
||||
color: var(--ink);
|
||||
}
|
||||
.grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
|
||||
gap: 14px;
|
||||
}
|
||||
.card {
|
||||
border: 1px solid var(--line);
|
||||
border-radius: 16px;
|
||||
background: var(--panel);
|
||||
box-shadow: var(--shadow);
|
||||
overflow: hidden;
|
||||
}
|
||||
.card h2 {
|
||||
margin: 0;
|
||||
padding: 14px 16px;
|
||||
border-bottom: 1px solid var(--line);
|
||||
font-size: 17px;
|
||||
background: linear-gradient(90deg, rgba(14,122,102,.08), rgba(217,107,43,.04));
|
||||
}
|
||||
pre {
|
||||
margin: 0;
|
||||
padding: 14px 16px;
|
||||
font-family: "SFMono-Regular", Consolas, monospace;
|
||||
font-size: 12px;
|
||||
line-height: 1.55;
|
||||
overflow: auto;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
}
|
||||
.status {
|
||||
margin-top: 14px;
|
||||
padding: 12px 14px;
|
||||
border-radius: 14px;
|
||||
background: rgba(14,122,102,.08);
|
||||
color: var(--ink);
|
||||
}
|
||||
.status.error {
|
||||
background: rgba(217,107,43,.12);
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="shell">
|
||||
<section class="hero">
|
||||
<div class="eyebrow">Telegram Mini App</div>
|
||||
<h1>Claw Operator Console</h1>
|
||||
<p class="sub">Tasks, agents, team state, mailbox, and approvals for the currently mapped Telegram profile.</p>
|
||||
<div class="toolbar">
|
||||
<button id="refresh">Refresh</button>
|
||||
<button id="expand" class="secondary">Expand</button>
|
||||
</div>
|
||||
<div id="status" class="status">Connecting…</div>
|
||||
</section>
|
||||
<section class="grid">
|
||||
<article class="card"><h2>Bootstrap</h2><pre id="bootstrap"></pre></article>
|
||||
<article class="card"><h2>Status</h2><pre id="worker-status"></pre></article>
|
||||
<article class="card"><h2>Tasks</h2><pre id="tasks"></pre></article>
|
||||
<article class="card"><h2>Team</h2><pre id="team"></pre></article>
|
||||
<article class="card"><h2>Agents</h2><pre id="agents"></pre></article>
|
||||
<article class="card"><h2>Mailbox</h2><pre id="mailbox"></pre></article>
|
||||
<article class="card"><h2>Approvals</h2><pre id="approvals"></pre></article>
|
||||
</section>
|
||||
</div>
|
||||
<script>
|
||||
const tg = window.Telegram?.WebApp;
|
||||
const state = { token: null };
|
||||
|
||||
function pretty(value) {
|
||||
return JSON.stringify(value, null, 2);
|
||||
}
|
||||
|
||||
function setStatus(message, isError = false) {
|
||||
const el = document.getElementById("status");
|
||||
el.textContent = message;
|
||||
el.className = isError ? "status error" : "status";
|
||||
}
|
||||
|
||||
async function call(path) {
|
||||
const response = await fetch(path, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${state.token}`,
|
||||
},
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(`${path} -> ${response.status}`);
|
||||
}
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async function loadAll() {
|
||||
const [bootstrap, workerStatus, tasks, team, agents, mailbox, approvals] = await Promise.all([
|
||||
call("/miniapp/api/bootstrap"),
|
||||
call("/miniapp/api/status"),
|
||||
call("/miniapp/api/tasks"),
|
||||
call("/miniapp/api/team"),
|
||||
call("/miniapp/api/agents"),
|
||||
call("/miniapp/api/mailbox"),
|
||||
call("/miniapp/api/approvals"),
|
||||
]);
|
||||
document.getElementById("bootstrap").textContent = pretty(bootstrap);
|
||||
document.getElementById("worker-status").textContent = pretty(workerStatus);
|
||||
document.getElementById("tasks").textContent = pretty(tasks);
|
||||
document.getElementById("team").textContent = pretty(team);
|
||||
document.getElementById("agents").textContent = pretty(agents);
|
||||
document.getElementById("mailbox").textContent = pretty(mailbox);
|
||||
document.getElementById("approvals").textContent = pretty(approvals);
|
||||
setStatus(`Connected as ${bootstrap.profile_id}`);
|
||||
}
|
||||
|
||||
async function authenticate() {
|
||||
tg?.ready?.();
|
||||
tg?.expand?.();
|
||||
const initData = tg?.initData || "";
|
||||
if (!initData) {
|
||||
setStatus("This page must be opened as a Telegram Mini App so the gateway can verify init data.", true);
|
||||
return;
|
||||
}
|
||||
const response = await fetch("/miniapp/auth", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ init_data: initData }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const message = await response.text();
|
||||
throw new Error(message || `auth failed with ${response.status}`);
|
||||
}
|
||||
const auth = await response.json();
|
||||
state.token = auth.token;
|
||||
await loadAll();
|
||||
}
|
||||
|
||||
document.getElementById("refresh").addEventListener("click", () => {
|
||||
if (state.token) {
|
||||
loadAll().catch(error => setStatus(error.message, true));
|
||||
}
|
||||
});
|
||||
document.getElementById("expand").addEventListener("click", () => tg?.expand?.());
|
||||
|
||||
authenticate().catch(error => setStatus(error.message, true));
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"#;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MiniAppState {
|
||||
config: GatewayConfig,
|
||||
worker_manager: Arc<DockerWorkerManager>,
|
||||
sessions: Arc<Mutex<BTreeMap<String, MiniAppSession>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct MiniAppSession {
|
||||
profile_id: String,
|
||||
display_name: Option<String>,
|
||||
expires_at: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct MiniAppAuthRequest {
|
||||
init_data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct MiniAppAuthResponse {
|
||||
token: String,
|
||||
profile_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
display_name: Option<String>,
|
||||
expires_at: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct MiniAppBootstrapResponse {
|
||||
profile_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
display_name: Option<String>,
|
||||
sections: Vec<&'static str>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TelegramMiniAppUser {
|
||||
id: i64,
|
||||
#[serde(default)]
|
||||
first_name: Option<String>,
|
||||
#[serde(default)]
|
||||
last_name: Option<String>,
|
||||
#[serde(default)]
|
||||
username: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn serve(config: GatewayConfig, bind_addr: String) -> Result<(), MiniAppError> {
|
||||
let worker_manager = Arc::new(DockerWorkerManager::new(config.clone())?);
|
||||
let state = MiniAppState {
|
||||
config,
|
||||
worker_manager,
|
||||
sessions: Arc::new(Mutex::new(BTreeMap::new())),
|
||||
};
|
||||
let app = Router::new()
|
||||
.route("/miniapp", get(index))
|
||||
.route("/miniapp/auth", post(authenticate))
|
||||
.route("/miniapp/api/bootstrap", get(bootstrap))
|
||||
.route("/miniapp/api/status", get(status))
|
||||
.route("/miniapp/api/tasks", get(tasks))
|
||||
.route("/miniapp/api/team", get(team))
|
||||
.route("/miniapp/api/agents", get(agents))
|
||||
.route("/miniapp/api/mailbox", get(mailbox))
|
||||
.route("/miniapp/api/approvals", get(approvals))
|
||||
.with_state(Arc::new(state));
|
||||
|
||||
let listener = TcpListener::bind(&bind_addr).await?;
|
||||
axum::serve(listener, app)
|
||||
.await
|
||||
.map_err(MiniAppError::Io)
|
||||
}
|
||||
|
||||
async fn index() -> Html<&'static str> {
|
||||
Html(INDEX_HTML)
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
State(state): State<Arc<MiniAppState>>,
|
||||
Json(request): Json<MiniAppAuthRequest>,
|
||||
) -> Result<Json<MiniAppAuthResponse>, StatusCode> {
|
||||
let user = verify_init_data(state.config.bot_token(), &request.init_data)
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
let manifest = load_manifest(&state.config).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let profile = manifest
|
||||
.resolve_profile_for_telegram_user(user.id)
|
||||
.ok_or(StatusCode::FORBIDDEN)?;
|
||||
let session = create_session(profile, state.config.miniapp_session_ttl_secs);
|
||||
let token = next_session_token(profile.profile_id.as_str(), user.id);
|
||||
{
|
||||
let mut sessions = state.sessions.lock().await;
|
||||
prune_sessions(&mut sessions);
|
||||
sessions.insert(token.clone(), session.clone());
|
||||
}
|
||||
Ok(Json(MiniAppAuthResponse {
|
||||
token,
|
||||
profile_id: session.profile_id,
|
||||
display_name: session.display_name,
|
||||
expires_at: session.expires_at,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn bootstrap(
|
||||
State(state): State<Arc<MiniAppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<MiniAppBootstrapResponse>, StatusCode> {
|
||||
let session = authorize(&state, &headers).await?;
|
||||
Ok(Json(MiniAppBootstrapResponse {
|
||||
profile_id: session.profile_id,
|
||||
display_name: session.display_name,
|
||||
sections: vec!["status", "tasks", "team", "agents", "mailbox", "approvals"],
|
||||
}))
|
||||
}
|
||||
|
||||
async fn status(
|
||||
State(state): State<Arc<MiniAppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<channel_gateway_core::WorkerStatusResponse>, StatusCode> {
|
||||
let session = authorize(&state, &headers).await?;
|
||||
let client = worker_client(&state, &session.profile_id).await?;
|
||||
client.status().await.map(Json).map_err(status_from_worker_error)
|
||||
}
|
||||
|
||||
async fn tasks(
|
||||
State(state): State<Arc<MiniAppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<channel_gateway_core::WorkerTaskListResponse>, StatusCode> {
|
||||
let session = authorize(&state, &headers).await?;
|
||||
let client = worker_client(&state, &session.profile_id).await?;
|
||||
client
|
||||
.list_tasks()
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(status_from_worker_error)
|
||||
}
|
||||
|
||||
async fn team(
|
||||
State(state): State<Arc<MiniAppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<channel_gateway_core::WorkerTeamSnapshotResponse>, StatusCode> {
|
||||
let session = authorize(&state, &headers).await?;
|
||||
let client = worker_client(&state, &session.profile_id).await?;
|
||||
client.team().await.map(Json).map_err(status_from_worker_error)
|
||||
}
|
||||
|
||||
async fn agents(
|
||||
State(state): State<Arc<MiniAppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<channel_gateway_core::WorkerAgentListResponse>, StatusCode> {
|
||||
let session = authorize(&state, &headers).await?;
|
||||
let client = worker_client(&state, &session.profile_id).await?;
|
||||
client.agents().await.map(Json).map_err(status_from_worker_error)
|
||||
}
|
||||
|
||||
async fn mailbox(
|
||||
State(state): State<Arc<MiniAppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<channel_gateway_core::WorkerMailboxSummaryResponse>, StatusCode> {
|
||||
let session = authorize(&state, &headers).await?;
|
||||
let client = worker_client(&state, &session.profile_id).await?;
|
||||
client.mailbox().await.map(Json).map_err(status_from_worker_error)
|
||||
}
|
||||
|
||||
async fn approvals(
|
||||
State(state): State<Arc<MiniAppState>>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<Json<channel_gateway_core::WorkerBackgroundApprovalListResponse>, StatusCode> {
|
||||
let session = authorize(&state, &headers).await?;
|
||||
let client = worker_client(&state, &session.profile_id).await?;
|
||||
client
|
||||
.background_approvals()
|
||||
.await
|
||||
.map(Json)
|
||||
.map_err(status_from_worker_error)
|
||||
}
|
||||
|
||||
async fn authorize(
|
||||
state: &MiniAppState,
|
||||
headers: &HeaderMap,
|
||||
) -> Result<MiniAppSession, StatusCode> {
|
||||
let token = headers
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.and_then(|value| value.strip_prefix("Bearer "))
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
let mut sessions = state.sessions.lock().await;
|
||||
prune_sessions(&mut sessions);
|
||||
sessions.get(token).cloned().ok_or(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
|
||||
async fn worker_client(
|
||||
state: &MiniAppState,
|
||||
profile_id: &str,
|
||||
) -> Result<WorkerClient, StatusCode> {
|
||||
let manifest = load_manifest(&state.config).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let profile = manifest
|
||||
.profiles
|
||||
.iter()
|
||||
.find(|profile| profile.profile_id.as_str() == profile_id)
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
let worker = state
|
||||
.worker_manager
|
||||
.ensure_profile_worker(&manifest, profile)
|
||||
.await
|
||||
.map_err(|error| {
|
||||
eprintln!("miniapp worker resolution failed for profile {profile_id}: {error}");
|
||||
StatusCode::BAD_GATEWAY
|
||||
})?;
|
||||
WorkerClient::new(&worker.base_url, &state.config.worker_auth_token)
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
fn create_session(profile: &ProfileRecord, ttl_secs: u64) -> MiniAppSession {
|
||||
MiniAppSession {
|
||||
profile_id: profile.profile_id.as_str().to_string(),
|
||||
display_name: profile.display_name.clone(),
|
||||
expires_at: now_secs().saturating_add(ttl_secs),
|
||||
}
|
||||
}
|
||||
|
||||
fn prune_sessions(sessions: &mut BTreeMap<String, MiniAppSession>) {
|
||||
let now = now_secs();
|
||||
sessions.retain(|_, session| session.expires_at > now);
|
||||
}
|
||||
|
||||
fn verify_init_data(bot_token: &str, init_data: &str) -> Result<TelegramMiniAppUser, MiniAppError> {
|
||||
let mut pairs = form_urlencoded::parse(init_data.as_bytes())
|
||||
.into_owned()
|
||||
.collect::<Vec<_>>();
|
||||
let hash_index = pairs
|
||||
.iter()
|
||||
.position(|(key, _)| key == "hash")
|
||||
.ok_or_else(|| MiniAppError::Auth("Telegram init data is missing hash".to_string()))?;
|
||||
let expected_hash = pairs.remove(hash_index).1;
|
||||
pairs.sort_by(|left, right| left.0.cmp(&right.0));
|
||||
let data_check_string = pairs
|
||||
.iter()
|
||||
.map(|(key, value)| format!("{key}={value}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let secret = hmac_sha256(b"WebAppData", bot_token.as_bytes());
|
||||
let actual_hash = hex_string(&hmac_sha256(&secret, data_check_string.as_bytes()));
|
||||
if actual_hash != expected_hash {
|
||||
return Err(MiniAppError::Auth(
|
||||
"Telegram init data signature mismatch".to_string(),
|
||||
));
|
||||
}
|
||||
let auth_date = pairs
|
||||
.iter()
|
||||
.find(|(key, _)| key == "auth_date")
|
||||
.and_then(|(_, value)| value.parse::<u64>().ok())
|
||||
.ok_or_else(|| MiniAppError::Auth("Telegram init data is missing auth_date".to_string()))?;
|
||||
if now_secs().saturating_sub(auth_date) > 86_400 {
|
||||
return Err(MiniAppError::Auth(
|
||||
"Telegram init data is too old".to_string(),
|
||||
));
|
||||
}
|
||||
let user_json = pairs
|
||||
.iter()
|
||||
.find(|(key, _)| key == "user")
|
||||
.map(|(_, value)| value.as_str())
|
||||
.ok_or_else(|| MiniAppError::Auth("Telegram init data is missing user".to_string()))?;
|
||||
serde_json::from_str(user_json)
|
||||
.map_err(|error| MiniAppError::Auth(format!("invalid Telegram user payload: {error}")))
|
||||
}
|
||||
|
||||
fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; 32] {
|
||||
let mut normalized_key = [0_u8; 64];
|
||||
if key.len() > 64 {
|
||||
let digest = Sha256::digest(key);
|
||||
normalized_key[..32].copy_from_slice(&digest);
|
||||
} else {
|
||||
normalized_key[..key.len()].copy_from_slice(key);
|
||||
}
|
||||
|
||||
let mut inner_pad = [0_u8; 64];
|
||||
let mut outer_pad = [0_u8; 64];
|
||||
for (index, byte) in normalized_key.iter().enumerate() {
|
||||
inner_pad[index] = byte ^ 0x36;
|
||||
outer_pad[index] = byte ^ 0x5c;
|
||||
}
|
||||
|
||||
let mut inner = Sha256::new();
|
||||
inner.update(inner_pad);
|
||||
inner.update(data);
|
||||
let inner_digest = inner.finalize();
|
||||
|
||||
let mut outer = Sha256::new();
|
||||
outer.update(outer_pad);
|
||||
outer.update(inner_digest);
|
||||
let digest = outer.finalize();
|
||||
let mut output = [0_u8; 32];
|
||||
output.copy_from_slice(&digest);
|
||||
output
|
||||
}
|
||||
|
||||
fn hex_string(bytes: &[u8]) -> String {
|
||||
const HEX: &[u8; 16] = b"0123456789abcdef";
|
||||
let mut output = String::with_capacity(bytes.len() * 2);
|
||||
for byte in bytes {
|
||||
output.push(HEX[(byte >> 4) as usize] as char);
|
||||
output.push(HEX[(byte & 0x0f) as usize] as char);
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
fn next_session_token(profile_id: &str, user_id: i64) -> String {
|
||||
static COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
let sequence = COUNTER.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
let seed = format!("{profile_id}:{user_id}:{sequence}:{}", now_secs());
|
||||
hex_string(Sha256::digest(seed.as_bytes()).as_slice())
|
||||
}
|
||||
|
||||
fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map_or(0, |duration| duration.as_secs())
|
||||
}
|
||||
|
||||
fn status_from_worker_error(error: WorkerClientError) -> StatusCode {
|
||||
match error {
|
||||
WorkerClientError::Http(http) if http.status() == Some(StatusCode::NOT_FOUND) => {
|
||||
StatusCode::NOT_FOUND
|
||||
}
|
||||
WorkerClientError::Http(http) if http.status() == Some(StatusCode::CONFLICT) => {
|
||||
StatusCode::CONFLICT
|
||||
}
|
||||
_ => StatusCode::BAD_GATEWAY,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum MiniAppError {
|
||||
Io(std::io::Error),
|
||||
Gateway(GatewayError),
|
||||
Worker(WorkerManagerError),
|
||||
Auth(String),
|
||||
}
|
||||
|
||||
impl Display for MiniAppError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Io(error) => write!(f, "{error}"),
|
||||
Self::Gateway(error) => write!(f, "{error}"),
|
||||
Self::Worker(error) => write!(f, "{error}"),
|
||||
Self::Auth(message) => write!(f, "{message}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for MiniAppError {}
|
||||
|
||||
impl From<std::io::Error> for MiniAppError {
|
||||
fn from(error: std::io::Error) -> Self {
|
||||
Self::Io(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GatewayError> for MiniAppError {
|
||||
fn from(error: GatewayError) -> Self {
|
||||
Self::Gateway(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WorkerManagerError> for MiniAppError {
|
||||
fn from(error: WorkerManagerError) -> Self {
|
||||
Self::Worker(error)
|
||||
}
|
||||
}
|
||||
@@ -1191,6 +1191,14 @@ impl BridgeToolExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_tool_input_json(input: &str) -> Result<Value, ToolError> {
|
||||
if input.trim().is_empty() {
|
||||
return Ok(json!({}));
|
||||
}
|
||||
serde_json::from_str(input)
|
||||
.map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))
|
||||
}
|
||||
|
||||
impl ToolExecutor for BridgeToolExecutor {
|
||||
fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
|
||||
if self
|
||||
@@ -1202,8 +1210,7 @@ impl ToolExecutor for BridgeToolExecutor {
|
||||
"tool `{tool_name}` is not enabled by the current allowed-tools setting"
|
||||
)));
|
||||
}
|
||||
let value = serde_json::from_str(input)
|
||||
.map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?;
|
||||
let value = parse_tool_input_json(input)?;
|
||||
if tool_name == "ToolSearch" {
|
||||
self.execute_search_tool(value)
|
||||
} else if self.tool_registry.has_runtime_tool(tool_name) {
|
||||
@@ -1505,8 +1512,12 @@ mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use runtime::{ContentBlock, ConversationMessage};
|
||||
use serde_json::json;
|
||||
|
||||
use super::{collect_generated_files, compose_user_input, AttachmentKind, AttachmentRef};
|
||||
use super::{
|
||||
collect_generated_files, compose_user_input, parse_tool_input_json, AttachmentKind,
|
||||
AttachmentRef,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn compose_user_input_includes_attachment_paths() {
|
||||
@@ -1563,4 +1574,13 @@ mod tests {
|
||||
};
|
||||
assert!(collect_generated_files(&summary).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_input_json_treats_blank_input_as_empty_object() {
|
||||
let value = parse_tool_input_json("").expect("blank input should parse");
|
||||
assert_eq!(value, json!({}));
|
||||
|
||||
let whitespace = parse_tool_input_json(" \n\t").expect("whitespace should parse");
|
||||
assert_eq!(whitespace, json!({}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -612,7 +612,17 @@ pub struct InlineKeyboardMarkup {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct InlineKeyboardButton {
|
||||
pub text: String,
|
||||
pub callback_data: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub callback_data: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub url: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub web_app: Option<WebAppInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct WebAppInfo {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
||||
@@ -3,9 +3,13 @@ use std::path::{Path, PathBuf};
|
||||
|
||||
use base64::Engine as _;
|
||||
use channel_gateway_core::{
|
||||
AttachmentRef, GeneratedFileDescriptor, TurnSource, WorkerApprovalDecision,
|
||||
WorkerStatusResponse, WorkerTurnAccepted, WorkerTurnEvent, WorkerTurnRequest,
|
||||
AttachmentRef, GeneratedFileDescriptor, TurnSource, WorkerAgentListResponse,
|
||||
WorkerApprovalDecision, WorkerBackgroundApprovalListResponse, WorkerMailboxSummaryResponse,
|
||||
WorkerMailboxPendingResponse, WorkerStatusResponse, WorkerTaskListResponse,
|
||||
WorkerTaskSnapshotResponse, WorkerTeamSnapshotResponse, WorkerTurnAccepted,
|
||||
WorkerTurnEvent, WorkerTurnRequest,
|
||||
};
|
||||
use runtime::RuntimeTaskRecord;
|
||||
use futures_util::StreamExt;
|
||||
use serde::Serialize;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -55,6 +59,97 @@ impl WorkerClient {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn list_tasks(&self) -> Result<WorkerTaskListResponse, WorkerClientError> {
|
||||
self.get_json("/v1/tasks").await
|
||||
}
|
||||
|
||||
pub async fn get_task(
|
||||
&self,
|
||||
task_id: &str,
|
||||
) -> Result<WorkerTaskSnapshotResponse, WorkerClientError> {
|
||||
self.get_json(&format!("/v1/tasks/{task_id}")).await
|
||||
}
|
||||
|
||||
pub async fn stop_task(&self, task_id: &str) -> Result<(), WorkerClientError> {
|
||||
self.post_no_content(&format!("/v1/tasks/{task_id}/stop"), &serde_json::json!({}))
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn team(&self) -> Result<WorkerTeamSnapshotResponse, WorkerClientError> {
|
||||
self.get_json("/v1/team").await
|
||||
}
|
||||
|
||||
pub async fn agents(&self) -> Result<WorkerAgentListResponse, WorkerClientError> {
|
||||
self.get_json("/v1/agents").await
|
||||
}
|
||||
|
||||
pub async fn agent(&self, agent_id: &str) -> Result<RuntimeTaskRecord, WorkerClientError> {
|
||||
self.get_json(&format!("/v1/agents/{agent_id}")).await
|
||||
}
|
||||
|
||||
pub async fn mark_agent_notified(&self, agent_id: &str) -> Result<(), WorkerClientError> {
|
||||
self.post_no_content(
|
||||
&format!("/v1/agents/{agent_id}/notified"),
|
||||
&serde_json::json!({}),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn background_approvals(
|
||||
&self,
|
||||
) -> Result<WorkerBackgroundApprovalListResponse, WorkerClientError> {
|
||||
self.get_json("/v1/background-approvals").await
|
||||
}
|
||||
|
||||
pub async fn post_background_approval(
|
||||
&self,
|
||||
approval_id: &str,
|
||||
decision: WorkerApprovalDecision,
|
||||
) -> Result<(), WorkerClientError> {
|
||||
self.post_no_content(
|
||||
&format!("/v1/background-approvals/{approval_id}"),
|
||||
&serde_json::json!({
|
||||
"approval_id": approval_id,
|
||||
"decision": decision,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn mark_background_approval_notified(
|
||||
&self,
|
||||
approval_id: &str,
|
||||
) -> Result<(), WorkerClientError> {
|
||||
self.post_no_content(
|
||||
&format!("/v1/background-approvals/{approval_id}/notified"),
|
||||
&serde_json::json!({}),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn mailbox(&self) -> Result<WorkerMailboxSummaryResponse, WorkerClientError> {
|
||||
self.get_json("/v1/mailbox").await
|
||||
}
|
||||
|
||||
pub async fn pending_mailbox_messages(
|
||||
&self,
|
||||
recipient: &str,
|
||||
) -> Result<WorkerMailboxPendingResponse, WorkerClientError> {
|
||||
self.get_json(&format!("/v1/mailbox/pending/{recipient}")).await
|
||||
}
|
||||
|
||||
pub async fn mark_mailbox_messages_notified(
|
||||
&self,
|
||||
recipient: &str,
|
||||
ids: &[String],
|
||||
) -> Result<(), WorkerClientError> {
|
||||
self.post_no_content(
|
||||
&format!("/v1/mailbox/{recipient}/notified"),
|
||||
&serde_json::json!({ "ids": ids }),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn post_turn(
|
||||
&self,
|
||||
prompt: String,
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{sanitize_state_component, state_root, PermissionRequest};
|
||||
use crate::workflow_state::now_secs;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "decision", rename_all = "snake_case")]
|
||||
pub enum BackgroundApprovalDecision {
|
||||
ApproveOnce,
|
||||
ApproveToolForSession,
|
||||
ApproveAllForSession,
|
||||
Deny { reason: String },
|
||||
CancelTurn,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct BackgroundApprovalRecord {
|
||||
pub approval_id: String,
|
||||
pub task_id: String,
|
||||
pub tool_name: String,
|
||||
pub input: String,
|
||||
pub current_mode: String,
|
||||
pub required_mode: String,
|
||||
#[serde(default)]
|
||||
pub reason: Option<String>,
|
||||
pub created_at: u64,
|
||||
#[serde(default)]
|
||||
pub notified: bool,
|
||||
#[serde(default)]
|
||||
pub decision: Option<BackgroundApprovalDecision>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct BackgroundApprovalStore;
|
||||
|
||||
impl BackgroundApprovalStore {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
pub fn create_request(
|
||||
&self,
|
||||
task_id: &str,
|
||||
request: &PermissionRequest,
|
||||
) -> io::Result<BackgroundApprovalRecord> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let record = BackgroundApprovalRecord {
|
||||
approval_id: next_approval_id(),
|
||||
task_id: task_id.to_string(),
|
||||
tool_name: request.tool_name.clone(),
|
||||
input: request.input.clone(),
|
||||
current_mode: request.current_mode.as_str().to_string(),
|
||||
required_mode: request.required_mode.as_str().to_string(),
|
||||
reason: request.reason.clone(),
|
||||
created_at: now_secs(),
|
||||
notified: false,
|
||||
decision: None,
|
||||
};
|
||||
self.write_locked(&record)?;
|
||||
Ok(record)
|
||||
}
|
||||
|
||||
pub fn get(&self, approval_id: &str) -> io::Result<Option<BackgroundApprovalRecord>> {
|
||||
let path = approval_path(approval_id)?;
|
||||
match fs::read_to_string(path) {
|
||||
Ok(contents) => serde_json::from_str(&contents)
|
||||
.map(Some)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)),
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(None),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list_pending(&self) -> io::Result<Vec<BackgroundApprovalRecord>> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let dir = approvals_dir(&state_root()?);
|
||||
if !dir.exists() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let mut records = Vec::new();
|
||||
for entry in fs::read_dir(dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if !path
|
||||
.extension()
|
||||
.and_then(|value| value.to_str())
|
||||
.is_some_and(|value| value.eq_ignore_ascii_case("json"))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
let contents = match fs::read_to_string(&path) {
|
||||
Ok(contents) => contents,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let Ok(record) = serde_json::from_str::<BackgroundApprovalRecord>(&contents) else {
|
||||
continue;
|
||||
};
|
||||
if record.decision.is_none() {
|
||||
records.push(record);
|
||||
}
|
||||
}
|
||||
records.sort_by(|left, right| {
|
||||
left.created_at
|
||||
.cmp(&right.created_at)
|
||||
.then_with(|| left.approval_id.cmp(&right.approval_id))
|
||||
});
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
pub fn resolve(
|
||||
&self,
|
||||
approval_id: &str,
|
||||
decision: BackgroundApprovalDecision,
|
||||
) -> io::Result<Option<BackgroundApprovalRecord>> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let Some(mut record) = self.get(approval_id)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
record.decision = Some(decision);
|
||||
self.write_locked(&record)?;
|
||||
Ok(Some(record))
|
||||
}
|
||||
|
||||
pub fn mark_notified(&self, approval_id: &str) -> io::Result<Option<BackgroundApprovalRecord>> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let Some(mut record) = self.get(approval_id)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
if record.notified {
|
||||
return Ok(Some(record));
|
||||
}
|
||||
record.notified = true;
|
||||
self.write_locked(&record)?;
|
||||
Ok(Some(record))
|
||||
}
|
||||
|
||||
fn write_locked(&self, record: &BackgroundApprovalRecord) -> io::Result<()> {
|
||||
let path = approval_path(&record.approval_id)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
let payload = serde_json::to_vec_pretty(record)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
fs::write(path, payload)
|
||||
}
|
||||
}
|
||||
|
||||
fn approvals_dir(root: &Path) -> PathBuf {
|
||||
root.join("approvals").join("background")
|
||||
}
|
||||
|
||||
fn approval_path(approval_id: &str) -> io::Result<PathBuf> {
|
||||
Ok(approvals_dir(&state_root()?).join(format!(
|
||||
"{}.json",
|
||||
sanitize_state_component(approval_id)
|
||||
)))
|
||||
}
|
||||
|
||||
fn next_approval_id() -> String {
|
||||
static NEXT_ID: AtomicU64 = AtomicU64::new(0);
|
||||
let sequence = NEXT_ID.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
format!("background-approval-{}-{sequence}", now_secs())
|
||||
}
|
||||
|
||||
fn store_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{BackgroundApprovalDecision, BackgroundApprovalStore};
|
||||
use crate::{PermissionMode, PermissionRequest, test_env_lock};
|
||||
|
||||
#[test]
|
||||
fn pending_background_approvals_round_trip() {
|
||||
let _lock = test_env_lock();
|
||||
let root = std::env::temp_dir().join("background-approval-store");
|
||||
let _ = std::fs::remove_dir_all(&root);
|
||||
std::env::set_var("CLAW_WORKER_STATE_ROOT", &root);
|
||||
|
||||
let store = BackgroundApprovalStore::new();
|
||||
let created = store
|
||||
.create_request(
|
||||
"agent-123",
|
||||
&PermissionRequest {
|
||||
tool_name: "bash".to_string(),
|
||||
input: "{\"command\":\"git status\"}".to_string(),
|
||||
current_mode: PermissionMode::WorkspaceWrite,
|
||||
required_mode: PermissionMode::DangerFullAccess,
|
||||
reason: Some("needs shell access".to_string()),
|
||||
},
|
||||
)
|
||||
.expect("request should be created");
|
||||
let pending = store.list_pending().expect("pending approvals should load");
|
||||
assert_eq!(pending.len(), 1);
|
||||
assert_eq!(pending[0].approval_id, created.approval_id);
|
||||
assert!(!pending[0].notified);
|
||||
|
||||
let notified = store
|
||||
.mark_notified(&created.approval_id)
|
||||
.expect("mark notified should work")
|
||||
.expect("record should exist");
|
||||
assert!(notified.notified);
|
||||
|
||||
let resolved = store
|
||||
.resolve(
|
||||
&created.approval_id,
|
||||
BackgroundApprovalDecision::ApproveToolForSession,
|
||||
)
|
||||
.expect("resolve should work")
|
||||
.expect("record should exist");
|
||||
assert_eq!(
|
||||
resolved.decision,
|
||||
Some(BackgroundApprovalDecision::ApproveToolForSession)
|
||||
);
|
||||
assert!(
|
||||
store
|
||||
.list_pending()
|
||||
.expect("pending approvals should load")
|
||||
.is_empty()
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_dir_all(&root);
|
||||
std::env::remove_var("CLAW_WORKER_STATE_ROOT");
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::env;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::{Command, Stdio};
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::sandbox::{
|
||||
build_linux_sandbox_command, resolve_sandbox_status_for_request, FilesystemIsolationMode,
|
||||
SandboxConfig, SandboxStatus,
|
||||
};
|
||||
use crate::ConfigLoader;
|
||||
use crate::{current_execution_cwd, ConfigLoader};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct BashCommandInput {
|
||||
@@ -64,8 +64,16 @@ pub struct BashCommandOutput {
|
||||
pub sandbox_status: Option<SandboxStatus>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct BackgroundBashHandle {
|
||||
pub pid: u32,
|
||||
pub sandbox_status: SandboxStatus,
|
||||
pub output_path: PathBuf,
|
||||
pub exit_code_path: PathBuf,
|
||||
}
|
||||
|
||||
pub fn execute_bash(input: BashCommandInput) -> io::Result<BashCommandOutput> {
|
||||
let cwd = env::current_dir()?;
|
||||
let cwd = current_execution_cwd()?;
|
||||
let sandbox_status = sandbox_status_for_input(&input, &cwd);
|
||||
|
||||
if input.run_in_background.unwrap_or(false) {
|
||||
@@ -99,6 +107,43 @@ pub fn execute_bash(input: BashCommandInput) -> io::Result<BashCommandOutput> {
|
||||
runtime.block_on(execute_bash_async(input, sandbox_status, cwd))
|
||||
}
|
||||
|
||||
pub fn spawn_background_bash(
|
||||
input: &BashCommandInput,
|
||||
output_path: &Path,
|
||||
exit_code_path: &Path,
|
||||
) -> io::Result<BackgroundBashHandle> {
|
||||
let cwd = current_execution_cwd()?;
|
||||
let sandbox_status = sandbox_status_for_input(input, &cwd);
|
||||
if let Some(parent) = output_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
if let Some(parent) = exit_code_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
let stdout = std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(output_path)?;
|
||||
let stderr = stdout.try_clone()?;
|
||||
let wrapped_command = format!(
|
||||
"({}); __claw_code=$?; printf '%s' \"$__claw_code\" > {}; exit \"$__claw_code\"",
|
||||
input.command,
|
||||
shell_quote(&exit_code_path.display().to_string())
|
||||
);
|
||||
let mut child = prepare_command(&wrapped_command, &cwd, &sandbox_status, false);
|
||||
let child = child
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::from(stdout))
|
||||
.stderr(Stdio::from(stderr))
|
||||
.spawn()?;
|
||||
Ok(BackgroundBashHandle {
|
||||
pid: child.id(),
|
||||
sandbox_status,
|
||||
output_path: output_path.to_path_buf(),
|
||||
exit_code_path: exit_code_path.to_path_buf(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_bash_async(
|
||||
input: BashCommandInput,
|
||||
sandbox_status: SandboxStatus,
|
||||
@@ -238,6 +283,10 @@ fn prepare_sandbox_dirs(cwd: &std::path::Path) {
|
||||
let _ = std::fs::create_dir_all(cwd.join(".sandbox-tmp"));
|
||||
}
|
||||
|
||||
fn shell_quote(value: &str) -> String {
|
||||
format!("'{}'", value.replace('\'', "'\"'\"'"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{execute_bash, BashCommandInput};
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
use std::cell::RefCell;
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
|
||||
thread_local! {
|
||||
static EXECUTION_CWD_STACK: RefCell<Vec<PathBuf>> = const { RefCell::new(Vec::new()) };
|
||||
}
|
||||
|
||||
pub struct ExecutionCwdGuard;
|
||||
|
||||
impl Drop for ExecutionCwdGuard {
|
||||
fn drop(&mut self) {
|
||||
EXECUTION_CWD_STACK.with(|stack| {
|
||||
stack.borrow_mut().pop();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn push_execution_cwd(cwd: PathBuf) -> ExecutionCwdGuard {
|
||||
EXECUTION_CWD_STACK.with(|stack| {
|
||||
stack.borrow_mut().push(cwd);
|
||||
});
|
||||
ExecutionCwdGuard
|
||||
}
|
||||
|
||||
pub fn with_execution_cwd<T>(cwd: PathBuf, f: impl FnOnce() -> T) -> T {
|
||||
let _guard = push_execution_cwd(cwd);
|
||||
f()
|
||||
}
|
||||
|
||||
pub fn current_execution_cwd() -> io::Result<PathBuf> {
|
||||
let override_cwd = EXECUTION_CWD_STACK.with(|stack| stack.borrow().last().cloned());
|
||||
override_cwd.map_or_else(std::env::current_dir, Ok)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use super::{current_execution_cwd, with_execution_cwd};
|
||||
|
||||
#[test]
|
||||
fn execution_cwd_is_thread_local_and_nested() {
|
||||
let base = current_execution_cwd().expect("cwd");
|
||||
let first = PathBuf::from("/tmp/execution-context-a");
|
||||
let second = PathBuf::from("/tmp/execution-context-b");
|
||||
|
||||
with_execution_cwd(first.clone(), || {
|
||||
assert_eq!(current_execution_cwd().expect("cwd"), first);
|
||||
with_execution_cwd(second.clone(), || {
|
||||
assert_eq!(current_execution_cwd().expect("cwd"), second);
|
||||
});
|
||||
assert_eq!(current_execution_cwd().expect("cwd"), first);
|
||||
|
||||
let inherited = std::thread::spawn(|| current_execution_cwd().expect("cwd"))
|
||||
.join()
|
||||
.expect("thread should join");
|
||||
assert_eq!(inherited, base);
|
||||
});
|
||||
|
||||
assert_eq!(current_execution_cwd().expect("cwd"), base);
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,8 @@ use regex::RegexBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use walkdir::WalkDir;
|
||||
|
||||
use crate::current_execution_cwd;
|
||||
|
||||
/// Maximum file size that can be read (10 MB).
|
||||
const MAX_READ_SIZE: u64 = 10 * 1024 * 1024;
|
||||
|
||||
@@ -288,7 +290,7 @@ pub fn glob_search(pattern: &str, path: Option<&str>) -> io::Result<GlobSearchOu
|
||||
let base_dir = path
|
||||
.map(normalize_path)
|
||||
.transpose()?
|
||||
.unwrap_or(std::env::current_dir()?);
|
||||
.unwrap_or(current_execution_cwd()?);
|
||||
let search_pattern = if Path::new(pattern).is_absolute() {
|
||||
pattern.to_owned()
|
||||
} else {
|
||||
@@ -332,7 +334,7 @@ pub fn grep_search(input: &GrepSearchInput) -> io::Result<GrepSearchOutput> {
|
||||
.as_deref()
|
||||
.map(normalize_path)
|
||||
.transpose()?
|
||||
.unwrap_or(std::env::current_dir()?);
|
||||
.unwrap_or(current_execution_cwd()?);
|
||||
|
||||
let regex = RegexBuilder::new(&input.pattern)
|
||||
.case_insensitive(input.case_insensitive.unwrap_or(false))
|
||||
@@ -515,7 +517,7 @@ fn normalize_path(path: &str) -> io::Result<PathBuf> {
|
||||
let candidate = if Path::new(path).is_absolute() {
|
||||
PathBuf::from(path)
|
||||
} else {
|
||||
std::env::current_dir()?.join(path)
|
||||
current_execution_cwd()?.join(path)
|
||||
};
|
||||
candidate.canonicalize()
|
||||
}
|
||||
@@ -524,7 +526,7 @@ fn normalize_path_allow_missing(path: &str) -> io::Result<PathBuf> {
|
||||
let candidate = if Path::new(path).is_absolute() {
|
||||
PathBuf::from(path)
|
||||
} else {
|
||||
std::env::current_dir()?.join(path)
|
||||
current_execution_cwd()?.join(path)
|
||||
};
|
||||
|
||||
if let Ok(canonical) = candidate.canonicalize() {
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
mod bash;
|
||||
pub mod background_approval_store;
|
||||
pub mod bash_validation;
|
||||
mod bootstrap;
|
||||
mod compact;
|
||||
mod config;
|
||||
mod conversation;
|
||||
mod execution_context;
|
||||
mod file_ops;
|
||||
pub mod green_contract;
|
||||
mod hooks;
|
||||
@@ -21,6 +23,7 @@ mod permissions;
|
||||
pub mod plugin_lifecycle;
|
||||
mod policy_engine;
|
||||
mod prompt;
|
||||
pub mod runtime_task_store;
|
||||
pub mod recovery_recipes;
|
||||
mod remote;
|
||||
pub mod sandbox;
|
||||
@@ -29,14 +32,24 @@ pub mod session_control;
|
||||
mod sse;
|
||||
pub mod stale_branch;
|
||||
pub mod summary_compression;
|
||||
pub mod task_list_store;
|
||||
pub mod task_cancellation;
|
||||
pub mod task_packet;
|
||||
pub mod task_registry;
|
||||
pub mod teamwork_store;
|
||||
pub mod team_cron_registry;
|
||||
pub mod trust_resolver;
|
||||
mod usage;
|
||||
pub mod workflow_state;
|
||||
pub mod worker_boot;
|
||||
|
||||
pub use bash::{execute_bash, BashCommandInput, BashCommandOutput};
|
||||
pub use bash::{
|
||||
execute_bash, spawn_background_bash, BackgroundBashHandle, BashCommandInput,
|
||||
BashCommandOutput,
|
||||
};
|
||||
pub use background_approval_store::{
|
||||
BackgroundApprovalDecision, BackgroundApprovalRecord, BackgroundApprovalStore,
|
||||
};
|
||||
pub use bootstrap::{BootstrapPhase, BootstrapPlan};
|
||||
pub use compact::{
|
||||
compact_session, estimate_session_tokens, format_compact_summary,
|
||||
@@ -55,6 +68,7 @@ pub use conversation::{
|
||||
AssistantEventObserver, AutoCompactionEvent, ConversationRuntime, PromptCacheEvent,
|
||||
RuntimeError, StaticToolExecutor, ToolError, ToolExecutor, TurnEventObserver, TurnSummary,
|
||||
};
|
||||
pub use execution_context::{current_execution_cwd, push_execution_cwd, with_execution_cwd};
|
||||
pub use file_ops::{
|
||||
edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput,
|
||||
GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload,
|
||||
@@ -110,6 +124,10 @@ pub use prompt::{
|
||||
load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError,
|
||||
SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
||||
};
|
||||
pub use runtime_task_store::{
|
||||
background_output_paths, RuntimeTaskKind, RuntimeTaskOutput, RuntimeTaskRecord,
|
||||
RuntimeTaskStatus, RuntimeTaskStore,
|
||||
};
|
||||
pub use recovery_recipes::{
|
||||
attempt_recovery, recipe_for, EscalationPolicy, FailureScenario, RecoveryContext,
|
||||
RecoveryEvent, RecoveryRecipe, RecoveryResult, RecoveryStep,
|
||||
@@ -138,10 +156,21 @@ pub use task_packet::{
|
||||
validate_packet, AcceptanceTest, BranchPolicy, CommitPolicy, RepoConfig, ReportingContract,
|
||||
TaskPacket, TaskPacketValidationError, TaskScope, ValidatedPacket,
|
||||
};
|
||||
pub use task_cancellation::{
|
||||
is_task_cancelled, register_task_cancel_flag, signal_task_cancel, unregister_task_cancel_flag,
|
||||
};
|
||||
pub use task_list_store::{TaskListPatch, TaskListRecord, TaskListStatus, TaskListStore};
|
||||
pub use teamwork_store::{
|
||||
MailboxMessage, MailboxSummary, MessageEnvelope, TeamMemberRecord, TeamRecord, TeamStore,
|
||||
};
|
||||
pub use trust_resolver::{TrustConfig, TrustDecision, TrustEvent, TrustPolicy, TrustResolver};
|
||||
pub use usage::{
|
||||
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
|
||||
};
|
||||
pub use workflow_state::{
|
||||
clear_team_context, current_task_list_id, default_session_identity, load_team_context,
|
||||
sanitize_state_component, state_root, TeamContext,
|
||||
};
|
||||
pub use worker_boot::{
|
||||
Worker, WorkerEvent, WorkerEventKind, WorkerFailure, WorkerFailureKind, WorkerReadySnapshot,
|
||||
WorkerRegistry, WorkerStatus,
|
||||
|
||||
@@ -0,0 +1,535 @@
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::signal_task_cancel;
|
||||
use crate::workflow_state::{now_secs, sanitize_state_component, state_root};
|
||||
|
||||
fn store_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RuntimeTaskKind {
|
||||
Agent,
|
||||
Shell,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RuntimeTaskStatus {
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
Stopped,
|
||||
}
|
||||
|
||||
impl RuntimeTaskStatus {
|
||||
#[must_use]
|
||||
pub fn is_terminal(self) -> bool {
|
||||
matches!(self, Self::Completed | Self::Failed | Self::Stopped)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RuntimeTaskStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Running => write!(f, "running"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
Self::Stopped => write!(f, "stopped"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct RuntimeTaskRecord {
|
||||
pub task_id: String,
|
||||
pub kind: RuntimeTaskKind,
|
||||
pub status: RuntimeTaskStatus,
|
||||
pub description: String,
|
||||
pub prompt: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub output_file: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub exit_code_file: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub final_result: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub exit_code: Option<i32>,
|
||||
#[serde(default)]
|
||||
pub notified: bool,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub pid: Option<u32>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub agent_id: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub agent_name: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub team_name: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub cwd: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub worktree_path: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub worktree_branch: Option<String>,
|
||||
pub created_at: u64,
|
||||
pub started_at: u64,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub completed_at: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct RuntimeTaskOutput {
|
||||
pub task: RuntimeTaskRecord,
|
||||
pub output: String,
|
||||
pub has_output: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RuntimeTaskStore;
|
||||
|
||||
impl RuntimeTaskStore {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
pub fn create_shell_task(
|
||||
&self,
|
||||
description: String,
|
||||
command: String,
|
||||
pid: u32,
|
||||
output_file: PathBuf,
|
||||
exit_code_file: PathBuf,
|
||||
team_name: Option<String>,
|
||||
cwd: Option<String>,
|
||||
) -> io::Result<RuntimeTaskRecord> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let now = now_secs();
|
||||
let record = RuntimeTaskRecord {
|
||||
task_id: make_runtime_task_id("shell"),
|
||||
kind: RuntimeTaskKind::Shell,
|
||||
status: RuntimeTaskStatus::Running,
|
||||
description,
|
||||
prompt: command,
|
||||
output_file: Some(output_file.display().to_string()),
|
||||
exit_code_file: Some(exit_code_file.display().to_string()),
|
||||
final_result: None,
|
||||
error: None,
|
||||
exit_code: None,
|
||||
notified: false,
|
||||
pid: Some(pid),
|
||||
agent_id: None,
|
||||
agent_name: None,
|
||||
team_name,
|
||||
cwd,
|
||||
worktree_path: None,
|
||||
worktree_branch: None,
|
||||
created_at: now,
|
||||
started_at: now,
|
||||
completed_at: None,
|
||||
};
|
||||
self.write_locked(&record)?;
|
||||
Ok(record)
|
||||
}
|
||||
|
||||
pub fn create_agent_task(
|
||||
&self,
|
||||
agent_id: String,
|
||||
agent_name: String,
|
||||
description: String,
|
||||
prompt: String,
|
||||
output_file: String,
|
||||
team_name: Option<String>,
|
||||
cwd: Option<String>,
|
||||
worktree_path: Option<String>,
|
||||
worktree_branch: Option<String>,
|
||||
) -> io::Result<RuntimeTaskRecord> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let now = now_secs();
|
||||
let record = RuntimeTaskRecord {
|
||||
task_id: agent_id.clone(),
|
||||
kind: RuntimeTaskKind::Agent,
|
||||
status: RuntimeTaskStatus::Running,
|
||||
description,
|
||||
prompt,
|
||||
output_file: Some(output_file),
|
||||
exit_code_file: None,
|
||||
final_result: None,
|
||||
error: None,
|
||||
exit_code: None,
|
||||
notified: false,
|
||||
pid: None,
|
||||
agent_id: Some(agent_id),
|
||||
agent_name: Some(agent_name),
|
||||
team_name,
|
||||
cwd,
|
||||
worktree_path,
|
||||
worktree_branch,
|
||||
created_at: now,
|
||||
started_at: now,
|
||||
completed_at: None,
|
||||
};
|
||||
self.write_locked(&record)?;
|
||||
Ok(record)
|
||||
}
|
||||
|
||||
pub fn get(&self, task_id: &str) -> io::Result<Option<RuntimeTaskRecord>> {
|
||||
let mut record = match fs::read_to_string(task_path(task_id)?) {
|
||||
Ok(contents) => serde_json::from_str::<RuntimeTaskRecord>(&contents)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?,
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(None),
|
||||
Err(error) => return Err(error),
|
||||
};
|
||||
refresh_record(&mut record)?;
|
||||
Ok(Some(record))
|
||||
}
|
||||
|
||||
pub fn list(&self) -> io::Result<Vec<RuntimeTaskRecord>> {
|
||||
let dir = tasks_dir()?;
|
||||
let entries = match fs::read_dir(dir) {
|
||||
Ok(entries) => entries,
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(Vec::new()),
|
||||
Err(error) => return Err(error),
|
||||
};
|
||||
let mut tasks = Vec::new();
|
||||
for entry in entries {
|
||||
let path = entry?.path();
|
||||
if path.extension().and_then(|value| value.to_str()) != Some("json") {
|
||||
continue;
|
||||
}
|
||||
let contents = fs::read_to_string(&path)?;
|
||||
let mut record = serde_json::from_str::<RuntimeTaskRecord>(&contents)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
refresh_record(&mut record)?;
|
||||
tasks.push(record);
|
||||
}
|
||||
tasks.sort_by(|left, right| left.created_at.cmp(&right.created_at));
|
||||
Ok(tasks)
|
||||
}
|
||||
|
||||
pub fn mark_agent_terminal(
|
||||
&self,
|
||||
task_id: &str,
|
||||
status: RuntimeTaskStatus,
|
||||
final_result: Option<String>,
|
||||
error: Option<String>,
|
||||
) -> io::Result<Option<RuntimeTaskRecord>> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let Some(mut record) = self.get(task_id)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
if record.status == RuntimeTaskStatus::Stopped {
|
||||
return Ok(Some(record));
|
||||
}
|
||||
record.status = status;
|
||||
record.final_result = final_result;
|
||||
record.error = error;
|
||||
record.completed_at = Some(now_secs());
|
||||
self.write_locked(&record)?;
|
||||
Ok(Some(record))
|
||||
}
|
||||
|
||||
pub fn stop(&self, task_id: &str) -> io::Result<Option<RuntimeTaskRecord>> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let Some(mut record) = self.get(task_id)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
if record.status.is_terminal() {
|
||||
return Ok(Some(record));
|
||||
}
|
||||
if let Some(pid) = record.pid {
|
||||
stop_pid(pid)?;
|
||||
} else if record.kind == RuntimeTaskKind::Agent {
|
||||
if !signal_task_cancel(task_id) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Unsupported,
|
||||
"background agent task is not currently cancellable",
|
||||
));
|
||||
}
|
||||
}
|
||||
record.status = RuntimeTaskStatus::Stopped;
|
||||
record.final_result = None;
|
||||
record.error = None;
|
||||
record.completed_at = Some(now_secs());
|
||||
self.write_locked(&record)?;
|
||||
Ok(Some(record))
|
||||
}
|
||||
|
||||
pub fn mark_notified(&self, task_id: &str) -> io::Result<Option<RuntimeTaskRecord>> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let Some(mut record) = self.get(task_id)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
if record.notified {
|
||||
return Ok(Some(record));
|
||||
}
|
||||
record.notified = true;
|
||||
self.write_locked(&record)?;
|
||||
Ok(Some(record))
|
||||
}
|
||||
|
||||
pub fn output(
|
||||
&self,
|
||||
task_id: &str,
|
||||
block: bool,
|
||||
timeout_ms: Option<u64>,
|
||||
) -> io::Result<Option<RuntimeTaskOutput>> {
|
||||
let start = std::time::Instant::now();
|
||||
loop {
|
||||
let Some(record) = self.get(task_id)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
if !block || record.status.is_terminal() {
|
||||
let output = record
|
||||
.output_file
|
||||
.as_deref()
|
||||
.map(fs::read_to_string)
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
return Ok(Some(RuntimeTaskOutput {
|
||||
has_output: !output.trim().is_empty(),
|
||||
output,
|
||||
task: record,
|
||||
}));
|
||||
}
|
||||
if let Some(timeout_ms) = timeout_ms {
|
||||
if start.elapsed() >= std::time::Duration::from_millis(timeout_ms) {
|
||||
let output = record
|
||||
.output_file
|
||||
.as_deref()
|
||||
.map(fs::read_to_string)
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
return Ok(Some(RuntimeTaskOutput {
|
||||
has_output: !output.trim().is_empty(),
|
||||
output,
|
||||
task: record,
|
||||
}));
|
||||
}
|
||||
}
|
||||
std::thread::sleep(std::time::Duration::from_millis(200));
|
||||
}
|
||||
}
|
||||
|
||||
fn write_locked(&self, record: &RuntimeTaskRecord) -> io::Result<()> {
|
||||
let path = task_path(&record.task_id)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
let payload = serde_json::to_vec_pretty(record)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
fs::write(path, payload)
|
||||
}
|
||||
}
|
||||
|
||||
fn tasks_dir() -> io::Result<PathBuf> {
|
||||
Ok(state_root()?.join("runtime-tasks"))
|
||||
}
|
||||
|
||||
fn task_path(task_id: &str) -> io::Result<PathBuf> {
|
||||
Ok(tasks_dir()?.join(format!(
|
||||
"{}.json",
|
||||
sanitize_state_component(task_id)
|
||||
)))
|
||||
}
|
||||
|
||||
fn make_runtime_task_id(prefix: &str) -> String {
|
||||
format!("{prefix}-{}", now_secs())
|
||||
}
|
||||
|
||||
fn refresh_record(record: &mut RuntimeTaskRecord) -> io::Result<()> {
|
||||
if record.status.is_terminal() || record.kind != RuntimeTaskKind::Shell {
|
||||
return Ok(());
|
||||
}
|
||||
let Some(exit_code_path) = record.exit_code_file.as_deref() else {
|
||||
return Ok(());
|
||||
};
|
||||
if let Ok(contents) = fs::read_to_string(exit_code_path) {
|
||||
if let Ok(code) = contents.trim().parse::<i32>() {
|
||||
record.exit_code = Some(code);
|
||||
record.status = if code == 0 {
|
||||
RuntimeTaskStatus::Completed
|
||||
} else {
|
||||
RuntimeTaskStatus::Failed
|
||||
};
|
||||
record.completed_at = Some(now_secs());
|
||||
let payload = serde_json::to_vec_pretty(record)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
fs::write(task_path(&record.task_id)?, payload)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn stop_pid(pid: u32) -> io::Result<()> {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let status = Command::new("kill")
|
||||
.arg("-TERM")
|
||||
.arg(pid.to_string())
|
||||
.status()?;
|
||||
if status.success() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(io::Error::other(format!("failed to stop pid {pid}")))
|
||||
}
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
let _ = pid;
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::Unsupported,
|
||||
"runtime task stopping is only supported on unix",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn background_output_paths(task_id: &str) -> io::Result<(PathBuf, PathBuf)> {
|
||||
let root = state_root()?.join("runtime-tasks").join("outputs");
|
||||
fs::create_dir_all(&root)?;
|
||||
let safe = sanitize_state_component(task_id);
|
||||
Ok((root.join(format!("{safe}.log")), root.join(format!("{safe}.exit"))))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs;
|
||||
use std::sync::{atomic::AtomicBool, Arc};
|
||||
|
||||
use super::{background_output_paths, RuntimeTaskStatus, RuntimeTaskStore};
|
||||
use crate::{register_task_cancel_flag, unregister_task_cancel_flag};
|
||||
use crate::test_env_lock;
|
||||
|
||||
#[test]
|
||||
fn shell_runtime_tasks_refresh_from_exit_file() {
|
||||
let _lock = test_env_lock();
|
||||
let root = std::env::temp_dir().join("runtime-task-store-tests");
|
||||
let _ = fs::remove_dir_all(&root);
|
||||
std::env::set_var("CLAW_WORKER_STATE_ROOT", &root);
|
||||
|
||||
let store = RuntimeTaskStore::new();
|
||||
let (output_path, exit_path) = background_output_paths("shell-1").expect("paths");
|
||||
fs::write(&output_path, "hello").expect("write output");
|
||||
let record = store
|
||||
.create_shell_task(
|
||||
"Run tests".to_string(),
|
||||
"cargo test".to_string(),
|
||||
123,
|
||||
output_path.clone(),
|
||||
exit_path.clone(),
|
||||
None,
|
||||
Some("/tmp/runtime-shell".to_string()),
|
||||
)
|
||||
.expect("create task");
|
||||
fs::write(&exit_path, "0").expect("write exit code");
|
||||
let output = store
|
||||
.output(&record.task_id, false, None)
|
||||
.expect("load output")
|
||||
.expect("task exists");
|
||||
assert_eq!(output.task.status, RuntimeTaskStatus::Completed);
|
||||
assert_eq!(output.output, "hello");
|
||||
assert_eq!(output.task.cwd.as_deref(), Some("/tmp/runtime-shell"));
|
||||
|
||||
let _ = fs::remove_dir_all(&root);
|
||||
std::env::remove_var("CLAW_WORKER_STATE_ROOT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agent_runtime_tasks_can_be_cancelled() {
|
||||
let _lock = test_env_lock();
|
||||
let root = std::env::temp_dir().join("runtime-task-store-agent-stop");
|
||||
let _ = fs::remove_dir_all(&root);
|
||||
std::env::set_var("CLAW_WORKER_STATE_ROOT", &root);
|
||||
|
||||
let store = RuntimeTaskStore::new();
|
||||
let record = store
|
||||
.create_agent_task(
|
||||
"agent-stop-1".to_string(),
|
||||
"agent-stop".to_string(),
|
||||
"Stop the agent".to_string(),
|
||||
"Run until cancelled".to_string(),
|
||||
root.join("agent.md").display().to_string(),
|
||||
None,
|
||||
Some("/tmp/runtime-agent".to_string()),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.expect("create agent task");
|
||||
register_task_cancel_flag(&record.task_id, Arc::new(AtomicBool::new(false)));
|
||||
|
||||
let stopped = store
|
||||
.stop(&record.task_id)
|
||||
.expect("stop should succeed")
|
||||
.expect("task should exist");
|
||||
assert_eq!(stopped.status, RuntimeTaskStatus::Stopped);
|
||||
|
||||
let persisted = store
|
||||
.get(&record.task_id)
|
||||
.expect("reload task")
|
||||
.expect("task exists");
|
||||
assert_eq!(persisted.status, RuntimeTaskStatus::Stopped);
|
||||
assert_eq!(persisted.cwd.as_deref(), Some("/tmp/runtime-agent"));
|
||||
|
||||
let _ = unregister_task_cancel_flag(&record.task_id);
|
||||
let _ = fs::remove_dir_all(&root);
|
||||
std::env::remove_var("CLAW_WORKER_STATE_ROOT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn runtime_tasks_can_be_marked_notified() {
|
||||
let _lock = test_env_lock();
|
||||
let root = std::env::temp_dir().join("runtime-task-store-notified");
|
||||
let _ = fs::remove_dir_all(&root);
|
||||
std::env::set_var("CLAW_WORKER_STATE_ROOT", &root);
|
||||
|
||||
let store = RuntimeTaskStore::new();
|
||||
let record = store
|
||||
.create_agent_task(
|
||||
"agent-notify-1".to_string(),
|
||||
"agent-notify".to_string(),
|
||||
"Notify me".to_string(),
|
||||
"Finish the work".to_string(),
|
||||
root.join("agent.md").display().to_string(),
|
||||
None,
|
||||
Some("/tmp/runtime-agent".to_string()),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.expect("create agent task");
|
||||
|
||||
let notified = store
|
||||
.mark_notified(&record.task_id)
|
||||
.expect("mark notified should succeed")
|
||||
.expect("task should exist");
|
||||
assert!(notified.notified);
|
||||
|
||||
let persisted = store
|
||||
.get(&record.task_id)
|
||||
.expect("reload task")
|
||||
.expect("task exists");
|
||||
assert!(persisted.notified);
|
||||
|
||||
let _ = fs::remove_dir_all(&root);
|
||||
std::env::remove_var("CLAW_WORKER_STATE_ROOT");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc, Mutex, OnceLock,
|
||||
};
|
||||
|
||||
fn registry() -> &'static Mutex<BTreeMap<String, Arc<AtomicBool>>> {
|
||||
static REGISTRY: OnceLock<Mutex<BTreeMap<String, Arc<AtomicBool>>>> = OnceLock::new();
|
||||
REGISTRY.get_or_init(|| Mutex::new(BTreeMap::new()))
|
||||
}
|
||||
|
||||
pub fn register_task_cancel_flag(
|
||||
task_id: &str,
|
||||
flag: Arc<AtomicBool>,
|
||||
) -> Option<Arc<AtomicBool>> {
|
||||
registry()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.insert(task_id.to_string(), flag)
|
||||
}
|
||||
|
||||
pub fn unregister_task_cancel_flag(task_id: &str) -> Option<Arc<AtomicBool>> {
|
||||
registry()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.remove(task_id)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn signal_task_cancel(task_id: &str) -> bool {
|
||||
let flag = registry()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.get(task_id)
|
||||
.cloned();
|
||||
if let Some(flag) = flag {
|
||||
flag.store(true, Ordering::SeqCst);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_task_cancelled(task_id: &str) -> bool {
|
||||
registry()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.get(task_id)
|
||||
.is_some_and(|flag| flag.load(Ordering::SeqCst))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::{atomic::AtomicBool, Arc};
|
||||
|
||||
use super::{
|
||||
is_task_cancelled, register_task_cancel_flag, signal_task_cancel,
|
||||
unregister_task_cancel_flag,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn cancel_flags_round_trip_through_registry() {
|
||||
let task_id = "task-cancel-registry";
|
||||
let _ = unregister_task_cancel_flag(task_id);
|
||||
assert!(!is_task_cancelled(task_id));
|
||||
|
||||
register_task_cancel_flag(task_id, Arc::new(AtomicBool::new(false)));
|
||||
assert!(signal_task_cancel(task_id));
|
||||
assert!(is_task_cancelled(task_id));
|
||||
|
||||
let removed = unregister_task_cancel_flag(task_id);
|
||||
assert!(removed.is_some());
|
||||
assert!(!signal_task_cancel(task_id));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,369 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::workflow_state::{
|
||||
current_task_list_id, now_secs, sanitize_state_component, state_root,
|
||||
};
|
||||
|
||||
fn store_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TaskListStatus {
|
||||
Pending,
|
||||
InProgress,
|
||||
Completed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TaskListStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Pending => write!(f, "pending"),
|
||||
Self::InProgress => write!(f, "in_progress"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct TaskListRecord {
|
||||
pub id: String,
|
||||
pub subject: String,
|
||||
pub description: String,
|
||||
#[serde(rename = "activeForm", default, skip_serializing_if = "Option::is_none")]
|
||||
pub active_form: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub owner: Option<String>,
|
||||
pub status: TaskListStatus,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub blocks: Vec<String>,
|
||||
#[serde(rename = "blockedBy", default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub blocked_by: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<BTreeMap<String, Value>>,
|
||||
#[serde(default)]
|
||||
pub internal: bool,
|
||||
#[serde(rename = "createdAt")]
|
||||
pub created_at: u64,
|
||||
#[serde(rename = "updatedAt")]
|
||||
pub updated_at: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
|
||||
pub struct TaskListPatch {
|
||||
#[serde(default)]
|
||||
pub subject: Option<String>,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
#[serde(rename = "activeForm", default)]
|
||||
pub active_form: Option<String>,
|
||||
#[serde(default)]
|
||||
pub status: Option<TaskListStatus>,
|
||||
#[serde(rename = "addBlocks", default)]
|
||||
pub add_blocks: Vec<String>,
|
||||
#[serde(rename = "addBlockedBy", default)]
|
||||
pub add_blocked_by: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub owner: Option<String>,
|
||||
#[serde(default)]
|
||||
pub metadata: Option<BTreeMap<String, Value>>,
|
||||
#[serde(default)]
|
||||
pub internal: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskListStore {
|
||||
task_list_id: String,
|
||||
}
|
||||
|
||||
impl TaskListStore {
|
||||
pub fn current() -> io::Result<Self> {
|
||||
Ok(Self {
|
||||
task_list_id: current_task_list_id()?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn for_task_list(task_list_id: impl Into<String>) -> Self {
|
||||
Self {
|
||||
task_list_id: sanitize_state_component(&task_list_id.into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn task_list_id(&self) -> &str {
|
||||
&self.task_list_id
|
||||
}
|
||||
|
||||
fn tasks_dir(&self) -> io::Result<PathBuf> {
|
||||
Ok(state_root()?.join("tasks").join(&self.task_list_id))
|
||||
}
|
||||
|
||||
fn task_path(&self, task_id: &str) -> io::Result<PathBuf> {
|
||||
Ok(self.tasks_dir()?.join(format!(
|
||||
"{}.json",
|
||||
sanitize_state_component(task_id)
|
||||
)))
|
||||
}
|
||||
|
||||
fn high_water_mark_path(&self) -> io::Result<PathBuf> {
|
||||
Ok(self.tasks_dir()?.join(".highwatermark"))
|
||||
}
|
||||
|
||||
pub fn create(
|
||||
&self,
|
||||
subject: String,
|
||||
description: String,
|
||||
active_form: Option<String>,
|
||||
metadata: Option<BTreeMap<String, Value>>,
|
||||
) -> io::Result<TaskListRecord> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let task_dir = self.tasks_dir()?;
|
||||
fs::create_dir_all(&task_dir)?;
|
||||
let next_id = self.next_task_id_locked()?;
|
||||
let now = now_secs();
|
||||
let record = TaskListRecord {
|
||||
id: next_id.clone(),
|
||||
subject,
|
||||
description,
|
||||
active_form,
|
||||
owner: None,
|
||||
status: TaskListStatus::Pending,
|
||||
blocks: Vec::new(),
|
||||
blocked_by: Vec::new(),
|
||||
metadata,
|
||||
internal: false,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
};
|
||||
self.write_record_locked(&record)?;
|
||||
Ok(record)
|
||||
}
|
||||
|
||||
pub fn get(&self, task_id: &str) -> io::Result<Option<TaskListRecord>> {
|
||||
let path = self.task_path(task_id)?;
|
||||
read_record(&path)
|
||||
}
|
||||
|
||||
pub fn list(&self, include_internal: bool) -> io::Result<Vec<TaskListRecord>> {
|
||||
let mut records = Vec::new();
|
||||
let dir = self.tasks_dir()?;
|
||||
let entries = match fs::read_dir(dir) {
|
||||
Ok(entries) => entries,
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => return Ok(records),
|
||||
Err(error) => return Err(error),
|
||||
};
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|value| value.to_str()) != Some("json") {
|
||||
continue;
|
||||
}
|
||||
if let Some(record) = read_record(&path)? {
|
||||
if include_internal || !record.internal {
|
||||
records.push(record);
|
||||
}
|
||||
}
|
||||
}
|
||||
records.sort_by_key(|record| record.id.parse::<u64>().unwrap_or(0));
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
pub fn update(&self, task_id: &str, patch: TaskListPatch) -> io::Result<Option<TaskListRecord>> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let Some(mut existing) = self.get(task_id)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
if let Some(subject) = patch.subject {
|
||||
existing.subject = subject;
|
||||
}
|
||||
if let Some(description) = patch.description {
|
||||
existing.description = description;
|
||||
}
|
||||
if let Some(active_form) = patch.active_form {
|
||||
existing.active_form = Some(active_form);
|
||||
}
|
||||
if let Some(status) = patch.status {
|
||||
existing.status = status;
|
||||
}
|
||||
if let Some(owner) = patch.owner {
|
||||
existing.owner = if owner.trim().is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(owner)
|
||||
};
|
||||
}
|
||||
if let Some(metadata) = patch.metadata {
|
||||
existing.metadata = Some(metadata);
|
||||
}
|
||||
if let Some(internal) = patch.internal {
|
||||
existing.internal = internal;
|
||||
}
|
||||
merge_unique(&mut existing.blocks, patch.add_blocks);
|
||||
merge_unique(&mut existing.blocked_by, patch.add_blocked_by);
|
||||
existing.updated_at = now_secs();
|
||||
self.write_record_locked(&existing)?;
|
||||
Ok(Some(existing))
|
||||
}
|
||||
|
||||
pub fn delete(&self, task_id: &str) -> io::Result<bool> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let path = self.task_path(task_id)?;
|
||||
match fs::remove_file(&path) {
|
||||
Ok(()) => {
|
||||
for mut other in self.list(true)? {
|
||||
let original_blocks = other.blocks.len();
|
||||
let original_blocked_by = other.blocked_by.len();
|
||||
other.blocks.retain(|value| value != task_id);
|
||||
other.blocked_by.retain(|value| value != task_id);
|
||||
if other.blocks.len() != original_blocks
|
||||
|| other.blocked_by.len() != original_blocked_by
|
||||
{
|
||||
other.updated_at = now_secs();
|
||||
self.write_record_locked(&other)?;
|
||||
}
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(false),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset(&self) -> io::Result<()> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let dir = self.tasks_dir()?;
|
||||
let existing = self.list(true)?;
|
||||
let max_id = existing
|
||||
.iter()
|
||||
.filter_map(|task| task.id.parse::<u64>().ok())
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
fs::create_dir_all(&dir)?;
|
||||
for entry in fs::read_dir(&dir)? {
|
||||
let path = entry?.path();
|
||||
if path.extension().and_then(|value| value.to_str()) == Some("json") {
|
||||
let _ = fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
if max_id > 0 {
|
||||
fs::write(self.high_water_mark_path()?, max_id.to_string())?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn next_task_id_locked(&self) -> io::Result<String> {
|
||||
let existing_max = self
|
||||
.list(true)?
|
||||
.into_iter()
|
||||
.filter_map(|task| task.id.parse::<u64>().ok())
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
let high_water_mark = fs::read_to_string(self.high_water_mark_path()?)
|
||||
.ok()
|
||||
.and_then(|value| value.trim().parse::<u64>().ok())
|
||||
.unwrap_or(0);
|
||||
let next = existing_max.max(high_water_mark) + 1;
|
||||
fs::write(self.high_water_mark_path()?, next.to_string())?;
|
||||
Ok(next.to_string())
|
||||
}
|
||||
|
||||
fn write_record_locked(&self, record: &TaskListRecord) -> io::Result<()> {
|
||||
let path = self.task_path(&record.id)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
let payload = serde_json::to_vec_pretty(record)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
fs::write(path, payload)
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_unique(target: &mut Vec<String>, additions: Vec<String>) {
|
||||
for value in additions {
|
||||
if !target.contains(&value) {
|
||||
target.push(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn read_record(path: &Path) -> io::Result<Option<TaskListRecord>> {
|
||||
match fs::read_to_string(path) {
|
||||
Ok(contents) => serde_json::from_str(&contents)
|
||||
.map(Some)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)),
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(None),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use super::{TaskListPatch, TaskListStatus, TaskListStore};
|
||||
use crate::test_env_lock;
|
||||
|
||||
#[test]
|
||||
fn create_update_and_delete_task_records() {
|
||||
let _lock = test_env_lock();
|
||||
let root = std::env::temp_dir().join("task-list-store-tests");
|
||||
let _ = std::fs::remove_dir_all(&root);
|
||||
std::env::set_var("CLAW_WORKER_STATE_ROOT", &root);
|
||||
std::env::set_var("CLAW_WORKER_PROFILE_ID", "makar");
|
||||
|
||||
let store = TaskListStore::for_task_list("alpha");
|
||||
let created = store
|
||||
.create(
|
||||
"Investigate".to_string(),
|
||||
"Check the failing worker".to_string(),
|
||||
Some("Investigating".to_string()),
|
||||
Some(BTreeMap::from([("priority".to_string(), json!("high"))])),
|
||||
)
|
||||
.expect("task creates");
|
||||
assert_eq!(created.id, "1");
|
||||
|
||||
let updated = store
|
||||
.update(
|
||||
&created.id,
|
||||
TaskListPatch {
|
||||
status: Some(TaskListStatus::InProgress),
|
||||
owner: Some("agent-lead".to_string()),
|
||||
add_blocked_by: vec!["7".to_string()],
|
||||
..TaskListPatch::default()
|
||||
},
|
||||
)
|
||||
.expect("task updates")
|
||||
.expect("task exists");
|
||||
assert_eq!(updated.status, TaskListStatus::InProgress);
|
||||
assert_eq!(updated.owner.as_deref(), Some("agent-lead"));
|
||||
assert_eq!(updated.blocked_by, vec!["7"]);
|
||||
|
||||
let listed = store.list(false).expect("tasks list");
|
||||
assert_eq!(listed.len(), 1);
|
||||
assert!(store.delete(&created.id).expect("delete succeeds"));
|
||||
assert!(store.list(false).expect("tasks list").is_empty());
|
||||
|
||||
let _ = std::fs::remove_dir_all(&root);
|
||||
std::env::remove_var("CLAW_WORKER_STATE_ROOT");
|
||||
std::env::remove_var("CLAW_WORKER_PROFILE_ID");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,501 @@
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::task_list_store::TaskListStore;
|
||||
use crate::workflow_state::{
|
||||
clear_team_context, now_secs, persist_team_context, sanitize_state_component, state_root,
|
||||
TeamContext,
|
||||
};
|
||||
|
||||
fn store_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct TeamMemberRecord {
|
||||
pub agent_id: String,
|
||||
pub name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub agent_type: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub status: Option<String>,
|
||||
pub joined_at: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct TeamRecord {
|
||||
pub team_name: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub agent_type: Option<String>,
|
||||
pub lead_agent_id: String,
|
||||
pub created_at: u64,
|
||||
pub updated_at: u64,
|
||||
#[serde(default)]
|
||||
pub deleted: bool,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub members: Vec<TeamMemberRecord>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MessageEnvelope {
|
||||
pub id: String,
|
||||
pub from: String,
|
||||
pub to: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub summary: Option<String>,
|
||||
pub message: Value,
|
||||
pub timestamp: u64,
|
||||
#[serde(default)]
|
||||
pub read: bool,
|
||||
#[serde(default)]
|
||||
pub notified: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MailboxMessage {
|
||||
pub recipient: String,
|
||||
pub envelope: MessageEnvelope,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MailboxSummary {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub team_name: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub recent_messages: Vec<MailboxMessage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TeamStore;
|
||||
|
||||
impl TeamStore {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
pub fn create_team(
|
||||
&self,
|
||||
requested_name: &str,
|
||||
description: Option<String>,
|
||||
agent_type: Option<String>,
|
||||
) -> io::Result<TeamRecord> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let final_name = self.unique_team_name(requested_name)?;
|
||||
let now = now_secs();
|
||||
let record = TeamRecord {
|
||||
team_name: final_name.clone(),
|
||||
description: description.clone(),
|
||||
agent_type: agent_type.clone(),
|
||||
lead_agent_id: format!("lead@{final_name}"),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
deleted: false,
|
||||
members: vec![TeamMemberRecord {
|
||||
agent_id: format!("lead@{final_name}"),
|
||||
name: "lead".to_string(),
|
||||
agent_type: agent_type.clone(),
|
||||
model: None,
|
||||
status: Some("active".to_string()),
|
||||
joined_at: now,
|
||||
}],
|
||||
};
|
||||
self.write_team_locked(&record)?;
|
||||
persist_team_context(&TeamContext {
|
||||
team_name: record.team_name.clone(),
|
||||
lead_agent_id: record.lead_agent_id.clone(),
|
||||
description,
|
||||
agent_type,
|
||||
task_list_id: record.team_name.clone(),
|
||||
created_at: now,
|
||||
})?;
|
||||
TaskListStore::for_task_list(record.team_name.clone()).reset()?;
|
||||
Ok(record)
|
||||
}
|
||||
|
||||
pub fn delete_team(&self, team_name: &str) -> io::Result<Option<TeamRecord>> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let Some(mut record) = self.get_team(team_name)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
record.deleted = true;
|
||||
record.updated_at = now_secs();
|
||||
self.write_team_locked(&record)?;
|
||||
if record.team_name == team_name {
|
||||
let _ = clear_team_context();
|
||||
}
|
||||
Ok(Some(record))
|
||||
}
|
||||
|
||||
pub fn current_team(&self) -> io::Result<Option<TeamRecord>> {
|
||||
let Some(context) = crate::workflow_state::load_team_context()? else {
|
||||
return Ok(None);
|
||||
};
|
||||
self.get_team(&context.team_name)
|
||||
}
|
||||
|
||||
pub fn get_team(&self, team_name: &str) -> io::Result<Option<TeamRecord>> {
|
||||
let path = team_path(team_name)?;
|
||||
match fs::read_to_string(path) {
|
||||
Ok(contents) => serde_json::from_str(&contents)
|
||||
.map(Some)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)),
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(None),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn upsert_member(
|
||||
&self,
|
||||
team_name: &str,
|
||||
member: TeamMemberRecord,
|
||||
) -> io::Result<Option<TeamRecord>> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let Some(mut record) = self.get_team(team_name)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
if let Some(existing) = record.members.iter_mut().find(|value| value.name == member.name) {
|
||||
*existing = member;
|
||||
} else {
|
||||
record.members.push(member);
|
||||
}
|
||||
record.updated_at = now_secs();
|
||||
self.write_team_locked(&record)?;
|
||||
Ok(Some(record))
|
||||
}
|
||||
|
||||
pub fn set_member_status(
|
||||
&self,
|
||||
team_name: &str,
|
||||
member_name: &str,
|
||||
status: &str,
|
||||
) -> io::Result<Option<TeamRecord>> {
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let Some(mut record) = self.get_team(team_name)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
let Some(member) = record.members.iter_mut().find(|value| value.name == member_name) else {
|
||||
return Ok(None);
|
||||
};
|
||||
member.status = Some(status.to_string());
|
||||
record.updated_at = now_secs();
|
||||
self.write_team_locked(&record)?;
|
||||
Ok(Some(record))
|
||||
}
|
||||
|
||||
pub fn send_message(
|
||||
&self,
|
||||
team_name: &str,
|
||||
from: &str,
|
||||
to: &str,
|
||||
summary: Option<String>,
|
||||
message: Value,
|
||||
) -> io::Result<Vec<MailboxMessage>> {
|
||||
let recipients = if to == "*" {
|
||||
self.get_team(team_name)?
|
||||
.map(|team| {
|
||||
team.members
|
||||
.into_iter()
|
||||
.filter(|member| member.name != from)
|
||||
.map(|member| member.name)
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
vec![to.to_string()]
|
||||
};
|
||||
let mut written = Vec::new();
|
||||
for recipient in recipients {
|
||||
let envelope = MessageEnvelope {
|
||||
id: format!("msg-{}-{}", now_secs(), sanitize_state_component(&recipient)),
|
||||
from: from.to_string(),
|
||||
to: recipient.clone(),
|
||||
summary: summary.clone(),
|
||||
message: message.clone(),
|
||||
timestamp: now_secs(),
|
||||
read: false,
|
||||
notified: false,
|
||||
};
|
||||
self.append_mailbox(team_name, &recipient, &envelope)?;
|
||||
written.push(MailboxMessage {
|
||||
recipient,
|
||||
envelope,
|
||||
});
|
||||
}
|
||||
Ok(written)
|
||||
}
|
||||
|
||||
pub fn mailbox_summary(&self, team_name: &str, limit: usize) -> io::Result<MailboxSummary> {
|
||||
let inbox_root = mailboxes_dir(team_name)?;
|
||||
let entries = match fs::read_dir(&inbox_root) {
|
||||
Ok(entries) => entries,
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => {
|
||||
return Ok(MailboxSummary {
|
||||
team_name: Some(team_name.to_string()),
|
||||
recent_messages: Vec::new(),
|
||||
})
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
};
|
||||
let mut messages = Vec::new();
|
||||
for entry in entries {
|
||||
let path = entry?.path();
|
||||
if path.extension().and_then(|value| value.to_str()) != Some("json") {
|
||||
continue;
|
||||
}
|
||||
let recipient = path
|
||||
.file_stem()
|
||||
.and_then(|value| value.to_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
let contents = fs::read_to_string(&path)?;
|
||||
let inbox = serde_json::from_str::<Vec<MessageEnvelope>>(&contents)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
for envelope in inbox {
|
||||
messages.push(MailboxMessage {
|
||||
recipient: recipient.clone(),
|
||||
envelope,
|
||||
});
|
||||
}
|
||||
}
|
||||
messages.sort_by(|left, right| left.envelope.timestamp.cmp(&right.envelope.timestamp));
|
||||
if messages.len() > limit {
|
||||
messages = messages.split_off(messages.len() - limit);
|
||||
}
|
||||
Ok(MailboxSummary {
|
||||
team_name: Some(team_name.to_string()),
|
||||
recent_messages: messages,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pending_messages(
|
||||
&self,
|
||||
team_name: &str,
|
||||
recipient: &str,
|
||||
limit: usize,
|
||||
) -> io::Result<Vec<MailboxMessage>> {
|
||||
let inbox = self.read_inbox(team_name, recipient)?;
|
||||
let mut messages = inbox
|
||||
.into_iter()
|
||||
.filter(|envelope| !envelope.notified)
|
||||
.map(|envelope| MailboxMessage {
|
||||
recipient: recipient.to_string(),
|
||||
envelope,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
messages.sort_by(|left, right| left.envelope.timestamp.cmp(&right.envelope.timestamp));
|
||||
if messages.len() > limit {
|
||||
messages = messages.split_off(messages.len() - limit);
|
||||
}
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
pub fn mark_messages_notified(
|
||||
&self,
|
||||
team_name: &str,
|
||||
recipient: &str,
|
||||
ids: &[String],
|
||||
) -> io::Result<usize> {
|
||||
if ids.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
let _lock = store_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let mut inbox = self.read_inbox(team_name, recipient)?;
|
||||
let ids = ids.iter().cloned().collect::<std::collections::BTreeSet<_>>();
|
||||
let mut changed = 0usize;
|
||||
for envelope in &mut inbox {
|
||||
if ids.contains(&envelope.id) && !envelope.notified {
|
||||
envelope.notified = true;
|
||||
changed += 1;
|
||||
}
|
||||
}
|
||||
if changed > 0 {
|
||||
self.write_inbox(team_name, recipient, &inbox)?;
|
||||
}
|
||||
Ok(changed)
|
||||
}
|
||||
|
||||
fn write_team_locked(&self, team: &TeamRecord) -> io::Result<()> {
|
||||
let path = team_path(&team.team_name)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
let payload = serde_json::to_vec_pretty(team)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
fs::write(path, payload)
|
||||
}
|
||||
|
||||
fn append_mailbox(
|
||||
&self,
|
||||
team_name: &str,
|
||||
recipient: &str,
|
||||
envelope: &MessageEnvelope,
|
||||
) -> io::Result<()> {
|
||||
let mut existing = self.read_inbox(team_name, recipient)?;
|
||||
existing.push(envelope.clone());
|
||||
self.write_inbox(team_name, recipient, &existing)
|
||||
}
|
||||
|
||||
fn read_inbox(&self, team_name: &str, recipient: &str) -> io::Result<Vec<MessageEnvelope>> {
|
||||
let path = mailbox_path(team_name, recipient)?;
|
||||
match fs::read_to_string(&path) {
|
||||
Ok(contents) => serde_json::from_str::<Vec<MessageEnvelope>>(&contents)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)),
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(Vec::new()),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_inbox(
|
||||
&self,
|
||||
team_name: &str,
|
||||
recipient: &str,
|
||||
inbox: &[MessageEnvelope],
|
||||
) -> io::Result<()> {
|
||||
let path = mailbox_path(team_name, recipient)?;
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
let payload = serde_json::to_vec_pretty(inbox)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
fs::write(path, payload)
|
||||
}
|
||||
|
||||
fn unique_team_name(&self, requested_name: &str) -> io::Result<String> {
|
||||
let requested_name = sanitize_state_component(requested_name);
|
||||
if self.get_team(&requested_name)?.is_none() {
|
||||
return Ok(requested_name);
|
||||
}
|
||||
for index in 2..=128 {
|
||||
let candidate = format!("{requested_name}-{index}");
|
||||
if self.get_team(&candidate)?.is_none() {
|
||||
return Ok(candidate);
|
||||
}
|
||||
}
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::AlreadyExists,
|
||||
"failed to allocate a unique team name",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn teams_dir() -> io::Result<PathBuf> {
|
||||
Ok(state_root()?.join("teams"))
|
||||
}
|
||||
|
||||
fn team_path(team_name: &str) -> io::Result<PathBuf> {
|
||||
Ok(teams_dir()?.join(sanitize_state_component(team_name)).join("config.json"))
|
||||
}
|
||||
|
||||
fn mailboxes_dir(team_name: &str) -> io::Result<PathBuf> {
|
||||
Ok(state_root()?
|
||||
.join("mailbox")
|
||||
.join(sanitize_state_component(team_name)))
|
||||
}
|
||||
|
||||
fn mailbox_path(team_name: &str, recipient: &str) -> io::Result<PathBuf> {
|
||||
Ok(mailboxes_dir(team_name)?.join(format!(
|
||||
"{}.json",
|
||||
sanitize_state_component(recipient)
|
||||
)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
use super::{TeamMemberRecord, TeamStore};
|
||||
use crate::test_env_lock;
|
||||
|
||||
#[test]
|
||||
fn team_creation_assigns_unique_names_and_mailbox_messages_round_trip() {
|
||||
let _lock = test_env_lock();
|
||||
let root = std::env::temp_dir().join("team-store-tests");
|
||||
let _ = std::fs::remove_dir_all(&root);
|
||||
std::env::set_var("CLAW_WORKER_STATE_ROOT", &root);
|
||||
|
||||
let store = TeamStore::new();
|
||||
let first = store
|
||||
.create_team("alpha", Some("Team Alpha".to_string()), Some("researcher".to_string()))
|
||||
.expect("team creates");
|
||||
let second = store
|
||||
.create_team("alpha", None, None)
|
||||
.expect("second team creates");
|
||||
assert_eq!(first.team_name, "alpha");
|
||||
assert_eq!(second.team_name, "alpha-2");
|
||||
store
|
||||
.upsert_member(
|
||||
&first.team_name,
|
||||
TeamMemberRecord {
|
||||
agent_id: "agent-1".to_string(),
|
||||
name: "alice".to_string(),
|
||||
agent_type: Some("researcher".to_string()),
|
||||
model: Some("claude-opus-4-6".to_string()),
|
||||
status: Some("active".to_string()),
|
||||
joined_at: 1,
|
||||
},
|
||||
)
|
||||
.expect("member upserts");
|
||||
let messages = store
|
||||
.send_message(&first.team_name, "lead", "alice", Some("hello".to_string()), json!("hi"))
|
||||
.expect("message sends");
|
||||
assert_eq!(messages.len(), 1);
|
||||
let summary = store
|
||||
.mailbox_summary(&first.team_name, 10)
|
||||
.expect("summary loads");
|
||||
assert_eq!(summary.recent_messages.len(), 1);
|
||||
let pending = store
|
||||
.pending_messages(&first.team_name, "alice", 10)
|
||||
.expect("pending mailbox should load");
|
||||
assert_eq!(pending.len(), 1);
|
||||
let notified = store
|
||||
.mark_messages_notified(
|
||||
&first.team_name,
|
||||
"alice",
|
||||
&[pending[0].envelope.id.clone()],
|
||||
)
|
||||
.expect("mark notified should work");
|
||||
assert_eq!(notified, 1);
|
||||
assert!(
|
||||
store
|
||||
.pending_messages(&first.team_name, "alice", 10)
|
||||
.expect("pending mailbox should load")
|
||||
.is_empty()
|
||||
);
|
||||
let updated = store
|
||||
.set_member_status(&first.team_name, "alice", "completed")
|
||||
.expect("status update should succeed")
|
||||
.expect("team should exist");
|
||||
assert_eq!(
|
||||
updated
|
||||
.members
|
||||
.iter()
|
||||
.find(|member| member.name == "alice")
|
||||
.and_then(|member| member.status.as_deref()),
|
||||
Some("completed")
|
||||
);
|
||||
|
||||
let _ = std::fs::remove_dir_all(&root);
|
||||
std::env::remove_var("CLAW_WORKER_STATE_ROOT");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct TeamContext {
|
||||
pub team_name: String,
|
||||
pub lead_agent_id: String,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
#[serde(default)]
|
||||
pub agent_type: Option<String>,
|
||||
pub task_list_id: String,
|
||||
pub created_at: u64,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn sanitize_state_component(value: &str) -> String {
|
||||
let sanitized = value
|
||||
.trim()
|
||||
.chars()
|
||||
.map(|ch| {
|
||||
if ch.is_ascii_alphanumeric() || matches!(ch, '-' | '_') {
|
||||
ch
|
||||
} else {
|
||||
'-'
|
||||
}
|
||||
})
|
||||
.collect::<String>();
|
||||
let collapsed = sanitized
|
||||
.split('-')
|
||||
.filter(|segment| !segment.is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
.join("-");
|
||||
if collapsed.is_empty() {
|
||||
"default".to_string()
|
||||
} else {
|
||||
collapsed
|
||||
}
|
||||
}
|
||||
|
||||
pub fn state_root() -> io::Result<PathBuf> {
|
||||
if let Some(root) = env::var_os("CLAW_WORKER_STATE_ROOT")
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(PathBuf::from)
|
||||
{
|
||||
return Ok(root);
|
||||
}
|
||||
if let Some(root) = env::var_os("CLAWD_STATE_ROOT")
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(PathBuf::from)
|
||||
{
|
||||
return Ok(root);
|
||||
}
|
||||
Ok(env::current_dir()?.join(".clawd-state"))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn default_session_identity() -> String {
|
||||
env::var("CLAW_WORKER_PROFILE_ID")
|
||||
.ok()
|
||||
.filter(|value| !value.trim().is_empty())
|
||||
.or_else(|| {
|
||||
env::current_dir()
|
||||
.ok()
|
||||
.and_then(|cwd| cwd.file_name().map(|value| value.to_string_lossy().to_string()))
|
||||
})
|
||||
.map(|value| sanitize_state_component(&value))
|
||||
.unwrap_or_else(|| "default".to_string())
|
||||
}
|
||||
|
||||
fn team_context_path(root: &Path) -> PathBuf {
|
||||
root.join("session").join("team-context.json")
|
||||
}
|
||||
|
||||
pub fn load_team_context() -> io::Result<Option<TeamContext>> {
|
||||
let path = team_context_path(&state_root()?);
|
||||
match fs::read_to_string(path) {
|
||||
Ok(contents) => serde_json::from_str(&contents)
|
||||
.map(Some)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)),
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(None),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn persist_team_context(context: &TeamContext) -> io::Result<()> {
|
||||
let root = state_root()?;
|
||||
let path = team_context_path(&root);
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
let serialized = serde_json::to_vec_pretty(context)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
fs::write(path, serialized)
|
||||
}
|
||||
|
||||
pub fn clear_team_context() -> io::Result<()> {
|
||||
let path = team_context_path(&state_root()?);
|
||||
match fs::remove_file(path) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(()),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn current_task_list_id() -> io::Result<String> {
|
||||
if let Some(explicit) = env::var_os("CLAW_CODE_TASK_LIST_ID")
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(|value| sanitize_state_component(&value.to_string_lossy()))
|
||||
{
|
||||
return Ok(explicit);
|
||||
}
|
||||
if let Some(team) = load_team_context()? {
|
||||
return Ok(sanitize_state_component(&team.task_list_id));
|
||||
}
|
||||
Ok(default_session_identity())
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
clear_team_context, current_task_list_id, default_session_identity, load_team_context,
|
||||
persist_team_context, sanitize_state_component, state_root, TeamContext,
|
||||
};
|
||||
use crate::test_env_lock;
|
||||
|
||||
#[test]
|
||||
fn sanitize_state_component_removes_path_chars() {
|
||||
assert_eq!(sanitize_state_component("../Team Alpha"), "Team-Alpha");
|
||||
assert_eq!(sanitize_state_component(""), "default");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn team_context_round_trips() {
|
||||
let _lock = test_env_lock();
|
||||
let root = std::env::temp_dir().join("workflow-state-roundtrip");
|
||||
let _ = std::fs::remove_dir_all(&root);
|
||||
std::env::set_var("CLAW_WORKER_STATE_ROOT", &root);
|
||||
std::env::set_var("CLAW_WORKER_PROFILE_ID", "makar");
|
||||
|
||||
let context = TeamContext {
|
||||
team_name: "alpha".to_string(),
|
||||
lead_agent_id: "agent-1".to_string(),
|
||||
description: Some("test".to_string()),
|
||||
agent_type: Some("researcher".to_string()),
|
||||
task_list_id: "alpha".to_string(),
|
||||
created_at: 1,
|
||||
};
|
||||
persist_team_context(&context).expect("context persists");
|
||||
let loaded = load_team_context()
|
||||
.expect("context loads")
|
||||
.expect("context should exist");
|
||||
assert_eq!(loaded, context);
|
||||
assert_eq!(current_task_list_id().expect("list id"), "alpha");
|
||||
clear_team_context().expect("context clears");
|
||||
assert!(load_team_context().expect("loads").is_none());
|
||||
assert_eq!(default_session_identity(), "makar");
|
||||
assert_eq!(state_root().expect("state root"), root);
|
||||
let _ = std::fs::remove_dir_all(&root);
|
||||
std::env::remove_var("CLAW_WORKER_STATE_ROOT");
|
||||
std::env::remove_var("CLAW_WORKER_PROFILE_ID");
|
||||
}
|
||||
}
|
||||
@@ -1 +1,2 @@
|
||||
.clawd-agents/
|
||||
.clawd-state/
|
||||
|
||||
+1415
-328
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user