gitdataai/lib/ai/agent/request.rs

256 lines
6.6 KiB
Rust

use std::time::Duration;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio_util::sync::CancellationToken;
use super::persistence::AgentRunContext;
use crate::error::{AiError, AiResult};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentRequest {
pub input: String,
pub messages: Vec<AgentMessage>,
pub context: Vec<AgentContextChunk>,
pub experts: Vec<AgentExpert>,
pub run_context: Option<AgentRunContext>,
#[serde(skip)]
pub prefill_messages: Vec<rig::completion::Message>,
#[serde(skip)]
pub cancellation_token: Option<CancellationToken>,
#[serde(skip)]
pub timeout: Option<Duration>,
}
impl AgentRequest {
pub fn new(input: impl Into<String>) -> Self {
Self {
input: input.into(),
messages: Vec::new(),
context: Vec::new(),
experts: Vec::new(),
run_context: None,
prefill_messages: Vec::new(),
cancellation_token: None,
timeout: None,
}
}
pub fn validate(&self) -> AiResult<()> {
if self.input.trim().is_empty() {
return Err(AiError::Config(
"agent request input is required".to_string(),
));
}
if self.input.len() > 1_000_000 {
return Err(AiError::Config(
"agent request input exceeds maximum length (1MB)".to_string(),
));
}
if self.experts.len() > 32 {
return Err(AiError::Config(
"agent request experts count exceeds maximum (32)".to_string(),
));
}
Ok(())
}
pub fn with_messages(mut self, messages: Vec<AgentMessage>) -> Self {
self.messages = messages;
self
}
pub fn with_context(mut self, context: Vec<AgentContextChunk>) -> Self {
self.context = context;
self
}
pub fn add_context(mut self, chunk: AgentContextChunk) -> Self {
self.context.push(chunk);
self
}
pub fn with_experts(mut self, experts: Vec<AgentExpert>) -> Self {
self.experts = experts;
self
}
pub fn add_expert(mut self, expert: AgentExpert) -> Self {
self.experts.push(expert);
self
}
pub fn with_run_context(mut self, run_context: AgentRunContext) -> Self {
self.run_context = Some(run_context);
self
}
pub fn with_prefill_messages(
mut self,
prefill_messages: Vec<rig::completion::Message>,
) -> Self {
self.prefill_messages = prefill_messages;
self
}
pub fn with_cancellation_token(
mut self,
cancellation_token: CancellationToken,
) -> Self {
self.cancellation_token = Some(cancellation_token);
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum AgentMessage {
User(String),
Assistant(String),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentExpert {
pub id: String,
pub role: String,
pub task: String,
pub system_prompt: Option<String>,
pub context: Vec<AgentContextChunk>,
/// Override the master agent's temperature for this subagent.
pub temperature: Option<f64>,
/// Override the master agent's max_completion_tokens for this subagent.
pub max_completion_tokens: Option<u64>,
}
impl AgentExpert {
pub fn new(
id: impl Into<String>,
role: impl Into<String>,
task: impl Into<String>,
) -> Self {
Self {
id: id.into(),
role: role.into(),
task: task.into(),
system_prompt: None,
context: Vec::new(),
temperature: None,
max_completion_tokens: None,
}
}
pub fn with_system_prompt(
mut self,
system_prompt: impl Into<String>,
) -> Self {
self.system_prompt = Some(system_prompt.into());
self
}
pub fn with_context(mut self, context: Vec<AgentContextChunk>) -> Self {
self.context = context;
self
}
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_completion_tokens(mut self, max_tokens: u64) -> Self {
self.max_completion_tokens = Some(max_tokens);
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentContextChunk {
pub id: String,
pub content: String,
pub source: Option<String>,
pub score: Option<f32>,
pub metadata: Value,
}
impl AgentContextChunk {
pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
id: id.into(),
content: content.into(),
source: None,
score: None,
metadata: Value::Null,
}
}
}
impl From<crate::rag::RagSearchHit> for AgentContextChunk {
fn from(hit: crate::rag::RagSearchHit) -> Self {
Self {
id: hit.id,
content: hit.content,
source: Some(hit.session_id),
score: Some(hit.score),
metadata: Value::Object(hit.metadata.into_iter().collect()),
}
}
}
impl From<&AgentExpertOutput> for AgentContextChunk {
fn from(output: &AgentExpertOutput) -> Self {
Self {
id: format!("subagent:{}", output.id),
content: output.output.clone(),
source: Some(output.role.clone()),
score: None,
metadata: serde_json::json!({
"kind": "subagent",
"task": output.task,
}),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentResult {
pub output: String,
pub steps: Vec<AgentStep>,
pub expert_outputs: Vec<AgentExpertOutput>,
pub input_tokens: i64,
pub output_tokens: i64,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentStep {
pub index: usize,
pub assistant: Option<String>,
pub reasoning_content: Option<String>,
pub tool_calls: Vec<ToolCallRecord>,
pub reflection: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolCallRecord {
pub id: String,
pub name: String,
pub arguments: Value,
pub output: Option<Value>,
pub error: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub elapsed_ms: Option<i64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AgentExpertOutput {
pub id: String,
pub role: String,
pub task: String,
pub output: String,
pub input_tokens: i64,
pub output_tokens: i64,
}