use ai::{ agent::AgentConfig, client::{AiClient, AiClientConfig, EmbedConfig, EndpointConfig}, }; use db::sqlx::{self, types::Decimal}; use model::{ agent::AgentSessionModel, ai::AiModelVersionModel, ai::AiProviderModel, }; use uuid::Uuid; use super::types::SessionContext; use crate::error::AppError; use crate::AppService; impl AppService { pub(crate) async fn agent_session_context( &self, session_id: Uuid, user_id: Uuid, ) -> Result { let session = sqlx::query_as::<_, AgentSessionModel>( "SELECT id, \"user\", wk, name, description, agent_kind, model_version, \ system_prompt, temperature, max_output_tokens, tool_policy, \ knowledge_base_ids, variables, visibility, version, \ published_at, rollback_from_version, enabled, \ source, parent_session_id, toolset_json, \ memory_provider, memory_provider_config, iteration_budget, \ created_by, created_at, updated_at, deleted_at \ FROM agent_session WHERE id = $1 AND deleted_at IS NULL AND enabled = true", ) .bind(session_id) .fetch_optional(self.db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))? .ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?; if let Some(wk_id) = session.wk { let _ = self .workspace_require_member(wk_id, user_id) .await?; } else if Some(user_id) != session.user { return Err(AppError::PermissionDenied); } let model_version_id = session .model_version .ok_or_else(|| AppError::BadRequest("agent session has no model_version".to_string()))?; let version = self.resolve_model_version(model_version_id).await?; let billing_target = if session.wk.is_some() { super::types::BillingTarget::Workspace } else { super::types::BillingTarget::User }; Ok(SessionContext { session_id, user_id: session.user, workspace_id: session.wk, system_prompt: self.build_system_prompt_with_context( &session, user_id, ).await, model_version_id: version.id, provider_model_name: version.provider_model_name, temperature: session.temperature, max_output_tokens: session.max_output_tokens, tool_policy_json: session.tool_policy, toolset_json: session.toolset_json, variables_json: session.variables, iteration_budget: session.iteration_budget, memory_provider: session.memory_provider, source: session.source, parent_session_id: session.parent_session_id, billing_target, }) } pub(crate) async fn agent_build_ai_client( &self, model_version_id: Uuid, ) -> Result { let version = self.resolve_model_version(model_version_id).await?; let model_record = sqlx::query_as::<_, model::ai::AiModelModel>( "SELECT id, provider, name, display_name, description, modality, \ context_window, input_token_limit, output_token_limit, \ enabled, public, created_at, updated_at, deleted_at \ FROM ai_model WHERE id = $1", ) .bind(version.model) .fetch_optional(self.db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))? .ok_or_else(|| AppError::NotFound("ai model not found".to_string()))?; let provider = sqlx::query_as::<_, AiProviderModel>( "SELECT id, name, base_url, website_url, logo_url, enabled, created_at, updated_at \ FROM ai_provider WHERE id = $1 AND enabled = true", ) .bind(model_record.provider) .fetch_optional(self.db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))? .ok_or_else(|| AppError::NotFound("ai provider not found".to_string()))?; let base_url = provider .base_url .unwrap_or_else(|| self.config.ai_basic_url().unwrap_or_default()); let api_key = self .config .ai_api_key() .map_err(|e| AppError::InternalServerError(format!("AI API key: {e}")))?; let embed_base_url = self .config .get_embed_model_base_url() .map_err(|e| AppError::InternalServerError(format!("embed base url: {e}")))?; let embed_api_key = self .config .get_embed_model_api_key() .map_err(|e| AppError::InternalServerError(format!("embed api key: {e}")))?; let llm_config = EndpointConfig::new(&base_url, &api_key) .map_err(|e| AppError::InternalServerError(e.to_string()))?; let embed_endpoint = EndpointConfig::new(&embed_base_url, &embed_api_key) .map_err(|e| AppError::InternalServerError(e.to_string()))?; let embed_config = EmbedConfig::new( embed_endpoint, self.config .get_embed_model_name() .map_err(|e| AppError::InternalServerError(e.to_string()))?, self.config .get_embed_model_dimensions() .map_err(|e| AppError::InternalServerError(e.to_string()))?, ) .map_err(|e| AppError::InternalServerError(e.to_string()))?; let client_config = AiClientConfig::new(llm_config, embed_config) .map_err(|e| AppError::InternalServerError(e.to_string()))?; AiClient::new(client_config).map_err(|e| AppError::InternalServerError(e.to_string())) } pub(crate) fn agent_build_config( &self, ctx: &SessionContext, max_steps_override: Option, ) -> AgentConfig { let mut config = AgentConfig::new(&ctx.provider_model_name) .unwrap_or_else(|_| { AgentConfig::new("gpt-4o").expect("default agent config") }); config.model = ctx.provider_model_name.clone(); config.system_prompt = ctx.system_prompt.clone(); if let Some(ref vars_json) = ctx.variables_json { if let Ok(vars) = serde_json::from_str::>(vars_json) { if !vars.is_empty() { let mut prompt = config.system_prompt.clone(); let mut any_replaced = false; for (key, val) in &vars { let placeholder = format!("{{{{{}}}}}", key); if prompt.contains(&placeholder) { let replacement = match val { serde_json::Value::String(s) => s.clone(), other => other.to_string(), }; prompt = prompt.replace(&placeholder, &replacement); any_replaced = true; } } if !any_replaced { prompt.push_str("\n\n\n"); for (key, val) in &vars { let val_str = match val { serde_json::Value::String(s) => s.clone(), other => other.to_string(), }; prompt.push_str(&format!("- {}: {}\n", key, val_str)); } prompt.push_str(""); } config.system_prompt = prompt; } } } if let Some(temp) = ctx.temperature { config.temperature = Some(temp as f64); } if let Some(max_tok) = ctx.max_output_tokens { config.max_completion_tokens = Some(max_tok as u64); } if let Some(max_steps) = max_steps_override { config.max_iterations = max_steps; } if let Some(budget) = ctx.iteration_budget { config.iteration_budget = budget as usize; } if let Some(ref policy_json) = ctx.tool_policy_json { match serde_json::from_str::(policy_json) { Ok(policy) => { let allowed: Vec = policy .get("allowed") .and_then(|v| v.as_array()) .map(|a| { a.iter() .filter_map(|v| v.as_str().map(String::from)) .collect() }) .unwrap_or_default(); let denied: Vec = policy .get("denied") .and_then(|v| v.as_array()) .map(|a| { a.iter() .filter_map(|v| v.as_str().map(String::from)) .collect() }) .unwrap_or_default(); if !allowed.is_empty() || !denied.is_empty() { config = config.with_tool_policy(allowed, denied); } } Err(e) => { tracing::warn!( error = %e, "failed to parse tool policy JSON, ignoring" ); } } } if let Some(ref toolset_json) = ctx.toolset_json { if let Ok(policy) = serde_json::from_str::(toolset_json) { let enabled: Vec = policy .get("enabled") .and_then(|v| v.as_array()) .map(|a| { a.iter() .filter_map(|v| v.as_str().map(String::from)) .collect() }) .unwrap_or_default(); let disabled: Vec = policy .get("disabled") .and_then(|v| v.as_array()) .map(|a| { a.iter() .filter_map(|v| v.as_str().map(String::from)) .collect() }) .unwrap_or_default(); if !enabled.is_empty() || !disabled.is_empty() { config = config.with_toolset_policy(enabled, disabled); } } } if ctx.memory_provider.as_deref() == Some("none") { config.skip_memory = true; } config } async fn resolve_model_version( &self, id: Uuid, ) -> Result { if let Some(version) = sqlx::query_as::<_, AiModelVersionModel>( "SELECT id, model, version, provider_model_name, \ input_price_per_million, output_price_per_million, cached_input_price_per_million, \ training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \ FROM ai_model_version WHERE id = $1 AND enabled = true", ) .bind(id) .fetch_optional(self.db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))? { return Ok(version); } let version = sqlx::query_as::<_, AiModelVersionModel>( "SELECT v.id, v.model, v.version, v.provider_model_name, \ v.input_price_per_million, v.output_price_per_million, v.cached_input_price_per_million, \ v.training_cutoff, v.released_at, v.deprecated_at, v.enabled, v.created_at, v.updated_at \ FROM ai_model_version v \ INNER JOIN ai_model m ON m.id = v.model \ WHERE m.id = $1 AND v.enabled = true AND m.enabled = true \ ORDER BY v.created_at DESC LIMIT 1", ) .bind(id) .fetch_optional(self.db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))? .ok_or_else(|| AppError::NotFound("model version not found".to_string()))?; Ok(version) } pub(crate) async fn agent_resolve_pricing( &self, model_version_id: Uuid, ) -> Result<(Option, Option), AppError> { let version = sqlx::query_as::<_, AiModelVersionModel>( "SELECT id, model, version, provider_model_name, \ input_price_per_million, output_price_per_million, cached_input_price_per_million, \ training_cutoff, released_at, deprecated_at, enabled, created_at, updated_at \ FROM ai_model_version WHERE id = $1", ) .bind(model_version_id) .fetch_optional(self.db.reader()) .await .map_err(|e| AppError::DatabaseError(e.to_string()))? .ok_or_else(|| AppError::NotFound("model version not found".to_string()))?; Ok(( version.input_price_per_million, version.output_price_per_million, )) } /// Build the system prompt enriched with workspace and user context. async fn build_system_prompt_with_context( &self, session: &model::agent::AgentSessionModel, user_id: Uuid, ) -> String { let base = session .system_prompt .clone() .unwrap_or_else(|| ai::agent::config::default_system_prompt().to_string()); let mut context_section = String::new(); // Workspace context if let Some(wk_id) = session.wk { let wk: Option<(String,)> = sqlx::query_as( "SELECT name FROM workspace WHERE id = $1") .bind(wk_id) .fetch_optional(self.db.reader()) .await .unwrap_or(None); if let Some((wk_name,)) = wk { context_section.push_str(&format!( "- You are operating in workspace \"{wk_name}\" (id: {wk_id}).\n" )); context_section.push_str( " All file operations, repo access, and code changes are scoped to this workspace.\n" ); } } // User context if let Some(session_user_id) = session.user { let u: Option<(String, String)> = sqlx::query_as( "SELECT display_name, username FROM \"user\" WHERE id = $1") .bind(session_user_id) .fetch_optional(self.db.reader()) .await .unwrap_or(None); if let Some((display_name, username)) = u { let name = if display_name.is_empty() { &username } else { &display_name }; context_section.push_str(&format!( "- The current user is {name} (username: {username}, id: {session_user_id}).\n" )); } } else { let u: Option<(String, String)> = sqlx::query_as( "SELECT display_name, username FROM \"user\" WHERE id = $1") .bind(user_id) .fetch_optional(self.db.reader()) .await .unwrap_or(None); if let Some((display_name, username)) = u { let name = if display_name.is_empty() { &username } else { &display_name }; context_section.push_str(&format!( "- The current user is {name} (username: {username}, id: {user_id}).\n" )); } } if context_section.is_empty() { return base; } format!( "{base}\n\n\nThe following is provided for context. Always use this information\nwhen tailoring your responses, resolving references, and scoping operations.\n\n{context_section}" ) } }