gitdataai/lib/service/agent/config.rs

407 lines
16 KiB
Rust

use ai::{
agent::AgentConfig,
client::{AiClient, AiClientConfig, EmbedConfig, EndpointConfig},
};
use db::sqlx::{self, types::Decimal};
use model::{
agent::AgentSessionModel, ai::AiModelVersionModel, ai::AiProviderModel,
};
use uuid::Uuid;
use super::types::SessionContext;
use crate::AppService;
use crate::error::AppError;
impl AppService {
pub async fn agent_session_context(
&self,
session_id: Uuid,
user_id: Uuid,
) -> Result<SessionContext, AppError> {
let session = sqlx::query_as::<_, AgentSessionModel>(
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
system_prompt, temperature, max_output_tokens, tool_policy, \
knowledge_base_ids, variables, visibility, version, \
published_at, rollback_from_version, enabled, \
source, parent_session_id, toolset_json, \
memory_provider, memory_provider_config, iteration_budget, \
created_by, created_at, updated_at, deleted_at \
FROM agent_session WHERE id = $1 AND deleted_at IS NULL AND enabled = true",
)
.bind(session_id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?;
if let Some(wk_id) = session.wk {
let _ = self.workspace_require_member(wk_id, user_id).await?;
} else if Some(user_id) != session.user {
return Err(AppError::PermissionDenied);
}
let model_version_id = session.model_version.ok_or_else(|| {
AppError::BadRequest(
"agent session has no model_version".to_string(),
)
})?;
let version = self.resolve_model_version(model_version_id).await?;
let billing_target = if session.wk.is_some() {
super::types::BillingTarget::Workspace
} else {
super::types::BillingTarget::User
};
Ok(SessionContext {
session_id,
user_id: session.user,
workspace_id: session.wk,
system_prompt: self
.build_system_prompt_with_context(&session, user_id)
.await,
model_version_id: version.id,
provider_model_name: version.provider_model_name,
temperature: session.temperature,
max_output_tokens: session.max_output_tokens,
tool_policy_json: session.tool_policy,
toolset_json: session.toolset_json,
variables_json: session.variables,
iteration_budget: session.iteration_budget,
memory_provider: session.memory_provider,
source: session.source,
parent_session_id: session.parent_session_id,
billing_target,
})
}
pub async fn agent_build_ai_client(
&self,
model_version_id: Uuid,
) -> Result<AiClient, AppError> {
let version = self.resolve_model_version(model_version_id).await?;
let model_record = sqlx::query_as::<_, model::ai::AiModelModel>(
"SELECT id, provider, name, display_name, description, modality, \
context_window, input_token_limit, output_token_limit, \
enabled, public, created_at, updated_at, deleted_at \
FROM ai_model WHERE id = $1",
)
.bind(version.model)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.ok_or_else(|| AppError::NotFound("ai model not found".to_string()))?;
let provider = sqlx::query_as::<_, AiProviderModel>(
"SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \
FROM ai_provider WHERE id = $1 AND enabled = true",
)
.bind(model_record.provider)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.ok_or_else(|| AppError::NotFound("ai provider not found".to_string()))?;
let base_url = provider
.base_url
.unwrap_or_else(|| self.config.ai_basic_url().unwrap_or_default());
let api_key = self.config.ai_api_key().map_err(|e| {
AppError::InternalServerError(format!("AI API key: {e}"))
})?;
let embed_base_url =
self.config.get_embed_model_base_url().map_err(|e| {
AppError::InternalServerError(format!("embed base url: {e}"))
})?;
let embed_api_key =
self.config.get_embed_model_api_key().map_err(|e| {
AppError::InternalServerError(format!("embed api key: {e}"))
})?;
let llm_config = EndpointConfig::new(&base_url, &api_key)
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let embed_endpoint =
EndpointConfig::new(&embed_base_url, &embed_api_key)
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let embed_config = EmbedConfig::new(
embed_endpoint,
self.config
.get_embed_model_name()
.map_err(|e| AppError::InternalServerError(e.to_string()))?,
self.config
.get_embed_model_dimensions()
.map_err(|e| AppError::InternalServerError(e.to_string()))?,
)
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
let client_config = AiClientConfig::new(llm_config, embed_config)
.map_err(|e| AppError::InternalServerError(e.to_string()))?;
AiClient::new(client_config)
.map_err(|e| AppError::InternalServerError(e.to_string()))
}
pub fn agent_build_config(
&self,
ctx: &SessionContext,
max_steps_override: Option<usize>,
) -> AgentConfig {
let mut config = AgentConfig::new(&ctx.provider_model_name)
.unwrap_or_else(|_| {
AgentConfig::new("gpt-4o").expect("default agent config")
});
config.model = ctx.provider_model_name.clone();
config.system_prompt = ctx.system_prompt.clone();
if let Some(ref vars_json) = ctx.variables_json {
if let Ok(vars) = serde_json::from_str::<
serde_json::Map<String, serde_json::Value>,
>(vars_json)
{
if !vars.is_empty() {
let mut prompt = config.system_prompt.clone();
let mut any_replaced = false;
for (key, val) in &vars {
let placeholder = format!("{{{{{}}}}}", key);
if prompt.contains(&placeholder) {
let replacement = match val {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
prompt = prompt.replace(&placeholder, &replacement);
any_replaced = true;
}
}
if !any_replaced {
prompt.push_str("\n\n<variables>\n");
for (key, val) in &vars {
let val_str = match val {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
prompt
.push_str(&format!("- {}: {}\n", key, val_str));
}
prompt.push_str("</variables>");
}
config.system_prompt = prompt;
}
}
}
if let Some(temp) = ctx.temperature {
config.temperature = Some(temp as f64);
}
if let Some(max_tok) = ctx.max_output_tokens {
config.max_completion_tokens = Some(max_tok as u64);
}
if let Some(max_steps) = max_steps_override {
config.max_iterations = max_steps;
}
if let Some(budget) = ctx.iteration_budget {
config.iteration_budget = budget as usize;
}
if let Some(ref policy_json) = ctx.tool_policy_json {
match serde_json::from_str::<serde_json::Value>(policy_json) {
Ok(policy) => {
let allowed: Vec<String> = policy
.get("allowed")
.and_then(|v| v.as_array())
.map(|a| {
a.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let denied: Vec<String> = policy
.get("denied")
.and_then(|v| v.as_array())
.map(|a| {
a.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
if !allowed.is_empty() || !denied.is_empty() {
config = config.with_tool_policy(allowed, denied);
}
}
Err(e) => {
tracing::warn!(
error = %e,
"failed to parse tool policy JSON, ignoring"
);
}
}
}
if let Some(ref toolset_json) = ctx.toolset_json {
if let Ok(policy) =
serde_json::from_str::<serde_json::Value>(toolset_json)
{
let enabled: Vec<String> = policy
.get("enabled")
.and_then(|v| v.as_array())
.map(|a| {
a.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let disabled: Vec<String> = policy
.get("disabled")
.and_then(|v| v.as_array())
.map(|a| {
a.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
if !enabled.is_empty() || !disabled.is_empty() {
config = config.with_toolset_policy(enabled, disabled);
}
}
}
if ctx.memory_provider.as_deref() == Some("none") {
config.skip_memory = true;
}
config
}
async fn resolve_model_version(
&self,
id: Uuid,
) -> Result<AiModelVersionModel, AppError> {
if let Some(version) = sqlx::query_as::<_, AiModelVersionModel>(
"SELECT id, model, version, provider_model_name, \
input_price_per_million, output_price_per_million, cached_input_price_per_million, \
training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \
FROM ai_model_version WHERE id = $1 AND enabled = true",
)
.bind(id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
{
return Ok(version);
}
let version = sqlx::query_as::<_, AiModelVersionModel>(
"SELECT v.id, v.model, v.version, v.provider_model_name, \
v.input_price_per_million, v.output_price_per_million, v.cached_input_price_per_million, \
v.training_cutoff, v.released_at, v.deprecated_at, v.enabled, v.created_at, v.updated_at \
FROM ai_model_version v \
INNER JOIN ai_model m ON m.id = v.model \
WHERE m.id = $1 AND v.enabled = true AND m.enabled = true \
ORDER BY v.created_at DESC LIMIT 1",
)
.bind(id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.ok_or_else(|| AppError::NotFound("model version not found".to_string()))?;
Ok(version)
}
pub async fn agent_resolve_pricing(
&self,
model_version_id: Uuid,
) -> Result<(Option<Decimal>, Option<Decimal>), AppError> {
let version = sqlx::query_as::<_, AiModelVersionModel>(
"SELECT id, model, version, provider_model_name, \
input_price_per_million, output_price_per_million, cached_input_price_per_million, \
training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \
FROM ai_model_version WHERE id = $1",
)
.bind(model_version_id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.ok_or_else(|| AppError::NotFound("model version not found".to_string()))?;
Ok((
version.input_price_per_million,
version.output_price_per_million,
))
}
/// Build the system prompt enriched with workspace and user context.
async fn build_system_prompt_with_context(
&self,
session: &model::agent::AgentSessionModel,
user_id: Uuid,
) -> String {
let base = session.system_prompt.clone().unwrap_or_else(|| {
ai::agent::config::default_system_prompt().to_string()
});
let mut context_section = String::new();
// Workspace context
if let Some(wk_id) = session.wk {
let wk: Option<(String,)> =
sqlx::query_as("SELECT name FROM workspace WHERE id = $1")
.bind(wk_id)
.fetch_optional(self.db.reader())
.await
.unwrap_or(None);
if let Some((wk_name,)) = wk {
context_section.push_str(&format!(
"- You are operating in workspace \"{wk_name}\" (id: {wk_id}).\n"
));
context_section.push_str(
" All file operations, repo access, and code changes are scoped to this workspace.\n"
);
}
}
// User context
if let Some(session_user_id) = session.user {
let u: Option<(String, String)> = sqlx::query_as(
"SELECT display_name, username FROM \"user\" WHERE id = $1",
)
.bind(session_user_id)
.fetch_optional(self.db.reader())
.await
.unwrap_or(None);
if let Some((display_name, username)) = u {
let name = if display_name.is_empty() {
&username
} else {
&display_name
};
context_section.push_str(&format!(
"- The current user is {name} (username: {username}, id: {session_user_id}).\n"
));
}
} else {
let u: Option<(String, String)> = sqlx::query_as(
"SELECT display_name, username FROM \"user\" WHERE id = $1",
)
.bind(user_id)
.fetch_optional(self.db.reader())
.await
.unwrap_or(None);
if let Some((display_name, username)) = u {
let name = if display_name.is_empty() {
&username
} else {
&display_name
};
context_section.push_str(&format!(
"- The current user is {name} (username: {username}, id: {user_id}).\n"
));
}
}
if context_section.is_empty() {
return base;
}
format!(
"{base}\n\n<environment>\nThe following is provided for context. Always use this information\nwhen tailoring your responses, resolving references, and scoping operations.\n\n{context_section}</environment>"
)
}
}