gitdataai/lib/ai/agent/hooks.rs
2026-05-30 01:38:40 +08:00

146 lines
3.5 KiB
Rust

use async_trait::async_trait;
use serde_json::Value;
use crate::agent::persistence::AgentRunContext;
use crate::error::AiResult;
#[derive(Debug, Clone)]
pub enum ToolGuardrailDecision {
Allow,
Block { reason: String },
RequireApproval { message: String },
}
#[derive(Debug, Clone)]
pub struct ToolCallOutcome {
pub name: String,
pub arguments: Value,
pub output: Option<Value>,
pub error: Option<String>,
pub elapsed_ms: i64,
}
#[derive(Debug, Clone)]
pub struct HookMessage {
pub role: String,
pub content: Option<String>,
pub tool_calls: Option<Value>,
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone)]
pub struct HookLlmResponse {
pub content: Option<String>,
pub tool_calls: Option<Value>,
pub input_tokens: u64,
pub output_tokens: u64,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct HookToolDef {
pub name: String,
pub description: String,
}
#[async_trait]
pub trait AgentHook: Send + Sync {
fn name(&self) -> &'static str;
async fn on_session_start(&self, _ctx: &AgentRunContext) -> AiResult<()> {
Ok(())
}
async fn on_session_end(&self, _ctx: &AgentRunContext, _success: bool) -> AiResult<()> {
Ok(())
}
async fn pre_llm_call(&self, _messages: &[HookMessage], _tools: &[HookToolDef]) -> AiResult<()> {
Ok(())
}
async fn post_llm_call(&self, _response: &HookLlmResponse) -> AiResult<()> {
Ok(())
}
async fn pre_tool_call(
&self,
_tool_name: &str,
_arguments: &Value,
) -> AiResult<Option<ToolGuardrailDecision>> {
Ok(None)
}
async fn post_tool_call(&self, _outcome: &ToolCallOutcome) -> AiResult<()> {
Ok(())
}
}
pub struct HookChain {
hooks: Vec<Box<dyn AgentHook>>,
}
impl HookChain {
pub fn new(hooks: Vec<Box<dyn AgentHook>>) -> Self {
Self { hooks }
}
pub fn empty() -> Self {
Self { hooks: Vec::new() }
}
pub fn is_empty(&self) -> bool {
self.hooks.is_empty()
}
pub async fn run_session_start(&self, ctx: &AgentRunContext) -> AiResult<()> {
for hook in &self.hooks {
hook.on_session_start(ctx).await?;
}
Ok(())
}
pub async fn run_session_end(&self, ctx: &AgentRunContext, success: bool) -> AiResult<()> {
for hook in &self.hooks {
hook.on_session_end(ctx, success).await?;
}
Ok(())
}
pub async fn run_pre_llm_call(&self, messages: &[HookMessage], tools: &[HookToolDef]) -> AiResult<()> {
for hook in &self.hooks {
hook.pre_llm_call(messages, tools).await?;
}
Ok(())
}
pub async fn run_post_llm_call(&self, response: &HookLlmResponse) -> AiResult<()> {
for hook in &self.hooks {
hook.post_llm_call(response).await?;
}
Ok(())
}
pub async fn run_pre_tool_call(
&self,
tool_name: &str,
arguments: &Value,
) -> AiResult<Option<ToolGuardrailDecision>> {
for hook in &self.hooks {
if let Some(decision) = hook.pre_tool_call(tool_name, arguments).await? {
if !matches!(decision, ToolGuardrailDecision::Allow) {
return Ok(Some(decision));
}
}
}
Ok(None)
}
pub async fn run_post_tool_call(&self, outcome: &ToolCallOutcome) -> AiResult<()> {
for hook in &self.hooks {
hook.post_tool_call(outcome).await?;
}
Ok(())
}
}