501 lines
19 KiB
Rust
501 lines
19 KiB
Rust
use chrono::Utc;
|
|
use db::sqlx;
|
|
use model::agent::AgentSessionModel;
|
|
use serde::{Deserialize, Serialize};
|
|
use utoipa::ToSchema;
|
|
use uuid::Uuid;
|
|
|
|
use crate::AppService;
|
|
use crate::error::AppError;
|
|
|
|
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
|
pub struct CreateAgentSession {
|
|
pub name: String,
|
|
pub agent_kind: String,
|
|
pub model_version: Uuid,
|
|
#[serde(default)]
|
|
pub description: Option<String>,
|
|
#[serde(default)]
|
|
pub system_prompt: Option<String>,
|
|
#[serde(default)]
|
|
pub temperature: Option<f32>,
|
|
pub max_output_tokens: Option<i32>,
|
|
pub tool_policy: Option<String>,
|
|
pub toolset_json: Option<String>,
|
|
pub memory_provider: Option<String>,
|
|
pub memory_provider_config: Option<String>,
|
|
pub iteration_budget: Option<i32>,
|
|
pub source: Option<String>,
|
|
pub visibility: Option<String>,
|
|
pub wk: Option<String>,
|
|
pub knowledge_base_ids: Option<Vec<Uuid>>,
|
|
pub variables: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, ToSchema)]
|
|
pub struct UpdateAgentSession {
|
|
pub name: Option<String>,
|
|
pub description: Option<String>,
|
|
pub system_prompt: Option<String>,
|
|
pub temperature: Option<f32>,
|
|
pub max_output_tokens: Option<i32>,
|
|
pub model_version: Option<Uuid>,
|
|
pub tool_policy: Option<String>,
|
|
pub toolset_json: Option<String>,
|
|
pub memory_provider: Option<String>,
|
|
pub memory_provider_config: Option<String>,
|
|
pub iteration_budget: Option<i32>,
|
|
pub visibility: Option<String>,
|
|
pub enabled: Option<bool>,
|
|
pub knowledge_base_ids: Option<Vec<Uuid>>,
|
|
pub variables: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, ToSchema)]
|
|
pub struct AgentSessionResponse {
|
|
pub id: Uuid,
|
|
pub name: String,
|
|
pub description: Option<String>,
|
|
pub agent_kind: String,
|
|
pub model_version: Option<Uuid>,
|
|
pub system_prompt: Option<String>,
|
|
pub temperature: Option<f32>,
|
|
pub max_output_tokens: Option<i32>,
|
|
pub tool_policy: Option<String>,
|
|
pub toolset_json: Option<String>,
|
|
pub memory_provider: Option<String>,
|
|
pub iteration_budget: Option<i32>,
|
|
pub source: Option<String>,
|
|
pub parent_session_id: Option<Uuid>,
|
|
pub visibility: String,
|
|
pub version: i32,
|
|
pub enabled: bool,
|
|
pub user: Option<Uuid>,
|
|
pub wk: Option<Uuid>,
|
|
pub variables: Option<String>,
|
|
pub published_at: Option<chrono::DateTime<Utc>>,
|
|
pub created_at: chrono::DateTime<Utc>,
|
|
pub updated_at: chrono::DateTime<Utc>,
|
|
}
|
|
|
|
impl AppService {
|
|
pub async fn agent_session_create(
|
|
&self,
|
|
user_id: Uuid,
|
|
params: CreateAgentSession,
|
|
) -> Result<AgentSessionResponse, AppError> {
|
|
let wk_uuid: Option<Uuid> = if let Some(ref wk_name) = params.wk {
|
|
let wk =
|
|
crate::AppService::workspace_resolve(&*self, wk_name).await?;
|
|
let _ = crate::AppService::workspace_require_member(
|
|
&*self, wk.id, user_id,
|
|
)
|
|
.await?;
|
|
Some(wk.id)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let id = Uuid::now_v7();
|
|
let now = Utc::now();
|
|
let visibility =
|
|
params.visibility.unwrap_or_else(|| "private".to_string());
|
|
let kb_ids = params.knowledge_base_ids.map(|ids| {
|
|
ids.iter()
|
|
.map(|id| id.to_string())
|
|
.collect::<Vec<_>>()
|
|
.join(",")
|
|
});
|
|
|
|
let row = sqlx::query_as::<_, AgentSessionModel>(
|
|
"INSERT INTO agent_session \
|
|
(id, \"user\", wk, name, description, agent_kind, model_version, \
|
|
system_prompt, temperature, max_output_tokens, tool_policy, \
|
|
knowledge_base_ids, variables, visibility, version, enabled, \
|
|
source, toolset_json, memory_provider, memory_provider_config, iteration_budget, \
|
|
created_by, created_at, updated_at) \
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, 1, true, \
|
|
'api', '{}', 'simple', '{}', 90, \
|
|
$15, $16, $16) \
|
|
RETURNING id, \"user\", wk, name, description, agent_kind, model_version, \
|
|
system_prompt, temperature, max_output_tokens, tool_policy, \
|
|
knowledge_base_ids, variables, visibility, version, \
|
|
published_at, rollback_from_version, enabled, \
|
|
source, parent_session_id, toolset_json, \
|
|
memory_provider, memory_provider_config, iteration_budget, \
|
|
created_by, created_at, updated_at, deleted_at",
|
|
)
|
|
.bind(id)
|
|
.bind(user_id)
|
|
.bind(wk_uuid)
|
|
.bind(¶ms.name)
|
|
.bind(¶ms.description)
|
|
.bind(¶ms.agent_kind)
|
|
.bind(params.model_version)
|
|
.bind(¶ms.system_prompt)
|
|
.bind(params.temperature)
|
|
.bind(params.max_output_tokens)
|
|
.bind(¶ms.tool_policy)
|
|
.bind(&kb_ids)
|
|
.bind(¶ms.variables)
|
|
.bind(&visibility)
|
|
.bind(user_id)
|
|
.bind(now)
|
|
.fetch_one(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(row.into())
|
|
}
|
|
|
|
pub async fn agent_session_list(
|
|
&self,
|
|
user_id: Uuid,
|
|
) -> Result<Vec<AgentSessionResponse>, AppError> {
|
|
let rows = sqlx::query_as::<_, AgentSessionModel>(
|
|
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
|
|
system_prompt, temperature, max_output_tokens, tool_policy, \
|
|
knowledge_base_ids, variables, visibility, version, \
|
|
published_at, rollback_from_version, enabled, \
|
|
source, parent_session_id, toolset_json, \
|
|
memory_provider, memory_provider_config, iteration_budget, \
|
|
created_by, created_at, updated_at, deleted_at \
|
|
FROM agent_session \
|
|
WHERE (\"user\" = $1 OR wk IN (SELECT wk FROM wk_member WHERE \"user\" = $1)) \
|
|
AND deleted_at IS NULL \
|
|
ORDER BY updated_at DESC",
|
|
)
|
|
.bind(user_id)
|
|
.fetch_all(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(rows.into_iter().map(Into::into).collect())
|
|
}
|
|
|
|
pub async fn agent_session_get(
|
|
&self,
|
|
user_id: Uuid,
|
|
session_id: Uuid,
|
|
) -> Result<AgentSessionResponse, AppError> {
|
|
let row = sqlx::query_as::<_, AgentSessionModel>(
|
|
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
|
|
system_prompt, temperature, max_output_tokens, tool_policy, \
|
|
knowledge_base_ids, variables, visibility, version, \
|
|
published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \
|
|
created_at, updated_at, deleted_at \
|
|
FROM agent_session \
|
|
WHERE id = $1 AND deleted_at IS NULL",
|
|
)
|
|
.bind(session_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?;
|
|
|
|
if row.user != Some(user_id) {
|
|
if let Some(wk) = row.wk {
|
|
let _ = crate::AppService::workspace_require_member(
|
|
&*self, wk, user_id,
|
|
)
|
|
.await?;
|
|
} else {
|
|
return Err(AppError::PermissionDenied);
|
|
}
|
|
}
|
|
|
|
Ok(row.into())
|
|
}
|
|
|
|
pub async fn agent_session_update(
|
|
&self,
|
|
user_id: Uuid,
|
|
session_id: Uuid,
|
|
params: UpdateAgentSession,
|
|
) -> Result<AgentSessionResponse, AppError> {
|
|
let existing = sqlx::query_as::<_, AgentSessionModel>(
|
|
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
|
|
system_prompt, temperature, max_output_tokens, tool_policy, \
|
|
knowledge_base_ids, variables, visibility, version, \
|
|
published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \
|
|
created_at, updated_at, deleted_at \
|
|
FROM agent_session \
|
|
WHERE id = $1 AND deleted_at IS NULL",
|
|
)
|
|
.bind(session_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?;
|
|
|
|
if existing.user != Some(user_id) {
|
|
if let Some(wk) = existing.wk {
|
|
let _ = crate::AppService::workspace_require_admin(
|
|
&*self, wk, user_id,
|
|
)
|
|
.await?;
|
|
} else {
|
|
return Err(AppError::PermissionDenied);
|
|
}
|
|
}
|
|
|
|
let now = Utc::now();
|
|
let name = params.name.unwrap_or(existing.name);
|
|
let description = params.description.or(existing.description);
|
|
let system_prompt = params.system_prompt.or(existing.system_prompt);
|
|
let temperature = params.temperature.or(existing.temperature);
|
|
let max_output_tokens =
|
|
params.max_output_tokens.or(existing.max_output_tokens);
|
|
let model_version = params.model_version.or(existing.model_version);
|
|
let tool_policy = params.tool_policy.or(existing.tool_policy);
|
|
let toolset_json = params.toolset_json.or(existing.toolset_json);
|
|
let memory_provider =
|
|
params.memory_provider.or(existing.memory_provider);
|
|
let memory_provider_config = params
|
|
.memory_provider_config
|
|
.or(existing.memory_provider_config);
|
|
let iteration_budget =
|
|
params.iteration_budget.or(existing.iteration_budget);
|
|
let visibility = params.visibility.unwrap_or(existing.visibility);
|
|
let enabled = params.enabled.unwrap_or(existing.enabled);
|
|
let kb_ids = params
|
|
.knowledge_base_ids
|
|
.map(|ids| {
|
|
ids.iter()
|
|
.map(|id| id.to_string())
|
|
.collect::<Vec<_>>()
|
|
.join(",")
|
|
})
|
|
.or(existing.knowledge_base_ids);
|
|
let variables = params.variables.or(existing.variables);
|
|
|
|
let row = sqlx::query_as::<_, AgentSessionModel>(
|
|
"UPDATE agent_session SET \
|
|
name = $1, description = $2, system_prompt = $3, temperature = $4, \
|
|
max_output_tokens = $5, model_version = $6, tool_policy = $7, \
|
|
toolset_json = $8, memory_provider = $9, \
|
|
memory_provider_config = $10, iteration_budget = $11, \
|
|
visibility = $12, enabled = $13, knowledge_base_ids = $14, \
|
|
variables = $15, updated_at = $16 \
|
|
WHERE id = $17 AND deleted_at IS NULL \
|
|
RETURNING id, \"user\", wk, name, description, agent_kind, model_version, \
|
|
system_prompt, temperature, max_output_tokens, tool_policy, \
|
|
knowledge_base_ids, variables, visibility, version, \
|
|
published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \
|
|
created_at, updated_at, deleted_at",
|
|
)
|
|
.bind(&name)
|
|
.bind(&description)
|
|
.bind(&system_prompt)
|
|
.bind(temperature)
|
|
.bind(max_output_tokens)
|
|
.bind(model_version)
|
|
.bind(&tool_policy)
|
|
.bind(&toolset_json)
|
|
.bind(&memory_provider)
|
|
.bind(&memory_provider_config)
|
|
.bind(iteration_budget)
|
|
.bind(&visibility)
|
|
.bind(enabled)
|
|
.bind(&kb_ids)
|
|
.bind(&variables)
|
|
.bind(now)
|
|
.bind(session_id)
|
|
.fetch_one(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(row.into())
|
|
}
|
|
|
|
pub async fn agent_session_delete(
|
|
&self,
|
|
user_id: Uuid,
|
|
session_id: Uuid,
|
|
) -> Result<(), AppError> {
|
|
let existing = sqlx::query_as::<_, AgentSessionModel>(
|
|
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
|
|
system_prompt, temperature, max_output_tokens, tool_policy, \
|
|
knowledge_base_ids, variables, visibility, version, \
|
|
published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \
|
|
created_at, updated_at, deleted_at \
|
|
FROM agent_session \
|
|
WHERE id = $1 AND deleted_at IS NULL",
|
|
)
|
|
.bind(session_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?;
|
|
|
|
if existing.user != Some(user_id) {
|
|
if let Some(wk) = existing.wk {
|
|
let _ = crate::AppService::workspace_require_admin(
|
|
&*self, wk, user_id,
|
|
)
|
|
.await?;
|
|
} else {
|
|
return Err(AppError::PermissionDenied);
|
|
}
|
|
}
|
|
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"UPDATE agent_session SET deleted_at = $1, updated_at = $1 WHERE id = $2",
|
|
)
|
|
.bind(now)
|
|
.bind(session_id)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
pub async fn agent_session_search(
|
|
&self,
|
|
user_id: Uuid,
|
|
query: &str,
|
|
limit: u32,
|
|
) -> Result<Vec<AgentSessionResponse>, AppError> {
|
|
let rows = sqlx::query_as::<_, AgentSessionModel>(
|
|
"SELECT DISTINCT s.id, s.\"user\", s.wk, s.name, s.description, \
|
|
s.agent_kind, s.model_version, \
|
|
s.system_prompt, s.temperature, s.max_output_tokens, s.tool_policy, \
|
|
s.knowledge_base_ids, s.variables, s.visibility, s.version, \
|
|
s.published_at, s.rollback_from_version, s.enabled, \
|
|
s.source, s.parent_session_id, s.toolset_json, \
|
|
s.memory_provider, s.memory_provider_config, s.iteration_budget, \
|
|
s.created_by, s.created_at, s.updated_at, s.deleted_at \
|
|
FROM agent_session s \
|
|
INNER JOIN agent_message m ON m.conversation IN ( \
|
|
SELECT id FROM agent_conversation WHERE session = s.id \
|
|
) \
|
|
WHERE (s.\"user\" = $1 OR s.wk IN (SELECT wk FROM wk_member WHERE \"user\" = $1)) \
|
|
AND s.deleted_at IS NULL \
|
|
AND m.deleted_at IS NULL \
|
|
AND m.search_vector @@ plainto_tsquery('english', $2) \
|
|
ORDER BY s.updated_at DESC \
|
|
LIMIT $3",
|
|
)
|
|
.bind(user_id)
|
|
.bind(query)
|
|
.bind(limit as i64)
|
|
.fetch_all(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(rows.into_iter().map(Into::into).collect())
|
|
}
|
|
pub async fn agent_session_update_toolsets(
|
|
&self,
|
|
user_id: Uuid,
|
|
session_id: Uuid,
|
|
enabled: Option<Vec<String>>,
|
|
disabled: Option<Vec<String>>,
|
|
) -> Result<AgentSessionResponse, AppError> {
|
|
let existing = sqlx::query_as::<_, AgentSessionModel>(
|
|
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
|
|
system_prompt, temperature, max_output_tokens, tool_policy, \
|
|
knowledge_base_ids, variables, visibility, version, \
|
|
published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \
|
|
created_at, updated_at, deleted_at \
|
|
FROM agent_session \
|
|
WHERE id = $1 AND deleted_at IS NULL",
|
|
)
|
|
.bind(session_id)
|
|
.fetch_optional(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?
|
|
.ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?;
|
|
|
|
if existing.user != Some(user_id) {
|
|
if let Some(wk) = existing.wk {
|
|
let _ = crate::AppService::workspace_require_admin(
|
|
&*self, wk, user_id,
|
|
)
|
|
.await?;
|
|
} else {
|
|
return Err(AppError::PermissionDenied);
|
|
}
|
|
}
|
|
|
|
let toolset_json = {
|
|
let mut current: serde_json::Map<String, serde_json::Value> =
|
|
existing
|
|
.toolset_json
|
|
.as_deref()
|
|
.and_then(|s| serde_json::from_str(s).ok())
|
|
.unwrap_or_default();
|
|
|
|
if let Some(en) = enabled {
|
|
current.insert(
|
|
"enabled".to_string(),
|
|
serde_json::Value::Array(
|
|
en.into_iter().map(serde_json::Value::String).collect(),
|
|
),
|
|
);
|
|
}
|
|
if let Some(dis) = disabled {
|
|
current.insert(
|
|
"disabled".to_string(),
|
|
serde_json::Value::Array(
|
|
dis.into_iter()
|
|
.map(serde_json::Value::String)
|
|
.collect(),
|
|
),
|
|
);
|
|
}
|
|
Some(serde_json::to_string(¤t).unwrap_or_default())
|
|
};
|
|
|
|
let now = Utc::now();
|
|
let row = sqlx::query_as::<_, AgentSessionModel>(
|
|
"UPDATE agent_session SET toolset_json = $1, updated_at = $2 \
|
|
WHERE id = $3 AND deleted_at IS NULL \
|
|
RETURNING id, \"user\", wk, name, description, agent_kind, model_version, \
|
|
system_prompt, temperature, max_output_tokens, tool_policy, \
|
|
knowledge_base_ids, variables, visibility, version, \
|
|
published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \
|
|
created_at, updated_at, deleted_at",
|
|
)
|
|
.bind(&toolset_json)
|
|
.bind(now)
|
|
.bind(session_id)
|
|
.fetch_one(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(row.into())
|
|
}
|
|
}
|
|
|
|
impl From<AgentSessionModel> for AgentSessionResponse {
|
|
fn from(m: AgentSessionModel) -> Self {
|
|
Self {
|
|
id: m.id,
|
|
name: m.name,
|
|
description: m.description,
|
|
agent_kind: m.agent_kind,
|
|
model_version: m.model_version,
|
|
system_prompt: m.system_prompt,
|
|
temperature: m.temperature,
|
|
max_output_tokens: m.max_output_tokens,
|
|
tool_policy: m.tool_policy,
|
|
toolset_json: m.toolset_json,
|
|
memory_provider: m.memory_provider,
|
|
iteration_budget: m.iteration_budget,
|
|
source: m.source,
|
|
parent_session_id: m.parent_session_id,
|
|
visibility: m.visibility,
|
|
version: m.version,
|
|
enabled: m.enabled,
|
|
user: m.user,
|
|
wk: m.wk,
|
|
variables: m.variables,
|
|
published_at: m.published_at,
|
|
created_at: m.created_at,
|
|
updated_at: m.updated_at,
|
|
}
|
|
}
|
|
}
|