gitdataai/lib/service/agent/memory.rs

245 lines
6.4 KiB
Rust

use ai::{
error::{AiError, AiResult},
tool::tools::FunctionCall,
};
use async_trait::async_trait;
use chrono::Utc;
use db::sqlx;
use serde_json::{Value, json};
use tracing::info;
use uuid::Uuid;
use super::run::AppAgentContext;
use crate::AppService;
use crate::error::AppError;
pub struct SaveMemoryTool;
impl SaveMemoryTool {
pub fn new() -> Self {
Self
}
}
impl Default for SaveMemoryTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl FunctionCall for SaveMemoryTool {
type Context = AppAgentContext;
fn name(&self) -> &'static str {
"save_memory"
}
fn schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "A short, descriptive key for the memory (e.g. 'user_preference_language', 'project_architecture')"
},
"value": {
"type": "string",
"description": "The information to remember"
},
"importance": {
"type": "integer",
"description": "Importance level 0-10 (10 = most important). Default: 5",
"minimum": 0,
"maximum": 10
}
},
"required": ["key", "value"]
})
}
async fn call(
&self,
context: &mut Self::Context,
args: Value,
) -> AiResult<Value> {
let key =
args.get("key").and_then(|v| v.as_str()).ok_or_else(|| {
AiError::Config("key parameter is required".to_string())
})?;
let value =
args.get("value").and_then(|v| v.as_str()).ok_or_else(|| {
AiError::Config("value parameter is required".to_string())
})?;
let importance = args
.get("importance")
.and_then(|v| v.as_i64())
.unwrap_or(5)
.clamp(0, 10) as i32;
context.pending_memories.push(PendingMemory {
key: key.to_string(),
value: value.to_string(),
importance,
});
Ok(json!({
"success": true,
"key": key,
"message": format!("Memory '{}' saved (importance: {})", key, importance)
}))
}
}
pub struct RecallMemoryTool {
memories_json: String,
}
impl RecallMemoryTool {
pub fn new(memories_json: String) -> Self {
Self { memories_json }
}
}
#[async_trait]
impl FunctionCall for RecallMemoryTool {
type Context = AppAgentContext;
fn name(&self) -> &'static str {
"recall_memory"
}
fn schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Optional search query to filter memories by key or content"
}
}
})
}
async fn call(
&self,
_context: &mut Self::Context,
args: Value,
) -> AiResult<Value> {
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
if query.is_empty() {
return Ok(json!({
"memories": self.memories_json,
"count": "all"
}));
}
Ok(json!({
"memories": self.memories_json,
"query": query,
"hint": "Search the memories above for matches to your query"
}))
}
}
#[derive(Debug, Clone)]
pub struct PendingMemory {
pub key: String,
pub value: String,
pub importance: i32,
}
impl AppService {
pub async fn agent_load_memories(
&self,
session_id: Uuid,
) -> Result<(String, Vec<(Uuid, String, String, i32)>), AppError> {
let rows: Vec<(Uuid, String, String, i32)> = sqlx::query_as(
"SELECT id, key, value, importance \
FROM agent_long_term_memory \
WHERE session = $1 AND deleted_at IS NULL \
ORDER BY importance DESC, updated_at DESC \
LIMIT 50",
)
.bind(session_id)
.fetch_all(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
if rows.is_empty() {
return Ok((String::new(), rows));
}
let mut formatted =
String::from("Long-term memories for this session:\n");
for (_, key, value, importance) in &rows {
formatted.push_str(&format!(
"- [{}] {} (importance: {})\n",
key, value, importance
));
}
Ok((formatted, rows))
}
pub async fn agent_persist_memories(
&self,
session_id: Uuid,
memories: &[PendingMemory],
) -> Result<(), AppError> {
if memories.is_empty() {
return Ok(());
}
let now = Utc::now();
for mem in memories {
sqlx::query(
"INSERT INTO agent_long_term_memory \
(id, session, key, value, importance, created_at, updated_at) \
VALUES ($1, $2, $3, $4, $5, $6, $6) \
ON CONFLICT (session, key) WHERE deleted_at IS NULL \
DO UPDATE SET value = $4, importance = $5, updated_at = $6",
)
.bind(Uuid::now_v7())
.bind(session_id)
.bind(&mem.key)
.bind(&mem.value)
.bind(mem.importance)
.bind(now)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
}
info!(
session_id = %session_id,
count = memories.len(),
"persisted long-term memories from agent run"
);
Ok(())
}
#[allow(dead_code)]
pub async fn agent_touch_memories(
&self,
memory_ids: &[Uuid],
) -> Result<(), AppError> {
if memory_ids.is_empty() {
return Ok(());
}
let now = Utc::now();
sqlx::query(
"UPDATE agent_long_term_memory \
SET last_used_at = $1 \
WHERE id = ANY($2::uuid[])",
)
.bind(now)
.bind(memory_ids)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(())
}
}