use async_trait::async_trait; use serde_json::Value; use crate::agent::persistence::AgentRunContext; use crate::error::AiResult; #[derive(Debug, Clone)] pub enum ToolGuardrailDecision { Allow, Block { reason: String }, RequireApproval { message: String }, } #[derive(Debug, Clone)] pub struct ToolCallOutcome { pub name: String, pub arguments: Value, pub output: Option, pub error: Option, pub elapsed_ms: i64, } #[derive(Debug, Clone)] pub struct HookMessage { pub role: String, pub content: Option, pub tool_calls: Option, pub tool_call_id: Option, } #[derive(Debug, Clone)] pub struct HookLlmResponse { pub content: Option, pub tool_calls: Option, pub input_tokens: u64, pub output_tokens: u64, pub finish_reason: Option, } #[derive(Debug, Clone)] pub struct HookToolDef { pub name: String, pub description: String, } #[async_trait] pub trait AgentHook: Send + Sync { fn name(&self) -> &'static str; async fn on_session_start(&self, _ctx: &AgentRunContext) -> AiResult<()> { Ok(()) } async fn on_session_end( &self, _ctx: &AgentRunContext, _success: bool, ) -> AiResult<()> { Ok(()) } async fn pre_llm_call( &self, _messages: &[HookMessage], _tools: &[HookToolDef], ) -> AiResult<()> { Ok(()) } async fn post_llm_call(&self, _response: &HookLlmResponse) -> AiResult<()> { Ok(()) } async fn pre_tool_call( &self, _tool_name: &str, _arguments: &Value, ) -> AiResult> { Ok(None) } async fn post_tool_call(&self, _outcome: &ToolCallOutcome) -> AiResult<()> { Ok(()) } } pub struct HookChain { hooks: Vec>, } impl HookChain { pub fn new(hooks: Vec>) -> Self { Self { hooks } } pub fn empty() -> Self { Self { hooks: Vec::new() } } pub fn is_empty(&self) -> bool { self.hooks.is_empty() } pub async fn run_session_start( &self, ctx: &AgentRunContext, ) -> AiResult<()> { for hook in &self.hooks { hook.on_session_start(ctx).await?; } Ok(()) } pub async fn run_session_end( &self, ctx: &AgentRunContext, success: bool, ) -> AiResult<()> { for hook in &self.hooks { hook.on_session_end(ctx, success).await?; } Ok(()) } pub async fn run_pre_llm_call( &self, messages: &[HookMessage], tools: &[HookToolDef], ) -> AiResult<()> { for hook in &self.hooks { hook.pre_llm_call(messages, tools).await?; } Ok(()) } pub async fn run_post_llm_call( &self, response: &HookLlmResponse, ) -> AiResult<()> { for hook in &self.hooks { hook.post_llm_call(response).await?; } Ok(()) } pub async fn run_pre_tool_call( &self, tool_name: &str, arguments: &Value, ) -> AiResult> { for hook in &self.hooks { if let Some(decision) = hook.pre_tool_call(tool_name, arguments).await? { if !matches!(decision, ToolGuardrailDecision::Allow) { return Ok(Some(decision)); } } } Ok(None) } pub async fn run_post_tool_call( &self, outcome: &ToolCallOutcome, ) -> AiResult<()> { for hook in &self.hooks { hook.post_tool_call(outcome).await?; } Ok(()) } }