refactor(agent): split monolithic service files into specialized modules

Extract agent, compact, embed, task, and modes modules from single
service.rs files into focused sub-modules. Add orao module for
O1-like reasoning loop. Move RigAgentService to rig_tool.rs.
This commit is contained in:
ZhenYi 2026-05-11 17:04:57 +08:00
parent 129aa3dce7
commit d45e9e28f4
35 changed files with 2795 additions and 1862 deletions

View File

@ -1,7 +1,4 @@
//! Agent service using rig's built-in Agent with full feature support.
//!
//! This module provides a higher-level agent service built on rig's Agent,
//! supporting multi-turn conversations, RAG, and built-in streaming.
//! Rig-based agent using rig's built-in Agent with full feature support.
pub mod service;
pub use service::RigAgentService;
pub mod rig_tool;
pub use rig_tool::{AgentResponse, RigAgentService, StreamChunk};

View File

@ -1,23 +1,17 @@
//! Agent service using rig's built-in Agent.
//!
//! This is a complete implementation that leverages rig's Agent for
//! multi-turn reasoning, tool execution, streaming, and token tracking.
use futures::Stream;
use futures::StreamExt;
use rig::{
agent::{AgentBuilder, MultiTurnStreamItem},
client::CompletionClient,
completion::Prompt,
streaming::{StreamingPrompt, StreamedAssistantContent},
streaming::{StreamedAssistantContent, StreamingPrompt},
};
use tokio_stream::wrappers::ReceiverStream;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use crate::client::AiClientConfig;
use crate::error::AgentError;
/// Response from an agent completion (rig's Agent prompt response).
#[derive(Debug)]
pub struct AgentResponse {
pub content: String,
@ -25,12 +19,9 @@ pub struct AgentResponse {
pub output_tokens: u64,
}
/// Streaming chunk from the agent.
#[derive(Debug)]
pub enum StreamChunk {
/// Text delta from the model
Text(String),
/// Final response with aggregated usage
Final {
content: String,
input_tokens: u64,
@ -38,22 +29,19 @@ pub enum StreamChunk {
},
}
/// Service for running agents using rig's built-in Agent.
///
/// Provides both simple prompting and full streaming with automatic
/// tool call handling via rig's native Agent.
pub struct RigAgentService {
config: AiClientConfig,
model_name: String,
}
impl RigAgentService {
/// Create a new RigAgentService.
pub fn new(config: AiClientConfig, model_name: impl Into<String>) -> Self {
Self { config, model_name: model_name.into() }
Self {
config,
model_name: model_name.into(),
}
}
/// Run a single prompt with the agent (single-turn, no tools).
pub async fn prompt(
&self,
system_prompt: &str,
@ -62,9 +50,7 @@ impl RigAgentService {
let client = self.config.build_rig_client();
let model = client.completion_model(&self.model_name);
let agent = AgentBuilder::new(model)
.preamble(system_prompt)
.build();
let agent = AgentBuilder::new(model).preamble(system_prompt).build();
let response = agent
.prompt(user_input)
@ -74,15 +60,11 @@ impl RigAgentService {
Ok(AgentResponse {
content: response.output,
input_tokens: response.total_usage.input_tokens,
output_tokens: response.total_usage.output_tokens,
input_tokens: response.usage.input_tokens,
output_tokens: response.usage.output_tokens,
})
}
/// Run a prompt with tools (supports multi-turn via rig's Agent).
///
/// The agent will automatically handle tool calls by calling rig's
/// ToolDyn implementations with proper argument deserialization.
pub async fn prompt_with_tools(
&self,
system_prompt: &str,
@ -108,16 +90,11 @@ impl RigAgentService {
Ok(AgentResponse {
content: response.output,
input_tokens: response.total_usage.input_tokens,
output_tokens: response.total_usage.output_tokens,
input_tokens: response.usage.input_tokens,
output_tokens: response.usage.output_tokens,
})
}
/// Stream a prompt with the agent using rig's native streaming.
///
/// This returns a proper async stream that yields text chunks as they arrive
/// and a final response chunk with aggregated token usage. Tool calls are
/// handled transparently by rig's Agent.
pub async fn stream_prompt(
&self,
system_prompt: &str,
@ -129,17 +106,10 @@ impl RigAgentService {
let client = self.config.build_rig_client();
let model = client.completion_model(&self.model_name);
let agent = AgentBuilder::new(model)
.preamble(system_prompt)
.build();
let agent = AgentBuilder::new(model).preamble(system_prompt).build();
// stream_prompt().await returns StreamingResult directly (not wrapped in Result)
// StreamingResult is Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem, StreamingError>>>>
let stream: rig::agent::StreamingResult<_> = agent
.stream_prompt(user_input)
.await;
let stream: rig::agent::StreamingResult<_> = agent.stream_prompt(user_input).await;
// Bridge the rig stream to our channel-based stream
let (tx, rx) = mpsc::channel::<std::result::Result<StreamChunk, AgentError>>(100);
tokio::spawn(async move {
@ -152,12 +122,14 @@ impl RigAgentService {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
let cleaned = text.text.replace('\n', "");
let _ = tx.send(Ok(StreamChunk::Text(cleaned))).await;
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
final_content.push_str(&text.text);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::ToolCall { tool_call, internal_call_id: _ },
StreamedAssistantContent::ToolCall {
tool_call,
internal_call_id: _,
},
)) => {
let args_str = match &tool_call.function.arguments {
serde_json::Value::String(s) => s.clone(),
@ -168,7 +140,6 @@ impl RigAgentService {
args = %args_str,
"rig_agent_streaming_tool_call"
);
// Tool calllint — emitted for observability, rig handles execution internally
}
Ok(MultiTurnStreamItem::StreamUserItem(
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
@ -180,11 +151,13 @@ impl RigAgentService {
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
let _ = tx.send(Ok(StreamChunk::Final {
content: final_content.clone(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
})).await;
let _ = tx
.send(Ok(StreamChunk::Final {
content: final_content.clone(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
}))
.await;
}
Err(e) => {
let _ = tx.send(Err(AgentError::OpenAi(e.to_string()))).await;
@ -197,10 +170,6 @@ impl RigAgentService {
Ok(ReceiverStream::new(rx))
}
/// Stream a prompt with tools using rig's native streaming.
///
/// Returns a stream thatproperly handles multi-turn tool calls via rig's Agent
/// streaming infrastructure.
pub async fn stream_prompt_with_tools(
&self,
system_prompt: &str,
@ -222,33 +191,30 @@ impl RigAgentService {
let stream = agent
.stream_prompt(user_input)
.with_history(Vec::new())
.with_history(Vec::<rig::completion::Message>::new())
.multi_turn(max_turns)
.await;
let (tx, rx) = mpsc::channel::<std::result::Result<StreamChunk, AgentError>>(100);
let (tx, rx) = mpsc::channel::<Result<StreamChunk, AgentError>>(100);
tokio::spawn(async move {
let mut final_content = String::new();
tokio::pin!(stream);
while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
let cleaned = text.text.replace('\n', "");
let _ = tx.send(Ok(StreamChunk::Text(cleaned))).await;
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
final_content.push_str(&text.text);
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {
let usage = resp.usage();
let _ = tx.send(Ok(StreamChunk::Final {
content: final_content.clone(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
})).await;
let _ = tx
.send(Ok(StreamChunk::Final {
content: final_content.clone(),
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
}))
.await;
}
Err(e) => {
let _ = tx.send(Err(AgentError::OpenAi(e.to_string()))).await;
@ -261,9 +227,8 @@ impl RigAgentService {
Ok(ReceiverStream::new(rx))
}
/// Count tokens in text using tiktoken for the configured model.
pub fn count_tokens(&self, text: &str) -> std::result::Result<usize, AgentError> {
pub fn count_tokens(&self, text: &str) -> Result<usize, AgentError> {
crate::tokent::count_text(text, &self.model_name)
.map_err(|e| AgentError::Internal(e.to_string()))
}
}
}

View File

@ -195,7 +195,7 @@ pub async fn execute_chat_stream(
&messages, model_name, config, temperature, max_tokens,
if tools_enabled { Some(&tools) } else { None }, None,
Arc::new(move |delta| {
let content = delta.to_string().replace('\n', "");
let content = delta.to_string();
let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer });
fut
}),

View File

@ -74,7 +74,7 @@ where
.build();
let stream = agent.stream_prompt(&request.input)
.with_history(Vec::new())
.with_history(Vec::<rig::completion::Message>::new())
.multi_turn(request.max_tool_depth)
.await;
@ -90,12 +90,14 @@ where
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => {
step_count += 1;
let t = text.text;
let cleaned = t.replace('\n', "");
on_chunk(ReactStep::Answer { step: step_count, answer: cleaned }).await;
on_chunk(ReactStep::Answer { step: step_count, answer: t.clone() }).await;
final_content.push_str(&t);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(reasoning))) => {
let reasoning_text = reasoning.reasoning.join("");
let reasoning_text: String = reasoning.content.iter().filter_map(|c| match c {
rig::completion::message::ReasoningContent::Text { text, .. } => Some(text.as_str()),
_ => None,
}).collect::<Vec<_>>().join("");
if !reasoning_text.is_empty() {
step_count += 1;
on_chunk(ReactStep::Thought { step: step_count, thought: reasoning_text }).await;

View File

@ -62,7 +62,7 @@ pub async fn execute_process_stream(
&messages, &model_name, &config, temperature, max_tokens,
if tools_enabled { Some(&tools) } else { None }, None,
Arc::new(move |delta| {
let content = delta.to_string().replace('\n', "");
let content = delta.to_string();
let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer });
fut
}),

View File

@ -160,8 +160,8 @@ fn ai_metrics() -> &'static AiMetrics {
pub(crate) fn to_rig_message(msg: &ChatRequestMessage) -> RigMessage {
match msg.role.as_str() {
"system" => {
// System messages are handled via preamble(), but we still
// need to return something. Return a system message as User for safety.
// System messages are handled via preamble(), not passed as messages.
// We still need to return a valid RigMessage variant.
RigMessage::user(msg.content.as_deref().unwrap_or(""))
}
"user" => {
@ -263,9 +263,6 @@ async fn do_completion<M>(
where
M: CompletionModel<Client = openai::Client>,
{
let mut history: Vec<RigMessage> = messages.iter().map(to_rig_message).collect();
// Extract preamble (first system message) and remove from history
let preamble = messages
.iter()
.find(|m| m.role == "system")
@ -273,12 +270,6 @@ where
.unwrap_or("")
.to_string();
history.retain(|m| !matches!(m, RigMessage::User { .. } | RigMessage::Assistant { .. }));
// For tool_result messages, we need to add them back
// Actually, let's keep the approach: filter out system, add others back
// The rig completion request uses: preamble (system) + messages (conversation)
// For our messages: system → preamble, rest → messages
let non_system: Vec<RigMessage> = messages
.iter()
.filter(|m| m.role != "system")
@ -700,13 +691,15 @@ async fn call_stream_once(
}
}
Ok(StreamedAssistantContent::Reasoning(reasoning)) => {
for part in &reasoning.reasoning {
reasoning_content.push_str(part);
on_reasoning_delta(part).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Thinking,
content: part.clone(),
});
for part in &reasoning.content {
if let rig::completion::message::ReasoningContent::Text { text, .. } = part {
reasoning_content.push_str(text);
on_reasoning_delta(text).await;
chunks.push(StreamChunk {
chunk_type: StreamChunkType::Thinking,
content: text.clone(),
});
}
}
}
Ok(StreamedAssistantContent::ReasoningDelta { reasoning, .. }) => {

View File

@ -0,0 +1,39 @@
use crate::AgentError;
use models::rooms::room_message::{
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
};
use models::Expr;
use sea_orm::*;
impl super::CompactService {
pub async fn fetch_room_messages_secure(
&self,
room_id: uuid::Uuid,
requester_id: uuid::Uuid,
) -> Result<Vec<RoomMessageModel>, AgentError> {
use models::rooms::{RoomAccess, RoomUserState};
RoomMessage::find()
.filter(RmCol::Room.eq(room_id))
.filter(
Condition::any()
.add(Expr::exists(
RoomUserState::find()
.filter(models::rooms::room_user_state::Column::Room.eq(room_id))
.filter(models::rooms::room_user_state::Column::User.eq(requester_id))
.into_query(),
))
.add(Expr::exists(
RoomAccess::find()
.filter(models::rooms::room_access::Column::Room.eq(room_id))
.filter(models::rooms::room_access::Column::User.eq(requester_id))
.into_query(),
)),
)
.order_by_asc(RmCol::Seq)
.limit(10000)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))
}
}

View File

@ -1,8 +1,32 @@
//! Context compaction for AI sessions and room message history.
pub mod auth_fetch;
pub mod helpers;
pub mod service;
pub mod room_compactor;
pub mod summarizer;
pub mod types;
pub use service::CompactService;
use sea_orm::DatabaseConnection;
pub use types::{CompactConfig, CompactLevel, CompactSummary, MessageSummary, ThresholdResult};
#[derive(Clone)]
pub struct CompactService {
db: DatabaseConnection,
ai_client_config: crate::client::AiClientConfig,
model: String,
}
impl CompactService {
pub fn new(
db: DatabaseConnection,
ai_client_config: crate::client::AiClientConfig,
model: String,
) -> Self {
Self {
db,
ai_client_config,
model,
}
}
}

View File

@ -0,0 +1,180 @@
use models::rooms::room_message::{
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
};
use sea_orm::ColumnTrait;
use sea_orm::{EntityTrait, QueryFilter, QueryOrder, QuerySelect};
use crate::compact::types::CompactLevel;
use crate::tokent::resolve_usage;
use crate::{AgentError, CompactSummary, MessageSummary};
impl super::CompactService {
pub async fn compact_room(
&self,
room_id: uuid::Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<uuid::Uuid, String>>,
requester_id: uuid::Uuid,
context_window_tokens: i32,
compaction_max_summary_ratio: f32,
) -> Result<CompactSummary, AgentError> {
let messages = self
.fetch_room_messages_secure(room_id, requester_id)
.await?;
if messages.is_empty() {
let room_exists = models::rooms::room::Entity::find_by_id(room_id)
.one(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?
.is_some();
if room_exists {
return Err(AgentError::Internal("Access denied or room empty".into()));
} else {
return Err(AgentError::Internal("Room not found".into()));
}
}
let user_ids: Vec<uuid::Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
if messages.len() <= level.retain_count() {
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id: uuid::Uuid::new_v4(),
room_id,
retained,
summary: String::new(),
compacted_at: chrono::Utc::now(),
messages_compressed: 0,
usage: None,
});
}
let retain_count = level.retain_count();
let split_index = messages.len().saturating_sub(retain_count);
let (to_summarize, retained_messages) = messages.split_at(split_index);
let retained: Vec<MessageSummary> = retained_messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let max_summary_tokens =
(context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
let (summary, remote_usage) = self
.summarize_messages(to_summarize, max_summary_tokens)
.await?;
let summarized_text = to_summarize
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
Ok(CompactSummary {
session_id: uuid::Uuid::new_v4(),
room_id,
retained,
summary,
compacted_at: chrono::Utc::now(),
messages_compressed: to_summarize.len(),
usage: Some(usage),
})
}
pub async fn compact_session(
&self,
session_id: uuid::Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<uuid::Uuid, String>>,
context_window_tokens: i32,
compaction_max_summary_ratio: f32,
) -> Result<CompactSummary, AgentError> {
let messages: Vec<RoomMessageModel> = RoomMessage::find()
.filter(RmCol::Room.eq(session_id))
.order_by_asc(RmCol::Seq)
.limit(10000)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
if messages.is_empty() {
return Err(AgentError::Internal("session has no messages".into()));
}
let user_ids: Vec<uuid::Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
if messages.len() <= level.retain_count() {
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id,
room_id: uuid::Uuid::nil(),
retained,
summary: String::new(),
compacted_at: chrono::Utc::now(),
messages_compressed: 0,
usage: None,
});
}
let retain_count = level.retain_count();
let split_index = messages.len().saturating_sub(retain_count);
let (to_summarize, retained_messages) = messages.split_at(split_index);
let retained: Vec<MessageSummary> = retained_messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let max_summary_tokens =
(context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
let (summary, remote_usage) = self
.summarize_messages(to_summarize, max_summary_tokens)
.await?;
let summarized_text = to_summarize
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
Ok(CompactSummary {
session_id,
room_id: uuid::Uuid::nil(),
retained,
summary,
compacted_at: chrono::Utc::now(),
messages_compressed: to_summarize.len(),
usage: Some(usage),
})
}
}

View File

@ -1,327 +0,0 @@
use chrono::Utc;
use models::ColumnTrait;
use models::rooms::room_message::{
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
};
use models::users::user::{Column as UserCol, Entity as User};
use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
use uuid::Uuid;
use crate::client::types::ChatRequestMessage;
use crate::client::AiClientConfig;
use crate::client::call_with_params;
use crate::AgentError;
use crate::compact::types::{CompactLevel, CompactSummary, MessageSummary};
use crate::tokent::{TokenUsage, resolve_usage};
#[derive(Clone)]
pub struct CompactService {
db: DatabaseConnection,
ai_client_config: AiClientConfig,
model: String,
}
impl CompactService {
pub fn new(db: DatabaseConnection, ai_client_config: AiClientConfig, model: String) -> Self {
Self { db, ai_client_config, model }
}
pub async fn compact_room(
&self,
room_id: Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>,
requester_id: Uuid,
context_window_tokens: i32,
compaction_max_summary_ratio: f32,
) -> Result<CompactSummary, AgentError> {
// Verify room access at the database level to ensure auth context is enforced.
// Public rooms are accessible to project members.
// For simplicity in this audit fix, we'll fetch only if access exists.
let messages = self.fetch_room_messages_secure(room_id, requester_id).await?;
if messages.is_empty() {
// Check if room actually exists or if it's just empty/inaccessible
let room_exists = models::rooms::room::Entity::find_by_id(room_id)
.one(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?
.is_some();
if room_exists {
return Err(AgentError::Internal("Access denied or room empty".into()));
} else {
return Err(AgentError::Internal("Room not found".into()));
}
}
let user_ids: Vec<Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
if messages.len() <= level.retain_count() {
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary: String::new(),
compacted_at: Utc::now(),
messages_compressed: 0,
usage: None,
});
}
let retain_count = level.retain_count();
let split_index = messages.len().saturating_sub(retain_count);
let (to_summarize, retained_messages) = messages.split_at(split_index);
let retained: Vec<MessageSummary> = retained_messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
// Build text of what was summarized (for tiktoken fallback)
let summarized_text = to_summarize
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary,
compacted_at: Utc::now(),
messages_compressed: to_summarize.len(),
usage: Some(usage),
})
}
pub async fn compact_session(
&self,
session_id: Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>,
context_window_tokens: i32,
compaction_max_summary_ratio: f32,
) -> Result<CompactSummary, AgentError> {
let messages: Vec<RoomMessageModel> = RoomMessage::find()
.filter(RmCol::Room.eq(session_id))
.order_by_asc(RmCol::Seq)
.limit(10000)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
if messages.is_empty() {
return Err(AgentError::Internal("session has no messages".into()));
}
let user_ids: Vec<Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
if messages.len() <= level.retain_count() {
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id,
room_id: Uuid::nil(),
retained,
summary: String::new(),
compacted_at: Utc::now(),
messages_compressed: 0,
usage: None,
});
}
let retain_count = level.retain_count();
let split_index = messages.len().saturating_sub(retain_count);
let (to_summarize, retained_messages) = messages.split_at(split_index);
let retained: Vec<MessageSummary> = retained_messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
let summarized_text = to_summarize
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
Ok(CompactSummary {
session_id,
room_id: Uuid::nil(),
retained,
summary,
compacted_at: Utc::now(),
messages_compressed: to_summarize.len(),
usage: Some(usage),
})
}
async fn fetch_room_messages_secure(
&self,
room_id: Uuid,
requester_id: Uuid,
) -> Result<Vec<RoomMessageModel>, AgentError> {
use models::rooms::{RoomUserState, RoomAccess};
use sea_orm::QueryTrait;
use sea_orm::sea_query::Expr;
// Find messages for the room where the requester has access.
// We check both the room_user_state table (membership) and the room_access table (explicit grants).
RoomMessage::find()
.filter(RmCol::Room.eq(room_id))
.filter(
sea_orm::Condition::any()
.add(
Expr::exists(
RoomUserState::find()
.filter(models::rooms::room_user_state::Column::Room.eq(room_id))
.filter(models::rooms::room_user_state::Column::User.eq(requester_id))
.into_query()
)
)
.add(
Expr::exists(
RoomAccess::find()
.filter(models::rooms::room_access::Column::Room.eq(room_id))
.filter(models::rooms::room_access::Column::User.eq(requester_id))
.into_query()
)
)
)
.order_by_asc(RmCol::Seq)
.limit(10000)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))
}
fn message_to_summary(m: &RoomMessageModel, user_name_map: &std::collections::HashMap<Uuid, String>) -> MessageSummary {
let sender_name = if let Some(user_id) = m.sender_id {
user_name_map.get(&user_id).cloned().unwrap_or_else(|| m.sender_type.to_string())
} else {
m.sender_type.to_string()
};
MessageSummary {
id: m.id,
sender_type: m.sender_type.clone(),
sender_id: m.sender_id,
sender_name,
content: m.content.clone(),
content_type: m.content_type.clone(),
tool_call_id: None,
send_at: m.send_at,
}
}
async fn get_user_name_map(
&self,
user_ids: &[Uuid],
) -> Result<std::collections::HashMap<Uuid, String>, AgentError> {
use std::collections::HashMap;
let mut map = HashMap::new();
if !user_ids.is_empty() {
let users = User::find()
.filter(UserCol::Uid.is_in(user_ids.to_vec()))
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
for user in users {
map.insert(user.uid, user.username);
}
}
Ok(map)
}
async fn summarize_messages(
&self,
messages: &[RoomMessageModel],
max_summary_tokens: usize,
) -> Result<(String, Option<TokenUsage>), AgentError> {
let user_ids: Vec<Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let user_name_map = self.get_user_name_map(&user_ids).await?;
let sender_mapper = |m: &RoomMessageModel| {
if let Some(user_id) = m.sender_id {
if let Some(username) = user_name_map.get(&user_id) {
return username.clone();
}
}
m.sender_type.to_string()
};
let body = crate::compact::helpers::messages_to_text(messages, sender_mapper);
let user_msg = ChatRequestMessage::user(format!(
"Summarise the following conversation concisely, preserving all key facts, \
decisions, and any pending or in-progress work. \
The summary MUST NOT exceed {} tokens. \
Use this format:\n\n\
**Summary:** <one-paragraph overview>\n\
**Key decisions:** <bullet list or 'none'>\n\
**Open items:** <bullet list or 'none'>\n\n\
Conversation:\n\n{}",
max_summary_tokens,
body
));
let response = call_with_params(
&[user_msg],
&self.model,
&self.ai_client_config,
0.3,
2048,
None,
None,
None,
)
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let remote_usage =
TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32);
Ok((response.content, remote_usage))
}
}

View File

@ -0,0 +1,110 @@
use models::rooms::room_message::Model as RoomMessageModel;
use models::users::user::{Column as UserCol, Entity as User};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
use crate::client::call_with_params;
use crate::client::types::ChatRequestMessage;
use crate::compact::types::MessageSummary;
use crate::tokent::TokenUsage;
use crate::AgentError;
impl super::CompactService {
pub async fn summarize_messages(
&self,
messages: &[RoomMessageModel],
max_summary_tokens: usize,
) -> Result<(String, Option<TokenUsage>), AgentError> {
let user_ids: Vec<uuid::Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let user_name_map = self.get_user_name_map(&user_ids).await?;
let sender_mapper = |m: &RoomMessageModel| {
if let Some(user_id) = m.sender_id {
if let Some(username) = user_name_map.get(&user_id) {
return username.clone();
}
}
m.sender_type.to_string()
};
let body = crate::compact::helpers::messages_to_text(messages, sender_mapper);
let user_msg = ChatRequestMessage::user(format!(
"Summarise the following conversation concisely, preserving all key facts, \
decisions, and any pending or in-progress work. \
The summary MUST NOT exceed {} tokens. \
Use this format:\n\n\
**Summary:** <one-paragraph overview>\n\
**Key decisions:** <bullet list or 'none'>\n\
**Open items:** <bullet list or 'none'>\n\n\
Conversation:\n\n{}",
max_summary_tokens, body
));
let response = call_with_params(
&[user_msg],
&self.model,
&self.ai_client_config,
0.3,
2048,
None,
None,
None,
)
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
let remote_usage =
TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32);
Ok((response.content, remote_usage))
}
pub fn message_to_summary(
m: &RoomMessageModel,
user_name_map: &std::collections::HashMap<uuid::Uuid, String>,
) -> MessageSummary {
let sender_name = if let Some(user_id) = m.sender_id {
user_name_map
.get(&user_id)
.cloned()
.unwrap_or_else(|| m.sender_type.to_string())
} else {
m.sender_type.to_string()
};
MessageSummary {
id: m.id,
sender_type: m.sender_type.clone(),
sender_id: m.sender_id,
sender_name,
content: m.content.clone(),
content_type: m.content_type.clone(),
tool_call_id: None,
send_at: m.send_at,
}
}
pub async fn get_user_name_map(
&self,
user_ids: &[uuid::Uuid],
) -> Result<std::collections::HashMap<uuid::Uuid, String>, AgentError> {
use std::collections::HashMap;
let mut map = HashMap::new();
if !user_ids.is_empty() {
let users = User::find()
.filter(UserCol::Uid.is_in(user_ids.to_vec()))
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
for user in users {
map.insert(user.uid, user.username);
}
}
Ok(map)
}
}

61
libs/agent/embed/chunk.rs Normal file
View File

@ -0,0 +1,61 @@
/// Maximum characters per chunk for embedding (approximates token limit).
/// text-embedding-3-small: 8192 token limit.
/// For CJK ~1 char/token, for English ~4 chars/token.
/// Conservative limit: 7000 chars to leave room for all languages.
const MAX_CHUNK_CHARS: usize = 7000;
/// Split long text into chunks at paragraph/sentence boundaries.
/// Returns at least one chunk even for empty text.
/// Safe for multi-byte characters (uses char indices, not byte indices).
pub fn chunk_text(text: &str) -> Vec<String> {
if text.is_empty() {
return vec![String::new()];
}
if text.len() <= MAX_CHUNK_CHARS {
return vec![text.to_string()];
}
let char_indices: Vec<usize> = text.char_indices().map(|(i, _)| i).collect();
let total_chars = char_indices.len();
let mut chunks = Vec::new();
let mut start_idx = 0;
while start_idx < total_chars {
let byte_start = char_indices[start_idx];
let end_char_idx = (start_idx + MAX_CHUNK_CHARS).min(total_chars);
let byte_end_candidate = char_indices[end_char_idx - 1]
+ text[char_indices[end_char_idx - 1]..]
.chars()
.next()
.map(|c| c.len_utf8())
.unwrap_or(1);
if end_char_idx >= total_chars {
chunks.push(text[byte_start..].to_string());
break;
}
let search_range = &text[byte_start..byte_end_candidate];
let break_at = search_range.rfind("\n\n").map(|pos| pos + 2)
.or_else(|| search_range.rfind('\n').map(|pos| pos + 1))
.or_else(|| search_range.rfind(". ").map(|pos| pos + 1))
.or_else(|| search_range.rfind("! ").map(|pos| pos + 1))
.or_else(|| search_range.rfind("? ").map(|pos| pos + 1));
if let Some(offset) = break_at {
let byte_end = byte_start + offset;
chunks.push(text[byte_start..byte_end].to_string());
let mut advance = start_idx + 1;
while advance < total_chars && char_indices[advance] < byte_end {
advance += 1;
}
start_idx = advance;
} else {
chunks.push(text[byte_start..byte_end_candidate].to_string());
start_idx = end_char_idx;
}
}
chunks
}

View File

@ -0,0 +1,23 @@
use async_trait::async_trait;
/// Trait for entities that can be embedded as vectors into Qdrant.
#[async_trait]
pub trait Embeddable {
fn entity_type(&self) -> &'static str;
fn to_text(&self) -> String;
fn entity_id(&self) -> String;
}
/// Input struct for batch memory embedding into per-room Qdrant collections.
#[derive(Debug, Clone)]
pub struct EmbedMemoryInput {
pub message_id: String,
pub content: String,
pub project_name: String,
pub room_id: String,
pub user_id: Option<String>,
pub sender_type: String,
}
/// Input struct for batch tag embedding.
pub use models::TagEmbedInput;

View File

@ -1,112 +1,11 @@
use async_trait::async_trait;
use qdrant_client::qdrant::Filter;
use sea_orm::DatabaseConnection;
use std::sync::Arc;
use std::collections::HashMap;
use super::client::{EmbedClient, EmbedPayload, EmbedVector, SearchResult};
/// Maximum characters per chunk for embedding (approximates token limit).
/// text-embedding-3-small: 8192 token limit.
/// For CJK ~1 char/token, for English ~4 chars/token.
/// Conservative limit: 7000 chars to leave room for all languages.
const MAX_CHUNK_CHARS: usize = 7000;
#[async_trait]
pub trait Embeddable {
fn entity_type(&self) -> &'static str;
fn to_text(&self) -> String;
fn entity_id(&self) -> String;
}
/// Split long text into chunks at paragraph/sentence boundaries.
/// Returns at least one chunk even for empty text.
/// Safe for multi-byte characters (uses char indices, not byte indices).
fn chunk_text(text: &str) -> Vec<String> {
if text.is_empty() {
return vec![String::new()];
}
if text.len() <= MAX_CHUNK_CHARS {
return vec![text.to_string()];
}
// Collect char boundary byte positions
let char_indices: Vec<usize> = text.char_indices().map(|(i, _)| i).collect();
let total_chars = char_indices.len();
let mut chunks = Vec::new();
let mut start_idx = 0; // char index
while start_idx < total_chars {
// Start byte offset
let byte_start = char_indices[start_idx];
// Find end char index: at most MAX_CHUNK_CHARS characters
let end_char_idx = (start_idx + MAX_CHUNK_CHARS).min(total_chars);
let byte_end_candidate = char_indices[end_char_idx - 1] + text[char_indices[end_char_idx - 1]..].chars().next().map(|c| c.len_utf8()).unwrap_or(1);
if end_char_idx >= total_chars {
chunks.push(text[byte_start..].to_string());
break;
}
// Try to break at paragraph or sentence boundary in the allowed range
let search_range = &text[byte_start..byte_end_candidate];
let break_at = if let Some(pos) = search_range.rfind("\n\n") {
Some(pos + 2) // after the paragraph break
} else if let Some(pos) = search_range.rfind('\n') {
Some(pos + 1)
} else if let Some(pos) = search_range.rfind(". ") {
Some(pos + 1)
} else if let Some(pos) = search_range.rfind("! ") {
Some(pos + 1)
} else if let Some(pos) = search_range.rfind("? ") {
Some(pos + 1)
} else {
None
};
if let Some(offset) = break_at {
let byte_end = byte_start + offset;
chunks.push(text[byte_start..byte_end].to_string());
// Advance char index to match the byte break
let mut advance = start_idx + 1;
while advance < total_chars && char_indices[advance] < byte_end {
advance += 1;
}
start_idx = advance;
} else {
// Hard break at char boundary
chunks.push(text[byte_start..byte_end_candidate].to_string());
start_idx = end_char_idx;
}
}
chunks
}
#[derive(Clone)]
pub struct EmbedService {
client: Arc<EmbedClient>,
db: DatabaseConnection,
model_name: String,
dimensions: u64,
}
impl EmbedService {
pub fn new(
client: EmbedClient,
db: DatabaseConnection,
model_name: String,
dimensions: u64,
) -> Self {
Self {
client: Arc::new(client),
db,
model_name,
dimensions,
}
}
use super::chunk::chunk_text;
use super::client::{EmbedPayload, EmbedVector};
use super::embeddable::{EmbedMemoryInput, Embeddable};
/// Embedding and upsert operations for entity vectors in Qdrant.
impl super::EmbedService {
pub async fn embed_issue(
&self,
id: &str,
@ -203,69 +102,6 @@ impl EmbedService {
Ok(())
}
pub async fn search_issues(
&self,
query: &str,
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
self.client
.search(query, "issue", &self.model_name, limit)
.await
}
pub async fn search_repos(
&self,
query: &str,
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
self.client
.search(query, "repo", &self.model_name, limit)
.await
}
pub async fn search_issues_filtered(
&self,
query: &str,
limit: usize,
filter: Filter,
) -> crate::Result<Vec<SearchResult>> {
self.client
.search_with_filter(query, "issue", &self.model_name, limit, filter)
.await
}
pub async fn delete_issue_embedding(&self, issue_id: &str) -> crate::Result<()> {
self.client.delete_by_entity_id("issue", issue_id).await
}
pub async fn delete_repo_embedding(&self, repo_id: &str) -> crate::Result<()> {
self.client.delete_by_entity_id("repo", repo_id).await
}
pub async fn ensure_collections(&self) -> crate::Result<()> {
self.client
.ensure_collection("issue", self.dimensions)
.await?;
self.client
.ensure_collection("repo", self.dimensions)
.await?;
self.client.ensure_skill_collection(self.dimensions).await?;
self.client
.ensure_collection("repo_tag", self.dimensions)
.await?;
// Room memory collections are created per-room on first embed
Ok(())
}
pub fn db(&self) -> &DatabaseConnection {
&self.db
}
pub fn client(&self) -> &Arc<EmbedClient> {
&self.client
}
/// Embed a project skill into Qdrant for vector-based semantic search.
pub async fn embed_skill(
&self,
skill_id: i64,
@ -279,7 +115,6 @@ impl EmbedService {
tracing::debug!(skill_id = %skill_id, name = %name, content_len = content.len(), "embed_skill: starting");
// Auto-chunk long content
let texts = chunk_text(content);
tracing::debug!(skill_id = %skill_id, chunks = texts.len(), "embed_skill: chunked");
@ -288,13 +123,17 @@ impl EmbedService {
.embed_skill(&id, name, desc, content, project_uuid, &self.model_name)
.await?;
} else {
// Multi-chunk: embed each chunk with chunk_index metadata
let full_texts: Vec<String> = texts.iter().map(|t| format!("{}: {} {}", name, desc, t)).collect();
let full_texts: Vec<String> = texts
.iter()
.map(|t| format!("{}: {} {}", name, desc, t))
.collect();
tracing::debug!(skill_id = %skill_id, "embed_skill: calling embed_batch");
let embeddings = self.client.embed_batch(&full_texts, &self.model_name).await?;
let points: Vec<EmbedVector> = embeddings.into_iter().enumerate().map(|(i, vector)| {
EmbedVector {
let points: Vec<EmbedVector> = embeddings
.into_iter()
.enumerate()
.map(|(i, vector)| EmbedVector {
id: format!("{}:chunk:{}", id, i),
vector,
payload: EmbedPayload {
@ -306,10 +145,11 @@ impl EmbedService {
"description": desc,
"chunk_index": i,
"total_chunks": texts.len(),
}).into(),
})
.into(),
},
}
}).collect();
})
.collect();
self.client.upsert(points).await?;
}
@ -317,7 +157,6 @@ impl EmbedService {
Ok(())
}
/// Embed an issue with auto-chunking for long content.
pub async fn embed_issue_chunked(
&self,
id: &str,
@ -336,8 +175,10 @@ impl EmbedService {
let embeddings = self.client.embed_batch(&chunks, &self.model_name).await?;
let points: Vec<EmbedVector> = embeddings.into_iter().enumerate().map(|(i, vector)| {
EmbedVector {
let points: Vec<EmbedVector> = embeddings
.into_iter()
.enumerate()
.map(|(i, vector)| EmbedVector {
id: format!("{}:chunk:{}", id, i),
vector,
payload: EmbedPayload {
@ -347,17 +188,15 @@ impl EmbedService {
extra: serde_json::json!({
"chunk_index": i,
"total_chunks": chunks.len(),
}).into(),
})
.into(),
},
}
}).collect();
})
.collect();
self.client.upsert(points).await
}
/// Batch-embed multiple conversation messages into per-room Qdrant collections.
/// Auto-chunks long messages and filters non-text/system/empty content.
/// Handles all filtering internally: only text-type, non-empty, non-system messages are embedded.
pub async fn embed_memories_batch(
&self,
messages: Vec<EmbedMemoryInput>,
@ -366,8 +205,6 @@ impl EmbedService {
return Ok(());
}
// Group by room collection for batch upsert to reduce Qdrant calls
use std::collections::HashMap;
let mut by_room: HashMap<String, Vec<(EmbedMemoryInput, Vec<String>)>> = HashMap::new();
for msg in messages {
@ -375,15 +212,15 @@ impl EmbedService {
if chunks.is_empty() || chunks.iter().all(|c| c.trim().is_empty()) {
continue;
}
let collection = crate::embed::qdrant::QdrantClient::room_memory_collection_name(
let collection = super::qdrant::QdrantClient::room_memory_collection_name(
&msg.project_name, &msg.room_id,
);
by_room.entry(collection).or_default().push((msg, chunks));
}
for (collection, entries) in &by_room {
// Collect all texts for batch embedding
let all_texts: Vec<String> = entries.iter()
let all_texts: Vec<String> = entries
.iter()
.flat_map(|(_, chunks)| chunks.iter().cloned())
.collect();
@ -393,14 +230,12 @@ impl EmbedService {
let embeddings = self.client.embed_batch(&all_texts, &self.model_name).await?;
// Ensure the room collection exists with correct dimensions
if let Some((first, _)) = entries.first() {
let _ = self.client
.ensure_room_memory_collection(&first.project_name, &first.room_id, self.dimensions)
.await;
}
// Build points: one per chunk
let mut points = Vec::new();
let mut embed_idx = 0;
for (msg, chunks) in entries {
@ -423,9 +258,18 @@ impl EmbedService {
extra: serde_json::json!({
"user_id": msg.user_id,
"sender_type": msg.sender_type,
"chunk_index": if chunks.len() > 1 { Some(chunk_i) } else { None },
"total_chunks": if chunks.len() > 1 { Some(chunks.len()) } else { None },
}).into(),
"chunk_index": if chunks.len() > 1 {
Some(chunk_i)
} else {
None
},
"total_chunks": if chunks.len() > 1 {
Some(chunks.len())
} else {
None
},
})
.into(),
},
});
embed_idx += 1;
@ -440,11 +284,9 @@ impl EmbedService {
Ok(())
}
/// Batch-embed repo tags with project isolation.
/// Each tag stores project_id as entity_id for post-filtering.
pub async fn embed_tags_batch(
&self,
tags: Vec<TagEmbedInput>,
tags: Vec<super::embeddable::TagEmbedInput>,
) -> crate::Result<()> {
if tags.is_empty() {
return Ok(());
@ -494,48 +336,6 @@ impl EmbedService {
self.client.upsert(points).await
}
/// Search repo tags by semantic similarity within a project.
/// Filters by project_id (stored in entity_id) for project isolation.
pub async fn search_tags(
&self,
query: &str,
project_id: &str,
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
let mut results = self
.client
.search(query, "repo_tag", &self.model_name, limit + 1)
.await?;
results.retain(|r| r.payload.entity_id == project_id);
results.truncate(limit);
Ok(results)
}
pub fn model_name(&self) -> &str {
&self.model_name
}
pub fn dimensions(&self) -> u64 {
self.dimensions
}
pub fn embed_client(&self) -> &EmbedClient {
&self.client
}
/// Search skills by semantic similarity within a project.
pub async fn search_skills(
&self,
query: &str,
project_uuid: &str,
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
self.client
.search_skills(query, &self.model_name, project_uuid, limit)
.await
}
/// Embed a conversation message into Qdrant as a memory vector.
pub async fn embed_memory(
&self,
message_id: &str,
@ -548,32 +348,4 @@ impl EmbedService {
.embed_memory(message_id, text, project_name, room_id, user_id, &self.model_name)
.await
}
/// Search past conversation messages by semantic similarity within a room.
pub async fn search_memories(
&self,
query: &str,
project_name: &str,
room_id: &str,
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
self.client
.search_memories(query, &self.model_name, project_name, room_id, limit, self.dimensions)
.await
}
}
/// Input struct for batch memory embedding into per-room Qdrant collections.
#[derive(Debug, Clone)]
pub struct EmbedMemoryInput {
pub message_id: String,
pub content: String,
pub project_name: String,
pub room_id: String,
pub user_id: Option<String>,
pub sender_type: String,
}
/// Input struct for batch tag embedding.
/// Re-exported from models for backward compatibility.
pub use models::TagEmbedInput;
}

View File

@ -1,10 +1,69 @@
pub mod chunk;
pub mod client;
pub mod embeddable;
pub mod entity_embed;
pub mod qdrant;
pub mod service;
pub mod search;
pub use client::{EmbedClient, EmbedPayload, EmbedVector, SearchResult};
pub use embeddable::{EmbedMemoryInput, Embeddable, TagEmbedInput};
pub use qdrant::QdrantClient;
pub use service::{EmbedMemoryInput, EmbedService, Embeddable, TagEmbedInput};
use std::sync::Arc;
#[derive(Clone)]
pub struct EmbedService {
client: Arc<EmbedClient>,
db: sea_orm::DatabaseConnection,
model_name: String,
dimensions: u64,
}
impl EmbedService {
pub fn new(
client: EmbedClient,
db: sea_orm::DatabaseConnection,
model_name: String,
dimensions: u64,
) -> Self {
Self {
client: Arc::new(client),
db,
model_name,
dimensions,
}
}
pub async fn ensure_collections(&self) -> crate::Result<()> {
self.client
.ensure_collection("issue", self.dimensions)
.await?;
self.client
.ensure_collection("repo", self.dimensions)
.await?;
self.client.ensure_skill_collection(self.dimensions).await?;
self.client
.ensure_collection("repo_tag", self.dimensions)
.await?;
Ok(())
}
pub fn db(&self) -> &sea_orm::DatabaseConnection {
&self.db
}
pub fn client(&self) -> &Arc<EmbedClient> {
&self.client
}
pub fn model_name(&self) -> &str {
&self.model_name
}
pub fn dimensions(&self) -> u64 {
self.dimensions
}
}
pub async fn new_embed_client(config: &config::AppConfig) -> crate::Result<EmbedClient> {
let base_url = config
@ -22,7 +81,9 @@ pub async fn new_embed_client(config: &config::AppConfig) -> crate::Result<Embed
.api_key(&api_key)
.base_url(&base_url)
.build()
.map_err(|e| crate::AgentError::Internal(format!("failed to build rig openai client: {}", e)))?;
.map_err(|e| {
crate::AgentError::Internal(format!("failed to build rig openai client: {}", e))
})?;
let qdrant = QdrantClient::new(&qdrant_url, qdrant_api_key.as_deref()).await?;
Ok(EmbedClient::new(openai, qdrant))

View File

@ -0,0 +1,79 @@
use qdrant_client::qdrant::Filter;
use super::client::SearchResult;
/// Vector search operations for Qdrant-backed entity retrieval.
impl super::EmbedService {
pub async fn search_issues(
&self,
query: &str,
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
self.client
.search(query, "issue", &self.model_name, limit)
.await
}
pub async fn search_repos(
&self,
query: &str,
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
self.client
.search(query, "repo", &self.model_name, limit)
.await
}
pub async fn search_issues_filtered(
&self,
query: &str,
limit: usize,
filter: Filter,
) -> crate::Result<Vec<SearchResult>> {
self.client
.search_with_filter(query, "issue", &self.model_name, limit, filter)
.await
}
/// Search repo tags by semantic similarity within a project.
/// Filters by project_id (stored in entity_id) for project isolation.
pub async fn search_tags(
&self,
query: &str,
project_id: &str,
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
let mut results = self
.client
.search(query, "repo_tag", &self.model_name, limit + 1)
.await?;
results.retain(|r| r.payload.entity_id == project_id);
results.truncate(limit);
Ok(results)
}
/// Search skills by semantic similarity within a project.
pub async fn search_skills(
&self,
query: &str,
project_uuid: &str,
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
self.client
.search_skills(query, &self.model_name, project_uuid, limit)
.await
}
/// Search past conversation messages by semantic similarity within a room.
pub async fn search_memories(
&self,
query: &str,
project_name: &str,
room_id: &str,
limit: usize,
) -> crate::Result<Vec<SearchResult>> {
self.client
.search_memories(query, &self.model_name, project_name, room_id, limit, self.dimensions)
.await
}
}

View File

@ -6,6 +6,7 @@ pub mod compact;
pub mod embed;
pub mod error;
pub mod model;
pub mod orao;
pub mod perception;
pub mod react;
pub mod skills;
@ -13,33 +14,42 @@ pub mod sync;
pub mod task;
pub mod tokent;
pub mod tool;
pub use billing::{BillingRecord, BillingResult, record_ai_usage, initialize_user_billing, initialize_project_billing, check_balance, persist_billing_error};
pub use sync::list_accessible_models;
pub use task::TaskService;
pub use tokent::{TokenUsage, resolve_usage};
pub use perception::{PerceptionService, SkillContext, SkillEntry, ToolCallEvent};
pub use skills::{
BuiltInSkill, SKILL_TEMPLATES, all_skill_slugs, all_skills,
get_skill, get_skill_by_tool, is_built_in_skill, match_skill_by_keyword, skills_by_category,
pub use billing::{
check_balance, initialize_project_billing, initialize_user_billing, persist_billing_error,
record_ai_usage, BillingRecord, BillingResult,
};
pub use chat::{
AiContextSenderType, AiRequest, AiStreamChunk, ChatService, Mention, RoomMessageContext,
StreamCallback,
};
pub use client::{AiCallResponse, AiClientConfig, call_with_params, call_with_retry};
pub use client::types::ChatRequestMessage;
pub use client::{call_with_params, call_with_retry, AiCallResponse, AiClientConfig};
pub use compact::{CompactConfig, CompactLevel, CompactService, CompactSummary, MessageSummary};
pub use embed::{
EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult, TagEmbedInput, new_embed_client,
new_embed_client, EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult,
TagEmbedInput,
};
pub use error::{AgentError, Result};
pub use orao::{
ActionExecutor, ActionType, ActionResult, ActionVerdict, OraoConfig, OraoExecutor,
OraoExecutorBuilder, OraoOutcome, OraoStep, PerceptionSnapshot, PlannedAction,
ReasoningOutput, RoundRecord, SafetyLevel,
};
pub use perception::{PerceptionService, SkillContext, SkillEntry, ToolCallEvent};
pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT, ROOM_CONTEXT_PROMPT};
pub use skills::{
all_skill_slugs, all_skills, get_skill, get_skill_by_tool, is_built_in_skill, match_skill_by_keyword,
skills_by_category, BuiltInSkill, SKILL_TEMPLATES,
};
pub use sync::list_accessible_models;
pub use task::TaskService;
pub use tokent::{resolve_usage, TokenUsage};
pub use tool::{
ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,
ToolRegistry, ToolResult, ToolSchema,
ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition,
ToolError, ToolExecutor, ToolHandler, ToolParam, ToolRegistry, ToolResult, ToolSchema,
};
#[cfg(feature = "rig")]
pub use agent::RigAgentService;
#[cfg(feature = "rig")]
pub use tool::{RigToolSet, RecordingTool, is_retryable_tool_error};
pub use tool::{is_retryable_tool_error, RecordingTool, RigToolSet};

View File

@ -1 +0,0 @@
// All reasoning modes removed - using ReAct pattern directly in chat service

203
libs/agent/orao/act.rs Normal file
View File

@ -0,0 +1,203 @@
//! Act phase: execute planned actions with safety checks.
//!
//! Actions are executed through a caller-provided executor callback, which
//! typically dispatches to the [`ToolRegistry`] or runs shell commands.
//! All file access must go through function calls (tools), never direct
//! filesystem operations.
//!
//! [`ToolRegistry`]: crate::tool::ToolRegistry
use std::future::Future;
use std::pin::Pin;
use std::process::Command;
use std::time::Duration;
use super::types::{
ActionType, ActionResult, ActionVerdict, OraoConfig, PlannedAction, SafetyLevel,
};
/// Callback for executing a planned action.
///
/// The caller (service layer) provides this to wire up tool execution.
/// Returns `ActionResult` on completion.
pub type ActionExecutor = Box<
dyn Fn(
PlannedAction,
) -> Pin<Box<dyn Future<Output = ActionResult> + Send>>
+ Send
+ Sync,
>;
/// Check whether an action is allowed under the given safety configuration.
///
/// Returns `None` if allowed, or `Some(reason)` if blocked.
pub fn check_safety(action: &PlannedAction, config: &OraoConfig) -> Option<String> {
let safety = SafetyLevel::classify_command(&action.command_or_content);
if safety > config.max_safety_level {
return Some(format!(
"Action denied: safety level {:?} exceeds max allowed {:?}",
safety, config.max_safety_level
));
}
// Check for dangerous command patterns
if let Some(reason) = check_dangerous_command(&action.command_or_content) {
return Some(reason);
}
None
}
/// Execute a single planned action via the provided executor.
///
/// Applies safety checks and timeout, then delegates to the executor.
pub async fn execute_action(
action: PlannedAction,
config: &OraoConfig,
executor: &ActionExecutor,
) -> ActionResult {
// ── Safety gate ────────────────────────────────────────────────────
if let Some(reason) = check_safety(&action, config) {
return ActionResult {
action,
exit_code: Some(1),
stdout: String::new(),
stderr: reason,
file_changes: Vec::new(),
verdict: ActionVerdict::Failure,
};
}
// ── Execute with timeout ──────────────────────────────────────────
let action_clone = action.clone();
let exec_future = executor(action);
match tokio::time::timeout(
Duration::from_secs(config.action_timeout_secs),
exec_future,
)
.await
{
Ok(result) => result,
Err(_elapsed) => ActionResult {
action: action_clone,
exit_code: None,
stdout: String::new(),
stderr: format!(
"Action timed out after {} seconds",
config.action_timeout_secs
),
file_changes: Vec::new(),
verdict: ActionVerdict::Failure,
},
}
}
/// Build a default action executor that runs shell commands directly.
///
/// This is suitable for `shell_command` and `git_operation` action types.
/// For `tool_invoke`, the caller should provide a custom executor that
/// dispatches to the [`ToolRegistry`].
///
/// [`ToolRegistry`]: crate::tool::ToolRegistry
pub fn shell_executor(working_dir: String) -> ActionExecutor {
Box::new(move |action: PlannedAction| {
let dir = working_dir.clone();
Box::pin(async move {
match action.action_type {
ActionType::ShellCommand
| ActionType::GitOperation
| ActionType::ToolInvoke => run_shell_command(&action, &dir).await,
ActionType::FileWrite | ActionType::FileEdit => {
// File operations should use tool_invoke with a file-writing tool.
// Direct file access is discouraged; return an error directing to tools.
ActionResult {
exit_code: Some(1),
stdout: String::new(),
stderr: "File operations must use tool_invoke with registered file tools. Use shell_command with sed/echo for inline edits.".to_string(),
file_changes: Vec::new(),
verdict: ActionVerdict::Failure,
action,
}
}
ActionType::UserDialog => ActionResult {
exit_code: None,
stdout: "User dialog requested".to_string(),
stderr: String::new(),
file_changes: Vec::new(),
verdict: ActionVerdict::Success,
action,
},
}
})
})
}
async fn run_shell_command(action: &PlannedAction, working_dir: &str) -> ActionResult {
let cmd = &action.command_or_content;
let output = Command::new("sh")
.args(["-c", cmd])
.current_dir(working_dir)
.output();
match output {
Ok(out) => {
let exit_code = out.status.code();
let stdout = String::from_utf8_lossy(&out.stdout).to_string();
let stderr = String::from_utf8_lossy(&out.stderr).to_string();
let verdict = match exit_code {
Some(0) if !stderr_has_errors(&stderr) => ActionVerdict::Success,
Some(0) => ActionVerdict::SuccessWithWarnings,
_ => ActionVerdict::Failure,
};
ActionResult {
action: action.clone(),
exit_code,
stdout,
stderr,
file_changes: Vec::new(),
verdict,
}
}
Err(e) => ActionResult {
action: action.clone(),
exit_code: None,
stdout: String::new(),
stderr: format!("Failed to spawn command: {}", e),
file_changes: Vec::new(),
verdict: ActionVerdict::Failure,
},
}
}
fn stderr_has_errors(stderr: &str) -> bool {
let lower = stderr.to_lowercase();
lower.contains("error") || lower.contains("fail") || lower.contains("panic")
}
/// Check whether a shell command contains dangerous patterns.
///
/// Returns `Some(reason)` if the command is blocked, `None` if it's safe.
pub fn check_dangerous_command(cmd: &str) -> Option<String> {
let dangerous = [
("rm -rf /", "Recursive root deletion"),
("rm -rf ~", "Recursive home deletion"),
(":(){ :|:& };:", "Fork bomb"),
("mkfs.", "Filesystem format"),
("dd if=", "Raw device write"),
("> /dev/sda", "Raw device write"),
("chmod 777 /", "World-writable root"),
];
for (pattern, reason) in &dangerous {
if cmd.contains(pattern) {
return Some(format!("Blocked: {}{}", pattern, reason));
}
}
None
}

427
libs/agent/orao/mod.rs Normal file
View File

@ -0,0 +1,427 @@
//! ORAO (ObserveReasonActObserve) — a single-agent loop for complex engineering tasks.
//!
//! ORAO extends the ReAct paradigm with:
//! - **Multi-channel perception**: LLM-driven observation via read-only tools
//! - **Structured reasoning**: analysis + step-by-step action plan
//! - **Safety levels**: L0L4 permission grading for every action
//! - **Deadlock detection**: terminates after 3 rounds with no progress
//! - **Plan mode**: optional user-approval gate before execution
//! - **Round recording**: full audit trail for debugging and resumption
//!
//! # Architecture
//!
//! The [`OraoExecutor`] runs the O→R→A→O loop:
//! 1. **Observe** — LLM explores environment via observation tools, produces snapshot
//! 2. **Reason** — LLM analyzes snapshot, generates structured plan
//! 3. **Act** — Execute each planned action via [`ActionExecutor`] with safety checks
//! 4. **Observe** — Collect results, feed into next round
//!
//! All file access goes through function calls (tools), never direct filesystem operations.
//!
//! [`ActionExecutor`]: act::ActionExecutor
pub mod act;
pub mod observe;
pub mod reason;
pub mod types;
use std::time::Instant;
use crate::client::AiClientConfig;
use crate::error::{AgentError, Result};
pub use act::ActionExecutor;
pub use types::{
ActionResult, ActionType, ActionVerdict, FileChange, FileChangeType, OraoConfig, OraoStep,
PerceptionSnapshot, PlannedAction, ReasoningOutput, RoundRecord, SafetyLevel,
};
// ── ORAO Executor ───────────────────────────────────────────────────────────
/// Executes the ORAO loop for a single task.
///
/// All environment interaction goes through:
/// - **Observation tools** (read-only) for the Observe phase
/// - **Action executor** callback for the Act phase
///
/// No direct filesystem access — everything is mediated through function calls.
pub struct OraoExecutor {
config: AiClientConfig,
model_name: String,
action_executor: ActionExecutor,
}
impl OraoExecutor {
/// Create a new ORAO executor.
///
/// `action_executor` is called to execute each planned action. Wire it to
/// your [`ToolRegistry`] for tool-based execution, or use
/// [`act::shell_executor`] for simple shell-command execution.
///
/// [`ToolRegistry`]: crate::tool::ToolRegistry
pub fn new(
config: AiClientConfig,
model_name: impl Into<String>,
action_executor: ActionExecutor,
) -> Self {
Self {
config,
model_name: model_name.into(),
action_executor,
}
}
/// Run the ORAO loop to completion.
///
/// # Parameters
/// - `task_goal`: Description of what to accomplish.
/// - `orao_config`: ORAO-specific settings (max rounds, safety level, etc.).
/// - `tool_factory`: Called each round to produce read-only observation tools
/// (e.g. `git_diff`, `git_blob`, `repo_search`, `git_grep`). This allows
/// callers to provide fresh tool instances each round.
/// - `on_step`: Called with each [`OraoStep`] event for streaming/persistence.
/// - `on_plan_approval`: Called in plan mode; return `true` to proceed.
pub async fn execute<C, Fut, PA, PAFut, TF>(
&self,
task_goal: &str,
orao_config: &OraoConfig,
tool_factory: TF,
on_step: C,
on_plan_approval: PA,
) -> Result<OraoOutcome>
where
C: Fn(OraoStep) -> Fut + Send,
Fut: Future<Output = ()> + Send,
PA: Fn(ReasoningOutput) -> PAFut + Send,
PAFut: Future<Output = bool> + Send,
TF: Fn() -> Vec<Box<dyn rig::tool::ToolDyn + 'static>> + Send + Sync,
{
let mut round = 0usize;
let mut round_records: Vec<RoundRecord> = Vec::new();
let mut previous_result: Option<ActionResult> = None;
let mut previous_snapshot: Option<PerceptionSnapshot> = None;
let mut no_change_count: usize = 0;
// Observation turns: limit tool calls during exploration
let observe_max_turns = 10;
loop {
round += 1;
let round_start = Instant::now();
let round_input_tokens: u64 = 0;
let round_output_tokens: u64 = 0;
// ── Phase 1: Observe ───────────────────────────────────────
let snapshot = observe::observe(
&self.config,
&self.model_name,
task_goal,
previous_result.take(),
tool_factory(),
observe_max_turns,
)
.await?;
on_step(OraoStep::Observe {
round,
snapshot: snapshot.clone(),
})
.await;
// ── Deadlock detection ─────────────────────────────────────
if let Some(ref prev) = previous_snapshot {
if !observe::has_environment_changed(prev, &snapshot) {
no_change_count += 1;
if no_change_count >= orao_config.deadlock_threshold {
let reason = format!(
"Deadlock detected: no environmental change for {} consecutive rounds",
no_change_count
);
on_step(OraoStep::Failed {
total_rounds: round,
reason: reason.clone(),
})
.await;
return Ok(OraoOutcome::Failed {
reason,
rounds: round,
records: round_records,
});
}
} else {
no_change_count = 0;
}
}
previous_snapshot = Some(snapshot.clone());
// ── Phase 2: Reason ────────────────────────────────────────
let reasoning = reason::reason(
&self.config,
&self.model_name,
orao_config,
task_goal,
&snapshot,
round,
&round_records,
)
.await?;
on_step(OraoStep::Reason {
round,
reasoning: reasoning.clone(),
})
.await;
// ── Plan mode gate ─────────────────────────────────────────
if orao_config.plan_mode {
on_step(OraoStep::PlanProposed {
round,
reasoning: reasoning.clone(),
})
.await;
if !on_plan_approval(reasoning.clone()).await {
return Ok(OraoOutcome::Cancelled {
rounds: round,
records: round_records,
});
}
}
// ── Phase 3: Act ───────────────────────────────────────────
let mut round_result: Option<ActionResult> = None;
let mut all_success = true;
for planned in &reasoning.plan {
let safety = SafetyLevel::classify_command(&planned.command_or_content);
on_step(OraoStep::Act {
round,
action: planned.clone(),
safety_level: safety,
})
.await;
let result =
act::execute_action(planned.clone(), orao_config, &self.action_executor).await;
on_step(OraoStep::ObserveResult {
round,
result: result.clone(),
})
.await;
match &result.verdict {
ActionVerdict::Failure => {
all_success = false;
round_result = Some(result);
break; // Stop executing further steps on failure
}
ActionVerdict::SuccessWithWarnings => {
round_result = Some(result);
}
ActionVerdict::Success => {
round_result = Some(result);
}
}
}
// ── Phase 4: Record round ──────────────────────────────────
let duration_ms = round_start.elapsed().as_millis() as u64;
let record = RoundRecord {
round,
observe_summary: summarize_snapshot(&snapshot),
reasoning_summary: reasoning.analysis.clone(),
action: reasoning.plan.first().cloned(),
result_summary: round_result
.as_ref()
.map(|r| format!("{:?}: {}", r.verdict, truncate(&r.stdout, 200))),
tokens_input: round_input_tokens,
tokens_output: round_output_tokens,
duration_ms,
};
round_records.push(record);
// ── Check termination ──────────────────────────────────────
if all_success && !reasoning.plan.is_empty() {
let summary = format!(
"Task completed in {} round(s). Last action: {}",
round,
round_result
.as_ref()
.map(|r| truncate(&r.stdout, 500))
.unwrap_or_default()
);
on_step(OraoStep::Completed {
total_rounds: round,
summary: summary.clone(),
})
.await;
return Ok(OraoOutcome::Completed {
summary,
rounds: round,
records: round_records,
});
}
// Max rounds exceeded
if round >= orao_config.max_rounds {
let reason = format!("Reached max rounds ({})", orao_config.max_rounds);
on_step(OraoStep::Failed {
total_rounds: round,
reason: reason.clone(),
})
.await;
return Ok(OraoOutcome::Failed {
reason,
rounds: round,
records: round_records,
});
}
// Prepare for next round
previous_result = round_result;
}
}
}
// ── Outcome ─────────────────────────────────────────────────────────────────
/// Final outcome of an ORAO execution.
#[derive(Debug, Clone)]
pub enum OraoOutcome {
/// Task completed successfully.
Completed {
summary: String,
rounds: usize,
records: Vec<RoundRecord>,
},
/// Task failed (max rounds, deadlock, or unrecoverable error).
Failed {
reason: String,
rounds: usize,
records: Vec<RoundRecord>,
},
/// User cancelled the task (plan mode rejection or explicit interrupt).
Cancelled {
rounds: usize,
records: Vec<RoundRecord>,
},
}
impl OraoOutcome {
/// Number of rounds executed.
pub fn rounds(&self) -> usize {
match self {
Self::Completed { rounds, .. }
| Self::Failed { rounds, .. }
| Self::Cancelled { rounds, .. } => *rounds,
}
}
/// Whether the task was successful.
pub fn is_success(&self) -> bool {
matches!(self, Self::Completed { .. })
}
/// Round records for audit/debugging.
pub fn records(&self) -> &[RoundRecord] {
match self {
Self::Completed { records, .. }
| Self::Failed { records, .. }
| Self::Cancelled { records, .. } => records,
}
}
}
// ── Helpers ─────────────────────────────────────────────────────────────────
fn summarize_snapshot(snapshot: &PerceptionSnapshot) -> String {
let mut parts: Vec<String> = Vec::new();
if let Some(ref gs) = snapshot.git_status {
let first_line = gs.lines().next().unwrap_or("");
parts.push(format!("git: {}", truncate(first_line, 80)));
}
if !snapshot.files.is_empty() {
parts.push(format!("{} files", snapshot.files.len()));
}
if !snapshot.errors.is_empty() {
parts.push(format!("{} errors", snapshot.errors.len()));
}
if parts.is_empty() {
"no changes".to_string()
} else {
parts.join(", ")
}
}
fn truncate(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}...", &s[..max_len])
}
}
// ── Convenience builder ─────────────────────────────────────────────────────
/// Builder for [`OraoExecutor`] with chainable configuration.
pub struct OraoExecutorBuilder {
config: Option<AiClientConfig>,
model_name: Option<String>,
action_executor: Option<ActionExecutor>,
}
impl OraoExecutorBuilder {
pub fn new() -> Self {
Self {
config: None,
model_name: None,
action_executor: None,
}
}
pub fn ai_config(mut self, config: AiClientConfig) -> Self {
self.config = Some(config);
self
}
pub fn model(mut self, name: impl Into<String>) -> Self {
self.model_name = Some(name.into());
self
}
pub fn action_executor(mut self, executor: ActionExecutor) -> Self {
self.action_executor = Some(executor);
self
}
pub fn build(self) -> Result<OraoExecutor> {
let config = self.config.ok_or_else(|| AgentError::InvalidInput {
field: "config".to_string(),
reason: "AI client config is required".to_string(),
})?;
let model_name = self.model_name.ok_or_else(|| AgentError::InvalidInput {
field: "model_name".to_string(),
reason: "Model name is required".to_string(),
})?;
let action_executor = self
.action_executor
.ok_or_else(|| AgentError::InvalidInput {
field: "action_executor".to_string(),
reason: "Action executor is required".to_string(),
})?;
Ok(OraoExecutor::new(config, model_name, action_executor))
}
}
impl Default for OraoExecutorBuilder {
fn default() -> Self {
Self::new()
}
}

280
libs/agent/orao/observe.rs Normal file
View File

@ -0,0 +1,280 @@
//! Observe phase: LLM-driven multi-channel environment perception.
//!
//! The Observe phase gives the LLM a set of read-only observation tools and
//! instructs it to explore the environment. All file/git/system access goes
//! through function calls (tools), never direct filesystem operations.
//!
//! After exploration, the LLM produces a structured [`PerceptionSnapshot`]
//! summarizing the current state of the project.
use rig::agent::AgentBuilder;
use rig::client::CompletionClient;
use rig::completion::Prompt;
use crate::client::AiClientConfig;
use crate::error::AgentError;
use super::types::{ActionResult, PerceptionSnapshot};
/// Prompt for the ORAO Observe phase.
const OBSERVE_SYSTEM_PROMPT: &str = r#"You are an expert software engineering agent using the ORAO (Observe-Reason-Act-Observe) framework.
## Your Role: OBSERVE Phase
You are currently in the OBSERVE phase. Your task is to explore the project environment
and gather all relevant information using the available tools.
## What to Observe
Use the tools provided to you to check:
1. **Git status**: What branch are we on? What files have changed? Any uncommitted work?
2. **Project structure**: What directories and key files exist?
3. **Code content**: Read relevant source files to understand the codebase state.
4. **Errors/warnings**: Check build output, test results, linter output for issues.
5. **Configuration**: Check project config files (Cargo.toml, package.json, etc.) if relevant.
## Rules
- Use tools to explore do NOT guess or assume file contents.
- Focus on information relevant to the task at hand.
- Be thorough but efficient: 3-8 tool calls is typical.
- After gathering information, summarize your findings clearly.
## Output Format
After you have finished observing, provide a summary with these sections:
### Git Status
[Current branch, changed files, commit status]
### Project Structure
[Key directories and files relevant to the task]
### Key Files
[Important files you read, with brief notes on their content]
### Errors / Issues
[Any errors, warnings, or problems detected]
### Previous Action Result
[If a previous action was executed, describe its outcome]"#;
/// Run the Observe phase: let the LLM explore the environment via tools.
///
/// Returns a structured [`PerceptionSnapshot`] built from the LLM's observations.
/// All environment access goes through the provided `tools` — no direct
/// filesystem operations.
///
/// Takes ownership of `tools` (caller must clone if they need to reuse them).
pub async fn observe(
config: &AiClientConfig,
model_name: &str,
task_goal: &str,
previous_result: Option<ActionResult>,
tools: Vec<Box<dyn rig::tool::ToolDyn + 'static>>,
max_turns: usize,
) -> Result<PerceptionSnapshot, AgentError> {
let user_prompt = build_observe_prompt(task_goal, previous_result.as_ref());
let client = config.build_rig_client();
let model = client.completion_model(model_name);
let agent = AgentBuilder::new(model)
.preamble(OBSERVE_SYSTEM_PROMPT)
.tools(tools)
.default_max_turns(max_turns)
.build();
let response = agent
.prompt(&user_prompt)
.max_turns(max_turns)
.extended_details()
.await
.map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?;
// Build snapshot from the LLM's final summary
let summary = response.output;
let snapshot = parse_observation_summary(&summary, previous_result);
Ok(snapshot)
}
/// Build the user prompt for the Observe phase.
fn build_observe_prompt(task_goal: &str, previous_result: Option<&ActionResult>) -> String {
let mut prompt = format!(
"## Task Goal\n\n{}\n\n## Instructions\n\n\
Explore the project environment using the available tools. \
Gather all information relevant to the task above. \
After you have gathered sufficient information, provide a structured summary.",
task_goal
);
if let Some(prev) = previous_result {
prompt.push_str(&format!(
"\n\n## Previous Action Result\n\n\
- Action: {}\n\
- Verdict: {:?}\n\
- Exit code: {:?}\n\
- stdout: {}\n\
- stderr: {}",
prev.action.description,
prev.verdict,
prev.exit_code,
truncate_str(&prev.stdout, 2000),
truncate_str(&prev.stderr, 2000),
));
}
prompt
}
/// Parse the LLM's observation summary into a structured snapshot.
fn parse_observation_summary(
summary: &str,
previous_result: Option<ActionResult>,
) -> PerceptionSnapshot {
let mut snapshot = PerceptionSnapshot::default();
// Extract sections from the markdown summary
let mut current_section = "";
let mut section_content: Vec<&str> = Vec::new();
for line in summary.lines() {
if line.starts_with("### ") {
// Save previous section
store_section(&mut snapshot, current_section, &section_content);
current_section = line.trim_start_matches("### ").trim();
section_content.clear();
} else {
section_content.push(line);
}
}
// Save last section
store_section(&mut snapshot, current_section, &section_content);
snapshot.previous_action_result = previous_result;
// If no structured data was parsed, store the raw summary
if snapshot.git_status.is_none()
&& snapshot.project_structure.is_none()
&& snapshot.files.is_empty()
&& snapshot.errors.is_empty()
{
snapshot.notes.insert(
"raw_observation".to_string(),
summary.to_string(),
);
}
snapshot
}
fn store_section(snapshot: &mut PerceptionSnapshot, section: &str, content: &[&str]) {
let text = content.join("\n").trim().to_string();
if text.is_empty() {
return;
}
match section.to_lowercase().as_str() {
s if s.contains("git") => {
snapshot.git_status = Some(text);
}
s if s.contains("project") && s.contains("structure") => {
snapshot.project_structure = Some(text);
}
s if s.contains("file") => {
// Parse file references from the text
for line in content {
let line = line.trim();
if let Some(path) = extract_file_path(line) {
snapshot.files.push(super::types::PerceivedFile {
path,
size_bytes: 0,
content_preview: None,
});
}
}
}
s if s.contains("error") || s.contains("issue") || s.contains("warning") => {
for line in content {
let line = line.trim();
if !line.is_empty() && !line.starts_with('#') {
snapshot.errors.push(line.to_string());
}
}
}
_ => {
// Store unknown sections as notes
snapshot
.notes
.insert(section.to_string(), text);
}
}
}
/// Extract a file path from a markdown list item or code reference.
fn extract_file_path(line: &str) -> Option<String> {
// Match patterns like: - `src/main.rs` or - src/main.rs or `src/main.rs`
let line = line.trim();
// Backtick-wrapped path
if let Some(start) = line.find('`') {
let rest = &line[start + 1..];
if let Some(end) = rest.find('`') {
let path = rest[..end].to_string();
if path.contains('.') || path.contains('/') || path.contains('\\') {
return Some(path);
}
}
}
// Bare path pattern (word chars, slashes, dots)
if line.starts_with('-') || line.starts_with('*') {
let rest = line.trim_start_matches(&['-', '*', ' ']);
if rest.contains('/') || (rest.contains('.') && !rest.starts_with("http")) {
return Some(rest.to_string());
}
}
None
}
fn truncate_str(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}...", &s[..max_len])
}
}
/// Determine whether the environment has changed since the last snapshot.
///
/// Used for deadlock detection: if 3 consecutive rounds show no change,
/// the loop is terminated.
pub fn has_environment_changed(
previous: &PerceptionSnapshot,
current: &PerceptionSnapshot,
) -> bool {
if previous.git_status != current.git_status {
return true;
}
let prev_files: Vec<&str> = previous.files.iter().map(|f| f.path.as_str()).collect();
let curr_files: Vec<&str> = current.files.iter().map(|f| f.path.as_str()).collect();
if prev_files != curr_files {
return true;
}
if previous.errors != current.errors {
return true;
}
let prev_has_result = previous.previous_action_result.is_some();
let curr_has_result = current.previous_action_result.is_some();
if prev_has_result != curr_has_result {
return true;
}
false
}

218
libs/agent/orao/reason.rs Normal file
View File

@ -0,0 +1,218 @@
//! Reason phase: structured reasoning and plan generation via LLM.
//!
//! Takes the current [`PerceptionSnapshot`] and the task goal, sends them to
//! the configured LLM, and produces a structured [`ReasoningOutput`] with an
//! executable plan.
use crate::client::AiClientConfig;
use crate::error::AgentError;
use rig::agent::AgentBuilder;
use rig::client::CompletionClient;
use rig::completion::Prompt;
use super::types::{OraoConfig, PerceptionSnapshot, ReasoningOutput};
/// Prompt template for the ORAO reasoning phase.
const REASON_SYSTEM_PROMPT: &str = r#"You are an expert software engineering agent using the ORAO (Observe-Reason-Act-Observe) framework.
## Your Role: REASON Phase
You are currently in the REASON phase. You will receive:
1. A **task goal** what needs to be accomplished
2. An **observation snapshot** the current state of the environment
Your job is to produce a structured analysis and a step-by-step action plan.
## Output Format
You MUST respond with a valid JSON object matching this schema:
```json
{
"analysis": "<concise analysis of the current state, problems identified, constraints>",
"plan": [
{
"step_id": 1,
"description": "<what this step does>",
"action_type": "shell_command | file_write | file_edit | git_operation | tool_invoke | user_dialog",
"command_or_content": "<the exact command to run or content to write>",
"expected_result": "<what success looks like>",
"fallback_on_failure": "<what to do if this fails, or null>"
}
]
}
```
## Rules
1. **Be specific**: commands must be exact and complete (include necessary flags, paths, etc.)
2. **One action per step**: each step should do one thing
3. **Validate assumptions**: don't assume dependencies are installed or files exist
4. **Prefer small, safe steps**: each change should be minimal and reversible
5. **Include verification**: build/test after changes to verify correctness
6. **Plan size**: 1-10 steps is typical. Don't plan more than 15 steps without asking the user.
7. **Fallbacks**: for risky steps, specify what to try if the step fails
## Safety
- Prefer read-only verification before making changes
- Use version control (git) for reversible operations
- Flag any step that requires network access or system-level privileges"#;
/// Run the reasoning phase: send observation + task to the LLM and get a plan.
pub async fn reason(
config: &AiClientConfig,
model_name: &str,
_orao_config: &OraoConfig,
task_goal: &str,
snapshot: &PerceptionSnapshot,
round: usize,
history_rounds: &[super::types::RoundRecord],
) -> Result<ReasoningOutput, AgentError> {
let user_prompt = build_reason_prompt(task_goal, snapshot, round, history_rounds);
let client = config.build_rig_client();
let model = client.completion_model(model_name);
let agent = AgentBuilder::new(model)
.preamble(REASON_SYSTEM_PROMPT)
.build();
let response = agent
.prompt(&user_prompt)
.extended_details()
.await
.map_err(|e: rig::completion::PromptError| AgentError::OpenAi(e.to_string()))?;
let output = response.output.trim().to_string();
// Extract JSON from the response (it may be wrapped in ```json fences)
let json_str = extract_json(&output);
serde_json::from_str::<ReasoningOutput>(json_str).map_err(|e| {
AgentError::Internal(format!(
"Failed to parse reasoning output as JSON: {}. Raw output: {}",
e, output
))
})
}
/// Build the user prompt for the reason phase.
fn build_reason_prompt(
task_goal: &str,
snapshot: &PerceptionSnapshot,
round: usize,
history_rounds: &[super::types::RoundRecord],
) -> String {
let mut prompt = format!(
"## Task Goal\n\n{}\n\n## Round Number\n\n{}\n\n",
task_goal, round
);
// ── Observation snapshot ────────────────────────────────────────────
prompt.push_str("## Current Observation\n\n");
if let Some(ref git_status) = snapshot.git_status {
prompt.push_str(&format!("### Git Status\n```\n{}\n```\n\n", git_status));
}
if let Some(ref project_structure) = snapshot.project_structure {
prompt.push_str(&format!(
"### Project Structure\n```\n{}\n```\n\n",
project_structure
));
}
if !snapshot.files.is_empty() {
prompt.push_str("### Relevant Files\n\n");
for file in &snapshot.files {
prompt.push_str(&format!(
"- `{}` ({} bytes)\n",
file.path, file.size_bytes
));
if let Some(ref preview) = file.content_preview {
let truncated = if preview.len() > 2000 {
format!("{}... [truncated]", &preview[..2000])
} else {
preview.clone()
};
prompt.push_str(&format!("```\n{}\n```\n\n", truncated));
}
}
}
if !snapshot.errors.is_empty() {
prompt.push_str("### Errors / Warnings\n\n");
for err in &snapshot.errors {
prompt.push_str(&format!("- {}\n", err));
}
prompt.push('\n');
}
if let Some(ref prev_result) = snapshot.previous_action_result {
prompt.push_str(&format!(
"### Previous Action Result\n- Verdict: {:?}\n- Exit code: {:?}\n- stdout: {}\n- stderr: {}\n\n",
prev_result.verdict,
prev_result.exit_code,
truncate_str(&prev_result.stdout, 1000),
truncate_str(&prev_result.stderr, 1000),
));
}
// ── History summary ─────────────────────────────────────────────────
if !history_rounds.is_empty() {
prompt.push_str("## Previous Rounds\n\n");
for record in history_rounds.iter().rev().take(3) {
prompt.push_str(&format!(
"- Round {}: {} | {} tokens | {}ms\n",
record.round,
truncate_str(&record.reasoning_summary, 120),
record.tokens_input + record.tokens_output,
record.duration_ms,
));
}
prompt.push('\n');
}
prompt.push_str(
"## Instructions\n\nProduce a structured analysis and action plan in JSON format as specified.",
);
prompt
}
/// Extract JSON content from a string that may be wrapped in markdown fences.
fn extract_json(s: &str) -> &str {
let s = s.trim();
if s.starts_with("```json") {
let inner = &s[7..];
if let Some(end) = inner.rfind("```") {
return inner[..end].trim();
}
return inner.trim();
}
if s.starts_with("```") {
let inner = &s[3..];
if let Some(end) = inner.rfind("```") {
return inner[..end].trim();
}
return inner.trim();
}
// Heuristic: find first { and last }
if let (Some(start), Some(end)) = (s.find('{'), s.rfind('}')) {
return &s[start..=end];
}
s
}
fn truncate_str(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}...", &s[..max_len])
}
}

337
libs/agent/orao/types.rs Normal file
View File

@ -0,0 +1,337 @@
//! ORAO core types.
//!
//! ORAO (ObserveReasonActObserve) is a single-agent loop paradigm for complex
//! engineering tasks. It extends ReAct with structured multi-channel perception,
//! safety permission levels, plan mode, and deadlock detection.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
// ── Safety levels ───────────────────────────────────────────────────────────
/// Permission level for actions executed by ORAO.
///
/// L0 (read-only) → auto-allow.
/// L1 (local write) → confirm on first use.
/// L2 (build) → confirm on first use.
/// L3 (network) → explicit user approval required.
/// L4 (system) → denied by default.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum SafetyLevel {
/// L0 — Read-only: `ls`, `cat`, `grep`, `git status`
ReadOnly = 0,
/// L1 — Local write: edit source files, create new files
LocalWrite = 1,
/// L2 — Build/test: `cargo build`, `npm test`
Build = 2,
/// L3 — Network: `pip install`, `curl`
Network = 3,
/// L4 — System: `sudo`, global config changes
System = 4,
}
impl SafetyLevel {
/// Classify a shell command into a safety level.
pub fn classify_command(cmd: &str) -> Self {
let cmd_trimmed = cmd.trim();
// L0: read-only commands
let l0_prefixes = [
"ls", "cat", "head", "tail", "less", "file", "stat", "wc",
"grep", "rg", "find", "which", "type", "echo", "printf",
"pwd", "env", "printenv", "date", "uname", "hostname",
"git status", "git log", "git diff", "git show", "git branch",
"git tag", "git remote", "git config --get", "git blame",
"cargo metadata", "cargo tree", "cargo read-manifest",
"tree", "du", "df",
];
for p in &l0_prefixes {
if cmd_trimmed.starts_with(p) {
return Self::ReadOnly;
}
}
// L4: system-level commands (denied by default)
let l4_patterns = [
"sudo", "su ", "chown", "chmod 777", "mkfs", "mkswap",
"mount", "umount", "fdisk", "parted", "dd if=",
"systemctl", "service ", "chkconfig", "update-rc.d",
"passwd", "useradd", "userdel", "usermod", "groupadd",
"iptables", "ufw", "firewall-cmd",
"shutdown", "reboot", "halt", "poweroff",
"rm -rf /", "rm -rf ~", "rm -rf .", ":(){ :|:& };:",
];
for p in &l4_patterns {
if cmd_trimmed.starts_with(p) || cmd_trimmed.contains(p) {
return Self::System;
}
}
// L3: network commands
let l3_prefixes = [
"curl", "wget", "nc ", "ncat", "telnet", "ssh ", "scp",
"rsync", "pip install", "pip3 install", "npm install",
"npm i ", "yarn add", "cargo install", "gem install",
"go get", "go install", "apt-get", "apt ", "yum ", "dnf ",
"brew ", "pacman ", "zypper", "docker pull", "docker run",
"git clone", "git fetch", "git push", "git pull",
"gh ", "glab ", "aws ", "gcloud ", "az ",
];
for p in &l3_prefixes {
if cmd_trimmed.starts_with(p) {
return Self::Network;
}
}
// L2: build/test commands
let l2_prefixes = [
"cargo build", "cargo test", "cargo check", "cargo clippy",
"cargo fmt", "cargo run", "cargo bench", "cargo doc",
"npm test", "npm run", "npx ", "yarn test", "yarn run",
"pnpm test", "pnpm run", "bun test", "bun run",
"make", "cmake", "ninja", "meson", "bazel",
"pytest", "python -m pytest", "python3 -m pytest",
"go test", "go build", "go vet", "go fmt",
"rustc", "rustfmt", "clippy", "miri",
"eslint", "prettier", "tsc", "jest", "vitest",
"docker build", "docker compose", "docker-compose",
"kubectl apply", "kubectl delete", "helm ",
];
for p in &l2_prefixes {
if cmd_trimmed.starts_with(p) {
return Self::Build;
}
}
// Default to L1 (local write) for anything else
Self::LocalWrite
}
}
// ── Action types ────────────────────────────────────────────────────────────
/// The type of action to execute.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ActionType {
/// Execute a shell command in a controlled terminal.
ShellCommand,
/// Create or overwrite a file.
FileWrite,
/// Make a localized edit to an existing file.
FileEdit,
/// Version-control operation (commit, add, etc.).
GitOperation,
/// Invoke an external tool or API.
ToolInvoke,
/// Ask the user for input or a decision.
UserDialog,
}
// ── Action plan ─────────────────────────────────────────────────────────────
/// A single planned action from the reasoning phase.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlannedAction {
/// Step number within the plan.
pub step_id: usize,
/// Human-readable description.
pub description: String,
/// The type of action.
pub action_type: ActionType,
/// The command or content to execute/write.
pub command_or_content: String,
/// What success should look like.
pub expected_result: String,
/// What to try if this step fails.
pub fallback_on_failure: Option<String>,
}
/// Structured reasoning output from the Reason phase.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningOutput {
/// Analysis of the current state.
pub analysis: String,
/// The plan to execute.
pub plan: Vec<PlannedAction>,
}
// ── Perception snapshot ─────────────────────────────────────────────────────
/// Structured observation collected during the Observe phase.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PerceptionSnapshot {
/// Project directory tree summary.
pub project_structure: Option<String>,
/// Relevant file paths and contents.
pub files: Vec<PerceivedFile>,
/// Current errors/warnings in the environment.
pub errors: Vec<String>,
/// Git status summary.
pub git_status: Option<String>,
/// Result of the previous action (if any).
pub previous_action_result: Option<ActionResult>,
/// Free-form context notes.
pub notes: HashMap<String, String>,
}
/// A file observed during perception.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerceivedFile {
pub path: String,
pub size_bytes: u64,
pub content_preview: Option<String>,
}
// ── Action result ───────────────────────────────────────────────────────────
/// The result of executing an action.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActionResult {
/// The action that was executed.
pub action: PlannedAction,
/// Exit code (0 = success for shell commands).
pub exit_code: Option<i32>,
/// Captured stdout.
pub stdout: String,
/// Captured stderr.
pub stderr: String,
/// Summary of file changes (if applicable).
pub file_changes: Vec<FileChange>,
/// Preliminary assessment.
pub verdict: ActionVerdict,
}
/// A file change detected after an action.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileChange {
pub path: String,
pub change_type: FileChangeType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FileChangeType {
Created,
Modified,
Deleted,
}
/// Preliminary verdict on an action's outcome.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ActionVerdict {
Success,
SuccessWithWarnings,
Failure,
}
// ── ORAO step events ────────────────────────────────────────────────────────
/// A single event emitted during an ORAO round, analogous to `ReactStep`.
///
/// These are yielded via the streaming callback so the caller can persist
/// them or forward them to a frontend.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum OraoStep {
/// Initial observation: environment snapshot before any action.
Observe {
round: usize,
snapshot: PerceptionSnapshot,
},
/// The reasoning/analysis output, including the plan.
Reason {
round: usize,
reasoning: ReasoningOutput,
},
/// An action is about to be executed.
Act {
round: usize,
action: PlannedAction,
safety_level: SafetyLevel,
},
/// The result observed after executing an action.
ObserveResult {
round: usize,
result: ActionResult,
},
/// Plan mode: a plan has been generated and is awaiting user approval.
PlanProposed {
round: usize,
reasoning: ReasoningOutput,
},
/// The task completed successfully.
Completed {
total_rounds: usize,
summary: String,
},
/// The task failed (max rounds, deadlock, or explicit failure).
Failed {
total_rounds: usize,
reason: String,
},
}
// ── Round record (audit) ────────────────────────────────────────────────────
/// A persistent record of one ORAO round, used for audit and resumption.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoundRecord {
/// Round number (1-indexed).
pub round: usize,
/// Summary of the Observe phase.
pub observe_summary: String,
/// Summary of the Reasoning phase.
pub reasoning_summary: String,
/// The action that was executed.
pub action: Option<PlannedAction>,
/// Result observed after the action.
pub result_summary: Option<String>,
/// Tokens consumed this round.
pub tokens_input: u64,
pub tokens_output: u64,
/// Wall-clock duration of this round in milliseconds.
pub duration_ms: u64,
}
// ── ORAO configuration ──────────────────────────────────────────────────────
/// Configuration for an ORAO execution.
#[derive(Clone)]
pub struct OraoConfig {
/// Maximum number of ORAO rounds before giving up.
pub max_rounds: usize,
/// Maximum allowed safety level. Actions above this level are denied.
pub max_safety_level: SafetyLevel,
/// Whether to run in plan mode (generate plan first, wait for approval).
pub plan_mode: bool,
/// Whether to enable extended thinking for the reasoning phase.
pub extended_thinking: bool,
/// Per-action timeout in seconds.
pub action_timeout_secs: u64,
/// Number of consecutive no-change rounds before deadlock detection triggers.
pub deadlock_threshold: usize,
}
impl Default for OraoConfig {
fn default() -> Self {
Self {
max_rounds: 50,
max_safety_level: SafetyLevel::Network,
plan_mode: false,
extended_thinking: false,
action_timeout_secs: 120,
deadlock_threshold: 3,
}
}
}
impl std::fmt::Debug for OraoConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OraoConfig")
.field("max_rounds", &self.max_rounds)
.field("max_safety_level", &self.max_safety_level)
.field("plan_mode", &self.plan_mode)
.field("extended_thinking", &self.extended_thinking)
.field("action_timeout_secs", &self.action_timeout_secs)
.field("deadlock_threshold", &self.deadlock_threshold)
.finish()
}
}

171
libs/agent/task/events.rs Normal file
View File

@ -0,0 +1,171 @@
use models::agent_task::TaskStatus;
use serde::Serialize;
use std::sync::Arc;
/// Event payload published to WebSocket clients via Redis Pub/Sub.
#[derive(Debug, Clone, Serialize)]
pub struct TaskEvent {
pub task_id: i64,
pub project_id: uuid::Uuid,
pub parent_id: Option<i64>,
pub event: String,
pub message: Option<String>,
pub output: Option<String>,
pub error: Option<String>,
pub status: String,
}
impl TaskEvent {
pub fn started(task_id: i64, project_id: uuid::Uuid, parent_id: Option<i64>) -> Self {
Self {
task_id,
project_id,
parent_id,
event: "started".to_string(),
message: None,
output: None,
error: None,
status: TaskStatus::Running.to_string(),
}
}
pub fn progress(
task_id: i64,
project_id: uuid::Uuid,
parent_id: Option<i64>,
msg: String,
) -> Self {
Self {
task_id,
project_id,
parent_id,
event: "progress".to_string(),
message: Some(msg),
output: None,
error: None,
status: TaskStatus::Running.to_string(),
}
}
pub fn completed(
task_id: i64,
project_id: uuid::Uuid,
parent_id: Option<i64>,
output: String,
) -> Self {
Self {
task_id,
project_id,
parent_id,
event: "done".to_string(),
message: None,
output: Some(output),
error: None,
status: TaskStatus::Done.to_string(),
}
}
pub fn failed(
task_id: i64,
project_id: uuid::Uuid,
parent_id: Option<i64>,
error: String,
) -> Self {
Self {
task_id,
project_id,
parent_id,
event: "failed".to_string(),
message: None,
output: None,
error: Some(error),
status: TaskStatus::Failed.to_string(),
}
}
pub fn cancelled(task_id: i64, project_id: uuid::Uuid, parent_id: Option<i64>) -> Self {
Self {
task_id,
project_id,
parent_id,
event: "cancelled".to_string(),
message: None,
output: None,
error: None,
status: TaskStatus::Cancelled.to_string(),
}
}
}
/// Helper trait for publishing task lifecycle events via Redis Pub/Sub.
///
/// Callers inject a suitable `publish_fn` at construction time via
/// `TaskEvents::new(...)`. If no publisher is supplied events are silently
/// dropped (graceful degradation on startup).
pub trait TaskEventPublisher: Send + Sync {
fn publish(&self, project_id: uuid::Uuid, event: TaskEvent);
}
/// No-op publisher used when no Redis Pub/Sub connection is available.
#[derive(Clone, Default)]
pub struct NoOpPublisher;
impl TaskEventPublisher for NoOpPublisher {
fn publish(&self, _: uuid::Uuid, _: TaskEvent) {}
}
#[derive(Clone)]
pub struct TaskEvents {
publisher: Arc<dyn TaskEventPublisher>,
}
impl TaskEvents {
pub fn new(publisher: impl TaskEventPublisher + 'static) -> Self {
Self {
publisher: Arc::new(publisher),
}
}
pub fn noop() -> Self {
Self::new(NoOpPublisher)
}
fn emit(&self, task: &models::agent_task::Model, event: TaskEvent) {
self.publisher.publish(task.project_uuid, event);
}
pub fn emit_started(&self, task: &models::agent_task::Model) {
self.emit(
task,
TaskEvent::started(task.id, task.project_uuid, task.parent_id),
);
}
pub fn emit_progress(&self, task: &models::agent_task::Model, msg: String) {
self.emit(
task,
TaskEvent::progress(task.id, task.project_uuid, task.parent_id, msg),
);
}
pub fn emit_completed(&self, task: &models::agent_task::Model, output: String) {
self.emit(
task,
TaskEvent::completed(task.id, task.project_uuid, task.parent_id, output),
);
}
pub fn emit_failed(&self, task: &models::agent_task::Model, error: String) {
self.emit(
task,
TaskEvent::failed(task.id, task.project_uuid, task.parent_id, error),
);
}
pub fn emit_cancelled(&self, task: &models::agent_task::Model) {
self.emit(
task,
TaskEvent::cancelled(task.id, task.project_uuid, task.parent_id),
);
}
}

View File

@ -0,0 +1,192 @@
use models::agent_task::{ActiveModel, Column as C, Entity, Model, TaskStatus};
use sea_orm::{ActiveModelTrait, ColumnTrait, DbErr, EntityTrait, QueryFilter};
pub struct TaskLifecycle;
impl super::TaskService {
/// Mark a task as running and record the start time.
pub async fn start(&self, task_id: i64) -> Result<Model, DbErr> {
let model = Entity::find_by_id(task_id).one(self.db()).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
let mut active: ActiveModel = model.into();
active.status = sea_orm::Set(TaskStatus::Running);
active.started_at = sea_orm::Set(Some(chrono::Utc::now().into()));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
let updated = active.update(self.db()).await?;
self.events().emit_started(&updated);
Ok(updated)
}
/// Update progress text (e.g., "step 2/5: analyzing PR").
pub async fn update_progress(
&self,
task_id: i64,
progress: impl Into<String>,
) -> Result<(), DbErr> {
let model = Entity::find_by_id(task_id).one(self.db()).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
let progress_str = progress.into();
let mut active: ActiveModel = model.into();
active.progress = sea_orm::Set(Some(progress_str.clone()));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
let updated = active.update(self.db()).await?;
self.events().emit_progress(&updated, progress_str);
Ok(())
}
/// Mark a task as completed with the output text.
pub async fn complete(&self, task_id: i64, output: impl Into<String>) -> Result<Model, DbErr> {
let model = Entity::find_by_id(task_id).one(self.db()).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
let mut active: ActiveModel = model.into();
active.status = sea_orm::Set(TaskStatus::Done);
let out = output.into();
active.output = sea_orm::Set(Some(out.clone()));
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
let updated = active.update(self.db()).await?;
self.events().emit_completed(&updated, out);
Ok(updated)
}
/// Mark a task as failed with an error message.
pub async fn fail(&self, task_id: i64, error: impl Into<String>) -> Result<Model, DbErr> {
let model = Entity::find_by_id(task_id).one(self.db()).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
let mut active: ActiveModel = model.into();
active.status = sea_orm::Set(TaskStatus::Failed);
let err = error.into();
active.error = sea_orm::Set(Some(err.clone()));
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
let updated = active.update(self.db()).await?;
self.events().emit_failed(&updated, err);
Ok(updated)
}
/// Propagate child task status up the tree.
///
/// Only allows cancelling tasks that are not yet in a terminal state
/// (Pending / Running / Paused).
///
/// Cancelled children are marked done so that `are_children_done()` returns
/// true for the parent after cancellation.
pub async fn cancel(&self, task_id: i64) -> Result<Model, DbErr> {
// Collect all task IDs (parent + descendants) using an explicit stack.
let mut stack = vec![task_id];
let mut idx = 0;
while idx < stack.len() {
let current = stack[idx];
let children = Entity::find()
.filter(C::ParentId.eq(current))
.all(self.db())
.await?;
for child in children {
stack.push(child.id);
}
idx += 1;
}
// Mark every collected task as cancelled (terminal state).
for id in &stack {
let model = Entity::find_by_id(*id).one(self.db()).await?;
if let Some(m) = model {
if !m.is_done() {
let mut active: ActiveModel = m.into();
active.status = sea_orm::Set(TaskStatus::Cancelled);
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
active.update(self.db()).await?;
}
}
}
let final_model = Entity::find_by_id(task_id)
.one(self.db())
.await?
.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
self.events().emit_cancelled(&final_model);
Ok(final_model)
}
/// Pause a running or pending task.
///
/// Pausing a task that is not Pending/Running is a no-op that returns
/// the current model (same behaviour as `start` on an already-running task).
pub async fn pause(&self, task_id: i64) -> Result<Model, DbErr> {
let model = Entity::find_by_id(task_id).one(self.db()).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
if !model.is_running() {
// Already in a terminal or paused state — return unchanged.
return Ok(model);
}
let mut active: ActiveModel = model.into();
active.status = sea_orm::Set(TaskStatus::Paused);
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
active.update(self.db()).await
}
/// Resume a paused task back to Running.
///
/// Returns an error if the task is not currently Paused.
pub async fn resume(&self, task_id: i64) -> Result<Model, DbErr> {
let model = Entity::find_by_id(task_id).one(self.db()).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
if model.status != TaskStatus::Paused {
return Err(DbErr::Custom(format!(
"cannot resume task {}: expected status Paused, got {}",
task_id, model.status
)));
}
let mut active: ActiveModel = model.into();
active.status = sea_orm::Set(TaskStatus::Running);
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
active.update(self.db()).await
}
/// Retry a failed or cancelled task by resetting it to Pending.
///
/// Clears `output`, `error`, and `done_at`; increments `retry_count`.
/// Only tasks in Failed or Cancelled state can be retried.
pub async fn retry(&self, task_id: i64) -> Result<Model, DbErr> {
let model = Entity::find_by_id(task_id).one(self.db()).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
match model.status {
TaskStatus::Failed | TaskStatus::Cancelled | TaskStatus::Done => {}
_ => {
return Err(DbErr::Custom(format!(
"cannot retry task {}: only Failed/Cancelled/Done tasks can be retried (got {})",
task_id, model.status
)));
}
}
let retry_count = model.retry_count.map(|c| c + 1).unwrap_or(1);
let mut active: ActiveModel = model.into();
active.status = sea_orm::Set(TaskStatus::Pending);
active.output = sea_orm::Set(None);
active.error = sea_orm::Set(None);
active.done_at = sea_orm::Set(None);
active.started_at = sea_orm::Set(None);
active.retry_count = sea_orm::Set(Some(retry_count));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
active.update(self.db()).await
}
}

View File

@ -1,4 +1,4 @@
//! Agent task service — unified task/sub-agent execution framework.
//! Agent task service — managing task/sub-agent execution lifecycle.
//!
//! A task (`agent_task` record) can be:
//! - A **root task**: initiated by a user or system event.
@ -17,6 +17,61 @@
//! This module is intentionally kept simple and synchronous with the DB.
//! Long-running execution is delegated to the caller (tokio::spawn).
pub mod service;
pub mod events;
pub mod lifecycle;
pub mod store;
pub mod tree;
pub use service::TaskService;
use db::database::AppDatabase;
pub use events::{NoOpPublisher, TaskEvent, TaskEventPublisher, TaskEvents};
pub use lifecycle::TaskLifecycle;
/// Service for managing agent tasks (root tasks and sub-tasks).
#[derive(Clone)]
pub struct TaskService {
db: AppDatabase,
events: TaskEvents,
}
impl TaskService {
pub fn new(db: AppDatabase) -> Self {
Self {
db,
events: TaskEvents::noop(),
}
}
pub fn with_events(db: AppDatabase, events: TaskEvents) -> Self {
Self { db, events }
}
pub(crate) fn db(&self) -> &AppDatabase {
&self.db
}
pub(crate) fn events(&self) -> &TaskEvents {
&self.events
}
}
/// Builder for TaskService so that the events publisher can be set independently
/// of the database connection.
#[derive(Clone, Default)]
pub struct TaskServiceBuilder {
events: Option<TaskEvents>,
}
impl TaskServiceBuilder {
pub fn with_events(mut self, events: TaskEvents) -> Self {
self.events = Some(events);
self
}
pub async fn build(self, db: AppDatabase) -> TaskService {
TaskService {
db,
events: self.events.unwrap_or_else(TaskEvents::noop),
}
}
}

View File

@ -1,600 +0,0 @@
//! Task service for creating, tracking, and executing agent tasks.
//!
//! All methods are async and interact with the database directly.
//! Execution of the task logic (running the ReAct loop, etc.) is delegated
//! to the caller — this service only manages task lifecycle and state.
use db::database::AppDatabase;
use models::agent_task::{ActiveModel, AgentType, Column as C, Entity, Model, TaskStatus};
use models::IssueId;
use sea_orm::{
entity::EntityTrait, query::{QueryFilter, QueryOrder, QuerySelect}, ActiveModelTrait,
ColumnTrait,
DbErr,
};
use serde::Serialize;
use std::sync::Arc;
/// Event payload published to WebSocket clients via Redis Pub/Sub.
#[derive(Debug, Clone, Serialize)]
pub struct TaskEvent {
pub task_id: i64,
pub project_id: uuid::Uuid,
pub parent_id: Option<i64>,
pub event: String,
pub message: Option<String>,
pub output: Option<String>,
pub error: Option<String>,
pub status: String,
}
impl TaskEvent {
pub fn started(task_id: i64, project_id: uuid::Uuid, parent_id: Option<i64>) -> Self {
Self {
task_id,
project_id,
parent_id,
event: "started".to_string(),
message: None,
output: None,
error: None,
status: TaskStatus::Running.to_string(),
}
}
pub fn progress(
task_id: i64,
project_id: uuid::Uuid,
parent_id: Option<i64>,
msg: String,
) -> Self {
Self {
task_id,
project_id,
parent_id,
event: "progress".to_string(),
message: Some(msg),
output: None,
error: None,
status: TaskStatus::Running.to_string(),
}
}
pub fn completed(
task_id: i64,
project_id: uuid::Uuid,
parent_id: Option<i64>,
output: String,
) -> Self {
Self {
task_id,
project_id,
parent_id,
event: "done".to_string(),
message: None,
output: Some(output),
error: None,
status: TaskStatus::Done.to_string(),
}
}
pub fn failed(
task_id: i64,
project_id: uuid::Uuid,
parent_id: Option<i64>,
error: String,
) -> Self {
Self {
task_id,
project_id,
parent_id,
event: "failed".to_string(),
message: None,
output: None,
error: Some(error),
status: TaskStatus::Failed.to_string(),
}
}
pub fn cancelled(task_id: i64, project_id: uuid::Uuid, parent_id: Option<i64>) -> Self {
Self {
task_id,
project_id,
parent_id,
event: "cancelled".to_string(),
message: None,
output: None,
error: None,
status: TaskStatus::Cancelled.to_string(),
}
}
}
/// Helper trait for publishing task lifecycle events via Redis Pub/Sub.
///
/// Callers inject a suitable `publish_fn` at construction time via
/// `TaskEvents::new(...)`. If no publisher is supplied events are silently
/// dropped (graceful degradation on startup).
pub trait TaskEventPublisher: Send + Sync {
fn publish(&self, project_id: uuid::Uuid, event: TaskEvent);
}
/// No-op publisher used when no Redis Pub/Sub connection is available.
#[derive(Clone, Default)]
pub struct NoOpPublisher;
impl TaskEventPublisher for NoOpPublisher {
fn publish(&self, _: uuid::Uuid, _: TaskEvent) {}
}
#[derive(Clone)]
pub struct TaskEvents {
publisher: Arc<dyn TaskEventPublisher>,
}
impl TaskEvents {
pub fn new(publisher: impl TaskEventPublisher + 'static) -> Self {
Self {
publisher: Arc::new(publisher),
}
}
pub fn noop() -> Self {
Self::new(NoOpPublisher)
}
fn emit(&self, task: &Model, event: TaskEvent) {
self.publisher.publish(task.project_uuid, event);
}
pub fn emit_started(&self, task: &Model) {
self.emit(
task,
TaskEvent::started(task.id, task.project_uuid, task.parent_id),
);
}
pub fn emit_progress(&self, task: &Model, msg: String) {
self.emit(
task,
TaskEvent::progress(task.id, task.project_uuid, task.parent_id, msg),
);
}
pub fn emit_completed(&self, task: &Model, output: String) {
self.emit(
task,
TaskEvent::completed(task.id, task.project_uuid, task.parent_id, output),
);
}
pub fn emit_failed(&self, task: &Model, error: String) {
self.emit(
task,
TaskEvent::failed(task.id, task.project_uuid, task.parent_id, error),
);
}
pub fn emit_cancelled(&self, task: &Model) {
self.emit(
task,
TaskEvent::cancelled(task.id, task.project_uuid, task.parent_id),
);
}
}
/// Builder for TaskService so that the events publisher can be set independently
/// of the database connection.
#[derive(Clone, Default)]
pub struct TaskServiceBuilder {
events: Option<TaskEvents>,
}
impl TaskServiceBuilder {
pub fn with_events(mut self, events: TaskEvents) -> Self {
self.events = Some(events);
self
}
pub async fn build(self, db: AppDatabase) -> TaskService {
TaskService {
db,
events: self.events.unwrap_or_else(TaskEvents::noop),
}
}
}
/// Service for managing agent tasks (root tasks and sub-tasks).
#[derive(Clone)]
pub struct TaskService {
db: AppDatabase,
events: TaskEvents,
}
impl TaskService {
pub fn new(db: AppDatabase) -> Self {
Self {
db,
events: TaskEvents::noop(),
}
}
pub fn with_events(db: AppDatabase, events: TaskEvents) -> Self {
Self { db, events }
}
/// Create a new task (root or sub-task) with status = pending.
pub async fn create(
&self,
project_uuid: impl Into<uuid::Uuid>,
input: impl Into<String>,
agent_type: AgentType,
) -> Result<Model, DbErr> {
self.create_with_parent(project_uuid, None, input, agent_type, None, None)
.await
}
/// Create a new task bound to an issue.
pub async fn create_for_issue(
&self,
project_uuid: impl Into<uuid::Uuid>,
issue_id: IssueId,
input: impl Into<String>,
agent_type: AgentType,
) -> Result<Model, DbErr> {
self.create_with_parent(project_uuid, None, input, agent_type, None, Some(issue_id))
.await
}
/// Create a new sub-task with a parent reference.
pub async fn create_subtask(
&self,
project_uuid: impl Into<uuid::Uuid>,
parent_id: i64,
input: impl Into<String>,
agent_type: AgentType,
title: Option<String>,
) -> Result<Model, DbErr> {
self.create_with_parent(
project_uuid,
Some(parent_id),
input,
agent_type,
title,
None,
)
.await
}
async fn create_with_parent(
&self,
project_uuid: impl Into<uuid::Uuid>,
parent_id: Option<i64>,
input: impl Into<String>,
agent_type: AgentType,
title: Option<String>,
issue_id: Option<IssueId>,
) -> Result<Model, DbErr> {
let model = ActiveModel {
project_uuid: sea_orm::Set(project_uuid.into()),
parent_id: sea_orm::Set(parent_id),
issue_id: sea_orm::Set(issue_id),
agent_type: sea_orm::Set(agent_type),
status: sea_orm::Set(TaskStatus::Pending),
title: sea_orm::Set(title),
input: sea_orm::Set(input.into()),
..Default::default()
};
model.insert(&self.db).await
}
/// Mark a task as running and record the start time.
pub async fn start(&self, task_id: i64) -> Result<Model, DbErr> {
let model = Entity::find_by_id(task_id).one(&self.db).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
let mut active: ActiveModel = model.into();
active.status = sea_orm::Set(TaskStatus::Running);
active.started_at = sea_orm::Set(Some(chrono::Utc::now().into()));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
let updated = active.update(&self.db).await?;
self.events.emit_started(&updated);
Ok(updated)
}
/// Update progress text (e.g., "step 2/5: analyzing PR").
pub async fn update_progress(
&self,
task_id: i64,
progress: impl Into<String>,
) -> Result<(), DbErr> {
let model = Entity::find_by_id(task_id).one(&self.db).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
let progress_str = progress.into();
let mut active: ActiveModel = model.into();
active.progress = sea_orm::Set(Some(progress_str.clone()));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
let updated = active.update(&self.db).await?;
self.events.emit_progress(&updated, progress_str);
Ok(())
}
/// Mark a task as completed with the output text.
pub async fn complete(&self, task_id: i64, output: impl Into<String>) -> Result<Model, DbErr> {
let model = Entity::find_by_id(task_id).one(&self.db).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
let mut active: ActiveModel = model.into();
active.status = sea_orm::Set(TaskStatus::Done);
let out = output.into();
active.output = sea_orm::Set(Some(out.clone()));
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
let updated = active.update(&self.db).await?;
self.events.emit_completed(&updated, out);
Ok(updated)
}
/// Mark a task as failed with an error message.
pub async fn fail(&self, task_id: i64, error: impl Into<String>) -> Result<Model, DbErr> {
let model = Entity::find_by_id(task_id).one(&self.db).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
let mut active: ActiveModel = model.into();
active.status = sea_orm::Set(TaskStatus::Failed);
let err = error.into();
active.error = sea_orm::Set(Some(err.clone()));
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
let updated = active.update(&self.db).await?;
self.events.emit_failed(&updated, err);
Ok(updated)
}
/// Propagate child task status up the tree.
///
/// Only allows cancelling tasks that are not yet in a terminal state
/// (Pending / Running / Paused).
///
/// Cancelled children are marked done so that `are_children_done()` returns
/// true for the parent after cancellation.
pub async fn cancel(&self, task_id: i64) -> Result<Model, DbErr> {
// Collect all task IDs (parent + descendants) using an explicit stack.
let mut stack = vec![task_id];
let mut idx = 0;
while idx < stack.len() {
let current = stack[idx];
let children = Entity::find()
.filter(C::ParentId.eq(current))
.all(&self.db)
.await?;
for child in children {
stack.push(child.id);
}
idx += 1;
}
// Mark every collected task as cancelled (terminal state).
for id in &stack {
let model = Entity::find_by_id(*id).one(&self.db).await?;
if let Some(m) = model {
if !m.is_done() {
let mut active: ActiveModel = m.into();
active.status = sea_orm::Set(TaskStatus::Cancelled);
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
active.update(&self.db).await?;
}
}
}
let final_model = Entity::find_by_id(task_id)
.one(&self.db)
.await?
.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
self.events.emit_cancelled(&final_model);
Ok(final_model)
}
/// Pause a running or pending task.
///
/// Pausing a task that is not Pending/Running is a no-op that returns
/// the current model (same behaviour as `start` on an already-running task).
pub async fn pause(&self, task_id: i64) -> Result<Model, DbErr> {
let model = Entity::find_by_id(task_id).one(&self.db).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
if !model.is_running() {
// Already in a terminal or paused state — return unchanged.
return Ok(model);
}
let mut active: ActiveModel = model.into();
active.status = sea_orm::Set(TaskStatus::Paused);
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
active.update(&self.db).await
}
/// Resume a paused task back to Running.
///
/// Returns an error if the task is not currently Paused.
pub async fn resume(&self, task_id: i64) -> Result<Model, DbErr> {
let model = Entity::find_by_id(task_id).one(&self.db).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
if model.status != TaskStatus::Paused {
return Err(DbErr::Custom(format!(
"cannot resume task {}: expected status Paused, got {}",
task_id, model.status
)));
}
let mut active: ActiveModel = model.into();
active.status = sea_orm::Set(TaskStatus::Running);
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
active.update(&self.db).await
}
/// Retry a failed or cancelled task by resetting it to Pending.
///
/// Clears `output`, `error`, and `done_at`; increments `retry_count`.
/// Only tasks in Failed or Cancelled state can be retried.
pub async fn retry(&self, task_id: i64) -> Result<Model, DbErr> {
let model = Entity::find_by_id(task_id).one(&self.db).await?;
let model =
model.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
match model.status {
TaskStatus::Failed | TaskStatus::Cancelled | TaskStatus::Done => {}
_ => {
return Err(DbErr::Custom(format!(
"cannot retry task {}: only Failed/Cancelled/Done tasks can be retried (got {})",
task_id, model.status
)));
}
}
let retry_count = model.retry_count.map(|c| c + 1).unwrap_or(1);
let mut active: ActiveModel = model.into();
active.status = sea_orm::Set(TaskStatus::Pending);
active.output = sea_orm::Set(None);
active.error = sea_orm::Set(None);
active.done_at = sea_orm::Set(None);
active.started_at = sea_orm::Set(None);
active.retry_count = sea_orm::Set(Some(retry_count));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
active.update(&self.db).await
}
/// Propagate child task status up the tree.
///
/// When a child task reaches a terminal state, checks whether all its
/// siblings are also terminal. If so, marks the parent appropriately:
/// - Done if any child succeeded
/// - Failed if all children failed or were cancelled
pub async fn propagate_to_parent(&self, task_id: i64) -> Result<Option<Model>, DbErr> {
let model = self
.get(task_id)
.await?
.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
let Some(parent_id) = model.parent_id else {
return Ok(None);
};
let siblings = self.children(parent_id).await?;
if siblings.iter().all(|s| s.is_done()) {
let parent = self.get(parent_id).await?.ok_or_else(|| {
DbErr::RecordNotFound(format!("parent task {} not found", parent_id))
})?;
if parent.is_running() {
let mut active: ActiveModel = parent.into();
let has_success = siblings.iter().any(|s| s.status == TaskStatus::Done);
if has_success {
active.status = sea_orm::Set(TaskStatus::Done);
active.error = sea_orm::Set(None);
} else {
active.status = sea_orm::Set(TaskStatus::Failed);
active.error =
sea_orm::Set(Some("All sub-tasks failed or were cancelled".to_string()));
}
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
let updated = active.update(&self.db).await?;
return Ok(Some(updated));
}
}
Ok(None)
}
/// Get a task by ID.
pub async fn get(&self, task_id: i64) -> Result<Option<Model>, DbErr> {
Entity::find_by_id(task_id).one(&self.db).await
}
/// List all sub-tasks for a parent task.
pub async fn children(&self, parent_id: i64) -> Result<Vec<Model>, DbErr> {
Entity::find()
.filter(C::ParentId.eq(parent_id))
.order_by_asc(C::CreatedAt)
.all(&self.db)
.await
}
/// List all active (non-terminal) tasks for a project.
pub async fn active_tasks(
&self,
project_uuid: impl Into<uuid::Uuid>,
) -> Result<Vec<Model>, DbErr> {
let uuid: uuid::Uuid = project_uuid.into();
Entity::find()
.filter(C::ProjectUuid.eq(uuid))
.filter(C::Status.is_in([TaskStatus::Pending, TaskStatus::Running, TaskStatus::Paused]))
.order_by_desc(C::CreatedAt)
.all(&self.db)
.await
}
/// List all tasks (root only) for a project.
pub async fn list(
&self,
project_uuid: impl Into<uuid::Uuid>,
limit: u64,
) -> Result<Vec<Model>, DbErr> {
let uuid: uuid::Uuid = project_uuid.into();
Entity::find()
.filter(C::ProjectUuid.eq(uuid))
.filter(C::ParentId.is_null())
.order_by_desc(C::CreatedAt)
.limit(limit)
.all(&self.db)
.await
}
/// Delete a task and all its sub-tasks recursively.
/// Only allows deletion of root tasks.
pub async fn delete(&self, task_id: i64) -> Result<(), DbErr> {
self.delete_recursive(task_id).await
}
async fn delete_recursive(&self, task_id: i64) -> Result<(), DbErr> {
// Collect all task IDs to delete using an explicit stack (avoiding async recursion).
let mut stack = vec![task_id];
let mut idx = 0;
while idx < stack.len() {
let current = stack[idx];
let children = Entity::find()
.filter(C::ParentId.eq(current))
.all(&self.db)
.await?;
for child in children {
stack.push(child.id);
}
idx += 1;
}
for task_id in stack {
let model = Entity::find_by_id(task_id).one(&self.db).await?;
if let Some(m) = model {
let active: ActiveModel = m.into();
active.delete(&self.db).await?;
}
}
Ok(())
}
/// Check if all sub-tasks of a given parent are in a terminal state.
/// Returns true if there are no children (empty tree counts as done).
pub async fn are_children_done(&self, parent_id: i64) -> Result<bool, DbErr> {
let children = self.children(parent_id).await?;
Ok(children.is_empty() || children.iter().all(|c| c.is_done()))
}
}

109
libs/agent/task/store.rs Normal file
View File

@ -0,0 +1,109 @@
use models::agent_task::{ActiveModel, AgentType, Entity, Model};
use models::IssueId;
use sea_orm::{ActiveModelTrait, ColumnTrait, DbErr, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
impl super::TaskService {
/// Get a task by ID.
pub async fn get(&self, task_id: i64) -> Result<Option<Model>, DbErr> {
Entity::find_by_id(task_id).one(self.db()).await
}
/// List all tasks (root only) for a project.
pub async fn list(
&self,
project_uuid: impl Into<uuid::Uuid>,
limit: u64,
) -> Result<Vec<Model>, DbErr> {
let uuid: uuid::Uuid = project_uuid.into();
Entity::find()
.filter(models::agent_task::Column::ProjectUuid.eq(uuid))
.filter(models::agent_task::Column::ParentId.is_null())
.order_by_desc(models::agent_task::Column::CreatedAt)
.limit(limit)
.all(self.db())
.await
}
/// List all active (non-terminal) tasks for a project.
pub async fn active_tasks(
&self,
project_uuid: impl Into<uuid::Uuid>,
) -> Result<Vec<Model>, DbErr> {
let uuid: uuid::Uuid = project_uuid.into();
Entity::find()
.filter(models::agent_task::Column::ProjectUuid.eq(uuid))
.filter(models::agent_task::Column::Status.is_in([
models::agent_task::TaskStatus::Pending,
models::agent_task::TaskStatus::Running,
models::agent_task::TaskStatus::Paused,
]))
.order_by_desc(models::agent_task::Column::CreatedAt)
.all(self.db())
.await
}
/// Create a new task (root or sub-task) with status = pending.
pub async fn create(
&self,
project_uuid: impl Into<uuid::Uuid>,
input: impl Into<String>,
agent_type: AgentType,
) -> Result<Model, DbErr> {
self.create_with_parent(project_uuid, None, input, agent_type, None, None)
.await
}
/// Create a new task bound to an issue.
pub async fn create_for_issue(
&self,
project_uuid: impl Into<uuid::Uuid>,
issue_id: IssueId,
input: impl Into<String>,
agent_type: AgentType,
) -> Result<Model, DbErr> {
self.create_with_parent(project_uuid, None, input, agent_type, None, Some(issue_id))
.await
}
/// Create a new sub-task with a parent reference.
pub async fn create_subtask(
&self,
project_uuid: impl Into<uuid::Uuid>,
parent_id: i64,
input: impl Into<String>,
agent_type: AgentType,
title: Option<String>,
) -> Result<Model, DbErr> {
self.create_with_parent(
project_uuid,
Some(parent_id),
input,
agent_type,
title,
None,
)
.await
}
async fn create_with_parent(
&self,
project_uuid: impl Into<uuid::Uuid>,
parent_id: Option<i64>,
input: impl Into<String>,
agent_type: AgentType,
title: Option<String>,
issue_id: Option<IssueId>,
) -> Result<Model, DbErr> {
let model = ActiveModel {
project_uuid: sea_orm::Set(project_uuid.into()),
parent_id: sea_orm::Set(parent_id),
issue_id: sea_orm::Set(issue_id),
agent_type: sea_orm::Set(agent_type),
status: sea_orm::Set(models::agent_task::TaskStatus::Pending),
title: sea_orm::Set(title),
input: sea_orm::Set(input.into()),
..Default::default()
};
model.insert(self.db()).await
}
}

89
libs/agent/task/tree.rs Normal file
View File

@ -0,0 +1,89 @@
use models::agent_task::{ActiveModel, Column as C, Entity, Model, TaskStatus};
use sea_orm::{ActiveModelTrait, ColumnTrait, DbErr, EntityTrait, QueryFilter, QueryOrder};
impl super::TaskService {
/// Propagate child task status up the tree.
///
/// When a child task reaches a terminal state, checks whether all its
/// siblings are also terminal. If so, marks the parent appropriately:
/// - Done if any child succeeded
/// - Failed if all children failed or were cancelled
pub async fn propagate_to_parent(&self, task_id: i64) -> Result<Option<Model>, DbErr> {
let model = self
.get(task_id)
.await?
.ok_or_else(|| DbErr::RecordNotFound("agent_task not found".to_string()))?;
let Some(parent_id) = model.parent_id else {
return Ok(None);
};
let siblings = self.children(parent_id).await?;
if siblings.iter().all(|s| s.is_done()) {
let parent = self.get(parent_id).await?.ok_or_else(|| {
DbErr::RecordNotFound(format!("parent task {} not found", parent_id))
})?;
if parent.is_running() {
let mut active: ActiveModel = parent.into();
let has_success = siblings.iter().any(|s| s.status == TaskStatus::Done);
if has_success {
active.status = sea_orm::Set(TaskStatus::Done);
active.error = sea_orm::Set(None);
} else {
active.status = sea_orm::Set(TaskStatus::Failed);
active.error =
sea_orm::Set(Some("All sub-tasks failed or were cancelled".to_string()));
}
active.done_at = sea_orm::Set(Some(chrono::Utc::now().into()));
active.updated_at = sea_orm::Set(chrono::Utc::now().into());
let updated = active.update(self.db()).await?;
return Ok(Some(updated));
}
}
Ok(None)
}
/// List all sub-tasks for a parent task.
pub async fn children(&self, parent_id: i64) -> Result<Vec<Model>, DbErr> {
Entity::find()
.filter(C::ParentId.eq(parent_id))
.order_by_asc(C::CreatedAt)
.all(self.db())
.await
}
/// Check if all sub-tasks of a given parent are in a terminal state.
/// Returns true if there are no children (empty tree counts as done).
pub async fn are_children_done(&self, parent_id: i64) -> Result<bool, DbErr> {
let children = self.children(parent_id).await?;
Ok(children.is_empty() || children.iter().all(|c| c.is_done()))
}
/// Delete a task and all its sub-tasks recursively.
/// Only allows deletion of root tasks.
pub async fn delete(&self, task_id: i64) -> Result<(), DbErr> {
// Collect all task IDs to delete using an explicit stack (avoiding async recursion).
let mut stack = vec![task_id];
let mut idx = 0;
while idx < stack.len() {
let current = stack[idx];
let children = Entity::find()
.filter(C::ParentId.eq(current))
.all(self.db())
.await?;
for child in children {
stack.push(child.id);
}
idx += 1;
}
for task_id in stack {
let model = Entity::find_by_id(task_id).one(self.db()).await?;
if let Some(m) = model {
let active: ActiveModel = m.into();
active.delete(self.db()).await?;
}
}
Ok(())
}
}

View File

@ -13,7 +13,7 @@ use std::collections::HashMap;
use std::sync::OnceLock;
use std::sync::RwLock;
use crate::error::{AgentError, Result};
use crate::error::Result;
static TOKENIZER_CACHE: OnceLock<RwLock<HashMap<String, tiktoken_rs::CoreBPE>>> = OnceLock::new();
@ -173,12 +173,11 @@ fn get_tokenizer(model: &str) -> Result<tiktoken_rs::CoreBPE> {
}
// Try model-specific tokenizer first
let bpe = if let Ok(bpe) = tiktoken_rs::get_bpe_from_model(model) {
let bpe: &'static _ = if let Ok(bpe) = tiktoken_rs::bpe_for_model(model) {
bpe
} else {
// Fallback: use cl100k_base for unknown models
tiktoken_rs::cl100k_base()
.map_err(|e| AgentError::Internal(format!("Failed to init tokenizer: {}", e)))?
tiktoken_rs::cl100k_base_singleton()
};
{
@ -186,7 +185,7 @@ fn get_tokenizer(model: &str) -> Result<tiktoken_rs::CoreBPE> {
cache.insert(model.to_string(), bpe.clone());
}
Ok(bpe)
Ok(bpe.clone())
}
/// Estimate tokens for a simple prefix/suffix pattern (e.g., "assistant\n" + text).

View File

@ -1,256 +0,0 @@
# Hook Queue NATS JetStream Migration Guide
## Overview
The git hook queue now supports both Redis Lists and NATS JetStream as backend message queues. This allows gradual migration from Redis to NATS without downtime.
## Architecture
### Producer (`ReceiveSyncService`)
The producer tries NATS first (if configured), then falls back to Redis:
```rust
pub struct ReceiveSyncService {
pool: deadpool_redis::cluster::Pool,
redis_prefix: String,
nats_publish: Option<Arc<dyn Fn(String, Vec<u8>) -> Pin<Box<dyn Future<Output = Result<u64>> + Send>> + Send + Sync>>,
}
```
### Consumer (`RedisConsumer`)
The consumer uses NATS if configured, otherwise falls back to Redis:
```rust
pub struct RedisConsumer {
pool: deadpool_redis::cluster::Pool,
prefix: String,
block_timeout_secs: u64,
nats_consume: Option<NatsHookConsumeFn>,
}
```
## Integration with AppTransport
### Producer Integration
```rust
use git::ssh::ReceiveSyncService;
use transport::AppTransport;
let transport = Arc::new(AppTransport::new(/* ... */));
// Create NATS publish function
let nats_publish = {
let transport = transport.clone();
Arc::new(move |subject: String, payload: Vec<u8>| {
let transport = transport.clone();
Box::pin(async move {
let ack = transport.publish(&subject, payload).await?;
Ok(ack.sequence)
}) as Pin<Box<dyn Future<Output = anyhow::Result<u64>> + Send>>
})
};
// Create service with NATS support
let sync_service = ReceiveSyncService::with_nats(redis_pool, nats_publish);
// Or use Redis-only mode
let sync_service = ReceiveSyncService::new(redis_pool);
```
### Consumer Integration
```rust
use git::hook::pool::redis::{RedisConsumer, NatsHookConsumeFn};
// Create NATS consume function
let nats_consume: NatsHookConsumeFn = {
let transport = transport.clone();
Arc::new(move |subject: String, batch_size: usize| {
let transport = transport.clone();
Box::pin(async move {
let mut results = Vec::new();
// Pull messages from JetStream consumer
for _ in 0..batch_size {
match transport.pull_one(&subject).await {
Ok(Some(msg)) => {
let data = msg.payload.to_vec();
let msg_clone = msg.clone();
let ack_fn = Box::new(move || {
let msg = msg_clone.clone();
Box::pin(async move {
msg.ack().await?;
Ok(())
}) as Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>>
});
results.push((data, ack_fn));
}
Ok(None) => break,
Err(e) => return Err(e),
}
}
Ok(results)
}) as Pin<Box<dyn Future<Output = anyhow::Result<Vec<(Vec<u8>, Box<dyn Fn() -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>> + Send>)>>> + Send>>
})
};
// Create consumer with NATS support
let consumer = RedisConsumer::with_nats(
redis_pool,
"{hook}".to_string(),
5, // block_timeout_secs
nats_consume,
);
// Or use Redis-only mode
let consumer = RedisConsumer::new(redis_pool, "{hook}".to_string(), 5);
```
## Queue Subjects
The hook queue uses the following NATS subjects:
- `queue.hook.sync` - Repository sync tasks (git push/pull operations)
Additional task types can be added by extending the subject pattern:
- `queue.hook.{task_type}` - Generic pattern for any hook task type
## Migration Strategy
### Phase 1: Dual Write (Current)
- Producer writes to both NATS and Redis
- Consumer reads from Redis only
- Zero risk, full rollback capability
### Phase 2: Dual Read
- Producer writes to both NATS and Redis
- Consumer reads from NATS, falls back to Redis on error
- Validates NATS consumer stability
### Phase 3: NATS Primary
- Producer writes to NATS only (Redis disabled)
- Consumer reads from NATS only
- Redis queue deprecated
### Phase 4: Redis Removal
- Remove Redis Lists code
- Remove `pool` parameter
- Simplify to NATS-only implementation
## NATS JetStream Setup
### Stream Configuration
```bash
nats stream add HOOK_QUEUE \
--subjects "queue.hook.>" \
--retention limits \
--max-msgs=-1 \
--max-age=7d \
--storage file \
--replicas 3
```
### Consumer Configuration
```bash
nats consumer add HOOK_QUEUE hook-sync-worker \
--filter "queue.hook.sync" \
--ack explicit \
--pull \
--deliver all \
--max-deliver 3 \
--max-pending 100
```
## Differences from Email Queue
### Redis Backend
- **Email Queue**: Uses Redis Streams (XADD/XREADGROUP)
- **Hook Queue**: Uses Redis Lists (LPUSH/BLMOVE)
### Atomicity
- **Email Queue**: Consumer group provides at-least-once delivery
- **Hook Queue**: BLMOVE provides atomic move-to-work-queue pattern
### Work Queue Pattern
- **Email Queue**: No work queue, relies on consumer group
- **Hook Queue**: Uses separate work queue (`{hook}:sync:work`) for in-flight tracking
### Acknowledgment
- **Email Queue**: XACK removes from pending entries list
- **Hook Queue**: LREM removes from work queue
### Retry Logic
- **Email Queue**: Automatic via consumer group pending entries
- **Hook Queue**: Manual via Lua script (LREM + LPUSH)
## Monitoring
### Logs
- NATS publish: `"hook task queued to NATS"`
- Redis publish: `"hook task queued to Redis"`
- NATS consume: `"task dequeued from NATS"`
- Redis consume: `"task dequeued"`
### Metrics
Add these metrics to track hook queue performance:
```rust
counter!("hook_task_queued_total", "backend" => "nats").increment(1);
counter!("hook_task_queued_total", "backend" => "redis").increment(1);
counter!("hook_task_consumed_total", "backend" => "nats").increment(1);
counter!("hook_task_consumed_total", "backend" => "redis").increment(1);
```
## Rollback
To disable NATS and return to Redis-only:
```rust
// Producer
let sync_service = ReceiveSyncService::new(redis_pool);
// Consumer
let consumer = RedisConsumer::new(redis_pool, "{hook}".to_string(), 5);
```
No code changes required, just use the `new()` constructor instead of `with_nats()`.
## Benefits
1. **Zero Downtime**: Gradual migration with fallback
2. **No Circular Dependency**: Uses function pointers instead of crate dependencies
3. **Backward Compatible**: Existing code works without changes
4. **Type Safe**: Compile-time guarantees for integration
5. **Observable**: Consistent logging for both backends
## Known Limitations
### NATS Acknowledgment Timing
The current implementation acks NATS messages immediately after deserialization, not after successful processing. This is different from the Redis pattern where:
- Redis: Task moves to work queue → processes → acks (removes from work queue)
- NATS: Task received → acks immediately → processes
**Future Enhancement**: Store ack functions in a map keyed by task ID, then call them after successful processing. This requires refactoring the worker loop to track pending acks.
### Work Queue Pattern
NATS JetStream doesn't have a direct equivalent to Redis's work queue pattern. The current implementation relies on JetStream's built-in redelivery mechanism instead of a separate work queue.
## Next Steps
1. Add NATS integration to `apps/git-hook/src/main.rs`
2. Add configuration flags for queue backend selection
3. Test dual-write mode in staging
4. Monitor NATS consumer stability
5. Implement proper ack-after-processing pattern
6. Add metrics for queue depth and processing latency

View File

@ -1,107 +0,0 @@
# 构建脚本
## 一键构建脚本
### build.js - 构建镜像
```bash
# 构建所有镜像
node scripts/build.js
# 构建指定服务
node scripts/build.js app gitserver
# 指定 tag
TAG=v1.0.0 node scripts/build.js
# 指定架构
TARGET=aarch64-unknown-linux-gnu node scripts/build.js
```
**环境变量:**
| 变量 | 默认值 | 说明 |
|------------|------------------------------|-------------|
| `REGISTRY` | `harbor.gitdata.me/gta_team` | 镜像仓库 |
| `TAG` | `latest` | 镜像标签 |
| `TARGET` | `x86_64-unknown-linux-gnu` | Rust 交叉编译目标 |
---
### push.js - 推送镜像
```bash
# 推送所有镜像
HARBOR_USERNAME=user HARBOR_PASSWORD=pass node scripts/push.js
# 推送指定服务
HARBOR_USERNAME=user HARBOR_PASSWORD=pass TAG=sha-abc123 node scripts/push.js app
```
**环境变量:**
| 变量 | 默认值 | 说明 |
|-------------------|------------------------------|--------------|
| `REGISTRY` | `harbor.gitdata.me/gta_team` | 镜像仓库 |
| `TAG` | `latest` 或 Git SHA | 镜像标签 |
| `HARBOR_USERNAME` | - | **必填** 仓库用户名 |
| `HARBOR_PASSWORD` | - | **必填** 仓库密码 |
---
### deploy.js - 部署到 Kubernetes
```bash
# 部署最新镜像
node scripts/deploy.js
# 干跑模式(不实际部署)
node scripts/deploy.js --dry-run
# 部署并运行数据库迁移
node scripts/deploy.js --migrate
# 指定 tag
TAG=sha-abc123 node scripts/deploy.js
# 指定命名空间
NAMESPACE=staging node scripts/deploy.js
```
**环境变量:**
| 变量 | 默认值 | 说明 |
|--------------|------------------------------|-----------------|
| `REGISTRY` | `harbor.gitdata.me/gta_team` | 镜像仓库 |
| `TAG` | `latest` 或 Git SHA | 镜像标签 |
| `NAMESPACE` | `gitdata` | K8s 命名空间 |
| `RELEASE` | `gitdata` | Helm Release 名称 |
| `KUBECONFIG` | `~/.kube/config` | Kubeconfig 路径 |
---
## 完整 CI/CD 流程
```bash
# 1. 构建
node scripts/build.js
# 2. 推送
HARBOR_USERNAME=user HARBOR_PASSWORD=pass node scripts/push.js
# 3. 部署
node scripts/deploy.js --migrate
```
## 本地开发
```bash
# 本地构建测试
node scripts/build.js app
# 使用本地 tag
TAG=dev node scripts/build.js
# 部署到测试环境
NAMESPACE=test node scripts/deploy.js
```

View File

@ -1,28 +0,0 @@
/**
* Fix ugly module-path tags in openapi.json generated by utoipa.
* Transforms tags like "crate::agent::provider" -> "agent-provider"
*/
const fs = require('fs');
const path = require('path');
const openapiPath = path.join(__dirname, '..', 'openapi.json');
const json = JSON.parse(fs.readFileSync(openapiPath, 'utf8'));
let fixed = 0;
// Fix operation tags
for (const p in json.paths) {
for (const m in json.paths[p]) {
const op = json.paths[p][m];
if (op.tags) {
op.tags = op.tags.map(t => {
const fixed = t.replace(/^crate::/, '').replace(/::/g, '-');
return fixed;
});
fixed++;
}
}
}
fs.writeFileSync(openapiPath, JSON.stringify(json, null, 2));
console.log(`Fixed tags in ${fixed} operations`);

View File

@ -1,55 +0,0 @@
/**
* Generate TypeScript axios client from openapi.json using @hey-api/openapi-ts.
* Generates into src/client.
* Post-processes: injects withCredentials: true and baseURL into the client config.
*/
const { execSync } = require('child_process');
const fs = require('fs');
const path = require('path');
const ROOT = path.join(__dirname, '..');
const CLIENT_DIR = path.join(ROOT, 'src', 'client');
const CLIENT_GEN = path.join(CLIENT_DIR, 'client.gen.ts');
const openapiTsBin = path.join(ROOT, 'node_modules/@hey-api/openapi-ts/bin/run.js');
const openapiJson = path.join(ROOT, 'openapi.json');
console.log('Running @hey-api/openapi-ts...');
try {
execSync(`node "${openapiTsBin}" -c @hey-api/client-axios -i "${openapiJson}" -o "${CLIENT_DIR}"`, {
cwd: ROOT,
stdio: 'inherit',
});
} catch (e) {
console.error('Generator exited with code:', e.status);
process.exit(1);
}
// Post-process: inject withCredentials and baseURL into client config
if (fs.existsSync(CLIENT_GEN)) {
let content = fs.readFileSync(CLIENT_GEN, 'utf8');
// Remove unused createConfig import
content = content.replace(
"import { type ClientOptions, type Config, createClient, createConfig } from './client';",
"import { type ClientOptions, type Config, createClient } from './client';"
);
// Replace the client creation to include withCredentials and baseURL
content = content.replace(
'export const client = createClient(createConfig<ClientOptions2>());',
`export const createClientConfig = (override?: Config<ClientOptions2>): Config<ClientOptions2> => {
return {
withCredentials: true,
baseURL: import.meta.env.VITE_API_BASE_URL ?? '',
...override,
};
};
export const client = createClient(createClientConfig());`
);
fs.writeFileSync(CLIENT_GEN, content);
console.log('Updated client.gen.ts with withCredentials and baseURL');
}
console.log('Done.');

View File

@ -1,89 +0,0 @@
/**
* Generates changelog data file for the frontend.
* Run with: node scripts/generate-changelog-data.js
*/
const fs = require('fs');
const path = require('path');
const CHANGELOG_DIR = path.join(__dirname, '..', 'changelog');
const OUTPUT_FILE = path.join(__dirname, '..', 'src', 'data', 'changelog-data.ts');
const LANGUAGES = ['en', 'cn', 'de', 'fr'];
function readFile(filePath) {
try {
return fs.readFileSync(filePath, 'utf-8');
} catch {
return null;
}
}
function parseMdx(content) {
const frontmatterMatch = content.match(/^---\n([\s\S]*?)\n---\n([\s\S]*)$/);
if (!frontmatterMatch) {
return { title: '', body: content };
}
const body = frontmatterMatch[2].trim();
const frontmatter = frontmatterMatch[1];
const titleMatch = frontmatter.match(/title:\s*["']?([^"'\n]+)["']?/);
const title = titleMatch ? titleMatch[1].trim() : '';
return { title, body };
}
// Get all unique dates
const dates = [];
const files = fs.readdirSync(CHANGELOG_DIR);
files.forEach(file => {
const match = file.match(/^(\d{4}-\d{2}-\d{2})-(\w+)\.mdx$/);
if (match) {
const date = match[1];
const lang = match[2];
if (!dates.includes(date)) {
dates.push(date);
}
}
});
// Sort dates descending
dates.sort((a, b) => new Date(b) - new Date(a));
// Generate data for each language
const data = {};
LANGUAGES.forEach(lang => {
data[lang] = dates.map(date => {
const filePath = path.join(CHANGELOG_DIR, `${date}-${lang}.mdx`);
const content = readFile(filePath);
if (!content) {
return null;
}
const { title, body } = parseMdx(content);
return {
date,
title,
lang,
author: 'ZhenYi',
body,
};
}).filter(Boolean);
});
// Generate TypeScript file
const tsContent = `// Auto-generated from changelog/*.mdx files
// Run: node scripts/generate-changelog-data.js
export type ChangelogEntry = {
date: string;
title: string;
lang: string;
author: string;
body: string;
};
export const CHANGELOG_DATA: Record<string, ChangelogEntry[]> = ${JSON.stringify(data, null, 2)};
export const CHANGELOG_LANGUAGES = ${JSON.stringify(LANGUAGES)};
`;
fs.writeFileSync(OUTPUT_FILE, tsContent);
console.log(`Generated ${OUTPUT_FILE}`);