gitdataai/lib/service/agent/session.rs

501 lines
19 KiB
Rust

use chrono::Utc;
use db::sqlx;
use model::agent::AgentSessionModel;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use uuid::Uuid;
use crate::AppService;
use crate::error::AppError;
#[derive(Debug, Clone, Deserialize, ToSchema)]
pub struct CreateAgentSession {
pub name: String,
pub agent_kind: String,
pub model_version: Uuid,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub system_prompt: Option<String>,
#[serde(default)]
pub temperature: Option<f32>,
pub max_output_tokens: Option<i32>,
pub tool_policy: Option<String>,
pub toolset_json: Option<String>,
pub memory_provider: Option<String>,
pub memory_provider_config: Option<String>,
pub iteration_budget: Option<i32>,
pub source: Option<String>,
pub visibility: Option<String>,
pub wk: Option<String>,
pub knowledge_base_ids: Option<Vec<Uuid>>,
pub variables: Option<String>,
}
#[derive(Debug, Clone, Deserialize, ToSchema)]
pub struct UpdateAgentSession {
pub name: Option<String>,
pub description: Option<String>,
pub system_prompt: Option<String>,
pub temperature: Option<f32>,
pub max_output_tokens: Option<i32>,
pub model_version: Option<Uuid>,
pub tool_policy: Option<String>,
pub toolset_json: Option<String>,
pub memory_provider: Option<String>,
pub memory_provider_config: Option<String>,
pub iteration_budget: Option<i32>,
pub visibility: Option<String>,
pub enabled: Option<bool>,
pub knowledge_base_ids: Option<Vec<Uuid>>,
pub variables: Option<String>,
}
#[derive(Debug, Clone, Serialize, ToSchema)]
pub struct AgentSessionResponse {
pub id: Uuid,
pub name: String,
pub description: Option<String>,
pub agent_kind: String,
pub model_version: Option<Uuid>,
pub system_prompt: Option<String>,
pub temperature: Option<f32>,
pub max_output_tokens: Option<i32>,
pub tool_policy: Option<String>,
pub toolset_json: Option<String>,
pub memory_provider: Option<String>,
pub iteration_budget: Option<i32>,
pub source: Option<String>,
pub parent_session_id: Option<Uuid>,
pub visibility: String,
pub version: i32,
pub enabled: bool,
pub user: Option<Uuid>,
pub wk: Option<Uuid>,
pub variables: Option<String>,
pub published_at: Option<chrono::DateTime<Utc>>,
pub created_at: chrono::DateTime<Utc>,
pub updated_at: chrono::DateTime<Utc>,
}
impl AppService {
pub async fn agent_session_create(
&self,
user_id: Uuid,
params: CreateAgentSession,
) -> Result<AgentSessionResponse, AppError> {
let wk_uuid: Option<Uuid> = if let Some(ref wk_name) = params.wk {
let wk =
crate::AppService::workspace_resolve(&*self, wk_name).await?;
let _ = crate::AppService::workspace_require_member(
&*self, wk.id, user_id,
)
.await?;
Some(wk.id)
} else {
None
};
let id = Uuid::now_v7();
let now = Utc::now();
let visibility =
params.visibility.unwrap_or_else(|| "private".to_string());
let kb_ids = params.knowledge_base_ids.map(|ids| {
ids.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(",")
});
let row = sqlx::query_as::<_, AgentSessionModel>(
"INSERT INTO agent_session \
(id, \"user\", wk, name, description, agent_kind, model_version, \
system_prompt, temperature, max_output_tokens, tool_policy, \
knowledge_base_ids, variables, visibility, version, enabled, \
source, toolset_json, memory_provider, memory_provider_config, iteration_budget, \
created_by, created_at, updated_at) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, 1, true, \
'api', '{}', 'simple', '{}', 90, \
$15, $16, $16) \
RETURNING id, \"user\", wk, name, description, agent_kind, model_version, \
system_prompt, temperature, max_output_tokens, tool_policy, \
knowledge_base_ids, variables, visibility, version, \
published_at, rollback_from_version, enabled, \
source, parent_session_id, toolset_json, \
memory_provider, memory_provider_config, iteration_budget, \
created_by, created_at, updated_at, deleted_at",
)
.bind(id)
.bind(user_id)
.bind(wk_uuid)
.bind(&params.name)
.bind(&params.description)
.bind(&params.agent_kind)
.bind(params.model_version)
.bind(&params.system_prompt)
.bind(params.temperature)
.bind(params.max_output_tokens)
.bind(&params.tool_policy)
.bind(&kb_ids)
.bind(&params.variables)
.bind(&visibility)
.bind(user_id)
.bind(now)
.fetch_one(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(row.into())
}
pub async fn agent_session_list(
&self,
user_id: Uuid,
) -> Result<Vec<AgentSessionResponse>, AppError> {
let rows = sqlx::query_as::<_, AgentSessionModel>(
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
system_prompt, temperature, max_output_tokens, tool_policy, \
knowledge_base_ids, variables, visibility, version, \
published_at, rollback_from_version, enabled, \
source, parent_session_id, toolset_json, \
memory_provider, memory_provider_config, iteration_budget, \
created_by, created_at, updated_at, deleted_at \
FROM agent_session \
WHERE (\"user\" = $1 OR wk IN (SELECT wk FROM wk_member WHERE \"user\" = $1)) \
AND deleted_at IS NULL \
ORDER BY updated_at DESC",
)
.bind(user_id)
.fetch_all(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(rows.into_iter().map(Into::into).collect())
}
pub async fn agent_session_get(
&self,
user_id: Uuid,
session_id: Uuid,
) -> Result<AgentSessionResponse, AppError> {
let row = sqlx::query_as::<_, AgentSessionModel>(
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
system_prompt, temperature, max_output_tokens, tool_policy, \
knowledge_base_ids, variables, visibility, version, \
published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \
created_at, updated_at, deleted_at \
FROM agent_session \
WHERE id = $1 AND deleted_at IS NULL",
)
.bind(session_id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?;
if row.user != Some(user_id) {
if let Some(wk) = row.wk {
let _ = crate::AppService::workspace_require_member(
&*self, wk, user_id,
)
.await?;
} else {
return Err(AppError::PermissionDenied);
}
}
Ok(row.into())
}
pub async fn agent_session_update(
&self,
user_id: Uuid,
session_id: Uuid,
params: UpdateAgentSession,
) -> Result<AgentSessionResponse, AppError> {
let existing = sqlx::query_as::<_, AgentSessionModel>(
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
system_prompt, temperature, max_output_tokens, tool_policy, \
knowledge_base_ids, variables, visibility, version, \
published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \
created_at, updated_at, deleted_at \
FROM agent_session \
WHERE id = $1 AND deleted_at IS NULL",
)
.bind(session_id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?;
if existing.user != Some(user_id) {
if let Some(wk) = existing.wk {
let _ = crate::AppService::workspace_require_admin(
&*self, wk, user_id,
)
.await?;
} else {
return Err(AppError::PermissionDenied);
}
}
let now = Utc::now();
let name = params.name.unwrap_or(existing.name);
let description = params.description.or(existing.description);
let system_prompt = params.system_prompt.or(existing.system_prompt);
let temperature = params.temperature.or(existing.temperature);
let max_output_tokens =
params.max_output_tokens.or(existing.max_output_tokens);
let model_version = params.model_version.or(existing.model_version);
let tool_policy = params.tool_policy.or(existing.tool_policy);
let toolset_json = params.toolset_json.or(existing.toolset_json);
let memory_provider =
params.memory_provider.or(existing.memory_provider);
let memory_provider_config = params
.memory_provider_config
.or(existing.memory_provider_config);
let iteration_budget =
params.iteration_budget.or(existing.iteration_budget);
let visibility = params.visibility.unwrap_or(existing.visibility);
let enabled = params.enabled.unwrap_or(existing.enabled);
let kb_ids = params
.knowledge_base_ids
.map(|ids| {
ids.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(",")
})
.or(existing.knowledge_base_ids);
let variables = params.variables.or(existing.variables);
let row = sqlx::query_as::<_, AgentSessionModel>(
"UPDATE agent_session SET \
name = $1, description = $2, system_prompt = $3, temperature = $4, \
max_output_tokens = $5, model_version = $6, tool_policy = $7, \
toolset_json = $8, memory_provider = $9, \
memory_provider_config = $10, iteration_budget = $11, \
visibility = $12, enabled = $13, knowledge_base_ids = $14, \
variables = $15, updated_at = $16 \
WHERE id = $17 AND deleted_at IS NULL \
RETURNING id, \"user\", wk, name, description, agent_kind, model_version, \
system_prompt, temperature, max_output_tokens, tool_policy, \
knowledge_base_ids, variables, visibility, version, \
published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \
created_at, updated_at, deleted_at",
)
.bind(&name)
.bind(&description)
.bind(&system_prompt)
.bind(temperature)
.bind(max_output_tokens)
.bind(model_version)
.bind(&tool_policy)
.bind(&toolset_json)
.bind(&memory_provider)
.bind(&memory_provider_config)
.bind(iteration_budget)
.bind(&visibility)
.bind(enabled)
.bind(&kb_ids)
.bind(&variables)
.bind(now)
.bind(session_id)
.fetch_one(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(row.into())
}
pub async fn agent_session_delete(
&self,
user_id: Uuid,
session_id: Uuid,
) -> Result<(), AppError> {
let existing = sqlx::query_as::<_, AgentSessionModel>(
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
system_prompt, temperature, max_output_tokens, tool_policy, \
knowledge_base_ids, variables, visibility, version, \
published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \
created_at, updated_at, deleted_at \
FROM agent_session \
WHERE id = $1 AND deleted_at IS NULL",
)
.bind(session_id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?;
if existing.user != Some(user_id) {
if let Some(wk) = existing.wk {
let _ = crate::AppService::workspace_require_admin(
&*self, wk, user_id,
)
.await?;
} else {
return Err(AppError::PermissionDenied);
}
}
let now = Utc::now();
sqlx::query(
"UPDATE agent_session SET deleted_at = $1, updated_at = $1 WHERE id = $2",
)
.bind(now)
.bind(session_id)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(())
}
pub async fn agent_session_search(
&self,
user_id: Uuid,
query: &str,
limit: u32,
) -> Result<Vec<AgentSessionResponse>, AppError> {
let rows = sqlx::query_as::<_, AgentSessionModel>(
"SELECT DISTINCT s.id, s.\"user\", s.wk, s.name, s.description, \
s.agent_kind, s.model_version, \
s.system_prompt, s.temperature, s.max_output_tokens, s.tool_policy, \
s.knowledge_base_ids, s.variables, s.visibility, s.version, \
s.published_at, s.rollback_from_version, s.enabled, \
s.source, s.parent_session_id, s.toolset_json, \
s.memory_provider, s.memory_provider_config, s.iteration_budget, \
s.created_by, s.created_at, s.updated_at, s.deleted_at \
FROM agent_session s \
INNER JOIN agent_message m ON m.conversation IN ( \
SELECT id FROM agent_conversation WHERE session = s.id \
) \
WHERE (s.\"user\" = $1 OR s.wk IN (SELECT wk FROM wk_member WHERE \"user\" = $1)) \
AND s.deleted_at IS NULL \
AND m.deleted_at IS NULL \
AND m.search_vector @@ plainto_tsquery('english', $2) \
ORDER BY s.updated_at DESC \
LIMIT $3",
)
.bind(user_id)
.bind(query)
.bind(limit as i64)
.fetch_all(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(rows.into_iter().map(Into::into).collect())
}
pub async fn agent_session_update_toolsets(
&self,
user_id: Uuid,
session_id: Uuid,
enabled: Option<Vec<String>>,
disabled: Option<Vec<String>>,
) -> Result<AgentSessionResponse, AppError> {
let existing = sqlx::query_as::<_, AgentSessionModel>(
"SELECT id, \"user\", wk, name, description, agent_kind, model_version, \
system_prompt, temperature, max_output_tokens, tool_policy, \
knowledge_base_ids, variables, visibility, version, \
published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \
created_at, updated_at, deleted_at \
FROM agent_session \
WHERE id = $1 AND deleted_at IS NULL",
)
.bind(session_id)
.fetch_optional(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?
.ok_or_else(|| AppError::NotFound("agent session not found".to_string()))?;
if existing.user != Some(user_id) {
if let Some(wk) = existing.wk {
let _ = crate::AppService::workspace_require_admin(
&*self, wk, user_id,
)
.await?;
} else {
return Err(AppError::PermissionDenied);
}
}
let toolset_json = {
let mut current: serde_json::Map<String, serde_json::Value> =
existing
.toolset_json
.as_deref()
.and_then(|s| serde_json::from_str(s).ok())
.unwrap_or_default();
if let Some(en) = enabled {
current.insert(
"enabled".to_string(),
serde_json::Value::Array(
en.into_iter().map(serde_json::Value::String).collect(),
),
);
}
if let Some(dis) = disabled {
current.insert(
"disabled".to_string(),
serde_json::Value::Array(
dis.into_iter()
.map(serde_json::Value::String)
.collect(),
),
);
}
Some(serde_json::to_string(&current).unwrap_or_default())
};
let now = Utc::now();
let row = sqlx::query_as::<_, AgentSessionModel>(
"UPDATE agent_session SET toolset_json = $1, updated_at = $2 \
WHERE id = $3 AND deleted_at IS NULL \
RETURNING id, \"user\", wk, name, description, agent_kind, model_version, \
system_prompt, temperature, max_output_tokens, tool_policy, \
knowledge_base_ids, variables, visibility, version, \
published_at, rollback_from_version, enabled, source, parent_session_id, toolset_json, memory_provider, memory_provider_config, iteration_budget, created_by, \
created_at, updated_at, deleted_at",
)
.bind(&toolset_json)
.bind(now)
.bind(session_id)
.fetch_one(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(row.into())
}
}
impl From<AgentSessionModel> for AgentSessionResponse {
fn from(m: AgentSessionModel) -> Self {
Self {
id: m.id,
name: m.name,
description: m.description,
agent_kind: m.agent_kind,
model_version: m.model_version,
system_prompt: m.system_prompt,
temperature: m.temperature,
max_output_tokens: m.max_output_tokens,
tool_policy: m.tool_policy,
toolset_json: m.toolset_json,
memory_provider: m.memory_provider,
iteration_budget: m.iteration_budget,
source: m.source,
parent_session_id: m.parent_session_id,
visibility: m.visibility,
version: m.version,
enabled: m.enabled,
user: m.user,
wk: m.wk,
variables: m.variables,
published_at: m.published_at,
created_at: m.created_at,
updated_at: m.updated_at,
}
}
}