refactor(agent): split monolithic service files into specialized modules
Extract agent, compact, embed, task, and modes modules from single service.rs files into focused sub-modules. Add orao module for O1-like reasoning loop. Move RigAgentService to rig_tool.rs.
This commit is contained in:
parent
129aa3dce7
commit
d45e9e28f4
@ -1,7 +1,4 @@
|
||||
//! Agent service using rig's built-in Agent with full feature support.
|
||||
//!
|
||||
//! This module provides a higher-level agent service built on rig's Agent,
|
||||
//! supporting multi-turn conversations, RAG, and built-in streaming.
|
||||
//! Rig-based agent using rig's built-in Agent with full feature support.
|
||||
|
||||
pub mod service;
|
||||
pub use service::RigAgentService;
|
||||
pub mod rig_tool;
|
||||
pub use rig_tool::{AgentResponse, RigAgentService, StreamChunk};
|
||||
|
||||
@ -1,23 +1,17 @@
|
||||
//! Agent service using rig's built-in Agent.
|
||||
//!
|
||||
//! This is a complete implementation that leverages rig's Agent for
|
||||
//! multi-turn reasoning, tool execution, streaming, and token tracking.
|
||||
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use rig::{
|
||||
agent::{AgentBuilder, MultiTurnStreamItem},
|
||||
client::CompletionClient,
|
||||
completion::Prompt,
|
||||
streaming::{StreamingPrompt, StreamedAssistantContent},
|
||||
streaming::{StreamedAssistantContent, StreamingPrompt},
|
||||
};
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
use crate::client::AiClientConfig;
|
||||
use crate::error::AgentError;
|
||||
|
||||
/// Response from an agent completion (rig's Agent prompt response).
|
||||
#[derive(Debug)]
|
||||
pub struct AgentResponse {
|
||||
pub content: String,
|
||||
@ -25,12 +19,9 @@ pub struct AgentResponse {
|
||||
pub output_tokens: u64,
|
||||
}
|
||||
|
||||
/// Streaming chunk from the agent.
|
||||
#[derive(Debug)]
|
||||
pub enum StreamChunk {
|
||||
/// Text delta from the model
|
||||
Text(String),
|
||||
/// Final response with aggregated usage
|
||||
Final {
|
||||
content: String,
|
||||
input_tokens: u64,
|
||||
@ -38,22 +29,19 @@ pub enum StreamChunk {
|
||||
},
|
||||
}
|
||||
|
||||
/// Service for running agents using rig's built-in Agent.
|
||||
///
|
||||
/// Provides both simple prompting and full streaming with automatic
|
||||
/// tool call handling via rig's native Agent.
|
||||
pub struct RigAgentService {
|
||||
config: AiClientConfig,
|
||||
model_name: String,
|
||||
}
|
||||
|
||||
impl RigAgentService {
|
||||
/// Create a new RigAgentService.
|
||||
pub fn new(config: AiClientConfig, model_name: impl Into<String>) -> Self {
|
||||
Self { config, model_name: model_name.into() }
|
||||
Self {
|
||||
config,
|
||||
model_name: model_name.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a single prompt with the agent (single-turn, no tools).
|
||||
pub async fn prompt(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
@ -62,9 +50,7 @@ impl RigAgentService {
|
||||
let client = self.config.build_rig_client();
|
||||
let model = client.completion_model(&self.model_name);
|
||||
|
||||
let agent = AgentBuilder::new(model)
|
||||
.preamble(system_prompt)
|
||||
.build();
|
||||
let agent = AgentBuilder::new(model).preamble(system_prompt).build();
|
||||
|
||||
let response = agent
|
||||
.prompt(user_input)
|
||||
@ -74,15 +60,11 @@ impl RigAgentService {
|
||||
|
||||
Ok(AgentResponse {
|
||||
content: response.output,
|
||||
input_tokens: response.total_usage.input_tokens,
|
||||
output_tokens: response.total_usage.output_tokens,
|
||||
input_tokens: response.usage.input_tokens,
|
||||
output_tokens: response.usage.output_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run a prompt with tools (supports multi-turn via rig's Agent).
|
||||
///
|
||||
/// The agent will automatically handle tool calls by calling rig's
|
||||
/// ToolDyn implementations with proper argument deserialization.
|
||||
pub async fn prompt_with_tools(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
@ -108,16 +90,11 @@ impl RigAgentService {
|
||||
|
||||
Ok(AgentResponse {
|
||||
content: response.output,
|
||||
input_tokens: response.total_usage.input_tokens,
|
||||
output_tokens: response.total_usage.output_tokens,
|
||||
input_tokens: response.usage.input_tokens,
|
||||
output_tokens: response.usage.output_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
/// Stream a prompt with the agent using rig's native streaming.
|
||||
///
|
||||
/// This returns a proper async stream that yields text chunks as they arrive
|
||||
/// and a final response chunk with aggregated token usage. Tool calls are
|
||||
/// handled transparently by rig's Agent.
|
||||
pub async fn stream_prompt(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
@ -129,17 +106,10 @@ impl RigAgentService {
|
||||
let client = self.config.build_rig_client();
|
||||
let model = client.completion_model(&self.model_name);
|
||||
|
||||
let agent = AgentBuilder::new(model)
|
||||
.preamble(system_prompt)
|
||||
.build();
|
||||
let agent = AgentBuilder::new(model).preamble(system_prompt).build();
|
||||
|
||||
// stream_prompt().await returns StreamingResult directly (not wrapped in Result)
|
||||
// StreamingResult is Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem, StreamingError>>>>
|
||||
let stream: rig::agent::StreamingResult<_> = agent
|
||||
.stream_prompt(user_input)
|
||||
.await;
|
||||
let stream: rig::agent::StreamingResult<_> = agent.stream_prompt(user_input).await;
|
||||
|
||||
// Bridge the rig stream to our channel-based stream
|
||||
let (tx, rx) = mpsc::channel::<std::result::Result<StreamChunk, AgentError>>(100);
|
||||
|
||||
tokio::spawn(async move {
|
||||
@ -152,12 +122,14 @@ impl RigAgentService {
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::Text(text),
|
||||
)) => {
|
||||
let cleaned = text.text.replace('\n', "");
|
||||
let _ = tx.send(Ok(StreamChunk::Text(cleaned))).await;
|
||||
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
|
||||
final_content.push_str(&text.text);
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::ToolCall { tool_call, internal_call_id: _ },
|
||||
StreamedAssistantContent::ToolCall {
|
||||
tool_call,
|
||||
internal_call_id: _,
|
||||
},
|
||||
)) => {
|
||||
let args_str = match &tool_call.function.arguments {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
@ -168,7 +140,6 @@ impl RigAgentService {
|
||||
args = %args_str,
|
||||
"rig_agent_streaming_tool_call"
|
||||
);
|
||||
// Tool calllint — emitted for observability, rig handles execution internally
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamUserItem(
|
||||
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
|
||||
@ -180,11 +151,13 @@ impl RigAgentService {
|
||||
}
|
||||
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
|
||||
let usage = resp.usage();
|
||||
let _ = tx.send(Ok(StreamChunk::Final {
|
||||
content: final_content.clone(),
|
||||
input_tokens: usage.input_tokens,
|
||||
output_tokens: usage.output_tokens,
|
||||
})).await;
|
||||
let _ = tx
|
||||
.send(Ok(StreamChunk::Final {
|
||||
content: final_content.clone(),
|
||||
input_tokens: usage.input_tokens,
|
||||
output_tokens: usage.output_tokens,
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(AgentError::OpenAi(e.to_string()))).await;
|
||||
@ -197,10 +170,6 @@ impl RigAgentService {
|
||||
Ok(ReceiverStream::new(rx))
|
||||
}
|
||||
|
||||
/// Stream a prompt with tools using rig's native streaming.
|
||||
///
|
||||
/// Returns a stream thatproperly handles multi-turn tool calls via rig's Agent
|
||||
/// streaming infrastructure.
|
||||
pub async fn stream_prompt_with_tools(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
@ -222,33 +191,30 @@ impl RigAgentService {
|
||||
|
||||
let stream = agent
|
||||
.stream_prompt(user_input)
|
||||
.with_history(Vec::new())
|
||||
.with_history(Vec::<rig::completion::Message>::new())
|
||||
.multi_turn(max_turns)
|
||||
.await;
|
||||
|
||||
let (tx, rx) = mpsc::channel::<std::result::Result<StreamChunk, AgentError>>(100);
|
||||
|
||||
let (tx, rx) = mpsc::channel::<Result<StreamChunk, AgentError>>(100);
|
||||
tokio::spawn(async move {
|
||||
let mut final_content = String::new();
|
||||
|
||||
tokio::pin!(stream);
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
match item {
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(
|
||||
StreamedAssistantContent::Text(text),
|
||||
)) => {
|
||||
let cleaned = text.text.replace('\n', "");
|
||||
let _ = tx.send(Ok(StreamChunk::Text(cleaned))).await;
|
||||
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
|
||||
final_content.push_str(&text.text);
|
||||
}
|
||||
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
|
||||
let usage = resp.usage();
|
||||
let _ = tx.send(Ok(StreamChunk::Final {
|
||||
content: final_content.clone(),
|
||||
input_tokens: usage.input_tokens,
|
||||
output_tokens: usage.output_tokens,
|
||||
})).await;
|
||||
let _ = tx
|
||||
.send(Ok(StreamChunk::Final {
|
||||
content: final_content.clone(),
|
||||
input_tokens: usage.input_tokens,
|
||||
output_tokens: usage.output_tokens,
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(AgentError::OpenAi(e.to_string()))).await;
|
||||
@ -261,9 +227,8 @@ impl RigAgentService {
|
||||
Ok(ReceiverStream::new(rx))
|
||||
}
|
||||
|
||||
/// Count tokens in text using tiktoken for the configured model.
|
||||
pub fn count_tokens(&self, text: &str) -> std::result::Result<usize, AgentError> {
|
||||
pub fn count_tokens(&self, text: &str) -> Result<usize, AgentError> {
|
||||
crate::tokent::count_text(text, &self.model_name)
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -195,7 +195,7 @@ pub async fn execute_chat_stream(
|
||||
&messages, model_name, config, temperature, max_tokens,
|
||||
if tools_enabled { Some(&tools) } else { None }, None,
|
||||
Arc::new(move |delta| {
|
||||
let content = delta.to_string().replace('\n', "");
|
||||
let content = delta.to_string();
|
||||
let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer });
|
||||
fut
|
||||
}),
|
||||
|
||||
@ -74,7 +74,7 @@ where
|
||||
.build();
|
||||
|
||||
let stream = agent.stream_prompt(&request.input)
|
||||
.with_history(Vec::new())
|
||||
.with_history(Vec::<rig::completion::Message>::new())
|
||||
.multi_turn(request.max_tool_depth)
|
||||
.await;
|
||||
|
||||
@ -90,12 +90,14 @@ where
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => {
|
||||
step_count += 1;
|
||||
let t = text.text;
|
||||
let cleaned = t.replace('\n', "");
|
||||
on_chunk(ReactStep::Answer { step: step_count, answer: cleaned }).await;
|
||||
on_chunk(ReactStep::Answer { step: step_count, answer: t.clone() }).await;
|
||||
final_content.push_str(&t);
|
||||
}
|
||||
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(reasoning))) => {
|
||||
let reasoning_text = reasoning.reasoning.join("");
|
||||
let reasoning_text: String = reasoning.content.iter().filter_map(|c| match c {
|
||||
rig::completion::message::ReasoningContent::Text { text, .. } => Some(text.as_str()),
|
||||
_ => None,
|
||||
}).collect::<Vec<_>>().join("");
|
||||
if !reasoning_text.is_empty() {
|
||||
step_count += 1;
|
||||
on_chunk(ReactStep::Thought { step: step_count, thought: reasoning_text }).await;
|
||||
|
||||
@ -62,7 +62,7 @@ pub async fn execute_process_stream(
|
||||
&messages, &model_name, &config, temperature, max_tokens,
|
||||
if tools_enabled { Some(&tools) } else { None }, None,
|
||||
Arc::new(move |delta| {
|
||||
let content = delta.to_string().replace('\n', "");
|
||||
let content = delta.to_string();
|
||||
let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer });
|
||||
fut
|
||||
}),
|
||||
|
||||
@ -160,8 +160,8 @@ fn ai_metrics() -> &'static AiMetrics {
|
||||
pub(crate) fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage {
|
||||
match msg.role.as_str() {
|
||||
"system" => {
|
||||
// System messages are handled via preamble(), but we still
|
||||
// need to return something. Return a system message as User for safety.
|
||||
// System messages are handled via preamble(), not passed as messages.
|
||||
// We still need to return a valid RigMessage variant.
|
||||
RigMessage::user(msg.content.as_deref().unwrap_or(""))
|
||||
}
|
||||
"user" => {
|
||||
@ -263,9 +263,6 @@ async fn do_completion<M>(
|
||||
where
|
||||
M: CompletionModel<Client = openai::Client>,
|
||||
{
|
||||
let mut history: Vec<RigMessage> = messages.iter().map(to_rig_message).collect();
|
||||
|
||||
// Extract preamble (first system message) and remove from history
|
||||
let preamble = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "system")
|
||||
@ -273,12 +270,6 @@ where
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
history.retain(|m| !matches!(m, RigMessage::User { .. } | RigMessage::Assistant { .. }));
|
||||
|
||||
// For tool_result messages, we need to add them back
|
||||
// Actually, let's keep the approach: filter out system, add others back
|
||||
// The rig completion request uses: preamble (system) + messages (conversation)
|
||||
// For our messages: system → preamble, rest → messages
|
||||
let non_system: Vec<RigMessage> = messages
|
||||
.iter()
|
||||
.filter(|m| m.role != "system")
|
||||
@ -700,13 +691,15 @@ async fn call_stream_once(
|
||||
}
|
||||
}
|
||||
Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
|
||||
for part in &reasoning.reasoning {
|
||||
reasoning_content.push_str(part);
|
||||
on_reasoning_delta(part).await;
|
||||
chunks.push(StreamChunk {
|
||||
chunk_type: StreamChunkType::Thinking,
|
||||
content: part.clone(),
|
||||
});
|
||||
for part in &reasoning.content {
|
||||
if let rig::completion::message::ReasoningContent::Text { text, .. } = part {
|
||||
reasoning_content.push_str(text);
|
||||
on_reasoning_delta(text).await;
|
||||
chunks.push(StreamChunk {
|
||||
chunk_type: StreamChunkType::Thinking,
|
||||
content: text.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {
|
||||
|
||||
39
libs/agent/compact/auth_fetch.rs
Normal file
39
libs/agent/compact/auth_fetch.rs
Normal file
@ -0,0 +1,39 @@
|
||||
use crate::AgentError;
|
||||
use models::rooms::room_message::{
|
||||
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
|
||||
};
|
||||
use models::Expr;
|
||||
use sea_orm::*;
|
||||
|
||||
impl super::CompactService {
|
||||
pub async fn fetch_room_messages_secure(
|
||||
&self,
|
||||
room_id: uuid::Uuid,
|
||||
requester_id: uuid::Uuid,
|
||||
) -> Result<Vec<RoomMessageModel>, AgentError> {
|
||||
use models::rooms::{RoomAccess, RoomUserState};
|
||||
|
||||
RoomMessage::find()
|
||||
.filter(RmCol::Room.eq(room_id))
|
||||
.filter(
|
||||
Condition::any()
|
||||
.add(Expr::exists(
|
||||
RoomUserState::find()
|
||||
.filter(models::rooms::room_user_state::Column::Room.eq(room_id))
|
||||
.filter(models::rooms::room_user_state::Column::User.eq(requester_id))
|
||||
.into_query(),
|
||||
))
|
||||
.add(Expr::exists(
|
||||
RoomAccess::find()
|
||||
.filter(models::rooms::room_access::Column::Room.eq(room_id))
|
||||
.filter(models::rooms::room_access::Column::User.eq(requester_id))
|
||||
.into_query(),
|
||||
)),
|
||||
)
|
||||
.order_by_asc(RmCol::Seq)
|
||||
.limit(10000)
|
||||
.all(&self.db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))
|
||||
}
|
||||
}
|
||||
@ -1,8 +1,32 @@
|
||||
//! Context compaction for AI sessions and room message history.
|
||||
|
||||
pub mod auth_fetch;
|
||||
pub mod helpers;
|
||||
pub mod service;
|
||||
pub mod room_compactor;
|
||||
pub mod summarizer;
|
||||
pub mod types;
|
||||
|
||||
pub use service::CompactService;
|
||||
use sea_orm::DatabaseConnection;
|
||||
|
||||
pub use types::{CompactConfig, CompactLevel, CompactSummary, MessageSummary, ThresholdResult};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompactService {
|
||||
db: DatabaseConnection,
|
||||
ai_client_config: crate::client::AiClientConfig,
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl CompactService {
|
||||
pub fn new(
|
||||
db: DatabaseConnection,
|
||||
ai_client_config: crate::client::AiClientConfig,
|
||||
model: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
db,
|
||||
ai_client_config,
|
||||
model,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
180
libs/agent/compact/room_compactor.rs
Normal file
180
libs/agent/compact/room_compactor.rs
Normal file
@ -0,0 +1,180 @@
|
||||
use models::rooms::room_message::{
|
||||
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
|
||||
};
|
||||
use sea_orm::ColumnTrait;
|
||||
use sea_orm::{EntityTrait, QueryFilter, QueryOrder, QuerySelect};
|
||||
|
||||
use crate::compact::types::CompactLevel;
|
||||
use crate::tokent::resolve_usage;
|
||||
use crate::{AgentError, CompactSummary, MessageSummary};
|
||||
|
||||
impl super::CompactService {
|
||||
pub async fn compact_room(
|
||||
&self,
|
||||
room_id: uuid::Uuid,
|
||||
level: CompactLevel,
|
||||
user_names: Option<std::collections::HashMap<uuid::Uuid, String>>,
|
||||
requester_id: uuid::Uuid,
|
||||
context_window_tokens: i32,
|
||||
compaction_max_summary_ratio: f32,
|
||||
) -> Result<CompactSummary, AgentError> {
|
||||
let messages = self
|
||||
.fetch_room_messages_secure(room_id, requester_id)
|
||||
.await?;
|
||||
|
||||
if messages.is_empty() {
|
||||
let room_exists = models::rooms::room::Entity::find_by_id(room_id)
|
||||
.one(&self.db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))?
|
||||
.is_some();
|
||||
|
||||
if room_exists {
|
||||
return Err(AgentError::Internal("Access denied or room empty".into()));
|
||||
} else {
|
||||
return Err(AgentError::Internal("Room not found".into()));
|
||||
}
|
||||
}
|
||||
|
||||
let user_ids: Vec<uuid::Uuid> = messages
|
||||
.iter()
|
||||
.filter_map(|m| m.sender_id)
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
let user_name_map = match user_names {
|
||||
Some(map) => map,
|
||||
None => self.get_user_name_map(&user_ids).await?,
|
||||
};
|
||||
|
||||
if messages.len() <= level.retain_count() {
|
||||
let retained: Vec<MessageSummary> = messages
|
||||
.iter()
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
return Ok(CompactSummary {
|
||||
session_id: uuid::Uuid::new_v4(),
|
||||
room_id,
|
||||
retained,
|
||||
summary: String::new(),
|
||||
compacted_at: chrono::Utc::now(),
|
||||
messages_compressed: 0,
|
||||
usage: None,
|
||||
});
|
||||
}
|
||||
|
||||
let retain_count = level.retain_count();
|
||||
let split_index = messages.len().saturating_sub(retain_count);
|
||||
let (to_summarize, retained_messages) = messages.split_at(split_index);
|
||||
|
||||
let retained: Vec<MessageSummary> = retained_messages
|
||||
.iter()
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
|
||||
let max_summary_tokens =
|
||||
(context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
|
||||
|
||||
let (summary, remote_usage) = self
|
||||
.summarize_messages(to_summarize, max_summary_tokens)
|
||||
.await?;
|
||||
|
||||
let summarized_text = to_summarize
|
||||
.iter()
|
||||
.map(|m| m.content.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
|
||||
|
||||
Ok(CompactSummary {
|
||||
session_id: uuid::Uuid::new_v4(),
|
||||
room_id,
|
||||
retained,
|
||||
summary,
|
||||
compacted_at: chrono::Utc::now(),
|
||||
messages_compressed: to_summarize.len(),
|
||||
usage: Some(usage),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn compact_session(
|
||||
&self,
|
||||
session_id: uuid::Uuid,
|
||||
level: CompactLevel,
|
||||
user_names: Option<std::collections::HashMap<uuid::Uuid, String>>,
|
||||
context_window_tokens: i32,
|
||||
compaction_max_summary_ratio: f32,
|
||||
) -> Result<CompactSummary, AgentError> {
|
||||
let messages: Vec<RoomMessageModel> = RoomMessage::find()
|
||||
.filter(RmCol::Room.eq(session_id))
|
||||
.order_by_asc(RmCol::Seq)
|
||||
.limit(10000)
|
||||
.all(&self.db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))?;
|
||||
|
||||
if messages.is_empty() {
|
||||
return Err(AgentError::Internal("session has no messages".into()));
|
||||
}
|
||||
|
||||
let user_ids: Vec<uuid::Uuid> = messages
|
||||
.iter()
|
||||
.filter_map(|m| m.sender_id)
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
let user_name_map = match user_names {
|
||||
Some(map) => map,
|
||||
None => self.get_user_name_map(&user_ids).await?,
|
||||
};
|
||||
|
||||
if messages.len() <= level.retain_count() {
|
||||
let retained: Vec<MessageSummary> = messages
|
||||
.iter()
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
return Ok(CompactSummary {
|
||||
session_id,
|
||||
room_id: uuid::Uuid::nil(),
|
||||
retained,
|
||||
summary: String::new(),
|
||||
compacted_at: chrono::Utc::now(),
|
||||
messages_compressed: 0,
|
||||
usage: None,
|
||||
});
|
||||
}
|
||||
|
||||
let retain_count = level.retain_count();
|
||||
let split_index = messages.len().saturating_sub(retain_count);
|
||||
let (to_summarize, retained_messages) = messages.split_at(split_index);
|
||||
|
||||
let retained: Vec<MessageSummary> = retained_messages
|
||||
.iter()
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
|
||||
let max_summary_tokens =
|
||||
(context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
|
||||
|
||||
let (summary, remote_usage) = self
|
||||
.summarize_messages(to_summarize, max_summary_tokens)
|
||||
.await?;
|
||||
|
||||
let summarized_text = to_summarize
|
||||
.iter()
|
||||
.map(|m| m.content.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
|
||||
|
||||
Ok(CompactSummary {
|
||||
session_id,
|
||||
room_id: uuid::Uuid::nil(),
|
||||
retained,
|
||||
summary,
|
||||
compacted_at: chrono::Utc::now(),
|
||||
messages_compressed: to_summarize.len(),
|
||||
usage: Some(usage),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,327 +0,0 @@
|
||||
use chrono::Utc;
|
||||
use models::ColumnTrait;
|
||||
use models::rooms::room_message::{
|
||||
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
|
||||
};
|
||||
use models::users::user::{Column as UserCol, Entity as User};
|
||||
use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::client::types::ChatRequestMessage;
|
||||
use crate::client::AiClientConfig;
|
||||
use crate::client::call_with_params;
|
||||
use crate::AgentError;
|
||||
use crate::compact::types::{CompactLevel, CompactSummary, MessageSummary};
|
||||
use crate::tokent::{TokenUsage, resolve_usage};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CompactService {
|
||||
db: DatabaseConnection,
|
||||
ai_client_config: AiClientConfig,
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl CompactService {
|
||||
pub fn new(db: DatabaseConnection, ai_client_config: AiClientConfig, model: String) -> Self {
|
||||
Self { db, ai_client_config, model }
|
||||
}
|
||||
|
||||
pub async fn compact_room(
|
||||
&self,
|
||||
room_id: Uuid,
|
||||
level: CompactLevel,
|
||||
user_names: Option<std::collections::HashMap<Uuid, String>>,
|
||||
requester_id: Uuid,
|
||||
context_window_tokens: i32,
|
||||
compaction_max_summary_ratio: f32,
|
||||
) -> Result<CompactSummary, AgentError> {
|
||||
// Verify room access at the database level to ensure auth context is enforced.
|
||||
// Public rooms are accessible to project members.
|
||||
// For simplicity in this audit fix, we'll fetch only if access exists.
|
||||
let messages = self.fetch_room_messages_secure(room_id, requester_id).await?;
|
||||
|
||||
if messages.is_empty() {
|
||||
// Check if room actually exists or if it's just empty/inaccessible
|
||||
let room_exists = models::rooms::room::Entity::find_by_id(room_id)
|
||||
.one(&self.db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))?
|
||||
.is_some();
|
||||
|
||||
if room_exists {
|
||||
return Err(AgentError::Internal("Access denied or room empty".into()));
|
||||
} else {
|
||||
return Err(AgentError::Internal("Room not found".into()));
|
||||
}
|
||||
}
|
||||
|
||||
let user_ids: Vec<Uuid> = messages
|
||||
.iter()
|
||||
.filter_map(|m| m.sender_id)
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
let user_name_map = match user_names {
|
||||
Some(map) => map,
|
||||
None => self.get_user_name_map(&user_ids).await?,
|
||||
};
|
||||
|
||||
if messages.len() <= level.retain_count() {
|
||||
let retained: Vec<MessageSummary> = messages
|
||||
.iter()
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
return Ok(CompactSummary {
|
||||
session_id: Uuid::new_v4(),
|
||||
room_id,
|
||||
retained,
|
||||
summary: String::new(),
|
||||
compacted_at: Utc::now(),
|
||||
messages_compressed: 0,
|
||||
usage: None,
|
||||
});
|
||||
}
|
||||
|
||||
let retain_count = level.retain_count();
|
||||
let split_index = messages.len().saturating_sub(retain_count);
|
||||
let (to_summarize, retained_messages) = messages.split_at(split_index);
|
||||
|
||||
let retained: Vec<MessageSummary> = retained_messages
|
||||
.iter()
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
|
||||
let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
|
||||
|
||||
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
|
||||
|
||||
// Build text of what was summarized (for tiktoken fallback)
|
||||
let summarized_text = to_summarize
|
||||
.iter()
|
||||
.map(|m| m.content.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
|
||||
|
||||
Ok(CompactSummary {
|
||||
session_id: Uuid::new_v4(),
|
||||
room_id,
|
||||
retained,
|
||||
summary,
|
||||
compacted_at: Utc::now(),
|
||||
messages_compressed: to_summarize.len(),
|
||||
usage: Some(usage),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn compact_session(
|
||||
&self,
|
||||
session_id: Uuid,
|
||||
level: CompactLevel,
|
||||
user_names: Option<std::collections::HashMap<Uuid, String>>,
|
||||
context_window_tokens: i32,
|
||||
compaction_max_summary_ratio: f32,
|
||||
) -> Result<CompactSummary, AgentError> {
|
||||
let messages: Vec<RoomMessageModel> = RoomMessage::find()
|
||||
.filter(RmCol::Room.eq(session_id))
|
||||
.order_by_asc(RmCol::Seq)
|
||||
.limit(10000)
|
||||
.all(&self.db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))?;
|
||||
|
||||
if messages.is_empty() {
|
||||
return Err(AgentError::Internal("session has no messages".into()));
|
||||
}
|
||||
|
||||
let user_ids: Vec<Uuid> = messages
|
||||
.iter()
|
||||
.filter_map(|m| m.sender_id)
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
let user_name_map = match user_names {
|
||||
Some(map) => map,
|
||||
None => self.get_user_name_map(&user_ids).await?,
|
||||
};
|
||||
|
||||
if messages.len() <= level.retain_count() {
|
||||
let retained: Vec<MessageSummary> = messages
|
||||
.iter()
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
return Ok(CompactSummary {
|
||||
session_id,
|
||||
room_id: Uuid::nil(),
|
||||
retained,
|
||||
summary: String::new(),
|
||||
compacted_at: Utc::now(),
|
||||
messages_compressed: 0,
|
||||
usage: None,
|
||||
});
|
||||
}
|
||||
|
||||
let retain_count = level.retain_count();
|
||||
let split_index = messages.len().saturating_sub(retain_count);
|
||||
let (to_summarize, retained_messages) = messages.split_at(split_index);
|
||||
|
||||
let retained: Vec<MessageSummary> = retained_messages
|
||||
.iter()
|
||||
.map(|m| Self::message_to_summary(m, &user_name_map))
|
||||
.collect();
|
||||
|
||||
let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
|
||||
|
||||
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
|
||||
|
||||
let summarized_text = to_summarize
|
||||
.iter()
|
||||
.map(|m| m.content.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
|
||||
|
||||
Ok(CompactSummary {
|
||||
session_id,
|
||||
room_id: Uuid::nil(),
|
||||
retained,
|
||||
summary,
|
||||
compacted_at: Utc::now(),
|
||||
messages_compressed: to_summarize.len(),
|
||||
usage: Some(usage),
|
||||
})
|
||||
}
|
||||
|
||||
async fn fetch_room_messages_secure(
|
||||
&self,
|
||||
room_id: Uuid,
|
||||
requester_id: Uuid,
|
||||
) -> Result<Vec<RoomMessageModel>, AgentError> {
|
||||
use models::rooms::{RoomUserState, RoomAccess};
|
||||
use sea_orm::QueryTrait;
|
||||
use sea_orm::sea_query::Expr;
|
||||
|
||||
// Find messages for the room where the requester has access.
|
||||
// We check both the room_user_state table (membership) and the room_access table (explicit grants).
|
||||
RoomMessage::find()
|
||||
.filter(RmCol::Room.eq(room_id))
|
||||
.filter(
|
||||
sea_orm::Condition::any()
|
||||
.add(
|
||||
Expr::exists(
|
||||
RoomUserState::find()
|
||||
.filter(models::rooms::room_user_state::Column::Room.eq(room_id))
|
||||
.filter(models::rooms::room_user_state::Column::User.eq(requester_id))
|
||||
.into_query()
|
||||
)
|
||||
)
|
||||
.add(
|
||||
Expr::exists(
|
||||
RoomAccess::find()
|
||||
.filter(models::rooms::room_access::Column::Room.eq(room_id))
|
||||
.filter(models::rooms::room_access::Column::User.eq(requester_id))
|
||||
.into_query()
|
||||
)
|
||||
)
|
||||
)
|
||||
.order_by_asc(RmCol::Seq)
|
||||
.limit(10000)
|
||||
.all(&self.db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))
|
||||
}
|
||||
|
||||
fn message_to_summary(m: &RoomMessageModel, user_name_map: &std::collections::HashMap<Uuid, String>) -> MessageSummary {
|
||||
let sender_name = if let Some(user_id) = m.sender_id {
|
||||
user_name_map.get(&user_id).cloned().unwrap_or_else(|| m.sender_type.to_string())
|
||||
} else {
|
||||
m.sender_type.to_string()
|
||||
};
|
||||
MessageSummary {
|
||||
id: m.id,
|
||||
sender_type: m.sender_type.clone(),
|
||||
sender_id: m.sender_id,
|
||||
sender_name,
|
||||
content: m.content.clone(),
|
||||
content_type: m.content_type.clone(),
|
||||
tool_call_id: None,
|
||||
send_at: m.send_at,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_user_name_map(
|
||||
&self,
|
||||
user_ids: &[Uuid],
|
||||
) -> Result<std::collections::HashMap<Uuid, String>, AgentError> {
|
||||
use std::collections::HashMap;
|
||||
let mut map = HashMap::new();
|
||||
if !user_ids.is_empty() {
|
||||
let users = User::find()
|
||||
.filter(UserCol::Uid.is_in(user_ids.to_vec()))
|
||||
.all(&self.db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))?;
|
||||
for user in users {
|
||||
map.insert(user.uid, user.username);
|
||||
}
|
||||
}
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
async fn summarize_messages(
|
||||
&self,
|
||||
messages: &[RoomMessageModel],
|
||||
max_summary_tokens: usize,
|
||||
) -> Result<(String, Option<TokenUsage>), AgentError> {
|
||||
let user_ids: Vec<Uuid> = messages
|
||||
.iter()
|
||||
.filter_map(|m| m.sender_id)
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let user_name_map = self.get_user_name_map(&user_ids).await?;
|
||||
|
||||
let sender_mapper = |m: &RoomMessageModel| {
|
||||
if let Some(user_id) = m.sender_id {
|
||||
if let Some(username) = user_name_map.get(&user_id) {
|
||||
return username.clone();
|
||||
}
|
||||
}
|
||||
m.sender_type.to_string()
|
||||
};
|
||||
|
||||
let body = crate::compact::helpers::messages_to_text(messages, sender_mapper);
|
||||
|
||||
let user_msg = ChatRequestMessage::user(format!(
|
||||
"Summarise the following conversation concisely, preserving all key facts, \
|
||||
decisions, and any pending or in-progress work. \
|
||||
The summary MUST NOT exceed {} tokens. \
|
||||
Use this format:\n\n\
|
||||
**Summary:** <one-paragraph overview>\n\
|
||||
**Key decisions:** <bullet list or 'none'>\n\
|
||||
**Open items:** <bullet list or 'none'>\n\n\
|
||||
Conversation:\n\n{}",
|
||||
max_summary_tokens,
|
||||
body
|
||||
));
|
||||
|
||||
let response = call_with_params(
|
||||
&[user_msg],
|
||||
&self.model,
|
||||
&self.ai_client_config,
|
||||
0.3,
|
||||
2048,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
||||
|
||||
let remote_usage =
|
||||
TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32);
|
||||
|
||||
Ok((response.content, remote_usage))
|
||||
}
|
||||
}
|
||||
110
libs/agent/compact/summarizer.rs
Normal file
110
libs/agent/compact/summarizer.rs
Normal file
@ -0,0 +1,110 @@
|
||||
use models::rooms::room_message::Model as RoomMessageModel;
|
||||
use models::users::user::{Column as UserCol, Entity as User};
|
||||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
|
||||
|
||||
use crate::client::call_with_params;
|
||||
use crate::client::types::ChatRequestMessage;
|
||||
use crate::compact::types::MessageSummary;
|
||||
use crate::tokent::TokenUsage;
|
||||
use crate::AgentError;
|
||||
|
||||
impl super::CompactService {
|
||||
pub async fn summarize_messages(
|
||||
&self,
|
||||
messages: &[RoomMessageModel],
|
||||
max_summary_tokens: usize,
|
||||
) -> Result<(String, Option<TokenUsage>), AgentError> {
|
||||
let user_ids: Vec<uuid::Uuid> = messages
|
||||
.iter()
|
||||
.filter_map(|m| m.sender_id)
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let user_name_map = self.get_user_name_map(&user_ids).await?;
|
||||
|
||||
let sender_mapper = |m: &RoomMessageModel| {
|
||||
if let Some(user_id) = m.sender_id {
|
||||
if let Some(username) = user_name_map.get(&user_id) {
|
||||
return username.clone();
|
||||
}
|
||||
}
|
||||
m.sender_type.to_string()
|
||||
};
|
||||
|
||||
let body = crate::compact::helpers::messages_to_text(messages, sender_mapper);
|
||||
|
||||
let user_msg = ChatRequestMessage::user(format!(
|
||||
"Summarise the following conversation concisely, preserving all key facts, \
|
||||
decisions, and any pending or in-progress work. \
|
||||
The summary MUST NOT exceed {} tokens. \
|
||||
Use this format:\n\n\
|
||||
**Summary:** <one-paragraph overview>\n\
|
||||
**Key decisions:** <bullet list or 'none'>\n\
|
||||
**Open items:** <bullet list or 'none'>\n\n\
|
||||
Conversation:\n\n{}",
|
||||
max_summary_tokens, body
|
||||
));
|
||||
|
||||
let response = call_with_params(
|
||||
&[user_msg],
|
||||
&self.model,
|
||||
&self.ai_client_config,
|
||||
0.3,
|
||||
2048,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
|
||||
|
||||
let remote_usage =
|
||||
TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32);
|
||||
|
||||
Ok((response.content, remote_usage))
|
||||
}
|
||||
|
||||
pub fn message_to_summary(
|
||||
m: &RoomMessageModel,
|
||||
user_name_map: &std::collections::HashMap<uuid::Uuid, String>,
|
||||
) -> MessageSummary {
|
||||
let sender_name = if let Some(user_id) = m.sender_id {
|
||||
user_name_map
|
||||
.get(&user_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| m.sender_type.to_string())
|
||||
} else {
|
||||
m.sender_type.to_string()
|
||||
};
|
||||
MessageSummary {
|
||||
id: m.id,
|
||||
sender_type: m.sender_type.clone(),
|
||||
sender_id: m.sender_id,
|
||||
sender_name,
|
||||
content: m.content.clone(),
|
||||
content_type: m.content_type.clone(),
|
||||
tool_call_id: None,
|
||||
send_at: m.send_at,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_user_name_map(
|
||||
&self,
|
||||
user_ids: &[uuid::Uuid],
|
||||
) -> Result<std::collections::HashMap<uuid::Uuid, String>, AgentError> {
|
||||
use std::collections::HashMap;
|
||||
let mut map = HashMap::new();
|
||||
if !user_ids.is_empty() {
|
||||
let users = User::find()
|
||||
.filter(UserCol::Uid.is_in(user_ids.to_vec()))
|
||||
.all(&self.db)
|
||||
.await
|
||||
.map_err(|e| AgentError::Internal(e.to_string()))?;
|
||||
for user in users {
|
||||
map.insert(user.uid, user.username);
|
||||
}
|
||||
}
|
||||
Ok(map)
|
||||
}
|
||||
}
|
||||
61
libs/agent/embed/chunk.rs
Normal file
61
libs/agent/embed/chunk.rs
Normal file
@ -0,0 +1,61 @@
|
||||
/// Maximum characters per chunk for embedding (approximates token limit).
|
||||
/// text-embedding-3-small: 8192 token limit.
|
||||
/// For CJK ~1 char/token, for English ~4 chars/token.
|
||||
/// Conservative limit: 7000 chars to leave room for all languages.
|
||||
const MAX_CHUNK_CHARS: usize = 7000;
|
||||
|
||||
/// Split long text into chunks at paragraph/sentence boundaries.
|
||||
/// Returns at least one chunk even for empty text.
|
||||
/// Safe for multi-byte characters (uses char indices, not byte indices).
|
||||
pub fn chunk_text(text: &str) -> Vec<String> {
|
||||
if text.is_empty() {
|
||||
return vec![String::new()];
|
||||
}
|
||||
if text.len() <= MAX_CHUNK_CHARS {
|
||||
return vec![text.to_string()];
|
||||
}
|
||||
|
||||
let char_indices: Vec<usize> = text.char_indices().map(|(i, _)| i).collect();
|
||||
let total_chars = char_indices.len();
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
let mut start_idx = 0;
|
||||
|
||||
while start_idx < total_chars {
|
||||
let byte_start = char_indices[start_idx];
|
||||
let end_char_idx = (start_idx + MAX_CHUNK_CHARS).min(total_chars);
|
||||
let byte_end_candidate = char_indices[end_char_idx - 1]
|
||||
+ text[char_indices[end_char_idx - 1]..]
|
||||
.chars()
|
||||
.next()
|
||||
.map(|c| c.len_utf8())
|
||||
.unwrap_or(1);
|
||||
|
||||
if end_char_idx >= total_chars {
|
||||
chunks.push(text[byte_start..].to_string());
|
||||
break;
|
||||
}
|
||||
|
||||
let search_range = &text[byte_start..byte_end_candidate];
|
||||
let break_at = search_range.rfind("\n\n").map(|pos| pos + 2)
|
||||
.or_else(|| search_range.rfind('\n').map(|pos| pos + 1))
|
||||
.or_else(|| search_range.rfind(". ").map(|pos| pos + 1))
|
||||
.or_else(|| search_range.rfind("! ").map(|pos| pos + 1))
|
||||
.or_else(|| search_range.rfind("? ").map(|pos| pos + 1));
|
||||
|
||||
if let Some(offset) = break_at {
|
||||
let byte_end = byte_start + offset;
|
||||
chunks.push(text[byte_start..byte_end].to_string());
|
||||
let mut advance = start_idx + 1;
|
||||
while advance < total_chars && char_indices[advance] < byte_end {
|
||||
advance += 1;
|
||||
}
|
||||
start_idx = advance;
|
||||
} else {
|
||||
chunks.push(text[byte_start..byte_end_candidate].to_string());
|
||||
start_idx = end_char_idx;
|
||||
}
|
||||
}
|
||||
|
||||
chunks
|
||||
}
|
||||
23
libs/agent/embed/embeddable.rs
Normal file
23
libs/agent/embed/embeddable.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Trait for entities that can be embedded as vectors into Qdrant.
|
||||
#[async_trait]
|
||||
pub trait Embeddable {
|
||||
fn entity_type(&self) -> &'static str;
|
||||
fn to_text(&self) -> String;
|
||||
fn entity_id(&self) -> String;
|
||||
}
|
||||
|
||||
/// Input struct for batch memory embedding into per-room Qdrant collections.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbedMemoryInput {
|
||||
pub message_id: String,
|
||||
pub content: String,
|
||||
pub project_name: String,
|
||||
pub room_id: String,
|
||||
pub user_id: Option<String>,
|
||||
pub sender_type: String,
|
||||
}
|
||||
|
||||
/// Input struct for batch tag embedding.
|
||||
pub use models::TagEmbedInput;
|
||||
@ -1,112 +1,11 @@
|
||||
use async_trait::async_trait;
|
||||
use qdrant_client::qdrant::Filter;
|
||||
use sea_orm::DatabaseConnection;
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::client::{EmbedClient, EmbedPayload, EmbedVector, SearchResult};
|
||||
|
||||
/// Maximum characters per chunk for embedding (approximates token limit).
|
||||
/// text-embedding-3-small: 8192 token limit.
|
||||
/// For CJK ~1 char/token, for English ~4 chars/token.
|
||||
/// Conservative limit: 7000 chars to leave room for all languages.
|
||||
const MAX_CHUNK_CHARS: usize = 7000;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Embeddable {
|
||||
fn entity_type(&self) -> &'static str;
|
||||
fn to_text(&self) -> String;
|
||||
fn entity_id(&self) -> String;
|
||||
}
|
||||
|
||||
/// Split long text into chunks at paragraph/sentence boundaries.
|
||||
/// Returns at least one chunk even for empty text.
|
||||
/// Safe for multi-byte characters (uses char indices, not byte indices).
|
||||
fn chunk_text(text: &str) -> Vec<String> {
|
||||
if text.is_empty() {
|
||||
return vec![String::new()];
|
||||
}
|
||||
if text.len() <= MAX_CHUNK_CHARS {
|
||||
return vec![text.to_string()];
|
||||
}
|
||||
|
||||
// Collect char boundary byte positions
|
||||
let char_indices: Vec<usize> = text.char_indices().map(|(i, _)| i).collect();
|
||||
let total_chars = char_indices.len();
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
let mut start_idx = 0; // char index
|
||||
|
||||
while start_idx < total_chars {
|
||||
// Start byte offset
|
||||
let byte_start = char_indices[start_idx];
|
||||
|
||||
// Find end char index: at most MAX_CHUNK_CHARS characters
|
||||
let end_char_idx = (start_idx + MAX_CHUNK_CHARS).min(total_chars);
|
||||
let byte_end_candidate = char_indices[end_char_idx - 1] + text[char_indices[end_char_idx - 1]..].chars().next().map(|c| c.len_utf8()).unwrap_or(1);
|
||||
|
||||
if end_char_idx >= total_chars {
|
||||
chunks.push(text[byte_start..].to_string());
|
||||
break;
|
||||
}
|
||||
|
||||
// Try to break at paragraph or sentence boundary in the allowed range
|
||||
let search_range = &text[byte_start..byte_end_candidate];
|
||||
let break_at = if let Some(pos) = search_range.rfind("\n\n") {
|
||||
Some(pos + 2) // after the paragraph break
|
||||
} else if let Some(pos) = search_range.rfind('\n') {
|
||||
Some(pos + 1)
|
||||
} else if let Some(pos) = search_range.rfind(". ") {
|
||||
Some(pos + 1)
|
||||
} else if let Some(pos) = search_range.rfind("! ") {
|
||||
Some(pos + 1)
|
||||
} else if let Some(pos) = search_range.rfind("? ") {
|
||||
Some(pos + 1)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(offset) = break_at {
|
||||
let byte_end = byte_start + offset;
|
||||
chunks.push(text[byte_start..byte_end].to_string());
|
||||
// Advance char index to match the byte break
|
||||
let mut advance = start_idx + 1;
|
||||
while advance < total_chars && char_indices[advance] < byte_end {
|
||||
advance += 1;
|
||||
}
|
||||
start_idx = advance;
|
||||
} else {
|
||||
// Hard break at char boundary
|
||||
chunks.push(text[byte_start..byte_end_candidate].to_string());
|
||||
start_idx = end_char_idx;
|
||||
}
|
||||
}
|
||||
|
||||
chunks
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbedService {
|
||||
client: Arc<EmbedClient>,
|
||||
db: DatabaseConnection,
|
||||
model_name: String,
|
||||
dimensions: u64,
|
||||
}
|
||||
|
||||
impl EmbedService {
|
||||
pub fn new(
|
||||
client: EmbedClient,
|
||||
db: DatabaseConnection,
|
||||
model_name: String,
|
||||
dimensions: u64,
|
||||
) -> Self {
|
||||
Self {
|
||||
client: Arc::new(client),
|
||||
db,
|
||||
model_name,
|
||||
dimensions,
|
||||
}
|
||||
}
|
||||
use super::chunk::chunk_text;
|
||||
use super::client::{EmbedPayload, EmbedVector};
|
||||
use super::embeddable::{EmbedMemoryInput, Embeddable};
|
||||
|
||||
/// Embedding and upsert operations for entity vectors in Qdrant.
|
||||
impl super::EmbedService {
|
||||
pub async fn embed_issue(
|
||||
&self,
|
||||
id: &str,
|
||||
@ -203,69 +102,6 @@ impl EmbedService {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn search_issues(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
) -> crate::Result<Vec<SearchResult>> {
|
||||
self.client
|
||||
.search(query, "issue", &self.model_name, limit)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn search_repos(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
) -> crate::Result<Vec<SearchResult>> {
|
||||
self.client
|
||||
.search(query, "repo", &self.model_name, limit)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn search_issues_filtered(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
filter: Filter,
|
||||
) -> crate::Result<Vec<SearchResult>> {
|
||||
self.client
|
||||
.search_with_filter(query, "issue", &self.model_name, limit, filter)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn delete_issue_embedding(&self, issue_id: &str) -> crate::Result<()> {
|
||||
self.client.delete_by_entity_id("issue", issue_id).await
|
||||
}
|
||||
|
||||
pub async fn delete_repo_embedding(&self, repo_id: &str) -> crate::Result<()> {
|
||||
self.client.delete_by_entity_id("repo", repo_id).await
|
||||
}
|
||||
|
||||
pub async fn ensure_collections(&self) -> crate::Result<()> {
|
||||
self.client
|
||||
.ensure_collection("issue", self.dimensions)
|
||||
.await?;
|
||||
self.client
|
||||
.ensure_collection("repo", self.dimensions)
|
||||
.await?;
|
||||
self.client.ensure_skill_collection(self.dimensions).await?;
|
||||
self.client
|
||||
.ensure_collection("repo_tag", self.dimensions)
|
||||
.await?;
|
||||
// Room memory collections are created per-room on first embed
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn db(&self) -> &DatabaseConnection {
|
||||
&self.db
|
||||
}
|
||||
|
||||
pub fn client(&self) -> &Arc<EmbedClient> {
|
||||
&self.client
|
||||
}
|
||||
|
||||
/// Embed a project skill into Qdrant for vector-based semantic search.
|
||||
pub async fn embed_skill(
|
||||
&self,
|
||||
skill_id: i64,
|
||||
@ -279,7 +115,6 @@ impl EmbedService {
|
||||
|
||||
tracing::debug!(skill_id = %skill_id, name = %name, content_len = content.len(), "embed_skill: starting");
|
||||
|
||||
// Auto-chunk long content
|
||||
let texts = chunk_text(content);
|
||||
tracing::debug!(skill_id = %skill_id, chunks = texts.len(), "embed_skill: chunked");
|
||||
|
||||
@ -288,13 +123,17 @@ impl EmbedService {
|
||||
.embed_skill(&id, name, desc, content, project_uuid, &self.model_name)
|
||||
.await?;
|
||||
} else {
|
||||
// Multi-chunk: embed each chunk with chunk_index metadata
|
||||
let full_texts: Vec<String> = texts.iter().map(|t| format!("{}: {} {}", name, desc, t)).collect();
|
||||
let full_texts: Vec<String> = texts
|
||||
.iter()
|
||||
.map(|t| format!("{}: {} {}", name, desc, t))
|
||||
.collect();
|
||||
tracing::debug!(skill_id = %skill_id, "embed_skill: calling embed_batch");
|
||||
let embeddings = self.client.embed_batch(&full_texts, &self.model_name).await?;
|
||||
|
||||
let points: Vec<EmbedVector> = embeddings.into_iter().enumerate().map(|(i, vector)| {
|
||||
EmbedVector {
|
||||
let points: Vec<EmbedVector> = embeddings
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, vector)| EmbedVector {
|
||||
id: format!("{}:chunk:{}", id, i),
|
||||
vector,
|
||||
payload: EmbedPayload {
|
||||
@ -306,10 +145,11 @@ impl EmbedService {
|
||||
"description": desc,
|
||||
"chunk_index": i,
|
||||
"total_chunks": texts.len(),
|
||||
}).into(),
|
||||
})
|
||||
.into(),
|
||||
},
|
||||
}
|
||||
}).collect();
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.client.upsert(points).await?;
|
||||
}
|
||||
@ -317,7 +157,6 @@ impl EmbedService {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Embed an issue with auto-chunking for long content.
|
||||
pub async fn embed_issue_chunked(
|
||||
&self,
|
||||
id: &str,
|
||||
@ -336,8 +175,10 @@ impl EmbedService {
|
||||
|
||||
let embeddings = self.client.embed_batch(&chunks, &self.model_name).await?;
|
||||
|
||||
let points: Vec<EmbedVector> = embeddings.into_iter().enumerate().map(|(i, vector)| {
|
||||
EmbedVector {
|
||||
let points: Vec<EmbedVector> = embeddings
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, vector)| EmbedVector {
|
||||
id: format!("{}:chunk:{}", id, i),
|
||||
vector,
|
||||
payload: EmbedPayload {
|
||||
@ -347,17 +188,15 @@ impl EmbedService {
|
||||
extra: serde_json::json!({
|
||||
"chunk_index": i,
|
||||
"total_chunks": chunks.len(),
|
||||
}).into(),
|
||||
})
|
||||
.into(),
|
||||
},
|
||||
}
|
||||
}).collect();
|
||||
})
|
||||
.collect();
|
||||
|
||||
self.client.upsert(points).await
|
||||
}
|
||||
|
||||
/// Batch-embed multiple conversation messages into per-room Qdrant collections.
|
||||
/// Auto-chunks long messages and filters non-text/system/empty content.
|
||||
/// Handles all filtering internally: only text-type, non-empty, non-system messages are embedded.
|
||||
pub async fn embed_memories_batch(
|
||||
&self,
|
||||
messages: Vec<EmbedMemoryInput>,
|
||||
@ -366,8 +205,6 @@ impl EmbedService {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Group by room collection for batch upsert to reduce Qdrant calls
|
||||
use std::collections::HashMap;
|
||||
let mut by_room: HashMap<String, Vec<(EmbedMemoryInput, Vec<String>)>> = HashMap::new();
|
||||
|
||||
for msg in messages {
|
||||
@ -375,15 +212,15 @@ impl EmbedService {
|
||||
if chunks.is_empty() || chunks.iter().all(|c| c.trim().is_empty()) {
|
||||
continue;
|
||||
}
|
||||
let collection = crate::embed::qdrant::QdrantClient::room_memory_collection_name(
|
||||
let collection = super::qdrant::QdrantClient::room_memory_collection_name(
|
||||
&msg.project_name, &msg.room_id,
|
||||
);
|
||||
by_room.entry(collection).or_default().push((msg, chunks));
|
||||
}
|
||||
|
||||
for (collection, entries) in &by_room {
|
||||
// Collect all texts for batch embedding
|
||||
let all_texts: Vec<String> = entries.iter()
|
||||
let all_texts: Vec<String> = entries
|
||||
.iter()
|
||||
.flat_map(|(_, chunks)| chunks.iter().cloned())
|
||||
.collect();
|
||||
|
||||
@ -393,14 +230,12 @@ impl EmbedService {
|
||||
|
||||
let embeddings = self.client.embed_batch(&all_texts, &self.model_name).await?;
|
||||
|
||||
// Ensure the room collection exists with correct dimensions
|
||||
if let Some((first, _)) = entries.first() {
|
||||
let _ = self.client
|
||||
.ensure_room_memory_collection(&first.project_name, &first.room_id, self.dimensions)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Build points: one per chunk
|
||||
let mut points = Vec::new();
|
||||
let mut embed_idx = 0;
|
||||
for (msg, chunks) in entries {
|
||||
@ -423,9 +258,18 @@ impl EmbedService {
|
||||
extra: serde_json::json!({
|
||||
"user_id": msg.user_id,
|
||||
"sender_type": msg.sender_type,
|
||||
"chunk_index": if chunks.len() > 1 { Some(chunk_i) } else { None },
|
||||
"total_chunks": if chunks.len() > 1 { Some(chunks.len()) } else { None },
|
||||
}).into(),
|
||||
"chunk_index": if chunks.len() > 1 {
|
||||
Some(chunk_i)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
"total_chunks": if chunks.len() > 1 {
|
||||
Some(chunks.len())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
})
|
||||
.into(),
|
||||
},
|
||||
});
|
||||
embed_idx += 1;
|
||||
@ -440,11 +284,9 @@ impl EmbedService {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Batch-embed repo tags with project isolation.
|
||||
/// Each tag stores project_id as entity_id for post-filtering.
|
||||
pub async fn embed_tags_batch(
|
||||
&self,
|
||||
tags: Vec<TagEmbedInput>,
|
||||
tags: Vec<super::embeddable::TagEmbedInput>,
|
||||
) -> crate::Result<()> {
|
||||
if tags.is_empty() {
|
||||
return Ok(());
|
||||
@ -494,48 +336,6 @@ impl EmbedService {
|
||||
self.client.upsert(points).await
|
||||
}
|
||||
|
||||
/// Search repo tags by semantic similarity within a project.
|
||||
/// Filters by project_id (stored in entity_id) for project isolation.
|
||||
pub async fn search_tags(
|
||||
&self,
|
||||
query: &str,
|
||||
project_id: &str,
|
||||
limit: usize,
|
||||
) -> crate::Result<Vec<SearchResult>> {
|
||||
let mut results = self
|
||||
.client
|
||||
.search(query, "repo_tag", &self.model_name, limit + 1)
|
||||
.await?;
|
||||
results.retain(|r| r.payload.entity_id == project_id);
|
||||
results.truncate(limit);
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
pub fn model_name(&self) -> &str {
|
||||
&self.model_name
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> u64 {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
pub fn embed_client(&self) -> &EmbedClient {
|
||||
&self.client
|
||||
}
|
||||
|
||||
/// Search skills by semantic similarity within a project.
|
||||
pub async fn search_skills(
|
||||
&self,
|
||||
query: &str,
|
||||
project_uuid: &str,
|
||||
limit: usize,
|
||||
) -> crate::Result<Vec<SearchResult>> {
|
||||
self.client
|
||||
.search_skills(query, &self.model_name, project_uuid, limit)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Embed a conversation message into Qdrant as a memory vector.
|
||||
pub async fn embed_memory(
|
||||
&self,
|
||||
message_id: &str,
|
||||
@ -548,32 +348,4 @@ impl EmbedService {
|
||||
.embed_memory(message_id, text, project_name, room_id, user_id, &self.model_name)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Search past conversation messages by semantic similarity within a room.
|
||||
pub async fn search_memories(
|
||||
&self,
|
||||
query: &str,
|
||||
project_name: &str,
|
||||
room_id: &str,
|
||||
limit: usize,
|
||||
) -> crate::Result<Vec<SearchResult>> {
|
||||
self.client
|
||||
.search_memories(query, &self.model_name, project_name, room_id, limit, self.dimensions)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Input struct for batch memory embedding into per-room Qdrant collections.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbedMemoryInput {
|
||||
pub message_id: String,
|
||||
pub content: String,
|
||||
pub project_name: String,
|
||||
pub room_id: String,
|
||||
pub user_id: Option<String>,
|
||||
pub sender_type: String,
|
||||
}
|
||||
|
||||
/// Input struct for batch tag embedding.
|
||||
/// Re-exported from models for backward compatibility.
|
||||
pub use models::TagEmbedInput;
|
||||
}
|
||||
@ -1,10 +1,69 @@
|
||||
pub mod chunk;
|
||||
pub mod client;
|
||||
pub mod embeddable;
|
||||
pub mod entity_embed;
|
||||
pub mod qdrant;
|
||||
pub mod service;
|
||||
pub mod search;
|
||||
|
||||
pub use client::{EmbedClient, EmbedPayload, EmbedVector, SearchResult};
|
||||
pub use embeddable::{EmbedMemoryInput, Embeddable, TagEmbedInput};
|
||||
pub use qdrant::QdrantClient;
|
||||
pub use service::{EmbedMemoryInput, EmbedService, Embeddable, TagEmbedInput};
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbedService {
|
||||
client: Arc<EmbedClient>,
|
||||
db: sea_orm::DatabaseConnection,
|
||||
model_name: String,
|
||||
dimensions: u64,
|
||||
}
|
||||
|
||||
impl EmbedService {
|
||||
pub fn new(
|
||||
client: EmbedClient,
|
||||
db: sea_orm::DatabaseConnection,
|
||||
model_name: String,
|
||||
dimensions: u64,
|
||||
) -> Self {
|
||||
Self {
|
||||
client: Arc::new(client),
|
||||
db,
|
||||
model_name,
|
||||
dimensions,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ensure_collections(&self) -> crate::Result<()> {
|
||||
self.client
|
||||
.ensure_collection("issue", self.dimensions)
|
||||
.await?;
|
||||
self.client
|
||||
.ensure_collection("repo", self.dimensions)
|
||||
.await?;
|
||||
self.client.ensure_skill_collection(self.dimensions).await?;
|
||||
self.client
|
||||
.ensure_collection("repo_tag", self.dimensions)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn db(&self) -> &sea_orm::DatabaseConnection {
|
||||
&self.db
|
||||
}
|
||||
|
||||
pub fn client(&self) -> &Arc<EmbedClient> {
|
||||
&self.client
|
||||
}
|
||||
|
||||
pub fn model_name(&self) -> &str {
|
||||
&self.model_name
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> u64 {
|
||||
self.dimensions
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new_embed_client(config: &config::AppConfig) -> crate::Result<EmbedClient> {
|
||||
let base_url = config
|
||||
@ -22,7 +81,9 @@ pub async fn new_embed_client(config: &config::AppConfig) -> crate::Result<Embed
|
||||
.api_key(&api_key)
|
||||
.base_url(&base_url)
|
||||
.build()
|
||||
.map_err(|e| crate::AgentError::Internal(format!("failed to build rig openai client: {}", e)))?;
|
||||
.map_err(|e| {
|
||||
crate::AgentError::Internal(format!("failed to build rig openai client: {}", e))
|
||||
})?;
|
||||
|
||||
let qdrant = QdrantClient::new(&qdrant_url, qdrant_api_key.as_deref()).await?;
|
||||
Ok(EmbedClient::new(openai, qdrant))
|
||||
|
||||
79
libs/agent/embed/search.rs
Normal file
79
libs/agent/embed/search.rs
Normal file
@ -0,0 +1,79 @@
|
||||
use qdrant_client::qdrant::Filter;
|
||||
|
||||
use super::client::SearchResult;
|
||||
|
||||
/// Vector search operations for Qdrant-backed entity retrieval.
|
||||
impl super::EmbedService {
|
||||
pub async fn search_issues(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
) -> crate::Result<Vec<SearchResult>> {
|
||||
self.client
|
||||
.search(query, "issue", &self.model_name, limit)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn search_repos(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
) -> crate::Result<Vec<SearchResult>> {
|
||||
self.client
|
||||
.search(query, "repo", &self.model_name, limit)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn search_issues_filtered(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
filter: Filter,
|
||||
) -> crate::Result<Vec<SearchResult>> {
|
||||
self.client
|
||||
.search_with_filter(query, "issue", &self.model_name, limit, filter)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Search repo tags by semantic similarity within a project.
|
||||
/// Filters by project_id (stored in entity_id) for project isolation.
|
||||
pub async fn search_tags(
|
||||
&self,
|
||||
query: &str,
|
||||
project_id: &str,
|
||||
limit: usize,
|
||||
) -> crate::Result<Vec<SearchResult>> {
|
||||
let mut results = self
|
||||
.client
|
||||
.search(query, "repo_tag", &self.model_name, limit + 1)
|
||||
.await?;
|
||||
results.retain(|r| r.payload.entity_id == project_id);
|
||||
results.truncate(limit);
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Search skills by semantic similarity within a project.
|
||||
pub async fn search_skills(
|
||||
&self,
|
||||
query: &str,
|
||||
project_uuid: &str,
|
||||
limit: usize,
|
||||
) -> crate::Result<Vec<SearchResult>> {
|
||||
self.client
|
||||
.search_skills(query, &self.model_name, project_uuid, limit)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Search past conversation messages by semantic similarity within a room.
|
||||
pub async fn search_memories(
|
||||
&self,
|
||||
query: &str,
|
||||
project_name: &str,
|
||||
room_id: &str,
|
||||
limit: usize,
|
||||
) -> crate::Result<Vec<SearchResult>> {
|
||||
self.client
|
||||
.search_memories(query, &self.model_name, project_name, room_id, limit, self.dimensions)
|
||||
.await
|
||||
}
|
||||
}
|
||||
@ -6,6 +6,7 @@ pub mod compact;
|
||||
pub mod embed;
|
||||
pub mod error;
|
||||
pub mod model;
|
||||
pub mod orao;
|
||||
pub mod perception;
|
||||
pub mod react;
|
||||
pub mod skills;
|
||||
@ -13,33 +14,42 @@ pub mod sync;
|
||||
pub mod task;
|
||||
pub mod tokent;
|
||||
pub mod tool;
|
||||
pub use billing::{BillingRecord, BillingResult, record_ai_usage, initialize_user_billing, initialize_project_billing, check_balance, persist_billing_error};
|
||||
pub use sync::list_accessible_models;
|
||||
pub use task::TaskService;
|
||||
pub use tokent::{TokenUsage, resolve_usage};
|
||||
pub use perception::{PerceptionService, SkillContext, SkillEntry, ToolCallEvent};
|
||||
pub use skills::{
|
||||
BuiltInSkill, SKILL_TEMPLATES, all_skill_slugs, all_skills,
|
||||
get_skill, get_skill_by_tool, is_built_in_skill, match_skill_by_keyword, skills_by_category,
|
||||
pub use billing::{
|
||||
check_balance, initialize_project_billing, initialize_user_billing, persist_billing_error,
|
||||
record_ai_usage, BillingRecord, BillingResult,
|
||||
};
|
||||
pub use chat::{
|
||||
AiContextSenderType, AiRequest, AiStreamChunk, ChatService, Mention, RoomMessageContext,
|
||||
StreamCallback,
|
||||
};
|
||||
pub use client::{AiCallResponse, AiClientConfig, call_with_params, call_with_retry};
|
||||
pub use client::types::ChatRequestMessage;
|
||||
pub use client::{call_with_params, call_with_retry, AiCallResponse, AiClientConfig};
|
||||
pub use compact::{CompactConfig, CompactLevel, CompactService, CompactSummary, MessageSummary};
|
||||
pub use embed::{
|
||||
EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult, TagEmbedInput, new_embed_client,
|
||||
new_embed_client, EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult,
|
||||
TagEmbedInput,
|
||||
};
|
||||
pub use error::{AgentError, Result};
|
||||
pub use orao::{
|
||||
ActionExecutor, ActionType, ActionResult, ActionVerdict, OraoConfig, OraoExecutor,
|
||||
OraoExecutorBuilder, OraoOutcome, OraoStep, PerceptionSnapshot, PlannedAction,
|
||||
ReasoningOutput, RoundRecord, SafetyLevel,
|
||||
};
|
||||
pub use perception::{PerceptionService, SkillContext, SkillEntry, ToolCallEvent};
|
||||
pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT, ROOM_CONTEXT_PROMPT};
|
||||
pub use skills::{
|
||||
all_skill_slugs, all_skills, get_skill, get_skill_by_tool, is_built_in_skill, match_skill_by_keyword,
|
||||
skills_by_category, BuiltInSkill, SKILL_TEMPLATES,
|
||||
};
|
||||
pub use sync::list_accessible_models;
|
||||
pub use task::TaskService;
|
||||
pub use tokent::{resolve_usage, TokenUsage};
|
||||
pub use tool::{
|
||||
ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,
|
||||
ToolRegistry, ToolResult, ToolSchema,
|
||||
ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition,
|
||||
ToolError, ToolExecutor, ToolHandler, ToolParam, ToolRegistry, ToolResult, ToolSchema,
|
||||
};
|
||||
|
||||
#[cfg(feature = "rig")]
|
||||
pub use agent::RigAgentService;
|
||||
#[cfg(feature = "rig")]
|
||||
pub use tool::{RigToolSet, RecordingTool, is_retryable_tool_error};
|
||||
pub use tool::{is_retryable_tool_error, RecordingTool, RigToolSet};
|
||||
|
||||
@ -1 +0,0 @@
|
||||
// All reasoning modes removed - using ReAct pattern directly in chat service
|
||||
203
libs/agent/orao/act.rs
Normal file
203
libs/agent/orao/act.rs
Normal file
@ -0,0 +1,203 @@
|
||||
//! Act phase: execute planned actions with safety checks.
|
||||
//!
|
||||
//! Actions are executed through a caller-provided executor callback, which
|
||||
//! typically dispatches to the [`ToolRegistry`] or runs shell commands.
|
||||
//! All file access must go through function calls (tools), never direct
|
||||
//! filesystem operations.
|
||||
//!
|
||||
//! [`ToolRegistry`]: crate::tool::ToolRegistry
|
||||
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::process::Command;
|
||||
use std::time::Duration;
|
||||
|
||||
use super::types::{
|
||||
ActionType, ActionResult, ActionVerdict, OraoConfig, PlannedAction, SafetyLevel,
|
||||
};
|
||||
|
||||
/// Callback for executing a planned action.
|
||||
///
|
||||
/// The caller (service layer) provides this to wire up tool execution.
|
||||
/// Returns `ActionResult` on completion.
|
||||
pub type ActionExecutor = Box<
|
||||
dyn Fn(
|
||||
PlannedAction,
|
||||
) -> Pin<Box<dyn Future<Output = ActionResult> + Send>>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>;
|
||||
|
||||
/// Check whether an action is allowed under the given safety configuration.
|
||||
///
|
||||
/// Returns `None` if allowed, or `Some(reason)` if blocked.
|
||||
pub fn check_safety(action: &PlannedAction, config: &OraoConfig) -> Option<String> {
|
||||
let safety = SafetyLevel::classify_command(&action.command_or_content);
|
||||
|
||||
if safety > config.max_safety_level {
|
||||
return Some(format!(
|
||||
"Action denied: safety level {:?} exceeds max allowed {:?}",
|
||||
safety, config.max_safety_level
|
||||
));
|
||||
}
|
||||
|
||||
// Check for dangerous command patterns
|
||||
if let Some(reason) = check_dangerous_command(&action.command_or_content) {
|
||||
return Some(reason);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Execute a single planned action via the provided executor.
|
||||
///
|
||||
/// Applies safety checks and timeout, then delegates to the executor.
|
||||
pub async fn execute_action(
|
||||
action: PlannedAction,
|
||||
config: &OraoConfig,
|
||||
executor: &ActionExecutor,
|
||||
) -> ActionResult {
|
||||
// ── Safety gate ────────────────────────────────────────────────────
|
||||
if let Some(reason) = check_safety(&action, config) {
|
||||
return ActionResult {
|
||||
action,
|
||||
exit_code: Some(1),
|
||||
stdout: String::new(),
|
||||
stderr: reason,
|
||||
file_changes: Vec::new(),
|
||||
verdict: ActionVerdict::Failure,
|
||||
};
|
||||
}
|
||||
|
||||
// ── Execute with timeout ──────────────────────────────────────────
|
||||
let action_clone = action.clone();
|
||||
let exec_future = executor(action);
|
||||
|
||||
match tokio::time::timeout(
|
||||
Duration::from_secs(config.action_timeout_secs),
|
||||
exec_future,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(_elapsed) => ActionResult {
|
||||
action: action_clone,
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: format!(
|
||||
"Action timed out after {} seconds",
|
||||
config.action_timeout_secs
|
||||
),
|
||||
file_changes: Vec::new(),
|
||||
verdict: ActionVerdict::Failure,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a default action executor that runs shell commands directly.
|
||||
///
|
||||
/// This is suitable for `shell_command` and `git_operation` action types.
|
||||
/// For `tool_invoke`, the caller should provide a custom executor that
|
||||
/// dispatches to the [`ToolRegistry`].
|
||||
///
|
||||
/// [`ToolRegistry`]: crate::tool::ToolRegistry
|
||||
pub fn shell_executor(working_dir: String) -> ActionExecutor {
|
||||
Box::new(move |action: PlannedAction| {
|
||||
let dir = working_dir.clone();
|
||||
Box::pin(async move {
|
||||
match action.action_type {
|
||||
ActionType::ShellCommand
|
||||
| ActionType::GitOperation
|
||||
| ActionType::ToolInvoke => run_shell_command(&action, &dir).await,
|
||||
ActionType::FileWrite | ActionType::FileEdit => {
|
||||
// File operations should use tool_invoke with a file-writing tool.
|
||||
// Direct file access is discouraged; return an error directing to tools.
|
||||
ActionResult {
|
||||
exit_code: Some(1),
|
||||
stdout: String::new(),
|
||||
stderr: "File operations must use tool_invoke with registered file tools. Use shell_command with sed/echo for inline edits.".to_string(),
|
||||
file_changes: Vec::new(),
|
||||
verdict: ActionVerdict::Failure,
|
||||
action,
|
||||
}
|
||||
}
|
||||
ActionType::UserDialog => ActionResult {
|
||||
exit_code: None,
|
||||
stdout: "User dialog requested".to_string(),
|
||||
stderr: String::new(),
|
||||
file_changes: Vec::new(),
|
||||
verdict: ActionVerdict::Success,
|
||||
action,
|
||||
},
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_shell_command(action: &PlannedAction, working_dir: &str) -> ActionResult {
|
||||
let cmd = &action.command_or_content;
|
||||
|
||||
let output = Command::new("sh")
|
||||
.args(["-c", cmd])
|
||||
.current_dir(working_dir)
|
||||
.output();
|
||||
|
||||
match output {
|
||||
Ok(out) => {
|
||||
let exit_code = out.status.code();
|
||||
let stdout = String::from_utf8_lossy(&out.stdout).to_string();
|
||||
let stderr = String::from_utf8_lossy(&out.stderr).to_string();
|
||||
|
||||
let verdict = match exit_code {
|
||||
Some(0) if !stderr_has_errors(&stderr) => ActionVerdict::Success,
|
||||
Some(0) => ActionVerdict::SuccessWithWarnings,
|
||||
_ => ActionVerdict::Failure,
|
||||
};
|
||||
|
||||
ActionResult {
|
||||
action: action.clone(),
|
||||
exit_code,
|
||||
stdout,
|
||||
stderr,
|
||||
file_changes: Vec::new(),
|
||||
verdict,
|
||||
}
|
||||
}
|
||||
Err(e) => ActionResult {
|
||||
action: action.clone(),
|
||||
exit_code: None,
|
||||
stdout: String::new(),
|
||||
stderr: format!("Failed to spawn command: {}", e),
|
||||
file_changes: Vec::new(),
|
||||
verdict: ActionVerdict::Failure,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn stderr_has_errors(stderr: &str) -> bool {
|
||||
let lower = stderr.to_lowercase();
|
||||
lower.contains("error") || lower.contains("fail") || lower.contains("panic")
|
||||
}
|
||||
|
||||
/// Check whether a shell command contains dangerous patterns.
|
||||
///
|
||||
/// Returns `Some(reason)` if the command is blocked, `None` if it's safe.
|
||||
pub fn check_dangerous_command(cmd: &str) -> Option<String> {
|
||||
let dangerous = [
|
||||
("rm -rf /", "Recursive root deletion"),
|
||||
("rm -rf ~", "Recursive home deletion"),
|
||||
(":(){ :|:& };:", "Fork bomb"),
|
||||
("mkfs.", "Filesystem format"),
|
||||
("dd if=", "Raw device write"),
|
||||
("> /dev/sda", "Raw device write"),
|
||||
("chmod 777 /", "World-writable root"),
|
||||
];
|
||||
|
||||
for (pattern, reason) in &dangerous {
|
||||
if cmd.contains(pattern) {
|
||||
return Some(format!("Blocked: {} — {}", pattern, reason));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
427
libs/agent/orao/mod.rs
Normal file
427
libs/agent/orao/mod.rs
Normal file
@ -0,0 +1,427 @@
|
||||
//! ORAO (Observe–Reason–Act–Observe) — a single-agent loop for complex engineering tasks.
|
||||
//!
|
||||
//! ORAO extends the ReAct paradigm with:
|
||||
//! - **Multi-channel perception**: LLM-driven observation via read-only tools
|
||||
//! - **Structured reasoning**: analysis + step-by-step action plan
|
||||
//! - **Safety levels**: L0–L4 permission grading for every action
|
||||
//! - **Deadlock detection**: terminates after 3 rounds with no progress
|
||||
//! - **Plan mode**: optional user-approval gate before execution
|
||||
//! - **Round recording**: full audit trail for debugging and resumption
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! The [`OraoExecutor`] runs the O→R→A→O loop:
|
||||
//! 1. **Observe** — LLM explores environment via observation tools, produces snapshot
|
||||
//! 2. **Reason** — LLM analyzes snapshot, generates structured plan
|
||||
//! 3. **Act** — Execute each planned action via [`ActionExecutor`] with safety checks
|
||||
//! 4. **Observe** — Collect results, feed into next round
|
||||
//!
|
||||
//! All file access goes through function calls (tools), never direct filesystem operations.
|
||||
//!
|
||||
//! [`ActionExecutor`]: act::ActionExecutor
|
||||
|
||||
pub mod act;
|
||||
pub mod observe;
|
||||
pub mod reason;
|
||||
pub mod types;
|
||||
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::client::AiClientConfig;
|
||||
use crate::error::{AgentError, Result};
|
||||
|
||||
pub use act::ActionExecutor;
|
||||
pub use types::{
|
||||
ActionResult, ActionType, ActionVerdict, FileChange, FileChangeType, OraoConfig, OraoStep,
|
||||
PerceptionSnapshot, PlannedAction, ReasoningOutput, RoundRecord, SafetyLevel,
|
||||
};
|
||||
|
||||
// ── ORAO Executor ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Executes the ORAO loop for a single task.
|
||||
///
|
||||
/// All environment interaction goes through:
|
||||
/// - **Observation tools** (read-only) for the Observe phase
|
||||
/// - **Action executor** callback for the Act phase
|
||||
///
|
||||
/// No direct filesystem access — everything is mediated through function calls.
|
||||
pub struct OraoExecutor {
|
||||
config: AiClientConfig,
|
||||
model_name: String,
|
||||
action_executor: ActionExecutor,
|
||||
}
|
||||
|
||||
impl OraoExecutor {
|
||||
/// Create a new ORAO executor.
|
||||
///
|
||||
/// `action_executor` is called to execute each planned action. Wire it to
|
||||
/// your [`ToolRegistry`] for tool-based execution, or use
|
||||
/// [`act::shell_executor`] for simple shell-command execution.
|
||||
///
|
||||
/// [`ToolRegistry`]: crate::tool::ToolRegistry
|
||||
pub fn new(
|
||||
config: AiClientConfig,
|
||||
model_name: impl Into<String>,
|
||||
action_executor: ActionExecutor,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
model_name: model_name.into(),
|
||||
action_executor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the ORAO loop to completion.
|
||||
///
|
||||
/// # Parameters
|
||||
/// - `task_goal`: Description of what to accomplish.
|
||||
/// - `orao_config`: ORAO-specific settings (max rounds, safety level, etc.).
|
||||
/// - `tool_factory`: Called each round to produce read-only observation tools
|
||||
/// (e.g. `git_diff`, `git_blob`, `repo_search`, `git_grep`). This allows
|
||||
/// callers to provide fresh tool instances each round.
|
||||
/// - `on_step`: Called with each [`OraoStep`] event for streaming/persistence.
|
||||
/// - `on_plan_approval`: Called in plan mode; return `true` to proceed.
|
||||
pub async fn execute<C, Fut, PA, PAFut, TF>(
|
||||
&self,
|
||||
task_goal: &str,
|
||||
orao_config: &OraoConfig,
|
||||
tool_factory: TF,
|
||||
on_step: C,
|
||||
on_plan_approval: PA,
|
||||
) -> Result<OraoOutcome>
|
||||
where
|
||||
C: Fn(OraoStep) -> Fut + Send,
|
||||
Fut: Future<Output = ()> + Send,
|
||||
PA: Fn(ReasoningOutput) -> PAFut + Send,
|
||||
PAFut: Future<Output = bool> + Send,
|
||||
TF: Fn() -> Vec<Box<dyn rig::tool::ToolDyn + 'static>> + Send + Sync,
|
||||
{
|
||||
let mut round = 0usize;
|
||||
let mut round_records: Vec<RoundRecord> = Vec::new();
|
||||
let mut previous_result: Option<ActionResult> = None;
|
||||
let mut previous_snapshot: Option<PerceptionSnapshot> = None;
|
||||
let mut no_change_count: usize = 0;
|
||||
|
||||
// Observation turns: limit tool calls during exploration
|
||||
let observe_max_turns = 10;
|
||||
|
||||
loop {
|
||||
round += 1;
|
||||
let round_start = Instant::now();
|
||||
let round_input_tokens: u64 = 0;
|
||||
let round_output_tokens: u64 = 0;
|
||||
|
||||
// ── Phase 1: Observe ───────────────────────────────────────
|
||||
let snapshot = observe::observe(
|
||||
&self.config,
|
||||
&self.model_name,
|
||||
task_goal,
|
||||
previous_result.take(),
|
||||
tool_factory(),
|
||||
observe_max_turns,
|
||||
)
|
||||
.await?;
|
||||
|
||||
on_step(OraoStep::Observe {
|
||||
round,
|
||||
snapshot: snapshot.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
// ── Deadlock detection ─────────────────────────────────────
|
||||
if let Some(ref prev) = previous_snapshot {
|
||||
if !observe::has_environment_changed(prev, &snapshot) {
|
||||
no_change_count += 1;
|
||||
if no_change_count >= orao_config.deadlock_threshold {
|
||||
let reason = format!(
|
||||
"Deadlock detected: no environmental change for {} consecutive rounds",
|
||||
no_change_count
|
||||
);
|
||||
on_step(OraoStep::Failed {
|
||||
total_rounds: round,
|
||||
reason: reason.clone(),
|
||||
})
|
||||
.await;
|
||||
return Ok(OraoOutcome::Failed {
|
||||
reason,
|
||||
rounds: round,
|
||||
records: round_records,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
no_change_count = 0;
|
||||
}
|
||||
}
|
||||
previous_snapshot = Some(snapshot.clone());
|
||||
|
||||
// ── Phase 2: Reason ────────────────────────────────────────
|
||||
let reasoning = reason::reason(
|
||||
&self.config,
|
||||
&self.model_name,
|
||||
orao_config,
|
||||
task_goal,
|
||||
&snapshot,
|
||||
round,
|
||||
&round_records,
|
||||
)
|
||||
.await?;
|
||||
|
||||
on_step(OraoStep::Reason {
|
||||
round,
|
||||
reasoning: reasoning.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
// ── Plan mode gate ─────────────────────────────────────────
|
||||
if orao_config.plan_mode {
|
||||
on_step(OraoStep::PlanProposed {
|
||||
round,
|
||||
reasoning: reasoning.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
if !on_plan_approval(reasoning.clone()).await {
|
||||
return Ok(OraoOutcome::Cancelled {
|
||||
rounds: round,
|
||||
records: round_records,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// ── Phase 3: Act ───────────────────────────────────────────
|
||||
let mut round_result: Option<ActionResult> = None;
|
||||
let mut all_success = true;
|
||||
|
||||
for planned in &reasoning.plan {
|
||||
let safety = SafetyLevel::classify_command(&planned.command_or_content);
|
||||
|
||||
on_step(OraoStep::Act {
|
||||
round,
|
||||
action: planned.clone(),
|
||||
safety_level: safety,
|
||||
})
|
||||
.await;
|
||||
|
||||
let result =
|
||||
act::execute_action(planned.clone(), orao_config, &self.action_executor).await;
|
||||
|
||||
on_step(OraoStep::ObserveResult {
|
||||
round,
|
||||
result: result.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
match &result.verdict {
|
||||
ActionVerdict::Failure => {
|
||||
all_success = false;
|
||||
round_result = Some(result);
|
||||
break; // Stop executing further steps on failure
|
||||
}
|
||||
ActionVerdict::SuccessWithWarnings => {
|
||||
round_result = Some(result);
|
||||
}
|
||||
ActionVerdict::Success => {
|
||||
round_result = Some(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Phase 4: Record round ──────────────────────────────────
|
||||
let duration_ms = round_start.elapsed().as_millis() as u64;
|
||||
let record = RoundRecord {
|
||||
round,
|
||||
observe_summary: summarize_snapshot(&snapshot),
|
||||
reasoning_summary: reasoning.analysis.clone(),
|
||||
action: reasoning.plan.first().cloned(),
|
||||
result_summary: round_result
|
||||
.as_ref()
|
||||
.map(|r| format!("{:?}: {}", r.verdict, truncate(&r.stdout, 200))),
|
||||
tokens_input: round_input_tokens,
|
||||
tokens_output: round_output_tokens,
|
||||
duration_ms,
|
||||
};
|
||||
round_records.push(record);
|
||||
|
||||
// ── Check termination ──────────────────────────────────────
|
||||
if all_success && !reasoning.plan.is_empty() {
|
||||
let summary = format!(
|
||||
"Task completed in {} round(s). Last action: {}",
|
||||
round,
|
||||
round_result
|
||||
.as_ref()
|
||||
.map(|r| truncate(&r.stdout, 500))
|
||||
.unwrap_or_default()
|
||||
);
|
||||
on_step(OraoStep::Completed {
|
||||
total_rounds: round,
|
||||
summary: summary.clone(),
|
||||
})
|
||||
.await;
|
||||
return Ok(OraoOutcome::Completed {
|
||||
summary,
|
||||
rounds: round,
|
||||
records: round_records,
|
||||
});
|
||||
}
|
||||
|
||||
// Max rounds exceeded
|
||||
if round >= orao_config.max_rounds {
|
||||
let reason = format!("Reached max rounds ({})", orao_config.max_rounds);
|
||||
on_step(OraoStep::Failed {
|
||||
total_rounds: round,
|
||||
reason: reason.clone(),
|
||||
})
|
||||
.await;
|
||||
return Ok(OraoOutcome::Failed {
|
||||
reason,
|
||||
rounds: round,
|
||||
records: round_records,
|
||||
});
|
||||
}
|
||||
|
||||
// Prepare for next round
|
||||
previous_result = round_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Outcome ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Final outcome of an ORAO execution.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OraoOutcome {
|
||||
/// Task completed successfully.
|
||||
Completed {
|
||||
summary: String,
|
||||
rounds: usize,
|
||||
records: Vec<RoundRecord>,
|
||||
},
|
||||
/// Task failed (max rounds, deadlock, or unrecoverable error).
|
||||
Failed {
|
||||
reason: String,
|
||||
rounds: usize,
|
||||
records: Vec<RoundRecord>,
|
||||
},
|
||||
/// User cancelled the task (plan mode rejection or explicit interrupt).
|
||||
Cancelled {
|
||||
rounds: usize,
|
||||
records: Vec<RoundRecord>,
|
||||
},
|
||||
}
|
||||
|
||||
impl OraoOutcome {
|
||||
/// Number of rounds executed.
|
||||
pub fn rounds(&self) -> usize {
|
||||
match self {
|
||||
Self::Completed { rounds, .. }
|
||||
| Self::Failed { rounds, .. }
|
||||
| Self::Cancelled { rounds, .. } => *rounds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether the task was successful.
|
||||
pub fn is_success(&self) -> bool {
|
||||
matches!(self, Self::Completed { .. })
|
||||
}
|
||||
|
||||
/// Round records for audit/debugging.
|
||||
pub fn records(&self) -> &[RoundRecord] {
|
||||
match self {
|
||||
Self::Completed { records, .. }
|
||||
| Self::Failed { records, .. }
|
||||
| Self::Cancelled { records, .. } => records,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
fn summarize_snapshot(snapshot: &PerceptionSnapshot) -> String {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
|
||||
if let Some(ref gs) = snapshot.git_status {
|
||||
let first_line = gs.lines().next().unwrap_or("");
|
||||
parts.push(format!("git: {}", truncate(first_line, 80)));
|
||||
}
|
||||
|
||||
if !snapshot.files.is_empty() {
|
||||
parts.push(format!("{} files", snapshot.files.len()));
|
||||
}
|
||||
|
||||
if !snapshot.errors.is_empty() {
|
||||
parts.push(format!("{} errors", snapshot.errors.len()));
|
||||
}
|
||||
|
||||
if parts.is_empty() {
|
||||
"no changes".to_string()
|
||||
} else {
|
||||
parts.join(", ")
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate(s: &str, max_len: usize) -> String {
|
||||
if s.len() <= max_len {
|
||||
s.to_string()
|
||||
} else {
|
||||
format!("{}...", &s[..max_len])
|
||||
}
|
||||
}
|
||||
|
||||
// ── Convenience builder ─────────────────────────────────────────────────────
|
||||
|
||||
/// Builder for [`OraoExecutor`] with chainable configuration.
|
||||
pub struct OraoExecutorBuilder {
|
||||
config: Option<AiClientConfig>,
|
||||
model_name: Option<String>,
|
||||
action_executor: Option<ActionExecutor>,
|
||||
}
|
||||
|
||||
impl OraoExecutorBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: None,
|
||||
model_name: None,
|
||||
action_executor: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ai_config(mut self, config: AiClientConfig) -> Self {
|
||||
self.config = Some(config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn model(mut self, name: impl Into<String>) -> Self {
|
||||
self.model_name = Some(name.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn action_executor(mut self, executor: ActionExecutor) -> Self {
|
||||
self.action_executor = Some(executor);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<OraoExecutor> {
|
||||
let config = self.config.ok_or_else(|| AgentError::InvalidInput {
|
||||
field: "config".to_string(),
|
||||
reason: "AI client config is required".to_string(),
|
||||
})?;
|
||||
let model_name = self.model_name.ok_or_else(|| AgentError::InvalidInput {
|
||||
field: "model_name".to_string(),
|
||||
reason: "Model name is required".to_string(),
|
||||
})?;
|
||||
let action_executor = self
|
||||
.action_executor
|
||||
.ok_or_else(|| AgentError::InvalidInput {
|
||||
field: "action_executor".to_string(),
|
||||
reason: "Action executor is required".to_string(),
|
||||
})?;
|
||||
|
||||
Ok(OraoExecutor::new(config, model_name, action_executor))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OraoExecutorBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
280
libs/agent/orao/observe.rs
Normal file
280
libs/agent/orao/observe.rs
Normal file
@ -0,0 +1,280 @@
|
||||
//! Observe phase: LLM-driven multi-channel environment perception.
|
||||
//!
|
||||
//! The Observe phase gives the LLM a set of read-only observation tools and
|
||||
//! instructs it to explore the environment. All file/git/system access goes
|
||||
//! through function calls (tools), never direct filesystem operations.
|
||||
//!
|
||||
//! After exploration, the LLM produces a structured [`PerceptionSnapshot`]
|
||||
//! summarizing the current state of the project.
|
||||
|
||||
use rig::agent::AgentBuilder;
|
||||
use rig::client::CompletionClient;
|
||||
use rig::completion::Prompt;
|
||||
|
||||
use crate::client::AiClientConfig;
|
||||
use crate::error::AgentError;
|
||||
|
||||
use super::types::{ActionResult, PerceptionSnapshot};
|
||||
|
||||
/// Prompt for the ORAO Observe phase.
|
||||
const OBSERVE_SYSTEM_PROMPT: &str = r#"You are an expert software engineering agent using the ORAO (Observe-Reason-Act-Observe) framework.
|
||||
|
||||
## Your Role: OBSERVE Phase
|
||||
|
||||
You are currently in the OBSERVE phase. Your task is to explore the project environment
|
||||
and gather all relevant information using the available tools.
|
||||
|
||||
## What to Observe
|
||||
|
||||
Use the tools provided to you to check:
|
||||
|
||||
1. **Git status**: What branch are we on? What files have changed? Any uncommitted work?
|
||||
2. **Project structure**: What directories and key files exist?
|
||||
3. **Code content**: Read relevant source files to understand the codebase state.
|
||||
4. **Errors/warnings**: Check build output, test results, linter output for issues.
|
||||
5. **Configuration**: Check project config files (Cargo.toml, package.json, etc.) if relevant.
|
||||
|
||||
## Rules
|
||||
|
||||
- Use tools to explore — do NOT guess or assume file contents.
|
||||
- Focus on information relevant to the task at hand.
|
||||
- Be thorough but efficient: 3-8 tool calls is typical.
|
||||
- After gathering information, summarize your findings clearly.
|
||||
|
||||
## Output Format
|
||||
|
||||
After you have finished observing, provide a summary with these sections:
|
||||
|
||||
### Git Status
|
||||
[Current branch, changed files, commit status]
|
||||
|
||||
### Project Structure
|
||||
[Key directories and files relevant to the task]
|
||||
|
||||
### Key Files
|
||||
[Important files you read, with brief notes on their content]
|
||||
|
||||
### Errors / Issues
|
||||
[Any errors, warnings, or problems detected]
|
||||
|
||||
### Previous Action Result
|
||||
[If a previous action was executed, describe its outcome]"#;
|
||||
|
||||
/// Run the Observe phase: let the LLM explore the environment via tools.
|
||||
///
|
||||
/// Returns a structured [`PerceptionSnapshot`] built from the LLM's observations.
|
||||
/// All environment access goes through the provided `tools` — no direct
|
||||
/// filesystem operations.
|
||||
///
|
||||
/// Takes ownership of `tools` (caller must clone if they need to reuse them).
|
||||
pub async fn observe(
|
||||
config: &AiClientConfig,
|
||||
model_name: &str,
|
||||
task_goal: &str,
|
||||
previous_result: Option<ActionResult>,
|
||||
tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>>,
|
||||
max_turns: usize,
|
||||
) -> Result<PerceptionSnapshot, AgentError> {
|
||||
let user_prompt = build_observe_prompt(task_goal, previous_result.as_ref());
|
||||
|
||||
let client = config.build_rig_client();
|
||||
let model = client.completion_model(model_name);
|
||||
|
||||
let agent = AgentBuilder::new(model)
|
||||
.preamble(OBSERVE_SYSTEM_PROMPT)
|
||||
.tools(tools)
|
||||
.default_max_turns(max_turns)
|
||||
.build();
|
||||
|
||||
let response = agent
|
||||
.prompt(&user_prompt)
|
||||
.max_turns(max_turns)
|
||||
.extended_details()
|
||||
.await
|
||||
.map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?;
|
||||
|
||||
// Build snapshot from the LLM's final summary
|
||||
let summary = response.output;
|
||||
let snapshot = parse_observation_summary(&summary, previous_result);
|
||||
|
||||
Ok(snapshot)
|
||||
}
|
||||
|
||||
/// Build the user prompt for the Observe phase.
|
||||
fn build_observe_prompt(task_goal: &str, previous_result: Option<&ActionResult>) -> String {
|
||||
let mut prompt = format!(
|
||||
"## Task Goal\n\n{}\n\n## Instructions\n\n\
|
||||
Explore the project environment using the available tools. \
|
||||
Gather all information relevant to the task above. \
|
||||
After you have gathered sufficient information, provide a structured summary.",
|
||||
task_goal
|
||||
);
|
||||
|
||||
if let Some(prev) = previous_result {
|
||||
prompt.push_str(&format!(
|
||||
"\n\n## Previous Action Result\n\n\
|
||||
- Action: {}\n\
|
||||
- Verdict: {:?}\n\
|
||||
- Exit code: {:?}\n\
|
||||
- stdout: {}\n\
|
||||
- stderr: {}",
|
||||
prev.action.description,
|
||||
prev.verdict,
|
||||
prev.exit_code,
|
||||
truncate_str(&prev.stdout, 2000),
|
||||
truncate_str(&prev.stderr, 2000),
|
||||
));
|
||||
}
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Parse the LLM's observation summary into a structured snapshot.
|
||||
fn parse_observation_summary(
|
||||
summary: &str,
|
||||
previous_result: Option<ActionResult>,
|
||||
) -> PerceptionSnapshot {
|
||||
let mut snapshot = PerceptionSnapshot::default();
|
||||
|
||||
// Extract sections from the markdown summary
|
||||
let mut current_section = "";
|
||||
let mut section_content: Vec<&str> = Vec::new();
|
||||
|
||||
for line in summary.lines() {
|
||||
if line.starts_with("### ") {
|
||||
// Save previous section
|
||||
store_section(&mut snapshot, current_section, §ion_content);
|
||||
current_section = line.trim_start_matches("### ").trim();
|
||||
section_content.clear();
|
||||
} else {
|
||||
section_content.push(line);
|
||||
}
|
||||
}
|
||||
// Save last section
|
||||
store_section(&mut snapshot, current_section, §ion_content);
|
||||
|
||||
snapshot.previous_action_result = previous_result;
|
||||
|
||||
// If no structured data was parsed, store the raw summary
|
||||
if snapshot.git_status.is_none()
|
||||
&& snapshot.project_structure.is_none()
|
||||
&& snapshot.files.is_empty()
|
||||
&& snapshot.errors.is_empty()
|
||||
{
|
||||
snapshot.notes.insert(
|
||||
"raw_observation".to_string(),
|
||||
summary.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
snapshot
|
||||
}
|
||||
|
||||
fn store_section(snapshot: &mut PerceptionSnapshot, section: &str, content: &[&str]) {
|
||||
let text = content.join("\n").trim().to_string();
|
||||
if text.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
match section.to_lowercase().as_str() {
|
||||
s if s.contains("git") => {
|
||||
snapshot.git_status = Some(text);
|
||||
}
|
||||
s if s.contains("project") && s.contains("structure") => {
|
||||
snapshot.project_structure = Some(text);
|
||||
}
|
||||
s if s.contains("file") => {
|
||||
// Parse file references from the text
|
||||
for line in content {
|
||||
let line = line.trim();
|
||||
if let Some(path) = extract_file_path(line) {
|
||||
snapshot.files.push(super::types::PerceivedFile {
|
||||
path,
|
||||
size_bytes: 0,
|
||||
content_preview: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
s if s.contains("error") || s.contains("issue") || s.contains("warning") => {
|
||||
for line in content {
|
||||
let line = line.trim();
|
||||
if !line.is_empty() && !line.starts_with('#') {
|
||||
snapshot.errors.push(line.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Store unknown sections as notes
|
||||
snapshot
|
||||
.notes
|
||||
.insert(section.to_string(), text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract a file path from a markdown list item or code reference.
|
||||
fn extract_file_path(line: &str) -> Option<String> {
|
||||
// Match patterns like: - `src/main.rs` or - src/main.rs or `src/main.rs`
|
||||
let line = line.trim();
|
||||
|
||||
// Backtick-wrapped path
|
||||
if let Some(start) = line.find('`') {
|
||||
let rest = &line[start + 1..];
|
||||
if let Some(end) = rest.find('`') {
|
||||
let path = rest[..end].to_string();
|
||||
if path.contains('.') || path.contains('/') || path.contains('\\') {
|
||||
return Some(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Bare path pattern (word chars, slashes, dots)
|
||||
if line.starts_with('-') || line.starts_with('*') {
|
||||
let rest = line.trim_start_matches(&['-', '*', ' ']);
|
||||
if rest.contains('/') || (rest.contains('.') && !rest.starts_with("http")) {
|
||||
return Some(rest.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn truncate_str(s: &str, max_len: usize) -> String {
|
||||
if s.len() <= max_len {
|
||||
s.to_string()
|
||||
} else {
|
||||
format!("{}...", &s[..max_len])
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine whether the environment has changed since the last snapshot.
|
||||
///
|
||||
/// Used for deadlock detection: if 3 consecutive rounds show no change,
|
||||
/// the loop is terminated.
|
||||
pub fn has_environment_changed(
|
||||
previous: &PerceptionSnapshot,
|
||||
current: &PerceptionSnapshot,
|
||||
) -> bool {
|
||||
if previous.git_status != current.git_status {
|
||||
return true;
|
||||
}
|
||||
|
||||
let prev_files: Vec<&str> = previous.files.iter().map(|f| f.path.as_str()).collect();
|
||||
let curr_files: Vec<&str> = current.files.iter().map(|f| f.path.as_str()).collect();
|
||||
if prev_files != curr_files {
|
||||
return true;
|
||||
}
|
||||
|
||||
if previous.errors != current.errors {
|
||||
return true;
|
||||
}
|
||||
|
||||
let prev_has_result = previous.previous_action_result.is_some();
|
||||
let curr_has_result = current.previous_action_result.is_some();
|
||||
if prev_has_result != curr_has_result {
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
218
libs/agent/orao/reason.rs
Normal file
218
libs/agent/orao/reason.rs
Normal file
@ -0,0 +1,218 @@
|
||||
//! Reason phase: structured reasoning and plan generation via LLM.
|
||||
//!
|
||||
//! Takes the current [`PerceptionSnapshot`] and the task goal, sends them to
|
||||
//! the configured LLM, and produces a structured [`ReasoningOutput`] with an
|
||||
//! executable plan.
|
||||
|
||||
use crate::client::AiClientConfig;
|
||||
use crate::error::AgentError;
|
||||
use rig::agent::AgentBuilder;
|
||||
use rig::client::CompletionClient;
|
||||
use rig::completion::Prompt;
|
||||
|
||||
use super::types::{OraoConfig, PerceptionSnapshot, ReasoningOutput};
|
||||
|
||||
/// Prompt template for the ORAO reasoning phase.
|
||||
const REASON_SYSTEM_PROMPT: &str = r#"You are an expert software engineering agent using the ORAO (Observe-Reason-Act-Observe) framework.
|
||||
|
||||
## Your Role: REASON Phase
|
||||
|
||||
You are currently in the REASON phase. You will receive:
|
||||
1. A **task goal** — what needs to be accomplished
|
||||
2. An **observation snapshot** — the current state of the environment
|
||||
|
||||
Your job is to produce a structured analysis and a step-by-step action plan.
|
||||
|
||||
## Output Format
|
||||
|
||||
You MUST respond with a valid JSON object matching this schema:
|
||||
|
||||
```json
|
||||
{
|
||||
"analysis": "<concise analysis of the current state, problems identified, constraints>",
|
||||
"plan": [
|
||||
{
|
||||
"step_id": 1,
|
||||
"description": "<what this step does>",
|
||||
"action_type": "shell_command | file_write | file_edit | git_operation | tool_invoke | user_dialog",
|
||||
"command_or_content": "<the exact command to run or content to write>",
|
||||
"expected_result": "<what success looks like>",
|
||||
"fallback_on_failure": "<what to do if this fails, or null>"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Rules
|
||||
|
||||
1. **Be specific**: commands must be exact and complete (include necessary flags, paths, etc.)
|
||||
2. **One action per step**: each step should do one thing
|
||||
3. **Validate assumptions**: don't assume dependencies are installed or files exist
|
||||
4. **Prefer small, safe steps**: each change should be minimal and reversible
|
||||
5. **Include verification**: build/test after changes to verify correctness
|
||||
6. **Plan size**: 1-10 steps is typical. Don't plan more than 15 steps without asking the user.
|
||||
7. **Fallbacks**: for risky steps, specify what to try if the step fails
|
||||
|
||||
## Safety
|
||||
|
||||
- Prefer read-only verification before making changes
|
||||
- Use version control (git) for reversible operations
|
||||
- Flag any step that requires network access or system-level privileges"#;
|
||||
|
||||
/// Run the reasoning phase: send observation + task to the LLM and get a plan.
|
||||
pub async fn reason(
|
||||
config: &AiClientConfig,
|
||||
model_name: &str,
|
||||
_orao_config: &OraoConfig,
|
||||
task_goal: &str,
|
||||
snapshot: &PerceptionSnapshot,
|
||||
round: usize,
|
||||
history_rounds: &[super::types::RoundRecord],
|
||||
) -> Result<ReasoningOutput, AgentError> {
|
||||
let user_prompt = build_reason_prompt(task_goal, snapshot, round, history_rounds);
|
||||
|
||||
let client = config.build_rig_client();
|
||||
let model = client.completion_model(model_name);
|
||||
|
||||
let agent = AgentBuilder::new(model)
|
||||
.preamble(REASON_SYSTEM_PROMPT)
|
||||
.build();
|
||||
|
||||
let response = agent
|
||||
.prompt(&user_prompt)
|
||||
.extended_details()
|
||||
.await
|
||||
.map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?;
|
||||
|
||||
let output = response.output.trim().to_string();
|
||||
|
||||
// Extract JSON from the response (it may be wrapped in ```json fences)
|
||||
let json_str = extract_json(&output);
|
||||
|
||||
serde_json::from_str::<ReasoningOutput>(json_str).map_err(|e| {
|
||||
AgentError::Internal(format!(
|
||||
"Failed to parse reasoning output as JSON: {}. Raw output: {}",
|
||||
e, output
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Build the user prompt for the reason phase.
|
||||
fn build_reason_prompt(
|
||||
task_goal: &str,
|
||||
snapshot: &PerceptionSnapshot,
|
||||
round: usize,
|
||||
history_rounds: &[super::types::RoundRecord],
|
||||
) -> String {
|
||||
let mut prompt = format!(
|
||||
"## Task Goal\n\n{}\n\n## Round Number\n\n{}\n\n",
|
||||
task_goal, round
|
||||
);
|
||||
|
||||
// ── Observation snapshot ────────────────────────────────────────────
|
||||
prompt.push_str("## Current Observation\n\n");
|
||||
|
||||
if let Some(ref git_status) = snapshot.git_status {
|
||||
prompt.push_str(&format!("### Git Status\n```\n{}\n```\n\n", git_status));
|
||||
}
|
||||
|
||||
if let Some(ref project_structure) = snapshot.project_structure {
|
||||
prompt.push_str(&format!(
|
||||
"### Project Structure\n```\n{}\n```\n\n",
|
||||
project_structure
|
||||
));
|
||||
}
|
||||
|
||||
if !snapshot.files.is_empty() {
|
||||
prompt.push_str("### Relevant Files\n\n");
|
||||
for file in &snapshot.files {
|
||||
prompt.push_str(&format!(
|
||||
"- `{}` ({} bytes)\n",
|
||||
file.path, file.size_bytes
|
||||
));
|
||||
if let Some(ref preview) = file.content_preview {
|
||||
let truncated = if preview.len() > 2000 {
|
||||
format!("{}... [truncated]", &preview[..2000])
|
||||
} else {
|
||||
preview.clone()
|
||||
};
|
||||
prompt.push_str(&format!("```\n{}\n```\n\n", truncated));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !snapshot.errors.is_empty() {
|
||||
prompt.push_str("### Errors / Warnings\n\n");
|
||||
for err in &snapshot.errors {
|
||||
prompt.push_str(&format!("- {}\n", err));
|
||||
}
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
if let Some(ref prev_result) = snapshot.previous_action_result {
|
||||
prompt.push_str(&format!(
|
||||
"### Previous Action Result\n- Verdict: {:?}\n- Exit code: {:?}\n- stdout: {}\n- stderr: {}\n\n",
|
||||
prev_result.verdict,
|
||||
prev_result.exit_code,
|
||||
truncate_str(&prev_result.stdout, 1000),
|
||||
truncate_str(&prev_result.stderr, 1000),
|
||||
));
|
||||
}
|
||||
|
||||
// ── History summary ─────────────────────────────────────────────────
|
||||
if !history_rounds.is_empty() {
|
||||
prompt.push_str("## Previous Rounds\n\n");
|
||||
for record in history_rounds.iter().rev().take(3) {
|
||||
prompt.push_str(&format!(
|
||||
"- Round {}: {} | {} tokens | {}ms\n",
|
||||
record.round,
|
||||
truncate_str(&record.reasoning_summary, 120),
|
||||
record.tokens_input + record.tokens_output,
|
||||
record.duration_ms,
|
||||
));
|
||||
}
|
||||
prompt.push('\n');
|
||||
}
|
||||
|
||||
prompt.push_str(
|
||||
"## Instructions\n\nProduce a structured analysis and action plan in JSON format as specified.",
|
||||
);
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Extract JSON content from a string that may be wrapped in markdown fences.
|
||||
fn extract_json(s: &str) -> &str {
|
||||
let s = s.trim();
|
||||
|
||||
if s.starts_with("```json") {
|
||||
let inner = &s[7..];
|
||||
if let Some(end) = inner.rfind("```") {
|
||||
return inner[..end].trim();
|
||||
}
|
||||
return inner.trim();
|
||||
}
|
||||
|
||||
if s.starts_with("```") {
|
||||
let inner = &s[3..];
|
||||
if let Some(end) = inner.rfind("```") {
|
||||
return inner[..end].trim();
|
||||
}
|
||||
return inner.trim();
|
||||
}
|
||||
|
||||
// Heuristic: find first { and last }
|
||||
if let (Some(start), Some(end)) = (s.find('{'), s.rfind('}')) {
|
||||
return &s[start..=end];
|
||||
}
|
||||
|
||||
s
|
||||
}
|
||||
|
||||
fn truncate_str(s: &str, max_len: usize) -> String {
|
||||
if s.len() <= max_len {
|
||||
s.to_string()
|
||||
} else {
|
||||
format!("{}...", &s[..max_len])
|
||||
}
|
||||
}
|
||||
337
libs/agent/orao/types.rs
Normal file
337
libs/agent/orao/types.rs
Normal file
@ -0,0 +1,337 @@
|
||||
//! ORAO core types.
|
||||
//!
|
||||
//! ORAO (Observe–Reason–Act–Observe) is a single-agent loop paradigm for complex
|
||||
//! engineering tasks. It extends ReAct with structured multi-channel perception,
|
||||
//! safety permission levels, plan mode, and deadlock detection.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ── Safety levels ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Permission level for actions executed by ORAO.
|
||||
///
|
||||
/// L0 (read-only) → auto-allow.
|
||||
/// L1 (local write) → confirm on first use.
|
||||
/// L2 (build) → confirm on first use.
|
||||
/// L3 (network) → explicit user approval required.
|
||||
/// L4 (system) → denied by default.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub enum SafetyLevel {
|
||||
/// L0 — Read-only: `ls`, `cat`, `grep`, `git status`
|
||||
ReadOnly = 0,
|
||||
/// L1 — Local write: edit source files, create new files
|
||||
LocalWrite = 1,
|
||||
/// L2 — Build/test: `cargo build`, `npm test`
|
||||
Build = 2,
|
||||
/// L3 — Network: `pip install`, `curl`
|
||||
Network = 3,
|
||||
/// L4 — System: `sudo`, global config changes
|
||||
System = 4,
|
||||
}
|
||||
|
||||
impl SafetyLevel {
|
||||
/// Classify a shell command into a safety level.
|
||||
pub fn classify_command(cmd: &str) -> Self {
|
||||
let cmd_trimmed = cmd.trim();
|
||||
// L0: read-only commands
|
||||
let l0_prefixes = [
|
||||
"ls", "cat", "head", "tail", "less", "file", "stat", "wc",
|
||||
"grep", "rg", "find", "which", "type", "echo", "printf",
|
||||
"pwd", "env", "printenv", "date", "uname", "hostname",
|
||||
"git status", "git log", "git diff", "git show", "git branch",
|
||||
"git tag", "git remote", "git config --get", "git blame",
|
||||
"cargo metadata", "cargo tree", "cargo read-manifest",
|
||||
"tree", "du", "df",
|
||||
];
|
||||
for p in &l0_prefixes {
|
||||
if cmd_trimmed.starts_with(p) {
|
||||
return Self::ReadOnly;
|
||||
}
|
||||
}
|
||||
|
||||
// L4: system-level commands (denied by default)
|
||||
let l4_patterns = [
|
||||
"sudo", "su ", "chown", "chmod 777", "mkfs", "mkswap",
|
||||
"mount", "umount", "fdisk", "parted", "dd if=",
|
||||
"systemctl", "service ", "chkconfig", "update-rc.d",
|
||||
"passwd", "useradd", "userdel", "usermod", "groupadd",
|
||||
"iptables", "ufw", "firewall-cmd",
|
||||
"shutdown", "reboot", "halt", "poweroff",
|
||||
"rm -rf /", "rm -rf ~", "rm -rf .", ":(){ :|:& };:",
|
||||
];
|
||||
for p in &l4_patterns {
|
||||
if cmd_trimmed.starts_with(p) || cmd_trimmed.contains(p) {
|
||||
return Self::System;
|
||||
}
|
||||
}
|
||||
|
||||
// L3: network commands
|
||||
let l3_prefixes = [
|
||||
"curl", "wget", "nc ", "ncat", "telnet", "ssh ", "scp",
|
||||
"rsync", "pip install", "pip3 install", "npm install",
|
||||
"npm i ", "yarn add", "cargo install", "gem install",
|
||||
"go get", "go install", "apt-get", "apt ", "yum ", "dnf ",
|
||||
"brew ", "pacman ", "zypper", "docker pull", "docker run",
|
||||
"git clone", "git fetch", "git push", "git pull",
|
||||
"gh ", "glab ", "aws ", "gcloud ", "az ",
|
||||
];
|
||||
for p in &l3_prefixes {
|
||||
if cmd_trimmed.starts_with(p) {
|
||||
return Self::Network;
|
||||
}
|
||||
}
|
||||
|
||||
// L2: build/test commands
|
||||
let l2_prefixes = [
|
||||
"cargo build", "cargo test", "cargo check", "cargo clippy",
|
||||
"cargo fmt", "cargo run", "cargo bench", "cargo doc",
|
||||
"npm test", "npm run", "npx ", "yarn test", "yarn run",
|
||||
"pnpm test", "pnpm run", "bun test", "bun run",
|
||||
"make", "cmake", "ninja", "meson", "bazel",
|
||||
"pytest", "python -m pytest", "python3 -m pytest",
|
||||
"go test", "go build", "go vet", "go fmt",
|
||||
"rustc", "rustfmt", "clippy", "miri",
|
||||
"eslint", "prettier", "tsc", "jest", "vitest",
|
||||
"docker build", "docker compose", "docker-compose",
|
||||
"kubectl apply", "kubectl delete", "helm ",
|
||||
];
|
||||
for p in &l2_prefixes {
|
||||
if cmd_trimmed.starts_with(p) {
|
||||
return Self::Build;
|
||||
}
|
||||
}
|
||||
|
||||
// Default to L1 (local write) for anything else
|
||||
Self::LocalWrite
|
||||
}
|
||||
}
|
||||
|
||||
// ── Action types ────────────────────────────────────────────────────────────
|
||||
|
||||
/// The type of action to execute.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ActionType {
|
||||
/// Execute a shell command in a controlled terminal.
|
||||
ShellCommand,
|
||||
/// Create or overwrite a file.
|
||||
FileWrite,
|
||||
/// Make a localized edit to an existing file.
|
||||
FileEdit,
|
||||
/// Version-control operation (commit, add, etc.).
|
||||
GitOperation,
|
||||
/// Invoke an external tool or API.
|
||||
ToolInvoke,
|
||||
/// Ask the user for input or a decision.
|
||||
UserDialog,
|
||||
}
|
||||
|
||||
// ── Action plan ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// A single planned action from the reasoning phase.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlannedAction {
|
||||
/// Step number within the plan.
|
||||
pub step_id: usize,
|
||||
/// Human-readable description.
|
||||
pub description: String,
|
||||
/// The type of action.
|
||||
pub action_type: ActionType,
|
||||
/// The command or content to execute/write.
|
||||
pub command_or_content: String,
|
||||
/// What success should look like.
|
||||
pub expected_result: String,
|
||||
/// What to try if this step fails.
|
||||
pub fallback_on_failure: Option<String>,
|
||||
}
|
||||
|
||||
/// Structured reasoning output from the Reason phase.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReasoningOutput {
|
||||
/// Analysis of the current state.
|
||||
pub analysis: String,
|
||||
/// The plan to execute.
|
||||
pub plan: Vec<PlannedAction>,
|
||||
}
|
||||
|
||||
// ── Perception snapshot ─────────────────────────────────────────────────────
|
||||
|
||||
/// Structured observation collected during the Observe phase.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct PerceptionSnapshot {
|
||||
/// Project directory tree summary.
|
||||
pub project_structure: Option<String>,
|
||||
/// Relevant file paths and contents.
|
||||
pub files: Vec<PerceivedFile>,
|
||||
/// Current errors/warnings in the environment.
|
||||
pub errors: Vec<String>,
|
||||
/// Git status summary.
|
||||
pub git_status: Option<String>,
|
||||
/// Result of the previous action (if any).
|
||||
pub previous_action_result: Option<ActionResult>,
|
||||
/// Free-form context notes.
|
||||
pub notes: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// A file observed during perception.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PerceivedFile {
|
||||
pub path: String,
|
||||
pub size_bytes: u64,
|
||||
pub content_preview: Option<String>,
|
||||
}
|
||||
|
||||
// ── Action result ───────────────────────────────────────────────────────────
|
||||
|
||||
/// The result of executing an action.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ActionResult {
|
||||
/// The action that was executed.
|
||||
pub action: PlannedAction,
|
||||
/// Exit code (0 = success for shell commands).
|
||||
pub exit_code: Option<i32>,
|
||||
/// Captured stdout.
|
||||
pub stdout: String,
|
||||
/// Captured stderr.
|
||||
pub stderr: String,
|
||||
/// Summary of file changes (if applicable).
|
||||
pub file_changes: Vec<FileChange>,
|
||||
/// Preliminary assessment.
|
||||
pub verdict: ActionVerdict,
|
||||
}
|
||||
|
||||
/// A file change detected after an action.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FileChange {
|
||||
pub path: String,
|
||||
pub change_type: FileChangeType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum FileChangeType {
|
||||
Created,
|
||||
Modified,
|
||||
Deleted,
|
||||
}
|
||||
|
||||
/// Preliminary verdict on an action's outcome.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ActionVerdict {
|
||||
Success,
|
||||
SuccessWithWarnings,
|
||||
Failure,
|
||||
}
|
||||
|
||||
// ── ORAO step events ────────────────────────────────────────────────────────
|
||||
|
||||
/// A single event emitted during an ORAO round, analogous to `ReactStep`.
|
||||
///
|
||||
/// These are yielded via the streaming callback so the caller can persist
|
||||
/// them or forward them to a frontend.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum OraoStep {
|
||||
/// Initial observation: environment snapshot before any action.
|
||||
Observe {
|
||||
round: usize,
|
||||
snapshot: PerceptionSnapshot,
|
||||
},
|
||||
/// The reasoning/analysis output, including the plan.
|
||||
Reason {
|
||||
round: usize,
|
||||
reasoning: ReasoningOutput,
|
||||
},
|
||||
/// An action is about to be executed.
|
||||
Act {
|
||||
round: usize,
|
||||
action: PlannedAction,
|
||||
safety_level: SafetyLevel,
|
||||
},
|
||||
/// The result observed after executing an action.
|
||||
ObserveResult {
|
||||
round: usize,
|
||||
result: ActionResult,
|
||||
},
|
||||
/// Plan mode: a plan has been generated and is awaiting user approval.
|
||||
PlanProposed {
|
||||
round: usize,
|
||||
reasoning: ReasoningOutput,
|
||||
},
|
||||
/// The task completed successfully.
|
||||
Completed {
|
||||
total_rounds: usize,
|
||||
summary: String,
|
||||
},
|
||||
/// The task failed (max rounds, deadlock, or explicit failure).
|
||||
Failed {
|
||||
total_rounds: usize,
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
// ── Round record (audit) ────────────────────────────────────────────────────
|
||||
|
||||
/// A persistent record of one ORAO round, used for audit and resumption.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RoundRecord {
|
||||
/// Round number (1-indexed).
|
||||
pub round: usize,
|
||||
/// Summary of the Observe phase.
|
||||
pub observe_summary: String,
|
||||
/// Summary of the Reasoning phase.
|
||||
pub reasoning_summary: String,
|
||||
/// The action that was executed.
|
||||
pub action: Option<PlannedAction>,
|
||||
/// Result observed after the action.
|
||||
pub result_summary: Option<String>,
|
||||
/// Tokens consumed this round.
|
||||
pub tokens_input: u64,
|
||||
pub tokens_output: u64,
|
||||
/// Wall-clock duration of this round in milliseconds.
|
||||
pub duration_ms: u64,
|
||||
}
|
||||
|
||||
// ── ORAO configuration ──────────────────────────────────────────────────────
|
||||
|
||||
/// Configuration for an ORAO execution.
|
||||
#[derive(Clone)]
|
||||
pub struct OraoConfig {
|
||||
/// Maximum number of ORAO rounds before giving up.
|
||||
pub max_rounds: usize,
|
||||
/// Maximum allowed safety level. Actions above this level are denied.
|
||||
pub max_safety_level: SafetyLevel,
|
||||
/// Whether to run in plan mode (generate plan first, wait for approval).
|
||||
pub plan_mode: bool,
|
||||
/// Whether to enable extended thinking for the reasoning phase.
|
||||
pub extended_thinking: bool,
|
||||
/// Per-action timeout in seconds.
|
||||
pub action_timeout_secs: u64,
|
||||
/// Number of consecutive no-change rounds before deadlock detection triggers.
|
||||
pub deadlock_threshold: usize,
|
||||
}
|
||||
|
||||
impl Default for OraoConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_rounds: 50,
|
||||
max_safety_level: SafetyLevel::Network,
|
||||
plan_mode: false,
|
||||
extended_thinking: false,
|
||||
action_timeout_secs: 120,
|
||||
deadlock_threshold: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for OraoConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("OraoConfig")
|
||||
.field("max_rounds", &self.max_rounds)
|
||||
.field("max_safety_level", &self.max_safety_level)
|
||||
.field("plan_mode", &self.plan_mode)
|
||||
.field("extended_thinking", &self.extended_thinking)
|
||||
.field("action_timeout_secs", &self.action_timeout_secs)
|
||||
.field("deadlock_threshold", &self.deadlock_threshold)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
171
libs/agent/task/events.rs
Normal file
171
libs/agent/task/events.rs
Normal file
@ -0,0 +1,171 @@
|
||||
use models::agent_task::TaskStatus;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Event payload published to WebSocket clients via Redis Pub/Sub.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct TaskEvent {
|
||||
pub task_id: i64,
|
||||
pub project_id: uuid::Uuid,
|
||||
pub parent_id: Option<i64>,
|
||||
pub event: String,
|
||||
pub message: Option<String>,
|
||||
pub output: Option<String>,
|
||||
pub error: Option<String>,
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
impl TaskEvent {
|
||||
pub fn started(task_id: i64, project_id: uuid::Uuid, parent_id: Option<i64>) -> Self {
|
||||
Self {
|
||||
task_id,
|
||||
project_id,
|
||||
parent_id,
|
||||
event: "started".to_string(),
|
||||
message: None,
|
||||
output: None,
|
||||
error: None,
|
||||
status: TaskStatus::Running.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn progress(
|
||||
task_id: i64,
|
||||
project_id: uuid::Uuid,
|
||||
parent_id: Option<i64>,
|
||||
msg: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
task_id,
|
||||
project_id,
|
||||
parent_id,
|
||||
event: "progress".to_string(),
|
||||
message: Some(msg),
|
||||
output: None,
|
||||
error: None,
|
||||
status: TaskStatus::Running.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn completed(
|
||||
task_id: i64,
|
||||
project_id: uuid::Uuid,
|
||||
parent_id: Option<i64>,
|
||||
output: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
task_id,
|
||||
project_id,
|
||||
parent_id,
|
||||
event: "done".to_string(),
|
||||
message: None,
|
||||
output: Some(output),
|
||||
error: None,
|
||||
status: TaskStatus::Done.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn failed(
|
||||
task_id: i64,
|
||||
project_id: uuid::Uuid,
|
||||
parent_id: Option<i64>,
|
||||
error: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
task_id,
|
||||
project_id,
|
||||
parent_id,
|
||||
event: "failed".to_string(),
|
||||
message: None,
|
||||
output: None,
|
||||
error: Some(error),
|
||||
status: TaskStatus::Failed.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cancelled(task_id: i64, project_id: uuid::Uuid, parent_id: Option<i64>) -> Self {
|
||||
Self {
|
||||
task_id,
|
||||
project_id,
|
||||
parent_id,
|
||||
event: "cancelled".to_string(),
|
||||
message: None,
|
||||
output: None,
|
||||
error: None,
|
||||
status: TaskStatus::Cancelled.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper trait for publishing task lifecycle events via Redis Pub/Sub.
|
||||
///
|
||||
/// Callers inject a suitable `publish_fn` at construction time via
|
||||
/// `TaskEvents::new(...)`. If no publisher is supplied events are silently
|
||||
/// dropped (graceful degradation on startup).
|
||||
pub trait TaskEventPublisher: Send + Sync {
|
||||
fn publish(&self, project_id: uuid::Uuid, event: TaskEvent);
|
||||
}
|
||||
|
||||
/// No-op publisher used when no Redis Pub/Sub connection is available.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct NoOpPublisher;
|
||||
|
||||
impl TaskEventPublisher for NoOpPublisher {
|
||||
fn publish(&self, _: uuid::Uuid, _: TaskEvent) {}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TaskEvents {
|
||||
publisher: Arc<dyn TaskEventPublisher>,
|
||||
}
|
||||
|
||||
impl TaskEvents {
|
||||
pub fn new(publisher: impl TaskEventPublisher + 'static) -> Self {
|
||||
Self {
|
||||
publisher: Arc::new(publisher),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn noop() -> Self {
|
||||
Self::new(NoOpPublisher)
|
||||
}
|
||||
|
||||
fn emit(&self, task: &models::agent_task::Model, event: TaskEvent) {
|
||||
self.publisher.publish(task.project_uuid, event);
|
||||
}
|
||||
|
||||
pub fn emit_started(&self, task: &models::agent_task::Model) {
|
||||
self.emit(
|
||||
task,
|
||||
TaskEvent::started(task.id, task.project_uuid, task.parent_id),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn emit_progress(&self, task: &models::agent_task::Model, msg: String) {
|
||||
self.emit(
|
||||
task,
|
||||
TaskEvent::progress(task.id, task.project_uuid, task.parent_id, msg),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn emit_completed(&self, task: &models::agent_task::Model, output: String) {
|
||||
self.emit(
|
||||
task,
|
||||
TaskEvent::completed(task.id, task.project_uuid, task.parent_id, output),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn emit_failed(&self, task: &models::agent_task::Model, error: String) {
|
||||
self.emit(
|
||||
task,
|
||||
TaskEvent::failed(task.id, task.project_uuid, task.parent_id, error),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn emit_cancelled(&self, task: &models::agent_task::Model) {
|
||||
self.emit(
|
||||
task,
|
||||
TaskEvent::cancelled(task.id, task.project_uuid, task.parent_id),
|
||||
);
|
||||
}
|
||||
}
|
||||
192
libs/agent/task/lifecycle.rs
Normal file
192
libs/agent/task/lifecycle.rs
Normal file
@ -0,0 +1,192 @@
|
||||
use models::agent_task::{ActiveModel, Column as C, Entity, Model, TaskStatus};
|
||||
use sea_orm::{ActiveModelTrait, ColumnTrait, DbErr, EntityTrait, QueryFilter};
|
||||
|
||||
pub struct TaskLifecycle;
|
||||
|
||||
impl super::TaskService {
|
||||
/// Mark a task as running and record the start time.
|
||||
pub async fn start(&self, task_id: i64) -> Result<Model, DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(self.db()).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Running);
|
||||
active.started_at = sea_orm::Set(Some(chrono::Utc::now().into()));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
let updated = active.update(self.db()).await?;
|
||||
self.events().emit_started(&updated);
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// Update progress text (e.g., "step 2/5: analyzing PR").
|
||||
pub async fn update_progress(
|
||||
&self,
|
||||
task_id: i64,
|
||||
progress: impl Into<String>,
|
||||
) -> Result<(), DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(self.db()).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
let progress_str = progress.into();
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.progress = sea_orm::Set(Some(progress_str.clone()));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
let updated = active.update(self.db()).await?;
|
||||
self.events().emit_progress(&updated, progress_str);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Mark a task as completed with the output text.
|
||||
pub async fn complete(&self, task_id: i64, output: impl Into<String>) -> Result<Model, DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(self.db()).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Done);
|
||||
let out = output.into();
|
||||
active.output = sea_orm::Set(Some(out.clone()));
|
||||
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
let updated = active.update(self.db()).await?;
|
||||
self.events().emit_completed(&updated, out);
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// Mark a task as failed with an error message.
|
||||
pub async fn fail(&self, task_id: i64, error: impl Into<String>) -> Result<Model, DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(self.db()).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Failed);
|
||||
let err = error.into();
|
||||
active.error = sea_orm::Set(Some(err.clone()));
|
||||
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
let updated = active.update(self.db()).await?;
|
||||
self.events().emit_failed(&updated, err);
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// Propagate child task status up the tree.
|
||||
///
|
||||
/// Only allows cancelling tasks that are not yet in a terminal state
|
||||
/// (Pending / Running / Paused).
|
||||
///
|
||||
/// Cancelled children are marked done so that `are_children_done()` returns
|
||||
/// true for the parent after cancellation.
|
||||
pub async fn cancel(&self, task_id: i64) -> Result<Model, DbErr> {
|
||||
// Collect all task IDs (parent + descendants) using an explicit stack.
|
||||
let mut stack = vec![task_id];
|
||||
let mut idx = 0;
|
||||
while idx < stack.len() {
|
||||
let current = stack[idx];
|
||||
let children = Entity::find()
|
||||
.filter(C::ParentId.eq(current))
|
||||
.all(self.db())
|
||||
.await?;
|
||||
for child in children {
|
||||
stack.push(child.id);
|
||||
}
|
||||
idx += 1;
|
||||
}
|
||||
|
||||
// Mark every collected task as cancelled (terminal state).
|
||||
for id in &stack {
|
||||
let model = Entity::find_by_id(*id).one(self.db()).await?;
|
||||
if let Some(m) = model {
|
||||
if !m.is_done() {
|
||||
let mut active: ActiveModel = m.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Cancelled);
|
||||
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
active.update(self.db()).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let final_model = Entity::find_by_id(task_id)
|
||||
.one(self.db())
|
||||
.await?
|
||||
.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
self.events().emit_cancelled(&final_model);
|
||||
Ok(final_model)
|
||||
}
|
||||
|
||||
/// Pause a running or pending task.
|
||||
///
|
||||
/// Pausing a task that is not Pending/Running is a no-op that returns
|
||||
/// the current model (same behaviour as `start` on an already-running task).
|
||||
pub async fn pause(&self, task_id: i64) -> Result<Model, DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(self.db()).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
if !model.is_running() {
|
||||
// Already in a terminal or paused state — return unchanged.
|
||||
return Ok(model);
|
||||
}
|
||||
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Paused);
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
active.update(self.db()).await
|
||||
}
|
||||
|
||||
/// Resume a paused task back to Running.
|
||||
///
|
||||
/// Returns an error if the task is not currently Paused.
|
||||
pub async fn resume(&self, task_id: i64) -> Result<Model, DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(self.db()).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
if model.status != TaskStatus::Paused {
|
||||
return Err(DbErr::Custom(format!(
|
||||
"cannot resume task {}: expected status Paused, got {}",
|
||||
task_id, model.status
|
||||
)));
|
||||
}
|
||||
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Running);
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
active.update(self.db()).await
|
||||
}
|
||||
|
||||
/// Retry a failed or cancelled task by resetting it to Pending.
|
||||
///
|
||||
/// Clears `output`, `error`, and `done_at`; increments `retry_count`.
|
||||
/// Only tasks in Failed or Cancelled state can be retried.
|
||||
pub async fn retry(&self, task_id: i64) -> Result<Model, DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(self.db()).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
match model.status {
|
||||
TaskStatus::Failed | TaskStatus::Cancelled | TaskStatus::Done => {}
|
||||
_ => {
|
||||
return Err(DbErr::Custom(format!(
|
||||
"cannot retry task {}: only Failed/Cancelled/Done tasks can be retried (got {})",
|
||||
task_id, model.status
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let retry_count = model.retry_count.map(|c| c + 1).unwrap_or(1);
|
||||
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Pending);
|
||||
active.output = sea_orm::Set(None);
|
||||
active.error = sea_orm::Set(None);
|
||||
active.done_at = sea_orm::Set(None);
|
||||
active.started_at = sea_orm::Set(None);
|
||||
active.retry_count = sea_orm::Set(Some(retry_count));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
active.update(self.db()).await
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
//! Agent task service — unified task/sub-agent execution framework.
|
||||
//! Agent task service — managing task/sub-agent execution lifecycle.
|
||||
//!
|
||||
//! A task (`agent_task` record) can be:
|
||||
//! - A **root task**: initiated by a user or system event.
|
||||
@ -17,6 +17,61 @@
|
||||
//! This module is intentionally kept simple and synchronous with the DB.
|
||||
//! Long-running execution is delegated to the caller (tokio::spawn).
|
||||
|
||||
pub mod service;
|
||||
pub mod events;
|
||||
pub mod lifecycle;
|
||||
pub mod store;
|
||||
pub mod tree;
|
||||
|
||||
pub use service::TaskService;
|
||||
use db::database::AppDatabase;
|
||||
|
||||
pub use events::{NoOpPublisher, TaskEvent, TaskEventPublisher, TaskEvents};
|
||||
pub use lifecycle::TaskLifecycle;
|
||||
|
||||
/// Service for managing agent tasks (root tasks and sub-tasks).
|
||||
#[derive(Clone)]
|
||||
pub struct TaskService {
|
||||
db: AppDatabase,
|
||||
events: TaskEvents,
|
||||
}
|
||||
|
||||
impl TaskService {
|
||||
pub fn new(db: AppDatabase) -> Self {
|
||||
Self {
|
||||
db,
|
||||
events: TaskEvents::noop(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_events(db: AppDatabase, events: TaskEvents) -> Self {
|
||||
Self { db, events }
|
||||
}
|
||||
|
||||
pub(crate) fn db(&self) -> &AppDatabase {
|
||||
&self.db
|
||||
}
|
||||
|
||||
pub(crate) fn events(&self) -> &TaskEvents {
|
||||
&self.events
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for TaskService so that the events publisher can be set independently
|
||||
/// of the database connection.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct TaskServiceBuilder {
|
||||
events: Option<TaskEvents>,
|
||||
}
|
||||
|
||||
impl TaskServiceBuilder {
|
||||
pub fn with_events(mut self, events: TaskEvents) -> Self {
|
||||
self.events = Some(events);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn build(self, db: AppDatabase) -> TaskService {
|
||||
TaskService {
|
||||
db,
|
||||
events: self.events.unwrap_or_else(TaskEvents::noop),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,600 +0,0 @@
|
||||
//! Task service for creating, tracking, and executing agent tasks.
|
||||
//!
|
||||
//! All methods are async and interact with the database directly.
|
||||
//! Execution of the task logic (running the ReAct loop, etc.) is delegated
|
||||
//! to the caller — this service only manages task lifecycle and state.
|
||||
|
||||
use db::database::AppDatabase;
|
||||
use models::agent_task::{ActiveModel, AgentType, Column as C, Entity, Model, TaskStatus};
|
||||
use models::IssueId;
|
||||
use sea_orm::{
|
||||
entity::EntityTrait, query::{QueryFilter, QueryOrder, QuerySelect}, ActiveModelTrait,
|
||||
ColumnTrait,
|
||||
DbErr,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Event payload published to WebSocket clients via Redis Pub/Sub.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct TaskEvent {
|
||||
pub task_id: i64,
|
||||
pub project_id: uuid::Uuid,
|
||||
pub parent_id: Option<i64>,
|
||||
pub event: String,
|
||||
pub message: Option<String>,
|
||||
pub output: Option<String>,
|
||||
pub error: Option<String>,
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
impl TaskEvent {
|
||||
pub fn started(task_id: i64, project_id: uuid::Uuid, parent_id: Option<i64>) -> Self {
|
||||
Self {
|
||||
task_id,
|
||||
project_id,
|
||||
parent_id,
|
||||
event: "started".to_string(),
|
||||
message: None,
|
||||
output: None,
|
||||
error: None,
|
||||
status: TaskStatus::Running.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn progress(
|
||||
task_id: i64,
|
||||
project_id: uuid::Uuid,
|
||||
parent_id: Option<i64>,
|
||||
msg: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
task_id,
|
||||
project_id,
|
||||
parent_id,
|
||||
event: "progress".to_string(),
|
||||
message: Some(msg),
|
||||
output: None,
|
||||
error: None,
|
||||
status: TaskStatus::Running.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn completed(
|
||||
task_id: i64,
|
||||
project_id: uuid::Uuid,
|
||||
parent_id: Option<i64>,
|
||||
output: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
task_id,
|
||||
project_id,
|
||||
parent_id,
|
||||
event: "done".to_string(),
|
||||
message: None,
|
||||
output: Some(output),
|
||||
error: None,
|
||||
status: TaskStatus::Done.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn failed(
|
||||
task_id: i64,
|
||||
project_id: uuid::Uuid,
|
||||
parent_id: Option<i64>,
|
||||
error: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
task_id,
|
||||
project_id,
|
||||
parent_id,
|
||||
event: "failed".to_string(),
|
||||
message: None,
|
||||
output: None,
|
||||
error: Some(error),
|
||||
status: TaskStatus::Failed.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cancelled(task_id: i64, project_id: uuid::Uuid, parent_id: Option<i64>) -> Self {
|
||||
Self {
|
||||
task_id,
|
||||
project_id,
|
||||
parent_id,
|
||||
event: "cancelled".to_string(),
|
||||
message: None,
|
||||
output: None,
|
||||
error: None,
|
||||
status: TaskStatus::Cancelled.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper trait for publishing task lifecycle events via Redis Pub/Sub.
|
||||
///
|
||||
/// Callers inject a suitable `publish_fn` at construction time via
|
||||
/// `TaskEvents::new(...)`. If no publisher is supplied events are silently
|
||||
/// dropped (graceful degradation on startup).
|
||||
pub trait TaskEventPublisher: Send + Sync {
|
||||
fn publish(&self, project_id: uuid::Uuid, event: TaskEvent);
|
||||
}
|
||||
|
||||
/// No-op publisher used when no Redis Pub/Sub connection is available.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct NoOpPublisher;
|
||||
|
||||
impl TaskEventPublisher for NoOpPublisher {
|
||||
fn publish(&self, _: uuid::Uuid, _: TaskEvent) {}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TaskEvents {
|
||||
publisher: Arc<dyn TaskEventPublisher>,
|
||||
}
|
||||
|
||||
impl TaskEvents {
|
||||
pub fn new(publisher: impl TaskEventPublisher + 'static) -> Self {
|
||||
Self {
|
||||
publisher: Arc::new(publisher),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn noop() -> Self {
|
||||
Self::new(NoOpPublisher)
|
||||
}
|
||||
|
||||
fn emit(&self, task: &Model, event: TaskEvent) {
|
||||
self.publisher.publish(task.project_uuid, event);
|
||||
}
|
||||
|
||||
pub fn emit_started(&self, task: &Model) {
|
||||
self.emit(
|
||||
task,
|
||||
TaskEvent::started(task.id, task.project_uuid, task.parent_id),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn emit_progress(&self, task: &Model, msg: String) {
|
||||
self.emit(
|
||||
task,
|
||||
TaskEvent::progress(task.id, task.project_uuid, task.parent_id, msg),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn emit_completed(&self, task: &Model, output: String) {
|
||||
self.emit(
|
||||
task,
|
||||
TaskEvent::completed(task.id, task.project_uuid, task.parent_id, output),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn emit_failed(&self, task: &Model, error: String) {
|
||||
self.emit(
|
||||
task,
|
||||
TaskEvent::failed(task.id, task.project_uuid, task.parent_id, error),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn emit_cancelled(&self, task: &Model) {
|
||||
self.emit(
|
||||
task,
|
||||
TaskEvent::cancelled(task.id, task.project_uuid, task.parent_id),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for TaskService so that the events publisher can be set independently
|
||||
/// of the database connection.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct TaskServiceBuilder {
|
||||
events: Option<TaskEvents>,
|
||||
}
|
||||
|
||||
impl TaskServiceBuilder {
|
||||
pub fn with_events(mut self, events: TaskEvents) -> Self {
|
||||
self.events = Some(events);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn build(self, db: AppDatabase) -> TaskService {
|
||||
TaskService {
|
||||
db,
|
||||
events: self.events.unwrap_or_else(TaskEvents::noop),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Service for managing agent tasks (root tasks and sub-tasks).
|
||||
#[derive(Clone)]
|
||||
pub struct TaskService {
|
||||
db: AppDatabase,
|
||||
events: TaskEvents,
|
||||
}
|
||||
|
||||
impl TaskService {
|
||||
pub fn new(db: AppDatabase) -> Self {
|
||||
Self {
|
||||
db,
|
||||
events: TaskEvents::noop(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_events(db: AppDatabase, events: TaskEvents) -> Self {
|
||||
Self { db, events }
|
||||
}
|
||||
|
||||
/// Create a new task (root or sub-task) with status = pending.
|
||||
pub async fn create(
|
||||
&self,
|
||||
project_uuid: impl Into<uuid::Uuid>,
|
||||
input: impl Into<String>,
|
||||
agent_type: AgentType,
|
||||
) -> Result<Model, DbErr> {
|
||||
self.create_with_parent(project_uuid, None, input, agent_type, None, None)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Create a new task bound to an issue.
|
||||
pub async fn create_for_issue(
|
||||
&self,
|
||||
project_uuid: impl Into<uuid::Uuid>,
|
||||
issue_id: IssueId,
|
||||
input: impl Into<String>,
|
||||
agent_type: AgentType,
|
||||
) -> Result<Model, DbErr> {
|
||||
self.create_with_parent(project_uuid, None, input, agent_type, None, Some(issue_id))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Create a new sub-task with a parent reference.
|
||||
pub async fn create_subtask(
|
||||
&self,
|
||||
project_uuid: impl Into<uuid::Uuid>,
|
||||
parent_id: i64,
|
||||
input: impl Into<String>,
|
||||
agent_type: AgentType,
|
||||
title: Option<String>,
|
||||
) -> Result<Model, DbErr> {
|
||||
self.create_with_parent(
|
||||
project_uuid,
|
||||
Some(parent_id),
|
||||
input,
|
||||
agent_type,
|
||||
title,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn create_with_parent(
|
||||
&self,
|
||||
project_uuid: impl Into<uuid::Uuid>,
|
||||
parent_id: Option<i64>,
|
||||
input: impl Into<String>,
|
||||
agent_type: AgentType,
|
||||
title: Option<String>,
|
||||
issue_id: Option<IssueId>,
|
||||
) -> Result<Model, DbErr> {
|
||||
let model = ActiveModel {
|
||||
project_uuid: sea_orm::Set(project_uuid.into()),
|
||||
parent_id: sea_orm::Set(parent_id),
|
||||
issue_id: sea_orm::Set(issue_id),
|
||||
agent_type: sea_orm::Set(agent_type),
|
||||
status: sea_orm::Set(TaskStatus::Pending),
|
||||
title: sea_orm::Set(title),
|
||||
input: sea_orm::Set(input.into()),
|
||||
..Default::default()
|
||||
};
|
||||
model.insert(&self.db).await
|
||||
}
|
||||
|
||||
/// Mark a task as running and record the start time.
|
||||
pub async fn start(&self, task_id: i64) -> Result<Model, DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(&self.db).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Running);
|
||||
active.started_at = sea_orm::Set(Some(chrono::Utc::now().into()));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
let updated = active.update(&self.db).await?;
|
||||
self.events.emit_started(&updated);
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// Update progress text (e.g., "step 2/5: analyzing PR").
|
||||
pub async fn update_progress(
|
||||
&self,
|
||||
task_id: i64,
|
||||
progress: impl Into<String>,
|
||||
) -> Result<(), DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(&self.db).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
let progress_str = progress.into();
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.progress = sea_orm::Set(Some(progress_str.clone()));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
let updated = active.update(&self.db).await?;
|
||||
self.events.emit_progress(&updated, progress_str);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Mark a task as completed with the output text.
|
||||
pub async fn complete(&self, task_id: i64, output: impl Into<String>) -> Result<Model, DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(&self.db).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Done);
|
||||
let out = output.into();
|
||||
active.output = sea_orm::Set(Some(out.clone()));
|
||||
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
let updated = active.update(&self.db).await?;
|
||||
self.events.emit_completed(&updated, out);
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// Mark a task as failed with an error message.
|
||||
pub async fn fail(&self, task_id: i64, error: impl Into<String>) -> Result<Model, DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(&self.db).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Failed);
|
||||
let err = error.into();
|
||||
active.error = sea_orm::Set(Some(err.clone()));
|
||||
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
let updated = active.update(&self.db).await?;
|
||||
self.events.emit_failed(&updated, err);
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// Propagate child task status up the tree.
|
||||
///
|
||||
/// Only allows cancelling tasks that are not yet in a terminal state
|
||||
/// (Pending / Running / Paused).
|
||||
///
|
||||
/// Cancelled children are marked done so that `are_children_done()` returns
|
||||
/// true for the parent after cancellation.
|
||||
pub async fn cancel(&self, task_id: i64) -> Result<Model, DbErr> {
|
||||
// Collect all task IDs (parent + descendants) using an explicit stack.
|
||||
let mut stack = vec![task_id];
|
||||
let mut idx = 0;
|
||||
while idx < stack.len() {
|
||||
let current = stack[idx];
|
||||
let children = Entity::find()
|
||||
.filter(C::ParentId.eq(current))
|
||||
.all(&self.db)
|
||||
.await?;
|
||||
for child in children {
|
||||
stack.push(child.id);
|
||||
}
|
||||
idx += 1;
|
||||
}
|
||||
|
||||
// Mark every collected task as cancelled (terminal state).
|
||||
for id in &stack {
|
||||
let model = Entity::find_by_id(*id).one(&self.db).await?;
|
||||
if let Some(m) = model {
|
||||
if !m.is_done() {
|
||||
let mut active: ActiveModel = m.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Cancelled);
|
||||
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
active.update(&self.db).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let final_model = Entity::find_by_id(task_id)
|
||||
.one(&self.db)
|
||||
.await?
|
||||
.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
self.events.emit_cancelled(&final_model);
|
||||
Ok(final_model)
|
||||
}
|
||||
|
||||
/// Pause a running or pending task.
|
||||
///
|
||||
/// Pausing a task that is not Pending/Running is a no-op that returns
|
||||
/// the current model (same behaviour as `start` on an already-running task).
|
||||
pub async fn pause(&self, task_id: i64) -> Result<Model, DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(&self.db).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
if !model.is_running() {
|
||||
// Already in a terminal or paused state — return unchanged.
|
||||
return Ok(model);
|
||||
}
|
||||
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Paused);
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
active.update(&self.db).await
|
||||
}
|
||||
|
||||
/// Resume a paused task back to Running.
|
||||
///
|
||||
/// Returns an error if the task is not currently Paused.
|
||||
pub async fn resume(&self, task_id: i64) -> Result<Model, DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(&self.db).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
if model.status != TaskStatus::Paused {
|
||||
return Err(DbErr::Custom(format!(
|
||||
"cannot resume task {}: expected status Paused, got {}",
|
||||
task_id, model.status
|
||||
)));
|
||||
}
|
||||
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Running);
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
active.update(&self.db).await
|
||||
}
|
||||
|
||||
/// Retry a failed or cancelled task by resetting it to Pending.
|
||||
///
|
||||
/// Clears `output`, `error`, and `done_at`; increments `retry_count`.
|
||||
/// Only tasks in Failed or Cancelled state can be retried.
|
||||
pub async fn retry(&self, task_id: i64) -> Result<Model, DbErr> {
|
||||
let model = Entity::find_by_id(task_id).one(&self.db).await?;
|
||||
let model =
|
||||
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
match model.status {
|
||||
TaskStatus::Failed | TaskStatus::Cancelled | TaskStatus::Done => {}
|
||||
_ => {
|
||||
return Err(DbErr::Custom(format!(
|
||||
"cannot retry task {}: only Failed/Cancelled/Done tasks can be retried (got {})",
|
||||
task_id, model.status
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let retry_count = model.retry_count.map(|c| c + 1).unwrap_or(1);
|
||||
|
||||
let mut active: ActiveModel = model.into();
|
||||
active.status = sea_orm::Set(TaskStatus::Pending);
|
||||
active.output = sea_orm::Set(None);
|
||||
active.error = sea_orm::Set(None);
|
||||
active.done_at = sea_orm::Set(None);
|
||||
active.started_at = sea_orm::Set(None);
|
||||
active.retry_count = sea_orm::Set(Some(retry_count));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
active.update(&self.db).await
|
||||
}
|
||||
|
||||
/// Propagate child task status up the tree.
|
||||
///
|
||||
/// When a child task reaches a terminal state, checks whether all its
|
||||
/// siblings are also terminal. If so, marks the parent appropriately:
|
||||
/// - Done if any child succeeded
|
||||
/// - Failed if all children failed or were cancelled
|
||||
pub async fn propagate_to_parent(&self, task_id: i64) -> Result<Option<Model>, DbErr> {
|
||||
let model = self
|
||||
.get(task_id)
|
||||
.await?
|
||||
.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
let Some(parent_id) = model.parent_id else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let siblings = self.children(parent_id).await?;
|
||||
if siblings.iter().all(|s| s.is_done()) {
|
||||
let parent = self.get(parent_id).await?.ok_or_else(|| {
|
||||
DbErr::RecordNotFound(format!("parent task {} not found", parent_id))
|
||||
})?;
|
||||
if parent.is_running() {
|
||||
let mut active: ActiveModel = parent.into();
|
||||
let has_success = siblings.iter().any(|s| s.status == TaskStatus::Done);
|
||||
if has_success {
|
||||
active.status = sea_orm::Set(TaskStatus::Done);
|
||||
active.error = sea_orm::Set(None);
|
||||
} else {
|
||||
active.status = sea_orm::Set(TaskStatus::Failed);
|
||||
active.error =
|
||||
sea_orm::Set(Some("All sub-tasks failed or were cancelled".to_string()));
|
||||
}
|
||||
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
let updated = active.update(&self.db).await?;
|
||||
return Ok(Some(updated));
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Get a task by ID.
|
||||
pub async fn get(&self, task_id: i64) -> Result<Option<Model>, DbErr> {
|
||||
Entity::find_by_id(task_id).one(&self.db).await
|
||||
}
|
||||
|
||||
/// List all sub-tasks for a parent task.
|
||||
pub async fn children(&self, parent_id: i64) -> Result<Vec<Model>, DbErr> {
|
||||
Entity::find()
|
||||
.filter(C::ParentId.eq(parent_id))
|
||||
.order_by_asc(C::CreatedAt)
|
||||
.all(&self.db)
|
||||
.await
|
||||
}
|
||||
|
||||
/// List all active (non-terminal) tasks for a project.
|
||||
pub async fn active_tasks(
|
||||
&self,
|
||||
project_uuid: impl Into<uuid::Uuid>,
|
||||
) -> Result<Vec<Model>, DbErr> {
|
||||
let uuid: uuid::Uuid = project_uuid.into();
|
||||
Entity::find()
|
||||
.filter(C::ProjectUuid.eq(uuid))
|
||||
.filter(C::Status.is_in([TaskStatus::Pending, TaskStatus::Running, TaskStatus::Paused]))
|
||||
.order_by_desc(C::CreatedAt)
|
||||
.all(&self.db)
|
||||
.await
|
||||
}
|
||||
|
||||
/// List all tasks (root only) for a project.
|
||||
pub async fn list(
|
||||
&self,
|
||||
project_uuid: impl Into<uuid::Uuid>,
|
||||
limit: u64,
|
||||
) -> Result<Vec<Model>, DbErr> {
|
||||
let uuid: uuid::Uuid = project_uuid.into();
|
||||
Entity::find()
|
||||
.filter(C::ProjectUuid.eq(uuid))
|
||||
.filter(C::ParentId.is_null())
|
||||
.order_by_desc(C::CreatedAt)
|
||||
.limit(limit)
|
||||
.all(&self.db)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Delete a task and all its sub-tasks recursively.
|
||||
/// Only allows deletion of root tasks.
|
||||
pub async fn delete(&self, task_id: i64) -> Result<(), DbErr> {
|
||||
self.delete_recursive(task_id).await
|
||||
}
|
||||
|
||||
async fn delete_recursive(&self, task_id: i64) -> Result<(), DbErr> {
|
||||
// Collect all task IDs to delete using an explicit stack (avoiding async recursion).
|
||||
let mut stack = vec![task_id];
|
||||
let mut idx = 0;
|
||||
while idx < stack.len() {
|
||||
let current = stack[idx];
|
||||
let children = Entity::find()
|
||||
.filter(C::ParentId.eq(current))
|
||||
.all(&self.db)
|
||||
.await?;
|
||||
for child in children {
|
||||
stack.push(child.id);
|
||||
}
|
||||
idx += 1;
|
||||
}
|
||||
|
||||
for task_id in stack {
|
||||
let model = Entity::find_by_id(task_id).one(&self.db).await?;
|
||||
if let Some(m) = model {
|
||||
let active: ActiveModel = m.into();
|
||||
active.delete(&self.db).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if all sub-tasks of a given parent are in a terminal state.
|
||||
/// Returns true if there are no children (empty tree counts as done).
|
||||
pub async fn are_children_done(&self, parent_id: i64) -> Result<bool, DbErr> {
|
||||
let children = self.children(parent_id).await?;
|
||||
Ok(children.is_empty() || children.iter().all(|c| c.is_done()))
|
||||
}
|
||||
}
|
||||
109
libs/agent/task/store.rs
Normal file
109
libs/agent/task/store.rs
Normal file
@ -0,0 +1,109 @@
|
||||
use models::agent_task::{ActiveModel, AgentType, Entity, Model};
|
||||
use models::IssueId;
|
||||
use sea_orm::{ActiveModelTrait, ColumnTrait, DbErr, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
|
||||
|
||||
impl super::TaskService {
|
||||
/// Get a task by ID.
|
||||
pub async fn get(&self, task_id: i64) -> Result<Option<Model>, DbErr> {
|
||||
Entity::find_by_id(task_id).one(self.db()).await
|
||||
}
|
||||
|
||||
/// List all tasks (root only) for a project.
|
||||
pub async fn list(
|
||||
&self,
|
||||
project_uuid: impl Into<uuid::Uuid>,
|
||||
limit: u64,
|
||||
) -> Result<Vec<Model>, DbErr> {
|
||||
let uuid: uuid::Uuid = project_uuid.into();
|
||||
Entity::find()
|
||||
.filter(models::agent_task::Column::ProjectUuid.eq(uuid))
|
||||
.filter(models::agent_task::Column::ParentId.is_null())
|
||||
.order_by_desc(models::agent_task::Column::CreatedAt)
|
||||
.limit(limit)
|
||||
.all(self.db())
|
||||
.await
|
||||
}
|
||||
|
||||
/// List all active (non-terminal) tasks for a project.
|
||||
pub async fn active_tasks(
|
||||
&self,
|
||||
project_uuid: impl Into<uuid::Uuid>,
|
||||
) -> Result<Vec<Model>, DbErr> {
|
||||
let uuid: uuid::Uuid = project_uuid.into();
|
||||
Entity::find()
|
||||
.filter(models::agent_task::Column::ProjectUuid.eq(uuid))
|
||||
.filter(models::agent_task::Column::Status.is_in([
|
||||
models::agent_task::TaskStatus::Pending,
|
||||
models::agent_task::TaskStatus::Running,
|
||||
models::agent_task::TaskStatus::Paused,
|
||||
]))
|
||||
.order_by_desc(models::agent_task::Column::CreatedAt)
|
||||
.all(self.db())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Create a new task (root or sub-task) with status = pending.
|
||||
pub async fn create(
|
||||
&self,
|
||||
project_uuid: impl Into<uuid::Uuid>,
|
||||
input: impl Into<String>,
|
||||
agent_type: AgentType,
|
||||
) -> Result<Model, DbErr> {
|
||||
self.create_with_parent(project_uuid, None, input, agent_type, None, None)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Create a new task bound to an issue.
|
||||
pub async fn create_for_issue(
|
||||
&self,
|
||||
project_uuid: impl Into<uuid::Uuid>,
|
||||
issue_id: IssueId,
|
||||
input: impl Into<String>,
|
||||
agent_type: AgentType,
|
||||
) -> Result<Model, DbErr> {
|
||||
self.create_with_parent(project_uuid, None, input, agent_type, None, Some(issue_id))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Create a new sub-task with a parent reference.
|
||||
pub async fn create_subtask(
|
||||
&self,
|
||||
project_uuid: impl Into<uuid::Uuid>,
|
||||
parent_id: i64,
|
||||
input: impl Into<String>,
|
||||
agent_type: AgentType,
|
||||
title: Option<String>,
|
||||
) -> Result<Model, DbErr> {
|
||||
self.create_with_parent(
|
||||
project_uuid,
|
||||
Some(parent_id),
|
||||
input,
|
||||
agent_type,
|
||||
title,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn create_with_parent(
|
||||
&self,
|
||||
project_uuid: impl Into<uuid::Uuid>,
|
||||
parent_id: Option<i64>,
|
||||
input: impl Into<String>,
|
||||
agent_type: AgentType,
|
||||
title: Option<String>,
|
||||
issue_id: Option<IssueId>,
|
||||
) -> Result<Model, DbErr> {
|
||||
let model = ActiveModel {
|
||||
project_uuid: sea_orm::Set(project_uuid.into()),
|
||||
parent_id: sea_orm::Set(parent_id),
|
||||
issue_id: sea_orm::Set(issue_id),
|
||||
agent_type: sea_orm::Set(agent_type),
|
||||
status: sea_orm::Set(models::agent_task::TaskStatus::Pending),
|
||||
title: sea_orm::Set(title),
|
||||
input: sea_orm::Set(input.into()),
|
||||
..Default::default()
|
||||
};
|
||||
model.insert(self.db()).await
|
||||
}
|
||||
}
|
||||
89
libs/agent/task/tree.rs
Normal file
89
libs/agent/task/tree.rs
Normal file
@ -0,0 +1,89 @@
|
||||
use models::agent_task::{ActiveModel, Column as C, Entity, Model, TaskStatus};
|
||||
use sea_orm::{ActiveModelTrait, ColumnTrait, DbErr, EntityTrait, QueryFilter, QueryOrder};
|
||||
|
||||
impl super::TaskService {
|
||||
/// Propagate child task status up the tree.
|
||||
///
|
||||
/// When a child task reaches a terminal state, checks whether all its
|
||||
/// siblings are also terminal. If so, marks the parent appropriately:
|
||||
/// - Done if any child succeeded
|
||||
/// - Failed if all children failed or were cancelled
|
||||
pub async fn propagate_to_parent(&self, task_id: i64) -> Result<Option<Model>, DbErr> {
|
||||
let model = self
|
||||
.get(task_id)
|
||||
.await?
|
||||
.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
|
||||
|
||||
let Some(parent_id) = model.parent_id else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let siblings = self.children(parent_id).await?;
|
||||
if siblings.iter().all(|s| s.is_done()) {
|
||||
let parent = self.get(parent_id).await?.ok_or_else(|| {
|
||||
DbErr::RecordNotFound(format!("parent task {} not found", parent_id))
|
||||
})?;
|
||||
if parent.is_running() {
|
||||
let mut active: ActiveModel = parent.into();
|
||||
let has_success = siblings.iter().any(|s| s.status == TaskStatus::Done);
|
||||
if has_success {
|
||||
active.status = sea_orm::Set(TaskStatus::Done);
|
||||
active.error = sea_orm::Set(None);
|
||||
} else {
|
||||
active.status = sea_orm::Set(TaskStatus::Failed);
|
||||
active.error =
|
||||
sea_orm::Set(Some("All sub-tasks failed or were cancelled".to_string()));
|
||||
}
|
||||
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
|
||||
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
|
||||
let updated = active.update(self.db()).await?;
|
||||
return Ok(Some(updated));
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// List all sub-tasks for a parent task.
|
||||
pub async fn children(&self, parent_id: i64) -> Result<Vec<Model>, DbErr> {
|
||||
Entity::find()
|
||||
.filter(C::ParentId.eq(parent_id))
|
||||
.order_by_asc(C::CreatedAt)
|
||||
.all(self.db())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Check if all sub-tasks of a given parent are in a terminal state.
|
||||
/// Returns true if there are no children (empty tree counts as done).
|
||||
pub async fn are_children_done(&self, parent_id: i64) -> Result<bool, DbErr> {
|
||||
let children = self.children(parent_id).await?;
|
||||
Ok(children.is_empty() || children.iter().all(|c| c.is_done()))
|
||||
}
|
||||
|
||||
/// Delete a task and all its sub-tasks recursively.
|
||||
/// Only allows deletion of root tasks.
|
||||
pub async fn delete(&self, task_id: i64) -> Result<(), DbErr> {
|
||||
// Collect all task IDs to delete using an explicit stack (avoiding async recursion).
|
||||
let mut stack = vec![task_id];
|
||||
let mut idx = 0;
|
||||
while idx < stack.len() {
|
||||
let current = stack[idx];
|
||||
let children = Entity::find()
|
||||
.filter(C::ParentId.eq(current))
|
||||
.all(self.db())
|
||||
.await?;
|
||||
for child in children {
|
||||
stack.push(child.id);
|
||||
}
|
||||
idx += 1;
|
||||
}
|
||||
|
||||
for task_id in stack {
|
||||
let model = Entity::find_by_id(task_id).one(self.db()).await?;
|
||||
if let Some(m) = model {
|
||||
let active: ActiveModel = m.into();
|
||||
active.delete(self.db()).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -13,7 +13,7 @@ use std::collections::HashMap;
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use crate::error::{AgentError, Result};
|
||||
use crate::error::Result;
|
||||
|
||||
static TOKENIZER_CACHE: OnceLock<RwLock<HashMap<String, tiktoken_rs::CoreBPE>>> = OnceLock::new();
|
||||
|
||||
@ -173,12 +173,11 @@ fn get_tokenizer(model: &str) -> Result<tiktoken_rs::CoreBPE> {
|
||||
}
|
||||
|
||||
// Try model-specific tokenizer first
|
||||
let bpe = if let Ok(bpe) = tiktoken_rs::get_bpe_from_model(model) {
|
||||
let bpe: &'static _ = if let Ok(bpe) = tiktoken_rs::bpe_for_model(model) {
|
||||
bpe
|
||||
} else {
|
||||
// Fallback: use cl100k_base for unknown models
|
||||
tiktoken_rs::cl100k_base()
|
||||
.map_err(|e| AgentError::Internal(format!("Failed to init tokenizer: {}", e)))?
|
||||
tiktoken_rs::cl100k_base_singleton()
|
||||
};
|
||||
|
||||
{
|
||||
@ -186,7 +185,7 @@ fn get_tokenizer(model: &str) -> Result<tiktoken_rs::CoreBPE> {
|
||||
cache.insert(model.to_string(), bpe.clone());
|
||||
}
|
||||
|
||||
Ok(bpe)
|
||||
Ok(bpe.clone())
|
||||
}
|
||||
|
||||
/// Estimate tokens for a simple prefix/suffix pattern (e.g., "assistant\n" + text).
|
||||
|
||||
@ -1,256 +0,0 @@
|
||||
# Hook Queue NATS JetStream Migration Guide
|
||||
|
||||
## Overview
|
||||
|
||||
The git hook queue now supports both Redis Lists and NATS JetStream as backend message queues. This allows gradual migration from Redis to NATS without downtime.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Producer (`ReceiveSyncService`)
|
||||
|
||||
The producer tries NATS first (if configured), then falls back to Redis:
|
||||
|
||||
```rust
|
||||
pub struct ReceiveSyncService {
|
||||
pool: deadpool_redis::cluster::Pool,
|
||||
redis_prefix: String,
|
||||
nats_publish: Option<Arc<dyn Fn(String, Vec<u8>) -> Pin<Box<dyn Future<Output = Result<u64>> + Send>> + Send + Sync>>,
|
||||
}
|
||||
```
|
||||
|
||||
### Consumer (`RedisConsumer`)
|
||||
|
||||
The consumer uses NATS if configured, otherwise falls back to Redis:
|
||||
|
||||
```rust
|
||||
pub struct RedisConsumer {
|
||||
pool: deadpool_redis::cluster::Pool,
|
||||
prefix: String,
|
||||
block_timeout_secs: u64,
|
||||
nats_consume: Option<NatsHookConsumeFn>,
|
||||
}
|
||||
```
|
||||
|
||||
## Integration with AppTransport
|
||||
|
||||
### Producer Integration
|
||||
|
||||
```rust
|
||||
use git::ssh::ReceiveSyncService;
|
||||
use transport::AppTransport;
|
||||
|
||||
let transport = Arc::new(AppTransport::new(/* ... */));
|
||||
|
||||
// Create NATS publish function
|
||||
let nats_publish = {
|
||||
let transport = transport.clone();
|
||||
Arc::new(move |subject: String, payload: Vec<u8>| {
|
||||
let transport = transport.clone();
|
||||
Box::pin(async move {
|
||||
let ack = transport.publish(&subject, payload).await?;
|
||||
Ok(ack.sequence)
|
||||
}) as Pin<Box<dyn Future<Output = anyhow::Result<u64>> + Send>>
|
||||
})
|
||||
};
|
||||
|
||||
// Create service with NATS support
|
||||
let sync_service = ReceiveSyncService::with_nats(redis_pool, nats_publish);
|
||||
|
||||
// Or use Redis-only mode
|
||||
let sync_service = ReceiveSyncService::new(redis_pool);
|
||||
```
|
||||
|
||||
### Consumer Integration
|
||||
|
||||
```rust
|
||||
use git::hook::pool::redis::{RedisConsumer, NatsHookConsumeFn};
|
||||
|
||||
// Create NATS consume function
|
||||
let nats_consume: NatsHookConsumeFn = {
|
||||
let transport = transport.clone();
|
||||
Arc::new(move |subject: String, batch_size: usize| {
|
||||
let transport = transport.clone();
|
||||
Box::pin(async move {
|
||||
let mut results = Vec::new();
|
||||
|
||||
// Pull messages from JetStream consumer
|
||||
for _ in 0..batch_size {
|
||||
match transport.pull_one(&subject).await {
|
||||
Ok(Some(msg)) => {
|
||||
let data = msg.payload.to_vec();
|
||||
let msg_clone = msg.clone();
|
||||
let ack_fn = Box::new(move || {
|
||||
let msg = msg_clone.clone();
|
||||
Box::pin(async move {
|
||||
msg.ack().await?;
|
||||
Ok(())
|
||||
}) as Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>>
|
||||
});
|
||||
results.push((data, ack_fn));
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}) as Pin<Box<dyn Future<Output = anyhow::Result<Vec<(Vec<u8>, Box<dyn Fn() -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>> + Send>)>>> + Send>>
|
||||
})
|
||||
};
|
||||
|
||||
// Create consumer with NATS support
|
||||
let consumer = RedisConsumer::with_nats(
|
||||
redis_pool,
|
||||
"{hook}".to_string(),
|
||||
5, // block_timeout_secs
|
||||
nats_consume,
|
||||
);
|
||||
|
||||
// Or use Redis-only mode
|
||||
let consumer = RedisConsumer::new(redis_pool, "{hook}".to_string(), 5);
|
||||
```
|
||||
|
||||
## Queue Subjects
|
||||
|
||||
The hook queue uses the following NATS subjects:
|
||||
|
||||
- `queue.hook.sync` - Repository sync tasks (git push/pull operations)
|
||||
|
||||
Additional task types can be added by extending the subject pattern:
|
||||
- `queue.hook.{task_type}` - Generic pattern for any hook task type
|
||||
|
||||
## Migration Strategy
|
||||
|
||||
### Phase 1: Dual Write (Current)
|
||||
- Producer writes to both NATS and Redis
|
||||
- Consumer reads from Redis only
|
||||
- Zero risk, full rollback capability
|
||||
|
||||
### Phase 2: Dual Read
|
||||
- Producer writes to both NATS and Redis
|
||||
- Consumer reads from NATS, falls back to Redis on error
|
||||
- Validates NATS consumer stability
|
||||
|
||||
### Phase 3: NATS Primary
|
||||
- Producer writes to NATS only (Redis disabled)
|
||||
- Consumer reads from NATS only
|
||||
- Redis queue deprecated
|
||||
|
||||
### Phase 4: Redis Removal
|
||||
- Remove Redis Lists code
|
||||
- Remove `pool` parameter
|
||||
- Simplify to NATS-only implementation
|
||||
|
||||
## NATS JetStream Setup
|
||||
|
||||
### Stream Configuration
|
||||
|
||||
```bash
|
||||
nats stream add HOOK_QUEUE \
|
||||
--subjects "queue.hook.>" \
|
||||
--retention limits \
|
||||
--max-msgs=-1 \
|
||||
--max-age=7d \
|
||||
--storage file \
|
||||
--replicas 3
|
||||
```
|
||||
|
||||
### Consumer Configuration
|
||||
|
||||
```bash
|
||||
nats consumer add HOOK_QUEUE hook-sync-worker \
|
||||
--filter "queue.hook.sync" \
|
||||
--ack explicit \
|
||||
--pull \
|
||||
--deliver all \
|
||||
--max-deliver 3 \
|
||||
--max-pending 100
|
||||
```
|
||||
|
||||
## Differences from Email Queue
|
||||
|
||||
### Redis Backend
|
||||
- **Email Queue**: Uses Redis Streams (XADD/XREADGROUP)
|
||||
- **Hook Queue**: Uses Redis Lists (LPUSH/BLMOVE)
|
||||
|
||||
### Atomicity
|
||||
- **Email Queue**: Consumer group provides at-least-once delivery
|
||||
- **Hook Queue**: BLMOVE provides atomic move-to-work-queue pattern
|
||||
|
||||
### Work Queue Pattern
|
||||
- **Email Queue**: No work queue, relies on consumer group
|
||||
- **Hook Queue**: Uses separate work queue (`{hook}:sync:work`) for in-flight tracking
|
||||
|
||||
### Acknowledgment
|
||||
- **Email Queue**: XACK removes from pending entries list
|
||||
- **Hook Queue**: LREM removes from work queue
|
||||
|
||||
### Retry Logic
|
||||
- **Email Queue**: Automatic via consumer group pending entries
|
||||
- **Hook Queue**: Manual via Lua script (LREM + LPUSH)
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Logs
|
||||
|
||||
- NATS publish: `"hook task queued to NATS"`
|
||||
- Redis publish: `"hook task queued to Redis"`
|
||||
- NATS consume: `"task dequeued from NATS"`
|
||||
- Redis consume: `"task dequeued"`
|
||||
|
||||
### Metrics
|
||||
|
||||
Add these metrics to track hook queue performance:
|
||||
|
||||
```rust
|
||||
counter!("hook_task_queued_total", "backend" => "nats").increment(1);
|
||||
counter!("hook_task_queued_total", "backend" => "redis").increment(1);
|
||||
counter!("hook_task_consumed_total", "backend" => "nats").increment(1);
|
||||
counter!("hook_task_consumed_total", "backend" => "redis").increment(1);
|
||||
```
|
||||
|
||||
## Rollback
|
||||
|
||||
To disable NATS and return to Redis-only:
|
||||
|
||||
```rust
|
||||
// Producer
|
||||
let sync_service = ReceiveSyncService::new(redis_pool);
|
||||
|
||||
// Consumer
|
||||
let consumer = RedisConsumer::new(redis_pool, "{hook}".to_string(), 5);
|
||||
```
|
||||
|
||||
No code changes required, just use the `new()` constructor instead of `with_nats()`.
|
||||
|
||||
## Benefits
|
||||
|
||||
1. **Zero Downtime**: Gradual migration with fallback
|
||||
2. **No Circular Dependency**: Uses function pointers instead of crate dependencies
|
||||
3. **Backward Compatible**: Existing code works without changes
|
||||
4. **Type Safe**: Compile-time guarantees for integration
|
||||
5. **Observable**: Consistent logging for both backends
|
||||
|
||||
## Known Limitations
|
||||
|
||||
### NATS Acknowledgment Timing
|
||||
|
||||
The current implementation acks NATS messages immediately after deserialization, not after successful processing. This is different from the Redis pattern where:
|
||||
|
||||
- Redis: Task moves to work queue → processes → acks (removes from work queue)
|
||||
- NATS: Task received → acks immediately → processes
|
||||
|
||||
**Future Enhancement**: Store ack functions in a map keyed by task ID, then call them after successful processing. This requires refactoring the worker loop to track pending acks.
|
||||
|
||||
### Work Queue Pattern
|
||||
|
||||
NATS JetStream doesn't have a direct equivalent to Redis's work queue pattern. The current implementation relies on JetStream's built-in redelivery mechanism instead of a separate work queue.
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Add NATS integration to `apps/git-hook/src/main.rs`
|
||||
2. Add configuration flags for queue backend selection
|
||||
3. Test dual-write mode in staging
|
||||
4. Monitor NATS consumer stability
|
||||
5. Implement proper ack-after-processing pattern
|
||||
6. Add metrics for queue depth and processing latency
|
||||
@ -1,107 +0,0 @@
|
||||
# 构建脚本
|
||||
|
||||
## 一键构建脚本
|
||||
|
||||
### build.js - 构建镜像
|
||||
|
||||
```bash
|
||||
# 构建所有镜像
|
||||
node scripts/build.js
|
||||
|
||||
# 构建指定服务
|
||||
node scripts/build.js app gitserver
|
||||
|
||||
# 指定 tag
|
||||
TAG=v1.0.0 node scripts/build.js
|
||||
|
||||
# 指定架构
|
||||
TARGET=aarch64-unknown-linux-gnu node scripts/build.js
|
||||
```
|
||||
|
||||
**环境变量:**
|
||||
|
||||
| 变量 | 默认值 | 说明 |
|
||||
|------------|------------------------------|-------------|
|
||||
| `REGISTRY` | `harbor.gitdata.me/gta_team` | 镜像仓库 |
|
||||
| `TAG` | `latest` | 镜像标签 |
|
||||
| `TARGET` | `x86_64-unknown-linux-gnu` | Rust 交叉编译目标 |
|
||||
|
||||
---
|
||||
|
||||
### push.js - 推送镜像
|
||||
|
||||
```bash
|
||||
# 推送所有镜像
|
||||
HARBOR_USERNAME=user HARBOR_PASSWORD=pass node scripts/push.js
|
||||
|
||||
# 推送指定服务
|
||||
HARBOR_USERNAME=user HARBOR_PASSWORD=pass TAG=sha-abc123 node scripts/push.js app
|
||||
```
|
||||
|
||||
**环境变量:**
|
||||
|
||||
| 变量 | 默认值 | 说明 |
|
||||
|-------------------|------------------------------|--------------|
|
||||
| `REGISTRY` | `harbor.gitdata.me/gta_team` | 镜像仓库 |
|
||||
| `TAG` | `latest` 或 Git SHA | 镜像标签 |
|
||||
| `HARBOR_USERNAME` | - | **必填** 仓库用户名 |
|
||||
| `HARBOR_PASSWORD` | - | **必填** 仓库密码 |
|
||||
|
||||
---
|
||||
|
||||
### deploy.js - 部署到 Kubernetes
|
||||
|
||||
```bash
|
||||
# 部署最新镜像
|
||||
node scripts/deploy.js
|
||||
|
||||
# 干跑模式(不实际部署)
|
||||
node scripts/deploy.js --dry-run
|
||||
|
||||
# 部署并运行数据库迁移
|
||||
node scripts/deploy.js --migrate
|
||||
|
||||
# 指定 tag
|
||||
TAG=sha-abc123 node scripts/deploy.js
|
||||
|
||||
# 指定命名空间
|
||||
NAMESPACE=staging node scripts/deploy.js
|
||||
```
|
||||
|
||||
**环境变量:**
|
||||
|
||||
| 变量 | 默认值 | 说明 |
|
||||
|--------------|------------------------------|-----------------|
|
||||
| `REGISTRY` | `harbor.gitdata.me/gta_team` | 镜像仓库 |
|
||||
| `TAG` | `latest` 或 Git SHA | 镜像标签 |
|
||||
| `NAMESPACE` | `gitdata` | K8s 命名空间 |
|
||||
| `RELEASE` | `gitdata` | Helm Release 名称 |
|
||||
| `KUBECONFIG` | `~/.kube/config` | Kubeconfig 路径 |
|
||||
|
||||
---
|
||||
|
||||
## 完整 CI/CD 流程
|
||||
|
||||
```bash
|
||||
# 1. 构建
|
||||
node scripts/build.js
|
||||
|
||||
# 2. 推送
|
||||
HARBOR_USERNAME=user HARBOR_PASSWORD=pass node scripts/push.js
|
||||
|
||||
# 3. 部署
|
||||
node scripts/deploy.js --migrate
|
||||
```
|
||||
|
||||
## 本地开发
|
||||
|
||||
```bash
|
||||
# 本地构建测试
|
||||
node scripts/build.js app
|
||||
|
||||
# 使用本地 tag
|
||||
TAG=dev node scripts/build.js
|
||||
|
||||
# 部署到测试环境
|
||||
NAMESPACE=test node scripts/deploy.js
|
||||
```
|
||||
@ -1,28 +0,0 @@
|
||||
/**
|
||||
* Fix ugly module-path tags in openapi.json generated by utoipa.
|
||||
* Transforms tags like "crate::agent::provider" -> "agent-provider"
|
||||
*/
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
const openapiPath = path.join(__dirname, '..', 'openapi.json');
|
||||
const json = JSON.parse(fs.readFileSync(openapiPath, 'utf8'));
|
||||
|
||||
let fixed = 0;
|
||||
|
||||
// Fix operation tags
|
||||
for (const p in json.paths) {
|
||||
for (const m in json.paths[p]) {
|
||||
const op = json.paths[p][m];
|
||||
if (op.tags) {
|
||||
op.tags = op.tags.map(t => {
|
||||
const fixed = t.replace(/^crate::/, '').replace(/::/g, '-');
|
||||
return fixed;
|
||||
});
|
||||
fixed++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(openapiPath, JSON.stringify(json, null, 2));
|
||||
console.log(`Fixed tags in ${fixed} operations`);
|
||||
@ -1,55 +0,0 @@
|
||||
/**
|
||||
* Generate TypeScript axios client from openapi.json using @hey-api/openapi-ts.
|
||||
* Generates into src/client.
|
||||
* Post-processes: injects withCredentials: true and baseURL into the client config.
|
||||
*/
|
||||
const { execSync } = require('child_process');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
const ROOT = path.join(__dirname, '..');
|
||||
const CLIENT_DIR = path.join(ROOT, 'src', 'client');
|
||||
const CLIENT_GEN = path.join(CLIENT_DIR, 'client.gen.ts');
|
||||
|
||||
const openapiTsBin = path.join(ROOT, 'node_modules/@hey-api/openapi-ts/bin/run.js');
|
||||
const openapiJson = path.join(ROOT, 'openapi.json');
|
||||
|
||||
console.log('Running @hey-api/openapi-ts...');
|
||||
try {
|
||||
execSync(`node "${openapiTsBin}" -c @hey-api/client-axios -i "${openapiJson}" -o "${CLIENT_DIR}"`, {
|
||||
cwd: ROOT,
|
||||
stdio: 'inherit',
|
||||
});
|
||||
} catch (e) {
|
||||
console.error('Generator exited with code:', e.status);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// Post-process: inject withCredentials and baseURL into client config
|
||||
if (fs.existsSync(CLIENT_GEN)) {
|
||||
let content = fs.readFileSync(CLIENT_GEN, 'utf8');
|
||||
|
||||
// Remove unused createConfig import
|
||||
content = content.replace(
|
||||
"import { type ClientOptions, type Config, createClient, createConfig } from './client';",
|
||||
"import { type ClientOptions, type Config, createClient } from './client';"
|
||||
);
|
||||
|
||||
// Replace the client creation to include withCredentials and baseURL
|
||||
content = content.replace(
|
||||
'export const client = createClient(createConfig<ClientOptions2>());',
|
||||
`export const createClientConfig = (override?: Config<ClientOptions2>): Config<ClientOptions2> => {
|
||||
return {
|
||||
withCredentials: true,
|
||||
baseURL: import.meta.env.VITE_API_BASE_URL ?? '',
|
||||
...override,
|
||||
};
|
||||
};
|
||||
export const client = createClient(createClientConfig());`
|
||||
);
|
||||
|
||||
fs.writeFileSync(CLIENT_GEN, content);
|
||||
console.log('Updated client.gen.ts with withCredentials and baseURL');
|
||||
}
|
||||
|
||||
console.log('Done.');
|
||||
@ -1,89 +0,0 @@
|
||||
/**
|
||||
* Generates changelog data file for the frontend.
|
||||
* Run with: node scripts/generate-changelog-data.js
|
||||
*/
|
||||
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
const CHANGELOG_DIR = path.join(__dirname, '..', 'changelog');
|
||||
const OUTPUT_FILE = path.join(__dirname, '..', 'src', 'data', 'changelog-data.ts');
|
||||
|
||||
const LANGUAGES = ['en', 'cn', 'de', 'fr'];
|
||||
|
||||
function readFile(filePath) {
|
||||
try {
|
||||
return fs.readFileSync(filePath, 'utf-8');
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function parseMdx(content) {
|
||||
const frontmatterMatch = content.match(/^---\n([\s\S]*?)\n---\n([\s\S]*)$/);
|
||||
if (!frontmatterMatch) {
|
||||
return { title: '', body: content };
|
||||
}
|
||||
const body = frontmatterMatch[2].trim();
|
||||
const frontmatter = frontmatterMatch[1];
|
||||
const titleMatch = frontmatter.match(/title:\s*["']?([^"'\n]+)["']?/);
|
||||
const title = titleMatch ? titleMatch[1].trim() : '';
|
||||
return { title, body };
|
||||
}
|
||||
|
||||
// Get all unique dates
|
||||
const dates = [];
|
||||
const files = fs.readdirSync(CHANGELOG_DIR);
|
||||
files.forEach(file => {
|
||||
const match = file.match(/^(\d{4}-\d{2}-\d{2})-(\w+)\.mdx$/);
|
||||
if (match) {
|
||||
const date = match[1];
|
||||
const lang = match[2];
|
||||
if (!dates.includes(date)) {
|
||||
dates.push(date);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Sort dates descending
|
||||
dates.sort((a, b) => new Date(b) - new Date(a));
|
||||
|
||||
// Generate data for each language
|
||||
const data = {};
|
||||
LANGUAGES.forEach(lang => {
|
||||
data[lang] = dates.map(date => {
|
||||
const filePath = path.join(CHANGELOG_DIR, `${date}-${lang}.mdx`);
|
||||
const content = readFile(filePath);
|
||||
if (!content) {
|
||||
return null;
|
||||
}
|
||||
const { title, body } = parseMdx(content);
|
||||
return {
|
||||
date,
|
||||
title,
|
||||
lang,
|
||||
author: 'ZhenYi',
|
||||
body,
|
||||
};
|
||||
}).filter(Boolean);
|
||||
});
|
||||
|
||||
// Generate TypeScript file
|
||||
const tsContent = `// Auto-generated from changelog/*.mdx files
|
||||
// Run: node scripts/generate-changelog-data.js
|
||||
|
||||
export type ChangelogEntry = {
|
||||
date: string;
|
||||
title: string;
|
||||
lang: string;
|
||||
author: string;
|
||||
body: string;
|
||||
};
|
||||
|
||||
export const CHANGELOG_DATA: Record<string, ChangelogEntry[]> = ${JSON.stringify(data, null, 2)};
|
||||
|
||||
export const CHANGELOG_LANGUAGES = ${JSON.stringify(LANGUAGES)};
|
||||
`;
|
||||
|
||||
fs.writeFileSync(OUTPUT_FILE, tsContent);
|
||||
console.log(`Generated ${OUTPUT_FILE}`);
|
||||
Loading…
Reference in New Issue
Block a user