gitdataai/lib/ai/agent/hooks.rs

173 lines
3.7 KiB
Rust

use async_trait::async_trait;
use serde_json::Value;
use crate::agent::persistence::AgentRunContext;
use crate::error::AiResult;
#[derive(Debug, Clone)]
pub enum ToolGuardrailDecision {
Allow,
Block { reason: String },
RequireApproval { message: String },
}
#[derive(Debug, Clone)]
pub struct ToolCallOutcome {
pub name: String,
pub arguments: Value,
pub output: Option<Value>,
pub error: Option<String>,
pub elapsed_ms: i64,
}
#[derive(Debug, Clone)]
pub struct HookMessage {
pub role: String,
pub content: Option<String>,
pub tool_calls: Option<Value>,
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone)]
pub struct HookLlmResponse {
pub content: Option<String>,
pub tool_calls: Option<Value>,
pub input_tokens: u64,
pub output_tokens: u64,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct HookToolDef {
pub name: String,
pub description: String,
}
#[async_trait]
pub trait AgentHook: Send + Sync {
fn name(&self) -> &'static str;
async fn on_session_start(&self, _ctx: &AgentRunContext) -> AiResult<()> {
Ok(())
}
async fn on_session_end(
&self,
_ctx: &AgentRunContext,
_success: bool,
) -> AiResult<()> {
Ok(())
}
async fn pre_llm_call(
&self,
_messages: &[HookMessage],
_tools: &[HookToolDef],
) -> AiResult<()> {
Ok(())
}
async fn post_llm_call(&self, _response: &HookLlmResponse) -> AiResult<()> {
Ok(())
}
async fn pre_tool_call(
&self,
_tool_name: &str,
_arguments: &Value,
) -> AiResult<Option<ToolGuardrailDecision>> {
Ok(None)
}
async fn post_tool_call(&self, _outcome: &ToolCallOutcome) -> AiResult<()> {
Ok(())
}
}
pub struct HookChain {
hooks: Vec<Box<dyn AgentHook>>,
}
impl HookChain {
pub fn new(hooks: Vec<Box<dyn AgentHook>>) -> Self {
Self { hooks }
}
pub fn empty() -> Self {
Self { hooks: Vec::new() }
}
pub fn is_empty(&self) -> bool {
self.hooks.is_empty()
}
pub async fn run_session_start(
&self,
ctx: &AgentRunContext,
) -> AiResult<()> {
for hook in &self.hooks {
hook.on_session_start(ctx).await?;
}
Ok(())
}
pub async fn run_session_end(
&self,
ctx: &AgentRunContext,
success: bool,
) -> AiResult<()> {
for hook in &self.hooks {
hook.on_session_end(ctx, success).await?;
}
Ok(())
}
pub async fn run_pre_llm_call(
&self,
messages: &[HookMessage],
tools: &[HookToolDef],
) -> AiResult<()> {
for hook in &self.hooks {
hook.pre_llm_call(messages, tools).await?;
}
Ok(())
}
pub async fn run_post_llm_call(
&self,
response: &HookLlmResponse,
) -> AiResult<()> {
for hook in &self.hooks {
hook.post_llm_call(response).await?;
}
Ok(())
}
pub async fn run_pre_tool_call(
&self,
tool_name: &str,
arguments: &Value,
) -> AiResult<Option<ToolGuardrailDecision>> {
for hook in &self.hooks {
if let Some(decision) =
hook.pre_tool_call(tool_name, arguments).await?
{
if !matches!(decision, ToolGuardrailDecision::Allow) {
return Ok(Some(decision));
}
}
}
Ok(None)
}
pub async fn run_post_tool_call(
&self,
outcome: &ToolCallOutcome,
) -> AiResult<()> {
for hook in &self.hooks {
hook.post_tool_call(outcome).await?;
}
Ok(())
}
}