Add session-scoped Telegram approvals
Build Claw Telegram / build (push) Successful in 4m38s
Build Claw Telegram / cleanup (push) Successful in 1s

This commit is contained in:
Wylabb
2026-04-05 07:40:39 +02:00
parent 04b482fbc8
commit d164dc5f8e
5 changed files with 224 additions and 27 deletions
+1 -1
View File
@@ -15,7 +15,7 @@ pub use protocol::{
};
pub use runtime_host::{
ApprovalDecision, ApprovalRequestPayload, ApprovalResponder, AttachmentKind, AttachmentRef,
HostError, RuntimeEvent, RuntimeHost, RuntimeHostConfig,
HostError, RuntimeEvent, RuntimeHost, RuntimeHostConfig, SessionApprovalState,
};
pub use unraid_template::{
ManagedTemplateRecord, UnraidTemplateConfigEntry, UnraidTemplateEntryType, UnraidTemplateSpec,
@@ -83,6 +83,8 @@ pub enum WorkerTurnEvent {
#[serde(tag = "decision", rename_all = "snake_case")]
pub enum WorkerApprovalDecision {
ApproveOnce,
ApproveToolForSession,
ApproveAllForSession,
Deny { reason: String },
CancelTurn,
}
@@ -63,10 +63,39 @@ pub struct ApprovalRequestPayload {
#[serde(tag = "decision", rename_all = "snake_case")]
pub enum ApprovalDecision {
ApproveOnce,
ApproveToolForSession,
ApproveAllForSession,
Deny { reason: String },
CancelTurn,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct SessionApprovalState {
pub allow_all_tools: bool,
#[serde(default)]
pub allowed_tools: BTreeSet<String>,
}
impl SessionApprovalState {
#[must_use]
pub fn allows(&self, tool_name: &str) -> bool {
self.allow_all_tools || self.allowed_tools.contains(tool_name)
}
pub fn allow_tool(&mut self, tool_name: impl Into<String>) {
self.allowed_tools.insert(tool_name.into());
}
pub fn allow_all(&mut self) {
self.allow_all_tools = true;
}
pub fn clear(&mut self) {
self.allow_all_tools = false;
self.allowed_tools.clear();
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TurnCancellation {
pub reason: String,
@@ -173,6 +202,7 @@ impl RuntimeHost {
session_path: &Path,
prompt: String,
attachments: Vec<AttachmentRef>,
approval_state: Arc<Mutex<SessionApprovalState>>,
cancel_flag: Arc<AtomicBool>,
event_tx: UnboundedSender<RuntimeEvent>,
) -> Result<(), HostError> {
@@ -193,6 +223,7 @@ impl RuntimeHost {
event_tx.clone(),
self.config.permission_mode,
self.config.approval_timeout,
approval_state,
cancel_flag.clone(),
);
let prompt = compose_user_input(prompt, &attachments);
@@ -1326,6 +1357,7 @@ struct BridgePermissionPrompter {
tx: UnboundedSender<RuntimeEvent>,
current_mode: PermissionMode,
timeout: Duration,
approval_state: Arc<Mutex<SessionApprovalState>>,
cancel_flag: Arc<AtomicBool>,
next_approval_id: u64,
}
@@ -1335,12 +1367,14 @@ impl BridgePermissionPrompter {
tx: UnboundedSender<RuntimeEvent>,
current_mode: PermissionMode,
timeout: Duration,
approval_state: Arc<Mutex<SessionApprovalState>>,
cancel_flag: Arc<AtomicBool>,
) -> Self {
Self {
tx,
current_mode,
timeout,
approval_state,
cancel_flag,
next_approval_id: 0,
}
@@ -1349,6 +1383,15 @@ impl BridgePermissionPrompter {
impl PermissionPrompter for BridgePermissionPrompter {
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
if self
.approval_state
.lock()
.map(|state| state.allows(&request.tool_name))
.unwrap_or(false)
{
return PermissionPromptDecision::Allow;
}
if self.cancel_flag.load(Ordering::SeqCst) {
return PermissionPromptDecision::Deny {
reason: "turn cancelled by user".to_string(),
@@ -1391,6 +1434,18 @@ impl PermissionPrompter for BridgePermissionPrompter {
Ok(ApprovalDecision::ApproveOnce) => {
return PermissionPromptDecision::Allow;
}
Ok(ApprovalDecision::ApproveToolForSession) => {
if let Ok(mut state) = self.approval_state.lock() {
state.allow_tool(request.tool_name.clone());
}
return PermissionPromptDecision::Allow;
}
Ok(ApprovalDecision::ApproveAllForSession) => {
if let Ok(mut state) = self.approval_state.lock() {
state.allow_all();
}
return PermissionPromptDecision::Allow;
}
Ok(ApprovalDecision::Deny { reason }) => {
return PermissionPromptDecision::Deny { reason };
}
+89 -6
View File
@@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use std::fmt::{Display, Formatter};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::sync::{Arc, Mutex as StdMutex};
use std::time::{SystemTime, UNIX_EPOCH};
use async_stream::stream;
@@ -15,8 +15,8 @@ use axum::{Json, Router};
use base64::Engine as _;
use channel_gateway_core::{
ApprovalDecision, ApprovalResponder, AttachmentRef, GeneratedFileDescriptor, HostError,
RuntimeEvent, RuntimeHost, RuntimeHostConfig, WorkerApprovalDecision, WorkerStatusResponse,
WorkerTurnAccepted, WorkerTurnEvent, WorkerTurnRequest,
RuntimeEvent, RuntimeHost, RuntimeHostConfig, SessionApprovalState, WorkerApprovalDecision,
WorkerStatusResponse, WorkerTurnAccepted, WorkerTurnEvent, WorkerTurnRequest,
};
use serde::Deserialize;
use subtle::ConstantTimeEq;
@@ -54,11 +54,13 @@ struct AppState {
runtime: Arc<dyn WorkerRuntime>,
active_turn: AsyncMutex<Option<Arc<TurnState>>>,
completed_turns: AsyncMutex<BTreeMap<String, Arc<TurnState>>>,
approval_state: Arc<StdMutex<SessionApprovalState>>,
}
impl AppState {
fn new(config: WorkerConfig, runtime: Arc<dyn WorkerRuntime>) -> Arc<Self> {
Arc::new(Self {
approval_state: Arc::new(StdMutex::new(load_session_approval_state(&config))),
config,
runtime,
active_turn: AsyncMutex::new(None),
@@ -139,6 +141,7 @@ trait WorkerRuntime: Send + Sync {
session_path: &Path,
prompt: String,
attachments: Vec<AttachmentRef>,
approval_state: Arc<StdMutex<SessionApprovalState>>,
cancel_flag: Arc<AtomicBool>,
event_tx: mpsc::UnboundedSender<RuntimeEvent>,
) -> Result<(), HostError>;
@@ -172,11 +175,19 @@ impl WorkerRuntime for RealWorkerRuntime {
session_path: &Path,
prompt: String,
attachments: Vec<AttachmentRef>,
approval_state: Arc<StdMutex<SessionApprovalState>>,
cancel_flag: Arc<AtomicBool>,
event_tx: mpsc::UnboundedSender<RuntimeEvent>,
) -> Result<(), HostError> {
self.host
.run_turn(session_path, prompt, attachments, cancel_flag, event_tx)
.run_turn(
session_path,
prompt,
attachments,
approval_state,
cancel_flag,
event_tx,
)
}
fn session_message_count(&self, session_path: &Path) -> Result<usize, HostError> {
@@ -239,6 +250,8 @@ async fn reset_session(
.runtime
.reset_session(&session_path, &archive_dir)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
clear_session_approval_state(&state.config, &state.approval_state)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
drop(active_turn);
Ok(StatusCode::ACCEPTED)
}
@@ -269,6 +282,7 @@ async fn post_turn(
let (runtime_tx, mut runtime_rx) = mpsc::unbounded_channel();
let runtime = state.runtime.clone();
let cancel_flag = turn_state.cancel_flag.clone();
let approval_state = state.approval_state.clone();
let prompt = request.prompt.clone();
let profile_id = state.config.profile_id.clone();
let model = state.config.model.clone();
@@ -277,13 +291,13 @@ async fn post_turn(
&session_path,
prompt,
attachments,
approval_state,
cancel_flag,
runtime_tx.clone(),
) {
eprintln!(
"worker turn failed: profile_id={} model={} error={error}",
profile_id,
model,
profile_id, model,
);
let _ = runtime_tx.send(RuntimeEvent::Failed {
message: error.to_string(),
@@ -362,6 +376,10 @@ async fn post_turn(
state_for_events
.finalize_turn(turn_state_for_events.clone())
.await;
let _ = persist_session_approval_state(
&state_for_events.config,
&state_for_events.approval_state,
);
return;
}
RuntimeEvent::Failed { message } => {
@@ -371,6 +389,10 @@ async fn post_turn(
state_for_events
.finalize_turn(turn_state_for_events.clone())
.await;
let _ = persist_session_approval_state(
&state_for_events.config,
&state_for_events.approval_state,
);
return;
}
}
@@ -378,6 +400,10 @@ async fn post_turn(
state_for_events
.finalize_turn(turn_state_for_events.clone())
.await;
let _ = persist_session_approval_state(
&state_for_events.config,
&state_for_events.approval_state,
);
});
Ok((StatusCode::ACCEPTED, Json(WorkerTurnAccepted { turn_id })))
@@ -581,11 +607,53 @@ fn authorize(headers: &HeaderMap, expected_token: &str) -> Result<(), StatusCode
fn map_approval_decision(decision: WorkerApprovalDecision) -> ApprovalDecision {
match decision {
WorkerApprovalDecision::ApproveOnce => ApprovalDecision::ApproveOnce,
WorkerApprovalDecision::ApproveToolForSession => ApprovalDecision::ApproveToolForSession,
WorkerApprovalDecision::ApproveAllForSession => ApprovalDecision::ApproveAllForSession,
WorkerApprovalDecision::Deny { reason } => ApprovalDecision::Deny { reason },
WorkerApprovalDecision::CancelTurn => ApprovalDecision::CancelTurn,
}
}
fn approval_state_path(config: &WorkerConfig) -> PathBuf {
config.state_root.join("approval-session.json")
}
fn load_session_approval_state(config: &WorkerConfig) -> SessionApprovalState {
let path = approval_state_path(config);
std::fs::read_to_string(path)
.ok()
.and_then(|contents| serde_json::from_str::<SessionApprovalState>(&contents).ok())
.unwrap_or_default()
}
fn persist_session_approval_state(
config: &WorkerConfig,
approval_state: &Arc<StdMutex<SessionApprovalState>>,
) -> Result<(), ServerError> {
std::fs::create_dir_all(&config.state_root)?;
let snapshot = approval_state
.lock()
.map(|state| state.clone())
.unwrap_or_default();
let serialized = serde_json::to_vec_pretty(&snapshot)
.map_err(|error| ServerError::StatePersistence(error.to_string()))?;
std::fs::write(approval_state_path(config), serialized)?;
Ok(())
}
fn clear_session_approval_state(
config: &WorkerConfig,
approval_state: &Arc<StdMutex<SessionApprovalState>>,
) -> Result<(), ServerError> {
{
let mut state = approval_state
.lock()
.map_err(|error| ServerError::StatePersistence(error.to_string()))?;
state.clear();
}
persist_session_approval_state(config, approval_state)
}
fn next_turn_id() -> String {
static COUNTER: AtomicU64 = AtomicU64::new(1);
let now = SystemTime::now()
@@ -646,6 +714,7 @@ pub enum ServerError {
Io(std::io::Error),
Host(HostError),
Base64(base64::DecodeError),
StatePersistence(String),
}
impl Display for ServerError {
@@ -654,6 +723,7 @@ impl Display for ServerError {
Self::Io(error) => write!(f, "{error}"),
Self::Host(error) => write!(f, "{error}"),
Self::Base64(error) => write!(f, "{error}"),
Self::StatePersistence(message) => write!(f, "{message}"),
}
}
}
@@ -688,6 +758,7 @@ mod tests {
_session_path: &Path,
_prompt: String,
_attachments: Vec<AttachmentRef>,
_approval_state: Arc<StdMutex<SessionApprovalState>>,
_cancel_flag: Arc<AtomicBool>,
event_tx: mpsc::UnboundedSender<RuntimeEvent>,
) -> Result<(), HostError> {
@@ -826,4 +897,16 @@ mod tests {
}
);
}
#[test]
fn approval_mapping_supports_session_scoped_allows() {
assert_eq!(
map_approval_decision(WorkerApprovalDecision::ApproveToolForSession),
ApprovalDecision::ApproveToolForSession
);
assert_eq!(
map_approval_decision(WorkerApprovalDecision::ApproveAllForSession),
ApprovalDecision::ApproveAllForSession
);
}
}
+77 -20
View File
@@ -464,7 +464,8 @@ impl TelegramGateway {
}
}
WorkerTurnEvent::ApprovalRequested { request } => {
let keyboard = approval_keyboard(&turn_id, &request.approval_id);
let keyboard =
approval_keyboard(&turn_id, &request.approval_id, &request.tool_name);
match api
.send_message(
chat_id,
@@ -1089,7 +1090,9 @@ struct ApprovalAction {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ApprovalKind {
Allow,
AllowOnce,
AllowToolForSession,
AllowAllForSession,
Deny,
Cancel,
}
@@ -1101,7 +1104,9 @@ impl ApprovalAction {
let turn_id = parts.next()?.to_string();
let approval_id = parts.next()?.to_string();
let action = match parts.next()? {
"allow" => ApprovalKind::Allow,
"allow" | "allow_once" => ApprovalKind::AllowOnce,
"allow_tool_session" => ApprovalKind::AllowToolForSession,
"allow_all_session" => ApprovalKind::AllowAllForSession,
"deny" => ApprovalKind::Deny,
"cancel" => ApprovalKind::Cancel,
_ => return None,
@@ -1115,7 +1120,9 @@ impl ApprovalAction {
fn decision(&self) -> WorkerApprovalDecision {
match self.action {
ApprovalKind::Allow => WorkerApprovalDecision::ApproveOnce,
ApprovalKind::AllowOnce => WorkerApprovalDecision::ApproveOnce,
ApprovalKind::AllowToolForSession => WorkerApprovalDecision::ApproveToolForSession,
ApprovalKind::AllowAllForSession => WorkerApprovalDecision::ApproveAllForSession,
ApprovalKind::Deny => WorkerApprovalDecision::Deny {
reason: "tool call denied from Telegram approval prompt".to_string(),
},
@@ -1125,32 +1132,61 @@ impl ApprovalAction {
fn label(&self) -> &'static str {
match self.action {
ApprovalKind::Allow => "Approved.",
ApprovalKind::AllowOnce => "Approved once.",
ApprovalKind::AllowToolForSession => "Tool allowed for this session.",
ApprovalKind::AllowAllForSession => "All tools allowed for this session.",
ApprovalKind::Deny => "Denied.",
ApprovalKind::Cancel => "Cancelled.",
}
}
}
fn approval_keyboard(turn_id: &str, approval_id: &str) -> InlineKeyboardMarkup {
fn approval_keyboard(turn_id: &str, approval_id: &str, tool_name: &str) -> InlineKeyboardMarkup {
let tool_label = session_tool_button_label(tool_name);
InlineKeyboardMarkup {
inline_keyboard: vec![vec![
InlineKeyboardButton {
text: "Approve once".to_string(),
callback_data: format!("cta:{turn_id}:{approval_id}:allow"),
},
InlineKeyboardButton {
text: "Deny".to_string(),
callback_data: format!("cta:{turn_id}:{approval_id}:deny"),
},
InlineKeyboardButton {
inline_keyboard: vec![
vec![
InlineKeyboardButton {
text: "Approve once".to_string(),
callback_data: format!("cta:{turn_id}:{approval_id}:allow_once"),
},
InlineKeyboardButton {
text: tool_label,
callback_data: format!("cta:{turn_id}:{approval_id}:allow_tool_session"),
},
],
vec![
InlineKeyboardButton {
text: "Allow all for session".to_string(),
callback_data: format!("cta:{turn_id}:{approval_id}:allow_all_session"),
},
InlineKeyboardButton {
text: "Deny".to_string(),
callback_data: format!("cta:{turn_id}:{approval_id}:deny"),
},
],
vec![InlineKeyboardButton {
text: "Cancel turn".to_string(),
callback_data: format!("cta:{turn_id}:{approval_id}:cancel"),
},
]],
}],
],
}
}
fn session_tool_button_label(tool_name: &str) -> String {
const MAX_LABEL_LEN: usize = 26;
let trimmed = tool_name.trim();
let short_name = if trimmed.chars().count() > MAX_LABEL_LEN {
let shortened = trimmed.chars().take(MAX_LABEL_LEN - 1).collect::<String>();
format!("{shortened}")
} else if trimmed.is_empty() {
"tool".to_string()
} else {
trimmed.to_string()
};
format!("Allow {short_name} for session")
}
fn sanitize_extension(extension: &str) -> String {
let trimmed = extension.trim();
if trimmed.starts_with('.') {
@@ -1266,7 +1302,10 @@ impl From<TemplateError> for GatewayError {
mod tests {
use std::time::{SystemTime, UNIX_EPOCH};
use super::{load_or_init_manifest, parse_command, split_text, ApprovalAction};
use super::{
load_or_init_manifest, parse_command, session_tool_button_label, split_text,
ApprovalAction,
};
use crate::config::GatewayConfig;
use crate::telegram_api::{Chat, Message, User};
@@ -1281,7 +1320,25 @@ mod tests {
let parsed = ApprovalAction::parse("cta:turn:approval:allow").expect("action should parse");
assert_eq!(parsed.turn_id, "turn");
assert_eq!(parsed.approval_id, "approval");
assert_eq!(parsed.label(), "Approved.");
assert_eq!(parsed.label(), "Approved once.");
}
#[test]
fn approval_action_supports_session_scope_buttons() {
let tool = ApprovalAction::parse("cta:turn:approval:allow_tool_session")
.expect("tool approval should parse");
assert_eq!(tool.label(), "Tool allowed for this session.");
let all = ApprovalAction::parse("cta:turn:approval:allow_all_session")
.expect("global approval should parse");
assert_eq!(all.label(), "All tools allowed for this session.");
}
#[test]
fn session_tool_button_label_truncates_long_names() {
let label = session_tool_button_label("very-long-tool-name-that-keeps-going");
assert!(label.starts_with("Allow very-long-tool-name-tha"));
assert!(label.ends_with("… for session"));
}
#[test]