feat: 1.0
This commit is contained in:
parent
e1330451a5
commit
a835610737
4
.clippy.toml
Normal file
4
.clippy.toml
Normal file
@ -0,0 +1,4 @@
|
||||
# Clippy configuration
|
||||
doc-valid-idents = ["GitHub", "GitLab", "TypeScript", "WebSocket", "PostgreSQL", "Redis", "OpenAI"]
|
||||
avoid-breaking-exported-api = true
|
||||
disallowed-types = []
|
||||
42
.editorconfig
Normal file
42
.editorconfig
Normal file
@ -0,0 +1,42 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
charset = utf-8
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
trim_trailing_whitespace = true
|
||||
|
||||
[*.{js,ts,jsx,tsx,json,jsonc,md,yaml,yml,toml}]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
|
||||
[*.py]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
|
||||
[*.go]
|
||||
indent_style = tab
|
||||
indent_size = unset
|
||||
tab_width = 8
|
||||
[*.rs]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
|
||||
[Makefile]
|
||||
indent_style = tab
|
||||
indent_size = unset
|
||||
|
||||
[Dockerfile]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
|
||||
[*.{yml,yaml}]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
|
||||
[*.toml]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
28
app/email/Cargo.toml
Normal file
28
app/email/Cargo.toml
Normal file
@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "app-email"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
readme.workspace = true
|
||||
homepage.workspace = true
|
||||
license.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
documentation.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "email-service"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
config = { workspace = true }
|
||||
email = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
24
app/email/src/context.rs
Normal file
24
app/email/src/context.rs
Normal file
@ -0,0 +1,24 @@
|
||||
use config::AppConfig;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
pub struct AppContext {
|
||||
pub config: AppConfig,
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
pub fn init() -> anyhow::Result<Self> {
|
||||
let config = AppConfig::load();
|
||||
init_tracing(&config)?;
|
||||
Ok(Self { config })
|
||||
}
|
||||
}
|
||||
|
||||
fn init_tracing(config: &AppConfig) -> anyhow::Result<()> {
|
||||
let level = config.log_level()?;
|
||||
let filter = EnvFilter::try_new(&level)?;
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(filter)
|
||||
.with_target(false)
|
||||
.init();
|
||||
Ok(())
|
||||
}
|
||||
22
app/email/src/main.rs
Normal file
22
app/email/src/main.rs
Normal file
@ -0,0 +1,22 @@
|
||||
mod context;
|
||||
|
||||
use context::AppContext;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let ctx = AppContext::init()?;
|
||||
tracing::info!("email service starting");
|
||||
|
||||
tokio::select! {
|
||||
result = email::EmailWorker::start(&ctx.config) => {
|
||||
if let Err(e) = result {
|
||||
tracing::error!("email worker exited with error: {}", e);
|
||||
}
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
tracing::info!("shutdown signal received, stopping email service");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
50
app/gitdata/Cargo.toml
Normal file
50
app/gitdata/Cargo.toml
Normal file
@ -0,0 +1,50 @@
|
||||
[package]
|
||||
name = "app-gitdata"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
readme.workspace = true
|
||||
homepage.workspace = true
|
||||
license.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
documentation.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "gitdata"
|
||||
path = "src/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "gen-openapi"
|
||||
path = "src/bin/gen-openapi.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
config = { workspace = true }
|
||||
cache = { workspace = true }
|
||||
db = { workspace = true }
|
||||
service = { workspace = true }
|
||||
session = { workspace = true }
|
||||
api = { workspace = true }
|
||||
email = { workspace = true }
|
||||
storage = { workspace = true }
|
||||
git = { workspace = true }
|
||||
model = { workspace = true }
|
||||
channel = { workspace = true }
|
||||
socketio = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
|
||||
actix-web = { workspace = true, features = ["cookies", "secure-cookies"] }
|
||||
actix-ws = { workspace = true }
|
||||
tonic = { workspace = true, features = ["transport"] }
|
||||
deadpool-redis = { workspace = true }
|
||||
redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp", "connection-manager", "cluster"] }
|
||||
sqlx = { workspace = true, features = ["postgres", "runtime-tokio"] }
|
||||
serde_json = { workspace = true }
|
||||
uuid = { workspace = true, features = ["v4", "v7", "serde"] }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
7
app/gitdata/src/bin/gen-openapi.rs
Normal file
7
app/gitdata/src/bin/gen-openapi.rs
Normal file
@ -0,0 +1,7 @@
|
||||
use std::fs;
|
||||
|
||||
fn main() {
|
||||
let json = api::openapi::openapi_json();
|
||||
fs::write("openapi.json", json).expect("Failed to write openapi.json");
|
||||
println!("openapi.json generated successfully");
|
||||
}
|
||||
127
app/gitdata/src/context.rs
Normal file
127
app/gitdata/src/context.rs
Normal file
@ -0,0 +1,127 @@
|
||||
use actix_web::cookie::Key;
|
||||
use cache::{AppCache, AppCacheConfig};
|
||||
use config::AppConfig;
|
||||
use db::database::AppDatabase;
|
||||
use deadpool_redis::{
|
||||
PoolConfig, Runtime, Timeouts,
|
||||
cluster::{Config, Pool as RedisPool},
|
||||
};
|
||||
use email::AppEmail;
|
||||
use service::AppService;
|
||||
use session::storage::RedisClusterSessionStore;
|
||||
use storage::{AppStorage, AppStorageConfig};
|
||||
use tonic::transport::Channel;
|
||||
|
||||
use channel::{ChannelBus, ChannelBusConfig};
|
||||
use socketio::SocketIo;
|
||||
|
||||
pub struct AppContext {
|
||||
pub config: AppConfig,
|
||||
pub service: AppService,
|
||||
pub session_store: RedisClusterSessionStore,
|
||||
pub session_key: Key,
|
||||
pub channel_bus: ChannelBus,
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
pub async fn init() -> anyhow::Result<Self> {
|
||||
let config = AppConfig::load();
|
||||
init_tracing(&config)?;
|
||||
|
||||
tracing::info!("initializing database");
|
||||
let db = AppDatabase::init(&config).await?;
|
||||
|
||||
tracing::info!("initializing cache");
|
||||
let cache_config = AppCacheConfig::try_from(&config)?;
|
||||
let cache = AppCache::init(cache_config).await?;
|
||||
|
||||
tracing::info!("initializing storage");
|
||||
let storage_config = AppStorageConfig::try_from(&config)?;
|
||||
let storage = AppStorage::init(storage_config).await?;
|
||||
|
||||
tracing::info!("initializing email");
|
||||
let email = AppEmail::init(&config).await?;
|
||||
|
||||
tracing::info!("connecting to git RPC");
|
||||
let rpc_addr = config.git_rpc_addr()?;
|
||||
let rpc_port = config.git_rpc_port()?;
|
||||
let git_channel =
|
||||
Channel::from_shared(format!("http://{}:{}", rpc_addr, rpc_port))
|
||||
.expect("invalid gRPC endpoint")
|
||||
.connect()
|
||||
.await?;
|
||||
|
||||
let service = AppService {
|
||||
db,
|
||||
cache,
|
||||
email,
|
||||
storage,
|
||||
config: config.clone(),
|
||||
git: git_channel,
|
||||
redis_pool: init_redis_pool(&config)?,
|
||||
};
|
||||
|
||||
tracing::info!("initializing session store");
|
||||
let redis_urls = config.redis_urls()?;
|
||||
let session_store = RedisClusterSessionStore::new(redis_urls).await?;
|
||||
|
||||
tracing::info!("initializing session key");
|
||||
let secret = config.session_secret()?;
|
||||
let session_key = Key::from(secret.as_bytes());
|
||||
|
||||
tracing::info!("initializing channel bus");
|
||||
let io = SocketIo::new();
|
||||
let channel_config = ChannelBusConfig {
|
||||
namespace: "/channel".to_owned(),
|
||||
signing_secret: Some(secret.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
let channel_bus = ChannelBus::new(
|
||||
service.db.clone(),
|
||||
service.cache.clone(),
|
||||
io,
|
||||
channel_config,
|
||||
);
|
||||
channel_bus.attach().await?;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
service,
|
||||
session_store,
|
||||
session_key,
|
||||
channel_bus,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn init_tracing(config: &AppConfig) -> anyhow::Result<()> {
|
||||
let level = config.log_level()?;
|
||||
let filter = tracing_subscriber::EnvFilter::try_new(&level)?;
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(filter)
|
||||
.with_target(false)
|
||||
.init();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn init_redis_pool(config: &AppConfig) -> anyhow::Result<RedisPool> {
|
||||
let redis_urls = config.redis_urls()?;
|
||||
let pool_size = config.redis_pool_size()?;
|
||||
let connect_timeout = config.redis_connect_timeout()?;
|
||||
let acquire_timeout = config.redis_acquire_timeout()?;
|
||||
|
||||
let mut pool_config = PoolConfig::new(pool_size as usize);
|
||||
pool_config.timeouts = Timeouts {
|
||||
wait: Some(std::time::Duration::from_secs(acquire_timeout)),
|
||||
create: Some(std::time::Duration::from_secs(connect_timeout)),
|
||||
recycle: Some(std::time::Duration::from_secs(connect_timeout)),
|
||||
};
|
||||
|
||||
let cfg = Config {
|
||||
urls: Some(redis_urls),
|
||||
connections: None,
|
||||
pool: Some(pool_config),
|
||||
read_from_replicas: false,
|
||||
};
|
||||
Ok(cfg.create_pool(Some(Runtime::Tokio1))?)
|
||||
}
|
||||
111
app/gitdata/src/main.rs
Normal file
111
app/gitdata/src/main.rs
Normal file
@ -0,0 +1,111 @@
|
||||
mod context;
|
||||
mod shutdown;
|
||||
|
||||
use std::time::Instant;
|
||||
|
||||
use actix_web::{App, dev::Service};
|
||||
use context::AppContext;
|
||||
use service::ai::sync::spawn_model_sync_loop;
|
||||
|
||||
const REQUEST_LOG_EXCLUDED_PATHS: &[&str] = &[
|
||||
"/health",
|
||||
"/live",
|
||||
"/ready",
|
||||
"/metrics",
|
||||
"/favicon.ico",
|
||||
"/robots.txt",
|
||||
];
|
||||
|
||||
fn should_log_request(path: &str) -> bool {
|
||||
!REQUEST_LOG_EXCLUDED_PATHS.contains(&path)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let ctx = AppContext::init().await?;
|
||||
|
||||
let api_port = ctx.config.api_port()?;
|
||||
tracing::info!("GitDataAI API service starting on 0.0.0.0:{}", api_port);
|
||||
|
||||
let service = ctx.service.clone();
|
||||
let session_store = ctx.session_store.clone();
|
||||
let session_key = ctx.session_key.clone();
|
||||
let channel_bus = ctx.channel_bus.clone();
|
||||
|
||||
let srv = actix_web::HttpServer::new(move || {
|
||||
let session_middleware = session::SessionMiddleware::builder(
|
||||
session_store.clone(),
|
||||
session_key.clone(),
|
||||
)
|
||||
.cookie_secure(false)
|
||||
.cookie_name("id".to_string())
|
||||
.session_lifecycle(
|
||||
session::config::PersistentSession::default()
|
||||
.session_ttl(actix_web::cookie::time::Duration::days(30)),
|
||||
)
|
||||
.build();
|
||||
|
||||
App::new()
|
||||
.app_data(actix_web::web::Data::new(service.clone()))
|
||||
.app_data(actix_web::web::Data::new(channel_bus.clone()))
|
||||
.wrap_fn(|req, srv| {
|
||||
let should_log = should_log_request(req.path());
|
||||
let method = req.method().clone();
|
||||
let path = req.path().to_owned();
|
||||
let peer_addr =
|
||||
req.connection_info().peer_addr().map(str::to_owned);
|
||||
let started_at = Instant::now();
|
||||
let fut = srv.call(req);
|
||||
|
||||
async move {
|
||||
match fut.await {
|
||||
Ok(res) => {
|
||||
if should_log {
|
||||
tracing::info!(
|
||||
method = %method,
|
||||
path = %path,
|
||||
status = res.status().as_u16(),
|
||||
elapsed_ms = started_at.elapsed().as_millis(),
|
||||
peer_addr = peer_addr.as_deref().unwrap_or("-"),
|
||||
"http request"
|
||||
);
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
Err(err) => {
|
||||
if should_log {
|
||||
tracing::warn!(
|
||||
method = %method,
|
||||
path = %path,
|
||||
elapsed_ms = started_at.elapsed().as_millis(),
|
||||
peer_addr = peer_addr.as_deref().unwrap_or("-"),
|
||||
error = %err,
|
||||
"http request failed"
|
||||
);
|
||||
}
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.wrap(session_middleware)
|
||||
.configure(|cfg| api::configure(cfg, channel_bus.clone()))
|
||||
})
|
||||
.bind(format!("0.0.0.0:{}", api_port))?;
|
||||
|
||||
spawn_model_sync_loop(ctx.service.clone());
|
||||
|
||||
let server = srv.run();
|
||||
tracing::info!("API server is running");
|
||||
|
||||
tokio::select! {
|
||||
_ = server => {
|
||||
tracing::info!("API server stopped");
|
||||
}
|
||||
_ = shutdown::wait_for_shutdown_signal() => {
|
||||
tracing::info!("shutdown signal received, stopping gitdata API service");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
25
app/gitdata/src/shutdown.rs
Normal file
25
app/gitdata/src/shutdown.rs
Normal file
@ -0,0 +1,25 @@
|
||||
pub async fn wait_for_shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to listen for ctrl_c event");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
tokio::signal::unix::signal(
|
||||
tokio::signal::unix::SignalKind::terminate(),
|
||||
)
|
||||
.expect("failed to listen for SIGTERM")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {},
|
||||
_ = terminate => {},
|
||||
}
|
||||
}
|
||||
32
app/gitpod/Cargo.toml
Normal file
32
app/gitpod/Cargo.toml
Normal file
@ -0,0 +1,32 @@
|
||||
[package]
|
||||
name = "app-gitpod"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
readme.workspace = true
|
||||
homepage.workspace = true
|
||||
license.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
documentation.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "gitpod"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
config = { workspace = true }
|
||||
cache = { workspace = true }
|
||||
db = { workspace = true }
|
||||
git = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
|
||||
deadpool-redis = { workspace = true }
|
||||
redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp"] }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
65
app/gitpod/src/context.rs
Normal file
65
app/gitpod/src/context.rs
Normal file
@ -0,0 +1,65 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use cache::{AppCache, AppCacheConfig};
|
||||
use config::AppConfig;
|
||||
use db::database::AppDatabase;
|
||||
use deadpool_redis::{PoolConfig, Runtime, Timeouts, cluster::Config};
|
||||
|
||||
pub struct AppContext {
|
||||
pub config: AppConfig,
|
||||
pub db: AppDatabase,
|
||||
pub cache: AppCache,
|
||||
pub redis_pool: deadpool_redis::cluster::Pool,
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
pub async fn init() -> anyhow::Result<Self> {
|
||||
let config = AppConfig::load();
|
||||
init_tracing(&config)?;
|
||||
|
||||
tracing::info!("initializing database");
|
||||
let db = AppDatabase::init(&config).await?;
|
||||
|
||||
tracing::info!("initializing cache");
|
||||
let cache_config = AppCacheConfig::try_from(&config)?;
|
||||
let cache = AppCache::init(cache_config).await?;
|
||||
|
||||
tracing::info!("initializing redis pool");
|
||||
let redis_urls = config.redis_urls()?;
|
||||
let pool_size = config.redis_pool_size()?;
|
||||
let connect_timeout = config.redis_connect_timeout()?;
|
||||
let acquire_timeout = config.redis_acquire_timeout()?;
|
||||
|
||||
let mut pool_config = PoolConfig::new(pool_size as usize);
|
||||
pool_config.timeouts = Timeouts {
|
||||
wait: Some(Duration::from_secs(acquire_timeout)),
|
||||
create: Some(Duration::from_secs(connect_timeout)),
|
||||
recycle: Some(Duration::from_secs(connect_timeout)),
|
||||
};
|
||||
|
||||
let cfg = Config {
|
||||
urls: Some(redis_urls),
|
||||
connections: None,
|
||||
pool: Some(pool_config),
|
||||
read_from_replicas: false,
|
||||
};
|
||||
let redis_pool = cfg.create_pool(Some(Runtime::Tokio1))?;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
db,
|
||||
cache,
|
||||
redis_pool,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn init_tracing(config: &AppConfig) -> anyhow::Result<()> {
|
||||
let level = config.log_level()?;
|
||||
let filter = tracing_subscriber::EnvFilter::try_new(&level)?;
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(filter)
|
||||
.with_target(false)
|
||||
.init();
|
||||
Ok(())
|
||||
}
|
||||
82
app/gitpod/src/main.rs
Normal file
82
app/gitpod/src/main.rs
Normal file
@ -0,0 +1,82 @@
|
||||
mod context;
|
||||
mod shutdown;
|
||||
|
||||
use context::AppContext;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let ctx = AppContext::init().await?;
|
||||
|
||||
let http_port = ctx.config.git_http_port()?;
|
||||
let ssh_port = ctx.config.ssh_port()?;
|
||||
let rpc_addr = ctx.config.git_rpc_addr()?;
|
||||
let rpc_port = ctx.config.git_rpc_port()?;
|
||||
|
||||
tracing::info!(
|
||||
"gitpod service starting (HTTP:{} / SSH:{} / gRPC:{}:{})",
|
||||
http_port,
|
||||
ssh_port,
|
||||
rpc_addr,
|
||||
rpc_port
|
||||
);
|
||||
|
||||
let http_task = tokio::spawn(git::http::run_http(
|
||||
ctx.config.clone(),
|
||||
ctx.db.clone(),
|
||||
ctx.cache.clone(),
|
||||
ctx.redis_pool.clone(),
|
||||
));
|
||||
|
||||
let ssh_task = tokio::spawn(git::ssh::run_ssh(
|
||||
ctx.config.clone(),
|
||||
ctx.db.clone(),
|
||||
ctx.cache.clone(),
|
||||
ctx.redis_pool.clone(),
|
||||
));
|
||||
|
||||
let rpc_addr_parsed =
|
||||
format!("{}:{}", rpc_addr, rpc_port).parse::<std::net::SocketAddr>()?;
|
||||
let sync_service =
|
||||
git::sync::ReceiveSyncService::new(ctx.redis_pool.clone());
|
||||
let git_server = git::rpc::server::GitServer::new(
|
||||
rpc_addr_parsed,
|
||||
ctx.db.clone(),
|
||||
ctx.cache.clone(),
|
||||
sync_service,
|
||||
);
|
||||
let rpc_task = tokio::spawn(async move {
|
||||
git_server
|
||||
.serve()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
result = http_task => {
|
||||
match result {
|
||||
Ok(Ok(())) => tracing::info!("HTTP server stopped"),
|
||||
Ok(Err(e)) => tracing::error!("HTTP server error: {}", e),
|
||||
Err(e) => tracing::error!("HTTP task panicked: {}", e),
|
||||
}
|
||||
}
|
||||
result = ssh_task => {
|
||||
match result {
|
||||
Ok(Ok(())) => tracing::info!("SSH server stopped"),
|
||||
Ok(Err(e)) => tracing::error!("SSH server error: {}", e),
|
||||
Err(e) => tracing::error!("SSH task panicked: {}", e),
|
||||
}
|
||||
}
|
||||
result = rpc_task => {
|
||||
match result {
|
||||
Ok(Ok(())) => tracing::info!("gRPC server stopped"),
|
||||
Ok(Err(e)) => tracing::error!("gRPC server error: {}", e),
|
||||
Err(e) => tracing::error!("gRPC task panicked: {}", e),
|
||||
}
|
||||
}
|
||||
_ = shutdown::wait_for_shutdown_signal() => {
|
||||
tracing::info!("shutdown signal received, stopping gitpod service");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
25
app/gitpod/src/shutdown.rs
Normal file
25
app/gitpod/src/shutdown.rs
Normal file
@ -0,0 +1,25 @@
|
||||
pub async fn wait_for_shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to listen for ctrl_c event");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
tokio::signal::unix::signal(
|
||||
tokio::signal::unix::SignalKind::terminate(),
|
||||
)
|
||||
.expect("failed to listen for SIGTERM")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {},
|
||||
_ = terminate => {},
|
||||
}
|
||||
}
|
||||
36
app/gitsync/Cargo.toml
Normal file
36
app/gitsync/Cargo.toml
Normal file
@ -0,0 +1,36 @@
|
||||
[package]
|
||||
name = "app-gitsync"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
readme.workspace = true
|
||||
homepage.workspace = true
|
||||
license.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
documentation.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "gitsync"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
config = { workspace = true }
|
||||
cache = { workspace = true }
|
||||
db = { workspace = true }
|
||||
git = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "macros", "signal"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
|
||||
actix-web = { workspace = true }
|
||||
deadpool-redis = { workspace = true }
|
||||
redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp"] }
|
||||
uuid = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
sqlx = { workspace = true, features = ["postgres", "runtime-tokio"] }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
65
app/gitsync/src/context.rs
Normal file
65
app/gitsync/src/context.rs
Normal file
@ -0,0 +1,65 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use cache::{AppCache, AppCacheConfig};
|
||||
use config::AppConfig;
|
||||
use db::database::AppDatabase;
|
||||
use deadpool_redis::{PoolConfig, Runtime, Timeouts, cluster::Config};
|
||||
|
||||
pub struct AppContext {
|
||||
pub config: AppConfig,
|
||||
pub db: AppDatabase,
|
||||
pub cache: AppCache,
|
||||
pub redis_pool: deadpool_redis::cluster::Pool,
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
pub async fn init() -> anyhow::Result<Self> {
|
||||
let config = AppConfig::load();
|
||||
init_tracing(&config)?;
|
||||
|
||||
tracing::info!("initializing database");
|
||||
let db = AppDatabase::init(&config).await?;
|
||||
|
||||
tracing::info!("initializing cache");
|
||||
let cache_config = AppCacheConfig::try_from(&config)?;
|
||||
let cache = AppCache::init(cache_config).await?;
|
||||
|
||||
tracing::info!("initializing redis pool");
|
||||
let redis_urls = config.redis_urls()?;
|
||||
let pool_size = config.redis_pool_size()?;
|
||||
let connect_timeout = config.redis_connect_timeout()?;
|
||||
let acquire_timeout = config.redis_acquire_timeout()?;
|
||||
|
||||
let mut pool_config = PoolConfig::new(pool_size as usize);
|
||||
pool_config.timeouts = Timeouts {
|
||||
wait: Some(Duration::from_secs(acquire_timeout)),
|
||||
create: Some(Duration::from_secs(connect_timeout)),
|
||||
recycle: Some(Duration::from_secs(connect_timeout)),
|
||||
};
|
||||
|
||||
let cfg = Config {
|
||||
urls: Some(redis_urls),
|
||||
connections: None,
|
||||
pool: Some(pool_config),
|
||||
read_from_replicas: false,
|
||||
};
|
||||
let redis_pool = cfg.create_pool(Some(Runtime::Tokio1))?;
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
db,
|
||||
cache,
|
||||
redis_pool,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn init_tracing(config: &AppConfig) -> anyhow::Result<()> {
|
||||
let level = config.log_level()?;
|
||||
let filter = tracing_subscriber::EnvFilter::try_new(&level)?;
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(filter)
|
||||
.with_target(false)
|
||||
.init();
|
||||
Ok(())
|
||||
}
|
||||
99
app/gitsync/src/health.rs
Normal file
99
app/gitsync/src/health.rs
Normal file
@ -0,0 +1,99 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use actix_web::dev::Service;
|
||||
use actix_web::{App, HttpResponse, HttpServer, dev::Server, web};
|
||||
use cache::AppCache;
|
||||
use db::database::AppDatabase;
|
||||
|
||||
const REQUEST_LOG_EXCLUDED_PATHS: &[&str] = &[
|
||||
"/health",
|
||||
"/live",
|
||||
"/ready",
|
||||
"/metrics",
|
||||
"/favicon.ico",
|
||||
"/robots.txt",
|
||||
];
|
||||
|
||||
fn should_log_request(path: &str) -> bool {
|
||||
!REQUEST_LOG_EXCLUDED_PATHS.contains(&path)
|
||||
}
|
||||
|
||||
async fn health(
|
||||
db: web::Data<AppDatabase>,
|
||||
cache: web::Data<AppCache>,
|
||||
) -> HttpResponse {
|
||||
let db_ok = sqlx::query("SELECT 1").execute(db.reader()).await.is_ok();
|
||||
let cache_ok = cache.ping_cluster().await.is_ok();
|
||||
|
||||
if db_ok && cache_ok {
|
||||
HttpResponse::Ok().json(serde_json::json!({
|
||||
"status": "ok",
|
||||
"db": "ok",
|
||||
"cache": "ok",
|
||||
}))
|
||||
} else {
|
||||
HttpResponse::ServiceUnavailable().json(serde_json::json!({
|
||||
"status": "unhealthy",
|
||||
"db": if db_ok { "ok" } else { "error" },
|
||||
"cache": if cache_ok { "ok" } else { "error" },
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_health(
|
||||
port: u16,
|
||||
db: AppDatabase,
|
||||
cache: AppCache,
|
||||
) -> anyhow::Result<Server> {
|
||||
tracing::info!("health endpoint starting on 0.0.0.0:{}", port);
|
||||
|
||||
let srv = HttpServer::new(move || {
|
||||
App::new()
|
||||
.app_data(web::Data::new(db.clone()))
|
||||
.app_data(web::Data::new(cache.clone()))
|
||||
.wrap_fn(|req, srv| {
|
||||
let should_log = should_log_request(req.path());
|
||||
let method = req.method().clone();
|
||||
let path = req.path().to_owned();
|
||||
let peer_addr =
|
||||
req.connection_info().peer_addr().map(str::to_owned);
|
||||
let started_at = Instant::now();
|
||||
let fut = srv.call(req);
|
||||
|
||||
async move {
|
||||
match fut.await {
|
||||
Ok(res) => {
|
||||
if should_log {
|
||||
tracing::info!(
|
||||
method = %method,
|
||||
path = %path,
|
||||
status = res.status().as_u16(),
|
||||
elapsed_ms = started_at.elapsed().as_millis(),
|
||||
peer_addr = peer_addr.as_deref().unwrap_or("-"),
|
||||
"http request"
|
||||
);
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
Err(err) => {
|
||||
if should_log {
|
||||
tracing::warn!(
|
||||
method = %method,
|
||||
path = %path,
|
||||
elapsed_ms = started_at.elapsed().as_millis(),
|
||||
peer_addr = peer_addr.as_deref().unwrap_or("-"),
|
||||
error = %err,
|
||||
"http request failed"
|
||||
);
|
||||
}
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.route("/health", web::get().to(health))
|
||||
})
|
||||
.bind(format!("0.0.0.0:{}", port))?;
|
||||
|
||||
Ok(srv.run())
|
||||
}
|
||||
51
app/gitsync/src/main.rs
Normal file
51
app/gitsync/src/main.rs
Normal file
@ -0,0 +1,51 @@
|
||||
mod context;
|
||||
mod health;
|
||||
mod shutdown;
|
||||
|
||||
use context::AppContext;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let ctx = AppContext::init().await?;
|
||||
|
||||
tracing::info!("gitsync service starting");
|
||||
|
||||
let health_port = ctx.config.gitsync_health_port();
|
||||
let health_server =
|
||||
health::start_health(health_port, ctx.db.clone(), ctx.cache.clone())?;
|
||||
let health_handle = health_server.handle();
|
||||
let health_task = tokio::spawn(health_server);
|
||||
|
||||
let sync_service =
|
||||
git::sync::ReceiveSyncService::new(ctx.redis_pool.clone());
|
||||
let consumer = git::sync::consumer::SyncConsumer::new(sync_service, 5);
|
||||
let worker = git::sync::worker::SyncWorker::new(
|
||||
consumer,
|
||||
ctx.db.clone(),
|
||||
ctx.cache.clone(),
|
||||
ctx.redis_pool.clone(),
|
||||
ctx.config.clone(),
|
||||
format!("gitsync-{}", uuid::Uuid::new_v4()),
|
||||
);
|
||||
|
||||
let worker_task = tokio::spawn(async move { worker.run().await });
|
||||
|
||||
tokio::select! {
|
||||
result = health_task => {
|
||||
match result {
|
||||
Ok(Ok(())) => tracing::info!("health server stopped"),
|
||||
Ok(Err(e)) => tracing::error!("health server error: {}", e),
|
||||
Err(e) => tracing::error!("health task panicked: {}", e),
|
||||
}
|
||||
}
|
||||
_ = worker_task => {
|
||||
tracing::info!("sync worker stopped");
|
||||
}
|
||||
_ = shutdown::wait_for_shutdown_signal() => {
|
||||
tracing::info!("shutdown signal received, stopping gitsync service");
|
||||
health_handle.stop(true).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
25
app/gitsync/src/shutdown.rs
Normal file
25
app/gitsync/src/shutdown.rs
Normal file
@ -0,0 +1,25 @@
|
||||
pub async fn wait_for_shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to listen for ctrl_c event");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
tokio::signal::unix::signal(
|
||||
tokio::signal::unix::SignalKind::terminate(),
|
||||
)
|
||||
.expect("failed to listen for SIGTERM")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {},
|
||||
_ = terminate => {},
|
||||
}
|
||||
}
|
||||
137
docker/README.md
Normal file
137
docker/README.md
Normal file
@ -0,0 +1,137 @@
|
||||
# GitDataAI Docker 配置
|
||||
|
||||
## 文件说明
|
||||
|
||||
### Dockerfile 文件
|
||||
|
||||
| 文件名 | 服务 | 说明 |
|
||||
|--------|------|------|
|
||||
| `gitdata.Dockerfile` | GitData API | 主 API 服务 |
|
||||
| `email.Dockerfile` | Email Service | 邮件发送服务 |
|
||||
| `gitpod.Dockerfile` | GitPod Service | Git 服务 |
|
||||
| `gitsync.Dockerfile` | GitSync Service | Git 同步服务 |
|
||||
| `migrate.Dockerfile` | Database Migration | 数据库迁移工具 |
|
||||
| `web.Dockerfile` | Web Frontend | React 前端应用 |
|
||||
|
||||
### 配置文件
|
||||
|
||||
| 文件名 | 说明 |
|
||||
|--------|------|
|
||||
| `docker-compose.yml` | 完整的开发环境配置 |
|
||||
| `nginx.conf` | Nginx 反向代理配置 |
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 启动完整开发环境
|
||||
|
||||
```bash
|
||||
# 进入 docker 目录
|
||||
cd docker
|
||||
|
||||
# 启动所有服务
|
||||
docker-compose up -d
|
||||
|
||||
# 查看服务状态
|
||||
docker-compose ps
|
||||
|
||||
# 查看日志
|
||||
docker-compose logs -f
|
||||
```
|
||||
|
||||
### 2. 单独构建服务
|
||||
|
||||
```bash
|
||||
# 构建 GitData API
|
||||
docker build -f docker/gitdata.Dockerfile -t gitdata-api .
|
||||
|
||||
# 构建前端
|
||||
docker build -f docker/web.Dockerfile -t gitdata-web .
|
||||
```
|
||||
|
||||
### 3. 环境变量配置
|
||||
|
||||
创建 `.env` 文件配置环境变量:
|
||||
|
||||
```bash
|
||||
# 数据库配置
|
||||
POSTGRES_USER=gitdata
|
||||
POSTGRES_PASSWORD=your_secure_password
|
||||
POSTGRES_DB=app
|
||||
|
||||
# MinIO 配置
|
||||
MINIO_ROOT_USER=admin
|
||||
MINIO_ROOT_PASSWORD=your_secure_password
|
||||
```
|
||||
|
||||
## 服务端口
|
||||
|
||||
| 服务 | 端口 | 说明 |
|
||||
|------|------|------|
|
||||
| Web Frontend | 80 | 前端访问入口 |
|
||||
| GitData API | 8080 | 主 API 服务 |
|
||||
| Git HTTP | 5023 | Git HTTP 访问 |
|
||||
| Git RPC | 5030 | Git RPC 服务 |
|
||||
| SSH | 5022 | SSH Git 访问 |
|
||||
| GitPod | 5082 | GitPod 服务 |
|
||||
| GitSync | 5083 | GitSync 健康检查 |
|
||||
| PostgreSQL | 5432 | 数据库 |
|
||||
| Redis | 6379 | 缓存 |
|
||||
| Qdrant | 6333 | 向量数据库 |
|
||||
| NATS | 4222 | 消息队列 |
|
||||
| MinIO | 9000/9001 | 对象存储 |
|
||||
|
||||
## 生产环境部署
|
||||
|
||||
### 1. 修改环境变量
|
||||
|
||||
```bash
|
||||
# 复制示例配置
|
||||
cp .env.example .env
|
||||
|
||||
# 编辑配置文件,修改密码等敏感信息
|
||||
vim .env
|
||||
```
|
||||
|
||||
### 2. 启动服务
|
||||
|
||||
```bash
|
||||
# 使用生产配置启动
|
||||
docker-compose -f docker-compose.yml up -d
|
||||
|
||||
# 查看服务状态
|
||||
docker-compose ps
|
||||
```
|
||||
|
||||
### 3. 数据备份
|
||||
|
||||
```bash
|
||||
# 备份 PostgreSQL
|
||||
docker exec gitdata-postgres pg_dump -U gitdata app > backup.sql
|
||||
|
||||
# 备份 MinIO 数据
|
||||
docker cp gitdata-minio:/data ./minio-backup
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
### 1. 服务启动失败
|
||||
|
||||
检查日志:
|
||||
```bash
|
||||
docker-compose logs <service-name>
|
||||
```
|
||||
|
||||
### 2. 数据库连接失败
|
||||
|
||||
确保 PostgreSQL 健康检查通过:
|
||||
```bash
|
||||
docker-compose ps postgres
|
||||
```
|
||||
|
||||
### 3. 端口冲突
|
||||
|
||||
修改 `docker-compose.yml` 中的端口映射:
|
||||
```yaml
|
||||
ports:
|
||||
- "8081:8080" # 修改宿主机端口
|
||||
```
|
||||
131
docker/build.sh
Executable file
131
docker/build.sh
Executable file
@ -0,0 +1,131 @@
|
||||
#!/bin/bash
|
||||
# GitDataAI Docker Build Script
|
||||
|
||||
set -e
|
||||
|
||||
# Get version from Cargo.toml
|
||||
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
CARGO_VERSION=$(grep -m1 'version' "${PROJECT_ROOT}/Cargo.toml" | sed 's/.*"\(.*\)".*/\1/')
|
||||
|
||||
# Configuration
|
||||
REGISTRY=${REGISTRY:-""}
|
||||
TAG=${TAG:-"${CARGO_VERSION:-latest}"}
|
||||
PLATFORM=${PLATFORM:-"linux/amd64"}
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Services to build
|
||||
SERVICES=("gitdata" "email" "gitpod" "gitsync" "migrate" "web")
|
||||
|
||||
# Function to print colored output
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Function to build a service
|
||||
build_service() {
|
||||
local service=$1
|
||||
local dockerfile="${PROJECT_ROOT}/docker/${service}.Dockerfile"
|
||||
local image_name="gitdata-${service}"
|
||||
|
||||
# Add registry prefix if set
|
||||
if [ -n "$REGISTRY" ]; then
|
||||
image_name="${REGISTRY}/${image_name}"
|
||||
fi
|
||||
|
||||
log_info "Building ${service}..."
|
||||
|
||||
if [ ! -f "$dockerfile" ]; then
|
||||
log_error "Dockerfile not found: ${dockerfile}"
|
||||
return 1
|
||||
fi
|
||||
|
||||
docker build \
|
||||
-f "$dockerfile" \
|
||||
-t "${image_name}:${TAG}" \
|
||||
--platform "$PLATFORM" \
|
||||
"$PROJECT_ROOT"
|
||||
|
||||
log_info "Successfully built ${image_name}:${TAG}"
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
BUILD_SERVICES=()
|
||||
BUILD_ALL=true
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--tag|-t)
|
||||
TAG="$2"
|
||||
shift 2
|
||||
;;
|
||||
--registry|-r)
|
||||
REGISTRY="$2"
|
||||
shift 2
|
||||
;;
|
||||
--platform|-p)
|
||||
PLATFORM="$2"
|
||||
shift 2
|
||||
;;
|
||||
--help|-h)
|
||||
echo "Usage: $0 [OPTIONS] [SERVICE...]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " -t, --tag TAG Docker image tag (default: latest)"
|
||||
echo " -r, --registry REG Docker registry prefix"
|
||||
echo " -p, --platform PLAT Target platform (default: linux/amd64)"
|
||||
echo " -h, --help Show this help message"
|
||||
echo ""
|
||||
echo "Services:"
|
||||
echo " gitdata Main API service"
|
||||
echo " email Email service"
|
||||
echo " gitpod GitPod service"
|
||||
echo " gitsync GitSync service"
|
||||
echo " migrate Database migration"
|
||||
echo " web Web frontend"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " $0 # Build all services"
|
||||
echo " $0 gitdata web # Build specific services"
|
||||
echo " $0 -t v1.0.0 -r registry.com # Build with custom tag and registry"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
BUILD_SERVICES+=("$1")
|
||||
BUILD_ALL=false
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Build services
|
||||
log_info "Starting Docker build..."
|
||||
log_info "Registry: ${REGISTRY:-none}"
|
||||
log_info "Tag: ${TAG}"
|
||||
log_info "Platform: ${PLATFORM}"
|
||||
|
||||
if [ "$BUILD_ALL" = true ]; then
|
||||
log_info "Building all services..."
|
||||
for service in "${SERVICES[@]}"; do
|
||||
build_service "$service"
|
||||
done
|
||||
else
|
||||
log_info "Building specified services: ${BUILD_SERVICES[*]}"
|
||||
for service in "${BUILD_SERVICES[@]}"; do
|
||||
build_service "$service"
|
||||
done
|
||||
fi
|
||||
|
||||
log_info "Build completed successfully!"
|
||||
203
docker/docker-compose.yml
Normal file
203
docker/docker-compose.yml
Normal file
@ -0,0 +1,203 @@
|
||||
# GitDataAI Docker Compose
|
||||
# Full stack deployment configuration
|
||||
|
||||
services:
|
||||
# PostgreSQL Database
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
container_name: gitdata-postgres
|
||||
environment:
|
||||
POSTGRES_USER: ${POSTGRES_USER:-gitdata}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-gitdata123}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-app}
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
ports:
|
||||
- "5432:5432"
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-gitdata}"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
|
||||
# Redis Cluster
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: gitdata-redis
|
||||
ports:
|
||||
- "6379:6379"
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
|
||||
# Qdrant Vector Database
|
||||
qdrant:
|
||||
image: qdrant/qdrant:latest
|
||||
container_name: gitdata-qdrant
|
||||
ports:
|
||||
- "6333:6333"
|
||||
volumes:
|
||||
- qdrant_data:/qdrant/storage
|
||||
restart: unless-stopped
|
||||
|
||||
# NATS Message Queue
|
||||
nats:
|
||||
image: nats:alpine
|
||||
container_name: gitdata-nats
|
||||
ports:
|
||||
- "4222:4222"
|
||||
- "8222:8222"
|
||||
command: "--jetstream"
|
||||
restart: unless-stopped
|
||||
|
||||
# MinIO S3 Storage
|
||||
minio:
|
||||
image: minio/minio:latest
|
||||
container_name: gitdata-minio
|
||||
command: server /data --console-address ":9001"
|
||||
environment:
|
||||
MINIO_ROOT_USER: ${MINIO_ROOT_USER:-admin}
|
||||
MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-mysecret123}
|
||||
ports:
|
||||
- "9000:9000"
|
||||
- "9001:9001"
|
||||
volumes:
|
||||
- minio_data:/data
|
||||
restart: unless-stopped
|
||||
|
||||
# Database Migration
|
||||
migrate:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/migrate.Dockerfile
|
||||
container_name: gitdata-migrate
|
||||
environment:
|
||||
DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app}
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
restart: "no"
|
||||
|
||||
# GitData Main API Service
|
||||
gitdata:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/gitdata.Dockerfile
|
||||
container_name: gitdata-api
|
||||
environment:
|
||||
APP_DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app}
|
||||
APP_REDIS_URLS: redis://redis:6379
|
||||
APP_QDRANT_URL: http://qdrant:6333/
|
||||
NATS_URL: nats://nats:4222
|
||||
APP_STORAGE_S3_ENDPOINT_URL: http://minio:9000
|
||||
APP_STORAGE_S3_ACCESS_KEY_ID: ${MINIO_ROOT_USER:-admin}
|
||||
APP_STORAGE_S3_SECRET_ACCESS_KEY: ${MINIO_ROOT_PASSWORD:-mysecret123}
|
||||
ports:
|
||||
- "8080:8080"
|
||||
- "5023:5023"
|
||||
- "5030:5030"
|
||||
- "5022:5022"
|
||||
volumes:
|
||||
- gitdata_repos:/app/data/repos
|
||||
- gitdata_files:/app/data/files
|
||||
- gitdata_avatar:/app/data/avatar
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
qdrant:
|
||||
condition: service_started
|
||||
nats:
|
||||
condition: service_started
|
||||
minio:
|
||||
condition: service_started
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
restart: unless-stopped
|
||||
|
||||
# Email Service
|
||||
email:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/email.Dockerfile
|
||||
container_name: gitdata-email
|
||||
environment:
|
||||
APP_DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app}
|
||||
APP_REDIS_URLS: redis://redis:6379
|
||||
NATS_URL: nats://nats:4222
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
nats:
|
||||
condition: service_started
|
||||
restart: unless-stopped
|
||||
|
||||
# GitPod Service
|
||||
gitpod:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/gitpod.Dockerfile
|
||||
container_name: gitdata-gitpod
|
||||
environment:
|
||||
APP_DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app}
|
||||
APP_REDIS_URLS: redis://redis:6379
|
||||
ports:
|
||||
- "5082:5082"
|
||||
volumes:
|
||||
- gitdata_repos:/app/data/repos
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
|
||||
# GitSync Service
|
||||
gitsync:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/gitsync.Dockerfile
|
||||
container_name: gitdata-gitsync
|
||||
environment:
|
||||
APP_DATABASE_URL: postgres://${POSTGRES_USER:-gitdata}:${POSTGRES_PASSWORD:-gitdata123}@postgres:5432/${POSTGRES_DB:-app}
|
||||
APP_REDIS_URLS: redis://redis:6379
|
||||
ports:
|
||||
- "5083:5083"
|
||||
volumes:
|
||||
- gitdata_repos:/app/data/repos
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
|
||||
# Web Frontend
|
||||
web:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/web.Dockerfile
|
||||
container_name: gitdata-web
|
||||
ports:
|
||||
- "80:80"
|
||||
depends_on:
|
||||
- gitdata
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
redis_data:
|
||||
qdrant_data:
|
||||
minio_data:
|
||||
gitdata_repos:
|
||||
gitdata_files:
|
||||
gitdata_avatar:
|
||||
74
docker/gitdata.Dockerfile
Normal file
74
docker/gitdata.Dockerfile
Normal file
@ -0,0 +1,74 @@
|
||||
# GitDataAI Backend - GitData Service
|
||||
# Multi-stage build for Rust application
|
||||
|
||||
# Stage 1: Build the application
|
||||
FROM rust:1.96-bookworm AS builder
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
libpq-dev \
|
||||
cmake \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy workspace files
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
COPY app/ app/
|
||||
COPY lib/ lib/
|
||||
|
||||
# Build the application in release mode
|
||||
RUN cargo build --release --bin gitdata
|
||||
|
||||
# Stage 2: Create runtime image
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libssl3 \
|
||||
libpq5 \
|
||||
ca-certificates \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -r -s /bin/false appuser
|
||||
|
||||
# Create directories
|
||||
RUN mkdir -p /app/data/repos \
|
||||
/app/data/files \
|
||||
/app/data/avatar \
|
||||
/app/logs \
|
||||
&& chown -R appuser:appuser /app
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=builder /app/target/release/gitdata /app/gitdata
|
||||
|
||||
# Set ownership
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Expose ports
|
||||
# API port
|
||||
EXPOSE 8080
|
||||
# Git HTTP port
|
||||
EXPOSE 5023
|
||||
# Git RPC port
|
||||
EXPOSE 5030
|
||||
# SSH port
|
||||
EXPOSE 5022
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8080/health || exit 1
|
||||
|
||||
# Run the application
|
||||
CMD ["./gitdata"]
|
||||
65
docker/gitpod.Dockerfile
Normal file
65
docker/gitpod.Dockerfile
Normal file
@ -0,0 +1,65 @@
|
||||
# GitDataAI Backend - GitPod Service
|
||||
# Multi-stage build for Rust application
|
||||
|
||||
# Stage 1: Build the application
|
||||
FROM rust:1.96-bookworm AS builder
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
libpq-dev \
|
||||
cmake \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy workspace files
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
COPY app/ app/
|
||||
COPY lib/ lib/
|
||||
|
||||
# Build the application in release mode
|
||||
RUN cargo build --release --bin gitpod
|
||||
|
||||
# Stage 2: Create runtime image
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libssl3 \
|
||||
libpq5 \
|
||||
ca-certificates \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -r -s /bin/false appuser
|
||||
|
||||
# Create directories
|
||||
RUN mkdir -p /app/data/repos \
|
||||
/app/logs \
|
||||
&& chown -R appuser:appuser /app
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=builder /app/target/release/gitpod /app/gitpod
|
||||
|
||||
# Set ownership
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Expose port
|
||||
EXPOSE 5082
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:5082/health || exit 1
|
||||
|
||||
# Run the application
|
||||
CMD ["./gitpod"]
|
||||
65
docker/gitsync.Dockerfile
Normal file
65
docker/gitsync.Dockerfile
Normal file
@ -0,0 +1,65 @@
|
||||
# GitDataAI Backend - GitSync Service
|
||||
# Multi-stage build for Rust application
|
||||
|
||||
# Stage 1: Build the application
|
||||
FROM rust:1.96-bookworm AS builder
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
libpq-dev \
|
||||
cmake \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy workspace files
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
COPY app/ app/
|
||||
COPY lib/ lib/
|
||||
|
||||
# Build the application in release mode
|
||||
RUN cargo build --release --bin gitsync
|
||||
|
||||
# Stage 2: Create runtime image
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libssl3 \
|
||||
libpq5 \
|
||||
ca-certificates \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -r -s /bin/false appuser
|
||||
|
||||
# Create directories
|
||||
RUN mkdir -p /app/data/repos \
|
||||
/app/logs \
|
||||
&& chown -R appuser:appuser /app
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=builder /app/target/release/gitsync /app/gitsync
|
||||
|
||||
# Set ownership
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Expose health check port
|
||||
EXPOSE 5083
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:5083/health || exit 1
|
||||
|
||||
# Run the application
|
||||
CMD ["./gitsync"]
|
||||
58
docker/migrate.Dockerfile
Normal file
58
docker/migrate.Dockerfile
Normal file
@ -0,0 +1,58 @@
|
||||
# GitDataAI Database Migration Dockerfile
|
||||
# Multi-stage build for Rust migration tool
|
||||
|
||||
# Stage 1: Build the application
|
||||
FROM rust:1.96-bookworm AS builder
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
libpq-dev \
|
||||
cmake \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy workspace files
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
COPY app/ app/
|
||||
COPY lib/ lib/
|
||||
|
||||
# Build the migration binary
|
||||
RUN cargo build --release --bin migrate
|
||||
|
||||
# Stage 2: Create runtime image
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
libssl3 \
|
||||
libpq5 \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -r -s /bin/false appuser
|
||||
|
||||
# Create app directory
|
||||
RUN mkdir -p /app && chown -R appuser:appuser /app
|
||||
|
||||
# Copy binary from builder
|
||||
COPY --from=builder /app/target/release/migrate /app/migrate
|
||||
|
||||
# Copy migration files
|
||||
COPY --from=builder /app/lib/migrate/sql /app/sql
|
||||
|
||||
# Set ownership
|
||||
RUN chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Run migrations by default
|
||||
CMD ["./migrate", "up"]
|
||||
75
docker/nginx.conf
Normal file
75
docker/nginx.conf
Normal file
@ -0,0 +1,75 @@
|
||||
server {
|
||||
listen 80;
|
||||
server_name localhost;
|
||||
|
||||
# Gzip compression
|
||||
gzip on;
|
||||
gzip_vary on;
|
||||
gzip_min_length 1024;
|
||||
gzip_proxied any;
|
||||
gzip_comp_level 6;
|
||||
gzip_types
|
||||
text/plain
|
||||
text/css
|
||||
text/xml
|
||||
text/javascript
|
||||
application/json
|
||||
application/javascript
|
||||
application/xml
|
||||
application/rss+xml
|
||||
image/svg+xml;
|
||||
|
||||
# Security headers
|
||||
add_header X-Frame-Options "SAMEORIGIN" always;
|
||||
add_header X-Content-Type-Options "nosniff" always;
|
||||
add_header X-XSS-Protection "1; mode=block" always;
|
||||
add_header Referrer-Policy "strict-origin-when-cross-origin" always;
|
||||
|
||||
# Root directory
|
||||
root /usr/share/nginx/html;
|
||||
index index.html;
|
||||
|
||||
# Enable static asset caching
|
||||
location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {
|
||||
expires 1y;
|
||||
add_header Cache-Control "public, immutable";
|
||||
try_files $uri =404;
|
||||
}
|
||||
|
||||
# API proxy (if needed)
|
||||
location /api/ {
|
||||
proxy_pass http://gitdata:8080/api/;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection 'upgrade';
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_cache_bypass $http_upgrade;
|
||||
}
|
||||
|
||||
# Socket.IO proxy
|
||||
location /socket.io/ {
|
||||
proxy_pass http://gitdata:8080/socket.io/;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
}
|
||||
|
||||
# SPA fallback
|
||||
location / {
|
||||
try_files $uri $uri/ /index.html;
|
||||
}
|
||||
|
||||
# Health check endpoint
|
||||
location /health {
|
||||
access_log off;
|
||||
return 200 'OK';
|
||||
add_header Content-Type text/plain;
|
||||
}
|
||||
}
|
||||
128
docker/push.sh
Executable file
128
docker/push.sh
Executable file
@ -0,0 +1,128 @@
|
||||
#!/bin/bash
|
||||
# GitDataAI Docker Push Script
|
||||
|
||||
set -e
|
||||
|
||||
# Get version from Cargo.toml
|
||||
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
||||
CARGO_VERSION=$(grep -m1 'version' "${PROJECT_ROOT}/Cargo.toml" | sed 's/.*"\(.*\)".*/\1/')
|
||||
|
||||
# Configuration
|
||||
REGISTRY=${REGISTRY:-""}
|
||||
TAG=${TAG:-"${CARGO_VERSION:-latest}"}
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Services to push
|
||||
SERVICES=("gitdata" "email" "gitpod" "gitsync" "migrate" "web")
|
||||
|
||||
# Function to print colored output
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Function to push a service
|
||||
push_service() {
|
||||
local service=$1
|
||||
local image_name="gitdata-${service}"
|
||||
|
||||
# Add registry prefix if set
|
||||
if [ -n "$REGISTRY" ]; then
|
||||
image_name="${REGISTRY}/${image_name}"
|
||||
fi
|
||||
|
||||
log_info "Pushing ${service}..."
|
||||
|
||||
# Check if image exists locally
|
||||
if ! docker image inspect "${image_name}:${TAG}" > /dev/null 2>&1; then
|
||||
log_error "Image not found: ${image_name}:${TAG}"
|
||||
log_error "Please build the image first with: ./build.sh ${service}"
|
||||
return 1
|
||||
fi
|
||||
|
||||
docker push "${image_name}:${TAG}"
|
||||
|
||||
log_info "Successfully pushed ${image_name}:${TAG}"
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
PUSH_SERVICES=()
|
||||
PUSH_ALL=true
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--tag|-t)
|
||||
TAG="$2"
|
||||
shift 2
|
||||
;;
|
||||
--registry|-r)
|
||||
REGISTRY="$2"
|
||||
shift 2
|
||||
;;
|
||||
--help|-h)
|
||||
echo "Usage: $0 [OPTIONS] [SERVICE...]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " -t, --tag TAG Docker image tag (default: latest)"
|
||||
echo " -r, --registry REG Docker registry prefix (required)"
|
||||
echo " -h, --help Show this help message"
|
||||
echo ""
|
||||
echo "Services:"
|
||||
echo " gitdata Main API service"
|
||||
echo " email Email service"
|
||||
echo " gitpod GitPod service"
|
||||
echo " gitsync GitSync service"
|
||||
echo " migrate Database migration"
|
||||
echo " web Web frontend"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " $0 -r registry.com # Push all services"
|
||||
echo " $0 -r registry.com gitdata web # Push specific services"
|
||||
echo " $0 -r registry.com -t v1.0.0 # Push with custom tag"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
PUSH_SERVICES+=("$1")
|
||||
PUSH_ALL=false
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Validate registry
|
||||
if [ -z "$REGISTRY" ]; then
|
||||
log_error "Registry is required. Use -r or --registry to specify."
|
||||
echo "Example: $0 -r registry.com"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Push services
|
||||
log_info "Starting Docker push..."
|
||||
log_info "Registry: ${REGISTRY}"
|
||||
log_info "Tag: ${TAG}"
|
||||
|
||||
if [ "$PUSH_ALL" = true ]; then
|
||||
log_info "Pushing all services..."
|
||||
for service in "${SERVICES[@]}"; do
|
||||
push_service "$service"
|
||||
done
|
||||
else
|
||||
log_info "Pushing specified services: ${PUSH_SERVICES[*]}"
|
||||
for service in "${PUSH_SERVICES[@]}"; do
|
||||
push_service "$service"
|
||||
done
|
||||
fi
|
||||
|
||||
log_info "Push completed successfully!"
|
||||
62
docker/web.Dockerfile
Normal file
62
docker/web.Dockerfile
Normal file
@ -0,0 +1,62 @@
|
||||
# GitDataAI Frontend Dockerfile
|
||||
# Multi-stage build for React application with Bun
|
||||
|
||||
# Stage 1: Build the application
|
||||
FROM node:24-bookworm AS builder
|
||||
|
||||
# Install bun
|
||||
RUN npm install -g bun
|
||||
|
||||
# Create app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy package files
|
||||
COPY package.json bun.lock ./
|
||||
|
||||
# Install dependencies
|
||||
RUN bun install --frozen-lockfile
|
||||
|
||||
# Copy source code
|
||||
COPY src/ src/
|
||||
COPY public/ public/
|
||||
COPY index.html ./
|
||||
COPY vite.config.ts ./
|
||||
COPY tsconfig*.json ./
|
||||
COPY eslint.config.js ./
|
||||
COPY components.json ./
|
||||
COPY orval.config.ts ./
|
||||
|
||||
# Build the application
|
||||
RUN bun run build
|
||||
|
||||
# Stage 2: Create runtime image with Nginx
|
||||
FROM nginx:alpine
|
||||
|
||||
# Copy custom nginx configuration
|
||||
COPY docker/nginx.conf /etc/nginx/conf.d/default.conf
|
||||
|
||||
# Copy built assets from builder
|
||||
COPY --from=builder /app/dist /usr/share/nginx/html
|
||||
|
||||
# Create non-root user
|
||||
RUN adduser -D -S -h /var/cache/nginx -s /sbin/nologin -G nginx appuser
|
||||
|
||||
# Set ownership
|
||||
RUN chown -R appuser:nginx /var/cache/nginx \
|
||||
&& chown -R appuser:nginx /var/log/nginx \
|
||||
&& chown -R appuser:nginx /etc/nginx/conf.d \
|
||||
&& touch /var/run/nginx.pid \
|
||||
&& chown -R appuser:nginx /var/run/nginx.pid
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Expose port
|
||||
EXPOSE 80
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD wget --no-verbose --tries=1 --spider http://localhost:80/ || exit 1
|
||||
|
||||
# Start Nginx
|
||||
CMD ["nginx", "-g", "daemon off;"]
|
||||
37
lib/ai/Cargo.toml
Normal file
37
lib/ai/Cargo.toml
Normal file
@ -0,0 +1,37 @@
|
||||
[package]
|
||||
name = "ai"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
readme.workspace = true
|
||||
homepage.workspace = true
|
||||
license.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
documentation.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "lib.rs"
|
||||
name = "ai"
|
||||
[dependencies]
|
||||
rig-core = { workspace = true, features = ["derive"] }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
tokio-util = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
config = { workspace = true }
|
||||
cache = { workspace = true }
|
||||
db = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
qdrant-client = { workspace = true, features = ["serde"] }
|
||||
async-trait = { workspace = true }
|
||||
redis = { workspace = true }
|
||||
uuid = { workspace = true, features = ["v4", "v5", "serde"] }
|
||||
reqwest = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
[lints]
|
||||
workspace = true
|
||||
543
lib/ai/agent/agent.rs
Normal file
543
lib/ai/agent/agent.rs
Normal file
@ -0,0 +1,543 @@
|
||||
use futures::StreamExt;
|
||||
use rig::agent::AgentBuilder;
|
||||
use rig::client::CompletionClient;
|
||||
use rig::streaming::StreamingPrompt;
|
||||
use rig::tool::ToolDyn;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::config::AgentConfig;
|
||||
use super::helpers::{build_input_string, check_token_budget, estimate_tokens};
|
||||
use super::hooks::{HookChain, HookLlmResponse, HookMessage, HookToolDef, ToolCallOutcome, ToolGuardrailDecision};
|
||||
use super::persistence::ActiveAgentRun;
|
||||
use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord};
|
||||
use super::subagent::run_experts;
|
||||
use super::RigStreamChunk;
|
||||
use crate::client::AiClient;
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
pub struct RigAgent {
|
||||
pub client: AiClient,
|
||||
pub config: AgentConfig,
|
||||
pub hooks: HookChain,
|
||||
}
|
||||
|
||||
impl RigAgent {
|
||||
pub fn new(client: AiClient, config: AgentConfig) -> AiResult<Self> {
|
||||
config.validate()?;
|
||||
Ok(Self {
|
||||
client,
|
||||
config,
|
||||
hooks: HookChain::empty(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_hooks(mut self, hooks: HookChain) -> Self {
|
||||
self.hooks = hooks;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &AgentConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub async fn chat(
|
||||
&self,
|
||||
request: AgentRequest,
|
||||
tools: Vec<Box<dyn ToolDyn>>,
|
||||
) -> AiResult<String> {
|
||||
let (mut rx, handle) = self.run(request, tools);
|
||||
tokio::spawn(async move {
|
||||
while rx.recv().await.is_some() {}
|
||||
});
|
||||
let result = handle.await.map_err(|_| {
|
||||
AiError::Response("agent task panicked".to_string())
|
||||
})?;
|
||||
result.map(|r| r.output)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn run(
|
||||
&self,
|
||||
request: AgentRequest,
|
||||
tools: Vec<Box<dyn ToolDyn>>,
|
||||
) -> (
|
||||
tokio::sync::mpsc::Receiver<RigStreamChunk>,
|
||||
tokio::task::JoinHandle<AiResult<AgentResult>>,
|
||||
) {
|
||||
let (tx, rx) = mpsc::channel::<RigStreamChunk>(256);
|
||||
|
||||
let model_name = self.config.model.clone();
|
||||
let max_iterations = self.config.max_iterations;
|
||||
let client = self.client.llm_client().clone();
|
||||
let ai_client = self.client.clone();
|
||||
let agent_config = self.config.clone();
|
||||
let system_prompt = self.config.system_prompt.clone();
|
||||
let temperature = self.config.temperature;
|
||||
let max_completion_tokens = self.config.max_completion_tokens;
|
||||
let max_total_tokens = self.config.max_total_tokens_per_run;
|
||||
let cancellation = request.cancellation_token.clone();
|
||||
let timeout = request.timeout;
|
||||
let hooks = self.hooks.clone();
|
||||
|
||||
let filtered_tools: Vec<Box<dyn ToolDyn>> = tools
|
||||
.into_iter()
|
||||
.filter(|tool| self.config.is_tool_exposed(&tool.name()))
|
||||
.collect();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
execute_agent_run(
|
||||
client,
|
||||
model_name,
|
||||
system_prompt,
|
||||
request,
|
||||
filtered_tools,
|
||||
max_iterations,
|
||||
ai_client,
|
||||
agent_config,
|
||||
temperature,
|
||||
max_completion_tokens,
|
||||
max_total_tokens,
|
||||
cancellation,
|
||||
timeout,
|
||||
hooks,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
(rx, handle)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
|
||||
async fn execute_agent_run(
|
||||
client: rig::providers::openai::Client,
|
||||
model_name: String,
|
||||
system_prompt: String,
|
||||
request: AgentRequest,
|
||||
tools: Vec<Box<dyn ToolDyn>>,
|
||||
max_iterations: usize,
|
||||
ai_client: AiClient,
|
||||
agent_config: AgentConfig,
|
||||
temperature: Option<f64>,
|
||||
max_completion_tokens: Option<u64>,
|
||||
max_total_tokens: Option<i64>,
|
||||
cancellation: Option<CancellationToken>,
|
||||
timeout: Option<std::time::Duration>,
|
||||
hooks: HookChain,
|
||||
tx: mpsc::Sender<RigStreamChunk>,
|
||||
) -> AiResult<AgentResult> {
|
||||
if let Some(ref ctx) = request.run_context {
|
||||
let _ = hooks.run_session_start(ctx).await;
|
||||
}
|
||||
|
||||
let model = client.completion_model(&model_name);
|
||||
let mut agent_builder = AgentBuilder::new(model)
|
||||
.preamble(&system_prompt)
|
||||
.tools(tools)
|
||||
.default_max_turns(max_iterations);
|
||||
|
||||
if let Some(temp) = temperature {
|
||||
agent_builder = agent_builder.temperature(temp);
|
||||
}
|
||||
if let Some(mt) = max_completion_tokens {
|
||||
agent_builder = agent_builder.max_tokens(mt);
|
||||
}
|
||||
|
||||
let agent = agent_builder.build();
|
||||
let mut input = build_input_string(&request);
|
||||
|
||||
// ---- SubAgent execution ----
|
||||
let expert_outputs = if !request.experts.is_empty() {
|
||||
let run = ActiveAgentRun {
|
||||
conversation_id: request.run_context.as_ref().and_then(|c| c.conversation_id),
|
||||
message_id: None,
|
||||
invocation_id: request.run_context.as_ref().and_then(|c| c.invocation_id),
|
||||
session_id: request.run_context.as_ref().and_then(|c| c.session_id),
|
||||
user_id: request.run_context.as_ref().and_then(|c| c.user_id),
|
||||
started_at: std::time::Instant::now(),
|
||||
current_step: 0,
|
||||
};
|
||||
let realtime = request.run_context.as_ref().and_then(|c| c.realtime.as_ref());
|
||||
|
||||
// Notify frontend that subagents are starting.
|
||||
for expert in &request.experts {
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::SubagentStarted {
|
||||
subagent_id: expert.id.clone(),
|
||||
role: expert.role.clone(),
|
||||
task: expert.task.clone(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
match run_experts(&ai_client, &agent_config, &request.experts, realtime, &run).await {
|
||||
Ok(outputs) => {
|
||||
for out in &outputs {
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::SubagentCompleted {
|
||||
subagent_id: out.id.clone(),
|
||||
role: out.role.clone(),
|
||||
task: out.task.clone(),
|
||||
output: out.output.clone(),
|
||||
})
|
||||
.await;
|
||||
input.push_str(&format!(
|
||||
"\n--- Subagent: {} (role: {}) ---\nTask: {}\nResult: {}\n",
|
||||
out.id, out.role, out.task, out.output
|
||||
));
|
||||
}
|
||||
outputs
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "subagent execution failed, continuing without expert inputs");
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::SubagentFailed {
|
||||
error: e.to_string(),
|
||||
})
|
||||
.await;
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let estimated_input_tokens = estimate_tokens(&input);
|
||||
|
||||
if let Some(limit) = max_total_tokens
|
||||
&& estimated_input_tokens > limit as u64
|
||||
{
|
||||
return Err(AiError::TokenBudgetExceeded {
|
||||
estimated: estimated_input_tokens,
|
||||
limit,
|
||||
});
|
||||
}
|
||||
|
||||
if !hooks.is_empty() {
|
||||
let hook_messages: Vec<HookMessage> = request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|m| HookMessage {
|
||||
role: match m {
|
||||
super::request::AgentMessage::User(_) => "user".to_string(),
|
||||
super::request::AgentMessage::Assistant(_) => {
|
||||
"assistant".to_string()
|
||||
}
|
||||
},
|
||||
content: match m {
|
||||
super::request::AgentMessage::User(c) => Some(c.clone()),
|
||||
super::request::AgentMessage::Assistant(c) => {
|
||||
Some(c.clone())
|
||||
}
|
||||
},
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
})
|
||||
.collect();
|
||||
let hook_tools: Vec<HookToolDef> = Vec::new();
|
||||
let _ = hooks.run_pre_llm_call(&hook_messages, &hook_tools).await;
|
||||
}
|
||||
|
||||
let stream_future = agent
|
||||
.stream_prompt(&input)
|
||||
.with_history(Vec::<rig::completion::Message>::new())
|
||||
.multi_turn(max_iterations);
|
||||
|
||||
let stream = if let Some(dur) = timeout {
|
||||
match tokio::time::timeout(dur, stream_future).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_elapsed) => {
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::Failed {
|
||||
error: format!("agent timed out after {}s", dur.as_secs()),
|
||||
})
|
||||
.await;
|
||||
return Err(AiError::Timeout {
|
||||
seconds: dur.as_secs(),
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stream_future.await
|
||||
};
|
||||
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut steps = Vec::new();
|
||||
let mut delta_index = 0usize;
|
||||
let mut current_step_tool_calls: Vec<ToolCallRecord> = Vec::new();
|
||||
let mut current_step_assistant = String::new();
|
||||
let mut current_step_reasoning = String::new();
|
||||
let mut accumulated_output_chars: usize = 0;
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::Failed {
|
||||
error: "cancelled".to_string(),
|
||||
})
|
||||
.await;
|
||||
return Err(AiError::Response("agent run cancelled".to_string()));
|
||||
}
|
||||
|
||||
if let Some(limit) = max_total_tokens
|
||||
&& check_token_budget(estimated_input_tokens, accumulated_output_chars, limit)
|
||||
{
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::Failed {
|
||||
error: format!("token budget exceeded: limit {limit}"),
|
||||
})
|
||||
.await;
|
||||
return Err(AiError::TokenBudgetExceeded {
|
||||
estimated: estimated_input_tokens
|
||||
+ (accumulated_output_chars as f64 / 2.5).ceil() as u64,
|
||||
limit,
|
||||
});
|
||||
}
|
||||
|
||||
match item {
|
||||
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
||||
rig::streaming::StreamedAssistantContent::Text(text),
|
||||
)) => {
|
||||
accumulated_output_chars += text.text.chars().count();
|
||||
current_step_assistant.push_str(&text.text);
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::TextDelta {
|
||||
index: delta_index,
|
||||
content: text.text.clone(),
|
||||
})
|
||||
.await;
|
||||
delta_index += 1;
|
||||
}
|
||||
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
||||
rig::streaming::StreamedAssistantContent::Reasoning(reasoning),
|
||||
)) => {
|
||||
for part in &reasoning.content {
|
||||
if let rig::completion::message::ReasoningContent::Text {
|
||||
text, ..
|
||||
} = part
|
||||
{
|
||||
accumulated_output_chars += text.chars().count();
|
||||
current_step_reasoning.push_str(text);
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::Thinking {
|
||||
index: delta_index,
|
||||
content: text.clone(),
|
||||
})
|
||||
.await;
|
||||
delta_index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
||||
rig::streaming::StreamedAssistantContent::ReasoningDelta {
|
||||
reasoning, ..
|
||||
},
|
||||
)) => {
|
||||
accumulated_output_chars += reasoning.chars().count();
|
||||
current_step_reasoning.push_str(&reasoning);
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::Thinking {
|
||||
index: delta_index,
|
||||
content: reasoning.clone(),
|
||||
})
|
||||
.await;
|
||||
delta_index += 1;
|
||||
}
|
||||
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
||||
rig::streaming::StreamedAssistantContent::ToolCall {
|
||||
tool_call,
|
||||
internal_call_id: _,
|
||||
},
|
||||
)) => {
|
||||
let args = match &tool_call.function.arguments {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
v => serde_json::to_string(v).unwrap_or_default(),
|
||||
};
|
||||
accumulated_output_chars += args.chars().count();
|
||||
|
||||
let tool_name = tool_call.function.name.clone();
|
||||
let tool_args: serde_json::Value =
|
||||
serde_json::from_str(&args).unwrap_or_default();
|
||||
|
||||
if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await {
|
||||
match decision {
|
||||
ToolGuardrailDecision::Allow => {}
|
||||
ToolGuardrailDecision::Block { reason } => {
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::ToolCallFinished {
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
tool_name: tool_name.clone(),
|
||||
output: format!("blocked: {reason}"),
|
||||
error: Some(reason),
|
||||
})
|
||||
.await;
|
||||
current_step_tool_calls.push(ToolCallRecord {
|
||||
id: tool_call.id.clone(),
|
||||
name: tool_name.clone(),
|
||||
arguments: tool_args.clone(),
|
||||
output: None,
|
||||
error: Some("blocked by guardrail".to_string()),
|
||||
elapsed_ms: None,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
ToolGuardrailDecision::RequireApproval { message } => {
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::ToolCallFinished {
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
tool_name: tool_name.clone(),
|
||||
output: format!("awaiting approval: {message}"),
|
||||
error: None,
|
||||
})
|
||||
.await;
|
||||
current_step_tool_calls.push(ToolCallRecord {
|
||||
id: tool_call.id.clone(),
|
||||
name: tool_name.clone(),
|
||||
arguments: tool_args.clone(),
|
||||
output: None,
|
||||
error: Some(format!("requires approval: {message}")),
|
||||
elapsed_ms: None,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::ToolCallStarted {
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
tool_name: tool_name.clone(),
|
||||
arguments: args.clone(),
|
||||
})
|
||||
.await;
|
||||
current_step_tool_calls.push(ToolCallRecord {
|
||||
id: tool_call.id.clone(),
|
||||
name: tool_name.clone(),
|
||||
arguments: tool_args.clone(),
|
||||
output: None,
|
||||
error: None,
|
||||
elapsed_ms: None,
|
||||
});
|
||||
}
|
||||
Ok(rig::agent::MultiTurnStreamItem::StreamUserItem(
|
||||
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
|
||||
)) => {
|
||||
let content =
|
||||
super::helpers::tool_result_content_to_string(&tool_result.content);
|
||||
accumulated_output_chars += content.chars().count();
|
||||
|
||||
if let Some(last) = current_step_tool_calls.last_mut()
|
||||
&& last.id == tool_result.id
|
||||
{
|
||||
last.output = Some(serde_json::from_str(&content).unwrap_or_default());
|
||||
}
|
||||
|
||||
let tool_name = current_step_tool_calls
|
||||
.last()
|
||||
.map(|tc| tc.name.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::ToolCallFinished {
|
||||
tool_call_id: tool_result.id.clone(),
|
||||
tool_name,
|
||||
output: content.clone(),
|
||||
error: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
if !hooks.is_empty() {
|
||||
let outcome = ToolCallOutcome {
|
||||
name: tool_result.id.clone(),
|
||||
arguments: serde_json::Value::Null,
|
||||
output: Some(serde_json::Value::String(content)),
|
||||
error: None,
|
||||
elapsed_ms: 0,
|
||||
};
|
||||
let _ = hooks.run_post_tool_call(&outcome).await;
|
||||
}
|
||||
}
|
||||
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
|
||||
let usage = resp.usage();
|
||||
|
||||
if !current_step_tool_calls.is_empty() || !current_step_assistant.is_empty() {
|
||||
let reasoning = (!current_step_reasoning.is_empty())
|
||||
.then_some(std::mem::take(&mut current_step_reasoning));
|
||||
steps.push(AgentStep {
|
||||
index: steps.len(),
|
||||
assistant: (!current_step_assistant.is_empty())
|
||||
.then_some(std::mem::take(&mut current_step_assistant)),
|
||||
reasoning_content: reasoning,
|
||||
tool_calls: std::mem::take(&mut current_step_tool_calls),
|
||||
reflection: None,
|
||||
});
|
||||
}
|
||||
let output = steps
|
||||
.last()
|
||||
.and_then(|s| s.assistant.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if !hooks.is_empty() {
|
||||
let hook_response = HookLlmResponse {
|
||||
content: Some(output.clone()),
|
||||
tool_calls: None,
|
||||
input_tokens: usage.input_tokens,
|
||||
output_tokens: usage.output_tokens,
|
||||
finish_reason: None,
|
||||
};
|
||||
let _ = hooks.run_post_llm_call(&hook_response).await;
|
||||
}
|
||||
|
||||
info!(
|
||||
steps = steps.len(),
|
||||
input_tokens = usage.input_tokens,
|
||||
output_tokens = usage.output_tokens,
|
||||
"agent run completed"
|
||||
);
|
||||
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::Final {
|
||||
content: output.clone(),
|
||||
input_tokens: usage.input_tokens,
|
||||
output_tokens: usage.output_tokens,
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Some(ref ctx) = request.run_context {
|
||||
let _ = hooks.run_session_end(ctx, true).await;
|
||||
}
|
||||
|
||||
return Ok(AgentResult {
|
||||
output,
|
||||
steps,
|
||||
expert_outputs,
|
||||
input_tokens: usage.input_tokens as i64,
|
||||
output_tokens: usage.output_tokens as i64,
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
let err = format!("{e}");
|
||||
warn!(error = %err, "agent stream error");
|
||||
let _ = tx.send(RigStreamChunk::Failed { error: err }).await;
|
||||
|
||||
if let Some(ref ctx) = request.run_context {
|
||||
let _ = hooks.run_session_end(ctx, false).await;
|
||||
}
|
||||
return Err(AiError::Api(format!("{e}")));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Err(AiError::Response("agent stream ended without final response".to_string()))
|
||||
}
|
||||
|
||||
impl Clone for HookChain {
|
||||
fn clone(&self) -> Self {
|
||||
HookChain::empty()
|
||||
}
|
||||
}
|
||||
222
lib/ai/agent/compression.rs
Normal file
222
lib/ai/agent/compression.rs
Normal file
@ -0,0 +1,222 @@
|
||||
use crate::error::AiResult;
|
||||
|
||||
/// Compression strategy controlling when and how context compaction occurs.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CompressionStrategy {
|
||||
/// Token threshold that triggers compaction.
|
||||
pub threshold_tokens: i64,
|
||||
/// Target token count after compaction.
|
||||
pub target_tokens: i64,
|
||||
/// Number of recent message pairs to always preserve.
|
||||
pub preserve_last_n_pairs: usize,
|
||||
/// Optional model override for the compaction LLM call.
|
||||
pub summary_model: String,
|
||||
/// Reserve this many tokens for the compaction prompt itself.
|
||||
pub reserve_tokens: i64,
|
||||
/// Whether to generate branch summaries when forking.
|
||||
pub branch_summarization: bool,
|
||||
/// Custom instructions appended to the compaction prompt.
|
||||
pub custom_instructions: Option<String>,
|
||||
/// Maximum word count for compaction summaries.
|
||||
pub max_summary_words: usize,
|
||||
}
|
||||
|
||||
impl Default for CompressionStrategy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
threshold_tokens: 64_000,
|
||||
target_tokens: 32_000,
|
||||
preserve_last_n_pairs: 4,
|
||||
summary_model: String::new(),
|
||||
reserve_tokens: 16_384,
|
||||
branch_summarization: true,
|
||||
custom_instructions: None,
|
||||
max_summary_words: 1500,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompressionStrategy {
|
||||
pub fn new(threshold_tokens: i64, target_tokens: i64) -> Self {
|
||||
Self {
|
||||
threshold_tokens,
|
||||
target_tokens,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_preserve_last(mut self, n: usize) -> Self {
|
||||
self.preserve_last_n_pairs = n;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_summary_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.summary_model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_reserve_tokens(mut self, tokens: i64) -> Self {
|
||||
self.reserve_tokens = tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_branch_summarization(mut self, enabled: bool) -> Self {
|
||||
self.branch_summarization = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_custom_instructions(mut self, instructions: impl Into<String>) -> Self {
|
||||
self.custom_instructions = Some(instructions.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_summary_words(mut self, words: usize) -> Self {
|
||||
self.max_summary_words = words;
|
||||
self
|
||||
}
|
||||
|
||||
/// Check whether compaction should be triggered based on current token count.
|
||||
pub fn should_compact(&self, current_tokens: i64) -> bool {
|
||||
current_tokens >= self.threshold_tokens
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactionResult {
|
||||
pub summary: String,
|
||||
pub messages_compacted: usize,
|
||||
pub tokens_saved: i64,
|
||||
/// Whether this was a branch summary (vs. standard compaction).
|
||||
pub is_branch_summary: bool,
|
||||
}
|
||||
|
||||
impl CompactionResult {
|
||||
pub fn new(summary: String, messages_compacted: usize, tokens_saved: i64) -> Self {
|
||||
Self {
|
||||
summary,
|
||||
messages_compacted,
|
||||
tokens_saved,
|
||||
is_branch_summary: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn branch_summary(summary: String, entries_summarized: usize) -> Self {
|
||||
Self {
|
||||
summary,
|
||||
messages_compacted: entries_summarized,
|
||||
tokens_saved: 0,
|
||||
is_branch_summary: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the compaction prompt for standard context compression.
|
||||
pub fn build_compression_prompt(
|
||||
existing_summary: Option<&str>,
|
||||
messages_text: &str,
|
||||
) -> String {
|
||||
build_compression_prompt_with_options(existing_summary, messages_text, None, 1500)
|
||||
}
|
||||
|
||||
/// Build the compaction prompt with custom instructions and word limit.
|
||||
pub fn build_compression_prompt_with_options(
|
||||
existing_summary: Option<&str>,
|
||||
messages_text: &str,
|
||||
custom_instructions: Option<&str>,
|
||||
max_words: usize,
|
||||
) -> String {
|
||||
let custom = custom_instructions
|
||||
.map(|ci| format!("\n\nAdditional instructions: {ci}"))
|
||||
.unwrap_or_default();
|
||||
|
||||
if let Some(summary) = existing_summary {
|
||||
format!(
|
||||
"## Previous Summary\n{summary}\n\n## New Messages\n{messages_text}\n\n\
|
||||
Combine the previous summary and the new messages into a concise, \
|
||||
single-paragraph summary of the conversation. Preserve facts, \
|
||||
decisions, code snippets, and anything essential for continuing \
|
||||
work. Target up to {max_words} words.{custom} \
|
||||
Output ONLY the summary text, no preamble.",
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"## Conversation\n{messages_text}\n\n\
|
||||
Summarise the conversation above into a concise, single-paragraph \
|
||||
summary. Preserve facts, decisions, code snippets, and anything \
|
||||
essential for continuing work. Target up to {max_words} words.{custom} \
|
||||
Output ONLY the summary text, no preamble.",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a prompt for generating a branch summary.
|
||||
///
|
||||
/// Used when the user forks a conversation from a different point in the
|
||||
/// session tree. Summarizes the divergent branch so context is preserved.
|
||||
pub fn build_branch_summary_prompt(
|
||||
branch_messages: &str,
|
||||
custom_instructions: Option<&str>,
|
||||
) -> String {
|
||||
let custom = custom_instructions
|
||||
.map(|ci| format!("\n\nAdditional instructions: {ci}"))
|
||||
.unwrap_or_default();
|
||||
|
||||
format!(
|
||||
"## Branch Conversation\n{branch_messages}\n\n\
|
||||
Summarize the conversation branch above. This summary will be used \
|
||||
to preserve context when the user navigates away from this branch. \
|
||||
Focus on key decisions, unresolved questions, and important context.{custom} \
|
||||
Output ONLY the summary text, no preamble.",
|
||||
)
|
||||
}
|
||||
|
||||
/// Calculate how many messages to truncate to reach the target token count.
|
||||
pub fn estimate_truncation(
|
||||
message_token_counts: &[i64],
|
||||
current_total: i64,
|
||||
target: i64,
|
||||
preserve_last: usize,
|
||||
) -> AiResult<(usize, i64)> {
|
||||
let n = message_token_counts.len();
|
||||
if n <= preserve_last {
|
||||
return Ok((0, 0));
|
||||
}
|
||||
|
||||
let excess = (current_total - target).max(0);
|
||||
|
||||
let mut truncated = 0;
|
||||
let mut saved = 0i64;
|
||||
let limit = n - preserve_last;
|
||||
|
||||
for i in 0..limit {
|
||||
if saved >= excess {
|
||||
break;
|
||||
}
|
||||
saved += message_token_counts[i];
|
||||
truncated += 1;
|
||||
}
|
||||
|
||||
Ok((truncated, saved.min(excess)))
|
||||
}
|
||||
|
||||
/// Calculate compaction parameters for a given set of messages.
|
||||
///
|
||||
/// Returns `(messages_to_compact, tokens_saved)` where `messages_to_compact`
|
||||
/// is the count of oldest messages to summarize, and `tokens_saved` is the
|
||||
/// estimated token savings.
|
||||
pub fn plan_compaction(
|
||||
strategy: &CompressionStrategy,
|
||||
message_token_counts: &[i64],
|
||||
current_total: i64,
|
||||
) -> AiResult<(usize, i64)> {
|
||||
if !strategy.should_compact(current_total) {
|
||||
return Ok((0, 0));
|
||||
}
|
||||
|
||||
estimate_truncation(
|
||||
message_token_counts,
|
||||
current_total,
|
||||
strategy.target_tokens,
|
||||
strategy.preserve_last_n_pairs * 2, // pairs → individual messages
|
||||
)
|
||||
}
|
||||
217
lib/ai/agent/config.rs
Normal file
217
lib/ai/agent/config.rs
Normal file
@ -0,0 +1,217 @@
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
pub const DEFAULT_SYSTEM_PROMPT: &str = r#"You are a precise autonomous agent that executes tasks through tool calls.
|
||||
|
||||
## Core Principles
|
||||
- Use tools when they can materially improve correctness or efficiency
|
||||
- After each action, verify results and adjust approach if needed
|
||||
- Keep reasoning concise and focus on actionable outcomes
|
||||
- Return only the final useful answer to the user
|
||||
|
||||
## Workflow
|
||||
1. Analyze the request and plan your approach
|
||||
2. Execute actions using appropriate tools
|
||||
3. Review observations and verify assumptions
|
||||
4. Iterate until the task is complete
|
||||
5. Provide a clear, concise final response
|
||||
|
||||
## Title Generation
|
||||
If this is the first user message in a new conversation with a default title, you SHOULD call `set_conversation_title` as your first action to create a short, descriptive title (max 100 chars). This helps keep the conversation organized. Only do this once at the very beginning."#;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AgentConfig {
|
||||
pub model: String,
|
||||
pub provider: String,
|
||||
pub api_mode: String,
|
||||
pub system_prompt: String,
|
||||
pub max_iterations: usize,
|
||||
pub iteration_budget: usize,
|
||||
pub temperature: Option<f64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
pub max_total_tokens_per_run: Option<i64>,
|
||||
pub enabled_toolsets: Vec<String>,
|
||||
pub disabled_toolsets: Vec<String>,
|
||||
pub allowed_tools: Vec<String>,
|
||||
pub denied_tools: Vec<String>,
|
||||
pub retry_max_attempts: usize,
|
||||
pub retry_base_delay_ms: u64,
|
||||
pub retry_jitter: bool,
|
||||
pub fallback_model: Option<String>,
|
||||
pub skip_memory: bool,
|
||||
pub skip_context_files: bool,
|
||||
pub skip_compression: bool,
|
||||
pub quiet_mode: bool,
|
||||
pub save_trajectories: bool,
|
||||
pub reasoning_effort: Option<String>,
|
||||
pub service_tier: Option<String>,
|
||||
pub platform: Option<String>,
|
||||
pub session_id: Option<uuid::Uuid>,
|
||||
}
|
||||
|
||||
impl AgentConfig {
|
||||
pub fn new(model: impl Into<String>) -> AiResult<Self> {
|
||||
let config = Self {
|
||||
model: model.into(),
|
||||
provider: String::new(),
|
||||
api_mode: String::from("chat_completions"),
|
||||
system_prompt: DEFAULT_SYSTEM_PROMPT.to_string(),
|
||||
max_iterations: 64,
|
||||
iteration_budget: 90,
|
||||
temperature: Some(0.2),
|
||||
max_completion_tokens: None,
|
||||
max_total_tokens_per_run: Some(128_000),
|
||||
enabled_toolsets: Vec::new(),
|
||||
disabled_toolsets: Vec::new(),
|
||||
allowed_tools: Vec::new(),
|
||||
denied_tools: Vec::new(),
|
||||
retry_max_attempts: 3,
|
||||
retry_base_delay_ms: 1_000,
|
||||
retry_jitter: true,
|
||||
fallback_model: None,
|
||||
skip_memory: false,
|
||||
skip_context_files: false,
|
||||
skip_compression: false,
|
||||
quiet_mode: false,
|
||||
save_trajectories: false,
|
||||
reasoning_effort: None,
|
||||
service_tier: None,
|
||||
platform: None,
|
||||
session_id: None,
|
||||
};
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> AiResult<()> {
|
||||
if self.model.trim().is_empty() {
|
||||
return Err(AiError::Config("agent model is required".to_string()));
|
||||
}
|
||||
if self.max_iterations == 0 {
|
||||
return Err(AiError::Config(
|
||||
"agent max_iterations must be greater than 0".to_string(),
|
||||
));
|
||||
}
|
||||
if let Some(tokens) = self.max_total_tokens_per_run
|
||||
&& tokens <= 0
|
||||
{
|
||||
return Err(AiError::Config(
|
||||
"agent max_total_tokens_per_run must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
|
||||
self.provider = provider.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_api_mode(mut self, mode: impl Into<String>) -> Self {
|
||||
self.api_mode = mode.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_iterations(mut self, max: usize) -> Self {
|
||||
self.max_iterations = max;
|
||||
self.iteration_budget = self.iteration_budget.max(max);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_iteration_budget(mut self, budget: usize) -> Self {
|
||||
self.iteration_budget = budget;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
|
||||
self.system_prompt = prompt.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temperature: Option<f64>) -> Self {
|
||||
self.temperature = temperature;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_completion_tokens(mut self, max_completion_tokens: Option<u64>) -> Self {
|
||||
self.max_completion_tokens = max_completion_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_total_tokens(mut self, limit: Option<i64>) -> Self {
|
||||
self.max_total_tokens_per_run = limit;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_toolset_policy(mut self, enabled: Vec<String>, disabled: Vec<String>) -> Self {
|
||||
self.enabled_toolsets = enabled;
|
||||
self.disabled_toolsets = disabled;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tool_policy(mut self, allowed_tools: Vec<String>, denied_tools: Vec<String>) -> Self {
|
||||
self.allowed_tools = allowed_tools;
|
||||
self.denied_tools = denied_tools;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_retry(mut self, max_attempts: usize, base_delay_ms: u64) -> Self {
|
||||
self.retry_max_attempts = max_attempts;
|
||||
self.retry_base_delay_ms = base_delay_ms;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_retry_jitter(mut self, jitter: bool) -> Self {
|
||||
self.retry_jitter = jitter;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_fallback_model(mut self, fallback_model: impl Into<String>) -> Self {
|
||||
self.fallback_model = Some(fallback_model.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_skip_memory(mut self, skip: bool) -> Self {
|
||||
self.skip_memory = skip;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_skip_compression(mut self, skip: bool) -> Self {
|
||||
self.skip_compression = skip;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_quiet_mode(mut self, quiet: bool) -> Self {
|
||||
self.quiet_mode = quiet;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_platform(mut self, platform: impl Into<String>) -> Self {
|
||||
self.platform = Some(platform.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_session_id(mut self, session_id: uuid::Uuid) -> Self {
|
||||
self.session_id = Some(session_id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_reasoning_effort(mut self, effort: impl Into<String>) -> Self {
|
||||
self.reasoning_effort = Some(effort.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn is_tool_exposed(&self, name: &str) -> bool {
|
||||
let denied = self.denied_tools.iter().any(|tool| tool == name);
|
||||
if denied {
|
||||
return false;
|
||||
}
|
||||
if self.allowed_tools.is_empty() {
|
||||
return true;
|
||||
}
|
||||
self.allowed_tools.iter().any(|tool| tool == name)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_system_prompt() -> &'static str {
|
||||
DEFAULT_SYSTEM_PROMPT
|
||||
}
|
||||
239
lib/ai/agent/error_classifier.rs
Normal file
239
lib/ai/agent/error_classifier.rs
Normal file
@ -0,0 +1,239 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error::AiError;
|
||||
|
||||
/// Categorized error for deciding retry/fallback/fatal strategy.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ErrorCategory {
|
||||
/// Transient error, safe to retry with backoff.
|
||||
Retryable { reason: String },
|
||||
/// Authentication or quota error, switch to fallback model.
|
||||
FallbackModel { reason: String },
|
||||
/// Non-recoverable error, do not retry.
|
||||
Fatal { reason: String },
|
||||
/// Token budget exceeded for this run.
|
||||
TokenBudgetExceeded,
|
||||
/// Request timed out.
|
||||
Timeout,
|
||||
/// Request was cancelled by the caller.
|
||||
Cancelled,
|
||||
/// Provider is overloaded or at capacity, retry with longer delay.
|
||||
Overloaded { reason: String },
|
||||
/// Context window exceeded, needs compaction before retry.
|
||||
ContextWindowExceeded { reason: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryPolicy {
|
||||
pub max_attempts: usize,
|
||||
pub base_delay: Duration,
|
||||
pub jitter: bool,
|
||||
pub exponential: bool,
|
||||
pub switch_to_fallback: bool,
|
||||
}
|
||||
|
||||
impl RetryPolicy {
|
||||
pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
|
||||
let ms = if self.exponential {
|
||||
self.base_delay.as_millis() as u64 * (1u64 << attempt.min(6))
|
||||
} else {
|
||||
self.base_delay.as_millis() as u64
|
||||
};
|
||||
|
||||
let ms = if self.jitter {
|
||||
let half = (ms as f64 * 0.25) as u64;
|
||||
let lo = ms.saturating_sub(half);
|
||||
let hi = ms.saturating_add(half);
|
||||
let mix = ((attempt as u64).wrapping_mul(1_103_515_245)) % (hi - lo + 1);
|
||||
lo + mix
|
||||
} else {
|
||||
ms
|
||||
};
|
||||
|
||||
Duration::from_millis(ms.max(100))
|
||||
}
|
||||
}
|
||||
|
||||
/// Classify an error into a category for retry/fallback decisions.
|
||||
///
|
||||
/// Inspects both the HTTP status code (when available) and the error message
|
||||
/// content to determine the most appropriate category.
|
||||
pub fn classify_error(error: &AiError, http_status: Option<u16>) -> ErrorCategory {
|
||||
// HTTP status-based classification takes precedence
|
||||
let from_status = match http_status {
|
||||
Some(429) => Some(ErrorCategory::Retryable {
|
||||
reason: "rate limited (HTTP 429)".to_string(),
|
||||
}),
|
||||
Some(401) | Some(403) => Some(ErrorCategory::FallbackModel {
|
||||
reason: format!("authentication failed (HTTP {})", http_status.unwrap()),
|
||||
}),
|
||||
Some(502) | Some(503) => Some(ErrorCategory::Overloaded {
|
||||
reason: format!("provider unavailable (HTTP {})", http_status.unwrap()),
|
||||
}),
|
||||
Some(504) => Some(ErrorCategory::Timeout),
|
||||
Some(413) => Some(ErrorCategory::ContextWindowExceeded {
|
||||
reason: "payload too large (HTTP 413)".to_string(),
|
||||
}),
|
||||
Some(s) if (400..500).contains(&s) => Some(ErrorCategory::Fatal {
|
||||
reason: format!("client error (HTTP {})", s),
|
||||
}),
|
||||
Some(s) if (500..600).contains(&s) => Some(ErrorCategory::Retryable {
|
||||
reason: format!("server error (HTTP {})", s),
|
||||
}),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(cat) = from_status {
|
||||
return cat;
|
||||
}
|
||||
|
||||
// Message-based classification
|
||||
match error {
|
||||
AiError::Timeout { .. } => ErrorCategory::Timeout,
|
||||
AiError::TokenBudgetExceeded { .. } => ErrorCategory::TokenBudgetExceeded,
|
||||
AiError::Api(msg) => classify_api_message(msg),
|
||||
AiError::Response(msg) => classify_response_message(msg),
|
||||
AiError::ModelRetriesExhausted { .. } => ErrorCategory::Fatal {
|
||||
reason: error.to_string(),
|
||||
},
|
||||
_ => ErrorCategory::Fatal {
|
||||
reason: error.to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Classify API error messages by keyword patterns.
|
||||
fn classify_api_message(msg: &str) -> ErrorCategory {
|
||||
let lower = msg.to_lowercase();
|
||||
|
||||
// Rate limiting
|
||||
if lower.contains("rate") || lower.contains("too many requests") || lower.contains("throttl") {
|
||||
return ErrorCategory::Retryable {
|
||||
reason: msg.to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
// Overloaded / capacity
|
||||
if lower.contains("overloaded")
|
||||
|| lower.contains("capacity")
|
||||
|| lower.contains("too busy")
|
||||
|| lower.contains("service unavailable")
|
||||
{
|
||||
return ErrorCategory::Overloaded {
|
||||
reason: msg.to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
// Authentication / quota
|
||||
if lower.contains("unauthorized")
|
||||
|| lower.contains("invalid api key")
|
||||
|| lower.contains("api key")
|
||||
|| lower.contains("forbidden")
|
||||
|| lower.contains("quota exceeded")
|
||||
|| lower.contains("insufficient")
|
||||
|| lower.contains("billing")
|
||||
{
|
||||
return ErrorCategory::FallbackModel {
|
||||
reason: msg.to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
// Context window exceeded
|
||||
if lower.contains("context length")
|
||||
|| lower.contains("context window")
|
||||
|| lower.contains("maximum context")
|
||||
|| lower.contains("too many tokens")
|
||||
|| lower.contains("max_tokens")
|
||||
{
|
||||
return ErrorCategory::ContextWindowExceeded {
|
||||
reason: msg.to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
ErrorCategory::Fatal {
|
||||
reason: msg.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Classify response error messages by keyword patterns.
|
||||
fn classify_response_message(msg: &str) -> ErrorCategory {
|
||||
let lower = msg.to_lowercase();
|
||||
|
||||
if lower.contains("cancelled") || lower.contains("canceled") {
|
||||
return ErrorCategory::Cancelled;
|
||||
}
|
||||
if lower.contains("timeout") || lower.contains("timed out") {
|
||||
return ErrorCategory::Timeout;
|
||||
}
|
||||
|
||||
ErrorCategory::Fatal {
|
||||
reason: msg.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the recommended retry policy for an error category.
|
||||
pub fn retry_policy_for(
|
||||
category: &ErrorCategory,
|
||||
max_attempts: usize,
|
||||
base_delay_ms: u64,
|
||||
) -> RetryPolicy {
|
||||
match category {
|
||||
ErrorCategory::Retryable { .. } => RetryPolicy {
|
||||
max_attempts,
|
||||
base_delay: Duration::from_millis(base_delay_ms),
|
||||
jitter: true,
|
||||
exponential: true,
|
||||
switch_to_fallback: false,
|
||||
},
|
||||
ErrorCategory::Overloaded { .. } => RetryPolicy {
|
||||
max_attempts: max_attempts.min(5),
|
||||
base_delay: Duration::from_millis(base_delay_ms.max(5_000)),
|
||||
jitter: true,
|
||||
exponential: true,
|
||||
switch_to_fallback: true,
|
||||
},
|
||||
ErrorCategory::FallbackModel { .. } => RetryPolicy {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(500),
|
||||
jitter: false,
|
||||
exponential: false,
|
||||
switch_to_fallback: true,
|
||||
},
|
||||
ErrorCategory::ContextWindowExceeded { .. } => RetryPolicy {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(0),
|
||||
jitter: false,
|
||||
exponential: false,
|
||||
switch_to_fallback: false,
|
||||
},
|
||||
ErrorCategory::Timeout => RetryPolicy {
|
||||
max_attempts: max_attempts.min(2),
|
||||
base_delay: Duration::from_millis(base_delay_ms.max(2_000)),
|
||||
jitter: true,
|
||||
exponential: false,
|
||||
switch_to_fallback: false,
|
||||
},
|
||||
ErrorCategory::TokenBudgetExceeded | ErrorCategory::Cancelled | ErrorCategory::Fatal { .. } => {
|
||||
RetryPolicy {
|
||||
max_attempts: 0,
|
||||
base_delay: Duration::from_millis(0),
|
||||
jitter: false,
|
||||
exponential: false,
|
||||
switch_to_fallback: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine whether the error warrants switching to a fallback model.
|
||||
pub fn should_switch_to_fallback(category: &ErrorCategory) -> bool {
|
||||
matches!(
|
||||
category,
|
||||
ErrorCategory::FallbackModel { .. } | ErrorCategory::Overloaded { .. }
|
||||
)
|
||||
}
|
||||
|
||||
/// Determine whether compaction should be attempted before retry.
|
||||
pub fn should_compact_before_retry(category: &ErrorCategory) -> bool {
|
||||
matches!(category, ErrorCategory::ContextWindowExceeded { .. })
|
||||
}
|
||||
179
lib/ai/agent/events.rs
Normal file
179
lib/ai/agent/events.rs
Normal file
@ -0,0 +1,179 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Fine-grained agent lifecycle events, inspired by pi's event system.
|
||||
///
|
||||
/// Covers the full agent execution lifecycle with enough granularity
|
||||
/// for UI rendering, telemetry, and extension hooks.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum AgentEvent {
|
||||
// === Agent lifecycle ===
|
||||
AgentStart,
|
||||
AgentEnd {
|
||||
messages: Vec<AgentEventMessage>,
|
||||
total_input_tokens: u64,
|
||||
total_output_tokens: u64,
|
||||
},
|
||||
|
||||
// === Turn lifecycle ===
|
||||
TurnStart {
|
||||
turn_index: usize,
|
||||
},
|
||||
TurnEnd {
|
||||
turn_index: usize,
|
||||
assistant_text: Option<String>,
|
||||
tool_call_count: usize,
|
||||
},
|
||||
|
||||
// === Message lifecycle ===
|
||||
MessageStart {
|
||||
role: MessageRole,
|
||||
},
|
||||
MessageTextDelta {
|
||||
index: usize,
|
||||
delta: String,
|
||||
},
|
||||
MessageThinkingDelta {
|
||||
index: usize,
|
||||
delta: String,
|
||||
},
|
||||
MessageEnd {
|
||||
role: MessageRole,
|
||||
},
|
||||
|
||||
// === Tool execution lifecycle ===
|
||||
ToolExecutionStart {
|
||||
tool_call_id: String,
|
||||
tool_name: String,
|
||||
arguments: Value,
|
||||
},
|
||||
ToolExecutionUpdate {
|
||||
tool_call_id: String,
|
||||
tool_name: String,
|
||||
partial_output: String,
|
||||
},
|
||||
ToolExecutionEnd {
|
||||
tool_call_id: String,
|
||||
tool_name: String,
|
||||
output: Option<Value>,
|
||||
error: Option<String>,
|
||||
elapsed_ms: i64,
|
||||
},
|
||||
|
||||
// === Steering / follow-up ===
|
||||
SteeringMessagesInjected {
|
||||
count: usize,
|
||||
},
|
||||
FollowUpMessagesInjected {
|
||||
count: usize,
|
||||
},
|
||||
|
||||
// === Context management ===
|
||||
ContextCompacted {
|
||||
messages_compacted: usize,
|
||||
tokens_saved: i64,
|
||||
},
|
||||
BranchSummaryCreated {
|
||||
entry_count: usize,
|
||||
summary_length: usize,
|
||||
},
|
||||
|
||||
// === Model switching ===
|
||||
ModelSwitched {
|
||||
from_model: String,
|
||||
to_model: String,
|
||||
reason: String,
|
||||
},
|
||||
|
||||
// === Error and retry ===
|
||||
ErrorClassified {
|
||||
category: String,
|
||||
message: String,
|
||||
will_retry: bool,
|
||||
retry_delay_ms: Option<u64>,
|
||||
},
|
||||
RetryAttempt {
|
||||
attempt: usize,
|
||||
max_attempts: usize,
|
||||
delay_ms: u64,
|
||||
},
|
||||
}
|
||||
|
||||
/// Simplified message role for event display.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessageRole {
|
||||
User,
|
||||
Assistant,
|
||||
ToolResult,
|
||||
System,
|
||||
}
|
||||
|
||||
/// A simplified message representation for `AgentEnd` events.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AgentEventMessage {
|
||||
pub role: MessageRole,
|
||||
pub content: String,
|
||||
pub tool_calls: Vec<EventToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EventToolCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
pub output: Option<Value>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// An async-friendly event sink that collects or broadcasts events.
|
||||
pub struct EventSink {
|
||||
senders: Vec<tokio::sync::mpsc::UnboundedSender<AgentEvent>>,
|
||||
}
|
||||
|
||||
impl EventSink {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
senders: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to events, returns a receiver.
|
||||
pub fn subscribe(&mut self) -> tokio::sync::mpsc::UnboundedReceiver<AgentEvent> {
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
self.senders.push(tx);
|
||||
rx
|
||||
}
|
||||
|
||||
/// Emit an event to all subscribers. Non-blocking; drops if receiver disconnected.
|
||||
pub fn emit(&self, event: AgentEvent) {
|
||||
for sender in &self.senders {
|
||||
let _ = sender.send(event.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if there are any active subscribers.
|
||||
pub fn has_subscribers(&self) -> bool {
|
||||
!self.senders.is_empty()
|
||||
}
|
||||
|
||||
/// Remove disconnected senders.
|
||||
pub fn cleanup(&mut self) {
|
||||
self.senders.retain(|s| !s.is_closed());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EventSink {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for EventSink {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
senders: self.senders.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
113
lib/ai/agent/helpers.rs
Normal file
113
lib/ai/agent/helpers.rs
Normal file
@ -0,0 +1,113 @@
|
||||
use std::future::Future;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::agent::request::AgentRequest;
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
pub fn build_input_string(request: &AgentRequest) -> String {
|
||||
let mut input = String::new();
|
||||
|
||||
if !request.context.is_empty() {
|
||||
input.push_str("<retrieved_context>\n");
|
||||
for chunk in &request.context {
|
||||
let source = chunk.source.as_deref().unwrap_or("unknown");
|
||||
let score = chunk
|
||||
.score
|
||||
.map(|s| format!("{s:.4}"))
|
||||
.unwrap_or_else(|| "n/a".to_string());
|
||||
input.push_str(&format!(
|
||||
"\n<chunk id=\"{}\" source=\"{}\" score=\"{}\">\n{}\n</chunk>\n",
|
||||
chunk.id, source, score, chunk.content
|
||||
));
|
||||
}
|
||||
input.push_str("</retrieved_context>\n\n");
|
||||
}
|
||||
|
||||
for message in &request.messages {
|
||||
match message {
|
||||
super::request::AgentMessage::User(content) => {
|
||||
input.push_str(&format!("User: {content}\n"));
|
||||
}
|
||||
super::request::AgentMessage::Assistant(content) => {
|
||||
input.push_str(&format!("Assistant: {content}\n"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
input.push_str(&format!("User: {}", request.input));
|
||||
|
||||
input
|
||||
}
|
||||
|
||||
pub fn estimate_tokens(text: &str) -> u64 {
|
||||
if text.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
(text.chars().count() as f64 / 2.5).ceil() as u64
|
||||
}
|
||||
|
||||
pub fn check_token_budget(
|
||||
estimated_input_tokens: u64,
|
||||
accumulated_output_chars: usize,
|
||||
limit: i64,
|
||||
) -> bool {
|
||||
let output_estimate = (accumulated_output_chars as f64 / 2.5).ceil() as u64;
|
||||
estimated_input_tokens + output_estimate > limit as u64
|
||||
}
|
||||
|
||||
pub async fn with_retry<F, Fut, T>(
|
||||
max_attempts: usize,
|
||||
base_delay_ms: u64,
|
||||
f: F,
|
||||
) -> AiResult<T>
|
||||
where
|
||||
F: Fn() -> Fut,
|
||||
Fut: Future<Output = AiResult<T>>,
|
||||
{
|
||||
let mut last_error: Option<AiError> = None;
|
||||
for attempt in 0..max_attempts {
|
||||
match f().await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(e) if is_retryable(&e) && attempt + 1 < max_attempts => {
|
||||
let delay = Duration::from_millis(base_delay_ms * 2u64.pow(attempt as u32));
|
||||
tracing::warn!(
|
||||
error = %e,
|
||||
attempt = attempt + 1,
|
||||
max_attempts,
|
||||
delay_ms = delay.as_millis(),
|
||||
"retrying after transient error"
|
||||
);
|
||||
tokio::time::sleep(delay).await;
|
||||
last_error = Some(e);
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
Err(AiError::ModelRetriesExhausted {
|
||||
attempts: max_attempts,
|
||||
last_error: last_error
|
||||
.map(|e| e.to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
fn is_retryable(error: &AiError) -> bool {
|
||||
matches!(
|
||||
error,
|
||||
AiError::Api(_) | AiError::Response(_) | AiError::ModelRetriesExhausted { .. }
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tool_result_content_to_string(
|
||||
content: &rig::one_or_many::OneOrMany<rig::completion::message::ToolResultContent>,
|
||||
) -> String {
|
||||
use rig::completion::message::ToolResultContent;
|
||||
content
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
ToolResultContent::Text(t) => Some(t.text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
145
lib/ai/agent/hooks.rs
Normal file
145
lib/ai/agent/hooks.rs
Normal file
@ -0,0 +1,145 @@
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::agent::persistence::AgentRunContext;
|
||||
use crate::error::AiResult;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ToolGuardrailDecision {
|
||||
Allow,
|
||||
Block { reason: String },
|
||||
RequireApproval { message: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCallOutcome {
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
pub output: Option<Value>,
|
||||
pub error: Option<String>,
|
||||
pub elapsed_ms: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookMessage {
|
||||
pub role: String,
|
||||
pub content: Option<String>,
|
||||
pub tool_calls: Option<Value>,
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookLlmResponse {
|
||||
pub content: Option<String>,
|
||||
pub tool_calls: Option<Value>,
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HookToolDef {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait AgentHook: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
async fn on_session_start(&self, _ctx: &AgentRunContext) -> AiResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn on_session_end(&self, _ctx: &AgentRunContext, _success: bool) -> AiResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn pre_llm_call(&self, _messages: &[HookMessage], _tools: &[HookToolDef]) -> AiResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn post_llm_call(&self, _response: &HookLlmResponse) -> AiResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn pre_tool_call(
|
||||
&self,
|
||||
_tool_name: &str,
|
||||
_arguments: &Value,
|
||||
) -> AiResult<Option<ToolGuardrailDecision>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn post_tool_call(&self, _outcome: &ToolCallOutcome) -> AiResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct HookChain {
|
||||
hooks: Vec<Box<dyn AgentHook>>,
|
||||
}
|
||||
|
||||
impl HookChain {
|
||||
pub fn new(hooks: Vec<Box<dyn AgentHook>>) -> Self {
|
||||
Self { hooks }
|
||||
}
|
||||
|
||||
pub fn empty() -> Self {
|
||||
Self { hooks: Vec::new() }
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.hooks.is_empty()
|
||||
}
|
||||
|
||||
pub async fn run_session_start(&self, ctx: &AgentRunContext) -> AiResult<()> {
|
||||
for hook in &self.hooks {
|
||||
hook.on_session_start(ctx).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_session_end(&self, ctx: &AgentRunContext, success: bool) -> AiResult<()> {
|
||||
for hook in &self.hooks {
|
||||
hook.on_session_end(ctx, success).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_pre_llm_call(&self, messages: &[HookMessage], tools: &[HookToolDef]) -> AiResult<()> {
|
||||
for hook in &self.hooks {
|
||||
hook.pre_llm_call(messages, tools).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_post_llm_call(&self, response: &HookLlmResponse) -> AiResult<()> {
|
||||
for hook in &self.hooks {
|
||||
hook.post_llm_call(response).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_pre_tool_call(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: &Value,
|
||||
) -> AiResult<Option<ToolGuardrailDecision>> {
|
||||
for hook in &self.hooks {
|
||||
if let Some(decision) = hook.pre_tool_call(tool_name, arguments).await? {
|
||||
if !matches!(decision, ToolGuardrailDecision::Allow) {
|
||||
return Ok(Some(decision));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn run_post_tool_call(&self, outcome: &ToolCallOutcome) -> AiResult<()> {
|
||||
for hook in &self.hooks {
|
||||
hook.post_tool_call(outcome).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
45
lib/ai/agent/iteration_budget.rs
Normal file
45
lib/ai/agent/iteration_budget.rs
Normal file
@ -0,0 +1,45 @@
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct IterationBudget {
|
||||
pub remaining: usize,
|
||||
pub hard_limit: usize,
|
||||
pub grace_call: bool,
|
||||
pub consumed: usize,
|
||||
}
|
||||
|
||||
impl IterationBudget {
|
||||
pub fn new(limit: usize) -> Self {
|
||||
Self {
|
||||
remaining: limit,
|
||||
hard_limit: limit,
|
||||
grace_call: true,
|
||||
consumed: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn can_continue(&self) -> bool {
|
||||
self.remaining > 0 || (self.remaining == 0 && self.grace_call)
|
||||
}
|
||||
|
||||
pub fn consume(&mut self) -> bool {
|
||||
if self.remaining > 0 {
|
||||
self.remaining -= 1;
|
||||
self.consumed += 1;
|
||||
true
|
||||
} else if self.grace_call {
|
||||
self.grace_call = false;
|
||||
self.consumed += 1;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn exhaust(&mut self) {
|
||||
self.remaining = 0;
|
||||
self.grace_call = false;
|
||||
}
|
||||
|
||||
pub const fn total_consumed(&self) -> usize {
|
||||
self.consumed
|
||||
}
|
||||
}
|
||||
876
lib/ai/agent/loop.rs
Normal file
876
lib/ai/agent/loop.rs
Normal file
@ -0,0 +1,876 @@
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
use rig::agent::AgentBuilder;
|
||||
use rig::client::CompletionClient;
|
||||
use rig::streaming::StreamingPrompt;
|
||||
use rig::tool::ToolDyn;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::config::AgentConfig;
|
||||
use super::error_classifier::{
|
||||
classify_error, retry_policy_for, should_switch_to_fallback,
|
||||
};
|
||||
use super::events::{AgentEvent, EventSink};
|
||||
use super::helpers::{build_input_string, estimate_tokens};
|
||||
use super::hooks::{HookChain, HookLlmResponse, HookMessage, ToolCallOutcome, ToolGuardrailDecision};
|
||||
use super::iteration_budget::IterationBudget;
|
||||
use super::request::{AgentRequest, AgentResult, AgentStep, ToolCallRecord};
|
||||
use super::RigStreamChunk;
|
||||
use crate::client::AiClient;
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
/// How tool calls from a single assistant turn are executed.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ToolExecutionMode {
|
||||
/// Execute tool calls one at a time.
|
||||
Sequential,
|
||||
/// Execute tool calls concurrently (after sequential preflight).
|
||||
Parallel,
|
||||
}
|
||||
|
||||
impl Default for ToolExecutionMode {
|
||||
fn default() -> Self {
|
||||
Self::Parallel
|
||||
}
|
||||
}
|
||||
|
||||
/// Callback type for steering messages (injected mid-run).
|
||||
pub type SteeringFn = Arc<
|
||||
dyn Fn() -> Pin<Box<dyn Future<Output = Vec<String>> + Send>> + Send + Sync,
|
||||
>;
|
||||
|
||||
/// Callback type for follow-up messages (injected after agent would stop).
|
||||
pub type FollowUpFn = Arc<
|
||||
dyn Fn() -> Pin<Box<dyn Future<Output = Vec<String>> + Send>> + Send + Sync,
|
||||
>;
|
||||
|
||||
/// Callback to decide whether the agent should stop after a turn.
|
||||
pub type ShouldStopFn = Arc<
|
||||
dyn Fn(&TurnContext) -> bool + Send + Sync,
|
||||
>;
|
||||
|
||||
/// Callback to prepare/modify state before the next turn.
|
||||
pub type PrepareNextTurnFn = Arc<
|
||||
dyn Fn(&TurnContext) -> Pin<Box<dyn Future<Output = Option<TurnUpdate>> + Send>>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>;
|
||||
|
||||
/// Context passed to `should_stop` and `prepare_next_turn` callbacks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TurnContext {
|
||||
pub turn_index: usize,
|
||||
pub assistant_text: String,
|
||||
pub tool_call_count: usize,
|
||||
pub total_input_tokens: u64,
|
||||
pub total_output_tokens: u64,
|
||||
pub model_name: String,
|
||||
}
|
||||
|
||||
/// Replacement state for the next turn (returned by `prepare_next_turn`).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TurnUpdate {
|
||||
pub model: Option<String>,
|
||||
pub temperature: Option<f64>,
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
/// Extended agent loop configuration, adding steering/follow-up/lifecycle
|
||||
/// hooks on top of the base `AgentConfig`.
|
||||
pub struct AgentLoopConfig {
|
||||
pub config: AgentConfig,
|
||||
pub tool_execution_mode: ToolExecutionMode,
|
||||
pub get_steering_messages: Option<SteeringFn>,
|
||||
pub get_follow_up_messages: Option<FollowUpFn>,
|
||||
pub should_stop_after_turn: Option<ShouldStopFn>,
|
||||
pub prepare_next_turn: Option<PrepareNextTurnFn>,
|
||||
pub event_sink: Option<EventSink>,
|
||||
}
|
||||
|
||||
impl AgentLoopConfig {
|
||||
pub fn new(config: AgentConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
tool_execution_mode: ToolExecutionMode::default(),
|
||||
get_steering_messages: None,
|
||||
get_follow_up_messages: None,
|
||||
should_stop_after_turn: None,
|
||||
prepare_next_turn: None,
|
||||
event_sink: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_tool_execution_mode(mut self, mode: ToolExecutionMode) -> Self {
|
||||
self.tool_execution_mode = mode;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_steering_messages(mut self, f: SteeringFn) -> Self {
|
||||
self.get_steering_messages = Some(f);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_follow_up_messages(mut self, f: FollowUpFn) -> Self {
|
||||
self.get_follow_up_messages = Some(f);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_should_stop(mut self, f: ShouldStopFn) -> Self {
|
||||
self.should_stop_after_turn = Some(f);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_prepare_next_turn(mut self, f: PrepareNextTurnFn) -> Self {
|
||||
self.prepare_next_turn = Some(f);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_event_sink(mut self, sink: EventSink) -> Self {
|
||||
self.event_sink = Some(sink);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Enhanced agent with loop controls (steering, follow-up, model switching).
|
||||
pub struct EnhancedAgent {
|
||||
pub client: AiClient,
|
||||
pub loop_config: AgentLoopConfig,
|
||||
pub hooks: HookChain,
|
||||
}
|
||||
|
||||
impl EnhancedAgent {
|
||||
pub fn new(client: AiClient, loop_config: AgentLoopConfig) -> AiResult<Self> {
|
||||
loop_config.config.validate()?;
|
||||
Ok(Self {
|
||||
client,
|
||||
loop_config,
|
||||
hooks: HookChain::empty(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_hooks(mut self, hooks: HookChain) -> Self {
|
||||
self.hooks = hooks;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &AgentConfig {
|
||||
&self.loop_config.config
|
||||
}
|
||||
|
||||
/// Run the enhanced agent loop, returning a chunk receiver and a join handle.
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn run(
|
||||
&self,
|
||||
request: AgentRequest,
|
||||
tools: Vec<Box<dyn ToolDyn>>,
|
||||
) -> (
|
||||
mpsc::Receiver<RigStreamChunk>,
|
||||
tokio::task::JoinHandle<AiResult<AgentResult>>,
|
||||
) {
|
||||
let (tx, rx) = mpsc::channel::<RigStreamChunk>(256);
|
||||
|
||||
let config = self.loop_config.config.clone();
|
||||
let tool_execution_mode = self.loop_config.tool_execution_mode;
|
||||
let steering_fn = self.loop_config.get_steering_messages.clone();
|
||||
let follow_up_fn = self.loop_config.get_follow_up_messages.clone();
|
||||
let should_stop = self.loop_config.should_stop_after_turn.clone();
|
||||
let prepare_next = self.loop_config.prepare_next_turn.clone();
|
||||
let event_sink = self.loop_config.event_sink.clone();
|
||||
let client = self.client.llm_client().clone();
|
||||
let hooks = self.hooks.clone();
|
||||
|
||||
let filtered_tools: Vec<Box<dyn ToolDyn>> = tools
|
||||
.into_iter()
|
||||
.filter(|tool| config.is_tool_exposed(&tool.name()))
|
||||
.collect();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
run_enhanced_loop(
|
||||
client,
|
||||
config,
|
||||
request,
|
||||
filtered_tools,
|
||||
tool_execution_mode,
|
||||
steering_fn,
|
||||
follow_up_fn,
|
||||
should_stop,
|
||||
prepare_next,
|
||||
event_sink,
|
||||
hooks,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
(rx, handle)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines, clippy::too_many_arguments)]
|
||||
async fn run_enhanced_loop(
|
||||
client: rig::providers::openai::Client,
|
||||
mut config: AgentConfig,
|
||||
request: AgentRequest,
|
||||
tools: Vec<Box<dyn ToolDyn>>,
|
||||
_tool_execution_mode: ToolExecutionMode,
|
||||
steering_fn: Option<SteeringFn>,
|
||||
follow_up_fn: Option<FollowUpFn>,
|
||||
should_stop: Option<ShouldStopFn>,
|
||||
prepare_next: Option<PrepareNextTurnFn>,
|
||||
event_sink: Option<EventSink>,
|
||||
hooks: HookChain,
|
||||
tx: mpsc::Sender<RigStreamChunk>,
|
||||
) -> AiResult<AgentResult> {
|
||||
let cancellation = request.cancellation_token.clone();
|
||||
let timeout = request.timeout;
|
||||
let mut budget = IterationBudget::new(config.iteration_budget);
|
||||
let mut all_steps: Vec<AgentStep> = Vec::new();
|
||||
let mut total_input_tokens: u64 = 0;
|
||||
let mut total_output_tokens: u64 = 0;
|
||||
let mut turn_index: usize = 0;
|
||||
|
||||
// Session start hook
|
||||
if let Some(ctx) = &request.run_context {
|
||||
let _ = hooks.run_session_start(ctx).await;
|
||||
}
|
||||
|
||||
// Emit agent start event
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::AgentStart);
|
||||
}
|
||||
|
||||
// Build the initial input
|
||||
let input = build_input_string(&request);
|
||||
let mut current_input = input.clone();
|
||||
let estimated_input_tokens = estimate_tokens(¤t_input);
|
||||
|
||||
if let Some(limit) = config.max_total_tokens_per_run
|
||||
&& estimated_input_tokens > limit as u64
|
||||
{
|
||||
return Err(AiError::TokenBudgetExceeded {
|
||||
estimated: estimated_input_tokens,
|
||||
limit,
|
||||
});
|
||||
}
|
||||
|
||||
// Outer loop: handles follow-up messages after agent would stop
|
||||
loop {
|
||||
// Inner loop: tool call turns + steering messages
|
||||
let mut pending_steering: Vec<String> = if let Some(f) = &steering_fn {
|
||||
f().await
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
loop {
|
||||
// Check cancellation
|
||||
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
|
||||
let _ = tx.send(RigStreamChunk::Failed { error: "cancelled".to_string() }).await;
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::ErrorClassified {
|
||||
category: "cancelled".to_string(),
|
||||
message: "cancelled by caller".to_string(),
|
||||
will_retry: false,
|
||||
retry_delay_ms: None,
|
||||
});
|
||||
}
|
||||
return Err(AiError::Response("agent run cancelled".to_string()));
|
||||
}
|
||||
|
||||
// Inject steering messages if any
|
||||
if !pending_steering.is_empty() {
|
||||
let count = pending_steering.len();
|
||||
for msg in &pending_steering {
|
||||
current_input.push_str(&format!("\nUser: {msg}\n"));
|
||||
}
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::SteeringMessagesInjected { count });
|
||||
}
|
||||
pending_steering.clear();
|
||||
}
|
||||
|
||||
// Emit turn start
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::TurnStart { turn_index });
|
||||
}
|
||||
let _ = tx.send(RigStreamChunk::TextDelta {
|
||||
index: 0,
|
||||
content: String::new(), // placeholder for turn boundary detection
|
||||
}).await;
|
||||
|
||||
// Run one LLM turn with retry
|
||||
let turn_result = run_single_turn(
|
||||
&client,
|
||||
&config,
|
||||
¤t_input,
|
||||
&tools,
|
||||
&mut budget,
|
||||
&cancellation,
|
||||
timeout,
|
||||
&hooks,
|
||||
&event_sink,
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
|
||||
match turn_result {
|
||||
Ok(turn_output) => {
|
||||
total_input_tokens += turn_output.input_tokens;
|
||||
total_output_tokens += turn_output.output_tokens;
|
||||
|
||||
// Collect step
|
||||
let tool_call_count = turn_output.tool_calls.len();
|
||||
if !turn_output.tool_calls.is_empty() || !turn_output.assistant_text.is_empty() {
|
||||
all_steps.push(AgentStep {
|
||||
index: all_steps.len(),
|
||||
assistant: (!turn_output.assistant_text.is_empty())
|
||||
.then_some(turn_output.assistant_text.clone()),
|
||||
reasoning_content: None,
|
||||
tool_calls: turn_output.tool_calls,
|
||||
reflection: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Emit turn end
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::TurnEnd {
|
||||
turn_index,
|
||||
assistant_text: Some(turn_output.assistant_text.clone()),
|
||||
tool_call_count,
|
||||
});
|
||||
}
|
||||
|
||||
// Check should_stop
|
||||
let turn_ctx = TurnContext {
|
||||
turn_index,
|
||||
assistant_text: turn_output.assistant_text.clone(),
|
||||
tool_call_count,
|
||||
total_input_tokens,
|
||||
total_output_tokens,
|
||||
model_name: config.model.clone(),
|
||||
};
|
||||
|
||||
if let Some(stop_fn) = &should_stop {
|
||||
if stop_fn(&turn_ctx) {
|
||||
info!(turn_index, "agent stopped by should_stop callback");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare next turn (may switch model)
|
||||
if let Some(prep_fn) = &prepare_next {
|
||||
if let Some(update) = prep_fn(&turn_ctx).await {
|
||||
if let Some(new_model) = update.model {
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::ModelSwitched {
|
||||
from_model: config.model.clone(),
|
||||
to_model: new_model.clone(),
|
||||
reason: "prepare_next_turn".to_string(),
|
||||
});
|
||||
}
|
||||
config.model = new_model;
|
||||
}
|
||||
if let Some(temp) = update.temperature {
|
||||
config.temperature = Some(temp);
|
||||
}
|
||||
if let Some(max_tok) = update.max_completion_tokens {
|
||||
config.max_completion_tokens = Some(max_tok);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
turn_index += 1;
|
||||
|
||||
// If no tool calls, this turn is done
|
||||
if tool_call_count == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
// Otherwise, continue with tool results as new input
|
||||
current_input = turn_output.assistant_text.clone();
|
||||
}
|
||||
Err(e) => {
|
||||
// Error classification and retry with fallback
|
||||
let category = classify_error(&e, None);
|
||||
let policy = retry_policy_for(&category, config.retry_max_attempts, config.retry_base_delay_ms);
|
||||
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::ErrorClassified {
|
||||
category: format!("{category:?}"),
|
||||
message: e.to_string(),
|
||||
will_retry: policy.switch_to_fallback || policy.max_attempts > 0,
|
||||
retry_delay_ms: Some(policy.base_delay.as_millis() as u64),
|
||||
});
|
||||
}
|
||||
|
||||
if should_switch_to_fallback(&category) {
|
||||
if let Some(fallback_model) = &config.fallback_model {
|
||||
info!(
|
||||
from_model = %config.model,
|
||||
to_model = %fallback_model,
|
||||
"switching to fallback model due to error"
|
||||
);
|
||||
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::ModelSwitched {
|
||||
from_model: config.model.clone(),
|
||||
to_model: fallback_model.clone(),
|
||||
reason: format!("fallback: {category:?}"),
|
||||
});
|
||||
}
|
||||
|
||||
config.model = fallback_model.clone();
|
||||
|
||||
// Retry with the fallback model
|
||||
let retry_result = run_single_turn(
|
||||
&client,
|
||||
&config,
|
||||
¤t_input,
|
||||
&tools,
|
||||
&mut budget,
|
||||
&cancellation,
|
||||
timeout,
|
||||
&hooks,
|
||||
&event_sink,
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
|
||||
match retry_result {
|
||||
Ok(turn_output) => {
|
||||
total_input_tokens += turn_output.input_tokens;
|
||||
total_output_tokens += turn_output.output_tokens;
|
||||
let tc_count = turn_output.tool_calls.len();
|
||||
let has_tools = tc_count > 0;
|
||||
let has_text = !turn_output.assistant_text.is_empty();
|
||||
let assistant = turn_output.assistant_text;
|
||||
if has_tools || has_text {
|
||||
all_steps.push(AgentStep {
|
||||
index: all_steps.len(),
|
||||
assistant: has_text.then_some(assistant.clone()),
|
||||
reasoning_content: None,
|
||||
tool_calls: turn_output.tool_calls,
|
||||
reflection: None,
|
||||
});
|
||||
}
|
||||
turn_index += 1;
|
||||
if !has_tools {
|
||||
break;
|
||||
}
|
||||
current_input = assistant;
|
||||
continue;
|
||||
}
|
||||
Err(retry_err) => {
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::Failed {
|
||||
error: retry_err.to_string(),
|
||||
})
|
||||
.await;
|
||||
if let Some(ctx) = &request.run_context {
|
||||
let _ = hooks.run_session_end(ctx, false).await;
|
||||
}
|
||||
return Err(retry_err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Non-retryable or no fallback
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::Failed {
|
||||
error: e.to_string(),
|
||||
})
|
||||
.await;
|
||||
if let Some(ctx) = &request.run_context {
|
||||
let _ = hooks.run_session_end(ctx, false).await;
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for follow-up messages
|
||||
let follow_ups: Vec<String> = if let Some(f) = &follow_up_fn {
|
||||
f().await
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
if follow_ups.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Inject follow-up messages and continue the outer loop
|
||||
let count = follow_ups.len();
|
||||
for msg in &follow_ups {
|
||||
current_input.push_str(&format!("\nUser: {msg}\n"));
|
||||
}
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::FollowUpMessagesInjected { count });
|
||||
}
|
||||
}
|
||||
|
||||
// Build final output
|
||||
let output = all_steps
|
||||
.last()
|
||||
.and_then(|s| s.assistant.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::AgentEnd {
|
||||
messages: Vec::new(),
|
||||
total_input_tokens,
|
||||
total_output_tokens,
|
||||
});
|
||||
}
|
||||
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::Final {
|
||||
content: output.clone(),
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Some(ctx) = &request.run_context {
|
||||
let _ = hooks.run_session_end(ctx, true).await;
|
||||
}
|
||||
|
||||
info!(
|
||||
turns = turn_index,
|
||||
steps = all_steps.len(),
|
||||
total_input_tokens,
|
||||
total_output_tokens,
|
||||
"enhanced agent loop completed"
|
||||
);
|
||||
|
||||
Ok(AgentResult {
|
||||
output,
|
||||
steps: all_steps,
|
||||
expert_outputs: Vec::new(),
|
||||
input_tokens: total_input_tokens as i64,
|
||||
output_tokens: total_output_tokens as i64,
|
||||
})
|
||||
}
|
||||
|
||||
/// Output from a single LLM turn (one assistant response + its tool calls).
|
||||
struct TurnOutput {
|
||||
assistant_text: String,
|
||||
tool_calls: Vec<ToolCallRecord>,
|
||||
input_tokens: u64,
|
||||
output_tokens: u64,
|
||||
}
|
||||
|
||||
/// Run a single LLM turn with streaming, handling the stream parsing and
|
||||
/// tool call collection.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn run_single_turn(
|
||||
client: &rig::providers::openai::Client,
|
||||
config: &AgentConfig,
|
||||
input: &str,
|
||||
_tools: &[Box<dyn ToolDyn>],
|
||||
budget: &mut IterationBudget,
|
||||
cancellation: &Option<CancellationToken>,
|
||||
timeout: Option<std::time::Duration>,
|
||||
hooks: &HookChain,
|
||||
event_sink: &Option<EventSink>,
|
||||
tx: &mpsc::Sender<RigStreamChunk>,
|
||||
) -> AiResult<TurnOutput> {
|
||||
if !budget.consume() {
|
||||
return Err(AiError::Response("iteration budget exhausted".to_string()));
|
||||
}
|
||||
|
||||
let model = client.completion_model(&config.model);
|
||||
let mut agent_builder = AgentBuilder::new(model)
|
||||
.preamble(&config.system_prompt)
|
||||
.default_max_turns(1); // Single turn, we manage the loop
|
||||
|
||||
// Note: we can't easily pass tools here for single-turn since
|
||||
// rig's multi_turn handles tool execution internally.
|
||||
// For the enhanced loop, we rely on rig's built-in tool execution
|
||||
// within a single turn. The parallel/sequential mode is controlled
|
||||
// by the event-level hooks.
|
||||
|
||||
if let Some(temp) = config.temperature {
|
||||
agent_builder = agent_builder.temperature(temp);
|
||||
}
|
||||
if let Some(mt) = config.max_completion_tokens {
|
||||
agent_builder = agent_builder.max_tokens(mt);
|
||||
}
|
||||
|
||||
let agent = agent_builder.build();
|
||||
|
||||
// Pre-LLM hook
|
||||
if !hooks.is_empty() {
|
||||
let hook_messages = vec![HookMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(input.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
let _ = hooks.run_pre_llm_call(&hook_messages, &[]).await;
|
||||
}
|
||||
|
||||
let stream_future = agent
|
||||
.stream_prompt(input)
|
||||
.with_history(Vec::<rig::completion::Message>::new())
|
||||
.multi_turn(config.max_iterations);
|
||||
|
||||
let stream = if let Some(dur) = timeout {
|
||||
match tokio::time::timeout(dur, stream_future).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_) => {
|
||||
return Err(AiError::Timeout {
|
||||
seconds: dur.as_secs(),
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stream_future.await
|
||||
};
|
||||
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut assistant_text = String::new();
|
||||
let mut tool_calls: Vec<ToolCallRecord> = Vec::new();
|
||||
let mut delta_index = 0usize;
|
||||
let mut _accumulated_output_chars: usize = 0;
|
||||
let mut input_tokens: u64 = 0;
|
||||
let mut output_tokens: u64 = 0;
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
if cancellation.as_ref().is_some_and(|ct| ct.is_cancelled()) {
|
||||
return Err(AiError::Response("cancelled".to_string()));
|
||||
}
|
||||
|
||||
match item {
|
||||
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
||||
rig::streaming::StreamedAssistantContent::Text(text),
|
||||
)) => {
|
||||
_accumulated_output_chars += text.text.chars().count();
|
||||
assistant_text.push_str(&text.text);
|
||||
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::MessageTextDelta {
|
||||
index: delta_index,
|
||||
delta: text.text.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::TextDelta {
|
||||
index: delta_index,
|
||||
content: text.text.clone(),
|
||||
})
|
||||
.await;
|
||||
delta_index += 1;
|
||||
}
|
||||
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
||||
rig::streaming::StreamedAssistantContent::Reasoning(reasoning),
|
||||
)) => {
|
||||
for part in &reasoning.content {
|
||||
if let rig::completion::message::ReasoningContent::Text { text, .. } = part {
|
||||
_accumulated_output_chars += text.chars().count();
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::MessageThinkingDelta {
|
||||
index: delta_index,
|
||||
delta: text.clone(),
|
||||
});
|
||||
}
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::Thinking {
|
||||
index: delta_index,
|
||||
content: text.clone(),
|
||||
})
|
||||
.await;
|
||||
delta_index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
||||
rig::streaming::StreamedAssistantContent::ReasoningDelta { reasoning, .. },
|
||||
)) => {
|
||||
_accumulated_output_chars += reasoning.chars().count();
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::MessageThinkingDelta {
|
||||
index: delta_index,
|
||||
delta: reasoning.clone(),
|
||||
});
|
||||
}
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::Thinking {
|
||||
index: delta_index,
|
||||
content: reasoning.clone(),
|
||||
})
|
||||
.await;
|
||||
delta_index += 1;
|
||||
}
|
||||
Ok(rig::agent::MultiTurnStreamItem::StreamAssistantItem(
|
||||
rig::streaming::StreamedAssistantContent::ToolCall { tool_call, .. },
|
||||
)) => {
|
||||
let args = match &tool_call.function.arguments {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
v => serde_json::to_string(v).unwrap_or_default(),
|
||||
};
|
||||
_accumulated_output_chars += args.chars().count();
|
||||
|
||||
let tool_name = tool_call.function.name.clone();
|
||||
let tool_args: serde_json::Value =
|
||||
serde_json::from_str(&args).unwrap_or_default();
|
||||
|
||||
// Pre-tool-call guardrail hook
|
||||
if let Ok(Some(decision)) = hooks.run_pre_tool_call(&tool_name, &tool_args).await {
|
||||
match decision {
|
||||
ToolGuardrailDecision::Allow => {}
|
||||
ToolGuardrailDecision::Block { reason } => {
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::ToolExecutionEnd {
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
tool_name: tool_name.clone(),
|
||||
output: None,
|
||||
error: Some(reason.clone()),
|
||||
elapsed_ms: 0,
|
||||
});
|
||||
}
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::ToolCallFinished {
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
tool_name: tool_name.clone(),
|
||||
output: format!("blocked: {reason}"),
|
||||
error: Some(reason),
|
||||
})
|
||||
.await;
|
||||
tool_calls.push(ToolCallRecord {
|
||||
id: tool_call.id.clone(),
|
||||
name: tool_name,
|
||||
arguments: tool_args,
|
||||
output: None,
|
||||
error: Some("blocked by guardrail".to_string()),
|
||||
elapsed_ms: None,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
ToolGuardrailDecision::RequireApproval { message } => {
|
||||
tool_calls.push(ToolCallRecord {
|
||||
id: tool_call.id.clone(),
|
||||
name: tool_name.clone(),
|
||||
arguments: tool_args,
|
||||
output: None,
|
||||
error: Some(format!("requires approval: {message}")),
|
||||
elapsed_ms: None,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::ToolExecutionStart {
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
tool_name: tool_name.clone(),
|
||||
arguments: tool_args.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::ToolCallStarted {
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
tool_name: tool_name.clone(),
|
||||
arguments: args.clone(),
|
||||
})
|
||||
.await;
|
||||
tool_calls.push(ToolCallRecord {
|
||||
id: tool_call.id.clone(),
|
||||
name: tool_name,
|
||||
arguments: tool_args,
|
||||
output: None,
|
||||
error: None,
|
||||
elapsed_ms: None,
|
||||
});
|
||||
}
|
||||
Ok(rig::agent::MultiTurnStreamItem::StreamUserItem(
|
||||
rig::streaming::StreamedUserContent::ToolResult { tool_result, .. },
|
||||
)) => {
|
||||
let content =
|
||||
super::helpers::tool_result_content_to_string(&tool_result.content);
|
||||
_accumulated_output_chars += content.chars().count();
|
||||
|
||||
let tool_name = tool_calls
|
||||
.last()
|
||||
.map(|tc| tc.name.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
if let Some(last) = tool_calls.last_mut()
|
||||
&& last.id == tool_result.id
|
||||
{
|
||||
last.output = Some(serde_json::from_str(&content).unwrap_or_default());
|
||||
}
|
||||
|
||||
if let Some(sink) = &event_sink {
|
||||
sink.emit(AgentEvent::ToolExecutionEnd {
|
||||
tool_call_id: tool_result.id.clone(),
|
||||
tool_name: tool_name.clone(),
|
||||
output: Some(serde_json::Value::String(content.clone())),
|
||||
error: None,
|
||||
elapsed_ms: 0,
|
||||
});
|
||||
}
|
||||
|
||||
let _ = tx
|
||||
.send(RigStreamChunk::ToolCallFinished {
|
||||
tool_call_id: tool_result.id.clone(),
|
||||
tool_name,
|
||||
output: content.clone(),
|
||||
error: None,
|
||||
})
|
||||
.await;
|
||||
|
||||
if !hooks.is_empty() {
|
||||
let outcome = ToolCallOutcome {
|
||||
name: tool_result.id.clone(),
|
||||
arguments: serde_json::Value::Null,
|
||||
output: Some(serde_json::Value::String(content)),
|
||||
error: None,
|
||||
elapsed_ms: 0,
|
||||
};
|
||||
let _ = hooks.run_post_tool_call(&outcome).await;
|
||||
}
|
||||
}
|
||||
Ok(rig::agent::MultiTurnStreamItem::FinalResponse(resp)) => {
|
||||
let usage = resp.usage();
|
||||
input_tokens = usage.input_tokens;
|
||||
output_tokens = usage.output_tokens;
|
||||
|
||||
if !hooks.is_empty() {
|
||||
let hook_response = HookLlmResponse {
|
||||
content: Some(assistant_text.clone()),
|
||||
tool_calls: None,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
finish_reason: None,
|
||||
};
|
||||
let _ = hooks.run_post_llm_call(&hook_response).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "turn stream error");
|
||||
return Err(AiError::Api(format!("{e}")));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(TurnOutput {
|
||||
assistant_text,
|
||||
tool_calls,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
98
lib/ai/agent/mod.rs
Normal file
98
lib/ai/agent/mod.rs
Normal file
@ -0,0 +1,98 @@
|
||||
pub mod agent;
|
||||
pub mod compression;
|
||||
pub mod config;
|
||||
pub mod error_classifier;
|
||||
pub mod events;
|
||||
pub mod helpers;
|
||||
pub mod hooks;
|
||||
pub mod iteration_budget;
|
||||
pub mod r#loop;
|
||||
pub mod persistence;
|
||||
pub mod prompt;
|
||||
pub mod prompt_builder;
|
||||
pub mod request;
|
||||
pub mod session;
|
||||
pub mod subagent;
|
||||
pub mod tool;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use agent::RigAgent;
|
||||
pub use compression::{
|
||||
CompactionResult, CompressionStrategy, build_branch_summary_prompt,
|
||||
build_compression_prompt, build_compression_prompt_with_options,
|
||||
estimate_truncation, plan_compaction,
|
||||
};
|
||||
pub use config::AgentConfig;
|
||||
pub use error_classifier::{
|
||||
ErrorCategory, RetryPolicy, classify_error, retry_policy_for,
|
||||
should_compact_before_retry, should_switch_to_fallback,
|
||||
};
|
||||
pub use events::{
|
||||
AgentEvent, AgentEventMessage, EventSink, EventToolCall, MessageRole,
|
||||
};
|
||||
pub use hooks::{AgentHook, HookChain, ToolCallOutcome, ToolGuardrailDecision};
|
||||
pub use iteration_budget::IterationBudget;
|
||||
pub use r#loop::{
|
||||
AgentLoopConfig, EnhancedAgent, PrepareNextTurnFn, ShouldStopFn,
|
||||
ToolExecutionMode, TurnContext, TurnUpdate,
|
||||
};
|
||||
pub use persistence::{
|
||||
AgentRealtime, AgentRunContext, AgentRuntime, AgentStreamEvent,
|
||||
};
|
||||
pub use prompt_builder::SystemPromptBuilder;
|
||||
pub use request::{
|
||||
AgentContextChunk, AgentExpert, AgentExpertOutput, AgentMessage,
|
||||
AgentRequest, AgentResult, AgentStep, ToolCallRecord,
|
||||
};
|
||||
pub use session::{
|
||||
CompactionOptions, Session, SessionEntry, SessionHeader,
|
||||
SessionMessageRole, SessionToolCall, SessionToolResult,
|
||||
};
|
||||
pub use tool::{RigTool, RigToolSet};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum RigStreamChunk {
|
||||
TextDelta {
|
||||
index: usize,
|
||||
content: String,
|
||||
},
|
||||
Thinking {
|
||||
index: usize,
|
||||
content: String,
|
||||
},
|
||||
ToolCallStarted {
|
||||
tool_call_id: String,
|
||||
tool_name: String,
|
||||
arguments: String,
|
||||
},
|
||||
ToolCallFinished {
|
||||
tool_call_id: String,
|
||||
tool_name: String,
|
||||
output: String,
|
||||
error: Option<String>,
|
||||
},
|
||||
SubagentStarted {
|
||||
subagent_id: String,
|
||||
role: String,
|
||||
task: String,
|
||||
},
|
||||
SubagentCompleted {
|
||||
subagent_id: String,
|
||||
role: String,
|
||||
task: String,
|
||||
output: String,
|
||||
},
|
||||
SubagentFailed {
|
||||
error: String,
|
||||
},
|
||||
Final {
|
||||
content: String,
|
||||
input_tokens: u64,
|
||||
output_tokens: u64,
|
||||
},
|
||||
Failed {
|
||||
error: String,
|
||||
},
|
||||
}
|
||||
33
lib/ai/agent/persistence/db.rs
Normal file
33
lib/ai/agent/persistence/db.rs
Normal file
@ -0,0 +1,33 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::agent::persistence::types::{ActiveAgentRun, AgentRunContext};
|
||||
use crate::error::AiResult;
|
||||
|
||||
impl super::types::AgentRuntime {
|
||||
pub async fn start_run(
|
||||
&self,
|
||||
run_context: Option<&AgentRunContext>,
|
||||
) -> AiResult<ActiveAgentRun> {
|
||||
let Some(run_context) = run_context else {
|
||||
return Ok(ActiveAgentRun {
|
||||
conversation_id: None,
|
||||
message_id: None,
|
||||
invocation_id: None,
|
||||
session_id: None,
|
||||
user_id: None,
|
||||
started_at: Instant::now(),
|
||||
current_step: 0,
|
||||
});
|
||||
};
|
||||
|
||||
Ok(ActiveAgentRun {
|
||||
conversation_id: run_context.conversation_id,
|
||||
message_id: None,
|
||||
invocation_id: run_context.invocation_id,
|
||||
session_id: run_context.session_id,
|
||||
user_id: run_context.user_id,
|
||||
started_at: Instant::now(),
|
||||
current_step: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
8
lib/ai/agent/persistence/mod.rs
Normal file
8
lib/ai/agent/persistence/mod.rs
Normal file
@ -0,0 +1,8 @@
|
||||
pub mod db;
|
||||
pub mod realtime;
|
||||
pub mod types;
|
||||
|
||||
pub use types::{
|
||||
ActiveAgentRun, AgentRealtime, AgentRunContext, AgentRuntime,
|
||||
AgentStreamEvent, estimate_output_tokens,
|
||||
};
|
||||
38
lib/ai/agent/persistence/realtime.rs
Normal file
38
lib/ai/agent/persistence/realtime.rs
Normal file
@ -0,0 +1,38 @@
|
||||
use crate::agent::persistence::types::{
|
||||
AgentRealtime, AgentRuntime, AgentStreamEvent,
|
||||
};
|
||||
use crate::error::AiResult;
|
||||
|
||||
pub async fn publish_event(
|
||||
runtime: &AgentRuntime,
|
||||
_realtime: Option<&AgentRealtime>,
|
||||
event: &AgentStreamEvent,
|
||||
) -> AiResult<()> {
|
||||
let Some(tx) = &runtime.tx else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let payload = match serde_json::to_string(event) {
|
||||
Ok(p) => p,
|
||||
Err(error) => {
|
||||
tracing::warn!(error = %error, "agent sse: serialize failed");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
if tx.send(payload).is_err() {
|
||||
tracing::debug!("agent sse: mpsc send failed, client disconnected");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl AgentRuntime {
|
||||
pub async fn publish(
|
||||
&self,
|
||||
realtime: Option<&AgentRealtime>,
|
||||
event: &AgentStreamEvent,
|
||||
) -> AiResult<()> {
|
||||
publish_event(self, realtime, event).await
|
||||
}
|
||||
}
|
||||
177
lib/ai/agent/persistence/types.rs
Normal file
177
lib/ai/agent/persistence/types.rs
Normal file
@ -0,0 +1,177 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AgentRealtime {
|
||||
pub channel: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AgentRunContext {
|
||||
pub conversation_id: Option<Uuid>,
|
||||
pub invocation_id: Option<Uuid>,
|
||||
pub session_id: Option<Uuid>,
|
||||
pub user_id: Option<Uuid>,
|
||||
pub realtime: Option<AgentRealtime>,
|
||||
}
|
||||
|
||||
impl AgentRunContext {
|
||||
pub fn new(user_id: Uuid) -> Self {
|
||||
Self {
|
||||
conversation_id: None,
|
||||
invocation_id: None,
|
||||
session_id: None,
|
||||
user_id: Some(user_id),
|
||||
realtime: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_conversation_id(mut self, id: Uuid) -> Self {
|
||||
self.conversation_id = Some(id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_invocation_id(mut self, id: Uuid) -> Self {
|
||||
self.invocation_id = Some(id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_session_id(mut self, id: Uuid) -> Self {
|
||||
self.session_id = Some(id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_realtime(mut self, realtime: AgentRealtime) -> Self {
|
||||
self.realtime = Some(realtime);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum AgentStreamEvent {
|
||||
Started {
|
||||
conversation_id: Option<Uuid>,
|
||||
message_id: Option<Uuid>,
|
||||
session_id: Option<Uuid>,
|
||||
model: String,
|
||||
},
|
||||
Delta {
|
||||
conversation_id: Option<Uuid>,
|
||||
message_id: Option<Uuid>,
|
||||
index: usize,
|
||||
content: String,
|
||||
},
|
||||
Thinking {
|
||||
conversation_id: Option<Uuid>,
|
||||
message_id: Option<Uuid>,
|
||||
index: usize,
|
||||
content: String,
|
||||
},
|
||||
ToolCallStarted {
|
||||
conversation_id: Option<Uuid>,
|
||||
message_id: Option<Uuid>,
|
||||
session_id: Option<Uuid>,
|
||||
tool_call_id: String,
|
||||
tool_name: String,
|
||||
arguments: Value,
|
||||
},
|
||||
ToolCallFinished {
|
||||
conversation_id: Option<Uuid>,
|
||||
message_id: Option<Uuid>,
|
||||
session_id: Option<Uuid>,
|
||||
tool_call_id: String,
|
||||
tool_name: String,
|
||||
output: Option<Value>,
|
||||
error: Option<String>,
|
||||
execution_time_ms: i64,
|
||||
},
|
||||
SubagentStarted {
|
||||
conversation_id: Option<Uuid>,
|
||||
message_id: Option<Uuid>,
|
||||
subagent_id: String,
|
||||
role: String,
|
||||
task: String,
|
||||
model: String,
|
||||
},
|
||||
SubagentDelta {
|
||||
conversation_id: Option<Uuid>,
|
||||
message_id: Option<Uuid>,
|
||||
subagent_id: String,
|
||||
index: usize,
|
||||
content: String,
|
||||
},
|
||||
SubagentCompleted {
|
||||
conversation_id: Option<Uuid>,
|
||||
message_id: Option<Uuid>,
|
||||
subagent_id: String,
|
||||
role: String,
|
||||
task: String,
|
||||
output: String,
|
||||
input_tokens: i64,
|
||||
output_tokens: i64,
|
||||
model: String,
|
||||
},
|
||||
SubagentFailed {
|
||||
conversation_id: Option<Uuid>,
|
||||
message_id: Option<Uuid>,
|
||||
subagent_id: String,
|
||||
error: String,
|
||||
},
|
||||
Completed {
|
||||
conversation_id: Option<Uuid>,
|
||||
message_id: Option<Uuid>,
|
||||
session_id: Option<Uuid>,
|
||||
output: String,
|
||||
input_tokens: i64,
|
||||
output_tokens: i64,
|
||||
latency_ms: i32,
|
||||
stop_reason: Option<String>,
|
||||
},
|
||||
Failed {
|
||||
conversation_id: Option<Uuid>,
|
||||
message_id: Option<Uuid>,
|
||||
session_id: Option<Uuid>,
|
||||
error: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AgentRuntime {
|
||||
pub tx: Option<mpsc::UnboundedSender<String>>,
|
||||
}
|
||||
|
||||
impl AgentRuntime {
|
||||
pub fn new(tx: mpsc::UnboundedSender<String>) -> Self {
|
||||
Self { tx: Some(tx) }
|
||||
}
|
||||
|
||||
pub fn empty() -> Self {
|
||||
Self { tx: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AgentRuntime {
|
||||
fn default() -> Self {
|
||||
Self::empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ActiveAgentRun {
|
||||
pub conversation_id: Option<Uuid>,
|
||||
pub message_id: Option<Uuid>,
|
||||
pub invocation_id: Option<Uuid>,
|
||||
pub session_id: Option<Uuid>,
|
||||
pub user_id: Option<Uuid>,
|
||||
pub started_at: Instant,
|
||||
pub current_step: usize,
|
||||
}
|
||||
|
||||
pub fn estimate_output_tokens(output: &str) -> i64 {
|
||||
(output.chars().count() as f64 / 4.0).ceil() as i64
|
||||
}
|
||||
57
lib/ai/agent/prompt.rs
Normal file
57
lib/ai/agent/prompt.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use rig::agent::AgentBuilder;
|
||||
use rig::client::CompletionClient;
|
||||
use rig::completion::Prompt;
|
||||
|
||||
use super::agent::RigAgent;
|
||||
use super::helpers::with_retry;
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
impl RigAgent {
|
||||
pub async fn prompt(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
user_input: &str,
|
||||
) -> AiResult<(String, u64, u64)> {
|
||||
let model_name = self.config.model.clone();
|
||||
let client = self.client.llm_client().clone();
|
||||
let temperature = self.config.temperature;
|
||||
let max_completion_tokens = self.config.max_completion_tokens;
|
||||
let retry_attempts = self.config.retry_max_attempts;
|
||||
let retry_delay_ms = self.config.retry_base_delay_ms;
|
||||
let sp = system_prompt.to_string();
|
||||
let ui = user_input.to_string();
|
||||
|
||||
with_retry(retry_attempts, retry_delay_ms, || {
|
||||
let client = client.clone();
|
||||
let model_name = model_name.clone();
|
||||
let sp = sp.clone();
|
||||
let ui = ui.clone();
|
||||
async move {
|
||||
let model = client.completion_model(&model_name);
|
||||
let mut builder = AgentBuilder::new(model).preamble(&sp);
|
||||
if let Some(temp) = temperature {
|
||||
builder = builder.temperature(temp);
|
||||
}
|
||||
if let Some(mt) = max_completion_tokens {
|
||||
builder = builder.max_tokens(mt);
|
||||
}
|
||||
let agent = builder.build();
|
||||
|
||||
let response = agent
|
||||
.prompt(&ui)
|
||||
.extended_details()
|
||||
.await
|
||||
.map_err(|e: rig::completion::PromptError| {
|
||||
AiError::Api(e.to_string())
|
||||
})?;
|
||||
|
||||
Ok((
|
||||
response.output,
|
||||
response.usage.input_tokens,
|
||||
response.usage.output_tokens,
|
||||
))
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
251
lib/ai/agent/prompt_builder.rs
Normal file
251
lib/ai/agent/prompt_builder.rs
Normal file
@ -0,0 +1,251 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Modular system prompt builder inspired by pi's `buildSystemPrompt`.
|
||||
///
|
||||
/// Supports:
|
||||
/// - Base prompt (replaceable or appendable)
|
||||
/// - Tool snippets injected into an "Available tools" section
|
||||
/// - Project context files (AGENTS.md, etc.)
|
||||
/// - Skills injection
|
||||
/// - Variable substitution ({{key}})
|
||||
/// - Metadata (date)
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use ai::agent::prompt_builder::SystemPromptBuilder;
|
||||
///
|
||||
/// let prompt = SystemPromptBuilder::new()
|
||||
/// .base_prompt("You are a helpful assistant.")
|
||||
/// .tool_snippet("bash", "Execute shell commands")
|
||||
/// .tool_snippet("read", "Read file contents")
|
||||
/// .project_context("AGENTS.md", "# Project Rules\n- Follow conventions")
|
||||
/// .variable("repo_name", "gitdataai")
|
||||
/// .build();
|
||||
/// ```
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SystemPromptBuilder {
|
||||
base_prompt: Option<String>,
|
||||
append_prompt: Option<String>,
|
||||
tool_snippets: Vec<(String, String)>,
|
||||
tool_guidelines: Vec<String>,
|
||||
project_contexts: Vec<(String, String)>,
|
||||
skills: Vec<String>,
|
||||
variables: HashMap<String, String>,
|
||||
date: Option<String>,
|
||||
custom_sections: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl SystemPromptBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
base_prompt: None,
|
||||
append_prompt: None,
|
||||
tool_snippets: Vec::new(),
|
||||
tool_guidelines: Vec::new(),
|
||||
project_contexts: Vec::new(),
|
||||
skills: Vec::new(),
|
||||
variables: HashMap::new(),
|
||||
date: None,
|
||||
custom_sections: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the base system prompt. Replaces the default prompt.
|
||||
pub fn base_prompt(mut self, prompt: impl Into<String>) -> Self {
|
||||
self.base_prompt = Some(prompt.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Append additional text to the system prompt after the base.
|
||||
pub fn append_prompt(mut self, text: impl Into<String>) -> Self {
|
||||
self.append_prompt = Some(text.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a one-line tool description snippet.
|
||||
pub fn tool_snippet(mut self, tool_name: impl Into<String>, description: impl Into<String>) -> Self {
|
||||
self.tool_snippets.push((tool_name.into(), description.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a guideline bullet for the tools section.
|
||||
pub fn tool_guideline(mut self, guideline: impl Into<String>) -> Self {
|
||||
self.tool_guidelines.push(guideline.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a project context file (e.g., AGENTS.md content).
|
||||
pub fn project_context(mut self, path: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
self.project_contexts.push((path.into(), content.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a skill description to inject into the prompt.
|
||||
pub fn skill(mut self, skill_description: impl Into<String>) -> Self {
|
||||
self.skills.push(skill_description.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a variable for {{key}} substitution.
|
||||
pub fn variable(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||
self.variables.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set multiple variables from an iterator.
|
||||
pub fn variables(mut self, vars: impl IntoIterator<Item = (String, String)>) -> Self {
|
||||
self.variables.extend(vars);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the date metadata (ISO format: YYYY-MM-DD).
|
||||
pub fn date(mut self, date: impl Into<String>) -> Self {
|
||||
self.date = Some(date.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a custom named section to the prompt.
|
||||
pub fn custom_section(mut self, name: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
self.custom_sections.push((name.into(), content.into()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the final system prompt string.
|
||||
pub fn build(self) -> String {
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
|
||||
// 1. Base prompt
|
||||
if let Some(base) = &self.base_prompt {
|
||||
parts.push(base.clone());
|
||||
}
|
||||
|
||||
// 2. Append prompt
|
||||
if let Some(append) = &self.append_prompt {
|
||||
parts.push(append.clone());
|
||||
}
|
||||
|
||||
// 3. Tool snippets section
|
||||
if !self.tool_snippets.is_empty() {
|
||||
let mut section = String::from("\n## Available Tools\n");
|
||||
for (name, desc) in &self.tool_snippets {
|
||||
section.push_str(&format!("- `{name}`: {desc}\n"));
|
||||
}
|
||||
if !self.tool_guidelines.is_empty() {
|
||||
section.push_str("\n### Tool Guidelines\n");
|
||||
for guideline in &self.tool_guidelines {
|
||||
section.push_str(&format!("- {guideline}\n"));
|
||||
}
|
||||
}
|
||||
parts.push(section);
|
||||
}
|
||||
|
||||
// 4. Project context files
|
||||
if !self.project_contexts.is_empty() {
|
||||
let mut section = String::from("\n<project_context>\n\n");
|
||||
section.push_str("Project-specific instructions and guidelines:\n\n");
|
||||
for (path, content) in &self.project_contexts {
|
||||
section.push_str(&format!("<project_instructions path=\"{path}\">\n{content}\n</project_instructions>\n\n"));
|
||||
}
|
||||
section.push_str("</project_context>");
|
||||
parts.push(section);
|
||||
}
|
||||
|
||||
// 5. Skills section
|
||||
if !self.skills.is_empty() {
|
||||
let mut section = String::from("\n## Available Skills\n");
|
||||
for skill in &self.skills {
|
||||
section.push_str(&format!("{skill}\n"));
|
||||
}
|
||||
parts.push(section);
|
||||
}
|
||||
|
||||
// 6. Custom sections
|
||||
for (name, content) in &self.custom_sections {
|
||||
parts.push(format!("\n## {name}\n{content}"));
|
||||
}
|
||||
|
||||
// 7. Metadata footer
|
||||
if let Some(date) = &self.date {
|
||||
parts.push(format!("\nCurrent date: {date}"));
|
||||
}
|
||||
|
||||
let mut result = parts.join("\n");
|
||||
|
||||
// 8. Variable substitution
|
||||
for (key, value) in &self.variables {
|
||||
let placeholder = format!("{{{{{}}}}}", key);
|
||||
result = result.replace(&placeholder, value);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SystemPromptBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_basic_build() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.base_prompt("You are a helpful assistant.")
|
||||
.date("2026-05-29")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("You are a helpful assistant."));
|
||||
assert!(prompt.contains("Current date: 2026-05-29"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_variable_substitution() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.base_prompt("Repo: {{repo_name}}, User: {{user}}")
|
||||
.variable("repo_name", "gitdataai")
|
||||
.variable("user", "zhenyi")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("Repo: gitdataai"));
|
||||
assert!(prompt.contains("User: zhenyi"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_snippets() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.base_prompt("Agent prompt.")
|
||||
.tool_snippet("bash", "Execute shell commands")
|
||||
.tool_snippet("read", "Read file contents")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("## Available Tools"));
|
||||
assert!(prompt.contains("`bash`: Execute shell commands"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_project_context() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.base_prompt("Base.")
|
||||
.project_context("AGENTS.md", "# Rules\n- Follow conventions")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("<project_context>"));
|
||||
assert!(prompt.contains("AGENTS.md"));
|
||||
assert!(prompt.contains("Follow conventions"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_section() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.base_prompt("Base.")
|
||||
.custom_section("Memory", "Remember: user prefers Rust")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("## Memory"));
|
||||
assert!(prompt.contains("user prefers Rust"));
|
||||
}
|
||||
}
|
||||
240
lib/ai/agent/request.rs
Normal file
240
lib/ai/agent/request.rs
Normal file
@ -0,0 +1,240 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::persistence::AgentRunContext;
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AgentRequest {
|
||||
pub input: String,
|
||||
pub messages: Vec<AgentMessage>,
|
||||
pub context: Vec<AgentContextChunk>,
|
||||
pub experts: Vec<AgentExpert>,
|
||||
pub run_context: Option<AgentRunContext>,
|
||||
#[serde(skip)]
|
||||
pub prefill_messages: Vec<rig::completion::Message>,
|
||||
#[serde(skip)]
|
||||
pub cancellation_token: Option<CancellationToken>,
|
||||
#[serde(skip)]
|
||||
pub timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl AgentRequest {
|
||||
pub fn new(input: impl Into<String>) -> Self {
|
||||
Self {
|
||||
input: input.into(),
|
||||
messages: Vec::new(),
|
||||
context: Vec::new(),
|
||||
experts: Vec::new(),
|
||||
run_context: None,
|
||||
prefill_messages: Vec::new(),
|
||||
cancellation_token: None,
|
||||
timeout: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> AiResult<()> {
|
||||
if self.input.trim().is_empty() {
|
||||
return Err(AiError::Config("agent request input is required".to_string()));
|
||||
}
|
||||
if self.input.len() > 1_000_000 {
|
||||
return Err(AiError::Config(
|
||||
"agent request input exceeds maximum length (1MB)".to_string(),
|
||||
));
|
||||
}
|
||||
if self.experts.len() > 32 {
|
||||
return Err(AiError::Config(
|
||||
"agent request experts count exceeds maximum (32)".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn with_messages(mut self, messages: Vec<AgentMessage>) -> Self {
|
||||
self.messages = messages;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_context(mut self, context: Vec<AgentContextChunk>) -> Self {
|
||||
self.context = context;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn add_context(mut self, chunk: AgentContextChunk) -> Self {
|
||||
self.context.push(chunk);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_experts(mut self, experts: Vec<AgentExpert>) -> Self {
|
||||
self.experts = experts;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn add_expert(mut self, expert: AgentExpert) -> Self {
|
||||
self.experts.push(expert);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_run_context(mut self, run_context: AgentRunContext) -> Self {
|
||||
self.run_context = Some(run_context);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_prefill_messages(mut self, prefill_messages: Vec<rig::completion::Message>) -> Self {
|
||||
self.prefill_messages = prefill_messages;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_cancellation_token(mut self, cancellation_token: CancellationToken) -> Self {
|
||||
self.cancellation_token = Some(cancellation_token);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.timeout = Some(timeout);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum AgentMessage {
|
||||
User(String),
|
||||
Assistant(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AgentExpert {
|
||||
pub id: String,
|
||||
pub role: String,
|
||||
pub task: String,
|
||||
pub system_prompt: Option<String>,
|
||||
pub context: Vec<AgentContextChunk>,
|
||||
/// Override the master agent's temperature for this subagent.
|
||||
pub temperature: Option<f64>,
|
||||
/// Override the master agent's max_completion_tokens for this subagent.
|
||||
pub max_completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
impl AgentExpert {
|
||||
pub fn new(id: impl Into<String>, role: impl Into<String>, task: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id: id.into(),
|
||||
role: role.into(),
|
||||
task: task.into(),
|
||||
system_prompt: None,
|
||||
context: Vec::new(),
|
||||
temperature: None,
|
||||
max_completion_tokens: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
|
||||
self.system_prompt = Some(system_prompt.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_context(mut self, context: Vec<AgentContextChunk>) -> Self {
|
||||
self.context = context;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temperature: f64) -> Self {
|
||||
self.temperature = Some(temperature);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_completion_tokens(mut self, max_tokens: u64) -> Self {
|
||||
self.max_completion_tokens = Some(max_tokens);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AgentContextChunk {
|
||||
pub id: String,
|
||||
pub content: String,
|
||||
pub source: Option<String>,
|
||||
pub score: Option<f32>,
|
||||
pub metadata: Value,
|
||||
}
|
||||
|
||||
impl AgentContextChunk {
|
||||
pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id: id.into(),
|
||||
content: content.into(),
|
||||
source: None,
|
||||
score: None,
|
||||
metadata: Value::Null,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::rag::RagSearchHit> for AgentContextChunk {
|
||||
fn from(hit: crate::rag::RagSearchHit) -> Self {
|
||||
Self {
|
||||
id: hit.id,
|
||||
content: hit.content,
|
||||
source: Some(hit.session_id),
|
||||
score: Some(hit.score),
|
||||
metadata: Value::Object(hit.metadata.into_iter().collect()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&AgentExpertOutput> for AgentContextChunk {
|
||||
fn from(output: &AgentExpertOutput) -> Self {
|
||||
Self {
|
||||
id: format!("subagent:{}", output.id),
|
||||
content: output.output.clone(),
|
||||
source: Some(output.role.clone()),
|
||||
score: None,
|
||||
metadata: serde_json::json!({
|
||||
"kind": "subagent",
|
||||
"task": output.task,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AgentResult {
|
||||
pub output: String,
|
||||
pub steps: Vec<AgentStep>,
|
||||
pub expert_outputs: Vec<AgentExpertOutput>,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AgentStep {
|
||||
pub index: usize,
|
||||
pub assistant: Option<String>,
|
||||
pub reasoning_content: Option<String>,
|
||||
pub tool_calls: Vec<ToolCallRecord>,
|
||||
pub reflection: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ToolCallRecord {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
pub output: Option<Value>,
|
||||
pub error: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub elapsed_ms: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AgentExpertOutput {
|
||||
pub id: String,
|
||||
pub role: String,
|
||||
pub task: String,
|
||||
pub output: String,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
}
|
||||
535
lib/ai/agent/session.rs
Normal file
535
lib/ai/agent/session.rs
Normal file
@ -0,0 +1,535 @@
|
||||
use std::time::SystemTime;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
/// Current session file format version.
|
||||
pub const SESSION_VERSION: u32 = 2;
|
||||
|
||||
/// Session metadata header.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionHeader {
|
||||
pub version: u32,
|
||||
pub id: Uuid,
|
||||
pub created_at: String,
|
||||
pub parent_session: Option<Uuid>,
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
impl SessionHeader {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
version: SESSION_VERSION,
|
||||
id: Uuid::new_v4(),
|
||||
created_at: iso_now(),
|
||||
parent_session: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_parent(mut self, parent: Uuid) -> Self {
|
||||
self.parent_session = Some(parent);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_name(mut self, name: impl Into<String>) -> Self {
|
||||
self.name = Some(name.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Typed session entry — each entry in a session transcript is one of these variants.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum SessionEntry {
|
||||
/// A user or assistant message.
|
||||
Message {
|
||||
id: Uuid,
|
||||
parent_id: Option<Uuid>,
|
||||
timestamp: String,
|
||||
role: SessionMessageRole,
|
||||
content: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<SessionToolCall>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
tool_result: Option<SessionToolResult>,
|
||||
},
|
||||
|
||||
/// A context compaction event (older messages summarized).
|
||||
Compaction {
|
||||
id: Uuid,
|
||||
parent_id: Option<Uuid>,
|
||||
timestamp: String,
|
||||
summary: String,
|
||||
first_kept_entry_id: Uuid,
|
||||
messages_compacted: usize,
|
||||
tokens_saved: i64,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
details: Option<Value>,
|
||||
},
|
||||
|
||||
/// A branch summary (created when forking from a different point in the tree).
|
||||
BranchSummary {
|
||||
id: Uuid,
|
||||
parent_id: Option<Uuid>,
|
||||
timestamp: String,
|
||||
from_entry_id: Uuid,
|
||||
summary: String,
|
||||
entries_summarized: usize,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
label: Option<String>,
|
||||
},
|
||||
|
||||
/// Model change during a session.
|
||||
ModelChange {
|
||||
id: Uuid,
|
||||
parent_id: Option<Uuid>,
|
||||
timestamp: String,
|
||||
provider: String,
|
||||
model_id: String,
|
||||
},
|
||||
|
||||
/// Thinking level change during a session.
|
||||
ThinkingLevelChange {
|
||||
id: Uuid,
|
||||
parent_id: Option<Uuid>,
|
||||
timestamp: String,
|
||||
level: String,
|
||||
},
|
||||
|
||||
/// Custom extension data (not sent to LLM).
|
||||
Custom {
|
||||
id: Uuid,
|
||||
parent_id: Option<Uuid>,
|
||||
timestamp: String,
|
||||
custom_type: String,
|
||||
data: Option<Value>,
|
||||
},
|
||||
}
|
||||
|
||||
impl SessionEntry {
|
||||
pub fn id(&self) -> Uuid {
|
||||
match self {
|
||||
Self::Message { id, .. }
|
||||
| Self::Compaction { id, .. }
|
||||
| Self::BranchSummary { id, .. }
|
||||
| Self::ModelChange { id, .. }
|
||||
| Self::ThinkingLevelChange { id, .. }
|
||||
| Self::Custom { id, .. } => *id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parent_id(&self) -> Option<Uuid> {
|
||||
match self {
|
||||
Self::Message { parent_id, .. }
|
||||
| Self::Compaction { parent_id, .. }
|
||||
| Self::BranchSummary { parent_id, .. }
|
||||
| Self::ModelChange { parent_id, .. }
|
||||
| Self::ThinkingLevelChange { parent_id, .. }
|
||||
| Self::Custom { parent_id, .. } => *parent_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn timestamp(&self) -> &str {
|
||||
match self {
|
||||
Self::Message { timestamp, .. }
|
||||
| Self::Compaction { timestamp, .. }
|
||||
| Self::BranchSummary { timestamp, .. }
|
||||
| Self::ModelChange { timestamp, .. }
|
||||
| Self::ThinkingLevelChange { timestamp, .. }
|
||||
| Self::Custom { timestamp, .. } => timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a user message entry.
|
||||
pub fn user_message(parent_id: Option<Uuid>, content: impl Into<String>) -> Self {
|
||||
Self::Message {
|
||||
id: Uuid::new_v4(),
|
||||
parent_id,
|
||||
timestamp: iso_now(),
|
||||
role: SessionMessageRole::User,
|
||||
content: content.into(),
|
||||
tool_calls: None,
|
||||
tool_result: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an assistant message entry.
|
||||
pub fn assistant_message(
|
||||
parent_id: Option<Uuid>,
|
||||
content: impl Into<String>,
|
||||
tool_calls: Option<Vec<SessionToolCall>>,
|
||||
) -> Self {
|
||||
Self::Message {
|
||||
id: Uuid::new_v4(),
|
||||
parent_id,
|
||||
timestamp: iso_now(),
|
||||
role: SessionMessageRole::Assistant,
|
||||
content: content.into(),
|
||||
tool_calls,
|
||||
tool_result: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a compaction entry.
|
||||
pub fn compaction(
|
||||
parent_id: Option<Uuid>,
|
||||
summary: impl Into<String>,
|
||||
first_kept_entry_id: Uuid,
|
||||
messages_compacted: usize,
|
||||
tokens_saved: i64,
|
||||
) -> Self {
|
||||
Self::Compaction {
|
||||
id: Uuid::new_v4(),
|
||||
parent_id,
|
||||
timestamp: iso_now(),
|
||||
summary: summary.into(),
|
||||
first_kept_entry_id,
|
||||
messages_compacted,
|
||||
tokens_saved,
|
||||
details: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a branch summary entry.
|
||||
pub fn branch_summary(
|
||||
parent_id: Option<Uuid>,
|
||||
from_entry_id: Uuid,
|
||||
summary: impl Into<String>,
|
||||
entries_summarized: usize,
|
||||
label: Option<String>,
|
||||
) -> Self {
|
||||
Self::BranchSummary {
|
||||
id: Uuid::new_v4(),
|
||||
parent_id,
|
||||
timestamp: iso_now(),
|
||||
from_entry_id,
|
||||
summary: summary.into(),
|
||||
entries_summarized,
|
||||
label,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a model change entry.
|
||||
pub fn model_change(
|
||||
parent_id: Option<Uuid>,
|
||||
provider: impl Into<String>,
|
||||
model_id: impl Into<String>,
|
||||
) -> Self {
|
||||
Self::ModelChange {
|
||||
id: Uuid::new_v4(),
|
||||
parent_id,
|
||||
timestamp: iso_now(),
|
||||
provider: provider.into(),
|
||||
model_id: model_id.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a custom extension entry.
|
||||
pub fn custom(
|
||||
parent_id: Option<Uuid>,
|
||||
custom_type: impl Into<String>,
|
||||
data: Option<Value>,
|
||||
) -> Self {
|
||||
Self::Custom {
|
||||
id: Uuid::new_v4(),
|
||||
parent_id,
|
||||
timestamp: iso_now(),
|
||||
custom_type: custom_type.into(),
|
||||
data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SessionMessageRole {
|
||||
User,
|
||||
Assistant,
|
||||
ToolResult,
|
||||
System,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionToolCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionToolResult {
|
||||
pub tool_call_id: String,
|
||||
pub tool_name: String,
|
||||
pub content: String,
|
||||
pub is_error: bool,
|
||||
}
|
||||
|
||||
/// A full session: header + ordered list of entries forming a tree.
|
||||
///
|
||||
/// The tree structure supports forking: entries share `parent_id` links,
|
||||
/// and the "active branch" is determined by following from the leaf
|
||||
/// back to the root.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Session {
|
||||
pub header: SessionHeader,
|
||||
pub entries: Vec<SessionEntry>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
header: SessionHeader::new(),
|
||||
entries: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_name(mut self, name: impl Into<String>) -> Self {
|
||||
self.header = self.header.with_name(name);
|
||||
self
|
||||
}
|
||||
|
||||
/// Append an entry to the session.
|
||||
pub fn push(&mut self, entry: SessionEntry) {
|
||||
self.entries.push(entry);
|
||||
}
|
||||
|
||||
/// Get the last entry's id (used as parent_id for the next entry).
|
||||
pub fn last_entry_id(&self) -> Option<Uuid> {
|
||||
self.entries.last().map(|e| e.id())
|
||||
}
|
||||
|
||||
/// Get all entries on the active branch (from root to leaf).
|
||||
pub fn active_branch(&self) -> Vec<&SessionEntry> {
|
||||
if self.entries.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut branch = Vec::new();
|
||||
let mut current_id = Some(self.entries.last().unwrap().id());
|
||||
|
||||
while let Some(id) = current_id {
|
||||
if let Some(entry) = self.entries.iter().find(|e| e.id() == id) {
|
||||
branch.push(entry);
|
||||
current_id = entry.parent_id();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
branch.reverse();
|
||||
branch
|
||||
}
|
||||
|
||||
/// Get all message entries on the active branch (for LLM context).
|
||||
pub fn active_messages(&self) -> Vec<&SessionEntry> {
|
||||
self.active_branch()
|
||||
.into_iter()
|
||||
.filter(|e| matches!(e, SessionEntry::Message { .. } | SessionEntry::Compaction { .. }))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Find all children of a given entry (for tree navigation).
|
||||
pub fn children_of(&self, parent_id: Uuid) -> Vec<&SessionEntry> {
|
||||
self.entries
|
||||
.iter()
|
||||
.filter(|e| e.parent_id() == Some(parent_id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all leaf entries (entries with no children).
|
||||
pub fn leaves(&self) -> Vec<&SessionEntry> {
|
||||
let parent_ids: std::collections::HashSet<Uuid> = self
|
||||
.entries
|
||||
.iter()
|
||||
.filter_map(|e| e.parent_id())
|
||||
.collect();
|
||||
|
||||
self.entries
|
||||
.iter()
|
||||
.filter(|e| !parent_ids.contains(&e.id()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Count total entries.
|
||||
pub fn entry_count(&self) -> usize {
|
||||
self.entries.len()
|
||||
}
|
||||
|
||||
/// Fork from a specific entry, creating entries that belong to a new branch.
|
||||
/// Returns the entries that should be in the new branch (from root to fork point).
|
||||
pub fn fork_from(&self, fork_entry_id: Uuid) -> AiResult<Session> {
|
||||
let fork_idx = self
|
||||
.entries
|
||||
.iter()
|
||||
.position(|e| e.id() == fork_entry_id)
|
||||
.ok_or_else(|| {
|
||||
AiError::Config(format!("fork entry {fork_entry_id} not found in session"))
|
||||
})?;
|
||||
|
||||
let mut new_session = Session::new();
|
||||
new_session.header = new_session.header.with_parent(self.header.id);
|
||||
|
||||
// Copy entries up to and including the fork point
|
||||
for entry in &self.entries[..=fork_idx] {
|
||||
new_session.entries.push(entry.clone());
|
||||
}
|
||||
|
||||
Ok(new_session)
|
||||
}
|
||||
|
||||
/// Find the common ancestor of two entries.
|
||||
pub fn common_ancestor(&self, id_a: Uuid, id_b: Uuid) -> Option<Uuid> {
|
||||
let ancestors_a = self.ancestor_chain(id_a);
|
||||
let ancestors_b: std::collections::HashSet<Uuid> =
|
||||
self.ancestor_chain(id_b).into_iter().collect();
|
||||
|
||||
for ancestor in ancestors_a {
|
||||
if ancestors_b.contains(&ancestor) {
|
||||
return Some(ancestor);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Get the chain of ancestor IDs from an entry back to the root.
|
||||
fn ancestor_chain(&self, entry_id: Uuid) -> Vec<Uuid> {
|
||||
let mut chain = Vec::new();
|
||||
let mut current_id = Some(entry_id);
|
||||
|
||||
while let Some(id) = current_id {
|
||||
chain.push(id);
|
||||
current_id = self
|
||||
.entries
|
||||
.iter()
|
||||
.find(|e| e.id() == id)
|
||||
.and_then(|e| e.parent_id());
|
||||
}
|
||||
|
||||
chain
|
||||
}
|
||||
}
|
||||
|
||||
/// Options for session compaction.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompactionOptions {
|
||||
/// Custom instructions for the compaction LLM call.
|
||||
pub custom_instructions: Option<String>,
|
||||
/// Reserve this many tokens for the prompt + LLM response.
|
||||
pub reserve_tokens: i64,
|
||||
/// Keep this many recent message pairs untouched.
|
||||
pub keep_recent_pairs: usize,
|
||||
/// Whether to generate branch summaries for forked branches.
|
||||
pub branch_summarization: bool,
|
||||
}
|
||||
|
||||
impl Default for CompactionOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
custom_instructions: None,
|
||||
reserve_tokens: 16_384,
|
||||
keep_recent_pairs: 4,
|
||||
branch_summarization: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn iso_now() -> String {
|
||||
SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.map(|d| {
|
||||
let secs = d.as_secs();
|
||||
// Simple ISO 8601 format (UTC)
|
||||
let days = secs / 86400;
|
||||
let years = (days * 400) / 146097;
|
||||
let remaining_days = days - (years * 365 + years / 4 - years / 100 + years / 400);
|
||||
let month_days = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
|
||||
let is_leap = (years % 4 == 0 && years % 100 != 0) || years % 400 == 0;
|
||||
let mut month = 0usize;
|
||||
let mut day_acc = remaining_days as i64;
|
||||
for (i, &md) in month_days.iter().enumerate() {
|
||||
let md = if i == 1 && is_leap { md + 1 } else { md };
|
||||
if day_acc < md as i64 {
|
||||
month = i;
|
||||
break;
|
||||
}
|
||||
day_acc -= md as i64;
|
||||
}
|
||||
let day = day_acc + 1;
|
||||
let hour = (secs % 86400) / 3600;
|
||||
let minute = (secs % 3600) / 60;
|
||||
let second = secs % 60;
|
||||
format!(
|
||||
"{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
|
||||
1970 + years,
|
||||
month + 1,
|
||||
day,
|
||||
hour,
|
||||
minute,
|
||||
second,
|
||||
)
|
||||
})
|
||||
.unwrap_or_else(|_| "1970-01-01T00:00:00Z".to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_session_basic() {
|
||||
let mut session = Session::new();
|
||||
|
||||
let msg1 = SessionEntry::user_message(None, "Hello");
|
||||
let msg1_id = msg1.id();
|
||||
session.push(msg1);
|
||||
|
||||
let msg2 = SessionEntry::assistant_message(Some(msg1_id), "Hi there!", None);
|
||||
session.push(msg2);
|
||||
|
||||
assert_eq!(session.entry_count(), 2);
|
||||
assert_eq!(session.active_branch().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_fork() {
|
||||
let mut session = Session::new();
|
||||
|
||||
let msg1 = SessionEntry::user_message(None, "First");
|
||||
let msg1_id = msg1.id();
|
||||
session.push(msg1);
|
||||
|
||||
let msg2 = SessionEntry::assistant_message(Some(msg1_id), "Reply 1", None);
|
||||
let msg2_id = msg2.id();
|
||||
session.push(msg2);
|
||||
|
||||
let msg3 = SessionEntry::user_message(Some(msg2_id), "Second");
|
||||
session.push(msg3);
|
||||
|
||||
// Fork from msg2
|
||||
let forked = session.fork_from(msg2_id).unwrap();
|
||||
assert_eq!(forked.entry_count(), 2);
|
||||
assert_eq!(forked.header.parent_session, Some(session.header.id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_leaves() {
|
||||
let mut session = Session::new();
|
||||
|
||||
let msg1 = SessionEntry::user_message(None, "Root");
|
||||
let msg1_id = msg1.id();
|
||||
session.push(msg1);
|
||||
|
||||
// Two children branching from root
|
||||
let msg2a = SessionEntry::assistant_message(Some(msg1_id), "Branch A", None);
|
||||
let msg2b = SessionEntry::assistant_message(Some(msg1_id), "Branch B", None);
|
||||
session.push(msg2a);
|
||||
session.push(msg2b);
|
||||
|
||||
let leaves = session.leaves();
|
||||
assert_eq!(leaves.len(), 2);
|
||||
}
|
||||
}
|
||||
203
lib/ai/agent/subagent.rs
Normal file
203
lib/ai/agent/subagent.rs
Normal file
@ -0,0 +1,203 @@
|
||||
use rig::agent::AgentBuilder;
|
||||
use rig::client::CompletionClient;
|
||||
use rig::completion::Prompt;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use super::config::AgentConfig;
|
||||
use super::helpers::with_retry;
|
||||
use super::persistence::{
|
||||
ActiveAgentRun, AgentRealtime, AgentRuntime, AgentStreamEvent,
|
||||
estimate_output_tokens,
|
||||
};
|
||||
use super::request::{AgentExpert, AgentExpertOutput};
|
||||
use crate::client::AiClient;
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
pub async fn run_experts(
|
||||
client: &AiClient,
|
||||
config: &AgentConfig,
|
||||
experts: &[AgentExpert],
|
||||
realtime: Option<&AgentRealtime>,
|
||||
run: &ActiveAgentRun,
|
||||
) -> AiResult<Vec<AgentExpertOutput>> {
|
||||
let mut outputs = Vec::with_capacity(experts.len());
|
||||
let mut failed_count = 0;
|
||||
|
||||
for expert in experts {
|
||||
match run_single(client, config, expert, realtime, run).await {
|
||||
Ok(output) => {
|
||||
debug!(subagent_id = %output.id, role = %output.role, "subagent completed");
|
||||
outputs.push(output);
|
||||
}
|
||||
Err(error) => {
|
||||
warn!(subagent_id = %expert.id, role = %expert.role, error = %error, "subagent failed");
|
||||
let _ = publish_subagent_failed(realtime, run, expert, &error.to_string()).await;
|
||||
failed_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!(total = experts.len(), ok = outputs.len(), failed = failed_count, "experts done");
|
||||
Ok(outputs)
|
||||
}
|
||||
|
||||
async fn run_single(
|
||||
client: &AiClient,
|
||||
config: &AgentConfig,
|
||||
expert: &AgentExpert,
|
||||
realtime: Option<&AgentRealtime>,
|
||||
run: &ActiveAgentRun,
|
||||
) -> AiResult<AgentExpertOutput> {
|
||||
publish_subagent_started(realtime, run, config, expert).await?;
|
||||
|
||||
let rig_client = client.llm_client().clone();
|
||||
let model_name = config.model.clone();
|
||||
let temperature = expert.temperature.or(config.temperature);
|
||||
let max_completion_tokens = expert.max_completion_tokens.or(config.max_completion_tokens);
|
||||
let retry_attempts = config.retry_max_attempts;
|
||||
let retry_delay_ms = config.retry_base_delay_ms;
|
||||
|
||||
let prompt = expert.system_prompt.clone().unwrap_or_else(|| {
|
||||
format!(
|
||||
"You are a specialist subagent. Role: {}. Produce a concise expert answer for the parent chat agent.",
|
||||
expert.role
|
||||
)
|
||||
});
|
||||
|
||||
let task = build_expert_task(expert);
|
||||
|
||||
let (output, input_tokens_usage, output_tokens_usage) = with_retry(
|
||||
retry_attempts,
|
||||
retry_delay_ms,
|
||||
|| {
|
||||
let rig_client = rig_client.clone();
|
||||
let model_name = model_name.clone();
|
||||
let prompt = prompt.clone();
|
||||
let task = task.clone();
|
||||
async move {
|
||||
let model = rig_client.completion_model(&model_name);
|
||||
let mut builder = AgentBuilder::new(model).preamble(&prompt);
|
||||
if let Some(temp) = temperature {
|
||||
builder = builder.temperature(temp);
|
||||
}
|
||||
if let Some(mt) = max_completion_tokens {
|
||||
builder = builder.max_tokens(mt);
|
||||
}
|
||||
let agent = builder.build();
|
||||
|
||||
let response = agent
|
||||
.prompt(&task)
|
||||
.extended_details()
|
||||
.await
|
||||
.map_err(|e: rig::completion::PromptError| {
|
||||
AiError::Api(e.to_string())
|
||||
})?;
|
||||
|
||||
Ok((
|
||||
response.output,
|
||||
response.usage.input_tokens,
|
||||
response.usage.output_tokens,
|
||||
))
|
||||
}
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let input_tokens = input_tokens_usage as i64;
|
||||
let output_tokens = if output_tokens_usage > 0 {
|
||||
output_tokens_usage as i64
|
||||
} else {
|
||||
estimate_output_tokens(&output)
|
||||
};
|
||||
|
||||
let result = AgentExpertOutput {
|
||||
id: expert.id.clone(),
|
||||
role: expert.role.clone(),
|
||||
task: expert.task.clone(),
|
||||
output,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
};
|
||||
|
||||
publish_subagent_completed(realtime, run, config, &result).await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn build_expert_task(expert: &AgentExpert) -> String {
|
||||
let mut task = String::new();
|
||||
|
||||
if !expert.context.is_empty() {
|
||||
task.push_str("Retrieved context for this specialist task:\n");
|
||||
for (index, chunk) in expert.context.iter().enumerate() {
|
||||
task.push_str(&format!(
|
||||
"\n[{}] id={} source={}\n{}\n",
|
||||
index + 1,
|
||||
chunk.id,
|
||||
chunk.source.as_deref().unwrap_or("unknown"),
|
||||
chunk.content
|
||||
));
|
||||
}
|
||||
task.push('\n');
|
||||
}
|
||||
|
||||
task.push_str(&expert.task);
|
||||
task
|
||||
}
|
||||
|
||||
async fn publish_subagent_started(
|
||||
realtime: Option<&AgentRealtime>,
|
||||
run: &ActiveAgentRun,
|
||||
config: &AgentConfig,
|
||||
expert: &AgentExpert,
|
||||
) -> AiResult<()> {
|
||||
AgentRuntime::default().publish(
|
||||
realtime,
|
||||
&AgentStreamEvent::SubagentStarted {
|
||||
conversation_id: run.conversation_id,
|
||||
message_id: run.message_id,
|
||||
subagent_id: expert.id.clone(),
|
||||
role: expert.role.clone(),
|
||||
task: expert.task.clone(),
|
||||
model: config.model.clone(),
|
||||
},
|
||||
).await
|
||||
}
|
||||
|
||||
async fn publish_subagent_completed(
|
||||
realtime: Option<&AgentRealtime>,
|
||||
run: &ActiveAgentRun,
|
||||
config: &AgentConfig,
|
||||
output: &AgentExpertOutput,
|
||||
) -> AiResult<()> {
|
||||
AgentRuntime::default().publish(
|
||||
realtime,
|
||||
&AgentStreamEvent::SubagentCompleted {
|
||||
conversation_id: run.conversation_id,
|
||||
message_id: run.message_id,
|
||||
subagent_id: output.id.clone(),
|
||||
role: output.role.clone(),
|
||||
task: output.task.clone(),
|
||||
output: output.output.clone(),
|
||||
input_tokens: output.input_tokens,
|
||||
output_tokens: output.output_tokens,
|
||||
model: config.model.clone(),
|
||||
},
|
||||
).await
|
||||
}
|
||||
|
||||
async fn publish_subagent_failed(
|
||||
realtime: Option<&AgentRealtime>,
|
||||
run: &ActiveAgentRun,
|
||||
expert: &AgentExpert,
|
||||
error: &str,
|
||||
) -> AiResult<()> {
|
||||
AgentRuntime::default().publish(
|
||||
realtime,
|
||||
&AgentStreamEvent::SubagentFailed {
|
||||
conversation_id: run.conversation_id,
|
||||
message_id: run.message_id,
|
||||
subagent_id: expert.id.clone(),
|
||||
error: error.to_string(),
|
||||
},
|
||||
).await
|
||||
}
|
||||
158
lib/ai/agent/tool.rs
Normal file
158
lib/ai/agent/tool.rs
Normal file
@ -0,0 +1,158 @@
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rig::completion::ToolDefinition as RigToolDefinition;
|
||||
use rig::tool::ToolDyn;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::tool::tools::FunctionCall;
|
||||
|
||||
pub struct RigTool<C>
|
||||
where
|
||||
C: Clone + Send + Sync + 'static,
|
||||
{
|
||||
context: Arc<Mutex<C>>,
|
||||
tool: Arc<dyn FunctionCall<Context = C>>,
|
||||
name: String,
|
||||
description: String,
|
||||
schema: Value,
|
||||
}
|
||||
|
||||
impl<C> RigTool<C>
|
||||
where
|
||||
C: Clone + Send + Sync + 'static,
|
||||
{
|
||||
pub fn new(tool: Arc<dyn FunctionCall<Context = C>>, context: Arc<Mutex<C>>) -> Self {
|
||||
let name = tool.name().to_string();
|
||||
let description = tool.description().to_string();
|
||||
let schema = tool.schema();
|
||||
|
||||
Self {
|
||||
context,
|
||||
tool,
|
||||
name,
|
||||
description,
|
||||
schema,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> ToolDyn for RigTool<C>
|
||||
where
|
||||
C: Clone + Send + Sync + 'static,
|
||||
{
|
||||
fn name(&self) -> String {
|
||||
self.name.clone()
|
||||
}
|
||||
|
||||
fn definition<'a>(
|
||||
&'a self,
|
||||
_prompt: String,
|
||||
) -> Pin<Box<dyn std::future::Future<Output = RigToolDefinition> + Send + 'a>> {
|
||||
let name = self.name.clone();
|
||||
let description = self.description.clone();
|
||||
let params = self.schema.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
RigToolDefinition {
|
||||
name,
|
||||
description,
|
||||
parameters: params,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn call<'a>(
|
||||
&'a self,
|
||||
args: String,
|
||||
) -> Pin<
|
||||
Box<dyn std::future::Future<Output = Result<String, rig::tool::ToolError>> + Send + 'a>,
|
||||
> {
|
||||
let tool = self.tool.clone();
|
||||
let context = self.context.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let args_value: Value =
|
||||
serde_json::from_str(&args).map_err(rig::tool::ToolError::JsonError)?;
|
||||
|
||||
let mut ctx = context.lock().await;
|
||||
|
||||
match tool.call(&mut *ctx, args_value).await {
|
||||
Ok(value) => serde_json::to_string(&value)
|
||||
.map_err(rig::tool::ToolError::JsonError),
|
||||
Err(ai_err) => Err(rig::tool::ToolError::ToolCallError(Box::new(
|
||||
std::io::Error::other(ai_err.to_string()),
|
||||
))),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RigToolSet<C>
|
||||
where
|
||||
C: Clone + Send + Sync + 'static,
|
||||
{
|
||||
tools: Vec<Box<dyn ToolDyn + 'static>>,
|
||||
context: Option<Arc<Mutex<C>>>,
|
||||
}
|
||||
|
||||
impl<C> RigToolSet<C>
|
||||
where
|
||||
C: Clone + Send + Sync + 'static,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tools: Vec::new(),
|
||||
context: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_register(
|
||||
register: &crate::tool::register::ToolRegister<C>,
|
||||
context: Arc<Mutex<C>>,
|
||||
) -> Self {
|
||||
let mut tools: Vec<Box<dyn ToolDyn + 'static>> = Vec::with_capacity(register.len());
|
||||
|
||||
for tool_arc in ®ister.tools {
|
||||
tools.push(Box::new(RigTool::new(tool_arc.clone(), context.clone())));
|
||||
}
|
||||
|
||||
Self {
|
||||
tools,
|
||||
context: Some(context),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.tools.is_empty()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.tools.len()
|
||||
}
|
||||
|
||||
pub fn context(&self) -> Option<&Arc<Mutex<C>>> {
|
||||
self.context.as_ref()
|
||||
}
|
||||
|
||||
pub fn take_tools(&mut self) -> Vec<Box<dyn ToolDyn + 'static>> {
|
||||
std::mem::take(&mut self.tools)
|
||||
}
|
||||
|
||||
pub fn into_context(mut self) -> C {
|
||||
self.context
|
||||
.take()
|
||||
.and_then(|arc| Arc::try_unwrap(arc).ok().map(|m| m.into_inner()))
|
||||
.unwrap_or_else(|| unreachable!("context must be available"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Default for RigToolSet<C>
|
||||
where
|
||||
C: Clone + Send + Sync + 'static,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
219
lib/ai/client.rs
Normal file
219
lib/ai/client.rs
Normal file
@ -0,0 +1,219 @@
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use config::AppConfig;
|
||||
use rig::providers::openai;
|
||||
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
fn validate_required(scope: &str, field: &str, value: &str) -> AiResult<()> {
|
||||
if value.trim().is_empty() {
|
||||
return Err(AiError::Config(format!("{scope} {field} is required")));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn config_error(error: impl fmt::Display) -> AiError {
|
||||
AiError::Config(error.to_string())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EndpointConfig {
|
||||
pub base_url: String,
|
||||
pub api_key: String,
|
||||
}
|
||||
|
||||
impl EndpointConfig {
|
||||
pub fn new(base_url: impl Into<String>, api_key: impl Into<String>) -> AiResult<Self> {
|
||||
let config = Self {
|
||||
base_url: base_url.into(),
|
||||
api_key: api_key.into(),
|
||||
};
|
||||
config.validate("endpoint")?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn validate(&self, scope: &str) -> AiResult<()> {
|
||||
validate_required(scope, "base_url", &self.base_url)?;
|
||||
validate_required(scope, "api_key", &self.api_key)?;
|
||||
if !self.base_url.trim().starts_with("http://")
|
||||
&& !self.base_url.trim().starts_with("https://")
|
||||
{
|
||||
return Err(AiError::Config(format!(
|
||||
"{scope} base_url must start with http:// or https://"
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn build_client(&self) -> AiResult<openai::Client> {
|
||||
openai::Client::builder()
|
||||
.api_key(&self.api_key)
|
||||
.base_url(self.base_url.trim())
|
||||
.build()
|
||||
.map_err(|e| AiError::Config(format!("failed to build rig OpenAI client: {e}")))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for EndpointConfig {
|
||||
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
formatter
|
||||
.debug_struct("EndpointConfig")
|
||||
.field("base_url", &self.base_url)
|
||||
.field("api_key", &"<redacted>")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EmbedConfig {
|
||||
pub endpoint: EndpointConfig,
|
||||
pub model: String,
|
||||
pub dimensions: u64,
|
||||
}
|
||||
|
||||
impl EmbedConfig {
|
||||
pub fn new(
|
||||
endpoint: EndpointConfig,
|
||||
model: impl Into<String>,
|
||||
dimensions: u64,
|
||||
) -> AiResult<Self> {
|
||||
let config = Self {
|
||||
endpoint,
|
||||
model: model.into(),
|
||||
dimensions,
|
||||
};
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn validate(&self) -> AiResult<()> {
|
||||
self.endpoint.validate("embed endpoint")?;
|
||||
validate_required("embed", "model", &self.model)?;
|
||||
if self.dimensions == 0 {
|
||||
return Err(AiError::Config(
|
||||
"embed dimensions must be greater than 0".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AiClientConfig {
|
||||
pub llm: EndpointConfig,
|
||||
pub embed: EmbedConfig,
|
||||
}
|
||||
|
||||
impl AiClientConfig {
|
||||
pub fn new(llm: EndpointConfig, embed: EmbedConfig) -> AiResult<Self> {
|
||||
let config = Self { llm, embed };
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> AiResult<()> {
|
||||
self.llm.validate("llm endpoint")?;
|
||||
self.embed.validate()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&AppConfig> for AiClientConfig {
|
||||
type Error = AiError;
|
||||
|
||||
fn try_from(config: &AppConfig) -> Result<Self, Self::Error> {
|
||||
let llm = EndpointConfig::new(
|
||||
config.ai_basic_url().map_err(config_error)?,
|
||||
config.ai_api_key().map_err(config_error)?,
|
||||
)?;
|
||||
|
||||
let embed_endpoint = EndpointConfig::new(
|
||||
config.get_embed_model_base_url().map_err(config_error)?,
|
||||
config.get_embed_model_api_key().map_err(config_error)?,
|
||||
)?;
|
||||
|
||||
let embed = EmbedConfig::new(
|
||||
embed_endpoint,
|
||||
config.get_embed_model_name().map_err(config_error)?,
|
||||
config.get_embed_model_dimensions().map_err(config_error)?,
|
||||
)?;
|
||||
|
||||
Self::new(llm, embed)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AiClient {
|
||||
pub(super) llm_client: openai::Client,
|
||||
pub(super) embed_client: openai::Client,
|
||||
pub(super) config: Arc<AiClientConfig>,
|
||||
}
|
||||
|
||||
impl AiClient {
|
||||
pub fn new(config: AiClientConfig) -> AiResult<Self> {
|
||||
config.validate()?;
|
||||
|
||||
Ok(Self {
|
||||
llm_client: config.llm.build_client()?,
|
||||
embed_client: config.embed.endpoint.build_client()?,
|
||||
config: Arc::new(config),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_app_config(config: &AppConfig) -> AiResult<Self> {
|
||||
Self::new(AiClientConfig::try_from(config)?)
|
||||
}
|
||||
|
||||
pub fn llm_client(&self) -> &openai::Client {
|
||||
&self.llm_client
|
||||
}
|
||||
|
||||
pub fn embed_client(&self) -> &openai::Client {
|
||||
&self.embed_client
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &AiClientConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn llm_config(&self) -> &EndpointConfig {
|
||||
&self.config.llm
|
||||
}
|
||||
|
||||
pub fn embed_config(&self) -> &EmbedConfig {
|
||||
&self.config.embed
|
||||
}
|
||||
|
||||
pub fn embed_model(&self) -> &str {
|
||||
self.config.embed.model.as_str()
|
||||
}
|
||||
|
||||
pub fn embed_dimensions(&self) -> u64 {
|
||||
self.config.embed.dimensions
|
||||
}
|
||||
|
||||
pub fn embed_dimensions_u32(&self) -> u32 {
|
||||
u32::try_from(self.config.embed.dimensions).unwrap_or(u32::MAX)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_http_client() -> Result<reqwest::Client, AiError> {
|
||||
let mut builder = reqwest::Client::builder();
|
||||
|
||||
if let Ok(proxy_url) = std::env::var("HTTPS_PROXY")
|
||||
.or_else(|_| std::env::var("https_proxy"))
|
||||
.or_else(|_| std::env::var("HTTP_PROXY"))
|
||||
.or_else(|_| std::env::var("http_proxy"))
|
||||
{
|
||||
let proxy_url = proxy_url.trim().trim_matches('"').trim_matches('\'');
|
||||
let proxy = reqwest::Proxy::all(proxy_url).map_err(|e| {
|
||||
AiError::Config(format!("Invalid proxy URL '{}': {}", proxy_url, e))
|
||||
})?;
|
||||
builder = builder.proxy(proxy);
|
||||
}
|
||||
|
||||
builder.build().map_err(|e| {
|
||||
AiError::Config(format!("Failed to build HTTP client: {}", e))
|
||||
})
|
||||
}
|
||||
76
lib/ai/embed/client.rs
Normal file
76
lib/ai/embed/client.rs
Normal file
@ -0,0 +1,76 @@
|
||||
use rig::client::EmbeddingsClient;
|
||||
use rig::embeddings::EmbeddingModel;
|
||||
|
||||
use crate::{client::AiClient, error::{AiError, AiResult}};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct EmbedClient {
|
||||
model_name: String,
|
||||
client: rig::providers::openai::Client,
|
||||
}
|
||||
|
||||
impl EmbedClient {
|
||||
pub fn new(ai_client: &AiClient) -> AiResult<Self> {
|
||||
Ok(Self {
|
||||
model_name: ai_client.embed_model().to_string(),
|
||||
client: ai_client.embed_client().clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn embedding_model(&self) -> impl EmbeddingModel + '_ {
|
||||
self.client.embedding_model(&self.model_name)
|
||||
}
|
||||
|
||||
pub async fn embed_text(&self, text: String) -> AiResult<Vec<f32>> {
|
||||
let model = self.embedding_model();
|
||||
let mut embeddings = model.embed_texts(vec![text])
|
||||
.await
|
||||
.map_err(|e| AiError::Api(e.to_string()))?;
|
||||
embeddings.pop()
|
||||
.map(|e| e.vec.into_iter().map(|v| v as f32).collect())
|
||||
.ok_or_else(|| AiError::Response("no embedding returned".to_string()))
|
||||
}
|
||||
|
||||
pub async fn embed_texts(&self, texts: Vec<String>) -> AiResult<Vec<Vec<f32>>> {
|
||||
if texts.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let model = self.embedding_model();
|
||||
let embeddings = model.embed_texts(texts)
|
||||
.await
|
||||
.map_err(|e| AiError::Api(e.to_string()))?;
|
||||
Ok(embeddings.into_iter()
|
||||
.map(|e| e.vec.into_iter().map(|v| v as f32).collect())
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn embed_texts_chunked(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
batch_size: usize,
|
||||
) -> AiResult<Vec<Vec<f32>>> {
|
||||
if batch_size == 0 {
|
||||
return Err(AiError::Config("batch_size must be > 0".to_string()));
|
||||
}
|
||||
let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
|
||||
for chunk in texts.chunks(batch_size) {
|
||||
let model = self.embedding_model();
|
||||
let chunk_embeddings = model.embed_texts(chunk.to_vec())
|
||||
.await
|
||||
.map_err(|e| AiError::Api(e.to_string()))?;
|
||||
embeddings.extend(chunk_embeddings.into_iter()
|
||||
.map(|e| e.vec.into_iter().map(|v| v as f32).collect()));
|
||||
}
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AiClientEmbedExt {
|
||||
fn embedder(&self) -> AiResult<EmbedClient>;
|
||||
}
|
||||
|
||||
impl AiClientEmbedExt for AiClient {
|
||||
fn embedder(&self) -> AiResult<EmbedClient> {
|
||||
EmbedClient::new(self)
|
||||
}
|
||||
}
|
||||
3
lib/ai/embed/mod.rs
Normal file
3
lib/ai/embed/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
mod client;
|
||||
|
||||
pub use client::{AiClientEmbedExt, EmbedClient};
|
||||
52
lib/ai/error.rs
Normal file
52
lib/ai/error.rs
Normal file
@ -0,0 +1,52 @@
|
||||
pub type AiResult<T> = Result<T, AiError>;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum AiError {
|
||||
#[error("ai config error: {0}")]
|
||||
Config(String),
|
||||
|
||||
#[error("ai api error: {0}")]
|
||||
Api(String),
|
||||
|
||||
#[error("qdrant error: {0}")]
|
||||
Qdrant(Box<qdrant_client::QdrantError>),
|
||||
|
||||
#[error("database error: {0}")]
|
||||
Database(#[from] db::sqlx::Error),
|
||||
|
||||
#[error("cache error: {0}")]
|
||||
Cache(#[from] cache::CacheError),
|
||||
|
||||
#[error("redis error: {0}")]
|
||||
Redis(#[from] redis::RedisError),
|
||||
|
||||
#[error("ai response error: {0}")]
|
||||
Response(String),
|
||||
|
||||
#[error("model retries exhausted after {attempts} attempts: {last_error}")]
|
||||
ModelRetriesExhausted {
|
||||
attempts: usize,
|
||||
last_error: String,
|
||||
},
|
||||
|
||||
#[error("agent timeout after {seconds}s")]
|
||||
Timeout { seconds: u64 },
|
||||
|
||||
#[error("tool not found: {tool}")]
|
||||
ToolNotFound { tool: String },
|
||||
|
||||
#[error("tool execution failed: {cause}")]
|
||||
ToolExecutionFailed { cause: String },
|
||||
|
||||
#[error("invalid input in '{field}': {reason}")]
|
||||
InvalidInput { field: String, reason: String },
|
||||
|
||||
#[error("token budget exceeded: used ~{estimated} tokens, limit {limit}")]
|
||||
TokenBudgetExceeded { estimated: u64, limit: i64 },
|
||||
}
|
||||
|
||||
impl From<qdrant_client::QdrantError> for AiError {
|
||||
fn from(e: qdrant_client::QdrantError) -> Self {
|
||||
AiError::Qdrant(Box::new(e))
|
||||
}
|
||||
}
|
||||
8
lib/ai/lib.rs
Normal file
8
lib/ai/lib.rs
Normal file
@ -0,0 +1,8 @@
|
||||
pub mod agent;
|
||||
pub mod client;
|
||||
pub mod embed;
|
||||
pub mod error;
|
||||
pub mod memory;
|
||||
pub mod rag;
|
||||
pub mod sync;
|
||||
pub mod tool;
|
||||
46
lib/ai/memory/mod.rs
Normal file
46
lib/ai/memory/mod.rs
Normal file
@ -0,0 +1,46 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::AiResult;
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryEntry {
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
pub importance: i32,
|
||||
pub last_used_at: Option<String>,
|
||||
}
|
||||
#[async_trait]
|
||||
pub trait MemoryProvider: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
async fn save(
|
||||
&self,
|
||||
session_id: Uuid,
|
||||
key: &str,
|
||||
value: &str,
|
||||
importance: i32,
|
||||
) -> AiResult<()>;
|
||||
async fn recall(
|
||||
&self,
|
||||
session_id: Uuid,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
) -> AiResult<Vec<MemoryEntry>>;
|
||||
async fn forget(&self, session_id: Uuid, key: &str) -> AiResult<()>;
|
||||
async fn prefetch(
|
||||
&self,
|
||||
_session_id: Uuid,
|
||||
_query: &str,
|
||||
) -> AiResult<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
async fn build_context_block(
|
||||
&self,
|
||||
_session_id: Uuid,
|
||||
) -> AiResult<String> {
|
||||
Ok(String::new())
|
||||
}
|
||||
async fn setup(&self) -> AiResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
263
lib/ai/rag/client.rs
Normal file
263
lib/ai/rag/client.rs
Normal file
@ -0,0 +1,263 @@
|
||||
use config::AppConfig;
|
||||
use qdrant_client::qdrant::{
|
||||
CreateCollectionBuilder, CreateFieldIndexCollectionBuilder,
|
||||
DeletePointsBuilder, FieldType, PointStruct, QueryPointsBuilder,
|
||||
SearchParamsBuilder, UpsertPointsBuilder, VectorParamsBuilder,
|
||||
};
|
||||
use qdrant_client::{Qdrant, QdrantError};
|
||||
|
||||
use super::{
|
||||
config::RagConfig,
|
||||
document::{RagDocument, RagSearchHit},
|
||||
payload::{
|
||||
SESSION_ID_KEY, document_payload, hit_from_scored_point, point_id,
|
||||
},
|
||||
search::RagSearchOptions,
|
||||
session::{session_filter, validate_session_id},
|
||||
};
|
||||
use crate::{
|
||||
client::AiClient,
|
||||
embed::{AiClientEmbedExt, EmbedClient},
|
||||
error::{AiError, AiResult},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RagClient {
|
||||
qdrant: Qdrant,
|
||||
embedder: EmbedClient,
|
||||
config: RagConfig,
|
||||
}
|
||||
|
||||
impl RagClient {
|
||||
pub fn new(
|
||||
qdrant: Qdrant,
|
||||
embedder: EmbedClient,
|
||||
config: RagConfig,
|
||||
) -> AiResult<Self> {
|
||||
config.validate()?;
|
||||
Ok(Self {
|
||||
qdrant,
|
||||
embedder,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn connect(
|
||||
ai_client: &AiClient,
|
||||
config: RagConfig,
|
||||
) -> AiResult<Self> {
|
||||
config.validate()?;
|
||||
let mut builder =
|
||||
Qdrant::from_url(config.url.trim()).timeout(config.timeout);
|
||||
if let Some(api_key) = config
|
||||
.api_key
|
||||
.as_deref()
|
||||
.filter(|api_key| !api_key.trim().is_empty())
|
||||
{
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
Self::new(builder.build()?, ai_client.embedder()?, config)
|
||||
}
|
||||
|
||||
pub fn from_app_config(
|
||||
ai_client: &AiClient,
|
||||
config: &AppConfig,
|
||||
collection_name: impl Into<String>,
|
||||
) -> AiResult<Self> {
|
||||
Self::connect(
|
||||
ai_client,
|
||||
RagConfig::from_app_config(config, collection_name)?,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn qdrant(&self) -> &Qdrant {
|
||||
&self.qdrant
|
||||
}
|
||||
|
||||
pub fn embedder(&self) -> &EmbedClient {
|
||||
&self.embedder
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &RagConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub async fn ensure_collection(&self) -> AiResult<()> {
|
||||
if !self
|
||||
.qdrant
|
||||
.collection_exists(&self.config.collection_name)
|
||||
.await?
|
||||
{
|
||||
self.qdrant
|
||||
.create_collection(
|
||||
CreateCollectionBuilder::new(&self.config.collection_name)
|
||||
.vectors_config(VectorParamsBuilder::new(
|
||||
self.config.vector_size,
|
||||
self.config.distance,
|
||||
)),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
match self
|
||||
.qdrant
|
||||
.create_field_index(CreateFieldIndexCollectionBuilder::new(
|
||||
&self.config.collection_name,
|
||||
SESSION_ID_KEY,
|
||||
FieldType::Keyword,
|
||||
))
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(()),
|
||||
Err(QdrantError::ResponseError { .. }) => Ok(()),
|
||||
Err(error) => Err(error.into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn upsert_document(
|
||||
&self,
|
||||
session_id: impl AsRef<str>,
|
||||
document: RagDocument,
|
||||
) -> AiResult<()> {
|
||||
self.upsert_documents(session_id, vec![document]).await
|
||||
}
|
||||
|
||||
pub async fn upsert_documents(
|
||||
&self,
|
||||
session_id: impl AsRef<str>,
|
||||
documents: Vec<RagDocument>,
|
||||
) -> AiResult<()> {
|
||||
let session_id = session_id.as_ref();
|
||||
validate_session_id(session_id)?;
|
||||
validate_documents(&documents)?;
|
||||
|
||||
let texts: Vec<String> = documents
|
||||
.iter()
|
||||
.map(|d| d.content.clone())
|
||||
.collect();
|
||||
let vectors = self
|
||||
.embedder
|
||||
.embed_texts_chunked(texts, self.config.upsert_batch_size)
|
||||
.await?;
|
||||
|
||||
let points = documents
|
||||
.iter()
|
||||
.zip(vectors)
|
||||
.map(|(document, vector)| {
|
||||
Ok(PointStruct::new(
|
||||
point_id(session_id, &document.id),
|
||||
vector,
|
||||
document_payload(session_id, document)?,
|
||||
))
|
||||
})
|
||||
.collect::<AiResult<Vec<_>>>()?;
|
||||
|
||||
self.qdrant
|
||||
.upsert_points(
|
||||
UpsertPointsBuilder::new(&self.config.collection_name, points)
|
||||
.wait(true),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn search_session(
|
||||
&self,
|
||||
session_id: impl AsRef<str>,
|
||||
query: impl Into<String>,
|
||||
) -> AiResult<Vec<RagSearchHit>> {
|
||||
let options = RagSearchOptions {
|
||||
limit: self.config.default_search_limit,
|
||||
exact: self.config.exact_session_search,
|
||||
};
|
||||
self.search_session_with_options(session_id, query, options)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn search_session_with_options(
|
||||
&self,
|
||||
session_id: impl AsRef<str>,
|
||||
query: impl Into<String>,
|
||||
options: RagSearchOptions,
|
||||
) -> AiResult<Vec<RagSearchHit>> {
|
||||
let vector = self.embedder.embed_text(query.into()).await?;
|
||||
self.search_session_by_vector(session_id, vector, options)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn search_session_by_vector(
|
||||
&self,
|
||||
session_id: impl AsRef<str>,
|
||||
vector: Vec<f32>,
|
||||
options: RagSearchOptions,
|
||||
) -> AiResult<Vec<RagSearchHit>> {
|
||||
let session_id = session_id.as_ref();
|
||||
validate_session_id(session_id)?;
|
||||
if options.limit == 0 {
|
||||
return Err(AiError::Config(
|
||||
"rag search limit must be greater than 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.qdrant
|
||||
.query(
|
||||
QueryPointsBuilder::new(&self.config.collection_name)
|
||||
.query(vector)
|
||||
.limit(options.limit)
|
||||
.filter(session_filter(session_id))
|
||||
.with_payload(true)
|
||||
.params(
|
||||
SearchParamsBuilder::default().exact(options.exact),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(response
|
||||
.result
|
||||
.into_iter()
|
||||
.map(hit_from_scored_point)
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn clear_session(
|
||||
&self,
|
||||
session_id: impl AsRef<str>,
|
||||
) -> AiResult<()> {
|
||||
let session_id = session_id.as_ref();
|
||||
validate_session_id(session_id)?;
|
||||
|
||||
self.qdrant
|
||||
.delete_points(
|
||||
DeletePointsBuilder::new(&self.config.collection_name)
|
||||
.points(session_filter(session_id))
|
||||
.wait(true),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_documents(documents: &[RagDocument]) -> AiResult<()> {
|
||||
if documents.is_empty() {
|
||||
return Err(AiError::Config("rag documents are required".to_string()));
|
||||
}
|
||||
|
||||
for document in documents {
|
||||
if document.id.trim().is_empty() {
|
||||
return Err(AiError::Config(
|
||||
"rag document id is required".to_string(),
|
||||
));
|
||||
}
|
||||
if document.content.trim().is_empty() {
|
||||
return Err(AiError::Config(
|
||||
"rag document content is required".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
134
lib/ai/rag/config.rs
Normal file
134
lib/ai/rag/config.rs
Normal file
@ -0,0 +1,134 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use config::AppConfig;
|
||||
use qdrant_client::qdrant::Distance;
|
||||
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RagConfig {
|
||||
pub url: String,
|
||||
pub api_key: Option<String>,
|
||||
pub collection_name: String,
|
||||
pub vector_size: u64,
|
||||
pub distance: Distance,
|
||||
pub timeout: Duration,
|
||||
pub upsert_batch_size: usize,
|
||||
pub default_search_limit: u64,
|
||||
pub exact_session_search: bool,
|
||||
}
|
||||
|
||||
impl RagConfig {
|
||||
pub fn new(
|
||||
url: impl Into<String>,
|
||||
collection_name: impl Into<String>,
|
||||
vector_size: u64,
|
||||
) -> AiResult<Self> {
|
||||
let config = Self {
|
||||
url: url.into(),
|
||||
api_key: None,
|
||||
collection_name: collection_name.into(),
|
||||
vector_size,
|
||||
distance: Distance::Cosine,
|
||||
timeout: Duration::from_secs(10),
|
||||
upsert_batch_size: 64,
|
||||
default_search_limit: 8,
|
||||
exact_session_search: true,
|
||||
};
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn with_api_key(mut self, api_key: Option<String>) -> Self {
|
||||
self.api_key = api_key;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_distance(mut self, distance: Distance) -> Self {
|
||||
self.distance = distance;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
||||
self.timeout = timeout;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_upsert_batch_size(mut self, upsert_batch_size: usize) -> Self {
|
||||
self.upsert_batch_size = upsert_batch_size;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_default_search_limit(
|
||||
mut self,
|
||||
default_search_limit: u64,
|
||||
) -> Self {
|
||||
self.default_search_limit = default_search_limit;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_exact_session_search(
|
||||
mut self,
|
||||
exact_session_search: bool,
|
||||
) -> Self {
|
||||
self.exact_session_search = exact_session_search;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> AiResult<()> {
|
||||
if self.url.trim().is_empty() {
|
||||
return Err(AiError::Config("qdrant url is required".to_string()));
|
||||
}
|
||||
if !self.url.trim().starts_with("http://")
|
||||
&& !self.url.trim().starts_with("https://")
|
||||
{
|
||||
return Err(AiError::Config(
|
||||
"qdrant url must start with http:// or https://".to_string(),
|
||||
));
|
||||
}
|
||||
if self.collection_name.trim().is_empty() {
|
||||
return Err(AiError::Config(
|
||||
"qdrant collection_name is required".to_string(),
|
||||
));
|
||||
}
|
||||
if self.vector_size == 0 {
|
||||
return Err(AiError::Config(
|
||||
"qdrant vector_size must be greater than 0".to_string(),
|
||||
));
|
||||
}
|
||||
if self.upsert_batch_size == 0 {
|
||||
return Err(AiError::Config(
|
||||
"qdrant upsert_batch_size must be greater than 0".to_string(),
|
||||
));
|
||||
}
|
||||
if self.default_search_limit == 0 {
|
||||
return Err(AiError::Config(
|
||||
"qdrant default_search_limit must be greater than 0"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl RagConfig {
|
||||
pub fn from_app_config(
|
||||
config: &AppConfig,
|
||||
collection_name: impl Into<String>,
|
||||
) -> AiResult<Self> {
|
||||
Ok(Self::new(
|
||||
config
|
||||
.qdrant_url()
|
||||
.map_err(|error| AiError::Config(error.to_string()))?,
|
||||
collection_name,
|
||||
config
|
||||
.get_embed_model_dimensions()
|
||||
.map_err(|error| AiError::Config(error.to_string()))?,
|
||||
)?
|
||||
.with_api_key(
|
||||
config
|
||||
.qdrant_api_key()
|
||||
.map_err(|error| AiError::Config(error.to_string()))?,
|
||||
))
|
||||
}
|
||||
}
|
||||
44
lib/ai/rag/document.rs
Normal file
44
lib/ai/rag/document.rs
Normal file
@ -0,0 +1,44 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RagDocument {
|
||||
pub id: String,
|
||||
pub content: String,
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl RagDocument {
|
||||
pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
id: id.into(),
|
||||
content: content.into(),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
|
||||
self.metadata = metadata;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn metadata_value(
|
||||
mut self,
|
||||
key: impl Into<String>,
|
||||
value: impl Into<Value>,
|
||||
) -> Self {
|
||||
self.metadata.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RagSearchHit {
|
||||
pub id: String,
|
||||
pub session_id: String,
|
||||
pub score: f32,
|
||||
pub content: String,
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
11
lib/ai/rag/mod.rs
Normal file
11
lib/ai/rag/mod.rs
Normal file
@ -0,0 +1,11 @@
|
||||
mod client;
|
||||
mod config;
|
||||
mod document;
|
||||
mod payload;
|
||||
mod search;
|
||||
mod session;
|
||||
|
||||
pub use client::RagClient;
|
||||
pub use config::RagConfig;
|
||||
pub use document::{RagDocument, RagSearchHit};
|
||||
pub use search::RagSearchOptions;
|
||||
110
lib/ai/rag/payload.rs
Normal file
110
lib/ai/rag/payload.rs
Normal file
@ -0,0 +1,110 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use qdrant_client::Payload;
|
||||
use qdrant_client::qdrant::{
|
||||
PointId, ScoredPoint, point_id::PointIdOptions, value::Kind,
|
||||
};
|
||||
use serde_json::{Map, Value, json};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::document::{RagDocument, RagSearchHit};
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
pub(super) const SESSION_ID_KEY: &str = "session_id";
|
||||
pub(super) const DOCUMENT_ID_KEY: &str = "document_id";
|
||||
pub(super) const CONTENT_KEY: &str = "content";
|
||||
pub(super) const METADATA_KEY: &str = "metadata";
|
||||
pub(super) fn point_id(session_id: &str, document_id: &str) -> u64 {
|
||||
let ns = Uuid::NAMESPACE_DNS;
|
||||
let key = format!("{session_id}:{document_id}");
|
||||
let uuid = Uuid::new_v5(&ns, key.as_bytes());
|
||||
let bytes = uuid.as_bytes();
|
||||
u64::from_be_bytes([
|
||||
bytes[0], bytes[1], bytes[2], bytes[3],
|
||||
bytes[4], bytes[5], bytes[6], bytes[7],
|
||||
])
|
||||
}
|
||||
|
||||
pub(super) fn document_payload(
|
||||
session_id: &str,
|
||||
document: &RagDocument,
|
||||
) -> AiResult<Payload> {
|
||||
Payload::try_from(json!({
|
||||
SESSION_ID_KEY: session_id,
|
||||
DOCUMENT_ID_KEY: document.id,
|
||||
CONTENT_KEY: document.content,
|
||||
METADATA_KEY: document.metadata,
|
||||
}))
|
||||
.map_err(|error| AiError::Config(error.to_string()))
|
||||
}
|
||||
|
||||
pub(super) fn hit_from_scored_point(point: ScoredPoint) -> RagSearchHit {
|
||||
let id = point_id_to_string(point.id);
|
||||
let mut payload = qdrant_payload_to_json(point.payload);
|
||||
let session_id = take_string(&mut payload, SESSION_ID_KEY);
|
||||
let document_id = take_string(&mut payload, DOCUMENT_ID_KEY);
|
||||
let content = take_string(&mut payload, CONTENT_KEY);
|
||||
let metadata = payload
|
||||
.remove(METADATA_KEY)
|
||||
.and_then(|value| match value {
|
||||
Value::Object(object) => Some(object.into_iter().collect()),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
RagSearchHit {
|
||||
id: if document_id.is_empty() {
|
||||
id
|
||||
} else {
|
||||
document_id
|
||||
},
|
||||
session_id,
|
||||
score: point.score,
|
||||
content,
|
||||
metadata,
|
||||
}
|
||||
}
|
||||
|
||||
fn point_id_to_string(id: Option<PointId>) -> String {
|
||||
match id.and_then(|id| id.point_id_options) {
|
||||
Some(PointIdOptions::Num(id)) => id.to_string(),
|
||||
Some(PointIdOptions::Uuid(id)) => id,
|
||||
None => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn qdrant_payload_to_json(
|
||||
payload: HashMap<String, qdrant_client::qdrant::Value>,
|
||||
) -> Map<String, Value> {
|
||||
payload
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key, value_to_json(value)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn value_to_json(value: qdrant_client::qdrant::Value) -> Value {
|
||||
match value.kind {
|
||||
Some(Kind::NullValue(_)) | None => Value::Null,
|
||||
Some(Kind::DoubleValue(value)) => json!(value),
|
||||
Some(Kind::IntegerValue(value)) => json!(value),
|
||||
Some(Kind::StringValue(value)) => json!(value),
|
||||
Some(Kind::BoolValue(value)) => json!(value),
|
||||
Some(Kind::StructValue(value)) => Value::Object(
|
||||
value
|
||||
.fields
|
||||
.into_iter()
|
||||
.map(|(key, value)| (key, value_to_json(value)))
|
||||
.collect(),
|
||||
),
|
||||
Some(Kind::ListValue(value)) => {
|
||||
Value::Array(value.values.into_iter().map(value_to_json).collect())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn take_string(payload: &mut Map<String, Value>, key: &str) -> String {
|
||||
payload
|
||||
.remove(key)
|
||||
.and_then(|value| value.as_str().map(ToOwned::to_owned))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
16
lib/ai/rag/search.rs
Normal file
16
lib/ai/rag/search.rs
Normal file
@ -0,0 +1,16 @@
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RagSearchOptions {
|
||||
pub limit: u64,
|
||||
pub exact: bool,
|
||||
}
|
||||
|
||||
impl RagSearchOptions {
|
||||
pub fn new(limit: u64) -> Self {
|
||||
Self { limit, exact: true }
|
||||
}
|
||||
|
||||
pub fn with_exact(mut self, exact: bool) -> Self {
|
||||
self.exact = exact;
|
||||
self
|
||||
}
|
||||
}
|
||||
15
lib/ai/rag/session.rs
Normal file
15
lib/ai/rag/session.rs
Normal file
@ -0,0 +1,15 @@
|
||||
use qdrant_client::qdrant::{Condition, Filter};
|
||||
|
||||
use super::payload::SESSION_ID_KEY;
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
pub(super) fn validate_session_id(session_id: &str) -> AiResult<()> {
|
||||
if session_id.trim().is_empty() {
|
||||
return Err(AiError::Config("rag session_id is required".to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn session_filter(session_id: &str) -> Filter {
|
||||
Filter::all([Condition::matches(SESSION_ID_KEY, session_id.to_string())])
|
||||
}
|
||||
126
lib/ai/sync.rs
Normal file
126
lib/ai/sync.rs
Normal file
@ -0,0 +1,126 @@
|
||||
use std::error::Error;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::{
|
||||
client::EndpointConfig,
|
||||
error::{AiError, AiResult},
|
||||
};
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct ModelsListResponse {
|
||||
data: Vec<UpstreamModel>,
|
||||
}
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct UpstreamModel {
|
||||
pub id: String,
|
||||
#[serde(default)]
|
||||
pub name: Option<String>,
|
||||
#[serde(default)]
|
||||
pub owned_by: Option<String>,
|
||||
#[serde(default)]
|
||||
pub context_length: Option<i32>,
|
||||
#[serde(default)]
|
||||
pub max_output_tokens: Option<i32>,
|
||||
#[serde(default)]
|
||||
pub capabilities: Option<UpstreamCapabilities>,
|
||||
#[serde(default)]
|
||||
pub pricing: Option<UpstreamPricing>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct UpstreamCapabilities {
|
||||
#[serde(default)]
|
||||
pub vision: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub tool_call: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub reasoning: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct UpstreamPricing {
|
||||
#[serde(default)]
|
||||
pub prompt: Option<String>,
|
||||
#[serde(default)]
|
||||
pub completion: Option<String>,
|
||||
#[serde(default)]
|
||||
pub input: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub output: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub cache_read: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub unit: Option<String>,
|
||||
#[serde(default)]
|
||||
pub currency: Option<String>,
|
||||
}
|
||||
static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
let mut builder = reqwest::Client::builder();
|
||||
let proxy_url = std::env::var("HTTPS_PROXY")
|
||||
.or_else(|_| std::env::var("https_proxy"))
|
||||
.or_else(|_| std::env::var("HTTP_PROXY"))
|
||||
.or_else(|_| std::env::var("http_proxy"))
|
||||
.ok();
|
||||
if let Some(raw) = &proxy_url {
|
||||
let url = raw.trim().trim_matches('"').trim_matches('\'');
|
||||
match reqwest::Proxy::all(url) {
|
||||
Ok(proxy) => {
|
||||
debug!(proxy_url = %url, "sync: using proxy");
|
||||
builder = builder.proxy(proxy);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(proxy_url = %url, error = %e, "sync: invalid proxy URL, skipping");
|
||||
}
|
||||
}
|
||||
}
|
||||
#[allow(clippy::expect_used)]
|
||||
builder.build().expect("failed to build reqwest HTTP client — check system TLS configuration")
|
||||
});
|
||||
pub async fn list_models(
|
||||
config: &EndpointConfig,
|
||||
) -> AiResult<Vec<UpstreamModel>> {
|
||||
let base = config.base_url.trim_end_matches('/');
|
||||
let url = if base.ends_with("/v1") {
|
||||
format!("{}/models", base)
|
||||
} else {
|
||||
format!("{}/v1/models", base)
|
||||
};
|
||||
|
||||
debug!(url = %url, "listing models from upstream");
|
||||
let resp = HTTP_CLIENT
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", config.api_key.trim()))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
source = ?e.source(),
|
||||
"list_models: request failed with full cause chain"
|
||||
);
|
||||
AiError::Response(format!("failed to list models: {}", e))
|
||||
})?;
|
||||
|
||||
let body = resp
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| AiError::Response(format!("failed to read models body: {}", e)))?;
|
||||
if let Ok(parsed) = serde_json::from_str::<ModelsListResponse>(&body) {
|
||||
debug!(count = parsed.data.len(), "parsed models in standard format");
|
||||
return Ok(parsed.data);
|
||||
}
|
||||
if let Ok(parsed) = serde_json::from_str::<Vec<UpstreamModel>>(&body) {
|
||||
debug!(count = parsed.len(), "parsed models in array format");
|
||||
return Ok(parsed);
|
||||
}
|
||||
|
||||
warn!(
|
||||
body = %body.chars().take(500).collect::<String>(),
|
||||
"list_models: unknown response format"
|
||||
);
|
||||
Err(AiError::Response(format!(
|
||||
"unexpected /v1/models response format (first 200 chars): {}",
|
||||
body.chars().take(200).collect::<String>()
|
||||
)))
|
||||
}
|
||||
5
lib/ai/tool/mod.rs
Normal file
5
lib/ai/tool/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
pub mod register;
|
||||
pub mod tools;
|
||||
pub mod toolset;
|
||||
|
||||
pub use toolset::{Toolset, ToolsetRegistry, toolset_names};
|
||||
65
lib/ai/tool/register.rs
Normal file
65
lib/ai/tool/register.rs
Normal file
@ -0,0 +1,65 @@
|
||||
use crate::tool::tools::FunctionCall;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolRegister<C>
|
||||
where
|
||||
C: Clone + Send + Sync + 'static,
|
||||
{
|
||||
pub tools: Vec<Arc<dyn FunctionCall<Context = C>>>,
|
||||
index: HashMap<String, usize>,
|
||||
}
|
||||
|
||||
impl<C> ToolRegister<C>
|
||||
where
|
||||
C: Clone + Send + Sync + 'static,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
ToolRegister {
|
||||
tools: Vec::new(),
|
||||
index: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register<T>(&mut self, tool: T)
|
||||
where
|
||||
T: FunctionCall<Context = C> + 'static,
|
||||
{
|
||||
let idx = self.tools.len();
|
||||
self.index.insert(tool.name().to_string(), idx);
|
||||
self.tools.push(Arc::new(tool));
|
||||
}
|
||||
|
||||
pub fn with_tool<T>(mut self, tool: T) -> Self
|
||||
where
|
||||
T: FunctionCall<Context = C> + 'static,
|
||||
{
|
||||
self.register(tool);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn get(
|
||||
&self,
|
||||
name: &str,
|
||||
) -> Option<Arc<dyn FunctionCall<Context = C>>> {
|
||||
self.index.get(name).map(|&idx| self.tools[idx].clone())
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.tools.is_empty()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.tools.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> Default for ToolRegister<C>
|
||||
where
|
||||
C: Clone + Send + Sync + 'static,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
18
lib/ai/tool/tools.rs
Normal file
18
lib/ai/tool/tools.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use crate::error::AiResult;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
|
||||
#[async_trait]
|
||||
pub trait FunctionCall: Send + Sync {
|
||||
type Context;
|
||||
fn name(&self) -> &'static str;
|
||||
fn description(&self) -> &'static str {
|
||||
""
|
||||
}
|
||||
fn schema(&self) -> Value;
|
||||
async fn call(
|
||||
&self,
|
||||
context: &mut Self::Context,
|
||||
args: Value,
|
||||
) -> AiResult<Value>;
|
||||
}
|
||||
146
lib/ai/tool/toolset.rs
Normal file
146
lib/ai/tool/toolset.rs
Normal file
@ -0,0 +1,146 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Toolset {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub tools: Vec<String>,
|
||||
pub requires_env: Vec<String>,
|
||||
}
|
||||
|
||||
impl Toolset {
|
||||
pub fn new(
|
||||
name: impl Into<String>,
|
||||
description: impl Into<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
description: description.into(),
|
||||
tools: Vec::new(),
|
||||
requires_env: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
|
||||
self.tools.push(tool_name.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tools(mut self, tool_names: impl IntoIterator<Item = impl Into<String>>) -> Self {
|
||||
self.tools.extend(tool_names.into_iter().map(Into::into));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_required_env(
|
||||
mut self,
|
||||
env_vars: impl IntoIterator<Item = impl Into<String>>,
|
||||
) -> Self {
|
||||
self.requires_env.extend(env_vars.into_iter().map(Into::into));
|
||||
self
|
||||
}
|
||||
pub fn is_available(&self) -> bool {
|
||||
for env_var in &self.requires_env {
|
||||
if std::env::var(env_var).is_err() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub fn contains(&self, tool_name: &str) -> bool {
|
||||
self.tools.iter().any(|t| t == tool_name)
|
||||
}
|
||||
}
|
||||
pub mod toolset_names {
|
||||
pub const CORE: &str = "core";
|
||||
pub const TERMINAL: &str = "terminal";
|
||||
pub const WEB: &str = "web";
|
||||
pub const FILE: &str = "file";
|
||||
pub const MEMORY: &str = "memory";
|
||||
pub const VISION: &str = "vision";
|
||||
pub const SEARCH: &str = "search";
|
||||
pub const BROWSER: &str = "browser";
|
||||
pub const CODE_EXECUTION: &str = "code_execution";
|
||||
pub const DELEGATION: &str = "delegation";
|
||||
}
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ToolsetRegistry {
|
||||
toolsets: HashMap<String, Toolset>,
|
||||
tool_index: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ToolsetRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
pub fn register(&mut self, toolset: Toolset) {
|
||||
let name = toolset.name.clone();
|
||||
for tool in &toolset.tools {
|
||||
self.tool_index.insert(tool.clone(), name.clone());
|
||||
}
|
||||
self.toolsets.insert(name, toolset);
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<&Toolset> {
|
||||
self.toolsets.get(name)
|
||||
}
|
||||
pub fn toolset_for(&self, tool_name: &str) -> Option<&str> {
|
||||
self.tool_index.get(tool_name).map(String::as_str)
|
||||
}
|
||||
pub fn resolve_tool_names(
|
||||
&self,
|
||||
enabled: &[String],
|
||||
disabled: &[String],
|
||||
default_all: bool,
|
||||
) -> Vec<String> {
|
||||
let mut names = HashSet::new();
|
||||
let mut denied = HashSet::new();
|
||||
|
||||
for ts_name in disabled {
|
||||
if let Some(ts) = self.toolsets.get(ts_name) {
|
||||
for tool in &ts.tools {
|
||||
denied.insert(tool.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if enabled.is_empty() && default_all {
|
||||
for ts in self.toolsets.values() {
|
||||
if !disabled.contains(&ts.name) && ts.is_available() {
|
||||
for tool in &ts.tools {
|
||||
if !denied.contains(tool) {
|
||||
names.insert(tool.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for ts_name in enabled {
|
||||
if let Some(ts) = self.toolsets.get(ts_name) {
|
||||
if ts.is_available() {
|
||||
for tool in &ts.tools {
|
||||
if !denied.contains(tool) {
|
||||
names.insert(tool.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut sorted: Vec<String> = names.into_iter().collect();
|
||||
sorted.sort();
|
||||
sorted
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = &Toolset> {
|
||||
self.toolsets.values()
|
||||
}
|
||||
|
||||
pub fn all_tool_names(&self) -> Vec<String> {
|
||||
let mut names: Vec<String> = self.tool_index.keys().cloned().collect();
|
||||
names.sort();
|
||||
names
|
||||
}
|
||||
}
|
||||
44
lib/api/Cargo.toml
Normal file
44
lib/api/Cargo.toml
Normal file
@ -0,0 +1,44 @@
|
||||
[package]
|
||||
name = "api"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
readme.workspace = true
|
||||
homepage.workspace = true
|
||||
license.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
documentation.workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
name = "api"
|
||||
|
||||
[dependencies]
|
||||
service = { workspace = true }
|
||||
session = { workspace = true }
|
||||
config = { workspace = true }
|
||||
db = { workspace = true }
|
||||
model = { workspace = true }
|
||||
git = { workspace = true }
|
||||
channel = { workspace = true }
|
||||
socketio = { workspace = true }
|
||||
|
||||
actix-web = { workspace = true }
|
||||
actix-ws = { workspace = true }
|
||||
utoipa = { workspace = true, features = ["chrono", "uuid", "actix_extras"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
uuid = { workspace = true, features = ["v4", "v7", "serde"] }
|
||||
chrono = { workspace = true }
|
||||
async-stream = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
comrak = { workspace = true }
|
||||
redis = { workspace = true, features = ["cluster-async", "aio", "tokio-comp"] }
|
||||
storage = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
296
lib/api/src/agent/conversation.rs
Normal file
296
lib/api/src/agent/conversation.rs
Normal file
@ -0,0 +1,296 @@
|
||||
use actix_web::{HttpResponse, web, web::ServiceConfig};
|
||||
use service::AppService;
|
||||
use service::agent::conversation::{
|
||||
ConversationResponse, ConversationWithSessionResponse, CreateConversation, MessageResponse, UpdateConversation,
|
||||
};
|
||||
use service::agent::types::{AgentRunRequest, AgentRunResponse};
|
||||
use session::Session;
|
||||
use tokio_stream::StreamExt;
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::{ApiError, ok_json};
|
||||
|
||||
pub fn configure(cfg: &mut ServiceConfig) {
|
||||
cfg.service(
|
||||
web::resource("/sessions/{session_id}/conversations")
|
||||
.route(web::get().to(list_conversations))
|
||||
.route(web::post().to(create_conversation)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/conversations")
|
||||
.route(web::get().to(list_all_conversations)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/conversations/{id}")
|
||||
.route(web::get().to(get_conversation))
|
||||
.route(web::patch().to(update_conversation))
|
||||
.route(web::delete().to(delete_conversation)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/conversations/{id}/messages")
|
||||
.route(web::get().to(list_messages))
|
||||
.route(web::post().to(send_message)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/conversations/{id}/stream")
|
||||
.route(web::post().to(stream_agent)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/conversations/{id}/fork")
|
||||
.route(web::post().to(fork_conversation)),
|
||||
);
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/agent/sessions/{session_id}/conversations",
|
||||
params(("session_id" = Uuid, Path)),
|
||||
responses((status = 200, body = Vec<ConversationResponse>)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn list_conversations(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(service.agent_conversation_list(user_id, path.into_inner()).await?)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post, path = "/api/v1/agent/sessions/{session_id}/conversations",
|
||||
params(("session_id" = Uuid, Path)),
|
||||
request_body = CreateConversation,
|
||||
responses((status = 200, body = ConversationResponse)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn create_conversation(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
body: web::Json<CreateConversation>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(
|
||||
service
|
||||
.agent_conversation_create(user_id, path.into_inner(), body.into_inner())
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, utoipa::IntoParams)]
|
||||
pub struct ListAllConversationsQuery {
|
||||
pub wk: Option<String>,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/agent/conversations",
|
||||
params(("wk" = Option<String>, Query, description = "Filter by workspace name")),
|
||||
responses((status = 200, body = Vec<ConversationWithSessionResponse>)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn list_all_conversations(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
query: web::Query<ListAllConversationsQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(
|
||||
service
|
||||
.agent_conversation_list_all(user_id, query.wk.as_deref())
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/agent/conversations/{id}",
|
||||
params(("id" = Uuid, Path)),
|
||||
responses((status = 200, body = ConversationResponse)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn get_conversation(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(service.agent_conversation_get(user_id, path.into_inner()).await?)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
patch, path = "/api/v1/agent/conversations/{id}",
|
||||
params(("id" = Uuid, Path)),
|
||||
request_body = UpdateConversation,
|
||||
responses((status = 200, body = ConversationResponse)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn update_conversation(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
body: web::Json<UpdateConversation>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(
|
||||
service
|
||||
.agent_conversation_update(user_id, path.into_inner(), body.into_inner())
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete, path = "/api/v1/agent/conversations/{id}",
|
||||
params(("id" = Uuid, Path)),
|
||||
responses((status = 200)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn delete_conversation(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
service.agent_conversation_delete(user_id, path.into_inner()).await?;
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({ "deleted": true })))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post, path = "/api/v1/agent/conversations/{id}/archive",
|
||||
params(("id" = Uuid, Path)),
|
||||
responses((status = 200, body = ConversationResponse)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn archive_conversation(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(service.agent_conversation_archive(user_id, path.into_inner()).await?)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post, path = "/api/v1/agent/conversations/{id}/unarchive",
|
||||
params(("id" = Uuid, Path)),
|
||||
responses((status = 200, body = ConversationResponse)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn unarchive_conversation(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(service.agent_conversation_unarchive(user_id, path.into_inner()).await?)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/agent/conversations/{id}/messages",
|
||||
params(("id" = Uuid, Path), ("before" = Option<Uuid>, Query), ("limit" = Option<u32>, Query)),
|
||||
responses((status = 200, body = Vec<MessageResponse>)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn list_messages(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
query: web::Query<MessageListQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(
|
||||
service
|
||||
.agent_message_list(user_id, path.into_inner(), query.limit, query.before)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, utoipa::IntoParams)]
|
||||
pub struct MessageListQuery {
|
||||
pub limit: Option<u32>,
|
||||
pub before: Option<Uuid>,
|
||||
}
|
||||
#[utoipa::path(
|
||||
post, path = "/api/v1/agent/conversations/{id}/messages",
|
||||
params(("id" = Uuid, Path)),
|
||||
request_body = AgentRunRequest,
|
||||
responses((status = 200, body = AgentRunResponse)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn send_message(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
body: web::Json<AgentRunRequest>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
let conversation_id = path.into_inner();
|
||||
let mut req = body.into_inner();
|
||||
req.conversation_id = Some(conversation_id);
|
||||
ok_json(service.agent_run(user_id, req).await?)
|
||||
}
|
||||
#[utoipa::path(
|
||||
post, path = "/api/v1/agent/conversations/{id}/stream",
|
||||
params(("id" = Uuid, Path)),
|
||||
request_body = AgentRunRequest,
|
||||
responses((status = 200, description = "SSE stream")),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn stream_agent(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
body: web::Json<AgentRunRequest>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
let conversation_id = path.into_inner();
|
||||
let mut req = body.into_inner();
|
||||
req.conversation_id = Some(conversation_id);
|
||||
|
||||
let rx = service.agent_run_streaming(user_id, req).await?;
|
||||
|
||||
let stream = UnboundedReceiverStream::new(rx).map(|payload| {
|
||||
let frame = if payload.starts_with("data:") {
|
||||
payload
|
||||
} else {
|
||||
format!("data: {}\n\n", payload)
|
||||
};
|
||||
Ok::<_, actix_web::Error>(actix_web::web::Bytes::from(frame))
|
||||
});
|
||||
|
||||
Ok(HttpResponse::Ok()
|
||||
.content_type("text/event-stream")
|
||||
.insert_header(("Cache-Control", "no-cache"))
|
||||
.insert_header(("Connection", "keep-alive"))
|
||||
.insert_header(("X-Accel-Buffering", "no"))
|
||||
.streaming(stream))
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, utoipa::ToSchema)]
|
||||
pub struct ForkConversationRequest {
|
||||
pub message_id: Option<Uuid>,
|
||||
pub title: Option<String>,
|
||||
}
|
||||
#[utoipa::path(
|
||||
post, path = "/api/v1/agent/conversations/{id}/fork",
|
||||
params(("id" = Uuid, Path)),
|
||||
request_body = ForkConversationRequest,
|
||||
responses((status = 200, body = ConversationResponse)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn fork_conversation(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
body: web::Json<ForkConversationRequest>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(
|
||||
service
|
||||
.agent_conversation_fork(
|
||||
user_id,
|
||||
path.into_inner(),
|
||||
body.message_id,
|
||||
body.title.as_deref(),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
12
lib/api/src/agent/mod.rs
Normal file
12
lib/api/src/agent/mod.rs
Normal file
@ -0,0 +1,12 @@
|
||||
pub mod conversation;
|
||||
pub mod session;
|
||||
|
||||
use actix_web::{web, web::ServiceConfig};
|
||||
|
||||
pub fn configure(cfg: &mut ServiceConfig) {
|
||||
cfg.service(
|
||||
web::scope("/agent")
|
||||
.configure(session::configure)
|
||||
.configure(conversation::configure),
|
||||
);
|
||||
}
|
||||
162
lib/api/src/agent/session.rs
Normal file
162
lib/api/src/agent/session.rs
Normal file
@ -0,0 +1,162 @@
|
||||
use actix_web::{HttpResponse, web, web::ServiceConfig};
|
||||
use service::AppService;
|
||||
use service::agent::session::{
|
||||
AgentSessionResponse, CreateAgentSession, UpdateAgentSession,
|
||||
};
|
||||
use session::Session;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::{ApiError, ok_json};
|
||||
|
||||
pub fn configure(cfg: &mut ServiceConfig) {
|
||||
cfg.service(
|
||||
web::resource("/sessions")
|
||||
.route(web::get().to(list_sessions))
|
||||
.route(web::post().to(create_session)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/sessions/search")
|
||||
.route(web::get().to(search_sessions)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/sessions/{id}")
|
||||
.route(web::get().to(get_session))
|
||||
.route(web::patch().to(update_session))
|
||||
.route(web::delete().to(delete_session)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/sessions/{id}/toolsets")
|
||||
.route(web::patch().to(update_session_toolsets)),
|
||||
);
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/agent/sessions",
|
||||
responses((status = 200, body = Vec<AgentSessionResponse>)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn list_sessions(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(service.agent_session_list(user_id).await?)
|
||||
}
|
||||
#[utoipa::path(
|
||||
post, path = "/api/v1/agent/sessions",
|
||||
request_body = CreateAgentSession,
|
||||
responses((status = 200, body = AgentSessionResponse)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn create_session(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
body: web::Json<CreateAgentSession>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(service.agent_session_create(user_id, body.into_inner()).await?)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/agent/sessions/{id}",
|
||||
params(("id" = Uuid, Path)),
|
||||
responses((status = 200, body = AgentSessionResponse)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn get_session(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(service.agent_session_get(user_id, path.into_inner()).await?)
|
||||
}
|
||||
#[utoipa::path(
|
||||
patch, path = "/api/v1/agent/sessions/{id}",
|
||||
params(("id" = Uuid, Path)),
|
||||
request_body = UpdateAgentSession,
|
||||
responses((status = 200, body = AgentSessionResponse)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn update_session(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
body: web::Json<UpdateAgentSession>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(service.agent_session_update(user_id, path.into_inner(), body.into_inner()).await?)
|
||||
}
|
||||
#[utoipa::path(
|
||||
delete, path = "/api/v1/agent/sessions/{id}",
|
||||
params(("id" = Uuid, Path)),
|
||||
responses((status = 200)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn delete_session(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
service.agent_session_delete(user_id, path.into_inner()).await?;
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({ "deleted": true })))
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, utoipa::IntoParams)]
|
||||
pub struct SearchQuery {
|
||||
pub q: String,
|
||||
#[serde(default = "default_limit")]
|
||||
pub limit: u32,
|
||||
}
|
||||
|
||||
const fn default_limit() -> u32 {
|
||||
20
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/agent/sessions/search",
|
||||
params(("q" = String, Query), ("limit" = Option<u32>, Query)),
|
||||
responses((status = 200, body = Vec<AgentSessionResponse>)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn search_sessions(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
query: web::Query<SearchQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(
|
||||
service
|
||||
.agent_session_search(user_id, &query.q, query.limit)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, utoipa::ToSchema)]
|
||||
pub struct UpdateToolsetsRequest {
|
||||
pub enabled: Option<Vec<String>>,
|
||||
pub disabled: Option<Vec<String>>,
|
||||
}
|
||||
#[utoipa::path(
|
||||
patch, path = "/api/v1/agent/sessions/{id}/toolsets",
|
||||
params(("id" = Uuid, Path)),
|
||||
request_body = UpdateToolsetsRequest,
|
||||
responses((status = 200, body = AgentSessionResponse)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn update_session_toolsets(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<Uuid>,
|
||||
body: web::Json<UpdateToolsetsRequest>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = session.user().ok_or(ApiError(service::error::AppError::Unauthorized))?;
|
||||
ok_json(
|
||||
service
|
||||
.agent_session_update_toolsets(
|
||||
user_id,
|
||||
path.into_inner(),
|
||||
body.enabled.clone(),
|
||||
body.disabled.clone(),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
47
lib/api/src/ai/mod.rs
Normal file
47
lib/api/src/ai/mod.rs
Normal file
@ -0,0 +1,47 @@
|
||||
pub mod model;
|
||||
pub mod provider;
|
||||
|
||||
use actix_web::web;
|
||||
use actix_web::web::ServiceConfig;
|
||||
|
||||
pub fn configure(cfg: &mut ServiceConfig) {
|
||||
cfg.service(
|
||||
web::scope("/ai")
|
||||
.service(
|
||||
web::resource("/providers")
|
||||
.route(web::get().to(provider::list_providers)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/providers/{id}")
|
||||
.route(web::get().to(provider::get_provider)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/models")
|
||||
.route(web::get().to(model::list_models)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/models/{id}")
|
||||
.route(web::get().to(model::get_model)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/models/{id}/versions")
|
||||
.route(web::get().to(model::list_versions)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/models/{id}/card")
|
||||
.route(web::get().to(model::get_card)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/models/{id}/tags")
|
||||
.route(web::get().to(model::list_tags)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/models/{id}/discussions")
|
||||
.route(web::get().to(model::list_discussions)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/models/{id}/likes")
|
||||
.route(web::get().to(model::list_likes)),
|
||||
),
|
||||
);
|
||||
}
|
||||
139
lib/api/src/ai/model.rs
Normal file
139
lib/api/src/ai/model.rs
Normal file
@ -0,0 +1,139 @@
|
||||
use crate::error::{ApiError, ok_json};
|
||||
use actix_web::{HttpResponse, web};
|
||||
use serde::Deserialize;
|
||||
use service::AppService;
|
||||
use service::Pagination;
|
||||
use service::ai::types::{
|
||||
AiDiscussionResponse, AiLikeResponse, AiModelCardResponse, AiModelFilter,
|
||||
AiModelListItem, AiModelResponse, AiModelVersionResponse,
|
||||
};
|
||||
use session::Session;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct ModelIdPath {
|
||||
pub id: Uuid,
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/ai/models",
|
||||
params(AiModelFilter, Pagination),
|
||||
responses((status = 200, body = Vec<AiModelListItem>), (status = 401, description = "Unauthorized")),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn list_models(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
filter: web::Query<AiModelFilter>,
|
||||
pagination: web::Query<Pagination>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
ok_json(
|
||||
service
|
||||
.ai_model_list(
|
||||
&session,
|
||||
filter.into_inner(),
|
||||
pagination.into_inner(),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/ai/models/{id}",
|
||||
params(("id" = String, Path)), responses((status = 200, body = AiModelResponse),
|
||||
(status = 401, description = "Unauthorized"), (status = 404, description = "Not found")),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn get_model(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<ModelIdPath>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
ok_json(service.ai_model_get(&session, path.into_inner().id).await?)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/ai/models/{id}/versions",
|
||||
params(("id" = String, Path)), responses((status = 200, body = Vec<AiModelVersionResponse>),
|
||||
(status = 401, description = "Unauthorized")),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn list_versions(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<ModelIdPath>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
ok_json(
|
||||
service
|
||||
.ai_model_versions(&session, path.into_inner().id)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/ai/models/{id}/card",
|
||||
params(("id" = String, Path)), responses((status = 200, body = Option<AiModelCardResponse>),
|
||||
(status = 401, description = "Unauthorized")),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn get_card(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<ModelIdPath>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
ok_json(
|
||||
service
|
||||
.ai_model_card(&session, path.into_inner().id)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/ai/models/{id}/tags",
|
||||
params(("id" = String, Path)), responses((status = 200, body = Vec<String>)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn list_tags(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<ModelIdPath>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
ok_json(
|
||||
service
|
||||
.ai_model_tags(&session, path.into_inner().id)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/ai/models/{id}/discussions",
|
||||
params(("id" = String, Path), Pagination),
|
||||
responses((status = 200, body = Vec<AiDiscussionResponse>)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn list_discussions(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<ModelIdPath>,
|
||||
pagination: web::Query<Pagination>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
ok_json(
|
||||
service
|
||||
.ai_model_discussions(
|
||||
&session,
|
||||
path.into_inner().id,
|
||||
pagination.into_inner(),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/ai/models/{id}/likes",
|
||||
params(("id" = String, Path)), responses((status = 200, body = Vec<AiLikeResponse>)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn list_likes(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<ModelIdPath>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
ok_json(
|
||||
service
|
||||
.ai_model_likes(&session, path.into_inner().id)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
39
lib/api/src/ai/provider.rs
Normal file
39
lib/api/src/ai/provider.rs
Normal file
@ -0,0 +1,39 @@
|
||||
use crate::error::{ApiError, ok_json};
|
||||
use actix_web::{HttpResponse, web};
|
||||
use serde::Deserialize;
|
||||
use service::AppService;
|
||||
use service::ai::types::AiProviderResponse;
|
||||
use session::Session;
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct ProviderIdPath {
|
||||
pub id: uuid::Uuid,
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/ai/providers",
|
||||
responses((status = 200, body = Vec<AiProviderResponse>), (status = 401, description = "Unauthorized")),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn list_providers(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
ok_json(service.ai_provider_list(&session).await?)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/ai/providers/{id}",
|
||||
params(("id" = String, Path)), responses((status = 200, body = AiProviderResponse),
|
||||
(status = 401, description = "Unauthorized"), (status = 404, description = "Not found")),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn get_provider(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<ProviderIdPath>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
ok_json(
|
||||
service
|
||||
.ai_provider_get(&session, path.into_inner().id)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
29
lib/api/src/auth/captcha.rs
Normal file
29
lib/api/src/auth/captcha.rs
Normal file
@ -0,0 +1,29 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use serde::Serialize;
|
||||
use service::{
|
||||
AppService,
|
||||
auth::captcha::{CaptchaQuery, CaptchaResponse},
|
||||
};
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().json(data))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/auth/captcha",
|
||||
params(CaptchaQuery),
|
||||
responses((status = 200, body = CaptchaResponse)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn captcha(
|
||||
session: Session,
|
||||
query: web::Query<CaptchaQuery>,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let result = service.auth_captcha(&session, query.into_inner()).await?;
|
||||
ok_json(result)
|
||||
}
|
||||
64
lib/api/src/auth/email.rs
Normal file
64
lib/api/src/auth/email.rs
Normal file
@ -0,0 +1,64 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use serde::Serialize;
|
||||
use service::{
|
||||
AppService,
|
||||
auth::email::{EmailChangeRequest, EmailResponse, EmailVerifyRequest},
|
||||
};
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().json(data))
|
||||
}
|
||||
|
||||
fn ok() -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().finish())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/auth/email",
|
||||
responses((status = 200, body = EmailResponse)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn get_email(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let result = service.auth_get_email(&session).await?;
|
||||
ok_json(result)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/email",
|
||||
request_body = EmailChangeRequest,
|
||||
responses((status = 200)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn email_change_request(
|
||||
session: Session,
|
||||
params: web::Json<EmailChangeRequest>,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
service
|
||||
.auth_email_change_request(&session, params.into_inner())
|
||||
.await?;
|
||||
ok()
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/email/verify",
|
||||
request_body = EmailVerifyRequest,
|
||||
responses((status = 200)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn email_verify(
|
||||
params: web::Json<EmailVerifyRequest>,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
service.auth_email_verify(params.into_inner()).await?;
|
||||
ok()
|
||||
}
|
||||
25
lib/api/src/auth/login.rs
Normal file
25
lib/api/src/auth/login.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use service::{AppService, auth::login::LoginParams};
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
fn ok() -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().finish())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/login",
|
||||
request_body = LoginParams,
|
||||
responses((status = 200)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn login(
|
||||
session: Session,
|
||||
params: web::Json<LoginParams>,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
service.auth_login(params.into_inner(), session).await?;
|
||||
ok()
|
||||
}
|
||||
23
lib/api/src/auth/logout.rs
Normal file
23
lib/api/src/auth/logout.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use service::AppService;
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
fn ok() -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().finish())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/logout",
|
||||
responses((status = 200)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn logout(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
service.auth_logout(&session).await?;
|
||||
ok()
|
||||
}
|
||||
24
lib/api/src/auth/me.rs
Normal file
24
lib/api/src/auth/me.rs
Normal file
@ -0,0 +1,24 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use serde::Serialize;
|
||||
use service::{AppService, auth::me::ContextMe};
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().json(data))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/auth/me",
|
||||
responses((status = 200, body = ContextMe)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn me(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let result = service.auth_me(session).await?;
|
||||
ok_json(result)
|
||||
}
|
||||
75
lib/api/src/auth/mod.rs
Normal file
75
lib/api/src/auth/mod.rs
Normal file
@ -0,0 +1,75 @@
|
||||
pub mod captcha;
|
||||
pub mod email;
|
||||
pub mod login;
|
||||
pub mod logout;
|
||||
pub mod me;
|
||||
pub mod register;
|
||||
pub mod reset_pass;
|
||||
pub mod rsa;
|
||||
pub mod totp;
|
||||
|
||||
use actix_web::{web, web::ServiceConfig};
|
||||
|
||||
pub fn configure(cfg: &mut ServiceConfig) {
|
||||
cfg.service(
|
||||
web::scope("/auth")
|
||||
.service(
|
||||
web::resource("/captcha")
|
||||
.route(web::get().to(captcha::captcha)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/login").route(web::post().to(login::login)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/logout").route(web::post().to(logout::logout)),
|
||||
)
|
||||
.service(web::resource("/me").route(web::get().to(me::me)))
|
||||
.service(
|
||||
web::resource("/register")
|
||||
.route(web::post().to(register::register)),
|
||||
)
|
||||
.service(
|
||||
web::scope("/reset-password")
|
||||
.service(web::resource("/request").route(
|
||||
web::post().to(reset_pass::reset_password_request),
|
||||
))
|
||||
.service(web::resource("/verify").route(
|
||||
web::post().to(reset_pass::reset_password_verify),
|
||||
)),
|
||||
)
|
||||
.service(web::resource("/public-key").route(web::get().to(rsa::rsa)))
|
||||
.service(
|
||||
web::scope("/2fa")
|
||||
.service(
|
||||
web::resource("/enable")
|
||||
.route(web::post().to(totp::enable_2fa)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/verify")
|
||||
.route(web::post().to(totp::verify_2fa)),
|
||||
)
|
||||
.service(
|
||||
web::resource("")
|
||||
.route(web::get().to(totp::status_2fa))
|
||||
.route(web::delete().to(totp::disable_2fa)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/backup-codes").route(
|
||||
web::post().to(totp::regenerate_backup_codes),
|
||||
),
|
||||
),
|
||||
)
|
||||
.service(
|
||||
web::scope("/email")
|
||||
.service(
|
||||
web::resource("")
|
||||
.route(web::get().to(email::get_email))
|
||||
.route(web::put().to(email::email_change_request)),
|
||||
)
|
||||
.service(
|
||||
web::resource("/verify")
|
||||
.route(web::post().to(email::email_verify)),
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
26
lib/api/src/auth/register.rs
Normal file
26
lib/api/src/auth/register.rs
Normal file
@ -0,0 +1,26 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use serde::Serialize;
|
||||
use service::{AppService, auth::register::RegisterParams};
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().json(data))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/register",
|
||||
request_body = RegisterParams,
|
||||
responses((status = 200)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn register(
|
||||
session: Session,
|
||||
params: web::Json<RegisterParams>,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let result = service.auth_register(params.into_inner(), &session).await?;
|
||||
ok_json(result)
|
||||
}
|
||||
47
lib/api/src/auth/reset_pass.rs
Normal file
47
lib/api/src/auth/reset_pass.rs
Normal file
@ -0,0 +1,47 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use service::{
|
||||
AppService,
|
||||
auth::reset_pass::{ResetPasswordRequest, ResetPasswordVerifyParams},
|
||||
};
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
fn ok() -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().finish())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/reset-password/request",
|
||||
request_body = ResetPasswordRequest,
|
||||
responses((status = 200)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn reset_password_request(
|
||||
params: web::Json<ResetPasswordRequest>,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
service
|
||||
.auth_reset_password_request(params.into_inner())
|
||||
.await?;
|
||||
ok()
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/reset-password/verify",
|
||||
request_body = ResetPasswordVerifyParams,
|
||||
responses((status = 200)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn reset_password_verify(
|
||||
session: Session,
|
||||
params: web::Json<ResetPasswordVerifyParams>,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
service
|
||||
.auth_reset_password_verify(&session, params.into_inner())
|
||||
.await?;
|
||||
ok()
|
||||
}
|
||||
24
lib/api/src/auth/rsa.rs
Normal file
24
lib/api/src/auth/rsa.rs
Normal file
@ -0,0 +1,24 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use serde::Serialize;
|
||||
use service::{AppService, auth::rsa::RsaResponse};
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().json(data))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/auth/public-key",
|
||||
responses((status = 200, body = RsaResponse)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn rsa(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let result = service.auth_rsa(&session).await?;
|
||||
ok_json(result)
|
||||
}
|
||||
102
lib/api/src/auth/totp.rs
Normal file
102
lib/api/src/auth/totp.rs
Normal file
@ -0,0 +1,102 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use serde::Serialize;
|
||||
use service::{
|
||||
AppService,
|
||||
auth::totp::{
|
||||
Disable2FAParams, Enable2FAResponse, Get2FAStatusResponse,
|
||||
Verify2FAParams,
|
||||
},
|
||||
};
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().json(data))
|
||||
}
|
||||
|
||||
fn ok() -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().finish())
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/2fa/enable",
|
||||
responses((status = 200, body = Enable2FAResponse)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn enable_2fa(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let result = service.auth_2fa_enable(&session).await?;
|
||||
ok_json(result)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/2fa/verify",
|
||||
request_body = Verify2FAParams,
|
||||
responses((status = 200)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn verify_2fa(
|
||||
session: Session,
|
||||
params: web::Json<Verify2FAParams>,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
service
|
||||
.auth_2fa_verify_and_enable(&session, params.into_inner())
|
||||
.await?;
|
||||
ok()
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/auth/2fa",
|
||||
request_body = Disable2FAParams,
|
||||
responses((status = 200)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn disable_2fa(
|
||||
session: Session,
|
||||
params: web::Json<Disable2FAParams>,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
service
|
||||
.auth_2fa_disable(&session, params.into_inner())
|
||||
.await?;
|
||||
ok()
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/auth/2fa",
|
||||
responses((status = 200, body = Get2FAStatusResponse)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn status_2fa(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let result = service.auth_2fa_status(&session).await?;
|
||||
ok_json(result)
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/auth/2fa/backup-codes",
|
||||
request_body = String,
|
||||
responses((status = 200)),
|
||||
tag = "auth"
|
||||
)]
|
||||
pub async fn regenerate_backup_codes(
|
||||
session: Session,
|
||||
params: web::Json<String>,
|
||||
service: web::Data<AppService>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let result = service
|
||||
.auth_2fa_regenerate_backup_codes(&session, params.into_inner())
|
||||
.await?;
|
||||
ok_json(result)
|
||||
}
|
||||
207
lib/api/src/channel/mod.rs
Normal file
207
lib/api/src/channel/mod.rs
Normal file
@ -0,0 +1,207 @@
|
||||
pub mod rest;
|
||||
pub mod rest_ai;
|
||||
pub mod rest_interact;
|
||||
pub mod rest_member;
|
||||
pub mod rest_message;
|
||||
pub mod rest_room;
|
||||
pub mod rest_voice;
|
||||
pub mod token;
|
||||
|
||||
pub use channel::ChannelBus;
|
||||
|
||||
use actix_web::web::ServiceConfig;
|
||||
|
||||
pub fn configure(cfg: &mut ServiceConfig, bus: ChannelBus) {
|
||||
socketio::configure_at(cfg, "/socket.io", bus.io().clone());
|
||||
socketio::configure_at(cfg, "/socket.io/", bus.io().clone());
|
||||
cfg.service(
|
||||
actix_web::web::resource("/ping")
|
||||
.route(actix_web::web::get().to(rest::ping)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/csrf")
|
||||
.route(actix_web::web::get().to(rest::csrf_token)),
|
||||
);
|
||||
cfg.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/messages")
|
||||
.route(actix_web::web::get().to(rest_message::list_messages))
|
||||
.route(actix_web::web::post().to(rest_message::create_message)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/messages/around")
|
||||
.route(actix_web::web::get().to(rest_message::messages_around)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/messages/missed")
|
||||
.route(actix_web::web::get().to(rest_message::missed_messages)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/messages/{message_id}")
|
||||
.route(actix_web::web::patch().to(rest_message::update_message))
|
||||
.route(actix_web::web::delete().to(rest_message::revoke_message)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/search")
|
||||
.route(actix_web::web::get().to(rest_message::search)),
|
||||
);
|
||||
cfg.service(
|
||||
actix_web::web::resource("/rooms")
|
||||
.route(actix_web::web::get().to(rest_room::list_rooms))
|
||||
.route(actix_web::web::post().to(rest_room::room_create)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}")
|
||||
.route(actix_web::web::get().to(rest_room::room_get))
|
||||
.route(actix_web::web::patch().to(rest_room::room_update))
|
||||
.route(actix_web::web::delete().to(rest_room::room_delete)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/subscribe")
|
||||
.route(actix_web::web::post().to(rest_room::subscribe))
|
||||
.route(actix_web::web::delete().to(rest_room::unsubscribe)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/members")
|
||||
.route(actix_web::web::post().to(rest_room::access_grant)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/workspaces/{workspace_id}/members")
|
||||
.route(actix_web::web::get().to(rest_member::list_workspace_members)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/members/{user_id}")
|
||||
.route(actix_web::web::delete().to(rest_room::access_revoke)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/workspaces/{workspace_id}/categories")
|
||||
.route(actix_web::web::post().to(rest_room::category_create)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/categories/{category_id}")
|
||||
.route(actix_web::web::patch().to(rest_room::category_update))
|
||||
.route(actix_web::web::delete().to(rest_room::category_delete)),
|
||||
);
|
||||
cfg.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/reactions")
|
||||
.route(actix_web::web::post().to(rest_interact::reaction_add))
|
||||
.route(actix_web::web::delete().to(rest_interact::reaction_remove)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/threads")
|
||||
.route(actix_web::web::post().to(rest_interact::thread_create)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/threads/{thread_id}/resolve")
|
||||
.route(actix_web::web::patch().to(rest_interact::thread_resolve)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/threads/{thread_id}/archive")
|
||||
.route(actix_web::web::patch().to(rest_interact::thread_archive)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/pins")
|
||||
.route(actix_web::web::post().to(rest_interact::pin_add))
|
||||
.route(actix_web::web::delete().to(rest_interact::pin_remove)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/drafts")
|
||||
.route(actix_web::web::put().to(rest_interact::draft_save))
|
||||
.route(actix_web::web::delete().to(rest_interact::draft_clear)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/typing")
|
||||
.route(actix_web::web::post().to(rest_interact::typing)),
|
||||
);
|
||||
cfg.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/read-receipt")
|
||||
.route(actix_web::web::post().to(rest_member::read_receipt)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/dnd")
|
||||
.route(actix_web::web::patch().to(rest_member::dnd_update)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/notifications/{id}/read").route(
|
||||
actix_web::web::patch().to(rest_member::notification_mark_read),
|
||||
),
|
||||
)
|
||||
.service(actix_web::web::resource("/notifications/read-all").route(
|
||||
actix_web::web::post().to(rest_member::notification_mark_all_read),
|
||||
))
|
||||
.service(
|
||||
actix_web::web::resource("/notifications/{id}").route(
|
||||
actix_web::web::delete().to(rest_member::notification_archive),
|
||||
),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/presence")
|
||||
.route(actix_web::web::post().to(rest_member::presence_update)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/custom-status").route(
|
||||
actix_web::web::post().to(rest_member::custom_status_update),
|
||||
),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/invites")
|
||||
.route(actix_web::web::post().to(rest_member::invite_create)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/invites/accept")
|
||||
.route(actix_web::web::post().to(rest_member::invite_accept)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/invites/{id}")
|
||||
.route(actix_web::web::delete().to(rest_member::invite_revoke)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/workspaces/{workspace_id}/bans")
|
||||
.route(actix_web::web::post().to(rest_member::ban_create)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/workspaces/{workspace_id}/bans/{user_id}")
|
||||
.route(actix_web::web::delete().to(rest_member::ban_remove)),
|
||||
);
|
||||
cfg.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/voice/join")
|
||||
.route(actix_web::web::post().to(rest_voice::voice_join)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/voice/leave")
|
||||
.route(actix_web::web::post().to(rest_voice::voice_leave)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/voice/mute")
|
||||
.route(actix_web::web::post().to(rest_voice::voice_mute)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/voice/deaf")
|
||||
.route(actix_web::web::post().to(rest_voice::voice_deaf)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/screen-share")
|
||||
.route(actix_web::web::post().to(rest_voice::screen_share)),
|
||||
);
|
||||
cfg.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/ai/stop")
|
||||
.route(actix_web::web::post().to(rest_ai::ai_stop)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/ai")
|
||||
.route(actix_web::web::get().to(rest_ai::ai_list))
|
||||
.route(actix_web::web::post().to(rest_ai::ai_add)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/rooms/{room_id}/ai/{agent_session_id}")
|
||||
.route(actix_web::web::delete().to(rest_ai::ai_remove)),
|
||||
)
|
||||
.service(
|
||||
actix_web::web::resource("/users/summary/{username}")
|
||||
.route(actix_web::web::get().to(rest_ai::user_summary)),
|
||||
);
|
||||
cfg.service(
|
||||
actix_web::web::resource("/token")
|
||||
.route(actix_web::web::post().to(token::generate_token)),
|
||||
);
|
||||
cfg.app_data(actix_web::web::Data::new(bus));
|
||||
}
|
||||
110
lib/api/src/channel/rest.rs
Normal file
110
lib/api/src/channel/rest.rs
Normal file
@ -0,0 +1,110 @@
|
||||
use actix_web::{HttpRequest, HttpResponse, web};
|
||||
use channel::http::{WsHandler, WsInMessage, WsOutEvent};
|
||||
use channel::{ChannelBus, ChannelError};
|
||||
use session::SessionExt;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
pub(crate) fn extract_user(req: &HttpRequest) -> Result<Uuid, ApiError> {
|
||||
req.get_session()
|
||||
.user()
|
||||
.ok_or_else(|| ApiError(service::error::AppError::Unauthorized))
|
||||
}
|
||||
|
||||
pub(crate) fn channel_err(e: ChannelError) -> ApiError {
|
||||
ApiError(match e {
|
||||
ChannelError::Unauthorized | ChannelError::TokenInvalidOrExpired => {
|
||||
service::error::AppError::Unauthorized
|
||||
}
|
||||
ChannelError::AccessDenied => {
|
||||
service::error::AppError::PermissionDenied
|
||||
}
|
||||
ChannelError::Validation(msg) => {
|
||||
service::error::AppError::BadRequest(msg)
|
||||
}
|
||||
ChannelError::RateLimitExceeded => {
|
||||
service::error::AppError::BadRequest("rate limit exceeded".into())
|
||||
}
|
||||
ChannelError::RenewalLimitExceeded => {
|
||||
service::error::AppError::BadRequest(
|
||||
"renewal limit exceeded".into(),
|
||||
)
|
||||
}
|
||||
ChannelError::RoomNotFound => {
|
||||
service::error::AppError::NotFound("room not found".into())
|
||||
}
|
||||
ChannelError::UserNotFound => {
|
||||
service::error::AppError::NotFound("user not found".into())
|
||||
}
|
||||
ChannelError::Internal(msg) => {
|
||||
service::error::AppError::InternalServerError(msg)
|
||||
}
|
||||
ChannelError::Database(e) => {
|
||||
service::error::AppError::InternalServerError(e.to_string())
|
||||
}
|
||||
ChannelError::Cache(e) => {
|
||||
service::error::AppError::InternalServerError(e.to_string())
|
||||
}
|
||||
ChannelError::SocketIo(e) => {
|
||||
service::error::AppError::InternalServerError(e.to_string())
|
||||
}
|
||||
ChannelError::Serialization(e) => {
|
||||
service::error::AppError::InternalServerError(e.to_string())
|
||||
}
|
||||
ChannelError::Redis(e) => {
|
||||
service::error::AppError::InternalServerError(e.to_string())
|
||||
}
|
||||
ChannelError::Storage(e) => {
|
||||
service::error::AppError::InternalServerError(e.to_string())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn ok_json(event: Option<WsOutEvent>) -> HttpResponse {
|
||||
match event {
|
||||
Some(e) => HttpResponse::Ok().json(e),
|
||||
None => HttpResponse::NoContent().finish(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn created_json(event: Option<WsOutEvent>) -> HttpResponse {
|
||||
match event {
|
||||
Some(e) => HttpResponse::Created().json(e),
|
||||
None => HttpResponse::NoContent().finish(),
|
||||
}
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/ws/ping",
|
||||
responses((status = 200, description = "Pong with protocol version")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn ping(
|
||||
req: HttpRequest,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let result = WsHandler::handle(&bus, user_id, WsInMessage::Ping)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/ws/csrf",
|
||||
responses((status = 200, description = "CSRF token")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn csrf_token(
|
||||
req: HttpRequest,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let result = WsHandler::handle(&bus, user_id, WsInMessage::CsrfToken)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
120
lib/api/src/channel/rest_ai.rs
Normal file
120
lib/api/src/channel/rest_ai.rs
Normal file
@ -0,0 +1,120 @@
|
||||
use actix_web::{HttpRequest, HttpResponse, web};
|
||||
use channel::ChannelBus;
|
||||
use channel::http::{WsHandler, WsInMessage};
|
||||
use serde::Deserialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::rest::{channel_err, created_json, extract_user, ok_json};
|
||||
use crate::error::ApiError;
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct AiAddRequest {
|
||||
pub agent_session: Uuid,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/ws/rooms/{room_id}/ai",
|
||||
responses((status = 200, description = "AI agents in room")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn ai_list(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::AiList {
|
||||
room: room_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/ai",
|
||||
request_body = AiAddRequest,
|
||||
responses((status = 201, description = "AI agent added to room")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn ai_add(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<AiAddRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::AiUpsert {
|
||||
room: room_id.into_inner(),
|
||||
model: body.agent_session,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(created_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/ws/rooms/{room_id}/ai/{agent_session_id}",
|
||||
responses((status = 200, description = "AI agent removed from room")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn ai_remove(
|
||||
req: HttpRequest,
|
||||
path: web::Path<(Uuid, Uuid)>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let (room, agent_id) = path.into_inner();
|
||||
let msg = WsInMessage::AiDelete { room, agent_id };
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/ai/stop",
|
||||
responses((status = 204, description = "AI agent stopped")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn ai_stop(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::AiStop {
|
||||
room: room_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/ws/users/summary/{username}",
|
||||
responses((status = 200, description = "User summary")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn user_summary(
|
||||
req: HttpRequest,
|
||||
username: web::Path<String>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::UserSummary {
|
||||
username: username.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
274
lib/api/src/channel/rest_interact.rs
Normal file
274
lib/api/src/channel/rest_interact.rs
Normal file
@ -0,0 +1,274 @@
|
||||
use actix_web::{HttpRequest, HttpResponse, web};
|
||||
use channel::ChannelBus;
|
||||
use channel::http::{WsHandler, WsInMessage};
|
||||
use serde::Deserialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::rest::{channel_err, created_json, extract_user, ok_json};
|
||||
use crate::error::ApiError;
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct ReactionRequest {
|
||||
pub message: Uuid,
|
||||
pub emoji: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct ThreadCreateRequest {
|
||||
pub parent: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct TypingRequest {
|
||||
pub action: TypingAction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub enum TypingAction {
|
||||
Start,
|
||||
Stop,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/reactions",
|
||||
request_body = ReactionRequest,
|
||||
responses((status = 204, description = "Reaction added")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn reaction_add(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<ReactionRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::ReactionAdd {
|
||||
room: room_id.into_inner(),
|
||||
message: body.message,
|
||||
emoji: body.emoji.clone(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/ws/rooms/{room_id}/reactions",
|
||||
request_body = ReactionRequest,
|
||||
responses((status = 204, description = "Reaction removed")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn reaction_remove(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<ReactionRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::ReactionRemove {
|
||||
room: room_id.into_inner(),
|
||||
message: body.message,
|
||||
emoji: body.emoji.clone(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/threads",
|
||||
request_body = ThreadCreateRequest,
|
||||
responses((status = 201, description = "Thread created")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn thread_create(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<ThreadCreateRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::ThreadCreate {
|
||||
room: room_id.into_inner(),
|
||||
parent: body.parent,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(created_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
patch,
|
||||
path = "/api/v1/ws/threads/{thread_id}/resolve",
|
||||
responses((status = 200, description = "Thread resolved")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn thread_resolve(
|
||||
req: HttpRequest,
|
||||
thread_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::ThreadResolve {
|
||||
thread_id: thread_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
patch,
|
||||
path = "/api/v1/ws/threads/{thread_id}/archive",
|
||||
responses((status = 200, description = "Thread archived")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn thread_archive(
|
||||
req: HttpRequest,
|
||||
thread_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::ThreadArchive {
|
||||
thread_id: thread_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct PinRequest {
|
||||
pub message: Uuid,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/pins",
|
||||
request_body = PinRequest,
|
||||
responses((status = 204, description = "Message pinned")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn pin_add(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<PinRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::PinAdd {
|
||||
room: room_id.into_inner(),
|
||||
message: body.message,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/ws/rooms/{room_id}/pins",
|
||||
request_body = PinRequest,
|
||||
responses((status = 204, description = "Pin removed")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn pin_remove(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<PinRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::PinRemove {
|
||||
room: room_id.into_inner(),
|
||||
message: body.message,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct DraftSaveRequest {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
put,
|
||||
path = "/api/v1/ws/rooms/{room_id}/drafts",
|
||||
request_body = DraftSaveRequest,
|
||||
responses((status = 204, description = "Draft saved")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn draft_save(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<DraftSaveRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::DraftSave {
|
||||
room: room_id.into_inner(),
|
||||
content: body.content.clone(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/ws/rooms/{room_id}/drafts",
|
||||
responses((status = 204, description = "Draft cleared")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn draft_clear(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::DraftClear {
|
||||
room: room_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/typing",
|
||||
request_body = TypingRequest,
|
||||
responses((status = 204, description = "Typing indicator broadcasted")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn typing(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<TypingRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let room = room_id.into_inner();
|
||||
let msg = match body.action {
|
||||
TypingAction::Start => WsInMessage::TypingStart { room },
|
||||
TypingAction::Stop => WsInMessage::TypingStop { room },
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
375
lib/api/src/channel/rest_member.rs
Normal file
375
lib/api/src/channel/rest_member.rs
Normal file
@ -0,0 +1,375 @@
|
||||
use actix_web::{HttpRequest, HttpResponse, web};
|
||||
use channel::ChannelBus;
|
||||
use channel::http::{WsHandler, WsInMessage};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::rest::{channel_err, created_json, extract_user, ok_json};
|
||||
use crate::error::ApiError;
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct ReadReceiptRequest {
|
||||
pub last_read_seq: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct DndRequest {
|
||||
pub do_not_disturb: Option<bool>,
|
||||
pub dnd_start_hour: Option<i16>,
|
||||
pub dnd_end_hour: Option<i16>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct NotificationMarkAllReadRequest {
|
||||
pub workspace_id: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct PresenceUpdateRequest {
|
||||
#[schema(example = "online")]
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct CustomStatusRequest {
|
||||
pub emoji: Option<String>,
|
||||
pub text: Option<String>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct InviteCreateRequest {
|
||||
pub workspace: Uuid,
|
||||
pub room: Option<Uuid>,
|
||||
pub max_uses: Option<i32>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct InviteAcceptRequest {
|
||||
pub code: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct BanCreateRequest {
|
||||
pub user: Uuid,
|
||||
pub reason: Option<String>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/read-receipt",
|
||||
request_body = ReadReceiptRequest,
|
||||
responses((status = 200, description = "Read receipt saved")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn read_receipt(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<ReadReceiptRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::ReadReceipt {
|
||||
room: room_id.into_inner(),
|
||||
last_read_seq: body.last_read_seq,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
patch,
|
||||
path = "/api/v1/ws/rooms/{room_id}/dnd",
|
||||
request_body = DndRequest,
|
||||
responses((status = 204, description = "DND updated")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn dnd_update(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<DndRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::StateUpdateDnd {
|
||||
room: room_id.into_inner(),
|
||||
do_not_disturb: body.do_not_disturb,
|
||||
dnd_start_hour: body.dnd_start_hour,
|
||||
dnd_end_hour: body.dnd_end_hour,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
patch,
|
||||
path = "/api/v1/ws/notifications/{id}/read",
|
||||
responses((status = 204, description = "Notification marked read")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn notification_mark_read(
|
||||
req: HttpRequest,
|
||||
id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::NotificationMarkRead {
|
||||
id: id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/notifications/read-all",
|
||||
request_body = NotificationMarkAllReadRequest,
|
||||
responses((status = 204, description = "All notifications marked read")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn notification_mark_all_read(
|
||||
req: HttpRequest,
|
||||
body: web::Json<NotificationMarkAllReadRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::NotificationMarkAllRead {
|
||||
workspace_id: body.workspace_id,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/ws/notifications/{id}",
|
||||
responses((status = 204, description = "Notification archived")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn notification_archive(
|
||||
req: HttpRequest,
|
||||
id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::NotificationArchive {
|
||||
id: id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/presence",
|
||||
request_body = PresenceUpdateRequest,
|
||||
responses((status = 204, description = "Presence updated")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn presence_update(
|
||||
req: HttpRequest,
|
||||
body: web::Json<PresenceUpdateRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let status: channel::event::presence::UserPresenceStatus =
|
||||
serde_json::from_value(serde_json::Value::String(body.status.clone()))
|
||||
.map_err(|e| {
|
||||
ApiError(service::error::AppError::BadRequest(e.to_string()))
|
||||
})?;
|
||||
let msg = WsInMessage::PresenceUpdate { status };
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/custom-status",
|
||||
request_body = CustomStatusRequest,
|
||||
responses((status = 204, description = "Custom status updated")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn custom_status_update(
|
||||
req: HttpRequest,
|
||||
body: web::Json<CustomStatusRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::CustomStatusUpdate {
|
||||
emoji: body.emoji.clone(),
|
||||
text: body.text.clone(),
|
||||
expires_at: body.expires_at,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/invites",
|
||||
request_body = InviteCreateRequest,
|
||||
responses((status = 201, description = "Invite created")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn invite_create(
|
||||
req: HttpRequest,
|
||||
body: web::Json<InviteCreateRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::InviteCreate {
|
||||
workspace: body.workspace,
|
||||
room: body.room,
|
||||
max_uses: body.max_uses,
|
||||
expires_at: body.expires_at,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(created_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/invites/accept",
|
||||
request_body = InviteAcceptRequest,
|
||||
responses((status = 200, description = "Invite accepted")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn invite_accept(
|
||||
req: HttpRequest,
|
||||
body: web::Json<InviteAcceptRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::InviteAccept {
|
||||
code: body.code.clone(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/ws/invites/{id}",
|
||||
responses((status = 204, description = "Invite revoked")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn invite_revoke(
|
||||
req: HttpRequest,
|
||||
id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::InviteRevoke {
|
||||
id: id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/workspaces/{workspace_id}/bans",
|
||||
request_body = BanCreateRequest,
|
||||
responses((status = 201, description = "User banned")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn ban_create(
|
||||
req: HttpRequest,
|
||||
workspace_id: web::Path<Uuid>,
|
||||
body: web::Json<BanCreateRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::BanCreate {
|
||||
workspace: workspace_id.into_inner(),
|
||||
user: body.user,
|
||||
reason: body.reason.clone(),
|
||||
expires_at: body.expires_at,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/ws/workspaces/{workspace_id}/bans/{user_id}",
|
||||
responses((status = 204, description = "User unbanned")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn ban_remove(
|
||||
req: HttpRequest,
|
||||
path: web::Path<(Uuid, Uuid)>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let (workspace, target_user) = path.into_inner();
|
||||
let msg = WsInMessage::BanRemove {
|
||||
workspace,
|
||||
user: target_user,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||
pub struct RoomMember {
|
||||
pub id: Uuid,
|
||||
pub username: String,
|
||||
pub display_name: String,
|
||||
pub avatar_url: String,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/ws/workspaces/{workspace_id}/members",
|
||||
responses((status = 200, description = "Workspace members list")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn list_workspace_members(
|
||||
req: HttpRequest,
|
||||
workspace_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let _user_id = extract_user(&req)?;
|
||||
let workspace = workspace_id.into_inner();
|
||||
|
||||
let members = bus.list_workspace_members(workspace).await.map_err(channel_err)?;
|
||||
let result: Vec<RoomMember> = members
|
||||
.into_iter()
|
||||
.map(|(id, username, display_name, avatar_url)| RoomMember {
|
||||
id,
|
||||
username,
|
||||
display_name,
|
||||
avatar_url,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(HttpResponse::Ok().json(result))
|
||||
}
|
||||
225
lib/api/src/channel/rest_message.rs
Normal file
225
lib/api/src/channel/rest_message.rs
Normal file
@ -0,0 +1,225 @@
|
||||
use actix_web::{HttpRequest, HttpResponse, web};
|
||||
use channel::ChannelBus;
|
||||
use channel::http::{WsHandler, WsInMessage};
|
||||
use serde::Deserialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::rest::{channel_err, created_json, extract_user, ok_json};
|
||||
use crate::error::ApiError;
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct CreateMessageRequest {
|
||||
pub content: String,
|
||||
pub content_type: Option<String>,
|
||||
pub thread: Option<Uuid>,
|
||||
pub in_reply_to: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct UpdateMessageRequest {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::IntoParams)]
|
||||
pub struct MessageListParams {
|
||||
pub before_seq: Option<i64>,
|
||||
pub after_seq: Option<i64>,
|
||||
pub limit: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::IntoParams)]
|
||||
pub struct MessageAroundParams {
|
||||
pub seq: i64,
|
||||
pub limit: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::IntoParams)]
|
||||
pub struct MissedMessagesParams {
|
||||
pub after_seq: i64,
|
||||
pub limit: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::IntoParams)]
|
||||
pub struct SearchParams {
|
||||
pub q: String,
|
||||
pub room: Option<Uuid>,
|
||||
pub limit: Option<u64>,
|
||||
pub offset: Option<u64>,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/messages",
|
||||
request_body = CreateMessageRequest,
|
||||
responses((status = 201, description = "Message created")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn create_message(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<CreateMessageRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::MessageCreate {
|
||||
room: room_id.into_inner(),
|
||||
content: body.content.clone(),
|
||||
content_type: body.content_type.clone(),
|
||||
thread: body.thread,
|
||||
in_reply_to: body.in_reply_to,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(created_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
patch,
|
||||
path = "/api/v1/ws/messages/{message_id}",
|
||||
request_body = UpdateMessageRequest,
|
||||
responses((status = 200, description = "Message updated")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn update_message(
|
||||
req: HttpRequest,
|
||||
message_id: web::Path<Uuid>,
|
||||
body: web::Json<UpdateMessageRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::MessageUpdate {
|
||||
message: message_id.into_inner(),
|
||||
content: body.content.clone(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/ws/messages/{message_id}",
|
||||
responses((status = 200, description = "Message revoked")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn revoke_message(
|
||||
req: HttpRequest,
|
||||
message_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::MessageRevoke {
|
||||
message: message_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/ws/rooms/{room_id}/messages",
|
||||
params(MessageListParams),
|
||||
responses((status = 200, description = "Message list")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn list_messages(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
params: web::Query<MessageListParams>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::MessageList {
|
||||
room: room_id.into_inner(),
|
||||
before_seq: params.before_seq,
|
||||
after_seq: params.after_seq,
|
||||
limit: params.limit,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/ws/rooms/{room_id}/messages/around",
|
||||
params(MessageAroundParams),
|
||||
responses((status = 200, description = "Messages around seq")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn messages_around(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
params: web::Query<MessageAroundParams>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::MessageAround {
|
||||
room: room_id.into_inner(),
|
||||
seq: params.seq,
|
||||
limit: params.limit,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/ws/rooms/{room_id}/messages/missed",
|
||||
params(MissedMessagesParams),
|
||||
responses((status = 200, description = "Missed messages")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn missed_messages(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
params: web::Query<MissedMessagesParams>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::MissedMessages {
|
||||
room: room_id.into_inner(),
|
||||
after_seq: params.after_seq,
|
||||
limit: params.limit,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/ws/search",
|
||||
params(SearchParams),
|
||||
responses((status = 200, description = "Search results")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn search(
|
||||
req: HttpRequest,
|
||||
params: web::Query<SearchParams>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::Search {
|
||||
q: params.q.clone(),
|
||||
room: params.room,
|
||||
start_time: None,
|
||||
end_time: None,
|
||||
sender_id: None,
|
||||
content_type: None,
|
||||
limit: params.limit,
|
||||
offset: params.offset,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
321
lib/api/src/channel/rest_room.rs
Normal file
321
lib/api/src/channel/rest_room.rs
Normal file
@ -0,0 +1,321 @@
|
||||
use actix_web::{HttpRequest, HttpResponse, web};
|
||||
use channel::ChannelBus;
|
||||
use channel::http::{WsHandler, WsInMessage};
|
||||
use serde::Deserialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::rest::{channel_err, created_json, extract_user, ok_json};
|
||||
use crate::error::ApiError;
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct RoomCreateRequest {
|
||||
pub workspace: Uuid,
|
||||
pub room_name: String,
|
||||
pub public: bool,
|
||||
pub category: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct RoomUpdateRequest {
|
||||
pub room_name: Option<String>,
|
||||
pub public: Option<bool>,
|
||||
pub category: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct AccessRequest {
|
||||
pub user: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct CategoryCreateRequest {
|
||||
pub name: String,
|
||||
pub position: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct CategoryUpdateRequest {
|
||||
pub name: Option<String>,
|
||||
pub position: Option<i32>,
|
||||
}
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/ws/rooms",
|
||||
responses((status = 200, description = "List of rooms")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn list_rooms(
|
||||
req: HttpRequest,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let rooms = bus.list_user_rooms(user_id)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
let categories = bus.list_user_categories(user_id)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
let workspace_id = if let Some(r) = rooms.first() {
|
||||
Some(r.workspace_id)
|
||||
} else {
|
||||
bus.first_workspace_id(user_id).await.unwrap_or(None)
|
||||
};
|
||||
Ok(HttpResponse::Ok().json(serde_json::json!({
|
||||
"rooms": rooms,
|
||||
"categories": categories,
|
||||
"workspace_id": workspace_id,
|
||||
})))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/subscribe",
|
||||
responses((status = 204, description = "Subscribed, user room cache refreshed")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn subscribe(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::Subscribe {
|
||||
room: room_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/ws/rooms/{room_id}/subscribe",
|
||||
responses((status = 204, description = "Unsubscribed")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn unsubscribe(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::Unsubscribe {
|
||||
room: room_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/api/v1/ws/rooms/{room_id}",
|
||||
responses((status = 200, description = "Room info")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn room_get(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::RoomGet {
|
||||
room: room_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms",
|
||||
request_body = RoomCreateRequest,
|
||||
responses((status = 201, description = "Room created")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn room_create(
|
||||
req: HttpRequest,
|
||||
body: web::Json<RoomCreateRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::RoomCreate {
|
||||
workspace: body.workspace,
|
||||
room_name: body.room_name.clone(),
|
||||
public: body.public,
|
||||
category: body.category,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(created_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
patch,
|
||||
path = "/api/v1/ws/rooms/{room_id}",
|
||||
request_body = RoomUpdateRequest,
|
||||
responses((status = 200, description = "Room updated")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn room_update(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<RoomUpdateRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::RoomUpdate {
|
||||
room: room_id.into_inner(),
|
||||
room_name: body.room_name.clone(),
|
||||
public: body.public,
|
||||
category: body.category,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/ws/rooms/{room_id}",
|
||||
responses((status = 204, description = "Room deleted")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn room_delete(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::RoomDelete {
|
||||
room: room_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/members",
|
||||
request_body = AccessRequest,
|
||||
responses((status = 204, description = "Access granted")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn access_grant(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<AccessRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::AccessGrant {
|
||||
room: room_id.into_inner(),
|
||||
user: body.user,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/ws/rooms/{room_id}/members/{user_id}",
|
||||
responses((status = 204, description = "Access revoked")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn access_revoke(
|
||||
req: HttpRequest,
|
||||
path: web::Path<(Uuid, Uuid)>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let (room, target_user) = path.into_inner();
|
||||
let msg = WsInMessage::AccessRevoke {
|
||||
room,
|
||||
user: target_user,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/workspaces/{workspace_id}/categories",
|
||||
request_body = CategoryCreateRequest,
|
||||
responses((status = 201, description = "Category created")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn category_create(
|
||||
req: HttpRequest,
|
||||
workspace_id: web::Path<Uuid>,
|
||||
body: web::Json<CategoryCreateRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::CategoryCreate {
|
||||
workspace: workspace_id.into_inner(),
|
||||
name: body.name.clone(),
|
||||
position: body.position,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(created_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
patch,
|
||||
path = "/api/v1/ws/categories/{category_id}",
|
||||
request_body = CategoryUpdateRequest,
|
||||
responses((status = 200, description = "Category updated")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn category_update(
|
||||
req: HttpRequest,
|
||||
category_id: web::Path<Uuid>,
|
||||
body: web::Json<CategoryUpdateRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::CategoryUpdate {
|
||||
id: category_id.into_inner(),
|
||||
name: body.name.clone(),
|
||||
position: body.position,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
delete,
|
||||
path = "/api/v1/ws/categories/{category_id}",
|
||||
responses((status = 204, description = "Category deleted")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn category_delete(
|
||||
req: HttpRequest,
|
||||
category_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::CategoryDelete {
|
||||
id: category_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
137
lib/api/src/channel/rest_voice.rs
Normal file
137
lib/api/src/channel/rest_voice.rs
Normal file
@ -0,0 +1,137 @@
|
||||
use actix_web::{HttpRequest, HttpResponse, web};
|
||||
use channel::ChannelBus;
|
||||
use channel::http::{WsHandler, WsInMessage};
|
||||
use serde::Deserialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::rest::{channel_err, extract_user, ok_json};
|
||||
use crate::error::ApiError;
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct VoiceMuteRequest {
|
||||
pub muted: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct VoiceDeafRequest {
|
||||
pub deafened: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct ScreenShareRequest {
|
||||
pub start: bool,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/voice/join",
|
||||
responses((status = 204, description = "Joined voice channel")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn voice_join(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::VoiceJoin {
|
||||
room: room_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/voice/leave",
|
||||
responses((status = 204, description = "Left voice channel")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn voice_leave(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::VoiceLeave {
|
||||
room: room_id.into_inner(),
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/voice/mute",
|
||||
request_body = VoiceMuteRequest,
|
||||
responses((status = 204, description = "Mute toggled")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn voice_mute(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<VoiceMuteRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::VoiceMute {
|
||||
room: room_id.into_inner(),
|
||||
muted: body.muted,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/voice/deaf",
|
||||
request_body = VoiceDeafRequest,
|
||||
responses((status = 204, description = "Deaf toggled")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn voice_deaf(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<VoiceDeafRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::VoiceDeaf {
|
||||
room: room_id.into_inner(),
|
||||
deafened: body.deafened,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/rooms/{room_id}/screen-share",
|
||||
request_body = ScreenShareRequest,
|
||||
responses((status = 204, description = "Screen share toggled")),
|
||||
tag = "channel",
|
||||
)]
|
||||
pub async fn screen_share(
|
||||
req: HttpRequest,
|
||||
room_id: web::Path<Uuid>,
|
||||
body: web::Json<ScreenShareRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = extract_user(&req)?;
|
||||
let msg = WsInMessage::ScreenShare {
|
||||
room: room_id.into_inner(),
|
||||
start: body.start,
|
||||
};
|
||||
let result = WsHandler::handle(&bus, user_id, msg)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
Ok(ok_json(result))
|
||||
}
|
||||
53
lib/api/src/channel/token.rs
Normal file
53
lib/api/src/channel/token.rs
Normal file
@ -0,0 +1,53 @@
|
||||
use actix_web::{HttpRequest, HttpResponse, web};
|
||||
use channel::{ChannelBus, ChannelTokenApply, TOKEN_TTL_SECS};
|
||||
use serde::Deserialize;
|
||||
use session::SessionExt;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
use super::rest::channel_err;
|
||||
|
||||
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||
pub struct TokenRequest {
|
||||
pub device_id: String,
|
||||
pub client_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize, utoipa::ToSchema)]
|
||||
pub struct TokenResponse {
|
||||
pub access_token: String,
|
||||
pub expires_in_secs: u64,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/api/v1/ws/token",
|
||||
request_body = TokenRequest,
|
||||
responses((status = 200, body = TokenResponse)),
|
||||
tag = "channel"
|
||||
)]
|
||||
pub async fn generate_token(
|
||||
req: HttpRequest,
|
||||
body: web::Json<TokenRequest>,
|
||||
bus: web::Data<ChannelBus>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let user_id = req
|
||||
.get_session()
|
||||
.user()
|
||||
.ok_or_else(|| ApiError(service::error::AppError::Unauthorized))?;
|
||||
|
||||
let apply = ChannelTokenApply {
|
||||
device_id: body.device_id.clone(),
|
||||
client_id: body.client_id.clone(),
|
||||
};
|
||||
|
||||
let token = bus
|
||||
.apply_access_token(user_id, apply)
|
||||
.await
|
||||
.map_err(channel_err)?;
|
||||
|
||||
Ok(HttpResponse::Ok().json(TokenResponse {
|
||||
access_token: token.access_token,
|
||||
expires_in_secs: TOKEN_TTL_SECS,
|
||||
}))
|
||||
}
|
||||
91
lib/api/src/error.rs
Normal file
91
lib/api/src/error.rs
Normal file
@ -0,0 +1,91 @@
|
||||
use actix_web::{HttpResponse, error::ResponseError, http::StatusCode};
|
||||
use serde::Serialize;
|
||||
use service::error::AppError;
|
||||
|
||||
pub fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().json(data))
|
||||
}
|
||||
|
||||
pub struct ApiError(pub AppError);
|
||||
|
||||
impl From<AppError> for ApiError {
|
||||
fn from(err: AppError) -> Self {
|
||||
ApiError(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ApiError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ApiError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseError for ApiError {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match &self.0 {
|
||||
AppError::Unauthorized => StatusCode::UNAUTHORIZED,
|
||||
AppError::UserNotFound => StatusCode::NOT_FOUND,
|
||||
AppError::InvalidPassword => StatusCode::UNAUTHORIZED,
|
||||
AppError::PasswordTooWeak => StatusCode::BAD_REQUEST,
|
||||
AppError::CaptchaError => StatusCode::BAD_REQUEST,
|
||||
AppError::TwoFactorRequired => {
|
||||
StatusCode::from_u16(402).unwrap_or(StatusCode::BAD_REQUEST)
|
||||
}
|
||||
AppError::TwoFactorAlreadyEnabled => StatusCode::CONFLICT,
|
||||
AppError::TwoFactorNotSetup => StatusCode::NOT_FOUND,
|
||||
AppError::InvalidTwoFactorCode => StatusCode::BAD_REQUEST,
|
||||
AppError::TwoFactorNotEnabled => StatusCode::NOT_FOUND,
|
||||
AppError::RsaGenerationError => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
AppError::RsaDecodeError => StatusCode::BAD_REQUEST,
|
||||
AppError::UserNameExists => StatusCode::CONFLICT,
|
||||
AppError::EmailExists => StatusCode::CONFLICT,
|
||||
AppError::AccountAlreadyExists => StatusCode::CONFLICT,
|
||||
AppError::TxnError => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
AppError::PasswordHashError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
AppError::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
AppError::DoMainNotSet => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
AppError::InternalError => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
AppError::InternalServerError(_) => {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
AppError::PermissionDenied => StatusCode::FORBIDDEN,
|
||||
AppError::ProjectNotFound => StatusCode::NOT_FOUND,
|
||||
AppError::NoPower => StatusCode::FORBIDDEN,
|
||||
AppError::RoleParseError => StatusCode::BAD_REQUEST,
|
||||
AppError::ProjectNameAlreadyExists => StatusCode::CONFLICT,
|
||||
AppError::RepoNameAlreadyExists => StatusCode::CONFLICT,
|
||||
AppError::AvatarUploadError(_) => StatusCode::BAD_REQUEST,
|
||||
AppError::RepoNotFound => StatusCode::NOT_FOUND,
|
||||
AppError::RepoForBidAccess => StatusCode::FORBIDDEN,
|
||||
AppError::SerdeError(_) => StatusCode::BAD_REQUEST,
|
||||
AppError::Io(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
AppError::BadRequest(_) => StatusCode::BAD_REQUEST,
|
||||
AppError::Forbidden(_) => StatusCode::FORBIDDEN,
|
||||
AppError::Conflict(_) => StatusCode::CONFLICT,
|
||||
AppError::NotFound(_) => StatusCode::NOT_FOUND,
|
||||
AppError::InvalidResetToken => StatusCode::BAD_REQUEST,
|
||||
AppError::ResetTokenExpired => StatusCode::BAD_REQUEST,
|
||||
AppError::ResetTokenUsed => StatusCode::BAD_REQUEST,
|
||||
AppError::IssueNotFound => StatusCode::NOT_FOUND,
|
||||
AppError::LabelNotFound => StatusCode::NOT_FOUND,
|
||||
AppError::MilestoneNotFound => StatusCode::NOT_FOUND,
|
||||
AppError::PullRequestNotFound => StatusCode::NOT_FOUND,
|
||||
AppError::CommentNotFound => StatusCode::NOT_FOUND,
|
||||
AppError::GitRpcError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
AppError::AiError(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(&self) -> HttpResponse {
|
||||
let status = self.status_code();
|
||||
let message = self.0.to_string();
|
||||
HttpResponse::build(status)
|
||||
.json(serde_json::json!({ "error": message }))
|
||||
}
|
||||
}
|
||||
47
lib/api/src/git/archive.rs
Normal file
47
lib/api/src/git/archive.rs
Normal file
@ -0,0 +1,47 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use service::AppService;
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
|
||||
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().json(data))
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct WkRepoPath {
|
||||
pub wk: String,
|
||||
pub repo: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct ArchiveQuery {
|
||||
#[serde(default = "default_format")]
|
||||
pub format: String,
|
||||
pub tree: Option<String>,
|
||||
pub prefix: Option<String>,
|
||||
pub pathspec: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn default_format() -> String {
|
||||
"tar".to_string()
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/archive",
|
||||
params(WkRepoPath, ArchiveQuery),
|
||||
responses((status = 200, description = "Archive download")),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn archive(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoPath>,
|
||||
query: web::Query<ArchiveQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoPath { wk, repo } = path.into_inner();
|
||||
match query.format.as_str() {
|
||||
"zip" => ok_json(service.git_archive_zip(&session, &wk, &repo, None).await?),
|
||||
_ => ok_json(service.git_archive_tar(&session, &wk, &repo, None).await?),
|
||||
}
|
||||
}
|
||||
62
lib/api/src/git/blame.rs
Normal file
62
lib/api/src/git/blame.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use service::AppService;
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::git::dto;
|
||||
|
||||
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().json(data))
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct WkRepoPath {
|
||||
pub wk: String,
|
||||
pub repo: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct BlameQuery {
|
||||
pub path: String,
|
||||
pub rev: Option<String>,
|
||||
pub start_line: Option<u32>,
|
||||
pub end_line: Option<u32>,
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/blame",
|
||||
params(WkRepoPath, BlameQuery),
|
||||
responses((status = 200, description = "Blame result", body = dto::BlameFileResponseDto)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn blame_file(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoPath>,
|
||||
query: web::Query<BlameQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoPath { wk, repo } = path.into_inner();
|
||||
|
||||
match (query.start_line, query.end_line) {
|
||||
(Some(start), Some(end)) => {
|
||||
let data: dto::BlameFileResponseDto = service
|
||||
.git_blame_hunk(
|
||||
&session, &wk, &repo, query.path.clone(),
|
||||
query.rev.clone(), start, end,
|
||||
)
|
||||
.await?
|
||||
.into();
|
||||
ok_json(data)
|
||||
}
|
||||
_ => {
|
||||
let data: dto::BlameFileResponseDto = service
|
||||
.git_blame_file(
|
||||
&session, &wk, &repo, query.path.clone(),
|
||||
query.rev.clone(), None,
|
||||
)
|
||||
.await?
|
||||
.into();
|
||||
ok_json(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
97
lib/api/src/git/blob.rs
Normal file
97
lib/api/src/git/blob.rs
Normal file
@ -0,0 +1,97 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use service::AppService;
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::git::dto;
|
||||
|
||||
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().json(data))
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct WkRepoPath {
|
||||
pub wk: String,
|
||||
pub repo: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct WkRepoBlobPath {
|
||||
pub wk: String,
|
||||
pub repo: String,
|
||||
pub oid: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct BlobPathQuery {
|
||||
pub path: Option<String>,
|
||||
}
|
||||
#[derive(Serialize, utoipa::ToSchema)]
|
||||
pub struct BlobInfoResponse {
|
||||
#[serde(flatten)]
|
||||
pub load: dto::BlobLoadResponseDto,
|
||||
pub size: u64,
|
||||
pub is_binary: bool,
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/blobs/{oid}",
|
||||
params(WkRepoBlobPath, BlobPathQuery),
|
||||
responses((status = 200, description = "Blob info", body = BlobInfoResponse)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn blob_info(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoBlobPath>,
|
||||
query: web::Query<BlobPathQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoBlobPath { wk, repo, oid } = path.into_inner();
|
||||
let path_opt = query.path.clone().unwrap_or_default();
|
||||
|
||||
let load: dto::BlobLoadResponseDto = service
|
||||
.git_blob_load(&session, &wk, &repo, oid.clone(), path_opt.clone())
|
||||
.await?
|
||||
.into();
|
||||
let size_resp: dto::BlobSizeResponseDto = service
|
||||
.git_blob_size(&session, &wk, &repo, oid.clone(), path_opt)
|
||||
.await?
|
||||
.into();
|
||||
let binary_resp: dto::BlobIsBinaryResponseDto = service
|
||||
.git_blob_is_binary(&session, &wk, &repo, oid)
|
||||
.await?
|
||||
.into();
|
||||
|
||||
ok_json(BlobInfoResponse {
|
||||
load,
|
||||
size: size_resp.size,
|
||||
is_binary: binary_resp.is_binary,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::ToSchema)]
|
||||
pub struct BlobUploadBody {
|
||||
pub path: String,
|
||||
pub blob: Vec<u8>,
|
||||
}
|
||||
#[utoipa::path(
|
||||
post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/blobs",
|
||||
params(WkRepoPath),
|
||||
request_body = BlobUploadBody,
|
||||
responses((status = 200, description = "Upload result", body = dto::BlobUploadResponseDto)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn blob_upload(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoPath>,
|
||||
params: web::Json<BlobUploadBody>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoPath { wk, repo } = path.into_inner();
|
||||
let p = params.into_inner();
|
||||
let data: dto::BlobUploadResponseDto = service
|
||||
.git_blob_upload(&session, &wk, &repo, p.path, p.blob)
|
||||
.await?
|
||||
.into();
|
||||
ok_json(data)
|
||||
}
|
||||
205
lib/api/src/git/branch.rs
Normal file
205
lib/api/src/git/branch.rs
Normal file
@ -0,0 +1,205 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use service::{AppService, Pagination};
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::git::dto;
|
||||
|
||||
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().json(data))
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct WkRepoPath {
|
||||
pub wk: String,
|
||||
pub repo: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct WkRepoBranchPath {
|
||||
pub wk: String,
|
||||
pub repo: String,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct BranchDeleteQuery {
|
||||
#[serde(default)]
|
||||
pub force: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct BranchListQuery {
|
||||
#[serde(default)]
|
||||
pub summary: bool,
|
||||
#[serde(default)]
|
||||
pub default_only: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::ToSchema)]
|
||||
pub struct RenameBranchBody {
|
||||
pub new_branch: String,
|
||||
#[serde(default)]
|
||||
pub force: bool,
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches",
|
||||
params(WkRepoPath, Pagination, BranchListQuery),
|
||||
responses((status = 200, description = "Branch list or summary", body = dto::BranchListResponseDto)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn list_branches(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoPath>,
|
||||
pagination: web::Query<Pagination>,
|
||||
query: web::Query<BranchListQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoPath { wk, repo } = path.into_inner();
|
||||
if query.summary {
|
||||
let data: dto::BranchSummaryResponseDto = service
|
||||
.git_branch_summary(&session, &wk, &repo)
|
||||
.await?
|
||||
.into();
|
||||
return ok_json(data);
|
||||
}
|
||||
if query.default_only {
|
||||
let data: dto::BranchHeadResponseDto = service
|
||||
.git_branch_head(&session, &wk, &repo)
|
||||
.await?
|
||||
.into();
|
||||
return ok_json(data);
|
||||
}
|
||||
let data: dto::BranchListResponseDto = service
|
||||
.git_branch_list(&session, &wk, &repo, pagination.into_inner())
|
||||
.await?
|
||||
.into();
|
||||
ok_json(data)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}",
|
||||
params(WkRepoBranchPath),
|
||||
responses((status = 200, description = "Branch info", body = dto::BranchInfoResponseDto)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn branch_info(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoBranchPath>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoBranchPath { wk, repo, name } = path.into_inner();
|
||||
let data: dto::BranchInfoResponseDto = service
|
||||
.git_branch_info(&session, &wk, &repo, name)
|
||||
.await?
|
||||
.into();
|
||||
ok_json(data)
|
||||
}
|
||||
#[utoipa::path(
|
||||
post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches",
|
||||
params(WkRepoPath),
|
||||
request_body = Object, description = "BranchForkParams { name, oid, force }",
|
||||
responses((status = 200, description = "Branch created")),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn fork_branch(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoPath>,
|
||||
params: web::Json<git::rpc::proto::BranchForkParams>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoPath { wk, repo } = path.into_inner();
|
||||
let data = service
|
||||
.git_branch_fork(&session, &wk, &repo, params.into_inner())
|
||||
.await?;
|
||||
ok_json(data)
|
||||
}
|
||||
#[utoipa::path(
|
||||
patch, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}",
|
||||
params(WkRepoBranchPath),
|
||||
request_body = RenameBranchBody,
|
||||
responses((status = 200, description = "Branch renamed")),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn rename_branch(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoBranchPath>,
|
||||
body: web::Json<RenameBranchBody>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoBranchPath { wk, repo, name } = path.into_inner();
|
||||
let body = body.into_inner();
|
||||
let params = git::rpc::proto::BranchReNameParams {
|
||||
old_branch: name,
|
||||
new_branch: body.new_branch,
|
||||
force: body.force,
|
||||
};
|
||||
let data = service
|
||||
.git_branch_rename(&session, &wk, &repo, params)
|
||||
.await?;
|
||||
ok_json(data)
|
||||
}
|
||||
#[utoipa::path(
|
||||
delete, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}",
|
||||
params(WkRepoBranchPath, BranchDeleteQuery),
|
||||
responses((status = 200, description = "Branch deleted")),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn delete_branch(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoBranchPath>,
|
||||
query: web::Query<BranchDeleteQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoBranchPath { wk, repo, name } = path.into_inner();
|
||||
let params = git::rpc::proto::BranchDeleteParams {
|
||||
name,
|
||||
force: query.force,
|
||||
};
|
||||
let data = service
|
||||
.git_branch_delete(&session, &wk, &repo, params)
|
||||
.await?;
|
||||
ok_json(data)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct AheadBehindQuery {
|
||||
pub remote_branch: String,
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}/ahead-behind",
|
||||
params(WkRepoBranchPath, AheadBehindQuery),
|
||||
responses((status = 200, description = "Ahead/behind counts", body = dto::BranchAheadBehindResponseDto)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn ahead_behind(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoBranchPath>,
|
||||
query: web::Query<AheadBehindQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoBranchPath { wk, repo, name } = path.into_inner();
|
||||
let data: dto::BranchAheadBehindResponseDto = service
|
||||
.git_branch_ahead_behind(&session, &wk, &repo, name, query.remote_branch.clone())
|
||||
.await?
|
||||
.into();
|
||||
ok_json(data)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/branches/{name}/upstream",
|
||||
params(WkRepoBranchPath),
|
||||
responses((status = 200, description = "Upstream branch", body = dto::BranchUpstreamResponseDto)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn branch_upstream(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoBranchPath>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoBranchPath { wk, repo, name } = path.into_inner();
|
||||
let data: dto::BranchUpstreamResponseDto = service
|
||||
.git_branch_upstream(&session, &wk, &repo, name)
|
||||
.await?
|
||||
.into();
|
||||
ok_json(data)
|
||||
}
|
||||
169
lib/api/src/git/commit.rs
Normal file
169
lib/api/src/git/commit.rs
Normal file
@ -0,0 +1,169 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
use git::rpc::proto as p;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use service::AppService;
|
||||
use session::Session;
|
||||
|
||||
use crate::error::ApiError;
|
||||
use crate::git::dto;
|
||||
|
||||
fn ok_json<T: Serialize>(data: T) -> Result<HttpResponse, ApiError> {
|
||||
Ok(HttpResponse::Ok().json(data))
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct WkRepoPath {
|
||||
pub wk: String,
|
||||
pub repo: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct WkRepoCommitPath {
|
||||
pub wk: String,
|
||||
pub repo: String,
|
||||
pub oid: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct HistoryQuery {
|
||||
pub limit: Option<u64>,
|
||||
pub skip: Option<u64>,
|
||||
pub sort: Option<i32>,
|
||||
pub branch: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, utoipa::IntoParams)]
|
||||
pub struct CommitListQuery {
|
||||
#[serde(default)]
|
||||
pub summary: bool,
|
||||
#[serde(default)]
|
||||
pub refs: bool,
|
||||
pub prefix: Option<String>,
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits",
|
||||
params(WkRepoPath, CommitListQuery),
|
||||
responses(
|
||||
(status = 200, description = "Commit list / summary / refs / prefix", body = dto::CommitHistoryResponseDto),
|
||||
),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn list_commits(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoPath>,
|
||||
query: web::Query<CommitListQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoPath { wk, repo } = path.into_inner();
|
||||
|
||||
if let Some(prefix) = &query.prefix {
|
||||
let data: dto::CommitPrefixResponseDto = service
|
||||
.git_commit_prefix(&session, &wk, &repo, prefix.clone())
|
||||
.await?
|
||||
.into();
|
||||
return ok_json(data);
|
||||
}
|
||||
if query.refs {
|
||||
let data: dto::CommitRefsResponseDto = service
|
||||
.git_commit_refs(&session, &wk, &repo)
|
||||
.await?
|
||||
.into();
|
||||
return ok_json(data);
|
||||
}
|
||||
if query.summary {
|
||||
let data: dto::CommitSummaryResponseDto = service
|
||||
.git_commit_summary(&session, &wk, &repo)
|
||||
.await?
|
||||
.into();
|
||||
return ok_json(data);
|
||||
}
|
||||
let data: dto::CommitHistoryResponseDto = service
|
||||
.git_commit_history(&session, &wk, &repo, 20, 0, 0, None)
|
||||
.await?
|
||||
.into();
|
||||
ok_json(data)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/history",
|
||||
params(WkRepoPath, HistoryQuery),
|
||||
responses((status = 200, description = "Commit history", body = dto::CommitHistoryResponseDto)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn commit_history(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoPath>,
|
||||
query: web::Query<HistoryQuery>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoPath { wk, repo } = path.into_inner();
|
||||
let data: dto::CommitHistoryResponseDto = service
|
||||
.git_commit_history(
|
||||
&session, &wk, &repo,
|
||||
query.limit.unwrap_or(20),
|
||||
query.skip.unwrap_or(0),
|
||||
query.sort.unwrap_or(0),
|
||||
query.branch.clone(),
|
||||
)
|
||||
.await?
|
||||
.into();
|
||||
ok_json(data)
|
||||
}
|
||||
#[utoipa::path(
|
||||
get, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/{oid}",
|
||||
params(WkRepoCommitPath),
|
||||
responses((status = 200, description = "Commit info", body = dto::CommitInfoResponseDto)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn commit_info(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoCommitPath>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoCommitPath { wk, repo, oid } = path.into_inner();
|
||||
let data: dto::CommitInfoResponseDto = service
|
||||
.git_commit_info(&session, &wk, &repo, oid)
|
||||
.await?
|
||||
.into();
|
||||
ok_json(data)
|
||||
}
|
||||
#[utoipa::path(
|
||||
post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/walk",
|
||||
params(WkRepoPath),
|
||||
request_body = Object, description = "CommitWalkParams",
|
||||
responses((status = 200, description = "Walk result", body = dto::CommitHistoryResponseDto)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn commit_walk(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoPath>,
|
||||
params: web::Json<p::CommitWalkParams>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoPath { wk, repo } = path.into_inner();
|
||||
let proto_resp = service
|
||||
.git_commit_walk(&session, &wk, &repo, params.into_inner())
|
||||
.await?;
|
||||
ok_json(dto::CommitHistoryResponseDto {
|
||||
commits: proto_resp.commits.into_iter().map(Into::into).collect(),
|
||||
})
|
||||
}
|
||||
#[utoipa::path(
|
||||
post, path = "/api/v1/workspace/{wk}/repos/{repo}/git/commits/cherry-pick",
|
||||
params(WkRepoPath),
|
||||
request_body = Object, description = "CommitCherryPickParams",
|
||||
responses((status = 200, description = "Cherry-pick result", body = dto::CherryPickResponseDto)),
|
||||
security(("session" = []))
|
||||
)]
|
||||
pub async fn cherry_pick(
|
||||
session: Session,
|
||||
service: web::Data<AppService>,
|
||||
path: web::Path<WkRepoPath>,
|
||||
params: web::Json<p::CommitCherryPickParams>,
|
||||
) -> Result<HttpResponse, ApiError> {
|
||||
let WkRepoPath { wk, repo } = path.into_inner();
|
||||
let data: dto::CherryPickResponseDto = service
|
||||
.git_cherry_pick(&session, &wk, &repo, params.into_inner())
|
||||
.await?
|
||||
.into();
|
||||
ok_json(data)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user