245 lines
6.4 KiB
Rust
245 lines
6.4 KiB
Rust
use ai::{
|
|
error::{AiError, AiResult},
|
|
tool::tools::FunctionCall,
|
|
};
|
|
use async_trait::async_trait;
|
|
use chrono::Utc;
|
|
use db::sqlx;
|
|
use serde_json::{Value, json};
|
|
use tracing::info;
|
|
use uuid::Uuid;
|
|
|
|
use super::run::AppAgentContext;
|
|
use crate::AppService;
|
|
use crate::error::AppError;
|
|
pub struct SaveMemoryTool;
|
|
|
|
impl SaveMemoryTool {
|
|
pub fn new() -> Self {
|
|
Self
|
|
}
|
|
}
|
|
|
|
impl Default for SaveMemoryTool {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl FunctionCall for SaveMemoryTool {
|
|
type Context = AppAgentContext;
|
|
|
|
fn name(&self) -> &'static str {
|
|
"save_memory"
|
|
}
|
|
|
|
fn schema(&self) -> Value {
|
|
json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"key": {
|
|
"type": "string",
|
|
"description": "A short, descriptive key for the memory (e.g. 'user_preference_language', 'project_architecture')"
|
|
},
|
|
"value": {
|
|
"type": "string",
|
|
"description": "The information to remember"
|
|
},
|
|
"importance": {
|
|
"type": "integer",
|
|
"description": "Importance level 0-10 (10 = most important). Default: 5",
|
|
"minimum": 0,
|
|
"maximum": 10
|
|
}
|
|
},
|
|
"required": ["key", "value"]
|
|
})
|
|
}
|
|
|
|
async fn call(
|
|
&self,
|
|
context: &mut Self::Context,
|
|
args: Value,
|
|
) -> AiResult<Value> {
|
|
let key =
|
|
args.get("key").and_then(|v| v.as_str()).ok_or_else(|| {
|
|
AiError::Config("key parameter is required".to_string())
|
|
})?;
|
|
|
|
let value =
|
|
args.get("value").and_then(|v| v.as_str()).ok_or_else(|| {
|
|
AiError::Config("value parameter is required".to_string())
|
|
})?;
|
|
|
|
let importance = args
|
|
.get("importance")
|
|
.and_then(|v| v.as_i64())
|
|
.unwrap_or(5)
|
|
.clamp(0, 10) as i32;
|
|
|
|
context.pending_memories.push(PendingMemory {
|
|
key: key.to_string(),
|
|
value: value.to_string(),
|
|
importance,
|
|
});
|
|
|
|
Ok(json!({
|
|
"success": true,
|
|
"key": key,
|
|
"message": format!("Memory '{}' saved (importance: {})", key, importance)
|
|
}))
|
|
}
|
|
}
|
|
pub struct RecallMemoryTool {
|
|
memories_json: String,
|
|
}
|
|
|
|
impl RecallMemoryTool {
|
|
pub fn new(memories_json: String) -> Self {
|
|
Self { memories_json }
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl FunctionCall for RecallMemoryTool {
|
|
type Context = AppAgentContext;
|
|
|
|
fn name(&self) -> &'static str {
|
|
"recall_memory"
|
|
}
|
|
|
|
fn schema(&self) -> Value {
|
|
json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Optional search query to filter memories by key or content"
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
async fn call(
|
|
&self,
|
|
_context: &mut Self::Context,
|
|
args: Value,
|
|
) -> AiResult<Value> {
|
|
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
|
|
|
|
if query.is_empty() {
|
|
return Ok(json!({
|
|
"memories": self.memories_json,
|
|
"count": "all"
|
|
}));
|
|
}
|
|
|
|
Ok(json!({
|
|
"memories": self.memories_json,
|
|
"query": query,
|
|
"hint": "Search the memories above for matches to your query"
|
|
}))
|
|
}
|
|
}
|
|
#[derive(Debug, Clone)]
|
|
pub struct PendingMemory {
|
|
pub key: String,
|
|
pub value: String,
|
|
pub importance: i32,
|
|
}
|
|
|
|
impl AppService {
|
|
pub async fn agent_load_memories(
|
|
&self,
|
|
session_id: Uuid,
|
|
) -> Result<(String, Vec<(Uuid, String, String, i32)>), AppError> {
|
|
let rows: Vec<(Uuid, String, String, i32)> = sqlx::query_as(
|
|
"SELECT id, key, value, importance \
|
|
FROM agent_long_term_memory \
|
|
WHERE session = $1 AND deleted_at IS NULL \
|
|
ORDER BY importance DESC, updated_at DESC \
|
|
LIMIT 50",
|
|
)
|
|
.bind(session_id)
|
|
.fetch_all(self.db.reader())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
if rows.is_empty() {
|
|
return Ok((String::new(), rows));
|
|
}
|
|
|
|
let mut formatted =
|
|
String::from("Long-term memories for this session:\n");
|
|
for (_, key, value, importance) in &rows {
|
|
formatted.push_str(&format!(
|
|
"- [{}] {} (importance: {})\n",
|
|
key, value, importance
|
|
));
|
|
}
|
|
|
|
Ok((formatted, rows))
|
|
}
|
|
pub async fn agent_persist_memories(
|
|
&self,
|
|
session_id: Uuid,
|
|
memories: &[PendingMemory],
|
|
) -> Result<(), AppError> {
|
|
if memories.is_empty() {
|
|
return Ok(());
|
|
}
|
|
|
|
let now = Utc::now();
|
|
for mem in memories {
|
|
sqlx::query(
|
|
"INSERT INTO agent_long_term_memory \
|
|
(id, session, key, value, importance, created_at, updated_at) \
|
|
VALUES ($1, $2, $3, $4, $5, $6, $6) \
|
|
ON CONFLICT (session, key) WHERE deleted_at IS NULL \
|
|
DO UPDATE SET value = $4, importance = $5, updated_at = $6",
|
|
)
|
|
.bind(Uuid::now_v7())
|
|
.bind(session_id)
|
|
.bind(&mem.key)
|
|
.bind(&mem.value)
|
|
.bind(mem.importance)
|
|
.bind(now)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
}
|
|
|
|
info!(
|
|
session_id = %session_id,
|
|
count = memories.len(),
|
|
"persisted long-term memories from agent run"
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
#[allow(dead_code)]
|
|
pub async fn agent_touch_memories(
|
|
&self,
|
|
memory_ids: &[Uuid],
|
|
) -> Result<(), AppError> {
|
|
if memory_ids.is_empty() {
|
|
return Ok(());
|
|
}
|
|
|
|
let now = Utc::now();
|
|
sqlx::query(
|
|
"UPDATE agent_long_term_memory \
|
|
SET last_used_at = $1 \
|
|
WHERE id = ANY($2::uuid[])",
|
|
)
|
|
.bind(now)
|
|
.bind(memory_ids)
|
|
.execute(self.db.writer())
|
|
.await
|
|
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|