407 lines
16 KiB
Rust
407 lines
16 KiB
Rust
use ai::{
|
|
agent::AgentConfig,
|
|
client::{AiClient, AiClientConfig, EmbedConfig, EndpointConfig},
|
|
};
|
|
use db::sqlx::{self, types::Decimal};
|
|
use model::{
|
|
agent::AgentSessionModel, ai::AiModelVersionModel, ai::AiProviderModel,
|
|
};
|
|
use uuid::Uuid;
|
|
|
|
use super::types::SessionContext;
|
|
use crate::AppService;
|
|
use crate::error::AppError;
|
|
|
|
impl AppService {
|
|
pub async fn agent_session_context(
|
|
&self,
|
|
session_id: Uuid,
|
|
user_id: Uuid,
|
|
) -> Result<SessionContext, AppError> {
|
|
let session = sqlx::query_as::<_, AgentSessionModel>(
|
|
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
|
|
system_prompt, temperature, max_output_tokens, tool_policy, \
|
|
knowledge_base_ids, variables, visibility, version, \
|
|
published_at, rollback_from_version, enabled, \
|
|
source, parent_session_id, toolset_json, \
|
|
memory_provider, memory_provider_config, iteration_budget, \
|
|
created_by, created_at, updated_at, deleted_at \
|
|
FROM agent_session WHERE id = $1 AND deleted_at IS NULL AND enabled = true",
|
|
)
|
|
.bind(session_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?;
|
|
|
|
if let Some(wk_id) = session.wk {
|
|
let _ = self.workspace_require_member(wk_id, user_id).await?;
|
|
} else if Some(user_id) != session.user {
|
|
return Err(AppError::PermissionDenied);
|
|
}
|
|
|
|
let model_version_id = session.model_version.ok_or_else(|| {
|
|
AppError::BadRequest(
|
|
"agent session has no model_version".to_string(),
|
|
)
|
|
})?;
|
|
|
|
let version = self.resolve_model_version(model_version_id).await?;
|
|
|
|
let billing_target = if session.wk.is_some() {
|
|
super::types::BillingTarget::Workspace
|
|
} else {
|
|
super::types::BillingTarget::User
|
|
};
|
|
|
|
Ok(SessionContext {
|
|
session_id,
|
|
user_id: session.user,
|
|
workspace_id: session.wk,
|
|
system_prompt: self
|
|
.build_system_prompt_with_context(&session, user_id)
|
|
.await,
|
|
model_version_id: version.id,
|
|
provider_model_name: version.provider_model_name,
|
|
temperature: session.temperature,
|
|
max_output_tokens: session.max_output_tokens,
|
|
tool_policy_json: session.tool_policy,
|
|
toolset_json: session.toolset_json,
|
|
variables_json: session.variables,
|
|
iteration_budget: session.iteration_budget,
|
|
memory_provider: session.memory_provider,
|
|
source: session.source,
|
|
parent_session_id: session.parent_session_id,
|
|
billing_target,
|
|
})
|
|
}
|
|
pub async fn agent_build_ai_client(
|
|
&self,
|
|
model_version_id: Uuid,
|
|
) -> Result<AiClient, AppError> {
|
|
let version = self.resolve_model_version(model_version_id).await?;
|
|
|
|
let model_record = sqlx::query_as::<_, model::ai::AiModelModel>(
|
|
"SELECT id, provider, name, display_name, description, modality, \
|
|
context_window, input_token_limit, output_token_limit, \
|
|
enabled, public, created_at, updated_at, deleted_at \
|
|
FROM ai_model WHERE id = $1",
|
|
)
|
|
.bind(version.model)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.ok_or_else(|| AppError::NotFound("ai model not found".to_string()))?;
|
|
|
|
let provider = sqlx::query_as::<_, AiProviderModel>(
|
|
"SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \
|
|
FROM ai_provider WHERE id = $1 AND enabled = true",
|
|
)
|
|
.bind(model_record.provider)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.ok_or_else(|| AppError::NotFound("ai provider not found".to_string()))?;
|
|
|
|
let base_url = provider
|
|
.base_url
|
|
.unwrap_or_else(|| self.config.ai_basic_url().unwrap_or_default());
|
|
let api_key = self.config.ai_api_key().map_err(|e| {
|
|
AppError::InternalServerError(format!("AI API key: {e}"))
|
|
})?;
|
|
|
|
let embed_base_url =
|
|
self.config.get_embed_model_base_url().map_err(|e| {
|
|
AppError::InternalServerError(format!("embed base url: {e}"))
|
|
})?;
|
|
let embed_api_key =
|
|
self.config.get_embed_model_api_key().map_err(|e| {
|
|
AppError::InternalServerError(format!("embed api key: {e}"))
|
|
})?;
|
|
|
|
let llm_config = EndpointConfig::new(&base_url, &api_key)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
let embed_endpoint =
|
|
EndpointConfig::new(&embed_base_url, &embed_api_key)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
let embed_config = EmbedConfig::new(
|
|
embed_endpoint,
|
|
self.config
|
|
.get_embed_model_name()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?,
|
|
self.config
|
|
.get_embed_model_dimensions()
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?,
|
|
)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
let client_config = AiClientConfig::new(llm_config, embed_config)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
|
|
|
|
AiClient::new(client_config)
|
|
.map_err(|e| AppError::InternalServerError(e.to_string()))
|
|
}
|
|
pub fn agent_build_config(
|
|
&self,
|
|
ctx: &SessionContext,
|
|
max_steps_override: Option<usize>,
|
|
) -> AgentConfig {
|
|
let mut config = AgentConfig::new(&ctx.provider_model_name)
|
|
.unwrap_or_else(|_| {
|
|
AgentConfig::new("gpt-4o").expect("default agent config")
|
|
});
|
|
|
|
config.model = ctx.provider_model_name.clone();
|
|
config.system_prompt = ctx.system_prompt.clone();
|
|
if let Some(ref vars_json) = ctx.variables_json {
|
|
if let Ok(vars) = serde_json::from_str::<
|
|
serde_json::Map<String, serde_json::Value>,
|
|
>(vars_json)
|
|
{
|
|
if !vars.is_empty() {
|
|
let mut prompt = config.system_prompt.clone();
|
|
let mut any_replaced = false;
|
|
for (key, val) in &vars {
|
|
let placeholder = format!("{{{{{}}}}}", key);
|
|
if prompt.contains(&placeholder) {
|
|
let replacement = match val {
|
|
serde_json::Value::String(s) => s.clone(),
|
|
other => other.to_string(),
|
|
};
|
|
prompt = prompt.replace(&placeholder, &replacement);
|
|
any_replaced = true;
|
|
}
|
|
}
|
|
if !any_replaced {
|
|
prompt.push_str("\n\n<variables>\n");
|
|
for (key, val) in &vars {
|
|
let val_str = match val {
|
|
serde_json::Value::String(s) => s.clone(),
|
|
other => other.to_string(),
|
|
};
|
|
prompt
|
|
.push_str(&format!("- {}: {}\n", key, val_str));
|
|
}
|
|
prompt.push_str("</variables>");
|
|
}
|
|
config.system_prompt = prompt;
|
|
}
|
|
}
|
|
}
|
|
|
|
if let Some(temp) = ctx.temperature {
|
|
config.temperature = Some(temp as f64);
|
|
}
|
|
if let Some(max_tok) = ctx.max_output_tokens {
|
|
config.max_completion_tokens = Some(max_tok as u64);
|
|
}
|
|
if let Some(max_steps) = max_steps_override {
|
|
config.max_iterations = max_steps;
|
|
}
|
|
if let Some(budget) = ctx.iteration_budget {
|
|
config.iteration_budget = budget as usize;
|
|
}
|
|
if let Some(ref policy_json) = ctx.tool_policy_json {
|
|
match serde_json::from_str::<serde_json::Value>(policy_json) {
|
|
Ok(policy) => {
|
|
let allowed: Vec<String> = policy
|
|
.get("allowed")
|
|
.and_then(|v| v.as_array())
|
|
.map(|a| {
|
|
a.iter()
|
|
.filter_map(|v| v.as_str().map(String::from))
|
|
.collect()
|
|
})
|
|
.unwrap_or_default();
|
|
let denied: Vec<String> = policy
|
|
.get("denied")
|
|
.and_then(|v| v.as_array())
|
|
.map(|a| {
|
|
a.iter()
|
|
.filter_map(|v| v.as_str().map(String::from))
|
|
.collect()
|
|
})
|
|
.unwrap_or_default();
|
|
|
|
if !allowed.is_empty() || !denied.is_empty() {
|
|
config = config.with_tool_policy(allowed, denied);
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!(
|
|
error = %e,
|
|
"failed to parse tool policy JSON, ignoring"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
if let Some(ref toolset_json) = ctx.toolset_json {
|
|
if let Ok(policy) =
|
|
serde_json::from_str::<serde_json::Value>(toolset_json)
|
|
{
|
|
let enabled: Vec<String> = policy
|
|
.get("enabled")
|
|
.and_then(|v| v.as_array())
|
|
.map(|a| {
|
|
a.iter()
|
|
.filter_map(|v| v.as_str().map(String::from))
|
|
.collect()
|
|
})
|
|
.unwrap_or_default();
|
|
let disabled: Vec<String> = policy
|
|
.get("disabled")
|
|
.and_then(|v| v.as_array())
|
|
.map(|a| {
|
|
a.iter()
|
|
.filter_map(|v| v.as_str().map(String::from))
|
|
.collect()
|
|
})
|
|
.unwrap_or_default();
|
|
|
|
if !enabled.is_empty() || !disabled.is_empty() {
|
|
config = config.with_toolset_policy(enabled, disabled);
|
|
}
|
|
}
|
|
}
|
|
if ctx.memory_provider.as_deref() == Some("none") {
|
|
config.skip_memory = true;
|
|
}
|
|
|
|
config
|
|
}
|
|
async fn resolve_model_version(
|
|
&self,
|
|
id: Uuid,
|
|
) -> Result<AiModelVersionModel, AppError> {
|
|
if let Some(version) = sqlx::query_as::<_, AiModelVersionModel>(
|
|
"SELECT id, model, version, provider_model_name, \
|
|
input_price_per_million, output_price_per_million, cached_input_price_per_million, \
|
|
training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \
|
|
FROM ai_model_version WHERE id = $1 AND enabled = true",
|
|
)
|
|
.bind(id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
{
|
|
return Ok(version);
|
|
}
|
|
|
|
let version = sqlx::query_as::<_, AiModelVersionModel>(
|
|
"SELECT v.id, v.model, v.version, v.provider_model_name, \
|
|
v.input_price_per_million, v.output_price_per_million, v.cached_input_price_per_million, \
|
|
v.training_cutoff, v.released_at, v.deprecated_at, v.enabled, v.created_at, v.updated_at \
|
|
FROM ai_model_version v \
|
|
INNER JOIN ai_model m ON m.id = v.model \
|
|
WHERE m.id = $1 AND v.enabled = true AND m.enabled = true \
|
|
ORDER BY v.created_at DESC LIMIT 1",
|
|
)
|
|
.bind(id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.ok_or_else(|| AppError::NotFound("model version not found".to_string()))?;
|
|
|
|
Ok(version)
|
|
}
|
|
pub async fn agent_resolve_pricing(
|
|
&self,
|
|
model_version_id: Uuid,
|
|
) -> Result<(Option<Decimal>, Option<Decimal>), AppError> {
|
|
let version = sqlx::query_as::<_, AiModelVersionModel>(
|
|
"SELECT id, model, version, provider_model_name, \
|
|
input_price_per_million, output_price_per_million, cached_input_price_per_million, \
|
|
training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \
|
|
FROM ai_model_version WHERE id = $1",
|
|
)
|
|
.bind(model_version_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.ok_or_else(|| AppError::NotFound("model version not found".to_string()))?;
|
|
|
|
Ok((
|
|
version.input_price_per_million,
|
|
version.output_price_per_million,
|
|
))
|
|
}
|
|
|
|
/// Build the system prompt enriched with workspace and user context.
|
|
async fn build_system_prompt_with_context(
|
|
&self,
|
|
session: &model::agent::AgentSessionModel,
|
|
user_id: Uuid,
|
|
) -> String {
|
|
let base = session.system_prompt.clone().unwrap_or_else(|| {
|
|
ai::agent::config::default_system_prompt().to_string()
|
|
});
|
|
|
|
let mut context_section = String::new();
|
|
|
|
// Workspace context
|
|
if let Some(wk_id) = session.wk {
|
|
let wk: Option<(String,)> =
|
|
sqlx::query_as("SELECT name FROM workspace WHERE id = $1")
|
|
.bind(wk_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.unwrap_or(None);
|
|
if let Some((wk_name,)) = wk {
|
|
context_section.push_str(&format!(
|
|
"- You are operating in workspace \"{wk_name}\" (id: {wk_id}).\n"
|
|
));
|
|
context_section.push_str(
|
|
" All file operations, repo access, and code changes are scoped to this workspace.\n"
|
|
);
|
|
}
|
|
}
|
|
|
|
// User context
|
|
if let Some(session_user_id) = session.user {
|
|
let u: Option<(String, String)> = sqlx::query_as(
|
|
"SELECT display_name, username FROM \"user\" WHERE id = $1",
|
|
)
|
|
.bind(session_user_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.unwrap_or(None);
|
|
if let Some((display_name, username)) = u {
|
|
let name = if display_name.is_empty() {
|
|
&username
|
|
} else {
|
|
&display_name
|
|
};
|
|
context_section.push_str(&format!(
|
|
"- The current user is {name} (username: {username}, id: {session_user_id}).\n"
|
|
));
|
|
}
|
|
} else {
|
|
let u: Option<(String, String)> = sqlx::query_as(
|
|
"SELECT display_name, username FROM \"user\" WHERE id = $1",
|
|
)
|
|
.bind(user_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.unwrap_or(None);
|
|
if let Some((display_name, username)) = u {
|
|
let name = if display_name.is_empty() {
|
|
&username
|
|
} else {
|
|
&display_name
|
|
};
|
|
context_section.push_str(&format!(
|
|
"- The current user is {name} (username: {username}, id: {user_id}).\n"
|
|
));
|
|
}
|
|
}
|
|
|
|
if context_section.is_empty() {
|
|
return base;
|
|
}
|
|
|
|
format!(
|
|
"{base}\n\n<environment>\nThe following is provided for context. Always use this information\nwhen tailoring your responses, resolving references, and scoping operations.\n\n{context_section}</environment>"
|
|
)
|
|
}
|
|
}
|