From f7e087e0661a025b5ad4b71b2e2037ee78bc15f0 Mon Sep 17 00:00:00 2001 From: ZhenYi <434836402@qq.com> Date: Sat, 25 Apr 2026 09:53:31 +0800 Subject: [PATCH] fix(agent/service): retry jitter, tool executor ordering, curl SSRF, grep/JSON - agent/client: full jitter backoff (random(0, base_ms)) instead of equal jitter - agent/tool/executor: fix buffer_unordered ordering mismatch with HashMap-by-index approach for concurrent tool execution - agent/chat: AiChunkType emit fixes, is_retryable_tool_error refinements, process_react uses request.max_tool_depth - agent/chat/context: fix Function message sender_name field - file_tools/curl: shared reqwest::Client via OnceLock, manual redirect following with per-hop SSRF validation, blocked sensitive headers - file_tools/grep: fix case-insensitive glob matching, segment consumption - file_tools/json: bracket notation support, remove .vscodeignore from JSONC - git_tools: git_diff_stats resolve base/head independently, DiffFileOut old_file.path for Deleted, reflog offset_minutes - git/repo: create_commit read parent tree into index, bare repo init - project_tools/repos: branch/path validation, .git/ prefix check - service/agent: tokent integration, billing, pr_summary, code_review fixes --- libs/agent/chat/context.rs | 4 +- libs/agent/chat/mod.rs | 25 ++- libs/agent/chat/service.rs | 135 ++++++++++--- libs/agent/client.rs | 22 +- libs/agent/embed/client.rs | 6 +- libs/agent/embed/service.rs | 4 +- libs/agent/react/loop_core.rs | 39 +++- libs/agent/task/service.rs | 17 +- libs/agent/tokent.rs | 2 + libs/agent/tool/call.rs | 4 +- libs/agent/tool/executor.rs | 47 +++-- libs/service/agent/code_review.rs | 2 +- libs/service/agent/issue_triage.rs | 232 +++++++++++++++++++++ libs/service/agent/mod.rs | 1 + libs/service/agent/pr_summary.rs | 2 +- libs/service/file_tools/grep.rs | 32 +-- libs/service/file_tools/json.rs | 97 ++++++--- libs/service/git/repo.rs | 4 + libs/service/git_tools/branch.rs | 10 +- libs/service/git_tools/commit.rs | 12 +- libs/service/git_tools/diff.rs | 27 ++- libs/service/git_tools/tree.rs | 8 +- libs/service/git_tools/types.rs | 15 +- libs/service/issue/issue.rs | 11 +- libs/service/issue/label.rs | 117 +++++++++++ libs/service/issue/mod.rs | 2 +- libs/service/project/repo.rs | 2 + libs/service/project_tools/arxiv.rs | 2 +- libs/service/project_tools/boards.rs | 2 + libs/service/project_tools/curl.rs | 290 +++++++++++++++++++-------- libs/service/project_tools/issues.rs | 28 ++- libs/service/project_tools/repos.rs | 78 ++++++- libs/service/search/service.rs | 187 +++++++++++++++++ 33 files changed, 1220 insertions(+), 246 deletions(-) create mode 100644 libs/service/agent/issue_triage.rs diff --git a/libs/agent/chat/context.rs b/libs/agent/chat/context.rs index 438d632..9d55da6 100644 --- a/libs/agent/chat/context.rs +++ b/libs/agent/chat/context.rs @@ -40,7 +40,7 @@ impl AiContextSenderType { models::rooms::MessageSenderType::Owner => Self::User, models::rooms::MessageSenderType::Ai => Self::Ai, models::rooms::MessageSenderType::System => Self::System, - models::rooms::MessageSenderType::Tool => Self::Function, + models::rooms::MessageSenderType::Tool => Self::FunctionResult, models::rooms::MessageSenderType::Guest => Self::User, } } @@ -135,7 +135,7 @@ impl RoomMessageContext { AiContextSenderType::Function => { ChatCompletionRequestMessage::Function(ChatCompletionRequestFunctionMessage { content: Some(self.content.clone()), - name: self.display_content(), // Function name is stored in content + name: self.sender_name.clone().unwrap_or_else(|| "unknown".to_string()), }) } AiContextSenderType::FunctionResult => { diff --git a/libs/agent/chat/mod.rs b/libs/agent/chat/mod.rs index 362aa8d..d2848ad 100644 --- a/libs/agent/chat/mod.rs +++ b/libs/agent/chat/mod.rs @@ -13,13 +13,36 @@ use std::collections::HashMap; use uuid::Uuid; /// Maximum recursion rounds for tool-call loops (AI → tool → result → AI). -pub const DEFAULT_MAX_TOOL_DEPTH: usize = 3; +/// Previous default of 3 caused frequent silent termination on realistic multi-step queries. +pub const DEFAULT_MAX_TOOL_DEPTH: usize = 99; /// A single chunk from an AI streaming response. #[derive(Debug, Clone)] pub struct AiStreamChunk { pub content: String, pub done: bool, + /// What kind of content this chunk contains — helps the frontend render + /// thinking, tool calls, and results with different styles. + pub chunk_type: AiChunkType, +} + +/// Type of streaming chunk, used by the frontend for rendering. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AiChunkType { + /// AI reasoning/thinking text before a tool call or answer. + Thinking, + /// Final answer text from the AI. + Answer, + /// A tool call is being executed (content = tool name + args summary). + ToolCall, + /// Tool execution result (content = result or error). + ToolResult, +} + +impl Default for AiChunkType { + fn default() -> Self { + Self::Answer + } } /// Optional streaming callback: called for each token chunk. diff --git a/libs/agent/chat/service.rs b/libs/agent/chat/service.rs index 272699e..42509d7 100644 --- a/libs/agent/chat/service.rs +++ b/libs/agent/chat/service.rs @@ -7,7 +7,7 @@ use async_openai::types::chat::{ ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, ChatCompletionRequestUserMessage, - ChatCompletionRequestUserMessageContent, ChatCompletionTool, + ChatCompletionTool, ChatCompletionTools, CreateChatCompletionRequest, CreateChatCompletionResponse, CreateChatCompletionStreamResponse, FinishReason, ReasoningEffort, ToolChoiceOptions, }; @@ -18,7 +18,7 @@ use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; use uuid::Uuid; use super::context::RoomMessageContext; -use super::{AiRequest, AiStreamChunk, Mention, StreamCallback}; +use super::{AiChunkType, AiRequest, AiStreamChunk, Mention, StreamCallback}; use crate::client::AiClientConfig; use crate::compact::{CompactConfig, CompactService}; use crate::embed::EmbedService; @@ -195,29 +195,38 @@ impl ChatService { .collect(); if !calls.is_empty() { + let calls_for_error = calls.clone(); let tool_messages = match self.execute_tool_calls(calls, &request).await { Ok(msgs) => msgs, Err(e) => { - // Surface the error as a tool result so the model can continue - let err_text = format!("[Tool call failed: {}]", e); - messages.push(ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: ChatCompletionRequestUserMessageContent::Text(err_text.clone()), - name: None, - }, - )); - tool_depth += 1; - if tool_depth >= max_tool_depth { - return Ok(err_text); - } - continue; + // Surface the error as per-call Tool messages (with matching IDs) + // so the API contract (Tool after Assistant+tool_calls) is preserved. + calls_for_error.iter().map(|c| { + ChatCompletionRequestMessage::Tool( + ChatCompletionRequestToolMessage { + tool_call_id: c.id.clone(), + content: ChatCompletionRequestToolMessageContent::Text( + format!("[Tool call failed: {}]", e), + ), + }, + ) + }).collect() } }; messages.extend(tool_messages); tool_depth += 1; if tool_depth >= max_tool_depth { - return Ok(String::new()); + // Return accumulated content rather than empty string so the user + // sees whatever the AI has produced so far. + let text = choice.message.content.unwrap_or_default(); + if text.is_empty() { + return Ok(format!( + "[AI reached maximum tool depth ({}) — no final answer produced]", + max_tool_depth + )); + } + return Ok(text); } continue; } @@ -320,6 +329,7 @@ impl ChatService { on_chunk(AiStreamChunk { content: text_accumulated.clone(), done: false, + chunk_type: AiChunkType::Answer, }) .await; } @@ -364,12 +374,13 @@ impl ChatService { .collect(); if !tool_calls.is_empty() { - // Capture thinking text, send it as a completed chunk, then clear for the next turn + // Capture thinking text, send it as a non-final chunk, then clear for the next turn let thinking_text = text_accumulated.clone(); if !thinking_text.is_empty() { on_chunk(AiStreamChunk { content: thinking_text.clone(), - done: true, + done: false, + chunk_type: AiChunkType::Thinking, }) .await; } @@ -406,29 +417,88 @@ impl ChatService { }, )); + let calls_for_error = tool_calls.clone(); + + // Notify frontend which tools are being called + let call_summary: Vec = calls_for_error.iter().map(|c| { + let args_preview: String = { + let args_json: serde_json::Value = serde_json::from_str(&c.arguments) + .unwrap_or(serde_json::Value::Null); + // Show truncated args for readability + let s = serde_json::to_string(&args_json).unwrap_or_default(); + if s.len() > 200 { s[..200].to_string() + "..." } else { s } + }; + format!("{}({})", c.name, args_preview) + }).collect(); + on_chunk(AiStreamChunk { + content: format!("[Calling tools: {}]", call_summary.join(", ")), + done: false, + chunk_type: AiChunkType::ToolCall, + }) + .await; + let tool_messages = match self.execute_tool_calls(tool_calls, &request).await { - Ok(msgs) => msgs, + Ok(msgs) => { + // Stream tool results to frontend so user can see what happened + let result_summary: Vec = msgs.iter().map(|m| { + if let ChatCompletionRequestMessage::Tool(tm) = m { + match &tm.content { + ChatCompletionRequestToolMessageContent::Text(t) => { + if t.len() > 300 { t[..300].to_string() + "..." } else { t.clone() } + } + _ => "[binary content]".to_string(), + } + } else { "unknown".to_string() } + }).collect(); + on_chunk(AiStreamChunk { + content: format!("[Tool results: {}]", result_summary.join("; ")), + done: false, + chunk_type: AiChunkType::ToolResult, + }) + .await; + msgs + } Err(e) => { - // Stream the FC error as an observation so the user sees it + // Stream the FC error as a non-final observation so the user sees it, + // but do NOT mark done=true — the AI will continue after seeing the error. let err_text = format!("[Tool call failed: {}]", e); on_chunk(AiStreamChunk { content: err_text.clone(), - done: true, + done: false, + chunk_type: AiChunkType::ToolResult, }) .await; - // Return an empty tool result so the loop can continue - vec![ChatCompletionRequestMessage::Tool( - ChatCompletionRequestToolMessage { - tool_call_id: String::new(), - content: ChatCompletionRequestToolMessageContent::Text(err_text), - }, - )] + // Return per-call Tool messages with matching IDs to preserve API contract + calls_for_error.iter().map(|c| { + ChatCompletionRequestMessage::Tool( + ChatCompletionRequestToolMessage { + tool_call_id: c.id.clone(), + content: ChatCompletionRequestToolMessageContent::Text(err_text.clone()), + }, + ) + }).collect() } }; messages.extend(tool_messages); tool_depth += 1; if tool_depth >= max_tool_depth { + // Emit a final done chunk with whatever content we have so the + // client receives a completion signal instead of hanging forever. + let final_content = if text_accumulated.is_empty() { + format!( + "[AI reached maximum tool depth ({}) — no final answer produced]", + max_tool_depth + ) + } else { + text_accumulated.clone() + }; + on_chunk(AiStreamChunk { + content: final_content, + done: true, + chunk_type: AiChunkType::Answer, + }) + .await; return Ok(()); } continue; @@ -438,6 +508,7 @@ impl ChatService { on_chunk(AiStreamChunk { content: text_accumulated, done: true, + chunk_type: AiChunkType::Answer, }) .await; return Ok(()); @@ -748,7 +819,8 @@ impl ChatService { /// Returns true if the error message indicates a transient failure that can be retried. fn is_retryable_tool_error(msg: &str) -> bool { let msg_lower = msg.to_lowercase(); - // Transient errors: network, timeouts, rate limits, permission issues that may be temporary + // Transient errors: network, timeouts, rate limits + // Permission/access errors are NOT retryable — they won't succeed on retry. msg_lower.contains("connection") || msg_lower.contains("timeout") || msg_lower.contains("timed out") @@ -762,9 +834,6 @@ impl ChatService { || msg_lower.contains("broken pipe") || msg_lower.contains("deadline exceeded") || msg_lower.contains("try again") - || msg_lower.contains("not found") // DB/Redis transient not-found - || msg_lower.contains("permission denied") - || msg_lower.contains("access denied") } /// Process a request using the ReAct (Reasoning + Acting) agent. @@ -885,7 +954,7 @@ impl ChatService { let tools = self.tools(); let config = ReactConfig { - max_steps: 20, + max_steps: request.max_tool_depth, stop_sequences: Vec::new(), tool_executor: Some(executor), }; diff --git a/libs/agent/client.rs b/libs/agent/client.rs index 1483f5a..ea17369 100644 --- a/libs/agent/client.rs +++ b/libs/agent/client.rs @@ -113,16 +113,17 @@ impl RetryState { self.attempt < self.max_retries } - /// Calculate backoff duration with "full jitter" technique. + /// Calculate backoff duration with full jitter technique. + /// sleep = random(0, min(cap, base * 2^attempt)) fn backoff_duration(&self) -> std::time::Duration { let exp = self.attempt.min(5); // base = 500 * 2^exp, capped at max_backoff_ms let base_ms = 500u64 .saturating_mul(2u64.pow(exp)) .min(self.max_backoff_ms); - // jitter: random [0, base_ms/2] - let jitter = (fastrand_u64(base_ms / 2 + 1)) as u64; - std::time::Duration::from_millis(base_ms / 2 + jitter) + // Full jitter: random value in [0, base_ms] + let jitter = fastrand_u64(base_ms + 1) as u64; + std::time::Duration::from_millis(jitter) } fn next(&mut self) { @@ -239,7 +240,11 @@ pub async fn call_with_retry( } } -/// Call with custom parameters (temperature, max_tokens, optional tools). +/// Call with custom parameters (temperature, max_tokens, optional tools, optional tool_choice). +/// +/// When `tool_choice` is `None` and tools are present, the default is `Auto`. +/// Pass `Some(ChatCompletionToolChoiceOption::None)` to force the model to respond +/// with text only (e.g. when you want JSON-in-text for ReAct parsing). pub async fn call_with_params( messages: &[ChatCompletionRequestMessage], model: &str, @@ -248,6 +253,7 @@ pub async fn call_with_params( max_tokens: u32, max_retries: Option, tools: Option<&[ChatCompletionTool]>, + tool_choice: Option, ) -> Result { let client = config.build_client(); let mut state = RetryState::new(max_retries.unwrap_or(3)); @@ -265,11 +271,7 @@ pub async fn call_with_params( .map(|t| ChatCompletionTools::Function(t.clone())) .collect() }), - tool_choice: tools.filter(|ts| !ts.is_empty()).map(|_| { - ChatCompletionToolChoiceOption::Mode( - async_openai::types::chat::ToolChoiceOptions::Auto, - ) - }), + tool_choice: tool_choice.clone(), ..Default::default() }; diff --git a/libs/agent/embed/client.rs b/libs/agent/embed/client.rs index d0b3358..3067bf2 100644 --- a/libs/agent/embed/client.rs +++ b/libs/agent/embed/client.rs @@ -137,8 +137,9 @@ impl EmbedClient { text: &str, room_id: &str, user_id: Option<&str>, + model: &str, ) -> crate::Result<()> { - let vector = self.embed_text(text, "").await?; + let vector = self.embed_text(text, model).await?; let point = EmbedVector { id: id.to_string(), vector, @@ -176,9 +177,10 @@ impl EmbedClient { description: &str, content: &str, project_uuid: &str, + model: &str, ) -> crate::Result<()> { let text = format!("{}: {} {}", name, description, content); - let vector = self.embed_text(&text, "").await?; + let vector = self.embed_text(&text, model).await?; let point = EmbedVector { id: id.to_string(), vector, diff --git a/libs/agent/embed/service.rs b/libs/agent/embed/service.rs index 971f540..77c9f5f 100644 --- a/libs/agent/embed/service.rs +++ b/libs/agent/embed/service.rs @@ -188,7 +188,7 @@ impl EmbedService { let desc = description.unwrap_or_default(); let id = skill_id.to_string(); self.client - .embed_skill(&id, name, desc, content, project_uuid) + .embed_skill(&id, name, desc, content, project_uuid, &self.model_name) .await } @@ -214,7 +214,7 @@ impl EmbedService { ) -> crate::Result<()> { let id = message_id.to_string(); self.client - .embed_memory(&id, text, room_id, user_id) + .embed_memory(&id, text, room_id, user_id, &self.model_name) .await } diff --git a/libs/agent/react/loop_core.rs b/libs/agent/react/loop_core.rs index d11afff..27b7366 100644 --- a/libs/agent/react/loop_core.rs +++ b/libs/agent/react/loop_core.rs @@ -1,13 +1,14 @@ //! ReAct (Reasoning + Acting) agent core. -use async_openai::types::chat::FunctionCall; use async_openai::types::chat::{ ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, ChatCompletionRequestUserMessage, - ChatCompletionRequestUserMessageContent, + ChatCompletionRequestUserMessageContent, ToolChoiceOptions, }; +use async_openai::types::chat::ChatCompletionToolChoiceOption; +use async_openai::types::chat::FunctionCall; use uuid::Uuid; use std::sync::Arc; @@ -37,9 +38,11 @@ impl ReactAgent { tools: Vec, config: ReactConfig, ) -> Self { - let messages = vec![ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: ChatCompletionRequestUserMessageContent::Text(system_prompt.to_string()), + let messages = vec![ChatCompletionRequestMessage::System( + async_openai::types::chat::ChatCompletionRequestSystemMessage { + content: async_openai::types::chat::ChatCompletionRequestSystemMessageContent::Text( + system_prompt.to_string(), + ), ..Default::default() }, )]; @@ -109,15 +112,30 @@ impl ReactAgent { { loop { if self.step_count >= self.config.max_steps { - return Err(AgentError::Internal(format!( - "ReAct agent reached max steps ({})", + // Emit a final Answer chunk so the caller receives a completion signal + // rather than a bare Err with no on_chunk notification. + let msg = format!( + "Agent reached maximum reasoning steps ({}) without producing a final answer.", self.config.max_steps - ))); + ); + on_chunk(ReactStep::Answer { + step: self.step_count, + answer: msg.clone(), + }); + return Ok(msg); } self.step_count += 1; let step = self.step_count; + let tool_choice = if self.tool_definitions.is_empty() { + None + } else { + // Force text-only response so the model follows our JSON-in-text format. + // With tool_choice=Auto the model might return native tool_calls which + // the ReAct parser ignores. + Some(ChatCompletionToolChoiceOption::Mode(ToolChoiceOptions::None)) + }; let response = call_with_params( &self.messages, model_name, @@ -130,6 +148,7 @@ impl ReactAgent { } else { Some(self.tool_definitions.as_slice()) }, + tool_choice, ) .await?; @@ -234,6 +253,10 @@ impl ReactAgent { _ => {} } + // Append assistant message with tool_calls so the Tool message has a matching parent. + let assistant_msg = build_tool_call_message(&act); + self.messages.push(assistant_msg); + // Append observation as a tool message so the model sees it in context. self.messages.push(ChatCompletionRequestMessage::Tool( ChatCompletionRequestToolMessage { diff --git a/libs/agent/task/service.rs b/libs/agent/task/service.rs index e19aa61..071c5fc 100644 --- a/libs/agent/task/service.rs +++ b/libs/agent/task/service.rs @@ -477,8 +477,9 @@ impl TaskService { /// Propagate child task status up the tree. /// /// When a child task reaches a terminal state, checks whether all its - /// siblings are also terminal. If so, marks the parent as failed so that - /// a stuck parent is never left in the `Running` state. + /// siblings are also terminal. If so, marks the parent appropriately: + /// - Done if any child succeeded + /// - Failed if all children failed or were cancelled pub async fn propagate_to_parent(&self, task_id: i64) -> Result, DbErr> { let model = self .get(task_id) @@ -496,9 +497,15 @@ impl TaskService { })?; if parent.is_running() { let mut active: ActiveModel = parent.into(); - active.status = sea_orm::Set(TaskStatus::Failed); - active.error = - sea_orm::Set(Some("All sub-tasks failed or were cancelled".to_string())); + let has_success = siblings.iter().any(|s| s.status == TaskStatus::Done); + if has_success { + active.status = sea_orm::Set(TaskStatus::Done); + active.error = sea_orm::Set(None); + } else { + active.status = sea_orm::Set(TaskStatus::Failed); + active.error = + sea_orm::Set(Some("All sub-tasks failed or were cancelled".to_string())); + } active.done_at = sea_orm::Set(Some(chrono::Utc::now().into())); active.updated_at = sea_orm::Set(chrono::Utc::now().into()); let updated = active.update(&self.db).await?; diff --git a/libs/agent/tokent.rs b/libs/agent/tokent.rs index e49d493..cd2c7e0 100644 --- a/libs/agent/tokent.rs +++ b/libs/agent/tokent.rs @@ -125,6 +125,8 @@ pub fn truncate_to_token_budget( while low + 100 < high { let mid = (low + high) / 2; + // Find the nearest valid char boundary to avoid panicking on multi-byte UTF-8 + let mid = text.floor_char_boundary(mid); let candidate = &text[..mid]; let tokens = bpe.encode_ordinary(candidate); diff --git a/libs/agent/tool/call.rs b/libs/agent/tool/call.rs index 0aa38b0..4588c4a 100644 --- a/libs/agent/tool/call.rs +++ b/libs/agent/tool/call.rs @@ -22,11 +22,13 @@ impl ToolCall { /// The result of executing a tool call. #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] +#[serde(tag = "status", content = "value")] pub enum ToolResult { /// Successful result with a JSON value. + #[serde(rename = "ok")] Ok(serde_json::Value), /// Error result with an error message. + #[serde(rename = "error")] Error(String), } diff --git a/libs/agent/tool/executor.rs b/libs/agent/tool/executor.rs index b372e93..2661bef 100644 --- a/libs/agent/tool/executor.rs +++ b/libs/agent/tool/executor.rs @@ -70,34 +70,33 @@ impl ToolExecutor { ctx.increment_tool_calls(); let concurrency = self.max_concurrency; - use tokio::sync::Mutex as AsyncMutex; - let results: AsyncMutex> = - AsyncMutex::new(Vec::with_capacity(calls.len())); + let calls_clone: Vec = calls.clone(); - stream::iter(calls.into_iter().map(|call| { - let child_ctx = ctx.child_context(); - async move { self.execute_one(call, child_ctx).await } - })) - .buffer_unordered(concurrency) - .for_each_concurrent( - concurrency, - |result: Result| async { - let r = result.unwrap_or_else(|e| { - ToolCallResult::error( - ToolCall { - id: String::new(), - name: String::new(), - arguments: String::new(), - }, - e.to_string(), - ) - }); - results.lock().await.push(r); - }, + // Execute tool calls concurrently but preserve input order for ID matching. + // buffer_unordered returns results in *completion* order, which mispairs IDs + // on concurrent errors. Instead, track each result with its original index. + let indexed_results: Vec<(usize, Result)> = stream::iter( + calls.into_iter().enumerate().map(|(i, call)| { + let child_ctx = ctx.child_context(); + async move { (i, self.execute_one(call, child_ctx).await) } + }) ) + .buffer_unordered(concurrency) + .collect() .await; - Ok(results.into_inner()) + // Re-sort by original index to restore input order, then pair with original calls. + let mut result_map: std::collections::HashMap> = + indexed_results.into_iter().collect(); + + let results: Vec = calls_clone.into_iter().enumerate().map(|(i, call)| { + let r = result_map.remove(&i).expect("every index must have a result"); + r.unwrap_or_else(|e: ToolError| { + ToolCallResult::error(call, e.to_string()) + }) + }).collect(); + + Ok(results) } async fn execute_one( diff --git a/libs/service/agent/code_review.rs b/libs/service/agent/code_review.rs index 286ffca..0342536 100644 --- a/libs/service/agent/code_review.rs +++ b/libs/service/agent/code_review.rs @@ -412,7 +412,7 @@ async fn call_ai_model( ), ]; - agent::call_with_params(&messages, model_name, &client_config, 0.2, 8192, None, None) + agent::call_with_params(&messages, model_name, &client_config, 0.2, 8192, None, None, None) .await .map_err(|e| AppError::InternalServerError(format!("AI call failed: {}", e))) } diff --git a/libs/service/agent/issue_triage.rs b/libs/service/agent/issue_triage.rs new file mode 100644 index 0000000..2c58f54 --- /dev/null +++ b/libs/service/agent/issue_triage.rs @@ -0,0 +1,232 @@ +//! AI-powered issue triage service. +//! +//! Analyzes newly created issues and suggests labels and priority. + +use crate::AppService; +use crate::error::AppError; +use chrono::Utc; +use config::AppConfig; +use models::agents::ModelStatus; +use models::agents::model::{Column as MColumn, Entity as MEntity}; +use models::issues::{issue, issue_comment}; +use sea_orm::*; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; +use uuid::Uuid; + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct IssueTriageSuggestion { + pub suggested_labels: Vec, + pub priority: String, + pub reasoning: String, +} + +#[derive(Debug, Clone, Serialize, ToSchema)] +pub struct IssueTriageResponse { + pub suggestions: Option, + pub comment_posted: bool, +} + +fn build_triage_prompt(title: &str, body: Option<&str>, existing_labels: &[String]) -> String { + let body_text = body.unwrap_or("(no description)"); + let labels_text = if existing_labels.is_empty() { + "none".to_string() + } else { + existing_labels.join(", ") + }; + + format!( + r#"You are an expert software project manager. Analyze the following GitHub issue and suggest how to triage it. + +Issue Title: {} +Issue Body: +{} +Existing Labels: {} + +Based on the issue, suggest: +1. Additional labels from this standard set: bug, enhancement, documentation, question, help wanted, good first issue, priority:high, priority:medium, priority:low, kind:backend, kind:frontend, kind:dx, kind:security, kind:performance +2. A priority level: high, medium, or low +3. A brief reasoning for your assessment + +Respond in JSON format like: +{{ + "suggested_labels": ["bug", "priority:high"], + "priority": "high", + "reasoning": "This is a critical security vulnerability in the auth module..." +}} + +Only suggest labels not already in the existing list. Be concise."#, + title, body_text, labels_text + ) +} + +fn parse_triage_response(content: &str) -> Option { + let content = content.trim(); + let json_str = if content.starts_with("```json") { + content + .strip_prefix("```json")? + .strip_prefix('\n') + .unwrap_or(content) + .trim_end_matches("```") + .trim() + } else if content.starts_with("```") { + content + .strip_prefix("```")? + .strip_prefix('\n') + .unwrap_or(content) + .trim_end_matches("```") + .trim() + } else { + content + }; + + let parsed: serde_json::Value = serde_json::from_str(json_str).ok()?; + Some(IssueTriageSuggestion { + suggested_labels: parsed + .get("suggested_labels")? + .as_array()? + .iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect(), + priority: parsed.get("priority")?.as_str()?.to_string(), + reasoning: parsed.get("reasoning")?.as_str()?.to_string(), + }) +} + +async fn call_ai_for_triage( + model_name: &str, + prompt: &str, + app_config: &AppConfig, +) -> Result { + let api_key = app_config + .ai_api_key() + .map_err(|e| AppError::InternalServerError(format!("AI API key not configured: {}", e)))?; + + let base_url = app_config + .ai_basic_url() + .unwrap_or_else(|_| "https://api.openai.com".into()); + + let client_config = + ::agent::AiClientConfig::new(api_key).with_base_url(base_url); + + let messages = vec![async_openai::types::chat::ChatCompletionRequestMessage::User( + async_openai::types::chat::ChatCompletionRequestUserMessage { + content: + async_openai::types::chat::ChatCompletionRequestUserMessageContent::Text( + prompt.to_string(), + ), + ..Default::default() + }, + )]; + + let response = ::agent::call_with_params( + &messages, + model_name, + &client_config, + 0.3, + 1024, + None, + None, + None, + ) + .await + .map_err(|e| { + AppError::InternalServerError(format!("AI triage call failed: {}", e)) + })?; + + Ok(response.content) +} + +impl AppService { + /// Run AI triage on a newly created issue and post a suggestion comment. + /// Called asynchronously after issue creation. + pub async fn triage_issue( + &self, + project_name: String, + issue_number: i64, + ) -> Result { + let project = self.utils_find_project_by_name(project_name.clone()).await?; + + let issue_model = issue::Entity::find() + .filter(issue::Column::Project.eq(project.id)) + .filter(issue::Column::Number.eq(issue_number)) + .one(&self.db) + .await? + .ok_or_else(|| AppError::NotFound("Issue not found".to_string()))?; + + let existing_labels: Vec = Vec::new(); + + let model = match MEntity::find() + .filter(MColumn::Status.eq(ModelStatus::Active.to_string())) + .order_by_asc(MColumn::Name) + .one(&self.db) + .await? + { + Some(m) => m, + None => { + tracing::debug!( + project = %project_name, + issue = issue_number, + "No active AI model for triage — skipping" + ); + return Ok(IssueTriageResponse { + suggestions: None, + comment_posted: false, + }); + } + }; + + let prompt = + build_triage_prompt(&issue_model.title, issue_model.body.as_deref(), &existing_labels); + let ai_content = match call_ai_for_triage(&model.name, &prompt, &self.config).await { + Ok(c) => c, + Err(e) => { + tracing::warn!( + project = %project_name, + issue = issue_number, + error = ?e, + "AI triage failed" + ); + return Ok(IssueTriageResponse { + suggestions: None, + comment_posted: false, + }); + } + }; + + let suggestions = parse_triage_response(&ai_content); + let mut comment_posted = false; + + if let Some(ref s) = suggestions { + let comment_body = format!( + "## AI Triage Suggestions\n\n**Priority:** *{}*\n\n{}\n\n**Suggested Labels:** \ + {}\n\n_This analysis was generated automatically by the AI collaborator._", + s.priority.to_uppercase(), + s.reasoning, + if s.suggested_labels.is_empty() { + "none".to_string() + } else { + s.suggested_labels.join(", ") + } + ); + + let now = Utc::now(); + let active = issue_comment::ActiveModel { + issue: Set(issue_model.id), + author: Set(Uuid::nil()), + body: Set(comment_body), + created_at: Set(now), + updated_at: Set(now), + ..Default::default() + }; + if active.insert(&self.db).await.is_ok() { + comment_posted = true; + } + } + + Ok(IssueTriageResponse { + suggestions, + comment_posted, + }) + } +} diff --git a/libs/service/agent/mod.rs b/libs/service/agent/mod.rs index eb41d69..42d0041 100644 --- a/libs/service/agent/mod.rs +++ b/libs/service/agent/mod.rs @@ -6,6 +6,7 @@ pub mod provider; pub mod billing; pub mod code_review; +pub mod issue_triage; pub mod model; pub mod pr_summary; pub mod sync; diff --git a/libs/service/agent/pr_summary.rs b/libs/service/agent/pr_summary.rs index 609eae5..580c153 100644 --- a/libs/service/agent/pr_summary.rs +++ b/libs/service/agent/pr_summary.rs @@ -147,7 +147,7 @@ async fn call_ai_model_for_description( ), ]; - agent::call_with_params(&messages, model_name, &client_config, 0.3, 4096, None, None) + agent::call_with_params(&messages, model_name, &client_config, 0.3, 4096, None, None, None) .await .map_err(|e| AppError::InternalServerError(format!("AI call failed: {}", e))) } diff --git a/libs/service/file_tools/grep.rs b/libs/service/file_tools/grep.rs index 95ebca7..972d065 100644 --- a/libs/service/file_tools/grep.rs +++ b/libs/service/file_tools/grep.rs @@ -201,11 +201,10 @@ async fn git_grep_exec( } fn glob_match(path: &str, pattern: &str) -> bool { - // Simple glob: support *, ?, ** - let parts: Vec<&str> = pattern.split('/').collect(); - let path_parts: Vec<&str> = path.split('/').collect(); - let _path_lower = path.to_lowercase(); + let path_lower = path.to_lowercase(); let pattern_lower = pattern.to_lowercase(); + let parts: Vec<&str> = pattern_lower.split('/').collect(); + let path_parts: Vec<&str> = path_lower.split('/').collect(); fn matches_part(path_part: &str, pattern_part: &str) -> bool { if pattern_part.is_empty() || pattern_part == "*" { @@ -231,24 +230,33 @@ fn glob_match(path: &str, pattern: &str) -> bool { if parts.len() == 1 { // Simple glob pattern on filename only let file_name = path_parts.last().unwrap_or(&""); - return matches_part(file_name, &pattern_lower); + return matches_part(file_name, &parts[0]); } // Multi-part glob let mut pi = 0; for part in &parts { - while pi < path_parts.len() { - if matches_part(path_parts[pi], part) { - pi += 1; + if *part == "**" { + // ** matches zero or more path segments + // If this is the last pattern part, consume all remaining path segments + if part == parts.last().unwrap() { + pi = path_parts.len(); break; } - if *part != "**" { - return false; + // Try skipping segments until the next pattern part matches + let next_part = parts.iter().skip_while(|p| **p == "**").next().unwrap_or(&"*"); + while pi < path_parts.len() && !matches_part(path_parts[pi], next_part) { + pi += 1; } - pi += 1; + continue; } + if pi >= path_parts.len() || !matches_part(path_parts[pi], part) { + return false; + } + pi += 1; } - true + // All pattern parts consumed — check that all path segments were matched too + pi == path_parts.len() } pub fn register_grep_tools(registry: &mut ToolRegistry) { diff --git a/libs/service/file_tools/json.rs b/libs/service/file_tools/json.rs index 6a6b2f3..de0a364 100644 --- a/libs/service/file_tools/json.rs +++ b/libs/service/file_tools/json.rs @@ -154,7 +154,9 @@ async fn read_json_exec( } let text = String::from_utf8_lossy(data); - let is_jsonc = path.ends_with(".jsonc") || path.ends_with(".vscodeignore") || text.contains("//"); + // Only treat as JSONC if the extension indicates it, or if we can + // confirm a comment-like pattern outside of a string context. + let is_jsonc = path.ends_with(".jsonc"); let json_text = if is_jsonc { strip_jsonc_comments(&text) @@ -187,13 +189,15 @@ async fn read_json_exec( "size_bytes": data.len(), "schema": schema, "data": if display.chars().count() > 5000 { - format!("{}... (truncated, {} chars total)", &display[..5000], display.chars().count()) + let truncated: String = display.chars().take(5000).collect(); + format!("{}... (truncated, {} chars total)", truncated, display.chars().count()) } else { display }, })) } /// Simple JSONPath-like query support. /// Supports: $.key, $[0], $.key.nested, $.arr[0].field +/// Bracket notation ["key.with.dots"] allows accessing keys containing dots. fn query_json(value: &JsonValue, query: &str) -> Result { let query = query.trim(); let query = if query.starts_with("$.") { @@ -206,43 +210,74 @@ fn query_json(value: &JsonValue, query: &str) -> Result { let mut current = value.clone(); - for part in query.split('.') { - if part.is_empty() { - continue; - } - - // Handle array index like [0] - if let Some(idx_start) = part.find('[') { - let key = &part[..idx_start]; + // Parse into access segments: Key("name"), Index(0), BracketKey("key.with.dots") + enum Segment { Key(String), Index(usize), BracketKey(String) } + let mut segments: Vec = Vec::new(); + let mut i = 0; + let q_chars: Vec = query.chars().collect(); + while i < q_chars.len() { + if q_chars[i] == '[' { + // Find matching ] + let mut j = i + 1; + let mut bracket_content = String::new(); + while j < q_chars.len() && q_chars[j] != ']' { + bracket_content.push(q_chars[j]); + j += 1; + } + if j < q_chars.len() && q_chars[j] == ']' { + let content = bracket_content.trim(); + // Check if it's a quoted string key or a numeric index + if content.starts_with('"') && content.ends_with('"') { + let key = content[1..content.len()-1].to_string(); + segments.push(Segment::BracketKey(key)); + } else if content.starts_with("'") && content.ends_with("'") { + let key = content[1..content.len()-1].to_string(); + segments.push(Segment::BracketKey(key)); + } else if let Ok(idx) = content.parse::() { + segments.push(Segment::Index(idx)); + } else { + return Err(format!("Invalid bracket notation: [{}]", content)); + } + i = j + 1; + // Skip dot after bracket if present + if i < q_chars.len() && q_chars[i] == '.' { + i += 1; + } + } else { + return Err("Unmatched [ in query".into()); + } + } else { + // Read key until . or [ + let mut key = String::new(); + while i < q_chars.len() && q_chars[i] != '.' && q_chars[i] != '[' { + key.push(q_chars[i]); + i += 1; + } if !key.is_empty() { + // Check if key contains a numeric-only segment (array index shorthand) + segments.push(Segment::Key(key)); + } + if i < q_chars.len() && q_chars[i] == '.' { + i += 1; + } + } + } + + for seg in &segments { + match seg { + Segment::Key(key) | Segment::BracketKey(key) => { if let JsonValue::Object(obj) = ¤t { current = obj.get(key).cloned().unwrap_or(JsonValue::Null); } else { return Err(format!("cannot access property '{}' on non-object", key)); } } - - let rest = &part[idx_start..]; - for bracket in rest.split_inclusive(']') { - if bracket.is_empty() || bracket == "]" { - continue; + Segment::Index(idx) => { + if let JsonValue::Array(arr) = ¤t { + current = arr.get(*idx).cloned().unwrap_or(JsonValue::Null); + } else { + return Err(format!("index {} on non-array", idx)); } - let inner = bracket.trim_end_matches(']'); - if let Some(idx) = inner.strip_prefix('[') { - if let Ok(index) = idx.parse::() { - if let JsonValue::Array(arr) = ¤t { - current = arr.get(index).cloned().unwrap_or(JsonValue::Null); - } else { - return Err(format!("index {} on non-array", index)); - } - } - } - } - } else { - if let JsonValue::Object(obj) = ¤t { - current = obj.get(part).cloned().unwrap_or(JsonValue::Null); - } else { - return Err(format!("property '{}' not found", part)); } } } diff --git a/libs/service/git/repo.rs b/libs/service/git/repo.rs index 8f0bd83..de77421 100644 --- a/libs/service/git/repo.rs +++ b/libs/service/git/repo.rs @@ -62,6 +62,7 @@ impl From for ConfigSnapshotResponse { #[derive(Debug, Clone, Deserialize, utoipa::ToSchema)] pub struct GitUpdateRepoRequest { pub default_branch: Option, + pub ai_code_review_enabled: Option, } #[derive(Debug, Clone, Serialize, utoipa::ToSchema)] pub struct ConfigBoolResponse { @@ -459,6 +460,9 @@ impl AppService { if let Some(default_branch) = params.default_branch { active.default_branch = Set(default_branch); } + if let Some(ai_enabled) = params.ai_code_review_enabled { + active.ai_code_review_enabled = Set(ai_enabled); + } active.update(&txn).await?; txn.commit().await?; Ok(()) diff --git a/libs/service/git_tools/branch.rs b/libs/service/git_tools/branch.rs index 919e6e4..263e782 100644 --- a/libs/service/git_tools/branch.rs +++ b/libs/service/git_tools/branch.rs @@ -55,7 +55,15 @@ async fn git_branches_merged_exec(ctx: GitToolCtx, args: serde_json::Value) -> R let domain = ctx.open_repo(project_name, repo_name).await?; let is_merged = domain.branch_is_merged(branch, &into).map_err(|e| e.to_string())?; - let merge_base = domain.merge_base(&git::commit::types::CommitOid::new(branch), &git::commit::types::CommitOid::new(&into)) + + // Resolve branch names to commit OIDs before calling merge_base + let branch_oid = domain.branch_target(branch) + .map_err(|e| e.to_string())? + .ok_or_else(|| format!("branch '{}' not found or has no target", branch))?; + let into_oid = domain.branch_target(&into) + .map_err(|e| e.to_string())? + .ok_or_else(|| format!("branch '{}' not found or has no target", into))?; + let merge_base = domain.merge_base(&branch_oid, &into_oid) .map(|oid| oid.to_string()).ok(); Ok(serde_json::json!({ "branch": branch, "into": into, "is_merged": is_merged, "merge_base": merge_base })) diff --git a/libs/service/git_tools/commit.rs b/libs/service/git_tools/commit.rs index 19f2657..875ee0a 100644 --- a/libs/service/git_tools/commit.rs +++ b/libs/service/git_tools/commit.rs @@ -22,7 +22,7 @@ async fn git_log_exec(ctx: GitToolCtx, args: serde_json::Value) -> Result = commits.iter().map(|c| { use chrono::TimeZone; - let ts = c.author.time_secs + (c.author.offset_minutes as i64 * 60); + let ts = c.author.time_secs - (c.author.offset_minutes as i64 * 60); let time_str = chrono::Utc.timestamp_opt(ts, 0).single() .map(|dt| dt.to_rfc3339()).unwrap_or_else(|| format!("{}", c.author.time_secs)); @@ -63,7 +63,7 @@ async fn git_show_exec(ctx: GitToolCtx, args: serde_json::Value) -> Result Re let limit = p.get("limit").and_then(|v| v.as_u64()).unwrap_or(20) as usize; let domain = ctx.open_repo(project_name, repo_name).await?; - let commits = domain.commit_log(Some("HEAD"), 0, 100).map_err(|e| e.to_string())?; + // Fetch extra commits to have enough candidates after filtering + let walk_limit = limit.saturating_mul(2).max(100); + let commits = domain.commit_log(Some("HEAD"), 0, walk_limit).map_err(|e| e.to_string())?; let q = query.to_lowercase(); let result: Vec<_> = commits.iter() @@ -104,7 +106,7 @@ async fn git_search_commits_exec(ctx: GitToolCtx, args: serde_json::Value) -> Re fn flatten_commit(c: &git::commit::types::CommitMeta) -> serde_json::Value { use chrono::TimeZone; - let ts = c.author.time_secs + (c.author.offset_minutes as i64 * 60); + let ts = c.author.time_secs - (c.author.offset_minutes as i64 * 60); let author_time = chrono::Utc.timestamp_opt(ts, 0).single() .map(|dt| dt.to_rfc3339()).unwrap_or_else(|| format!("{}", c.author.time_secs)); let oid = c.oid.to_string(); @@ -160,7 +162,7 @@ async fn git_graph_exec(ctx: GitToolCtx, args: serde_json::Value) -> Result Result { let head_meta = domain.commit_get_prefix("HEAD").map_err(|e| e.to_string())?; - domain.diff_commit_to_workdir(&head_meta.oid, opts).map_err(|e| e.to_string())? + // Bare repos have no working tree — use tree-to-tree diff instead + if domain.repo().is_bare() { + domain.diff_tree_to_tree(None, Some(&head_meta.oid), opts).map_err(|e| e.to_string())? + } else { + domain.diff_commit_to_workdir(&head_meta.oid, opts).map_err(|e| e.to_string())? + } } (Some(base), None) => { - domain.diff_commit_to_workdir(base, opts).map_err(|e| e.to_string())? + if domain.repo().is_bare() { + domain.diff_tree_to_tree(Some(base), None, opts).map_err(|e| e.to_string())? + } else { + domain.diff_commit_to_workdir(base, opts).map_err(|e| e.to_string())? + } } (Some(base), Some(head_oid_val)) => { domain.diff_tree_to_tree(Some(base), Some(head_oid_val), opts).map_err(|e| e.to_string())? @@ -74,12 +83,20 @@ async fn git_diff_stats_exec(ctx: GitToolCtx, args: serde_json::Value) -> Result let domain = ctx.open_repo(project_name, repo_name).await?; - let stats = if base.len() >= 40 || head.len() >= 40 { + let stats = if base.len() >= 40 && head.len() >= 40 { domain.diff_stats(&git::commit::types::CommitOid::new(base), &git::commit::types::CommitOid::new(head)) .map_err(|e| e.to_string())? } else { - let b = domain.commit_get_prefix(base).map_err(|e| e.to_string())?.oid; - let h = domain.commit_get_prefix(head).map_err(|e| e.to_string())?.oid; + let b = if base.len() >= 40 { + git::commit::types::CommitOid::new(base) + } else { + domain.commit_get_prefix(base).map_err(|e| e.to_string())?.oid + }; + let h = if head.len() >= 40 { + git::commit::types::CommitOid::new(head) + } else { + domain.commit_get_prefix(head).map_err(|e| e.to_string())?.oid + }; domain.diff_stats(&b, &h).map_err(|e| e.to_string())? }; diff --git a/libs/service/git_tools/tree.rs b/libs/service/git_tools/tree.rs index 46db052..68e948b 100644 --- a/libs/service/git_tools/tree.rs +++ b/libs/service/git_tools/tree.rs @@ -77,10 +77,13 @@ async fn git_file_history_exec(ctx: GitToolCtx, args: serde_json::Value) -> Resu let project_name = p.get("project_name").and_then(|v| v.as_str()).ok_or("missing project_name")?; let repo_name = p.get("repo_name").and_then(|v| v.as_str()).ok_or("missing repo_name")?; let path = p.get("path").and_then(|v| v.as_str()).ok_or("missing path")?; + let rev = p.get("rev").and_then(|v| v.as_str()).map(String::from).unwrap_or_else(|| "HEAD".to_string()); let limit = p.get("limit").and_then(|v| v.as_u64()).unwrap_or(20) as usize; let domain = ctx.open_repo(project_name, repo_name).await?; - let commits = domain.commit_log(Some("HEAD"), 0, 500).map_err(|e| e.to_string())?; + // Fetch extra commits to have enough candidates after filtering + let walk_limit = limit.saturating_mul(2).max(200); + let commits = domain.commit_log(Some(&rev), 0, walk_limit).map_err(|e| e.to_string())?; let result: Vec<_> = commits.iter() .filter(|c| domain.tree_entry_by_path(&c.tree_id, path).is_ok()) @@ -125,7 +128,7 @@ async fn git_blob_get_exec(ctx: GitToolCtx, args: serde_json::Value) -> Result serde_json::Value { use chrono::TimeZone; - let ts = c.author.time_secs + (c.author.offset_minutes as i64 * 60); + let ts = c.author.time_secs - (c.author.offset_minutes as i64 * 60); let author_time = chrono::Utc.timestamp_opt(ts, 0).single() .map(|dt| dt.to_rfc3339()).unwrap_or_else(|| format!("{}", c.author.time_secs)); let oid = c.oid.to_string(); @@ -182,6 +185,7 @@ pub fn register_git_tools(registry: &mut ToolRegistry) { ("project_name".into(), ToolParam { name: "project_name".into(), param_type: "string".into(), description: Some("Project name (slug)".into()), required: true, properties: None, items: None }), ("repo_name".into(), ToolParam { name: "repo_name".into(), param_type: "string".into(), description: Some("Repository name".into()), required: true, properties: None, items: None }), ("path".into(), ToolParam { name: "path".into(), param_type: "string".into(), description: Some("File path to trace history for".into()), required: true, properties: None, items: None }), + ("rev".into(), ToolParam { name: "rev".into(), param_type: "string".into(), description: Some("Revision to start history from (default: HEAD)".into()), required: false, properties: None, items: None }), ("limit".into(), ToolParam { name: "limit".into(), param_type: "integer".into(), description: Some("Maximum number of commits to return (default: 20)".into()), required: false, properties: None, items: None }), ]); let schema = ToolSchema { schema_type: "object".into(), properties: Some(p), required: Some(vec!["project_name".into(), "repo_name".into(), "path".into()]) }; diff --git a/libs/service/git_tools/types.rs b/libs/service/git_tools/types.rs index 7f9f5c2..7c66d3a 100644 --- a/libs/service/git_tools/types.rs +++ b/libs/service/git_tools/types.rs @@ -6,7 +6,7 @@ use base64::Engine; use chrono::TimeZone; use git::commit::types::{CommitMeta, CommitReflogEntry}; -use git::diff::types::{DiffDelta, DiffStats}; +use git::diff::types::{DiffDelta, DiffDeltaStatus, DiffStats}; use git::tree::types::TreeEntry; use serde::{Deserialize, Serialize}; @@ -121,7 +121,8 @@ pub struct ReflogEntryInfo { impl ReflogEntryInfo { pub fn from_entry(entry: &CommitReflogEntry) -> Self { let ts = entry.time_secs; - let time = format_rfc3339(ts, 0); + let offset = entry.offset_minutes; + let time = format_rfc3339(ts, offset); Self { oid_new: entry.oid_new.to_string(), oid_old: entry.oid_old.to_string(), @@ -216,8 +217,13 @@ pub struct DiffFileOut { impl DiffFileOut { pub fn from_delta(delta: &DiffDelta) -> Self { + // For deleted files, use old_file.path; for all others, use new_file.path. + let path = match delta.status { + DiffDeltaStatus::Deleted => delta.old_file.path.clone(), + _ => delta.new_file.path.clone(), + }; Self { - path: delta.new_file.path.clone(), + path, status: format!("{:?}", delta.status), is_binary: delta.new_file.is_binary, size: delta.new_file.size, @@ -402,7 +408,8 @@ impl From<&git::commit::graph::CommitGraphLine> for GraphLineOut { // --------------------------------------------------------------------------- fn format_rfc3339(time_secs: i64, offset_minutes: i32) -> String { - let secs = time_secs + (offset_minutes as i64 * 60); + // Git stores local time + offset. To convert to UTC, subtract the offset. + let secs = time_secs - (offset_minutes as i64 * 60); chrono::Utc .timestamp_opt(secs, 0) .single() diff --git a/libs/service/issue/issue.rs b/libs/service/issue/issue.rs index ee8a520..684cadc 100644 --- a/libs/service/issue/issue.rs +++ b/libs/service/issue/issue.rs @@ -222,7 +222,7 @@ impl AppService { ctx: &Session, ) -> Result { let user_uid = ctx.user().ok_or(AppError::Unauthorized)?; - let project = self.utils_find_project_by_name(project_name).await?; + let project = self.utils_find_project_by_name(project_name.clone()).await?; // Any project member can create issues let member = project_members::Entity::find() @@ -280,6 +280,15 @@ impl AppService { ) .await; + // Run AI triage asynchronously + let project_name_clone = project_name.clone(); + let issue_number = number; + let this = self.clone(); + drop(project_name); // allow move below + tokio::spawn(async move { + let _ = this.triage_issue(project_name_clone, issue_number).await; + }); + Ok(IssueResponse::from(model)) } diff --git a/libs/service/issue/label.rs b/libs/service/issue/label.rs index 0483c20..a3314d7 100644 --- a/libs/service/issue/label.rs +++ b/libs/service/issue/label.rs @@ -16,6 +16,38 @@ pub struct IssueAddLabelRequest { pub label_id: i64, } +#[derive(Debug, Clone, Deserialize, ToSchema)] +pub struct IssueAddLabelsByNamesRequest { + pub names: Vec, +} + +fn default_color_for_label(name: &str) -> String { + let lower = name.to_lowercase(); + if lower.contains("bug") || lower.contains("critical") || lower.contains("security") { + "ef4444".to_string() + } else if lower.contains("enhancement") || lower.contains("feature") || lower.contains("improvement") { + "22c55e".to_string() + } else if lower.contains("documentation") || lower.contains("docs") { + "3b82f6".to_string() + } else if lower.contains("question") || lower.contains("help wanted") { + "a855f7".to_string() + } else if lower.contains("good first") || lower.contains("beginner") || lower.contains("easy") { + "10b981".to_string() + } else if lower.contains("priority") || lower.contains("high") { + "f97316".to_string() + } else if lower.contains("backend") || lower.contains("server") { + "6366f1".to_string() + } else if lower.contains("frontend") || lower.contains("ui") || lower.contains("ux") { + "ec4899".to_string() + } else if lower.contains("performance") || lower.contains("optimize") { + "eab308".to_string() + } else if lower.contains("dx") || lower.contains("dev") || lower.contains("tool") { + "14b8a6".to_string() + } else { + "6b7280".to_string() + } +} + #[derive(Debug, Clone, Deserialize, ToSchema)] pub struct CreateLabelRequest { pub name: String, @@ -285,6 +317,91 @@ impl AppService { response } + /// Add labels to an issue by name, creating missing labels automatically. + pub async fn issue_label_add_by_names( + &self, + project_name: String, + issue_number: i64, + request: IssueAddLabelsByNamesRequest, + ctx: &Session, + ) -> Result, AppError> { + let user_uid = ctx.user().ok_or(AppError::Unauthorized)?; + let project = self.utils_find_project_by_name(project_name).await?; + + let _member = project_members::Entity::find() + .filter(project_members::Column::Project.eq(project.id)) + .filter(project_members::Column::User.eq(user_uid)) + .one(&self.db) + .await? + .ok_or(AppError::NoPower)?; + + let issue = issue::Entity::find() + .filter(issue::Column::Project.eq(project.id)) + .filter(issue::Column::Number.eq(issue_number)) + .one(&self.db) + .await? + .ok_or(AppError::NotFound("Issue not found".to_string()))?; + + let mut added: Vec = Vec::new(); + + for name in request.names { + let color = default_color_for_label(&name); + + // Find or create label + let lbl = label::Entity::find() + .filter(label::Column::Project.eq(project.id)) + .filter(label::Column::Name.eq(&name)) + .one(&self.db) + .await?; + + let lbl = match lbl { + Some(l) => l, + None => { + let active = label::ActiveModel { + id: Set(0), + project: Set(project.id), + name: Set(name.clone()), + color: Set(color), + ..Default::default() + }; + active.insert(&self.db).await? + } + }; + + // Check if already linked + let existing = issue_label::Entity::find() + .filter(issue_label::Column::Issue.eq(issue.id)) + .filter(issue_label::Column::Label.eq(lbl.id)) + .one(&self.db) + .await?; + if existing.is_some() { + continue; + } + + let now = Utc::now(); + let active = issue_label::ActiveModel { + issue: Set(issue.id), + label: Set(lbl.id), + relation_at: Set(now), + ..Default::default() + }; + let model = active.insert(&self.db).await?; + added.push(IssueLabelResponse { + issue: model.issue, + label_id: model.label, + label_name: Some(lbl.name.clone()), + label_color: Some(lbl.color.clone()), + relation_at: model.relation_at, + }); + } + + if !added.is_empty() { + self.invalidate_issue_cache(project.id, issue_number).await; + } + + Ok(added) + } + /// Remove a label from an issue. pub async fn issue_label_remove( &self, diff --git a/libs/service/issue/mod.rs b/libs/service/issue/mod.rs index e59df1f..df0c31e 100644 --- a/libs/service/issue/mod.rs +++ b/libs/service/issue/mod.rs @@ -16,7 +16,7 @@ pub use comment::{ pub use issue::{ IssueCreateRequest, IssueListResponse, IssueResponse, IssueSummaryResponse, IssueUpdateRequest, }; -pub use label::{CreateLabelRequest, IssueAddLabelRequest, IssueLabelResponse, LabelResponse}; +pub use label::{CreateLabelRequest, IssueAddLabelRequest, IssueAddLabelsByNamesRequest, IssueLabelResponse, LabelResponse}; pub use pull_request::{IssueLinkPullRequestRequest, IssuePullRequestResponse}; pub use reaction::{ReactionAddRequest, ReactionListResponse, ReactionResponse}; pub use repo::{IssueLinkRepoRequest, IssueRepoResponse}; diff --git a/libs/service/project/repo.rs b/libs/service/project/repo.rs index eca7e8c..29471ba 100644 --- a/libs/service/project/repo.rs +++ b/libs/service/project/repo.rs @@ -36,6 +36,7 @@ pub struct ProjectRepositoryItem { pub last_commit_at: Option>, pub ssh_clone_url: String, pub https_clone_url: String, + pub ai_code_review_enabled: bool, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, ToSchema)] @@ -226,6 +227,7 @@ impl AppService { last_commit_at: last_commit_times.get(&r.id).and_then(|t| *t), ssh_clone_url: format!("git@{}:{}", ssh_domain, path), https_clone_url: format!("https://{}/{}", ssh_domain, path), + ai_code_review_enabled: r.ai_code_review_enabled, } }) .collect(); diff --git a/libs/service/project_tools/arxiv.rs b/libs/service/project_tools/arxiv.rs index 69f52a9..9500b87 100644 --- a/libs/service/project_tools/arxiv.rs +++ b/libs/service/project_tools/arxiv.rs @@ -9,7 +9,7 @@ const DEFAULT_MAX_RESULTS: usize = 10; const MAX_MAX_RESULTS: usize = 50; /// arXiv API base URL (Atom feed). -const ARXIV_API: &str = "http://export.arxiv.org/api/query"; +const ARXIV_API: &str = "https://export.arxiv.org/api/query"; /// arXiv Atom feed entry fields we care about. #[derive(Debug, Deserialize)] diff --git a/libs/service/project_tools/boards.rs b/libs/service/project_tools/boards.rs index c653674..5c09252 100644 --- a/libs/service/project_tools/boards.rs +++ b/libs/service/project_tools/boards.rs @@ -274,6 +274,8 @@ pub async fn create_board_card_exec( .ok_or_else(|| ToolError::ExecutionError("No sender context".into()))?; let db = ctx.db(); + require_admin(db, project_id, sender_id).await?; + let board_id = args .get("board_id") .and_then(|v| Uuid::parse_str(v.as_str()?).ok()) diff --git a/libs/service/project_tools/curl.rs b/libs/service/project_tools/curl.rs index f2b5be8..98b0e7b 100644 --- a/libs/service/project_tools/curl.rs +++ b/libs/service/project_tools/curl.rs @@ -1,11 +1,80 @@ //! Tool: project_curl — perform HTTP requests (GET/POST/PUT/DELETE) +//! +//! Security measures: +//! - SSRF protection: blocks private IPs and blocks redirects to private IPs +//! - Sensitive header injection: blocks Host, Authorization, Cookie, Proxy-* +//! - Connection pooling via a shared reqwest::Client use agent::{ToolContext, ToolDefinition, ToolError, ToolParam, ToolSchema}; use std::collections::HashMap; +use std::sync::OnceLock; /// Maximum response body size: 1 MB. const MAX_BODY_BYTES: usize = 1 << 20; +/// Headers that are blocked from user-supplied values to prevent injection attacks. +const BLOCKED_HEADERS: &[&str] = &[ + "host", "authorization", "cookie", "proxy-authorization", + "proxy-connection", "proxy-authenticate", +]; + +/// Shared reqwest::Client for connection pooling. +static SHARED_CLIENT: OnceLock = OnceLock::new(); + +fn shared_client() -> &'static reqwest::Client { + SHARED_CLIENT.get_or_init(|| { + reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(10)) + .timeout(std::time::Duration::from_secs(120)) + // Block automatic redirect following so we can validate each hop + .redirect(reqwest::redirect::Policy::limited(0)) + .build() + .expect("reqwest client build should not fail") + }) +} + +/// Check if a host string resolves to or is a private/internal IP. +fn is_private_host(host: &str) -> bool { + host.eq_ignore_ascii_case("localhost") + || host.eq_ignore_ascii_case("127.0.0.1") + || host.eq_ignore_ascii_case("::1") + || host.eq_ignore_ascii_case("0.0.0.0") + || host.eq_ignore_ascii_case("metadata.google.internal") + || host.eq_ignore_ascii_case("169.254.169.254") + || host.starts_with("10.") + || host.starts_with("172.16.") + || host.starts_with("172.17.") + || host.starts_with("172.18.") + || host.starts_with("172.19.") + || host.starts_with("172.20.") + || host.starts_with("172.21.") + || host.starts_with("172.22.") + || host.starts_with("172.23.") + || host.starts_with("172.24.") + || host.starts_with("172.25.") + || host.starts_with("172.26.") + || host.starts_with("172.27.") + || host.starts_with("172.28.") + || host.starts_with("172.29.") + || host.starts_with("172.30.") + || host.starts_with("172.31.") + || host.starts_with("192.168.") +} + +/// Validate URL and any redirect hops against SSRF rules. +fn validate_url_against_ssrf(url_str: &str) -> Result { + let parsed = reqwest::Url::parse(url_str) + .map_err(|e| ToolError::ExecutionError(format!("Invalid URL: {}", e)))?; + if let Some(host) = parsed.host_str() { + if is_private_host(host) { + return Err(ToolError::ExecutionError( + "Requests to internal/private IPs are not allowed for security reasons".into(), + )); + } + } + Ok(parsed) +} + /// Perform an HTTP request and return the response body and metadata. /// Supports GET, POST, PUT, DELETE methods. Useful for fetching web pages, /// calling external APIs, or downloading resources. @@ -13,11 +82,14 @@ pub async fn curl_exec( _ctx: ToolContext, args: serde_json::Value, ) -> Result { - let url = args + let url_str = args .get("url") .and_then(|v| v.as_str()) .ok_or_else(|| ToolError::ExecutionError("url is required".into()))?; + // SSRF protection: validate initial URL + validate_url_against_ssrf(url_str)?; + let method = args .get("method") .and_then(|v| v.as_str()) @@ -36,104 +108,156 @@ pub async fn curl_exec( }) .unwrap_or_default(); + // Block sensitive headers that could be used for injection attacks + for (key, _) in &headers { + if BLOCKED_HEADERS.contains(&key.to_lowercase().as_str()) { + return Err(ToolError::ExecutionError( + format!("Header '{}' is not allowed for security reasons", key), + )); + } + } + let timeout_secs = args .get("timeout") .and_then(|v| v.as_u64()) .unwrap_or(30) .min(120); - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(timeout_secs)) - .build() - .map_err(|e| ToolError::ExecutionError(format!("Failed to build HTTP client: {}", e)))?; + let client = shared_client(); + // Build a per-request client with the specific timeout by using the shared + // client's connection pool but overriding timeout per request via request builder. + // Since reqwest::Client::builder().redirect(Policy::limited(0)) disables auto-redirects, + // we manually follow up to 5 redirects with SSRF validation on each hop. - let mut request = match method.as_str() { - "GET" => client.get(url), - "POST" => client.post(url), - "PUT" => client.put(url), - "DELETE" => client.delete(url), - "PATCH" => client.patch(url), - "HEAD" => client.head(url), - _ => { - return Err(ToolError::ExecutionError(format!( - "Unsupported HTTP method: {}. Use GET, POST, PUT, DELETE, PATCH, or HEAD.", - method - ))) + let mut current_url = url_str.to_string(); + let mut redirect_count = 0u32; + const MAX_REDIRECTS: u32 = 5; + + loop { + let mut request = match method.as_str() { + "GET" => client.get(¤t_url), + "POST" => client.post(¤t_url), + "PUT" => client.put(¤t_url), + "DELETE" => client.delete(¤t_url), + "PATCH" => client.patch(¤t_url), + "HEAD" => client.head(¤t_url), + _ => { + return Err(ToolError::ExecutionError(format!( + "Unsupported HTTP method: {}. Use GET, POST, PUT, DELETE, PATCH, or HEAD.", + method + ))) + } + }; + + request = request.timeout(std::time::Duration::from_secs(timeout_secs)); + + for (key, value) in &headers { + request = request.header(key, value); } - }; - for (key, value) in &headers { - request = request.header(key, value); - } + // Set default Content-Type for POST/PUT/PATCH if not provided and body exists + if body.is_some() && !headers.iter().any(|(k, _)| k.to_lowercase() == "content-type") { + request = request.header("Content-Type", "application/json"); + } - // Set default Content-Type for POST/PUT/PATCH if not provided and body exists - if body.is_some() && !headers.iter().any(|(k, _)| k.to_lowercase() == "content-type") { - request = request.header("Content-Type", "application/json"); - } + if let Some(ref b) = body { + request = request.body(b.clone()); + } - if let Some(ref b) = body { - request = request.body(b.clone()); - } + let response = request + .send() + .await + .map_err(|e| ToolError::ExecutionError(format!("HTTP request failed: {}", e)))?; - let response = request - .send() - .await - .map_err(|e| ToolError::ExecutionError(format!("HTTP request failed: {}", e)))?; + let status = response.status().as_u16(); - let status = response.status().as_u16(); - let status_text = response.status().canonical_reason().unwrap_or(""); + // Handle redirects manually with SSRF validation + if status >= 300 && status < 400 { + redirect_count += 1; + if redirect_count > MAX_REDIRECTS { + return Err(ToolError::ExecutionError( + format!("Too many redirects (max {})", MAX_REDIRECTS), + )); + } + let location = response.headers() + .get("location") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + let location = match location { + Some(l) => l, + None => return Err(ToolError::ExecutionError("Redirect with no Location header".into())), + }; + // Resolve relative redirect against current URL + let base = reqwest::Url::parse(¤t_url) + .map_err(|e| ToolError::ExecutionError(format!("Invalid current URL: {}", e)))?; + let next_url = base.join(&location) + .map_err(|e| ToolError::ExecutionError(format!("Invalid redirect URL: {}", e)))?; + // Validate redirect target against SSRF rules + if let Some(host) = next_url.host_str() { + if is_private_host(host) { + return Err(ToolError::ExecutionError( + "Redirect to internal/private IP is not allowed".into(), + )); + } + } + current_url = next_url.to_string(); + continue; + } - let response_headers: std::collections::HashMap = response - .headers() - .iter() - .map(|(k, v)| { - ( - k.to_string(), - v.to_str().unwrap_or("").to_string(), + let status_text = response.status().canonical_reason().unwrap_or(""); + + let response_headers: std::collections::HashMap = response + .headers() + .iter() + .map(|(k, v)| { + ( + k.to_string(), + v.to_str().unwrap_or("").to_string(), + ) + }) + .collect(); + + let content_type = response + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .to_string(); + + let is_text = content_type.starts_with("text/") + || content_type.contains("json") + || content_type.contains("xml") + || content_type.contains("javascript"); + + let body_bytes = response + .bytes() + .await + .map_err(|e| ToolError::ExecutionError(format!("Failed to read response body: {}", e)))?; + + let body_len = body_bytes.len(); + let truncated = body_len > MAX_BODY_BYTES; + let body_text = if truncated { + String::from("[Response truncated — exceeds 1 MB limit]") + } else if is_text { + String::from_utf8_lossy(&body_bytes).to_string() + } else { + format!( + "[Binary body, {} bytes, Content-Type: {}]", + body_len, content_type ) - }) - .collect(); + }; - let content_type = response - .headers() - .get("content-type") - .and_then(|v| v.to_str().ok()) - .unwrap_or("") - .to_string(); - - let is_text = content_type.starts_with("text/") - || content_type.contains("json") - || content_type.contains("xml") - || content_type.contains("javascript"); - - let body_bytes = response - .bytes() - .await - .map_err(|e| ToolError::ExecutionError(format!("Failed to read response body: {}", e)))?; - - let body_len = body_bytes.len(); - let truncated = body_len > MAX_BODY_BYTES; - let body_text = if truncated { - String::from("[Response truncated — exceeds 1 MB limit]") - } else if is_text { - String::from_utf8_lossy(&body_bytes).to_string() - } else { - format!( - "[Binary body, {} bytes, Content-Type: {}]", - body_len, content_type - ) - }; - - Ok(serde_json::json!({ - "url": url, - "method": method, - "status": status, - "status_text": status_text, - "headers": response_headers, - "body": body_text, - "truncated": truncated, - "size_bytes": body_len, - })) + return Ok(serde_json::json!({ + "url": current_url, + "method": method, + "status": status, + "status_text": status_text, + "headers": response_headers, + "body": body_text, + "truncated": truncated, + "size_bytes": body_len, + })); + } } // ─── tool definition ───────────────────────────────────────────────────────── diff --git a/libs/service/project_tools/issues.rs b/libs/service/project_tools/issues.rs index 2c2a0e1..1a20391 100644 --- a/libs/service/project_tools/issues.rs +++ b/libs/service/project_tools/issues.rs @@ -224,6 +224,17 @@ pub async fn create_issue_exec( .sender_id() .ok_or_else(|| ToolError::ExecutionError("No sender context".into()))?; + // Membership check: only project members can create issues + let member = ProjectMember::find() + .filter(project_members::Column::Project.eq(project_id)) + .filter(project_members::Column::User.eq(author_id)) + .one(db) + .await + .map_err(|e| ToolError::ExecutionError(e.to_string()))?; + if member.is_none() { + return Err(ToolError::ExecutionError("You are not a member of this project".into())); + } + let number = next_issue_number(db, project_id).await?; let now = Utc::now(); @@ -248,7 +259,8 @@ pub async fn create_issue_exec( .await .map_err(|e| ToolError::ExecutionError(e.to_string()))?; - // Add assignees + // Add assignees (collect errors for partial failure reporting) + let mut assignee_errors = Vec::new(); for uid in &assignee_ids { let a = issue_assignee::ActiveModel { issue: Set(model.id), @@ -256,10 +268,13 @@ pub async fn create_issue_exec( assigned_at: Set(now), ..Default::default() }; - let _ = a.insert(db).await; + if let Err(e) = a.insert(db).await { + assignee_errors.push(format!("assignee {}: {}", uid, e)); + } } // Add labels + let mut label_errors = Vec::new(); for lid in &label_ids { let l = issue_label::ActiveModel { issue: Set(model.id), @@ -267,7 +282,9 @@ pub async fn create_issue_exec( relation_at: Set(now), ..Default::default() }; - let _ = l.insert(db).await; + if let Err(e) = l.insert(db).await { + label_errors.push(format!("label {}: {}", lid, e)); + } } // Build assignee/label maps for response @@ -330,6 +347,11 @@ pub async fn create_issue_exec( "updated_at": model.updated_at.to_rfc3339(), "assignees": assignee_ids.iter().filter_map(|uid| assignee_map.get(uid)).collect::>(), "labels": label_ids.iter().filter_map(|lid| label_map.get(lid)).collect::>(), + "warnings": if assignee_errors.is_empty() && label_errors.is_empty() { + None + } else { + Some([assignee_errors, label_errors].concat()) + }, })) } diff --git a/libs/service/project_tools/repos.rs b/libs/service/project_tools/repos.rs index a1ad40d..e5eba91 100644 --- a/libs/service/project_tools/repos.rs +++ b/libs/service/project_tools/repos.rs @@ -4,6 +4,7 @@ use agent::{ToolContext, ToolDefinition, ToolError, ToolParam, ToolSchema}; use chrono::Utc; use git::commit::types::CommitOid; use git::commit::types::CommitSignature; +use git2; use models::projects::{MemberRole, ProjectMember}; use models::projects::project_members; use models::repos::repo; @@ -85,6 +86,16 @@ pub async fn create_repo_exec( .ok_or_else(|| ToolError::ExecutionError("name is required".into()))? .to_string(); + // Validate repo name: no path traversal, no special chars + if repo_name.contains("..") || repo_name.contains('/') || repo_name.contains('\\') + || repo_name.is_empty() || repo_name.len() > 100 + || !repo_name.chars().next().map_or(false, |c| c.is_alphanumeric()) + { + return Err(ToolError::ExecutionError( + "Invalid repository name: must start with alphanumeric, contain no path separators or '..', max 100 chars".into(), + )); + } + let description = args .get("description") .and_then(|v| v.as_str()) @@ -145,13 +156,16 @@ pub async fn create_repo_exec( .await .map_err(|e| ToolError::ExecutionError(e.to_string()))?; + // Initialize the bare git repository on disk + git2::Repository::init_bare(&repo_dir) + .map_err(|e| ToolError::ExecutionError(format!("Failed to init bare repo: {}", e)))?; + Ok(serde_json::json!({ "id": model.id.to_string(), "name": model.repo_name, "description": model.description, "default_branch": model.default_branch, "is_private": model.is_private, - "storage_path": model.storage_path, "created_at": model.created_at.to_rfc3339(), })) } @@ -294,6 +308,13 @@ pub async fn create_commit_exec( .unwrap_or("main") .to_string(); + // Validate branch: no path traversal, no slashes + if branch.contains("..") || branch.contains('/') || branch.contains('\\') || branch.is_empty() { + return Err(ToolError::ExecutionError( + "Invalid branch name: must not contain path separators or '..'".into(), + )); + } + let files = args .get("files") .and_then(|v| v.as_array()) @@ -350,32 +371,75 @@ pub async fn create_commit_exec( let repo = domain.repo(); // Get current head commit (parent) + // If the repo already has commits (has HEAD), the branch must exist. + // Only allow root commits on truly empty repos (no HEAD at all). + let has_head = repo.head().is_ok(); let parent_oid = repo.refname_to_id(&format!("refs/heads/{}", branch)).ok(); + + if has_head && parent_oid.is_none() { + return Err(ToolError::ExecutionError( + format!("Branch '{}' does not exist in this repository", branch), + )); + } + let parent_ids: Vec = parent_oid .map(|oid| CommitOid::from_git2(oid)) .into_iter() .collect(); - // Build index with new files + // Build index from existing tree first (preserves all previous files), + // then add/overwrite with the new files. let mut index = repo .index() .map_err(|e| ToolError::ExecutionError(format!("Failed to get index: {}", e)))?; + // If repo has a parent commit, read its tree into the index so we don't + // lose existing files when write_tree() is called. + if let Some(oid) = &parent_oid { + let parent_commit = repo.find_commit(*oid) + .map_err(|e| ToolError::ExecutionError(format!("Failed to find parent commit: {}", e)))?; + let parent_tree = parent_commit.tree() + .map_err(|e| ToolError::ExecutionError(format!("Failed to get parent tree: {}", e)))?; + index.read_tree(&parent_tree) + .map_err(|e| ToolError::ExecutionError(format!("Failed to read parent tree into index: {}", e)))?; + } + for file in files_data { let path = file .get("path") .and_then(|v| v.as_str()) .ok_or_else(|| ToolError::ExecutionError("Each file must have a 'path'".into()))?; + + // Validate path: no traversal, no absolute paths, no .git/ prefix + if path.contains("..") || path.starts_with('/') || path.starts_with('\\') + || path.is_empty() || path.starts_with(".git/") || path == ".git" + { + return Err(ToolError::ExecutionError( + format!("Invalid file path '{}': must be relative, no '..' or absolute path components", path) + )); + } let content = file .get("content") .and_then(|v| v.as_str()) .ok_or_else(|| ToolError::ExecutionError("Each file must have 'content'".into()))?; - let _oid = repo.blob(content.as_bytes()).map_err(|e| { - ToolError::ExecutionError(format!("Failed to write blob for '{}': {}", path, e)) - })?; - - index.add_path(path.as_ref()).map_err(|e| { + // add_frombuffer requires an IndexEntry with at minimum a path field set. + // It works for both bare and non-bare repos (add_path requires a working tree). + let mut entry = git2::IndexEntry { + ctime: git2::IndexTime::new(0, 0), + mtime: git2::IndexTime::new(0, 0), + dev: 0, + ino: 0, + mode: 0o100644, + uid: 0, + gid: 0, + file_size: 0, + id: git2::Oid::zero(), + flags: 0, + flags_extended: 0, + path: path.as_bytes().to_vec(), + }; + index.add_frombuffer(&mut entry, content.as_bytes()).map_err(|e| { ToolError::ExecutionError(format!("Failed to add '{}' to index: {}", path, e)) })?; } diff --git a/libs/service/search/service.rs b/libs/service/search/service.rs index f51d66a..5b9a052 100644 --- a/libs/service/search/service.rs +++ b/libs/service/search/service.rs @@ -5,6 +5,7 @@ use db::database::AppDatabase; use models::issues::issue; use models::projects::{project, project_members}; use models::repos::repo; +use models::rooms::{room, room_member}; use models::users::user; use sea_orm::*; use sea_query::{Expr as SqExpr, extension::postgres::PgExpr}; @@ -113,6 +114,39 @@ pub struct UserSearchItem { pub created_at: DateTime, } +// ─── Global message search ──────────────────────────────────────────────────── + +#[derive(Debug, Clone, Deserialize, utoipa::IntoParams)] +pub struct GlobalMessageSearchQuery { + #[param(min_length = 1, max_length = 200)] + pub q: String, + pub page: Option, + pub per_page: Option, +} + +#[derive(Debug, Clone, Serialize, ToSchema)] +pub struct GlobalMessageSearchItem { + pub id: Uuid, + pub room_id: Uuid, + pub room_name: String, + pub sender_id: Option, + pub sender_type: String, + pub display_name: Option, + pub content: String, + pub content_type: String, + pub send_at: DateTime, + pub highlighted_content: Option, +} + +#[derive(Debug, Clone, Serialize, ToSchema)] +pub struct GlobalMessageSearchResponse { + pub query: String, + pub messages: Vec, + pub total: i64, + pub page: u32, + pub per_page: u32, +} + // ─── Per-type result set ───────────────────────────────────────────────────── #[derive(Debug, Clone, Serialize, ToSchema)] @@ -465,4 +499,157 @@ impl AppService { Ok(SearchResultSet::new(items, total, page, per_page)) } + + /// Search messages across all rooms the current user can access. + /// Uses PostgreSQL full-text search with ts_headline for result highlighting. + pub async fn global_message_search( + &self, + ctx: &Session, + params: GlobalMessageSearchQuery, + ) -> Result { + let user_id = ctx.user(); + + // Anonymous users cannot search messages + let Some(user_id) = user_id else { + return Err(AppError::Unauthorized); + }; + + if params.q.trim().is_empty() { + return Ok(GlobalMessageSearchResponse { + query: params.q.clone(), + messages: Vec::new(), + total: 0, + page: params.page.unwrap_or(1), + per_page: params.per_page.unwrap_or(20), + }); + } + + let page = std::cmp::max(1, params.page.unwrap_or(1)); + let per_page = std::cmp::min(100, std::cmp::max(1, params.per_page.unwrap_or(20))); + let offset = (page - 1) * per_page; + let q = params.q.trim(); + + // Build the set of room IDs the user can access: + // 1. Direct room memberships + let direct_rooms: Vec = room_member::Entity::find() + .filter(room_member::Column::User.eq(user_id)) + .select_only() + .column(room_member::Column::Room) + .into_tuple::() + .all(&self.db) + .await + .map_err(|_| AppError::InternalError)?; + + // 2. Public rooms in projects the user is a member of + let project_ids = accessible_project_ids(&self.db, Some(user_id)).await?; + let public_rooms: Vec = room::Entity::find() + .filter(room::Column::Project.is_in(project_ids.clone())) + .filter(room::Column::Public.eq(true)) + .select_only() + .column(room::Column::Id) + .into_tuple::() + .all(&self.db) + .await + .map_err(|_| AppError::InternalError)?; + + // Merge and deduplicate accessible room IDs using a HashSet + use std::collections::HashSet; + let mut accessible_set: HashSet = direct_rooms.into_iter().collect(); + for rid in public_rooms { + accessible_set.insert(rid); + } + + let accessible_rooms: Vec = accessible_set.iter().cloned().collect(); + + if accessible_rooms.is_empty() { + return Ok(GlobalMessageSearchResponse { + query: q.to_string(), + messages: Vec::new(), + total: 0, + page, + per_page, + }); + } + + // Fetch room names for the accessible rooms + let room_names_map: std::collections::HashMap = room::Entity::find() + .filter(room::Column::Id.is_in(accessible_rooms.clone())) + .all(&self.db) + .await + .map_err(|_| AppError::InternalError)? + .into_iter() + .map(|r| (r.id, r.room_name)) + .collect(); + + let tsquery = format!("plainto_tsquery('simple', $1)"); + let sql = format!( + r#" + SELECT m.id, m.room, m.sender_type, m.sender_id, + m.content, m.content_type, m.send_at, + ts_headline('simple', m.content, {}, 'StartSel=, StopSel=, MaxWords=50, MinWords=15') AS highlighted_content + FROM room_message m + WHERE m.room = ANY($2) + AND m.content_tsv @@ {} + AND m.revoked IS NULL + ORDER BY m.send_at DESC + LIMIT $3 OFFSET $4"#, + tsquery, + tsquery + ); + + // Results query + let results_sql = Statement::from_sql_and_values( + DbBackend::Postgres, + &sql, + vec![q.into(), accessible_rooms.clone().into(), per_page.into(), offset.into()], + ); + let rows = self.db.query_all_raw(results_sql).await?; + + let mut messages: Vec = Vec::new(); + for row in rows { + let room_id: Uuid = row.try_get::("", "room").unwrap_or_default(); + let sender_type_str = row.try_get::("", "sender_type").unwrap_or_default(); + let content_type_str = row.try_get::("", "content_type").unwrap_or_default(); + + let highlighted = row + .try_get::("", "highlighted_content") + .ok(); + + messages.push(GlobalMessageSearchItem { + id: row.try_get::("", "id").unwrap_or_default(), + room_id, + room_name: room_names_map.get(&room_id).cloned().unwrap_or_default(), + sender_id: row.try_get::>("", "sender_id").ok().flatten(), + sender_type: sender_type_str, + display_name: None, + content: row.try_get::("", "content").unwrap_or_default(), + content_type: content_type_str, + send_at: row.try_get::>("", "send_at").unwrap_or_default(), + highlighted_content: highlighted, + }); + } + + // Count total across all accessible rooms + let count_sql = format!( + "SELECT COUNT(*) AS count FROM room_message WHERE room = ANY($1) AND content_tsv @@ {} AND revoked IS NULL", + tsquery + ); + let count_stmt = Statement::from_sql_and_values( + DbBackend::Postgres, + &count_sql, + vec![accessible_rooms.into(), q.into()], + ); + let count_row = self.db.query_one_raw(count_stmt).await?; + let total: i64 = count_row + .and_then(|r| r.try_get::("", "count").ok()) + .unwrap_or(0); + + Ok(GlobalMessageSearchResponse { + query: q.to_string(), + messages, + total, + page, + per_page, + }) + } }