487 lines
18 KiB
Rust
487 lines
18 KiB
Rust
use std::pin::Pin;
|
||
use std::sync::Arc;
|
||
use uuid::Uuid;
|
||
|
||
use crate::client::AiClientConfig;
|
||
use crate::client::types::{ChatRequestMessage, ToolCall};
|
||
use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream};
|
||
use crate::embed::EmbedService;
|
||
use crate::error::Result;
|
||
use crate::tool::registry::ToolRegistry;
|
||
use crate::tool::{
|
||
ToolCall as AgentToolCall, ToolContext, ToolDefinition, ToolExecutor, ToolHandler, ToolParam,
|
||
};
|
||
use sea_orm::{ActiveModelTrait, EntityTrait, Set};
|
||
|
||
use super::service::StreamResult;
|
||
use super::{AiChunkType, AiStreamChunk, StreamCallback};
|
||
|
||
// Keyword-extraction-based title generator: reads conversation messages, extracts
|
||
// meaningful words, and updates the conversation record with a short title.
|
||
async fn generate_title_for_conversation(
|
||
ctx: &ToolContext,
|
||
conversation_id: Uuid,
|
||
) -> Result<serde_json::Value> {
|
||
use models::ai::{AiMessage, ai_conversation, ai_message};
|
||
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
|
||
|
||
let db_reader = ctx.db().reader();
|
||
let db_writer = ctx.db().writer();
|
||
let conv = ai_conversation::Entity::find_by_id(conversation_id)
|
||
.one(db_reader)
|
||
.await
|
||
.map_err(|e| crate::error::AgentError::ToolExecutionFailed {
|
||
tool: "generate_title".into(),
|
||
cause: format!("db error: {}", e),
|
||
})?
|
||
.ok_or_else(|| crate::error::AgentError::NotFound("Conversation not found".into()))?;
|
||
|
||
let recent_messages = AiMessage::find()
|
||
.filter(ai_message::Column::ConversationId.eq(conversation_id))
|
||
.filter(ai_message::Column::Role.eq("user"))
|
||
.order_by_desc(ai_message::Column::CreatedAt)
|
||
.limit(3)
|
||
.all(db_reader)
|
||
.await
|
||
.map_err(|e| crate::error::AgentError::ToolExecutionFailed {
|
||
tool: "generate_title".into(),
|
||
cause: format!("db error: {}", e),
|
||
})?;
|
||
|
||
if recent_messages.is_empty() {
|
||
return Err(crate::error::AgentError::ToolExecutionFailed {
|
||
tool: "generate_title".into(),
|
||
cause: "No user messages found".into(),
|
||
});
|
||
}
|
||
|
||
let content = recent_messages
|
||
.first()
|
||
.and_then(|m| m.content.as_array())
|
||
.and_then(|arr| arr.first())
|
||
.and_then(|v| v.get("content"))
|
||
.and_then(|c| c.as_str())
|
||
.unwrap_or("");
|
||
|
||
let words: Vec<&str> = content
|
||
.split_whitespace()
|
||
.filter(|w| w.len() > 2 && !is_stop_word(w))
|
||
.take(5)
|
||
.collect();
|
||
|
||
let title = if words.is_empty() {
|
||
"New Chat".to_string()
|
||
} else {
|
||
words.join(" ")
|
||
};
|
||
|
||
let mut active: ai_conversation::ActiveModel = conv.into();
|
||
active.title = Set(Some(title.clone()));
|
||
active.updated_at = Set(chrono::Utc::now());
|
||
active
|
||
.update(db_writer)
|
||
.await
|
||
.map_err(|e| crate::error::AgentError::ToolExecutionFailed {
|
||
tool: "generate_title".into(),
|
||
cause: format!("failed to update title: {}", e),
|
||
})?;
|
||
|
||
Ok(serde_json::json!({ "conversation_id": conversation_id.to_string(), "title": title }))
|
||
}
|
||
|
||
fn is_stop_word(w: &str) -> bool {
|
||
matches!(
|
||
w.to_lowercase().as_str(),
|
||
"the"
|
||
| "this"
|
||
| "that"
|
||
| "what"
|
||
| "which"
|
||
| "when"
|
||
| "where"
|
||
| "why"
|
||
| "how"
|
||
| "can"
|
||
| "could"
|
||
| "would"
|
||
| "should"
|
||
| "please"
|
||
| "help"
|
||
| "thanks"
|
||
| "thank"
|
||
| "you"
|
||
| "your"
|
||
| "have"
|
||
| "has"
|
||
| "had"
|
||
| "with"
|
||
| "for"
|
||
| "from"
|
||
| "into"
|
||
| "about"
|
||
| "also"
|
||
| "just"
|
||
| "now"
|
||
| "very"
|
||
| "really"
|
||
)
|
||
}
|
||
|
||
type SharedCallback = Arc<
|
||
dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync,
|
||
>;
|
||
|
||
/// Simplified ReAct execution for Chat API.
|
||
///
|
||
/// Unlike `execute_process_stream` (which requires `AiRequest` with room-specific data),
|
||
/// this function takes messages and tools directly. It does NOT record AI sessions to
|
||
/// the `ai_session` table — the caller is responsible for persisting results.
|
||
pub async fn execute_chat_stream(
|
||
messages: Vec<ChatRequestMessage>,
|
||
tools: Vec<serde_json::Value>,
|
||
model_name: &str,
|
||
config: &AiClientConfig,
|
||
temperature: f32,
|
||
max_tokens: u32,
|
||
max_tool_depth: usize,
|
||
tool_registry: Option<&ToolRegistry>,
|
||
db: db::database::AppDatabase,
|
||
cache: db::cache::AppCache,
|
||
app_config: config::AppConfig,
|
||
project_id: Uuid,
|
||
sender_uid: Uuid,
|
||
embed_service: Option<EmbedService>,
|
||
on_chunk: StreamCallback,
|
||
conversation_id: Option<uuid::Uuid>,
|
||
) -> Result<StreamResult> {
|
||
let on_chunk: SharedCallback = Arc::from(on_chunk);
|
||
let tools_enabled = !tools.is_empty();
|
||
let mut messages = messages;
|
||
let mut tool_depth = 0;
|
||
let mut total_input_tokens = 0i64;
|
||
let mut total_output_tokens = 0i64;
|
||
let mut full_content = String::new();
|
||
let mut all_chunks: Vec<StreamChunk> = Vec::new();
|
||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
|
||
|
||
// Conditionally inject chat_generate_title tool if conversation has no title
|
||
let (tools, _tools_injected) = if let Some(conv_id) = conversation_id {
|
||
if let Some(registry) = tool_registry {
|
||
let db_reader = db.reader();
|
||
let has_title = models::ai::ai_conversation::Entity::find_by_id(conv_id)
|
||
.one(db_reader)
|
||
.await
|
||
.map(|c| c.map(|m| m.title.is_some()).unwrap_or(false))
|
||
.unwrap_or(false);
|
||
if !has_title {
|
||
let mut reg = registry.clone();
|
||
reg.register(
|
||
ToolDefinition::new("chat_generate_title")
|
||
.description(
|
||
"Generate a concise title (5 words or fewer) for the current conversation \
|
||
based on its message history, and save it to the conversation record. \
|
||
Call this tool at the start of a new conversation if it has no title.",
|
||
)
|
||
.parameters(crate::tool::ToolSchema {
|
||
schema_type: "object".into(),
|
||
properties: Some({
|
||
let mut p = std::collections::HashMap::new();
|
||
p.insert("conversation_id".into(), ToolParam {
|
||
name: "conversation_id".into(),
|
||
param_type: "string".into(),
|
||
description: Some("The UUID of the conversation (required).".into()),
|
||
required: true,
|
||
properties: None,
|
||
items: None,
|
||
});
|
||
p
|
||
}),
|
||
required: Some(vec!["conversation_id".into()]),
|
||
}),
|
||
ToolHandler::new(|ctx, args| {
|
||
let conv_id = args.get("conversation_id")
|
||
.and_then(|v| v.as_str())
|
||
.and_then(|s| Uuid::parse_str(s).ok());
|
||
Box::pin(async move {
|
||
match conv_id {
|
||
Some(id) => generate_title_for_conversation(&ctx, id).await
|
||
.map_err(|e| crate::tool::ToolError::ExecutionError(e.to_string())),
|
||
None => Err(crate::tool::ToolError::ExecutionError("conversation_id missing".into())),
|
||
}
|
||
})
|
||
}),
|
||
);
|
||
// Prepend system message instructing the model to generate title first
|
||
messages.insert(0, ChatRequestMessage::system(
|
||
"IMPORTANT: If the conversation has no title, you MUST call chat_generate_title \
|
||
with the conversation_id immediately before answering any user question. \
|
||
The title must be 5 words or fewer and should summarize the user's intent.".to_string(),
|
||
));
|
||
(reg.to_openai_tools(), true)
|
||
} else {
|
||
(tools.clone(), false)
|
||
}
|
||
} else {
|
||
(tools.clone(), false)
|
||
}
|
||
} else {
|
||
(tools.clone(), false)
|
||
};
|
||
|
||
loop {
|
||
let on_chunk_cb = on_chunk.clone();
|
||
let on_chunk_cb2 = on_chunk.clone();
|
||
let tx_arc = Arc::new(tx.clone());
|
||
let tx_arc2 = tx_arc.clone();
|
||
|
||
let response = call_stream(
|
||
&messages,
|
||
model_name,
|
||
config,
|
||
temperature,
|
||
max_tokens,
|
||
if tools_enabled { Some(&tools) } else { None },
|
||
None,
|
||
Arc::new(move |delta| {
|
||
let content = delta.to_string();
|
||
let fut = on_chunk_cb(AiStreamChunk {
|
||
content,
|
||
done: false,
|
||
chunk_type: AiChunkType::Answer,
|
||
metadata: None,
|
||
});
|
||
fut
|
||
}),
|
||
Arc::new(move |delta| {
|
||
let fut = on_chunk_cb2(AiStreamChunk {
|
||
content: delta.to_string(),
|
||
done: false,
|
||
chunk_type: AiChunkType::Thinking,
|
||
metadata: None,
|
||
});
|
||
fut
|
||
}),
|
||
Arc::new(move |tc: &StreamedToolCall| {
|
||
let tx = tx_arc2.clone();
|
||
let tc_owned = tc.clone();
|
||
Box::pin(async move {
|
||
let _ = tx.send(tc_owned);
|
||
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
|
||
}),
|
||
)
|
||
.await?;
|
||
|
||
total_input_tokens += response.input_tokens;
|
||
total_output_tokens += response.output_tokens;
|
||
all_chunks.extend(response.chunks.clone());
|
||
|
||
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
|
||
if !has_tool_calls {
|
||
let final_content = response.content.clone();
|
||
// Don't push full content as a chunk — incremental deltas in
|
||
// response.chunks (already added above) sum to the same text.
|
||
// merge_consecutive_blocks would concatenate delta_sum + full =
|
||
// 2× full, causing duplicate content in DB persistence.
|
||
return Ok(StreamResult {
|
||
content: final_content,
|
||
reasoning_content: response.reasoning_content,
|
||
input_tokens: total_input_tokens,
|
||
output_tokens: total_output_tokens,
|
||
chunks: all_chunks,
|
||
});
|
||
}
|
||
|
||
full_content.push_str(&response.content);
|
||
|
||
let tool_calls: Vec<ToolCall> = response
|
||
.tool_calls
|
||
.iter()
|
||
.map(|tc| ToolCall {
|
||
id: tc.id.clone(),
|
||
type_: "function".into(),
|
||
function: crate::client::types::ToolCallFunction {
|
||
name: tc.name.clone(),
|
||
arguments: tc.arguments.clone(),
|
||
},
|
||
})
|
||
.collect();
|
||
|
||
messages.push(ChatRequestMessage::assistant(
|
||
Some(response.content.clone()),
|
||
Some(tool_calls.clone()),
|
||
));
|
||
|
||
// Drain tool call notifications
|
||
loop {
|
||
match rx.try_recv() {
|
||
Ok(tc) => {
|
||
let args_display = if tc.arguments.len() > 100 {
|
||
let end = tc
|
||
.arguments
|
||
.char_indices()
|
||
.map(|(i, _)| i)
|
||
.take_while(|&i| i <= 100)
|
||
.last()
|
||
.unwrap_or(100);
|
||
format!("{}...", &tc.arguments[..end])
|
||
} else {
|
||
tc.arguments.clone()
|
||
};
|
||
let tool_display = format!("🔧 {}({})", tc.name, args_display);
|
||
on_chunk(AiStreamChunk {
|
||
content: tool_display.clone(),
|
||
done: false,
|
||
chunk_type: AiChunkType::ToolCall,
|
||
metadata: None,
|
||
})
|
||
.await;
|
||
all_chunks.push(StreamChunk {
|
||
chunk_type: StreamChunkType::ToolCall,
|
||
content: tool_display,
|
||
});
|
||
}
|
||
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
|
||
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
|
||
}
|
||
}
|
||
|
||
let calls: Vec<AgentToolCall> = response
|
||
.tool_calls
|
||
.iter()
|
||
.map(|tc| AgentToolCall {
|
||
id: tc.id.clone(),
|
||
name: tc.name.clone(),
|
||
arguments: tc.arguments.clone(),
|
||
})
|
||
.collect();
|
||
|
||
let tool_messages = execute_tools(
|
||
&calls,
|
||
&db,
|
||
&cache,
|
||
&app_config,
|
||
project_id,
|
||
sender_uid,
|
||
tool_registry,
|
||
embed_service.as_ref(),
|
||
&on_chunk,
|
||
&mut all_chunks,
|
||
)
|
||
.await;
|
||
|
||
messages.extend(tool_messages);
|
||
|
||
tool_depth += 1;
|
||
if tool_depth >= max_tool_depth {
|
||
let max_depth_text = format!(
|
||
"[AI reached maximum tool depth ({}) — no final answer produced]",
|
||
max_tool_depth
|
||
);
|
||
on_chunk(AiStreamChunk {
|
||
content: max_depth_text.clone(),
|
||
done: true,
|
||
chunk_type: AiChunkType::Answer,
|
||
metadata: None,
|
||
})
|
||
.await;
|
||
all_chunks.push(StreamChunk {
|
||
chunk_type: StreamChunkType::Answer,
|
||
content: max_depth_text,
|
||
});
|
||
return Ok(StreamResult {
|
||
content: full_content,
|
||
reasoning_content: String::new(),
|
||
input_tokens: 0,
|
||
output_tokens: 0,
|
||
chunks: all_chunks,
|
||
});
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn execute_tools(
|
||
calls: &[AgentToolCall],
|
||
db: &db::database::AppDatabase,
|
||
cache: &db::cache::AppCache,
|
||
app_config: &config::AppConfig,
|
||
project_id: Uuid,
|
||
sender_uid: Uuid,
|
||
tool_registry: Option<&ToolRegistry>,
|
||
embed_service: Option<&EmbedService>,
|
||
on_chunk: &SharedCallback,
|
||
all_chunks: &mut Vec<StreamChunk>,
|
||
) -> Vec<ChatRequestMessage> {
|
||
let mut tool_messages = Vec::new();
|
||
let mut ctx = ToolContext::new(
|
||
db.clone(),
|
||
cache.clone(),
|
||
app_config.clone(),
|
||
Uuid::nil(),
|
||
Some(sender_uid),
|
||
)
|
||
.with_project(project_id);
|
||
if let Some(es) = embed_service {
|
||
ctx = ctx.with_embed_service(es.clone());
|
||
}
|
||
if let Some(registry) = tool_registry {
|
||
ctx.registry_mut().merge(registry.clone());
|
||
}
|
||
|
||
let mut join_set = tokio::task::JoinSet::new();
|
||
for call in calls {
|
||
let call_clone = call.clone();
|
||
let mut ctx_clone = ctx.clone();
|
||
join_set.spawn(async move {
|
||
let executor = ToolExecutor::new();
|
||
let res = executor
|
||
.execute_batch(vec![call_clone.clone()], &mut ctx_clone)
|
||
.await;
|
||
(call_clone, res)
|
||
});
|
||
}
|
||
|
||
let heartbeat_dur = std::time::Duration::from_secs(10);
|
||
while !join_set.is_empty() {
|
||
tokio::select! {
|
||
Some(res) = join_set.join_next() => {
|
||
if let Ok((call, results)) = res {
|
||
match results {
|
||
Ok(results) => {
|
||
for result in &results {
|
||
let preview = match &result.result {
|
||
crate::tool::ToolResult::Ok(v) => {
|
||
let t = v.to_string();
|
||
if t.len() > 300 {
|
||
let end = t.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
|
||
format!("{}...", &t[..end])
|
||
} else { t.clone() }
|
||
}
|
||
crate::tool::ToolResult::Error(msg) => msg.clone(),
|
||
};
|
||
tracing::debug!("tool_result: {} — {}", call.name, preview);
|
||
}
|
||
let success_display = format!("✅ {}", call.name);
|
||
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult, metadata: None }).await;
|
||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
|
||
let msgs = ToolExecutor::to_tool_messages(&results);
|
||
tool_messages.extend(msgs);
|
||
}
|
||
Err(e) => {
|
||
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
|
||
let err_text = format!("[Tool call failed: {}]", e);
|
||
let err_display = format!("❌ {} (failed)", call.name);
|
||
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult, metadata: None }).await;
|
||
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
|
||
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
|
||
}
|
||
}
|
||
}
|
||
},
|
||
_ = tokio::time::sleep(heartbeat_dur) => {
|
||
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall, metadata: None }).await;
|
||
}
|
||
}
|
||
}
|
||
tool_messages
|
||
}
|