gitdataai/libs/agent/chat/chat_execution.rs

487 lines
18 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream};
use crate::embed::EmbedService;
use crate::error::Result;
use crate::tool::registry::ToolRegistry;
use crate::tool::{
ToolCall as AgentToolCall, ToolContext, ToolDefinition, ToolExecutor, ToolHandler, ToolParam,
};
use sea_orm::{ActiveModelTrait, EntityTrait, Set};
use super::service::StreamResult;
use super::{AiChunkType, AiStreamChunk, StreamCallback};
// Keyword-extraction-based title generator: reads conversation messages, extracts
// meaningful words, and updates the conversation record with a short title.
async fn generate_title_for_conversation(
ctx: &ToolContext,
conversation_id: Uuid,
) -> Result<serde_json::Value> {
use models::ai::{AiMessage, ai_conversation, ai_message};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
let db_reader = ctx.db().reader();
let db_writer = ctx.db().writer();
let conv = ai_conversation::Entity::find_by_id(conversation_id)
.one(db_reader)
.await
.map_err(|e| crate::error::AgentError::ToolExecutionFailed {
tool: "generate_title".into(),
cause: format!("db error: {}", e),
})?
.ok_or_else(|| crate::error::AgentError::NotFound("Conversation not found".into()))?;
let recent_messages = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.filter(ai_message::Column::Role.eq("user"))
.order_by_desc(ai_message::Column::CreatedAt)
.limit(3)
.all(db_reader)
.await
.map_err(|e| crate::error::AgentError::ToolExecutionFailed {
tool: "generate_title".into(),
cause: format!("db error: {}", e),
})?;
if recent_messages.is_empty() {
return Err(crate::error::AgentError::ToolExecutionFailed {
tool: "generate_title".into(),
cause: "No user messages found".into(),
});
}
let content = recent_messages
.first()
.and_then(|m| m.content.as_array())
.and_then(|arr| arr.first())
.and_then(|v| v.get("content"))
.and_then(|c| c.as_str())
.unwrap_or("");
let words: Vec<&str> = content
.split_whitespace()
.filter(|w| w.len() > 2 && !is_stop_word(w))
.take(5)
.collect();
let title = if words.is_empty() {
"New Chat".to_string()
} else {
words.join(" ")
};
let mut active: ai_conversation::ActiveModel = conv.into();
active.title = Set(Some(title.clone()));
active.updated_at = Set(chrono::Utc::now());
active
.update(db_writer)
.await
.map_err(|e| crate::error::AgentError::ToolExecutionFailed {
tool: "generate_title".into(),
cause: format!("failed to update title: {}", e),
})?;
Ok(serde_json::json!({ "conversation_id": conversation_id.to_string(), "title": title }))
}
fn is_stop_word(w: &str) -> bool {
matches!(
w.to_lowercase().as_str(),
"the"
| "this"
| "that"
| "what"
| "which"
| "when"
| "where"
| "why"
| "how"
| "can"
| "could"
| "would"
| "should"
| "please"
| "help"
| "thanks"
| "thank"
| "you"
| "your"
| "have"
| "has"
| "had"
| "with"
| "for"
| "from"
| "into"
| "about"
| "also"
| "just"
| "now"
| "very"
| "really"
)
}
type SharedCallback = Arc<
dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync,
>;
/// Simplified ReAct execution for Chat API.
///
/// Unlike `execute_process_stream` (which requires `AiRequest` with room-specific data),
/// this function takes messages and tools directly. It does NOT record AI sessions to
/// the `ai_session` table — the caller is responsible for persisting results.
pub async fn execute_chat_stream(
messages: Vec<ChatRequestMessage>,
tools: Vec<serde_json::Value>,
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
max_tool_depth: usize,
tool_registry: Option<&ToolRegistry>,
db: db::database::AppDatabase,
cache: db::cache::AppCache,
app_config: config::AppConfig,
project_id: Uuid,
sender_uid: Uuid,
embed_service: Option<EmbedService>,
on_chunk: StreamCallback,
conversation_id: Option<uuid::Uuid>,
) -> Result<StreamResult> {
let on_chunk: SharedCallback = Arc::from(on_chunk);
let tools_enabled = !tools.is_empty();
let mut messages = messages;
let mut tool_depth = 0;
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let mut full_content = String::new();
let mut all_chunks: Vec<StreamChunk> = Vec::new();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
// Conditionally inject chat_generate_title tool if conversation has no title
let (tools, _tools_injected) = if let Some(conv_id) = conversation_id {
if let Some(registry) = tool_registry {
let db_reader = db.reader();
let has_title = models::ai::ai_conversation::Entity::find_by_id(conv_id)
.one(db_reader)
.await
.map(|c| c.map(|m| m.title.is_some()).unwrap_or(false))
.unwrap_or(false);
if !has_title {
let mut reg = registry.clone();
reg.register(
ToolDefinition::new("chat_generate_title")
.description(
"Generate a concise title (5 words or fewer) for the current conversation \
based on its message history, and save it to the conversation record. \
Call this tool at the start of a new conversation if it has no title.",
)
.parameters(crate::tool::ToolSchema {
schema_type: "object".into(),
properties: Some({
let mut p = std::collections::HashMap::new();
p.insert("conversation_id".into(), ToolParam {
name: "conversation_id".into(),
param_type: "string".into(),
description: Some("The UUID of the conversation (required).".into()),
required: true,
properties: None,
items: None,
});
p
}),
required: Some(vec!["conversation_id".into()]),
}),
ToolHandler::new(|ctx, args| {
let conv_id = args.get("conversation_id")
.and_then(|v| v.as_str())
.and_then(|s| Uuid::parse_str(s).ok());
Box::pin(async move {
match conv_id {
Some(id) => generate_title_for_conversation(&ctx, id).await
.map_err(|e| crate::tool::ToolError::ExecutionError(e.to_string())),
None => Err(crate::tool::ToolError::ExecutionError("conversation_id missing".into())),
}
})
}),
);
// Prepend system message instructing the model to generate title first
messages.insert(0, ChatRequestMessage::system(
"IMPORTANT: If the conversation has no title, you MUST call chat_generate_title \
with the conversation_id immediately before answering any user question. \
The title must be 5 words or fewer and should summarize the user's intent.".to_string(),
));
(reg.to_openai_tools(), true)
} else {
(tools.clone(), false)
}
} else {
(tools.clone(), false)
}
} else {
(tools.clone(), false)
};
loop {
let on_chunk_cb = on_chunk.clone();
let on_chunk_cb2 = on_chunk.clone();
let tx_arc = Arc::new(tx.clone());
let tx_arc2 = tx_arc.clone();
let response = call_stream(
&messages,
model_name,
config,
temperature,
max_tokens,
if tools_enabled { Some(&tools) } else { None },
None,
Arc::new(move |delta| {
let content = delta.to_string();
let fut = on_chunk_cb(AiStreamChunk {
content,
done: false,
chunk_type: AiChunkType::Answer,
metadata: None,
});
fut
}),
Arc::new(move |delta| {
let fut = on_chunk_cb2(AiStreamChunk {
content: delta.to_string(),
done: false,
chunk_type: AiChunkType::Thinking,
metadata: None,
});
fut
}),
Arc::new(move |tc: &StreamedToolCall| {
let tx = tx_arc2.clone();
let tc_owned = tc.clone();
Box::pin(async move {
let _ = tx.send(tc_owned);
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
)
.await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
all_chunks.extend(response.chunks.clone());
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
if !has_tool_calls {
let final_content = response.content.clone();
// Don't push full content as a chunk — incremental deltas in
// response.chunks (already added above) sum to the same text.
// merge_consecutive_blocks would concatenate delta_sum + full =
// 2× full, causing duplicate content in DB persistence.
return Ok(StreamResult {
content: final_content,
reasoning_content: response.reasoning_content,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
chunks: all_chunks,
});
}
full_content.push_str(&response.content);
let tool_calls: Vec<ToolCall> = response
.tool_calls
.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
type_: "function".into(),
function: crate::client::types::ToolCallFunction {
name: tc.name.clone(),
arguments: tc.arguments.clone(),
},
})
.collect();
messages.push(ChatRequestMessage::assistant(
Some(response.content.clone()),
Some(tool_calls.clone()),
));
// Drain tool call notifications
loop {
match rx.try_recv() {
Ok(tc) => {
let args_display = if tc.arguments.len() > 100 {
let end = tc
.arguments
.char_indices()
.map(|(i, _)| i)
.take_while(|&i| i <= 100)
.last()
.unwrap_or(100);
format!("{}...", &tc.arguments[..end])
} else {
tc.arguments.clone()
};
let tool_display = format!("🔧 {}({})", tc.name, args_display);
on_chunk(AiStreamChunk {
content: tool_display.clone(),
done: false,
chunk_type: AiChunkType::ToolCall,
metadata: None,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::ToolCall,
content: tool_display,
});
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
}
}
let calls: Vec<AgentToolCall> = response
.tool_calls
.iter()
.map(|tc| AgentToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect();
let tool_messages = execute_tools(
&calls,
&db,
&cache,
&app_config,
project_id,
sender_uid,
tool_registry,
embed_service.as_ref(),
&on_chunk,
&mut all_chunks,
)
.await;
messages.extend(tool_messages);
tool_depth += 1;
if tool_depth >= max_tool_depth {
let max_depth_text = format!(
"[AI reached maximum tool depth ({}) — no final answer produced]",
max_tool_depth
);
on_chunk(AiStreamChunk {
content: max_depth_text.clone(),
done: true,
chunk_type: AiChunkType::Answer,
metadata: None,
})
.await;
all_chunks.push(StreamChunk {
chunk_type: StreamChunkType::Answer,
content: max_depth_text,
});
return Ok(StreamResult {
content: full_content,
reasoning_content: String::new(),
input_tokens: 0,
output_tokens: 0,
chunks: all_chunks,
});
}
}
}
async fn execute_tools(
calls: &[AgentToolCall],
db: &db::database::AppDatabase,
cache: &db::cache::AppCache,
app_config: &config::AppConfig,
project_id: Uuid,
sender_uid: Uuid,
tool_registry: Option<&ToolRegistry>,
embed_service: Option<&EmbedService>,
on_chunk: &SharedCallback,
all_chunks: &mut Vec<StreamChunk>,
) -> Vec<ChatRequestMessage> {
let mut tool_messages = Vec::new();
let mut ctx = ToolContext::new(
db.clone(),
cache.clone(),
app_config.clone(),
Uuid::nil(),
Some(sender_uid),
)
.with_project(project_id);
if let Some(es) = embed_service {
ctx = ctx.with_embed_service(es.clone());
}
if let Some(registry) = tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let mut join_set = tokio::task::JoinSet::new();
for call in calls {
let call_clone = call.clone();
let mut ctx_clone = ctx.clone();
join_set.spawn(async move {
let executor = ToolExecutor::new();
let res = executor
.execute_batch(vec![call_clone.clone()], &mut ctx_clone)
.await;
(call_clone, res)
});
}
let heartbeat_dur = std::time::Duration::from_secs(10);
while !join_set.is_empty() {
tokio::select! {
Some(res) = join_set.join_next() => {
if let Ok((call, results)) = res {
match results {
Ok(results) => {
for result in &results {
let preview = match &result.result {
crate::tool::ToolResult::Ok(v) => {
let t = v.to_string();
if t.len() > 300 {
let end = t.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
format!("{}...", &t[..end])
} else { t.clone() }
}
crate::tool::ToolResult::Error(msg) => msg.clone(),
};
tracing::debug!("tool_result: {} — {}", call.name, preview);
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult, metadata: None }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
let msgs = ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
Err(e) => {
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_text = format!("[Tool call failed: {}]", e);
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult, metadata: None }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
}
}
}
},
_ = tokio::time::sleep(heartbeat_dur) => {
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall, metadata: None }).await;
}
}
}
tool_messages
}