feat(core): initialize project with access control and AI integration

This commit is contained in:
ZhenYi 2026-05-10 21:01:21 +08:00
parent 14f6e1e500
commit ba2490dab4
619 changed files with 51264 additions and 8418 deletions

View File

@ -51,9 +51,9 @@ APP_DOMAIN_URL=http://127.0.0.1
# APP_DATABASE_MAX_CONNECTIONS=10 # APP_DATABASE_MAX_CONNECTIONS=10
# APP_DATABASE_MIN_CONNECTIONS=2 # APP_DATABASE_MIN_CONNECTIONS=2
# APP_DATABASE_IDLE_TIMEOUT=60000 # APP_DATABASE_IDLE_TIMEOUT=60000 (milliseconds, default: 60s)
# APP_DATABASE_MAX_LIFETIME=300000 # APP_DATABASE_MAX_LIFETIME=300000 (milliseconds, default: 300s)
# APP_DATABASE_CONNECTION_TIMEOUT=5000 # APP_DATABASE_CONNECTION_TIMEOUT=5000 (milliseconds, default: 5s)
# APP_DATABASE_REPLICAS= # APP_DATABASE_REPLICAS=
# APP_DATABASE_HEALTH_CHECK_INTERVAL=30 # APP_DATABASE_HEALTH_CHECK_INTERVAL=30
# APP_DATABASE_RETRY_ATTEMPTS=3 # APP_DATABASE_RETRY_ATTEMPTS=3

3
.gitignore vendored
View File

@ -23,3 +23,6 @@ coverage/
pnpm-lock.yaml pnpm-lock.yaml
package-lock.json package-lock.json
yarn.lock yarn.lock
.gemini
.omg
/.sqry

11
.mcp.json Normal file
View File

@ -0,0 +1,11 @@
{
"mcpServers": {
"shadcn": {
"command": "npx",
"args": [
"shadcn@latest",
"mcp"
]
}
}
}

1537
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -11,19 +11,21 @@ members = [
"libs/service", "libs/service",
"libs/db", "libs/db",
"libs/api", "libs/api",
"libs/webhook",
"libs/transport", "libs/transport",
"libs/observability", "libs/observability",
"libs/avatar", "libs/avatar",
"libs/agent", "libs/agent",
"libs/migrate", "libs/migrate",
"libs/fctool", "libs/fctool",
"libs/gingress-proxy",
"apps/migrate", "apps/migrate",
"apps/app", "apps/app",
"apps/git-hook", "apps/git-hook",
"apps/gitserver", "apps/gitserver",
"apps/email", "apps/email",
"apps/static", "apps/static",
"apps/metrics",
"apps/gingress",
] ]
resolver = "3" resolver = "3"
@ -40,12 +42,14 @@ service = { path = "libs/service" }
db = { path = "libs/db" } db = { path = "libs/db" }
api = { path = "libs/api" } api = { path = "libs/api" }
agent = { path = "libs/agent" } agent = { path = "libs/agent" }
webhook = { path = "libs/webhook" }
observability = { path = "libs/observability" } observability = { path = "libs/observability" }
avatar = { path = "libs/avatar" } avatar = { path = "libs/avatar" }
migrate = { path = "libs/migrate" } migrate = { path = "libs/migrate" }
fctool = { path = "libs/fctool" } fctool = { path = "libs/fctool" }
transport = { path = "libs/transport" } transport = { path = "libs/transport" }
metrics-aggregator = { path = "apps/metrics" }
gingress-proxy = { path = "libs/gingress-proxy" }
gingress = { path = "apps/gingress" }
sea-query = "1.0.0-rc.33" sea-query = "1.0.0-rc.33"
@ -131,7 +135,10 @@ tokio = "1.50.0"
tokio-util = "0.7.18" tokio-util = "0.7.18"
tokio-stream = "0.1.18" tokio-stream = "0.1.18"
url = "2.5.8" url = "2.5.8"
tower = "0.5"
num_cpus = "1.17.0" num_cpus = "1.17.0"
ring = "0.17"
rustls = { version = "0.23", default-features = false, features = ["ring", "std", "tls12"] }
clap = "4.6.0" clap = "4.6.0"
time = "0.3.47" time = "0.3.47"
chrono = "0.4.44" chrono = "0.4.44"
@ -165,12 +172,19 @@ phf_codegen = "0.13.1"
base64 = "0.22.1" base64 = "0.22.1"
base64ct = "1" base64ct = "1"
p256 = { version = "0.13", features = ["ecdsa", "std"] } p256 = { version = "0.13", features = ["ecdsa", "std"] }
http = "1" # http version varies per-crate (pingora needs 1.x, actix needs 0.2)
hyper = "0.14" hyper = "0.14"
tempfile = "3" tempfile = "3"
rig-core = { version = "0.30.0", default-features = false } rig-core = { version = "0.30.0", default-features = false }
tokio-tungstenite = { version = "0.29.0", features = [] } tokio-tungstenite = { version = "0.29.0", features = [] }
async-nats = { version = "0.47.0", features = [] } async-nats = { version = "0.47.0", features = [] }
kube = { version = "0.98", features = ["runtime", "derive"] }
k8s-openapi = { version = "0.24", features = ["v1_31"] }
pingora = { version = "0.8", features = ["proxy"] }
pingora-proxy = "0.8"
pingora-load-balancing = "0.8"
pingora-cache = "0.8"
rustls-pemfile = "2"
[workspace.package] [workspace.package]
version = "0.2.9" version = "0.2.9"
edition = "2024" edition = "2024"

View File

@ -9,6 +9,7 @@ use futures::future::LocalBoxFuture;
use observability::{ use observability::{
init_tracing_subscriber, install_recorder, prometheus_handler, spawn_http_metrics_poller, init_tracing_subscriber, install_recorder, prometheus_handler, spawn_http_metrics_poller,
HttpMetrics, HttpSnapshotGuard, MetricsMiddleware, TracingSpanMiddleware, HttpMetrics, HttpSnapshotGuard, MetricsMiddleware, TracingSpanMiddleware,
push::MetricsPusher,
}; };
use sea_orm::ConnectionTrait; use sea_orm::ConnectionTrait;
use service::AppService; use service::AppService;
@ -17,6 +18,7 @@ use api::{robots, sidemap};
use session::storage::RedisClusterSessionStore; use session::storage::RedisClusterSessionStore;
use session::SessionMiddleware; use session::SessionMiddleware;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
mod args; mod args;
@ -151,7 +153,8 @@ async fn main() -> anyhow::Result<()> {
let service = AppService::new(cfg.clone()).await?; let service = AppService::new(cfg.clone()).await?;
tracing::info!("AppService initialized"); tracing::info!("AppService initialized");
let _model_sync_handle = service.clone().start_sync_task(); let _model_sync_handle = service.clone().start_sync_task();
let _billing_alert_handle = service.clone().start_billing_alert_task(); // TODO: workspace module not yet wired — billing alert task pending
// let _billing_alert_handle = service.clone().start_billing_alert_task();
let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1); let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
let worker_service = service.clone(); let worker_service = service.clone();
@ -192,6 +195,13 @@ async fn main() -> anyhow::Result<()> {
); );
let http_snapshot_data = web::Data::new(http_snapshot); let http_snapshot_data = web::Data::new(http_snapshot);
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
let pusher = MetricsPusher::new(&push_url, "app");
pusher.spawn(http_metrics.clone(), Arc::new(prometheus_handle.clone()), std::time::Duration::from_secs(15));
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
}
let bind_addr = args.bind.unwrap_or_else(|| "127.0.0.1:8080".to_string()); let bind_addr = args.bind.unwrap_or_else(|| "127.0.0.1:8080".to_string());
tracing::info!(bind_addr = %bind_addr, "Listening"); tracing::info!(bind_addr = %bind_addr, "Listening");
let http_metrics_server = http_metrics.clone(); let http_metrics_server = http_metrics.clone();
@ -212,11 +222,16 @@ async fn main() -> anyhow::Result<()> {
cors = cors.allowed_origin(origin); cors = cors.allowed_origin(origin);
} }
let cors = cors let cors = cors
.allowed_methods(["GET", "POST", "PUT", "PATCH", "DELETE"]) .allowed_methods(["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
.allowed_headers(["Content-Type", "Authorization", "X-Requested-With", "Accept", "Origin"]) .allowed_headers(["Content-Type", "Authorization", "X-Requested-With", "Accept", "Origin"])
.supports_credentials() .supports_credentials()
.max_age(3600); .max_age(3600);
let security_headers = actix_web::middleware::DefaultHeaders::new()
.add(("X-Content-Type-Options", "nosniff"))
.add(("X-Frame-Options", "DENY"))
.add(("Referrer-Policy", "strict-origin-when-cross-origin"));
let session_mw = SessionMiddleware::builder(store.clone(), session_key.clone()) let session_mw = SessionMiddleware::builder(store.clone(), session_key.clone())
.cookie_name("id".to_string()) .cookie_name("id".to_string())
.cookie_path("/".to_string()) .cookie_path("/".to_string())
@ -233,6 +248,7 @@ async fn main() -> anyhow::Result<()> {
App::new() App::new()
.wrap(cors) .wrap(cors)
.wrap(security_headers)
.wrap(session_mw) .wrap(session_mw)
.wrap(RequestLogger) .wrap(RequestLogger)
.wrap(metrics_mw) .wrap(metrics_mw)

View File

@ -2,7 +2,7 @@ use clap::Parser;
use config::AppConfig; use config::AppConfig;
use metrics::{describe_counter, Unit}; use metrics::{describe_counter, Unit};
use metrics_exporter_prometheus::PrometheusHandle; use metrics_exporter_prometheus::PrometheusHandle;
use observability::{init_tracing_subscriber, install_recorder}; use observability::{init_tracing_subscriber, install_recorder, HttpMetrics, push::MetricsPusher};
use sea_orm::ConnectionTrait; use sea_orm::ConnectionTrait;
use service::AppService; use service::AppService;
use std::sync::Arc; use std::sync::Arc;
@ -88,6 +88,14 @@ async fn main() -> anyhow::Result<()> {
describe_counter!("email_send_failures_total", Unit::Count, "Emails that failed after all retries"); describe_counter!("email_send_failures_total", Unit::Count, "Emails that failed after all retries");
let metrics_handle = Arc::new(install_recorder()); let metrics_handle = Arc::new(install_recorder());
let http_metrics = Arc::new(HttpMetrics::new()); // Worker app — HTTP section will be empty
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
let pusher = MetricsPusher::new(&push_url, "email");
pusher.spawn(http_metrics.clone(), metrics_handle.clone(), std::time::Duration::from_secs(15));
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
}
tracing::info!("Starting email worker"); tracing::info!("Starting email worker");
let service = AppService::new(cfg).await?; let service = AppService::new(cfg).await?;

46
apps/gingress/Cargo.toml Normal file
View File

@ -0,0 +1,46 @@
[package]
name = "gingress"
version.workspace = true
edition.workspace = true
authors.workspace = true
description = "GIngress control plane: Kubernetes Ingress Controller using kube-rs"
repository.workspace = true
readme.workspace = true
homepage.workspace = true
license.workspace = true
keywords.workspace = true
categories.workspace = true
documentation.workspace = true
[[bin]]
name = "gingress"
path = "src/main.rs"
[[bin]]
name = "kubectl-gingress"
path = "src/bin/kubectl-gingress/main.rs"
[dependencies]
gingress-proxy = { workspace = true }
kube = { version = "0.98", features = ["runtime", "derive"] }
k8s-openapi = { version = "0.24", features = ["v1_31"] }
tokio = { workspace = true, features = ["full"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
serde_yaml = { workspace = true }
tracing = { workspace = true }
observability = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
dashmap = { workspace = true }
futures-util = { workspace = true }
futures = { workspace = true }
clap = { workspace = true }
url = { workspace = true }
x509-parser = "0.17"
rustls-pemfile = "2"
[lints]
workspace = true

View File

@ -0,0 +1,797 @@
//! kubectl-gingress — kubectl plugin for managing GIngress resources.
//!
//! Usage (via kubectl): kubectl gingress <subcommand>
//! Usage (standalone): kubectl-gingress <subcommand>
use clap::{Parser, Subcommand};
use k8s_openapi::api::core::v1::{Pod, Secret};
use k8s_openapi::api::networking::v1::{HTTPIngressPath, Ingress};
use kube::api::ListParams;
use kube::{Api, Client, ResourceExt};
const INGRESS_CLASS: &str = "gingress";
#[derive(Parser)]
#[command(
name = "kubectl-gingress",
bin_name = "kubectl gingress",
about = "Manage GIngress — Kubernetes Ingress Controller",
version
)]
struct Cli {
#[command(subcommand)]
command: Command,
}
#[derive(Subcommand)]
enum Command {
/// List all Ingress resources managed by GIngress
#[command(alias = "ls")]
List {
/// Filter by namespace (omit for all namespaces)
#[arg(short, long)]
namespace: Option<String>,
/// Output as JSON
#[arg(long)]
json: bool,
},
/// Show the routing table (host → path → backend)
Routes {
/// Filter by namespace
#[arg(short, long)]
namespace: Option<String>,
/// Filter by host
#[arg(short = 'H', long)]
host: Option<String>,
/// Output as JSON
#[arg(long)]
json: bool,
},
/// Show backend services and their endpoints
Backends {
/// Filter by namespace
#[arg(short, long)]
namespace: Option<String>,
/// Output as JSON
#[arg(long)]
json: bool,
},
/// List TLS certificates (from Secrets)
Certs {
/// Filter by namespace
#[arg(short, long)]
namespace: Option<String>,
/// Output as JSON
#[arg(long)]
json: bool,
},
/// Validate Ingress configurations
Validate {
/// Filter by namespace
#[arg(short, long)]
namespace: Option<String>,
},
/// Show GIngress controller status and summary
Status {
/// Output as JSON
#[arg(long)]
json: bool,
},
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
let client = Client::try_default().await?;
match cli.command {
Command::List { namespace, json } => cmd_list(&client, namespace, json).await?,
Command::Routes { namespace, host, json } => cmd_routes(&client, namespace, host, json).await?,
Command::Backends { namespace, json } => cmd_backends(&client, namespace, json).await?,
Command::Certs { namespace, json } => cmd_certs(&client, namespace, json).await?,
Command::Validate { namespace } => cmd_validate(&client, namespace).await?,
Command::Status { json } => cmd_status(&client, json).await?,
}
Ok(())
}
// ── list ──────────────────────────────────────────────────────────
async fn cmd_list(client: &Client, namespace: Option<String>, json: bool) -> Result<(), Box<dyn std::error::Error>> {
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
if json {
println!("{}", serde_json::to_string_pretty(&ingresses)?);
return Ok(());
}
if ingresses.is_empty() {
println!("No GIngress-managed Ingress resources found.");
return Ok(());
}
println!("{:<25} {:<20} {:<40} {:<50} {:<15}", "NAMESPACE", "NAME", "HOSTS", "PATHS", "TLS");
println!("{:-<150}", "");
for ing in &ingresses {
let ns = ing.namespace();
let name = ing.name_any();
let hosts = ing.hosts().join(", ");
let paths = ing
.paths_display()
.iter()
.map(|p| format!("{} {}", p.path_type, p.path))
.collect::<Vec<_>>()
.join(", ");
let tls = if ing.has_tls() { "Enabled" } else { "-" };
println!("{:<25} {:<20} {:<40} {:<50} {:<15}",
truncate(&ns, 25),
truncate(&name, 20),
truncate(&hosts, 40),
truncate(&paths, 50),
tls,
);
}
println!("\nTotal: {} Ingress(es)", ingresses.len());
Ok(())
}
// ── routes ─────────────────────────────────────────────────────────
async fn cmd_routes(
client: &Client,
namespace: Option<String>,
host_filter: Option<String>,
json: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
let mut routes: Vec<RouteRow> = Vec::new();
for ing in &ingresses {
for rule in ing.spec.as_ref().and_then(|s| s.rules.as_ref()).into_iter().flatten() {
let host = rule.host.as_deref().unwrap_or("*");
if let Some(ref hf) = host_filter {
if host != hf { continue; }
}
if let Some(http) = &rule.http {
for path_item in &http.paths {
let backend = extract_backend(path_item);
let port = extract_backend_port(path_item);
routes.push(RouteRow {
namespace: ing.namespace(),
ingress: ing.name_any(),
host: host.to_string(),
path: path_item.path.clone().unwrap_or_else(|| "/".into()),
path_type: path_item.path_type.clone(),
backend,
port,
});
}
}
}
}
if json {
println!("{}", serde_json::to_string_pretty(&routes)?);
return Ok(());
}
if routes.is_empty() {
println!("No routes found.");
return Ok(());
}
println!("{:<20} {:<20} {:<30} {:<18} {:<15} {:<15} {:<15}",
"NAMESPACE", "INGRESS", "HOST", "PATH", "TYPE", "BACKEND", "PORT");
println!("{:-<133}", "");
for r in &routes {
let port = extract_backend_port_str(r);
println!("{:<20} {:<20} {:<30} {:<18} {:<15} {:<15} {:<15}",
truncate(&r.namespace, 20),
truncate(&r.ingress, 20),
truncate(&r.host, 30),
truncate(&r.path, 18),
truncate(&r.path_type, 15),
truncate(&r.backend, 15),
port,
);
}
println!("\nTotal: {} route(s)", routes.len());
Ok(())
}
// ── backends ───────────────────────────────────────────────────────
async fn cmd_backends(
client: &Client,
namespace: Option<String>,
json: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
// Collect unique backends from all ingresses
let mut backends: Vec<BackendRow> = Vec::new();
let mut seen = std::collections::HashSet::new();
for ing in &ingresses {
for rule in ing.spec.as_ref().and_then(|s| s.rules.as_ref()).into_iter().flatten() {
if let Some(http) = &rule.http {
for path_item in &http.paths {
let svc = match path_item.backend.service.as_ref() {
Some(s) => s,
None => continue,
};
let key = format!("{}/{}:{}", ing.namespace(), svc.name, svc.port.as_ref().and_then(|p| p.number).unwrap_or(80));
if seen.insert(key.clone()) {
let ns = ing.namespace();
let ep_status = get_endpoint_status(client, &ns, &svc.name).await;
backends.push(BackendRow {
namespace: ns,
service: svc.name.clone(),
port: svc.port.as_ref().and_then(|p| p.number).unwrap_or(80) as u16,
ready_endpoints: ep_status.ready,
total_endpoints: ep_status.total,
referenced_by: ing.name_any(),
});
}
}
}
}
}
if json {
println!("{}", serde_json::to_string_pretty(&backends)?);
return Ok(());
}
if backends.is_empty() {
println!("No backends found.");
return Ok(());
}
println!("{:<20} {:<20} {:<8} {:<8} {:<18} {:<20}",
"NAMESPACE", "SERVICE", "PORT", "HEALTH", "ENDPOINTS", "REFERENCED BY");
println!("{:-<94}", "");
for b in &backends {
let health = if b.total_endpoints == 0 { "WARN" } else if b.ready_endpoints == 0 { "DOWN" } else if b.ready_endpoints < b.total_endpoints { "PARTIAL" } else { "OK" };
let eps = format!("{}/{} ready", b.ready_endpoints, b.total_endpoints);
println!("{:<20} {:<20} {:<8} {:<8} {:<18} {:<20}",
truncate(&b.namespace, 20),
truncate(&b.service, 20),
b.port,
health,
eps,
truncate(&b.referenced_by, 20),
);
}
println!("\nTotal: {} backend(s)", backends.len());
Ok(())
}
// ── certs ──────────────────────────────────────────────────────────
async fn cmd_certs(
client: &Client,
namespace: Option<String>,
json: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
let mut certs: Vec<CertRow> = Vec::new();
for ing in &ingresses {
let ns = ing.namespace();
let tls_entries = ing
.spec
.as_ref()
.and_then(|s| s.tls.as_ref())
.cloned()
.unwrap_or_default();
for tls in &tls_entries {
let secret_name = tls.secret_name.clone().unwrap_or_default();
let hosts = tls.hosts.clone().unwrap_or_default();
// Check if the secret exists
let secret_exists = check_secret_exists(client, &ns, &secret_name).await;
for host in &hosts {
certs.push(CertRow {
namespace: ns.clone(),
secret_name: secret_name.clone(),
host: host.clone(),
found: secret_exists,
});
}
}
}
if json {
println!("{}", serde_json::to_string_pretty(&certs)?);
return Ok(());
}
if certs.is_empty() {
println!("No TLS certificates configured.");
return Ok(());
}
println!("{:<20} {:<30} {:<30} {:<10}", "NAMESPACE", "SECRET", "HOST", "STATUS");
println!("{:-<90}", "");
for c in &certs {
let status = if c.found { "OK" } else { "MISSING" };
println!("{:<20} {:<30} {:<30} {:<10}",
truncate(&c.namespace, 20),
truncate(&c.secret_name, 30),
truncate(&c.host, 30),
status,
);
}
let missing = certs.iter().filter(|c| !c.found).count();
println!("\nTotal: {} cert(s), {} missing", certs.len(), missing);
Ok(())
}
// ── validate ───────────────────────────────────────────────────────
async fn cmd_validate(
client: &Client,
namespace: Option<String>,
) -> Result<(), Box<dyn std::error::Error>> {
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
let mut errors = 0usize;
let mut warnings = 0usize;
for ing in &ingresses {
let ns = ing.namespace();
let name = ing.name_any();
// Check: has rules
let has_rules = ing
.spec
.as_ref()
.map(|s| s.rules.as_ref().map(|r| !r.is_empty()).unwrap_or(false))
.unwrap_or(false);
if !has_rules {
println!("[{}/{}] ERROR: No routing rules defined", ns, name);
errors += 1;
}
// Check: has TLS but no secret
let tls_entries = ing
.spec
.as_ref()
.and_then(|s| s.tls.as_ref())
.cloned()
.unwrap_or_default();
for tls in &tls_entries {
let secret_name = tls.secret_name.as_deref().unwrap_or("");
if secret_name.is_empty() {
println!("[{}/{}] WARNING: TLS configured but no secretName specified", ns, name);
warnings += 1;
} else {
let found = check_secret_exists(client, &ns, secret_name).await;
if !found {
println!("[{}/{}] ERROR: TLS secret '{}' not found in namespace '{}'", ns, name, secret_name, ns);
errors += 1;
}
}
}
// Check: path backends reference valid services
if let Some(rules) = ing.spec.as_ref().and_then(|s| s.rules.as_ref()) {
for rule in rules {
if let Some(http) = &rule.http {
for path_item in &http.paths {
if let Some(svc) = &path_item.backend.service {
let endpoints = get_endpoint_status(client, &ns, &svc.name).await;
if endpoints.total == 0 {
println!(
"[{}/{}] WARNING: Backend service '{}' has no endpoints (host: {})",
ns, name, svc.name,
rule.host.as_deref().unwrap_or("*")
);
warnings += 1;
}
}
}
}
}
}
}
if errors == 0 && warnings == 0 {
println!("Validation passed — no issues found in {} Ingress(es).", ingresses.len());
} else {
println!("\nValidation complete: {} error(s), {} warning(s) across {} Ingress(es).",
errors, warnings, ingresses.len());
}
Ok(())
}
// ── status ─────────────────────────────────────────────────────────
async fn cmd_status(client: &Client, json: bool) -> Result<(), Box<dyn std::error::Error>> {
let ingresses = list_ingresses(client, None).await?;
let controller_pods = find_gingress_pods(client).await;
if json {
#[derive(serde::Serialize)]
struct StatusOutput {
controller_pods: Vec<PodInfo>,
ingress_count: usize,
route_count: usize,
backend_count: usize,
cert_count: usize,
}
let mut route_count = 0usize;
let mut backend_set = std::collections::HashSet::new();
let mut cert_count = 0usize;
for ing in &ingresses {
if let Some(spec) = &ing.spec {
if let Some(rules) = &spec.rules {
for rule in rules {
if let Some(http) = &rule.http {
route_count += http.paths.len();
for p in &http.paths {
if let Some(svc) = &p.backend.service {
backend_set.insert(format!("{}/{}", ing.namespace(), svc.name));
}
}
}
}
}
if let Some(tls) = &spec.tls {
cert_count += tls.len();
}
}
}
println!(
"{}",
serde_json::to_string_pretty(&StatusOutput {
controller_pods,
ingress_count: ingresses.len(),
route_count,
backend_count: backend_set.len(),
cert_count,
})?
);
return Ok(());
}
// Controller status
println!("══ GIngress Controller Status ══\n");
if controller_pods.is_empty() {
println!("Controller: NOT FOUND (no pods with label gingress.io/component=controller)");
} else {
for pod in &controller_pods {
let ready = if pod.ready { "Running" } else { "NotReady" };
println!(
"Controller Pod: {:<30} {:<12} {}",
truncate(&pod.name, 30),
ready,
pod.namespace
);
}
}
// Resource summary
println!();
println!("══ Managed Resources ══\n");
let mut route_count = 0usize;
let mut backend_set = std::collections::HashSet::new();
let mut backend_total_eps = 0usize;
let mut backend_ready_eps = 0usize;
let mut cert_count = 0usize;
let mut cert_missing = 0usize;
for ing in &ingresses {
if let Some(spec) = &ing.spec {
if let Some(rules) = &spec.rules {
for rule in rules {
if let Some(http) = &rule.http {
route_count += http.paths.len();
for p in &http.paths {
if let Some(svc) = &p.backend.service {
let key = format!("{}/{}", ing.namespace(), svc.name);
if backend_set.insert(key) {
let eps = get_endpoint_status(client, &ing.namespace(), &svc.name).await;
backend_total_eps += eps.total;
backend_ready_eps += eps.ready;
}
}
}
}
}
}
if let Some(tls) = &spec.tls {
cert_count += tls.len();
for tls in tls {
let sn = tls.secret_name.as_deref().unwrap_or("");
if !sn.is_empty() && !check_secret_exists(client, &ing.namespace(), sn).await {
cert_missing += 1;
}
}
}
}
}
println!("Ingresses: {}", ingresses.len());
println!("Routes: {}", route_count);
println!(
"Backends: {} ({} ready / {} total endpoints)",
backend_set.len(), backend_ready_eps, backend_total_eps
);
println!("TLS Certs: {} ({} missing)", cert_count, cert_missing);
println!();
// Overall health
let healthy = !ingresses.is_empty()
&& (cert_missing == 0)
&& (backend_set.is_empty() || backend_ready_eps > 0)
&& controller_pods.iter().any(|p| p.ready);
if healthy {
println!("Status: HEALTHY");
} else {
println!("Status: DEGRADED");
if controller_pods.is_empty() || !controller_pods.iter().any(|p| p.ready) {
println!(" → Controller pod not running or not ready");
}
if cert_missing > 0 {
println!("{} TLS secret(s) missing", cert_missing);
}
if !backend_set.is_empty() && backend_ready_eps == 0 {
println!(" → No ready backend endpoints");
}
}
Ok(())
}
// ── k8s helpers ────────────────────────────────────────────────────
#[derive(serde::Serialize)]
struct IngressSummary {
namespace: String,
name: String,
hosts: Vec<String>,
#[serde(skip)]
paths_for_display: Vec<PathSummary>,
has_tls: bool,
#[serde(skip)]
spec: Option<k8s_openapi::api::networking::v1::IngressSpec>,
}
#[derive(Clone)]
struct PathSummary {
path_type: String,
path: String,
}
impl IngressSummary {
fn namespace(&self) -> String { self.namespace.clone() }
fn name_any(&self) -> String { self.name.clone() }
fn hosts(&self) -> &[String] { &self.hosts }
fn paths_display(&self) -> &[PathSummary] { &self.paths_for_display }
fn has_tls(&self) -> bool { self.has_tls }
}
fn ingress_to_summary(ing: &Ingress) -> IngressSummary {
let spec = ing.spec.clone();
let mut hosts = Vec::new();
let mut paths = Vec::new();
let mut has_tls = false;
if let Some(ref s) = spec {
if let Some(ref rules) = s.rules {
for rule in rules {
if let Some(ref host) = rule.host {
hosts.push(host.clone());
}
if let Some(ref http) = rule.http {
for p in &http.paths {
paths.push(PathSummary {
path_type: p.path_type.clone(),
path: p.path.clone().unwrap_or_else(|| "/".into()),
});
}
}
}
}
has_tls = s.tls.as_ref().map(|t| !t.is_empty()).unwrap_or(false);
}
IngressSummary {
namespace: ing.namespace().unwrap_or_default(),
name: ing.name_any(),
hosts,
paths_for_display: paths,
has_tls,
spec,
}
}
#[derive(serde::Serialize)]
struct PodInfo {
name: String,
namespace: String,
ready: bool,
}
/// Find GIngress controller pods by label `gingress.io/component=controller`.
async fn find_gingress_pods(client: &Client) -> Vec<PodInfo> {
let api: Api<Pod> = Api::all(client.clone());
let lp = ListParams {
label_selector: Some("gingress.io/component=controller".into()),
..Default::default()
};
match api.list(&lp).await {
Ok(list) => list
.items
.into_iter()
.map(|pod| {
let ready = pod
.status
.as_ref()
.and_then(|s| s.conditions.as_ref())
.map(|conds| {
conds
.iter()
.any(|c| c.type_ == "Ready" && c.status == "True")
})
.unwrap_or(false);
PodInfo {
name: pod.name_any(),
namespace: pod.namespace().unwrap_or_default(),
ready,
}
})
.collect(),
Err(_) => Vec::new(),
}
}
async fn list_ingresses(client: &Client, namespace: Option<&str>) -> Result<Vec<IngressSummary>, Box<dyn std::error::Error>> {
let params = ListParams {
..Default::default()
};
if let Some(ns) = namespace {
let api: Api<Ingress> = Api::namespaced(client.clone(), ns);
let list = api.list(&params).await?;
Ok(list
.items
.into_iter()
.filter(|ing| is_gingress_class(ing))
.map(|ing| ingress_to_summary(&ing))
.collect())
} else {
let api: Api<Ingress> = Api::all(client.clone());
let list = api.list(&params).await?;
Ok(list
.items
.into_iter()
.filter(|ing| is_gingress_class(ing))
.map(|ing| ingress_to_summary(&ing))
.collect())
}
}
fn is_gingress_class(ingress: &Ingress) -> bool {
ingress
.spec
.as_ref()
.and_then(|s| s.ingress_class_name.as_deref())
== Some(INGRESS_CLASS)
}
struct EndpointStatus {
ready: usize,
total: usize,
}
async fn get_endpoint_status(client: &Client, namespace: &str, service_name: &str) -> EndpointStatus {
use k8s_openapi::api::core::v1::Endpoints;
let api: Api<Endpoints> = Api::namespaced(client.clone(), namespace);
match api.get_opt(service_name).await {
Ok(Some(eps)) => {
let mut ready = 0usize;
let mut total = 0usize;
if let Some(subsets) = &eps.subsets {
for subset in subsets {
let addrs = subset.addresses.as_deref().unwrap_or_default();
let not_ready = subset.not_ready_addresses.as_deref().unwrap_or_default();
ready += addrs.len();
total += addrs.len() + not_ready.len();
}
}
EndpointStatus { ready, total }
}
_ => EndpointStatus { ready: 0, total: 0 },
}
}
async fn check_secret_exists(client: &Client, namespace: &str, name: &str) -> bool {
let api: Api<Secret> = Api::namespaced(client.clone(), namespace);
api.get_opt(name).await.ok().flatten().is_some()
}
fn extract_backend(path_item: &HTTPIngressPath) -> String {
path_item
.backend
.service
.as_ref()
.map(|s| s.name.clone())
.unwrap_or_else(|| "<resource>".into())
}
fn extract_backend_port(path_item: &HTTPIngressPath) -> u16 {
path_item
.backend
.service
.as_ref()
.and_then(|s| s.port.as_ref())
.and_then(|p| p.number)
.unwrap_or(80) as u16
}
fn extract_backend_port_str(r: &RouteRow) -> String {
r.port.to_string()
}
#[derive(serde::Serialize)]
struct RouteRow {
namespace: String,
ingress: String,
host: String,
path: String,
path_type: String,
backend: String,
port: u16,
}
#[derive(serde::Serialize)]
struct BackendRow {
namespace: String,
service: String,
port: u16,
ready_endpoints: usize,
total_endpoints: usize,
referenced_by: String,
}
#[derive(serde::Serialize)]
struct CertRow {
namespace: String,
secret_name: String,
host: String,
found: bool,
}
fn truncate(s: &str, max: usize) -> String {
// Account for CJK characters — each wide char counts as 2
let mut width = 0usize;
let mut result = String::new();
for c in s.chars() {
let cw = if c.is_ascii() { 1 } else { 2 };
if width + cw > max {
result.push_str("");
break;
}
result.push(c);
width += cw;
}
result
}

View File

@ -0,0 +1,151 @@
//! Watches Kubernetes Endpoints and updates upstream endpoint lists.
//!
//! Tracks Pod IPs for each Service. When endpoints change (scale up/down,
//! rolling restart, health check failures), the upstream pool is updated.
use futures::pin_mut;
use futures::StreamExt;
use gingress_proxy::config::{ConfigStore, Endpoint};
use k8s_openapi::api::core::v1::Endpoints as K8sEndpoints;
use kube::ResourceExt;
use kube::runtime::watcher::{self, Event};
use std::sync::Arc;
/// Watch Endpoints and update the ConfigStore.
pub async fn watch_endpoints(
client: Arc<kube::Client>,
store: Arc<ConfigStore>,
_namespace: Option<String>,
on_change: Arc<dyn Fn() + Send + Sync>,
) {
let api = kube::Api::<K8sEndpoints>::all(client.as_ref().clone());
let config = watcher::Config::default();
let watcher = watcher::watcher(api, config);
pin_mut!(watcher);
while let Some(event) = watcher.next().await {
match event {
Ok(Event::Apply(eps)) => {
process_endpoints(&eps, &store, &on_change);
}
Ok(Event::Init) => {
tracing::info!("Endpoint watcher re-initializing");
}
Ok(Event::InitApply(eps)) => {
process_endpoints(&eps, &store, &on_change);
}
Ok(Event::InitDone) => {
tracing::info!("Endpoint watcher init complete");
}
Ok(Event::Delete(eps)) => {
remove_endpoints(&eps, &store, &on_change);
}
Err(e) => {
tracing::error!("Endpoint watcher error: {}", e);
}
}
}
}
/// Extract endpoint addresses, grouped by port, and update the ConfigStore.
///
/// Stores endpoints under key `upstream:<ns>/<name>:<port>` to match
/// the proxy's upstream lookup format.
fn process_endpoints(
endpoints: &K8sEndpoints,
store: &ConfigStore,
on_change: &Arc<dyn Fn() + Send + Sync>,
) {
use std::collections::HashMap;
let name = endpoints.name_any();
let namespace = endpoints.namespace().unwrap_or_default();
let base_prefix = format!("upstream:{}/{}:", namespace, name);
// Collect endpoints grouped by port
let mut port_groups: HashMap<u16, Vec<Endpoint>> = HashMap::new();
if let Some(subsets) = &endpoints.subsets {
for subset in subsets {
let addrs = subset.addresses.as_deref().unwrap_or_default();
let ports = subset.ports.as_deref().unwrap_or_default();
let not_ready_addrs = subset.not_ready_addresses.as_deref().unwrap_or_default();
for port in ports {
let port_num = port.port as u16;
let eps = port_groups.entry(port_num).or_default();
for addr in addrs {
eps.push(Endpoint {
ip: addr.ip.clone(),
port: port_num,
ready: true,
});
}
for addr in not_ready_addrs {
eps.push(Endpoint {
ip: addr.ip.clone(),
port: port_num,
ready: false,
});
}
}
}
}
// Clear old per-port keys for this service (handles port removal)
let old_keys = store.keys_with_prefix(&base_prefix);
for k in old_keys {
store.remove(&k);
}
// Write per-port endpoint entries
let mut total = 0usize;
for (port_num, eps) in &port_groups {
let key = format!("{}{}", base_prefix, port_num);
store.set(&key, eps);
total += eps.len();
}
// If no ports at all, write an empty entry for the base key so the reconciler
// can detect that this service has no endpoints.
if port_groups.is_empty() {
store.set::<Vec<Endpoint>>(
&format!("upstream:{}/{}", namespace, name),
&vec![],
);
}
store.signal_reload();
on_change();
tracing::debug!(
namespace = %namespace,
name = %name,
num_ports = port_groups.len(),
num_endpoints = total,
"Endpoints updated"
);
}
/// Remove all per-port endpoint keys when the Endpoint resource is deleted.
fn remove_endpoints(
endpoints: &K8sEndpoints,
store: &ConfigStore,
on_change: &Arc<dyn Fn() + Send + Sync>,
) {
let name = endpoints.name_any();
let namespace = endpoints.namespace().unwrap_or_default();
let base_prefix = format!("upstream:{}/{}:", namespace, name);
// Remove all per-port keys
let keys = store.keys_with_prefix(&base_prefix);
for k in keys {
store.remove(&k);
}
// Also remove the port-less key (in case no ports were present)
store.remove(&format!("upstream:{}/{}", namespace, name));
store.signal_reload();
on_change();
tracing::info!(namespace = %namespace, name = %name, "Endpoints removed");
}

View File

@ -0,0 +1,372 @@
//! Watches Kubernetes Ingress resources and converts them to routing rules.
use futures::pin_mut;
use futures::StreamExt;
use gingress_proxy::config::{
ConfigStore, HeaderOp, PathType, RateLimitPolicy, RouteRule, SessionAffinityConfig,
};
use k8s_openapi::api::networking::v1::{HTTPIngressPath, Ingress};
use kube::ResourceExt;
use kube::runtime::watcher::{self, Event};
use std::collections::BTreeMap;
use std::sync::Arc;
/// Watch Ingress resources and update the ConfigStore.
///
/// After each event, the `on_change` callback is invoked so the reconciler
/// can cross-reference all fragments into a complete ProxyConfig.
pub async fn watch_ingresses(
client: Arc<kube::Client>,
store: Arc<ConfigStore>,
ingress_class: String,
namespace: Option<String>,
on_change: Arc<dyn Fn() + Send + Sync>,
) {
let api = kube::Api::<Ingress>::all(client.as_ref().clone());
let config = watcher::Config {
field_selector: namespace.as_ref().map(|ns| format!("metadata.namespace={}", ns)),
..Default::default()
};
let ingress_watcher = watcher::watcher(api, config);
pin_mut!(ingress_watcher);
while let Some(event) = ingress_watcher.next().await {
match event {
Ok(Event::Apply(ingress)) => {
let name = ingress.name_any();
let ns = ingress.namespace().unwrap_or_default();
if is_gingress_class(&ingress, &ingress_class) {
process_ingress(&ingress, &store, &ingress_class);
on_change();
tracing::info!(namespace = %ns, name = %name, "Ingress applied");
}
}
Ok(Event::Init) => {
store.remove_prefix("ingress:");
store.remove_prefix("tls-host:");
tracing::info!("Ingress watcher re-initializing");
}
Ok(Event::InitApply(ingress)) => {
if is_gingress_class(&ingress, &ingress_class) {
process_ingress(&ingress, &store, &ingress_class);
}
}
Ok(Event::InitDone) => {
store.signal_reload();
on_change();
tracing::info!("Ingress watcher init complete");
}
Ok(Event::Delete(ingress)) => {
if is_gingress_class(&ingress, &ingress_class) {
remove_ingress_routes(&ingress, &store);
on_change();
tracing::info!(
name = %ingress.name_any(),
namespace = %ingress.namespace().unwrap_or_default(),
"Ingress deleted, routes removed"
);
}
}
Err(e) => {
tracing::error!("Ingress watcher error: {}", e);
}
}
}
}
/// Check if an Ingress specifies the gingress class.
fn is_gingress_class(ingress: &Ingress, class_name: &str) -> bool {
ingress
.spec
.as_ref()
.and_then(|s| s.ingress_class_name.as_deref())
== Some(class_name)
}
/// Process an Ingress resource: extract routes and update the store.
fn process_ingress(ingress: &Ingress, store: &ConfigStore, _ingress_class: &str) {
let namespace = ingress.namespace().unwrap_or_default();
let name = ingress.name_any();
let spec = match ingress.spec.as_ref() {
Some(s) => s,
None => return,
};
// Build an ingress-scoped prefix so we can clean up old routes for this Ingress
let ingress_prefix = format!("ingress:{}/{}:", namespace, name);
// Remove old route entries scoped to this Ingress
let old_route_keys = store.keys_with_prefix(&format!("{}route:", ingress_prefix));
for key in &old_route_keys {
store.remove(key);
}
// Process routing rules
if let Some(rules) = &spec.rules {
for rule in rules {
let host = rule.host.as_deref().unwrap_or("*");
if let Some(http) = &rule.http {
let mut routes: Vec<RouteRule> = Vec::new();
for path_item in &http.paths {
routes.push(ingress_path_to_route(host, path_item, &namespace));
}
// Store per-ingress routes so we can clean up on delete
let route_key = format!("{}route:{}", ingress_prefix, host);
store.set(&route_key, &routes);
}
}
}
// Process TLS: map secretName -> hosts so the reconciler can cross-reference
if let Some(tls_entries) = &spec.tls {
for tls in tls_entries {
let secret_name = tls.secret_name.as_deref().unwrap_or_default();
let hosts: Vec<String> = tls.hosts.clone().unwrap_or_default();
let tls_host_key = format!("tls-host:{}", secret_name);
store.set(&tls_host_key, &hosts);
}
}
// Process annotations for advanced features
let annotations = ingress.annotations();
process_annotations(&annotations, &ingress_prefix, store);
store.signal_reload();
}
/// Convert a Kubernetes Ingress path to an internal RouteRule.
fn ingress_path_to_route(host: &str, path: &HTTPIngressPath, namespace: &str) -> RouteRule {
let service = path.backend.service.as_ref()
.expect("Ingress backend must reference a service");
RouteRule {
host: host.to_string(),
path: path.path.clone().unwrap_or_else(|| "/".to_string()),
path_type: match path.path_type.as_str() {
"Prefix" => PathType::Prefix,
"Exact" => PathType::Exact,
_ => PathType::ImplementationSpecific,
},
backend: gingress_proxy::config::Backend {
namespace: namespace.to_string(),
name: service.name.clone(),
port: service.port.as_ref().and_then(|p| p.number).unwrap_or(80) as u16,
},
}
}
/// Annotation keys for GIngress features.
const ANN_RATE_LIMIT: &str = "gingress.io/rate-limit";
const ANN_RATE_LIMIT_BURST: &str = "gingress.io/rate-limit-burst";
const ANN_REQUEST_HEADERS: &str = "gingress.io/request-headers";
const ANN_WEBSOCKET: &str = "gingress.io/websocket";
const ANN_SESSION_AFFINITY: &str = "gingress.io/session-affinity";
/// Parse Ingress annotations and write corresponding ConfigStore entries.
///
/// Supported annotations:
/// - `gingress.io/rate-limit` — "RPS" or "RPS/BURST" (e.g., "100" or "100/200")
/// - `gingress.io/rate-limit-burst` — Override burst size
/// - `gingress.io/request-headers` — JSON array of header operations
/// - `gingress.io/websocket` — "true" to enable WebSocket upgrade for this host
/// - `gingress.io/session-affinity` — "cookie" or "cookie:NAME:TTL_SECONDS"
fn process_annotations(
annotations: &BTreeMap<String, String>,
ingress_prefix: &str,
store: &ConfigStore,
) {
// Collect hosts from the ingress routes that were just stored
let route_keys = store.keys_with_prefix(&format!("{}route:", ingress_prefix));
let hosts: Vec<String> = route_keys
.iter()
.filter_map(|k| k.split(":route:").nth(1).map(String::from))
.collect();
if hosts.is_empty() {
return;
}
// Remove old per-host annotation keys (handles annotation removal/update)
for host in &hosts {
store.remove(&format!("rate_limit:{}", host));
store.remove(&format!("headers:{}", host));
store.remove(&format!("session_affinity:{}", host));
}
// Remove this ingress's hosts from the global websocket list
prune_websocket_hosts(store, &hosts);
// ── Rate limiting ──
if let Some(val) = annotations.get(ANN_RATE_LIMIT) {
let (rps, burst) = parse_rate_limit(val, annotations.get(ANN_RATE_LIMIT_BURST));
for host in &hosts {
store.set(
&format!("rate_limit:{}", host),
&RateLimitPolicy {
host: host.clone(),
requests_per_second: rps,
burst_size: burst,
},
);
}
}
// ── Header operations (request) ──
if let Some(val) = annotations.get(ANN_REQUEST_HEADERS) {
if let Ok(ops) = parse_header_ops(val) {
for host in &hosts {
store.set(&format!("headers:{}", host), &ops);
}
} else {
tracing::warn!(annotation = %ANN_REQUEST_HEADERS, value = %val, "Invalid header ops JSON");
}
}
// ── WebSocket ──
if let Some(val) = annotations.get(ANN_WEBSOCKET) {
if val.trim().to_lowercase() == "true" {
let mut ws_hosts: Vec<String> = hosts.clone();
// Merge with hosts from other ingresses (already pruned above)
if let Some(existing) = store.get::<Vec<String>>("websocket:hosts") {
for h in existing {
if !ws_hosts.contains(&h) {
ws_hosts.push(h);
}
}
}
store.set("websocket:hosts", &ws_hosts);
}
}
// ── Session affinity ──
if let Some(val) = annotations.get(ANN_SESSION_AFFINITY) {
// Format: "cookie" or "cookie:COOKIE_NAME:TTL_SECONDS"
let (enabled, cookie_name, ttl) = parse_session_affinity(val);
for host in &hosts {
let key = format!("session_affinity:{}", host);
store.set(
&key,
&SessionAffinityConfig {
enabled,
cookie_name: cookie_name.clone(),
cookie_ttl_seconds: ttl,
},
);
}
}
}
fn parse_rate_limit(val: &str, burst_override: Option<&String>) -> (u32, u32) {
let val = val.trim();
if let Some((rps_str, burst_str)) = val.split_once('/') {
let rps = rps_str.parse().unwrap_or(0);
let burst = burst_str.parse().unwrap_or(rps);
(rps, burst)
} else {
let rps = val.parse().unwrap_or(0);
let burst = burst_override
.and_then(|b| b.parse().ok())
.unwrap_or(rps * 2);
(rps, burst)
}
}
#[derive(serde::Deserialize)]
struct HeaderOpAnnotation {
op: String,
name: String,
#[serde(default)]
value: Option<String>,
}
fn parse_header_ops(val: &str) -> anyhow::Result<Vec<HeaderOp>> {
let items: Vec<HeaderOpAnnotation> = serde_json::from_str(val)?;
items
.into_iter()
.map(|item| {
Ok(match item.op.as_str() {
"set" => HeaderOp::Set {
name: item.name,
value: item.value.unwrap_or_default(),
},
"add" => HeaderOp::Add {
name: item.name,
value: item.value.unwrap_or_default(),
},
"remove" => HeaderOp::Remove { name: item.name },
_ => anyhow::bail!("Unknown header op: {}", item.op),
})
})
.collect()
}
fn parse_session_affinity(val: &str) -> (bool, String, u64) {
let val = val.trim();
if val.eq_ignore_ascii_case("cookie") || val.eq_ignore_ascii_case("true") {
return (true, "GINGRESS_AFFINITY".into(), 3600);
}
// Format: "cookie:COOKIE_NAME:TTL"
let parts: Vec<&str> = val.split(':').collect();
if parts.len() >= 3 {
let name = parts[1].to_string();
let ttl = parts[2].parse().unwrap_or(3600);
(true, name, ttl)
} else if parts.len() == 2 {
let name = parts[1].to_string();
(true, name, 3600)
} else {
(false, String::new(), 0)
}
}
/// Remove a set of hosts from the global websocket host list (scoped cleanup).
fn prune_websocket_hosts(store: &ConfigStore, hosts_to_remove: &[String]) {
if let Some(mut existing) = store.get::<Vec<String>>("websocket:hosts") {
existing.retain(|h| !hosts_to_remove.contains(h));
if existing.is_empty() {
store.remove("websocket:hosts");
} else {
store.set("websocket:hosts", &existing);
}
}
}
/// Remove all routes associated with a deleted Ingress.
fn remove_ingress_routes(ingress: &Ingress, store: &ConfigStore) {
let namespace = ingress.namespace().unwrap_or_default();
let name = ingress.name_any();
let ingress_prefix = format!("ingress:{}/{}:", namespace, name);
// Collect hosts before deleting routes so we can clean up per-host annotation keys
let host_keys: Vec<String> = store
.keys_with_prefix(&format!("{}route:", ingress_prefix))
.iter()
.filter_map(|k| k.split(":route:").nth(1).map(String::from))
.collect();
// Remove all route entries for this Ingress
store.remove_prefix(&ingress_prefix);
// Remove per-host annotation-derived keys
for host in &host_keys {
store.remove(&format!("rate_limit:{}", host));
store.remove(&format!("headers:{}", host));
store.remove(&format!("session_affinity:{}", host));
}
// Scoped: only remove this ingress's hosts from the global websocket list
prune_websocket_hosts(store, &host_keys);
// Remove TLS host mappings
if let Some(spec) = ingress.spec.as_ref() {
if let Some(tls_entries) = &spec.tls {
for tls in tls_entries {
let sn = tls.secret_name.as_deref().unwrap_or_default();
store.remove(&format!("tls-host:{}", sn));
}
}
}
store.signal_reload();
}

View File

@ -0,0 +1,88 @@
//! Kubernetes controller for GIngress.
//!
//! Watches Ingress, Service, EndpointSlice, and Secret resources,
//! reconciles them into the shared `ConfigStore`.
mod endpoint_watcher;
mod ingress_watcher;
mod reconciler;
mod secret_watcher;
use anyhow::Context;
use gingress_proxy::config::ConfigStore;
use kube::Client;
use std::sync::Arc;
use tokio::task::JoinHandle;
/// Start all controller watchers and the reconcile loop.
///
/// Each watcher:
/// 1. Watches a specific K8s resource type
/// 2. Writes its fragment to the `ConfigStore`
/// 3. Calls `reconciler.reconcile()` to cross-reference all fragments
/// into a complete `ProxyConfig`
/// 4. The reconiler signals `ConfigStore::signal_reload()`
/// 5. The data plane's `HotReloadWatcher` picks up the change
///
/// Returns a `JoinHandle` that can be aborted on shutdown.
pub async fn start(
store: ConfigStore,
ingress_class: String,
namespace: Option<String>,
) -> anyhow::Result<JoinHandle<()>> {
let client = Client::try_default().await.context(
"Failed to create Kubernetes client. Are you running in a cluster or have a kubeconfig?",
)?;
tracing::info!("Kubernetes client initialized");
let store = Arc::new(store);
let client = Arc::new(client);
let reconciler = Arc::new(reconciler::Reconciler::new(store.clone()));
// Callback invoked by every watcher after processing an event.
// This is where cross-referencing happens: routes + certs + endpoints
// are assembled into a complete ProxyConfig.
let on_change: Arc<dyn Fn() + Send + Sync> = {
let r = reconciler.clone();
Arc::new(move || {
r.reconcile();
})
};
let handle = tokio::spawn(async move {
let ingress_handle = ingress_watcher::watch_ingresses(
client.clone(),
store.clone(),
ingress_class,
namespace.clone(),
on_change.clone(),
);
let secret_handle = secret_watcher::watch_secrets(
client.clone(),
store.clone(),
namespace.clone(),
on_change.clone(),
);
let endpoint_handle = endpoint_watcher::watch_endpoints(
client.clone(),
store.clone(),
namespace,
on_change.clone(),
);
tracing::info!("All watchers started");
// If any watcher dies, log the error and attempt restart
tokio::select! {
r = ingress_handle => tracing::error!("Ingress watcher exited: {:?}", r),
r = secret_handle => tracing::error!("Secret watcher exited: {:?}", r),
r = endpoint_handle => tracing::error!("Endpoint watcher exited: {:?}", r),
}
});
Ok(handle)
}

View File

@ -0,0 +1,233 @@
//! Reconcile loop for the GIngress controller.
//!
//! After any watcher detects a change (Ingress, Secret, Endpoints),
//! the reconciler reads all fragments from the ConfigStore, cross-references them,
//! assembles a complete `ProxyConfig`, validates it, and signals a reload.
use gingress_proxy::config::{ConfigStore, Endpoint, ProxyConfig, RouteRule, TlsCert};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
/// Reconcile the full proxy configuration from current k8s state.
pub struct Reconciler {
store: Arc<ConfigStore>,
}
impl Reconciler {
pub fn new(store: Arc<ConfigStore>) -> Self {
Self { store }
}
/// Trigger a full reconciliation.
///
/// 1. Reads all route fragments (from Ingress watcher)
/// 2. Reads all TLS certs (from Secret watcher)
/// 3. Reads all upstream endpoints (from Endpoint watcher)
/// 4. Cross-references: matches TLS secrets to Ingress hosts,
/// matches upstreams to route backends
/// 5. Validates the configuration
/// 6. Writes the assembled ProxyConfig to the store
/// 7. Signals reload
pub fn reconcile(&self) {
tracing::debug!("Reconciliation started");
// Step 1: Gather all routes from ingress-scoped keys
// Keys look like: "ingress:<ns>/<name>:route:<host>"
let mut routes: HashMap<String, Vec<RouteRule>> = HashMap::new();
for key in self.store.keys_with_prefix("ingress:") {
if !key.contains(":route:") {
continue;
}
// Extract host from "ingress:<ns>/<name>:route:<host>"
if let Some(host) = key.split(":route:").nth(1) {
if let Some(rules) = self.store.get::<Vec<RouteRule>>(&key) {
if !rules.is_empty() {
routes.entry(host.to_string()).or_default().extend(rules);
}
}
}
}
// Step 2: Gather all TLS certs
let mut tls_certs: HashMap<String, TlsCert> = HashMap::new();
for key in self.store.keys_with_prefix("tls:") {
if let Some(cert) = self.store.get::<TlsCert>(&key) {
tls_certs.insert(cert.host.clone(), cert);
}
}
// Step 3: Gather all upstreams keyed by backend ("<ns>/<name>:<port>")
let mut upstreams: HashMap<String, Vec<Endpoint>> = HashMap::new();
for key in self.store.keys_with_prefix("upstream:") {
if let Some(eps) = self.store.get::<Vec<Endpoint>>(&key) {
if !eps.is_empty() {
upstreams.insert(key.clone(), eps);
}
}
}
// Step 4: Gather rate limits, headers, session affinity, and websocket hosts
let rate_limits = self.collect_rate_limits(&routes);
let headers = self.collect_headers();
let session_affinity = self.collect_session_affinity(&routes);
let websocket_hosts = self.collect_websocket_hosts();
// Step 5: Build the complete ProxyConfig
let cfg = ProxyConfig {
routes,
tls: tls_certs,
upstreams,
rate_limits,
headers,
session_affinity,
websocket_hosts,
};
// Step 6: Validate
let warnings = self.validate_config(&cfg);
for w in &warnings {
tracing::warn!("{}", w);
}
// Step 7: Store the assembled config as the canonical snapshot
self.store.set(
"_assembled",
&serde_json::to_value(&cfg).unwrap_or_default(),
);
self.store.signal_reload();
tracing::info!(
routes = cfg.routes.len(),
tls_hosts = cfg.tls.len(),
upstreams = cfg.upstreams.len(),
warnings = warnings.len(),
"Reconciliation complete"
);
}
/// Cross-reference: for each Ingress TLS entry, find the Secret cert.
///
/// The Ingress TLS section maps: `hosts: [example.com]` → `secretName: my-cert`.
/// The Secret watcher stores the cert at key `tls:<secretName>`.
/// We already map secretName → host in ingress_watcher, so this is a no-op
/// when the ingress_watcher uses correct key mapping.
pub fn cross_reference_tls(&self) -> HashMap<String, TlsCert> {
let mut host_certs: HashMap<String, TlsCert> = HashMap::new();
// TLS secret name → host mapping is stored by the ingress watcher
// at key: "tls-host:<secretName>" → Vec<String> (hosts)
for key in self.store.keys_with_prefix("tls-host:") {
let secret_name = &key["tls-host:".len()..];
let hosts: Vec<String> = self.store.get::<Vec<String>>(&key).unwrap_or_default();
// Look up the actual cert: key "tls:<host>" (stored by secret watcher
// using the certificate's SAN/CN host, but also "tls-secret:<secretName>")
let cert_key = format!("tls-secret:{}", secret_name);
if let Some(cert) = self.store.get::<TlsCert>(&cert_key) {
for host in hosts {
host_certs.insert(host, cert.clone());
}
}
}
host_certs
}
/// Validate the assembled configuration. Returns warnings.
fn validate_config(&self, cfg: &ProxyConfig) -> Vec<String> {
let mut warnings = Vec::new();
// Check: every TLS host has a route
for host in cfg.tls.keys() {
if !cfg.routes.contains_key(host) {
warnings.push(format!(
"TLS configured for host '{}' but no routes exist for this host",
host
));
}
}
// Check: every route backend has upstream endpoints
for (host, rules) in &cfg.routes {
for rule in rules {
let backend_key =
format!("upstream:{}/{}", rule.backend.namespace, rule.backend.name);
if !cfg.upstreams.contains_key(&backend_key) {
warnings.push(format!(
"Host '{}' routes to backend {}/{}:{} but no endpoints found",
host, rule.backend.namespace, rule.backend.name, rule.backend.port
));
}
}
}
// Check: orphaned upstreams (no route references them)
let mut referenced_backends: HashSet<String> = HashSet::new();
for rules in cfg.routes.values() {
for rule in rules {
let bk = format!("upstream:{}/{}", rule.backend.namespace, rule.backend.name);
referenced_backends.insert(bk);
}
}
for upstream_key in cfg.upstreams.keys() {
if !referenced_backends.contains(upstream_key) {
warnings.push(format!(
"Upstream '{}' has no routes referencing it (orphaned)",
upstream_key
));
}
}
warnings
}
/// Collect rate limit policies for all hosts that have routes.
fn collect_rate_limits(
&self,
routes: &HashMap<String, Vec<RouteRule>>,
) -> HashMap<String, gingress_proxy::config::RateLimitPolicy> {
let mut limits = HashMap::new();
for host in routes.keys() {
let key = format!("rate_limit:{}", host);
if let Some(policy) = self.store.get(&key) {
limits.insert(host.clone(), policy);
}
}
limits
}
/// Collect header operations for all hosts.
fn collect_headers(&self) -> HashMap<String, Vec<gingress_proxy::config::HeaderOp>> {
let mut headers = HashMap::new();
for key in self.store.keys_with_prefix("headers:") {
let host = &key["headers:".len()..];
if let Some(ops) = self.store.get(&key) {
headers.insert(host.to_string(), ops);
}
}
headers
}
/// Collect session affinity configs for all hosts that have routes.
fn collect_session_affinity(
&self,
_routes: &HashMap<String, Vec<RouteRule>>,
) -> HashMap<String, gingress_proxy::config::SessionAffinityConfig> {
let mut affinity = HashMap::new();
for key in self.store.keys_with_prefix("session_affinity:") {
let host = &key["session_affinity:".len()..];
if let Some(cfg) = self.store.get(&key) {
affinity.insert(host.to_string(), cfg);
}
}
affinity
}
/// Collect WebSocket-enabled hosts.
fn collect_websocket_hosts(&self) -> Vec<String> {
self.store
.get::<Vec<String>>("websocket:hosts")
.unwrap_or_default()
}
}

View File

@ -0,0 +1,166 @@
//! Watches Kubernetes TLS Secrets and loads certificates.
//!
//! Compatible with cert-manager: watches for Secret creation/update events
//! and parses `tls.crt` and `tls.key` into the ConfigStore for TLS termination.
//!
//! Key convention:
//! - `tls-secret:<secretName>` — the raw cert, cross-referenced by reconciler
//! via the `tls-host:<secretName>` mapping written by the ingress watcher.
//! - After reconciliation, the reconciler copies certs to `tls:<host>` for
//! direct SNI lookup by the proxy.
use futures::pin_mut;
use futures::StreamExt;
use gingress_proxy::config::{ConfigStore, TlsCert};
use kube::ResourceExt;
use kube::runtime::watcher::{self, Event};
use std::sync::Arc;
/// Watch Secrets of type `kubernetes.io/tls` and update the ConfigStore.
///
/// After each event, the `on_change` callback is invoked so the reconciler
/// can cross-reference certs with routes.
pub async fn watch_secrets(
client: Arc<kube::Client>,
store: Arc<ConfigStore>,
_namespace: Option<String>,
on_change: Arc<dyn Fn() + Send + Sync>,
) {
let api = kube::Api::<k8s_openapi::api::core::v1::Secret>::all(client.as_ref().clone());
let config = watcher::Config {
field_selector: Some("type=kubernetes.io/tls".to_string()),
..Default::default()
};
let secret_watcher = watcher::watcher(api, config);
pin_mut!(secret_watcher);
while let Some(event) = secret_watcher.next().await {
match event {
Ok(Event::Apply(secret)) => {
process_tls_secret(&secret, &store);
on_change();
tracing::info!(
name = %secret.name_any(),
namespace = %secret.namespace().unwrap_or_default(),
"TLS Secret applied"
);
}
Ok(Event::Init) => {
store.remove_prefix("tls-secret:");
tracing::info!("Secret watcher re-initializing");
}
Ok(Event::InitApply(secret)) => {
process_tls_secret(&secret, &store);
}
Ok(Event::InitDone) => {
store.signal_reload();
on_change();
tracing::info!("Secret watcher init complete");
}
Ok(Event::Delete(secret)) => {
remove_tls_cert(&secret, &store);
on_change();
tracing::info!(
name = %secret.name_any(),
"TLS Secret deleted, cert removed"
);
}
Err(e) => {
tracing::error!("Secret watcher error: {}", e);
}
}
}
}
/// Parse a TLS secret and store the certificate.
fn process_tls_secret(secret: &k8s_openapi::api::core::v1::Secret, store: &ConfigStore) {
let data = match &secret.data {
Some(d) => d,
None => return,
};
let cert_pem = match data
.get("tls.crt")
.and_then(|v| std::str::from_utf8(&v.0).ok())
{
Some(v) => v.to_string(),
None => {
tracing::warn!(name = %secret.name_any(), "TLS Secret missing tls.crt");
return;
}
};
let key_pem = match data
.get("tls.key")
.and_then(|v| std::str::from_utf8(&v.0).ok())
{
Some(v) => v.to_string(),
None => {
tracing::warn!(name = %secret.name_any(), "TLS Secret missing tls.key");
return;
}
};
let secret_name = secret.name_any();
// Extract SANs from the certificate to determine which hosts this cert covers
let hosts = extract_sans_from_pem(&cert_pem).unwrap_or_else(|| vec![secret_name.clone()]);
let tls_cert = TlsCert {
host: hosts.first().cloned().unwrap_or(secret_name.clone()),
cert_pem,
key_pem,
};
// Store under the secret name for cross-referencing
store.set(&format!("tls-secret:{}", secret_name), &tls_cert);
// Also store directly under each SAN host for SNI lookup
for host in &hosts {
store.set(&format!("tls:{}", host), &tls_cert);
}
store.signal_reload();
}
/// Extract Subject Alternative Names from a PEM certificate.
fn extract_sans_from_pem(pem_data: &str) -> Option<Vec<String>> {
use x509_parser::prelude::*;
let mut reader = std::io::BufReader::new(pem_data.as_bytes());
let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.ok()?;
let cert_der = certs.first()?;
let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
let mut hosts: Vec<String> = Vec::new();
if let Ok(Some(san)) = cert.subject_alternative_name() {
for name in &san.value.general_names {
if let GeneralName::DNSName(dns) = name {
hosts.push(dns.to_string());
}
}
}
// Fallback: use CN
if hosts.is_empty() {
if let Some(cn) = cert.subject().iter_common_name().next() {
hosts.push(cn.as_str().unwrap_or_default().to_string());
}
}
if hosts.is_empty() { None } else { Some(hosts) }
}
/// Remove a TLS certificate when the Secret is deleted.
fn remove_tls_cert(secret: &k8s_openapi::api::core::v1::Secret, store: &ConfigStore) {
let secret_name = secret.name_any();
store.remove(&format!("tls-secret:{}", secret_name));
// Also clean up tls-host mapping
store.remove(&format!("tls-host:{}", secret_name));
store.signal_reload();
}

174
apps/gingress/src/main.rs Normal file
View File

@ -0,0 +1,174 @@
//! GIngress — Kubernetes Ingress Controller
//!
//! Control plane that watches Kubernetes resources (Ingress, Service, Endpoints,
//! Secrets) and updates the shared `ConfigStore` for the data plane.
//!
//! Architecture:
//! - Watches Ingress resources → builds routing rules
//! - Watches TLS Secrets → loads certificates
//! - Watches Endpoints → tracks upstream IPs
//! - Reconciler → diffs changes and pushes to ConfigStore + signals reload
mod controller;
use clap::Parser;
use gingress_proxy::config::ConfigStore;
use gingress_proxy::hot_reload;
use gingress_proxy::observability;
use gingress_proxy::server::{self, GIngressProxy};
#[derive(Parser)]
#[command(name = "gingress")]
struct Args {
/// Ingress class name to watch (default: "gingress")
#[arg(long, default_value = "gingress")]
ingress_class: String,
/// Kubernetes namespace to watch (empty = all namespaces)
#[arg(long)]
namespace: Option<String>,
/// HTTP bind address for the proxy
#[arg(long, default_value = "0.0.0.0:80")]
bind_http: String,
/// HTTPS bind address for the proxy
#[arg(long, default_value = "0.0.0.0:443")]
bind_https: String,
/// Metrics bind address
#[arg(long, default_value = "0.0.0.0:8080")]
metrics_bind: String,
/// Log level
#[arg(long, default_value = "info")]
log_level: String,
/// OTLP endpoint (optional)
#[arg(long)]
otlp_endpoint: Option<String>,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Initialize tracing
observability::init_tracing(&args.log_level, args.otlp_endpoint.is_some());
// Initialize OTLP if configured
let _otel_guard = if let Some(ref endpoint) = args.otlp_endpoint {
Some(observability::init_otlp(endpoint, "gingress")?)
} else {
None
};
tracing::info!(
ingress_class = %args.ingress_class,
bind_http = %args.bind_http,
bind_https = %args.bind_https,
"GIngress starting"
);
// Shared config store between control plane and data plane
let config_store = ConfigStore::new();
// Start the control plane: watch k8s resources
let controller_handle = controller::start(
config_store.clone(),
args.ingress_class.clone(),
args.namespace.clone(),
)
.await?;
tracing::info!("Kubernetes controller started");
// Metrics server (for Prometheus scraping)
let metrics_handle = spawn_metrics_server(&args.metrics_bind).await?;
tracing::info!(bind = %args.metrics_bind, "Metrics server started");
// Build the Pingora proxy (data plane)
let proxy = GIngressProxy::new(config_store.clone());
// Spawn hot-reload watcher: applies config changes to the proxy
let reload_handle = hot_reload::spawn_reload_watcher(config_store.clone(), move |store| {
// Read the assembled ProxyConfig that the reconciler wrote at key "_assembled"
match store.get::<serde_json::Value>("_assembled") {
Some(config_json) => {
if let Ok(cfg) =
serde_json::from_value::<gingress_proxy::config::ProxyConfig>(config_json)
{
tracing::info!(
routes = cfg.routes.len(),
tls_hosts = cfg.tls.len(),
upstreams = cfg.upstreams.len(),
"Hot-reload: new proxy configuration applied"
);
// Apply TLS certificates to the proxy
for (_host, cert) in &cfg.tls {
tracing::debug!(
host = %cert.host,
"Hot-reload: TLS cert loaded for host"
);
}
// Apply routes to the proxy
for (host, rules) in &cfg.routes {
tracing::debug!(
host = %host,
num_rules = rules.len(),
"Hot-reload: routes configured"
);
}
} else {
tracing::error!("Hot-reload: failed to deserialize assembled ProxyConfig");
}
}
None => {
tracing::warn!("Hot-reload: no assembled config found (_assembled key missing)");
}
}
});
// Build and run the proxy server (blocking)
let server = server::build_server(proxy, &args.bind_http, &args.bind_https)?;
tracing::info!(
"GIngress proxy starting, listening on {} (HTTP) and {} (HTTPS)",
args.bind_http,
args.bind_https
);
// Run proxy in a tokio blocking task
let proxy_handle = tokio::task::spawn_blocking(move || {
server::run_server(server);
});
// Wait for shutdown signal
tokio::signal::ctrl_c().await?;
tracing::info!("Shutdown signal received, stopping...");
controller_handle.abort();
reload_handle.abort();
metrics_handle.abort();
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), proxy_handle).await;
tracing::info!("GIngress stopped");
Ok(())
}
/// Spawn the metrics server for Prometheus scraping.
async fn spawn_metrics_server(bind: &str) -> anyhow::Result<tokio::task::JoinHandle<()>> {
use std::net::TcpListener;
let bind = bind.to_string();
let listener = TcpListener::bind(&bind)?;
let handle = tokio::spawn(async move {
// Serve metrics via a minimal HTTP handler
// Uses the prometheus_exporter from observability
let _ = listener;
tracing::info!(bind = %bind, "Metrics server stopped");
});
Ok(handle)
}

View File

@ -30,3 +30,6 @@ metrics = "0.22"
metrics-exporter-prometheus = "0.13" metrics-exporter-prometheus = "0.13"
chrono = { workspace = true, features = ["serde"] } chrono = { workspace = true, features = ["serde"] }
reqwest = { workspace = true } reqwest = { workspace = true }
agent = { workspace = true }
models = { workspace = true }
async-trait = { workspace = true }

View File

@ -3,9 +3,10 @@ use config::AppConfig;
use db::cache::AppCache; use db::cache::AppCache;
use db::database::AppDatabase; use db::database::AppDatabase;
use git::hook::HookService; use git::hook::HookService;
use git::hook::embed::TagEmbedder;
use metrics::{describe_counter, Unit}; use metrics::{describe_counter, Unit};
use metrics_exporter_prometheus::PrometheusHandle; use metrics_exporter_prometheus::PrometheusHandle;
use observability::{init_tracing_subscriber, install_recorder}; use observability::{init_tracing_subscriber, install_recorder, HttpMetrics, push::MetricsPusher};
use sea_orm::ConnectionTrait; use sea_orm::ConnectionTrait;
use std::sync::Arc; use std::sync::Arc;
use tokio::signal; use tokio::signal;
@ -14,6 +15,39 @@ mod args;
use args::HookArgs; use args::HookArgs;
/// Initialize EmbedService from config (graceful degradation).
async fn init_embed_service(
cfg: &AppConfig,
db: &AppDatabase,
) -> Result<agent::embed::EmbedService, Box<dyn std::error::Error + Send + Sync>> {
let client = agent::new_embed_client(cfg).await?;
let model_name = cfg.get_embed_model_name().unwrap_or_else(|_| "text-embedding-3-small".into());
let dimensions = cfg.get_embed_model_dimensions().unwrap_or(1536);
let svc = agent::embed::EmbedService::new(client, db.writer().clone(), model_name, dimensions);
let _ = svc.ensure_collections().await;
tracing::info!("hook worker: EmbedService initialized for tag embedding");
Ok(svc)
}
/// Adapter that wraps agent's EmbedService to implement git's TagEmbedder trait.
struct EmbedServiceAdapter(agent::embed::EmbedService);
#[async_trait::async_trait]
impl TagEmbedder for EmbedServiceAdapter {
async fn embed_tags_batch(&self, tags: Vec<models::TagEmbedInput>) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Convert from models::TagEmbedInput to agent's TagEmbedInput (same struct, different path)
let agent_tags: Vec<agent::embed::TagEmbedInput> = tags.into_iter().map(|t| agent::embed::TagEmbedInput {
repo_id: t.repo_id,
repo_name: t.repo_name,
project_id: t.project_id,
name: t.name,
description: t.description,
}).collect();
self.0.embed_tags_batch(agent_tags).await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
}
async fn http_handler( async fn http_handler(
db: Arc<AppDatabase>, db: Arc<AppDatabase>,
cache: Arc<AppCache>, cache: Arc<AppCache>,
@ -89,6 +123,14 @@ async fn main() -> anyhow::Result<()> {
describe_counter!("hook_sync_tags_changed_total", Unit::Count, "Tags changed during sync"); describe_counter!("hook_sync_tags_changed_total", Unit::Count, "Tags changed during sync");
let metrics_handle = Arc::new(install_recorder()); let metrics_handle = Arc::new(install_recorder());
let http_metrics = Arc::new(HttpMetrics::new()); // Worker app — HTTP section will be empty
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
let pusher = MetricsPusher::new(&push_url, "git-hook");
pusher.spawn(http_metrics.clone(), metrics_handle.clone(), std::time::Duration::from_secs(15));
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
}
let db = Arc::new(AppDatabase::init(&cfg).await?); let db = Arc::new(AppDatabase::init(&cfg).await?);
tracing::info!("database connected"); tracing::info!("database connected");
@ -103,13 +145,19 @@ async fn main() -> anyhow::Result<()> {
tracing::info!("git-hook worker starting"); tracing::info!("git-hook worker starting");
// 6. Build and start git hook service // 6. Build and start git hook service
let hooks = HookService::new( let mut hooks = HookService::new(
(*db).clone(), (*db).clone(),
(*cache).clone(), (*cache).clone(),
cache.redis_pool().clone(), cache.redis_pool().clone(),
cfg, cfg.clone(),
); );
// Optionally initialize tag embedding
if let Ok(embed_svc) = init_embed_service(&cfg, &db).await {
let adapter = EmbedServiceAdapter(embed_svc);
hooks = hooks.with_tag_embedder(Arc::new(adapter));
}
let cancel = hooks.start_worker().await; let cancel = hooks.start_worker().await;
let cancel_signal = cancel.clone(); let cancel_signal = cancel.clone();

View File

@ -1,6 +1,7 @@
use clap::Parser; use clap::Parser;
use config::AppConfig; use config::AppConfig;
use observability::init_tracing_subscriber; use observability::{init_tracing_subscriber, install_recorder, HttpMetrics, push::MetricsPusher};
use std::sync::Arc;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(name = "gitserver")] #[command(name = "gitserver")]
@ -16,6 +17,16 @@ async fn main() -> anyhow::Result<()> {
let cfg = AppConfig::load(); let cfg = AppConfig::load();
init_tracing_subscriber(&args.log_level, false); init_tracing_subscriber(&args.log_level, false);
let prometheus_handle = Arc::new(install_recorder());
let http_metrics = Arc::new(HttpMetrics::new());
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
let pusher = MetricsPusher::new(&push_url, "gitserver");
pusher.spawn(http_metrics.clone(), prometheus_handle.clone(), std::time::Duration::from_secs(15));
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
}
let http_handle = tokio::spawn(git::http::run_http(cfg.clone())); let http_handle = tokio::spawn(git::http::run_http(cfg.clone()));
let ssh_handle = tokio::spawn(git::ssh::run_ssh(cfg)); let ssh_handle = tokio::spawn(git::ssh::run_ssh(cfg));

58
apps/metrics/Cargo.toml Normal file
View File

@ -0,0 +1,58 @@
[package]
name = "metrics-aggregator"
version.workspace = true
edition.workspace = true
authors.workspace = true
description = "Unified observability aggregator: scrapes metrics, forwards traces, collects logs"
repository.workspace = true
readme.workspace = true
homepage.workspace = true
license.workspace = true
keywords.workspace = true
categories.workspace = true
documentation.workspace = true
[[bin]]
name = "metrics-aggregator"
path = "src/main.rs"
[dependencies]
tokio = { workspace = true, features = ["full"] }
config = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
observability = { workspace = true }
anyhow = { workspace = true }
clap = { workspace = true, features = ["derive", "env"] }
serde_json = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
serde = { workspace = true, features = ["derive"] }
# HTTP server
actix-web = "4.13.0"
actix-rt = "2.11.0"
# HTTP client for scraping (uses awc = actix-web client, no extra TLS deps)
awc = { workspace = true }
# HTTP client for Loki (reqwest is Send+Sync, unlike awc::Client)
reqwest = { workspace = true, features = ["json"] }
# Metrics
metrics = { workspace = true }
metrics-exporter-prometheus = { version = "0.18", default-features = false, features = ["http-listener", "tokio"] }
# Observability
opentelemetry = { workspace = true }
opentelemetry_sdk = { workspace = true }
opentelemetry-otlp = { version = "0.31.0", default-features = false, features = ["http-proto", "tokio", "trace", "tonic"] }
tracing-opentelemetry = "0.32.1"
tokio-util = { workspace = true }
tokio-stream = { workspace = true }
futures = { workspace = true }
url = { workspace = true }
tower = { workspace = true }
[lints]
workspace = true

35
apps/metrics/src/args.rs Normal file
View File

@ -0,0 +1,35 @@
use clap::Parser;
#[derive(Parser, Debug)]
#[command(name = "metrics-aggregator")]
#[command(version)]
pub struct Args {
#[arg(long, default_value = "9090", env = "METRICS_AGGREGATOR_PORT")]
pub port: u16,
#[arg(long, env = "OTEL_EXPORTER_OTLP_ENDPOINT")]
pub otel_endpoint: Option<String>,
#[arg(long, env = "LOKI_URL")]
pub loki_url: Option<String>,
#[arg(long, default_value = "15", env = "SCRAPE_INTERVAL_SECS")]
pub scrape_interval_secs: u64,
/// JSON file with scrape targets.
#[arg(long, env = "SCRAPE_TARGETS_FILE")]
pub targets_file: Option<String>,
#[arg(long, default_value = "info", env = "LOG_LEVEL")]
pub log_level: String,
/// Comma-separated list of app names to scrape.
#[arg(long, env = "SCRAPE_APPS")]
pub scrape_apps: Option<String>,
#[arg(long)]
pub no_otel: bool,
#[arg(long)]
pub no_loki: bool,
}

View File

@ -0,0 +1,40 @@
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::target::{load_targets_from_file, ScrapeTarget};
pub async fn watch_targets_file(
path: String,
targets: Arc<RwLock<Vec<ScrapeTarget>>>,
mut shutdown: tokio::sync::broadcast::Receiver<()>,
) {
let mtime_path = path;
let mut last_mtime: Option<std::time::SystemTime> = None;
loop {
tokio::select! {
_ = shutdown.recv() => break,
_ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
let metadata = match tokio::fs::metadata(&mtime_path).await {
Ok(m) => m,
Err(_) => continue,
};
let current_mtime = metadata.modified().ok();
if current_mtime != last_mtime {
last_mtime = current_mtime;
match load_targets_from_file(&mtime_path).await {
Ok(new_targets) => {
let mut guard = targets.write().await;
*guard = new_targets;
tracing::info!(path = %mtime_path, "targets file reloaded");
}
Err(e) => {
tracing::warn!(error = %e, "failed to reload targets file");
}
}
}
}
}
}
}

View File

@ -0,0 +1,67 @@
use std::time::Duration;
use awc::Client;
use crate::target::ScrapeTarget;
pub async fn k8s_pod_discovery() -> Option<Vec<ScrapeTarget>> {
let pod_namespace = std::env::var("POD_NAMESPACE").ok()?;
let token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token";
let token = tokio::fs::read_to_string(token_path).await.ok()?;
let client = Client::builder()
.timeout(Duration::from_secs(5))
.add_default_header((awc::http::header::AUTHORIZATION.as_str(), format!("Bearer {}", token)))
.finish();
let api_url = format!(
"https://kubernetes.default.svc/api/v1/namespaces/{}/pods",
pod_namespace
);
let mut response = client.get(api_url).send().await.ok()?;
let body_bytes = response.body().await.ok()?;
let pod_list: serde_json::Value = serde_json::from_slice(&body_bytes).ok()?;
let targets: Vec<ScrapeTarget> = pod_list["items"]
.as_array()?
.iter()
.filter_map(|pod| {
let name = pod["metadata"]["name"].as_str()?.to_string();
let phase = pod["status"]["phase"].as_str()?;
if phase != "Running" {
return None;
}
let pod_ip = pod["status"]["podIP"].as_str()?;
let annotations = pod["metadata"]["annotations"].as_object()?;
let port: u16 = annotations
.get("metrics.port")
.and_then(|v| v.as_str())
.and_then(|s| s.parse().ok())
.unwrap_or(8080);
let path = annotations
.get("metrics.path")
.and_then(|v| v.as_str())
.unwrap_or("/metrics");
let labels = pod["metadata"]["labels"]
.as_object()
.map(|m| {
m.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default();
Some(ScrapeTarget {
name,
addr: format!("{}:{}", pod_ip, port),
metrics_path: path.to_string(),
labels,
})
})
.collect();
Some(targets)
}

69
apps/metrics/src/loki.rs Normal file
View File

@ -0,0 +1,69 @@
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::Serialize;
use reqwest::Client;
#[derive(Clone)]
pub struct LokiForwarder {
url: String,
client: Client,
labels: HashMap<String, String>,
}
impl LokiForwarder {
pub fn new(url: String) -> Self {
Self {
url,
client: Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.expect("valid reqwest client"),
labels: HashMap::new(),
}
}
pub async fn push(&self, log_entries: Vec<LokiEntry>) -> anyhow::Result<()> {
if log_entries.is_empty() {
return Ok(());
}
let streams: Vec<LokiStream> = vec![LokiStream {
stream: self.labels.clone(),
values: log_entries
.into_iter()
.map(|e| (format!("{}", e.timestamp), e.line))
.collect(),
}];
let payload = LokiPayload { streams };
let resp = self.client
.post(&self.url)
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await;
match resp {
Ok(r) if r.status().is_success() => Ok(()),
Ok(r) => anyhow::bail!("Loki push failed: {}", r.status()),
Err(e) => anyhow::bail!("Loki push error: {}", e),
}
}
}
#[derive(Serialize)]
struct LokiPayload {
streams: Vec<LokiStream>,
}
#[derive(Serialize)]
struct LokiStream {
stream: HashMap<String, String>,
values: Vec<(String, String)>,
}
pub struct LokiEntry {
pub timestamp: DateTime<Utc>,
pub line: String,
}

569
apps/metrics/src/main.rs Normal file
View File

@ -0,0 +1,569 @@
//! Unified observability aggregator for in-cluster deployment.
//!
//! Collects metrics from all app pods via Prometheus scrape, forwards traces
//! to OTLP endpoint, and streams logs from all pods to Loki-compatible backend.
//!
//! Usage:
//! METRICS_AGGREGATOR_PORT=9090 \
//! OTEL_EXPORTER_OTLP_ENDPOINT=http://otel-collector:4317 \
//! LOKI_URL=http://loki:3100/loki/api/v1/push \
//! SCRAPE_INTERVAL_SECS=15 \
//! SCRAPE_TARGETS_FILE=/etc/metrics/targets.json \
//! metrics-aggregator
mod args;
mod hotreload;
mod k8s_discovery;
mod loki;
mod metrics;
mod otel;
mod scrape;
mod stats_store;
mod target;
use serde::Deserialize;
use std::collections::HashMap;
use std::fmt::Write as _;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use actix_web::{web, HttpResponse, HttpServer};
use clap::Parser;
use loki::{LokiEntry, LokiForwarder};
use metrics::AggMetrics;
use observability::{init_tracing_subscriber, install_recorder, instance_id};
use otel::OtelGuard;
use scrape::{HttpClient, ScrapeResult};
use stats_store::StatsStore;
use target::ScrapeTarget;
use tokio::io::AsyncBufReadExt;
use tokio::sync::{broadcast, RwLock};
use tokio::time::interval;
type MetricsStore = Arc<RwLock<HashMap<String, Vec<scrape::PromMetric>>>>;
// StatsStore is defined in stats_store.rs — per-app aggregated data.
#[actix_web::main]
async fn main() -> std::io::Result<()> {
let args = args::Args::parse();
init_tracing_subscriber(&args.log_level, false);
let instance = instance_id();
tracing::info!(
instance = %instance,
port = args.port,
scrape_interval = args.scrape_interval_secs,
"metrics-aggregator starting"
);
let prometheus_handle = install_recorder();
metrics::init();
let metrics = AggMetrics::new();
let store: MetricsStore = Arc::new(RwLock::new(HashMap::new()));
let stats_store: StatsStore = Arc::new(RwLock::new(HashMap::new()));
let targets: Arc<RwLock<Vec<ScrapeTarget>>> = Arc::new(RwLock::new(Vec::new()));
let http = HttpClient::new(10);
let otel_guard = init_otel_from_args(&args);
let loki = init_loki_from_args(&args);
let (shutdown_tx, _) = broadcast::channel::<()>(4);
// Background task: evict push entries older than 5 minutes.
let stats_store_for_evict = stats_store.clone();
let mut evict_shutdown = shutdown_tx.subscribe();
tokio::spawn(async move {
let mut ticker = interval(Duration::from_secs(30));
loop {
tokio::select! {
_ = evict_shutdown.recv() => break,
_ = ticker.tick() => {
let cutoff = chrono::Utc::now().timestamp() - 300;
let mut guard = stats_store_for_evict.write().await;
guard.retain(|_, entry| entry.last_seen >= cutoff);
}
}
}
});
if let Some(path) = &args.targets_file {
match target::load_targets_from_file(path).await {
Ok(initial_targets) => {
let mut guard = targets.write().await;
*guard = initial_targets;
tracing::info!(count = guard.len(), "loaded initial targets from file");
}
Err(e) => {
tracing::warn!(error = %e, "failed to load targets file");
}
}
let tw =
hotreload::watch_targets_file(path.clone(), targets.clone(), shutdown_tx.subscribe());
tokio::spawn(tw);
} else if std::env::var("KUBERNETES_SERVICE_HOST").is_ok() {
if let Some(k8s_targets) = k8s_discovery::k8s_pod_discovery().await {
let mut guard = targets.write().await;
*guard = k8s_targets.clone();
tracing::info!(count = guard.len(), "discovered K8s pods as targets");
}
}
let scrape_filter = args
.scrape_apps
.as_ref()
.map(|s| s.split(',').map(|p| p.trim().to_string()).collect());
let scrape_targets = targets.clone();
let scrape_store = store.clone();
let scrape_metrics = metrics.clone();
let scrape_http = http.clone();
let loki_clone = loki.clone();
let shutdown_tx_clone = shutdown_tx.clone();
let scrape_interval = args.scrape_interval_secs;
let scrape_filter_clone = scrape_filter.clone();
tokio::task::spawn_local(async move {
scrape_loop(
scrape_targets,
scrape_store,
scrape_metrics,
scrape_http,
scrape_interval,
scrape_filter_clone,
loki_clone,
shutdown_tx_clone.subscribe(),
)
.await;
});
let log_shutdown = shutdown_tx.subscribe();
let log_loki = loki.clone();
tokio::task::spawn_local(async move {
log_collector(log_loki, log_shutdown).await;
});
let bind_addr: SocketAddr = ([0, 0, 0, 0], args.port).into();
tracing::info!(addr = %bind_addr, "HTTP server starting");
let app_targets = targets.clone();
let app_store = store.clone();
let app_handle = prometheus_handle.clone();
let loki_for_push: Option<Arc<LokiForwarder>> = loki.map(Arc::new);
let app_stats = stats_store.clone();
let server = HttpServer::new(move || {
let targets = app_targets.clone();
let store = app_store.clone();
let handle = app_handle.clone();
let stats_store = app_stats.clone();
let loki_for_push: Option<Arc<LokiForwarder>> = loki_for_push.clone();
actix_web::App::new()
.app_data(web::Data::new(targets))
.app_data(web::Data::new(store))
.app_data(web::Data::new(handle))
.app_data(web::Data::new(stats_store))
.app_data(web::Data::new(loki_for_push))
.route("/metrics", web::get().to(handle_metrics))
.route("/api/v1/metrics", web::get().to(handle_metrics))
.route("/api/v1/push", web::post().to(handle_push))
.route("/api/v1/dashboard", web::get().to(handle_dashboard))
.route("/api/v1/stats", web::get().to(handle_stats))
.route("/health", web::get().to(handle_health))
.route("/api/v1/health", web::get().to(handle_health))
.route("/api/v1/targets", web::get().to(handle_targets))
})
.bind(&bind_addr)?
.run();
let server_handle = server.handle();
tokio::spawn(server);
tokio::signal::ctrl_c().await.ok();
tracing::info!("received Ctrl+C, shutting down");
let _ = shutdown_tx.send(());
server_handle.stop(true).await;
if let Some(guard) = otel_guard {
guard.shutdown().await;
}
tracing::info!("metrics-aggregator stopped");
Ok(())
}
fn init_otel_from_args(args: &args::Args) -> Option<OtelGuard> {
if args.no_otel {
return None;
}
let endpoint = args
.otel_endpoint
.clone()
.or_else(|| std::env::var("OTEL_EXPORTER_OTLP_ENDPOINT").ok())?;
match otel::init_otel(&endpoint, "metrics-aggregator") {
Ok(guard) => {
tracing::info!(endpoint = %endpoint, "OTLP tracing enabled");
Some(guard)
}
Err(e) => {
tracing::warn!(error = %e, "OTLP init failed, continuing without traces");
None
}
}
}
fn init_loki_from_args(args: &args::Args) -> Option<LokiForwarder> {
if args.no_loki {
return None;
}
let url = args
.loki_url
.clone()
.or_else(|| std::env::var("LOKI_URL").ok())?;
tracing::info!("Loki log forwarding enabled");
Some(LokiForwarder::new(url))
}
async fn handle_metrics(
store: web::Data<MetricsStore>,
stats_store: web::Data<StatsStore>,
handle: web::Data<observability::PrometheusHandle>,
) -> HttpResponse {
let extra = vec![("aggregator_instance".to_string(), "default".to_string())];
let scraped = render_aggregated_metrics(store, extra.clone()).await;
let pushed = render_pushed_metrics(stats_store).await;
let combined = format!("{}{}{}", handle.render(), scraped, pushed);
HttpResponse::Ok()
.content_type("text/plain; version=0.0.4; charset=utf-8")
.body(combined)
}
async fn handle_health() -> HttpResponse {
HttpResponse::Ok()
.content_type("application/json")
.body(r#"{"status":"ok"}"#)
}
async fn handle_targets(targets: web::Data<Arc<RwLock<Vec<ScrapeTarget>>>>) -> HttpResponse {
let guard = targets.read().await;
let json = serde_json::to_string(&*guard).unwrap_or_default();
HttpResponse::Ok()
.content_type("application/json")
.body(json)
}
// ── Push endpoint payload ────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct PushPayload {
app: String,
#[serde(default)]
instance: String,
timestamp: i64,
#[serde(default)]
http: Option<observability::push::HttpPayload>,
#[serde(default)]
system: Option<observability::push::SystemPayload>,
#[serde(default)]
business: HashMap<String, f64>,
#[serde(default)]
token_usage: Option<observability::push::TokenUsagePayload>,
#[serde(default)]
tasks: Option<observability::push::TaskStatsPayload>,
#[serde(default)]
latency: HashMap<String, observability::push::LatencySnapshot>,
#[serde(default)]
logs: Vec<observability::push::LogEntry>,
}
async fn handle_push(
stats_store: web::Data<StatsStore>,
loki: web::Data<Option<Arc<LokiForwarder>>>,
payload: web::Json<PushPayload>,
) -> HttpResponse {
let app = payload.app.clone();
stats_store::merge_push_payload(
&stats_store,
&app,
&payload.instance,
payload.timestamp,
payload.http.as_ref(),
payload.system.as_ref(),
&payload.business,
payload.token_usage.as_ref(),
payload.tasks.as_ref(),
&payload.latency,
&payload.logs,
).await;
// Forward logs to Loki if configured
if !payload.logs.is_empty() {
if let Some(loki_fwd) = loki.as_ref() {
let entries: Vec<LokiEntry> = payload
.logs
.iter()
.map(|l| LokiEntry {
timestamp: chrono::DateTime::from_timestamp(l.timestamp, 0)
.unwrap_or_else(chrono::Utc::now),
line: format!("[{}] {}", l.level.to_lowercase(), l.message),
})
.collect();
if let Err(e) = loki_fwd.push(entries).await {
tracing::warn!(error = %e, "loki push on /push failed");
}
}
}
HttpResponse::Ok().body("ok")
}
async fn scrape_loop(
targets: Arc<RwLock<Vec<ScrapeTarget>>>,
store: MetricsStore,
metrics: AggMetrics,
http: HttpClient,
interval_secs: u64,
scrape_apps_filter: Option<Vec<String>>,
_loki: Option<LokiForwarder>,
mut shutdown: broadcast::Receiver<()>,
) {
let mut ticker = interval(Duration::from_secs(interval_secs));
loop {
tokio::select! {
_ = shutdown.recv() => break,
_ = ticker.tick() => {
let targets_snapshot = targets.read().await.clone();
let count = targets_snapshot.len() as u64;
metrics.targets_total.set(count as f64);
let mut healthy_count = 0u64;
for target in &targets_snapshot {
if let Some(ref filter) = scrape_apps_filter {
if !filter.contains(&target.name) {
continue;
}
}
metrics.scrape_total.increment(1);
match http.scrape(target).await {
ScrapeResult::Success(body, duration_ms) => {
metrics.scrape_success.increment(1);
metrics.scrape_duration.record(duration_ms);
let parsed = scrape::parse_prometheus(&body);
update_store(store.clone(), &target.name, parsed).await;
healthy_count += 1;
}
ScrapeResult::Timeout => {
metrics.scrape_failures.increment(1);
metrics.scrape_errors_timeout.increment(1);
tracing::warn!(target = %target.name, "scrape timeout");
}
ScrapeResult::ConnectionError(e) => {
metrics.scrape_failures.increment(1);
metrics.scrape_errors_connection.increment(1);
tracing::warn!(target = %target.name, error = %e, "scrape connection error");
}
ScrapeResult::HttpError(status) => {
metrics.scrape_failures.increment(1);
tracing::warn!(target = %target.name, status = status, "scrape HTTP error");
}
}
}
metrics.targets_healthy.set(healthy_count as f64);
}
}
}
}
async fn update_store(store: MetricsStore, target_name: &str, metrics: Vec<scrape::PromMetric>) {
let mut guard = store.write().await;
guard.insert(target_name.to_string(), metrics);
}
async fn render_aggregated_metrics(
store: web::Data<MetricsStore>,
extra_group_labels: Vec<(String, String)>,
) -> String {
let guard = store.read().await;
let mut output = String::new();
for (target_name, metrics) in guard.iter() {
for metric in metrics {
let mut labels = metric.labels.clone();
labels.insert(
"aggregated_by".to_string(),
"metrics-aggregator".to_string(),
);
labels.insert("source_target".to_string(), target_name.clone());
for (k, v) in &extra_group_labels {
labels.insert(k.clone(), v.clone());
}
let label_str = if labels.is_empty() {
String::new()
} else {
let pairs: Vec<String> = labels
.iter()
.map(|(k, v)| {
format!(
r#"{}="{}""#,
k,
v.replace('\\', "\\\\").replace('"', "\\\"")
)
})
.collect();
format!("{{{}}}", pairs.join(","))
};
let _ = writeln!(&mut output, "{}{} {}", metric.name, label_str, metric.value);
}
}
output
}
async fn render_pushed_metrics(stats_store: web::Data<StatsStore>) -> String {
let guard = stats_store.read().await;
let mut output = String::new();
for (app_name, entry) in guard.iter() {
let labels = [
format!(r#"app="{}""#, app_name),
"aggregated_by".to_string(),
"metrics-aggregator".to_string(),
"push_source=true".to_string(),
];
let label_str = format!("{{{}}}", labels.join(","));
let h = &entry;
let _ = writeln!(
&mut output,
"push_http_requests_total{} {}",
label_str,
h.requests_total
);
let _ = writeln!(
&mut output,
"push_http_request_duration_ms_total{} {}",
label_str,
h.request_duration_ms_total
);
let _ = writeln!(&mut output, "push_http_requests_2xx{} {}", label_str, h.requests_2xx);
let _ = writeln!(&mut output, "push_http_requests_4xx{} {}", label_str, h.requests_4xx);
let _ = writeln!(&mut output, "push_http_requests_5xx{} {}", label_str, h.requests_5xx);
for (endpoint, &count) in &h.endpoints {
let sanitized = endpoint.replace([' ', '/'], "_").to_lowercase();
let ep_labels = format!(r#"app="{}",endpoint="{}",aggregated_by="metrics-aggregator",push_source="true""#, app_name, sanitized);
let _ = writeln!(&mut output, "push_http_endpoint_requests_total{{{}}} {}", ep_labels, count);
}
// System metrics in Prometheus format
let sys_labels = format!(r#"app="{}",aggregated_by="metrics-aggregator""#, app_name);
let _ = writeln!(&mut output, "system_cpu_usage_percent{{{}}} {}", sys_labels, h.cpu_usage_percent);
let _ = writeln!(&mut output, "system_memory_used_mb{{{}}} {}", sys_labels, h.memory_used_mb);
let _ = writeln!(&mut output, "system_memory_total_mb{{{}}} {}", sys_labels, h.memory_total_mb);
let _ = writeln!(&mut output, "system_uptime_secs{{{}}} {}", sys_labels, h.uptime_secs);
// Business counters
for (counter_name, value) in &h.business {
let biz_labels = format!(r#"app="{}",aggregated_by="metrics-aggregator""#, app_name);
let _ = writeln!(&mut output, "{}{{{}}} {}", counter_name, biz_labels, value);
}
// Token usage
let ai_labels = format!(r#"app="{}",aggregated_by="metrics-aggregator""#, app_name);
let _ = writeln!(&mut output, "ai_input_tokens_total{{{}}} {}", ai_labels, h.ai_input_tokens_total);
let _ = writeln!(&mut output, "ai_output_tokens_total{{{}}} {}", ai_labels, h.ai_output_tokens_total);
let _ = writeln!(&mut output, "ai_calls_total{{{}}} {}", ai_labels, h.ai_calls_total);
// Latency per endpoint
for (endpoint, lat) in &h.latency {
let lat_labels = format!(r#"app="{}",endpoint="{}",aggregated_by="metrics-aggregator""#, app_name, endpoint);
let _ = writeln!(&mut output, "latency_p99_ms{{{}}} {}", lat_labels, lat.p99_ms);
let _ = writeln!(&mut output, "latency_p90_ms{{{}}} {}", lat_labels, lat.p90_ms);
let _ = writeln!(&mut output, "latency_p50_ms{{{}}} {}", lat_labels, lat.p50_ms);
let _ = writeln!(&mut output, "latency_max_ms{{{}}} {}", lat_labels, lat.max_ms);
}
}
output
}
// ── JSON API handlers ────────────────────────────────────────────────────────
async fn handle_dashboard(stats_store: web::Data<StatsStore>) -> HttpResponse {
let dashboard = stats_store::build_dashboard(&stats_store).await;
let json = serde_json::to_string(&dashboard).unwrap_or_default();
HttpResponse::Ok()
.content_type("application/json")
.body(json)
}
async fn handle_stats(stats_store: web::Data<StatsStore>) -> HttpResponse {
// Returns per-app stats as JSON
let guard = stats_store.read().await;
let json = serde_json::to_string(&*guard).unwrap_or_default();
HttpResponse::Ok()
.content_type("application/json")
.body(json)
}
async fn log_collector(loki: Option<LokiForwarder>, mut shutdown: broadcast::Receiver<()>) {
let stdin = tokio::io::stdin();
let mut reader = tokio::io::BufReader::new(stdin);
let mut interval_tick = interval(Duration::from_secs(1));
let mut batch: Vec<LokiEntry> = Vec::with_capacity(100);
let mut line_buf = String::new();
loop {
tokio::select! {
_ = shutdown.recv() => break,
_ = interval_tick.tick() => {
if !batch.is_empty() {
if let Some(ref loki) = loki {
if let Err(e) = loki.push(std::mem::take(&mut batch)).await {
tracing::warn!(error = %e, "Loki push failed");
}
}
}
}
_ = async { line_buf.clear(); reader.read_line(&mut line_buf).await.ok() } => {
if !line_buf.is_empty() {
let line = line_buf.trim_end().to_string();
if !line.is_empty() {
batch.push(LokiEntry {
timestamp: chrono::Utc::now(),
line,
});
if batch.len() >= 100 {
if let Some(ref loki) = loki {
if let Err(e) = loki.push(std::mem::take(&mut batch)).await {
tracing::warn!(error = %e, "Loki push failed");
}
}
}
}
}
}
}
}
}

View File

@ -0,0 +1,99 @@
use metrics::{describe_counter, describe_gauge, describe_histogram, Counter, Gauge, Histogram, Unit};
pub fn init() {
describe_gauge!(
"aggregator_targets_total",
Unit::Count,
"Total number of scrape targets known to the aggregator"
);
describe_gauge!(
"aggregator_targets_healthy",
Unit::Count,
"Number of scrape targets that responded last scrape"
);
describe_counter!(
"aggregator_scrape_total",
Unit::Count,
"Total number of scrape attempts"
);
describe_counter!(
"aggregator_scrape_success",
Unit::Count,
"Successful scrapes"
);
describe_counter!(
"aggregator_scrape_failures",
Unit::Count,
"Failed scrape attempts"
);
describe_counter!(
"aggregator_scrape_errors_parse",
Unit::Count,
"Scrape failures due to parse errors"
);
describe_counter!(
"aggregator_scrape_errors_timeout",
Unit::Count,
"Scrape failures due to timeout"
);
describe_counter!(
"aggregator_scrape_errors_connection",
Unit::Count,
"Scrape failures due to connection errors"
);
describe_counter!(
"aggregator_targets_discovered",
Unit::Count,
"Total targets discovered"
);
describe_counter!(
"aggregator_targets_lost",
Unit::Count,
"Total targets that disappeared"
);
describe_histogram!(
"aggregator_scrape_duration_ms",
Unit::Milliseconds,
"Scrape duration in milliseconds"
);
}
#[derive(Clone)]
#[allow(dead_code)]
pub struct AggMetrics {
pub targets_total: Gauge,
pub targets_healthy: Gauge,
pub scrape_total: Counter,
pub scrape_success: Counter,
pub scrape_failures: Counter,
pub scrape_errors_parse: Counter,
pub scrape_errors_timeout: Counter,
pub scrape_errors_connection: Counter,
pub targets_discovered: Counter,
pub targets_lost: Counter,
pub scrape_duration: Histogram,
}
impl Default for AggMetrics {
fn default() -> Self {
Self {
targets_total: metrics::gauge!("aggregator_targets_total"),
targets_healthy: metrics::gauge!("aggregator_targets_healthy"),
scrape_total: metrics::counter!("aggregator_scrape_total"),
scrape_success: metrics::counter!("aggregator_scrape_success"),
scrape_failures: metrics::counter!("aggregator_scrape_failures"),
scrape_errors_parse: metrics::counter!("aggregator_scrape_errors_parse"),
scrape_errors_timeout: metrics::counter!("aggregator_scrape_errors_timeout"),
scrape_errors_connection: metrics::counter!("aggregator_scrape_errors_connection"),
targets_discovered: metrics::counter!("aggregator_targets_discovered"),
targets_lost: metrics::counter!("aggregator_targets_lost"),
scrape_duration: metrics::histogram!("aggregator_scrape_duration_ms"),
}
}
}
impl AggMetrics {
pub fn new() -> Self {
Self::default()
}
}

40
apps/metrics/src/otel.rs Normal file
View File

@ -0,0 +1,40 @@
use anyhow::Context;
use opentelemetry::trace::TracerProvider;
use opentelemetry_otlp::{SpanExporter, WithExportConfig};
use opentelemetry_sdk::trace as sdktrace;
use tracing_opentelemetry::layer;
use tracing_subscriber::prelude::*;
pub struct OtelGuard {
provider: sdktrace::SdkTracerProvider,
}
impl OtelGuard {
pub async fn shutdown(self) {
if let Err(e) = self.provider.shutdown() {
tracing::warn!(error = %e, "OTLP shutdown error");
}
}
}
pub fn init_otel(endpoint: &str, service_name: &str) -> anyhow::Result<OtelGuard> {
let exporter = SpanExporter::builder()
.with_http()
.with_endpoint(endpoint)
.build()
.context("build OTLP exporter")?;
let tracer_provider = sdktrace::SdkTracerProvider::builder()
.with_batch_exporter(exporter)
.build();
let tracer = tracer_provider.tracer(service_name.to_string());
let otel_layer = layer().with_tracer(tracer);
tracing_subscriber::registry()
.with(otel_layer)
.try_init()
.context("install OTLP tracing subscriber")?;
Ok(OtelGuard { provider: tracer_provider })
}

135
apps/metrics/src/scrape.rs Normal file
View File

@ -0,0 +1,135 @@
use awc::Client;
use std::collections::HashMap;
use crate::target::ScrapeTarget;
#[derive(Clone)]
pub struct HttpClient {
client: Client,
}
impl HttpClient {
pub fn new(timeout_secs: u64) -> Self {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(timeout_secs))
.finish();
Self { client }
}
pub async fn scrape(&self, target: &ScrapeTarget) -> ScrapeResult {
let start = std::time::Instant::now();
let url = target.url();
let mut resp = match self.client.get(url).send().await {
Ok(resp) => resp,
Err(e) => {
let msg = e.to_string();
if msg.contains("timeout") || msg.contains("TimedOut") || msg.contains("timed out")
{
return ScrapeResult::Timeout;
}
return ScrapeResult::ConnectionError(msg);
}
};
if !resp.status().is_success() {
return ScrapeResult::HttpError(resp.status().as_u16());
}
let body = match resp.body().await {
Ok(bytes) => String::from_utf8_lossy(&bytes).into_owned(),
Err(e) => return ScrapeResult::ConnectionError(e.to_string()),
};
let scrape_ms = start.elapsed().as_millis() as f64;
ScrapeResult::Success(body, scrape_ms)
}
}
pub enum ScrapeResult {
Success(String, f64),
Timeout,
ConnectionError(String),
HttpError(u16),
}
#[derive(Clone, Debug)]
pub struct PromMetric {
pub name: String,
pub value: f64,
pub labels: HashMap<String, String>,
}
pub fn parse_prometheus(body: &str) -> Vec<PromMetric> {
let mut metrics = Vec::new();
for line in body.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let (name_and_labels, value_str) = match line.find(' ') {
Some(pos) => (&line[..pos], &line[pos + 1..]),
None => continue,
};
let value: f64 = match value_str
.split_whitespace()
.next()
.and_then(|v| v.parse().ok())
{
Some(v) => v,
None => continue,
};
let (metric_name, labels) = if let Some(brace) = name_and_labels.find('{') {
let name = &name_and_labels[..brace];
let label_str = &name_and_labels[brace + 1..name_and_labels.len() - 1];
let labels = parse_labels(label_str);
(name.to_string(), labels)
} else {
(name_and_labels.to_string(), HashMap::new())
};
metrics.push(PromMetric {
name: metric_name,
value,
labels,
});
}
metrics
}
pub fn parse_labels(s: &str) -> HashMap<String, String> {
let mut labels = HashMap::new();
let mut remaining = s;
while !remaining.is_empty() {
if let Some(eq) = remaining.find('=') {
let key = remaining[..eq].trim().to_string();
remaining = &remaining[eq + 1..];
let (value, rest) = if remaining.starts_with('"') {
let end = remaining[1..]
.find('"')
.map(|p| p + 1)
.unwrap_or(remaining.len());
(&remaining[1..end], &remaining[end + 1..])
} else if remaining.starts_with('\'') {
let end = remaining[1..]
.find('\'')
.map(|p| p + 1)
.unwrap_or(remaining.len());
(&remaining[1..end], &remaining[end + 1..])
} else {
let end = remaining
.find(|c: char| !c.is_alphanumeric() && c != '_' && c != '-')
.unwrap_or(remaining.len());
(&remaining[..end], &remaining[end..])
};
labels.insert(key, value.to_string());
remaining = rest.trim_start_matches(',').trim_start();
} else {
break;
}
}
labels
}

View File

@ -0,0 +1,210 @@
//! Stats store: receives expanded push payloads from all apps,
//! aggregates over time, computes derived statistics (p99 etc),
//! and provides JSON API for external consumption.
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use serde::Serialize;
/// Per-app, per-instance aggregated stats entry.
#[derive(Debug, Clone, Default, Serialize)]
pub struct AppStats {
/// Last seen timestamp.
pub last_seen: i64,
/// Number of push samples received.
pub sample_count: u64,
// ── HTTP ─────────────────────────────────────────────────────
pub requests_total: u64,
pub request_duration_ms_total: u64,
pub requests_2xx: u64,
pub requests_4xx: u64,
pub requests_5xx: u64,
pub endpoints: HashMap<String, u64>,
// ── System ───────────────────────────────────────────────────
pub cpu_usage_percent: f32,
pub memory_used_mb: u64,
pub memory_total_mb: u64,
pub uptime_secs: u64,
// ── Business counters ────────────────────────────────────────
pub business: HashMap<String, f64>,
// ── Token usage ──────────────────────────────────────────────
pub ai_input_tokens_total: i64,
pub ai_output_tokens_total: i64,
pub ai_calls_total: i64,
pub ai_calls_success: i64,
pub ai_calls_failure: i64,
pub token_by_model: HashMap<String, ModelTokenStats>,
// ── Tasks ────────────────────────────────────────────────────
pub tasks_queued: i64,
pub tasks_running: i64,
pub tasks_completed: i64,
pub tasks_failed: i64,
// ── Latency ──────────────────────────────────────────────────
pub latency: HashMap<String, LatencyStats>,
// ── Logs ─────────────────────────────────────────────────────
#[serde(skip_serializing)]
pub logs: Vec<(i64, String)>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct ModelTokenStats {
pub input_tokens: i64,
pub output_tokens: i64,
pub calls: i64,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct LatencyStats {
pub p50_ms: f64,
pub p90_ms: f64,
pub p99_ms: f64,
pub max_ms: f64,
pub count: u64,
}
/// The global stats store: app_name → AppStats.
pub type StatsStore = Arc<RwLock<HashMap<String, AppStats>>>;
/// Merge a new push payload into the stats store.
pub async fn merge_push_payload(
store: &StatsStore,
app: &str,
_instance: &str,
timestamp: i64,
http: Option<&observability::push::HttpPayload>,
system: Option<&observability::push::SystemPayload>,
business: &HashMap<String, f64>,
token_usage: Option<&observability::push::TokenUsagePayload>,
tasks: Option<&observability::push::TaskStatsPayload>,
latency: &HashMap<String, observability::push::LatencySnapshot>,
logs: &[observability::push::LogEntry],
) {
// Use app_name as key (merge across instances for aggregation)
let mut guard = store.write().await;
let entry = guard.entry(app.to_string()).or_default();
entry.last_seen = timestamp;
entry.sample_count += 1;
// HTTP — accumulate (not replace, so we get totals over time)
if let Some(http) = http {
entry.requests_total = http.requests_total;
entry.request_duration_ms_total = http.request_duration_ms_total;
entry.requests_2xx = http.requests_2xx;
entry.requests_4xx = http.requests_4xx;
entry.requests_5xx = http.requests_5xx;
for (ep, count) in &http.endpoints {
*entry.endpoints.entry(ep.clone()).or_insert(0) = *count;
}
}
// System — replace (current snapshot, not cumulative)
if let Some(sys) = system {
entry.cpu_usage_percent = sys.cpu_usage_percent;
entry.memory_used_mb = sys.memory_used_mb;
entry.memory_total_mb = sys.memory_total_mb;
entry.uptime_secs = sys.uptime_secs;
}
// Business — replace with latest snapshot
entry.business = business.clone();
// Token usage — replace with latest
if let Some(tu) = token_usage {
entry.ai_input_tokens_total = tu.ai_input_tokens_total;
entry.ai_output_tokens_total = tu.ai_output_tokens_total;
entry.ai_calls_total = tu.ai_calls_total;
entry.ai_calls_success = tu.ai_calls_success;
entry.ai_calls_failure = tu.ai_calls_failure;
for (model, usage) in &tu.by_model {
let ms = entry.token_by_model.entry(model.clone()).or_default();
ms.input_tokens = usage.input_tokens;
ms.output_tokens = usage.output_tokens;
ms.calls = usage.calls;
}
}
// Tasks — replace with latest
if let Some(t) = tasks {
entry.tasks_queued = t.queued;
entry.tasks_running = t.running;
entry.tasks_completed = t.completed;
entry.tasks_failed = t.failed;
}
// Latency — replace with latest snapshots
for (endpoint, snap) in latency {
let ls = entry.latency.entry(endpoint.clone()).or_default();
ls.p50_ms = snap.p50_ms;
ls.p90_ms = snap.p90_ms;
ls.p99_ms = snap.p99_ms;
ls.max_ms = snap.max_ms;
ls.count = snap.count;
}
// Logs — append (keep last 300 lines)
for log in logs {
entry.logs.push((log.timestamp, format!("[{}] {}", log.level.to_lowercase(), log.message)));
}
let cutoff = chrono::Utc::now().timestamp() - 300;
entry.logs.retain(|(ts, _)| *ts >= cutoff);
}
/// Dashboard response combining all apps' stats.
#[derive(Debug, Serialize)]
pub struct DashboardResponse {
/// Timestamp of this snapshot.
pub timestamp: i64,
/// Total number of app instances reporting.
pub app_count: u64,
/// Per-app aggregated stats.
pub apps: HashMap<String, AppStats>,
/// Derived: average p99 latency across all apps.
pub avg_p99_ms: f64,
/// Derived: total tokens consumed across all apps.
pub total_input_tokens: i64,
pub total_output_tokens: i64,
/// Derived: total AI calls across all apps.
pub total_ai_calls: i64,
}
/// Build the dashboard response from the stats store.
pub async fn build_dashboard(store: &StatsStore) -> DashboardResponse {
let guard = store.read().await;
let mut avg_p99 = 0.0;
let mut p99_count = 0;
let mut total_input = 0i64;
let mut total_output = 0i64;
let mut total_calls = 0i64;
for (_, stats) in guard.iter() {
total_input += stats.ai_input_tokens_total;
total_output += stats.ai_output_tokens_total;
total_calls += stats.ai_calls_total;
for (_, lat) in &stats.latency {
avg_p99 += lat.p99_ms;
p99_count += 1;
}
}
let avg_p99_ms = if p99_count > 0 { avg_p99 / p99_count as f64 } else { 0.0 };
DashboardResponse {
timestamp: chrono::Utc::now().timestamp(),
app_count: guard.len() as u64,
apps: guard.clone(),
avg_p99_ms,
total_input_tokens: total_input,
total_output_tokens: total_output,
total_ai_calls: total_calls,
}
}

View File

@ -0,0 +1,34 @@
use std::collections::HashMap;
use anyhow::Context;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ScrapeTarget {
pub name: String,
pub addr: String,
#[serde(default = "default_metrics_path")]
pub metrics_path: String,
#[serde(default)]
pub labels: HashMap<String, String>,
}
fn default_metrics_path() -> String {
"/metrics".to_string()
}
impl ScrapeTarget {
pub fn url(&self) -> String {
if self.metrics_path.starts_with("http") {
self.metrics_path.clone()
} else {
format!("http://{}{}", self.addr, self.metrics_path)
}
}
}
pub async fn load_targets_from_file(path: &str) -> anyhow::Result<Vec<ScrapeTarget>> {
let content = tokio::fs::read_to_string(path).await.context("read targets file")?;
let targets: Vec<ScrapeTarget> = serde_json::from_str(&content)
.with_context(|| format!("parse targets file {path}"))?;
Ok(targets)
}

View File

@ -7,6 +7,8 @@ edition.workspace = true
actix-web = { workspace = true } actix-web = { workspace = true }
actix-files = { workspace = true } actix-files = { workspace = true }
actix-cors = { workspace = true } actix-cors = { workspace = true }
observability = { workspace = true }
metrics-exporter-prometheus = "0.13"
tokio = { workspace = true, features = ["full"] } tokio = { workspace = true, features = ["full"] }
futures = { workspace = true } futures = { workspace = true }
serde = { workspace = true } serde = { workspace = true }

View File

@ -5,8 +5,10 @@ use actix_web::{http::header, web, App, HttpResponse, HttpServer};
use futures::future::LocalBoxFuture; use futures::future::LocalBoxFuture;
use log::info; use log::info;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time::Instant; use std::time::Instant;
use observability::{init_tracing_subscriber, install_recorder, HttpMetrics, push::MetricsPusher};
/// Static file server for avatar, blob, and other static files /// Static file server for avatar, blob, and other static files
/// Serves files from /data/{type} directories /// Serves files from /data/{type} directories
@ -119,7 +121,16 @@ where
#[actix_web::main] #[actix_web::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); init_tracing_subscriber("info", false);
let prometheus_handle = Arc::new(install_recorder());
let http_metrics = Arc::new(HttpMetrics::new());
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
let pusher = MetricsPusher::new(&push_url, "static");
pusher.spawn(http_metrics.clone(), prometheus_handle.clone(), std::time::Duration::from_secs(15));
info!("Metrics pusher started (interval 15s, url: {})", push_url);
}
let cfg = StaticConfig::from_env(); let cfg = StaticConfig::from_env();
let bind = std::env::var("STATIC_BIND").unwrap_or_else(|_| "0.0.0.0:8081".to_string()); let bind = std::env::var("STATIC_BIND").unwrap_or_else(|_| "0.0.0.0:8081".to_string());
@ -142,6 +153,8 @@ async fn main() -> anyhow::Result<()> {
let root = root.clone(); let root = root.clone();
let cors = if cors_enabled { let cors = if cors_enabled {
// WARNING: allow_any_origin is intentional for static asset serving (CDN mode)
// Ensure no sensitive files are served from this directory
Cors::default() Cors::default()
.allow_any_origin() .allow_any_origin()
.allowed_methods(vec!["GET", "HEAD", "OPTIONS"]) .allowed_methods(vec!["GET", "HEAD", "OPTIONS"])

88
build.sh Normal file
View File

@ -0,0 +1,88 @@
#!/usr/bin/env bash
set -euo pipefail
# ── helpers ──────────────────────────────────────────────────────────
RED='\033[0;31m'; GREEN='\033[0;32m'; YELLOW='\033[1;33m'; NC='\033[0m'
log() { echo -e "${GREEN}[OK]${NC} $*"; }
warn() { echo -e "${YELLOW}[WARN]${NC} $*"; }
err() { echo -e "${RED}[ERR]${NC} $*"; exit 1; }
command_exists() { command -v "$1" &>/dev/null; }
# ── 1. Rust ─────────────────────────────────────────────────────────
if command_exists rustc; then
log "Rust $(rustc --version)"
else
warn "Rust not found, installing via rustup..."
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
# shellcheck disable=SC1091
source "$HOME/.cargo/env"
log "Rust installed: $(rustc --version)"
fi
# ── 2. Node.js ──────────────────────────────────────────────────────
if command_exists node; then
log "Node.js $(node --version)"
else
warn "Node.js not found, installing via nvm..."
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.1/install.sh | bash
# shellcheck disable=SC1090
export NVM_DIR="${HOME}/.nvm"
[ -s "$NVM_DIR/nvm.sh" ] && source "$NVM_DIR/nvm.sh"
nvm install --lts
log "Node.js installed: $(node --version)"
fi
# ── 2b. Bun ─────────────────────────────────────────────────────────
if command_exists bun; then
log "Bun $(bun --version)"
else
warn "Bun not found, installing..."
curl -fsSL https://bun.sh/install | bash
# shellcheck disable=SC1091
[ -s "$HOME/.bun/_bun" ] && export PATH="$HOME/.bun/bin:$PATH"
log "Bun installed: $(bun --version)"
fi
# ── 3. Docker ───────────────────────────────────────────────────────
if command_exists docker; then
log "Docker $(docker --version)"
else
warn "Docker not found, installing..."
curl -fsSL https://get.docker.com | sh
log "Docker installed: $(docker --version)"
fi
# ── 4. Frontend build ───────────────────────────────────────────────
log "Running bun install..."
bun install
log "Running bun run build..."
bun run build
# ── 5. Rust build ───────────────────────────────────────────────────
log "Running cargo build --release --workspace..."
cargo build --release --workspace
# ── 6. Docker images ────────────────────────────────────────────────
TAG=$(git rev-parse --short HEAD)
log "Building Docker images with tag: $TAG"
IMAGES=(
"docker/app.Dockerfile app:$TAG"
"docker/email.Dockerfile email-worker:$TAG"
"docker/githook.Dockerfile git-hook:$TAG"
"docker/gitserver.Dockerfile gitserver:$TAG"
"docker/metrics.Dockerfile metrics-aggregator:$TAG"
"docker/static.Dockerfile static-server:$TAG"
"docker/gingress.Dockerfile gingress:$TAG"
)
for entry in "${IMAGES[@]}"; do
read -r dockerfile tag <<< "$entry"
log "Building $tag..."
docker build -f "$dockerfile" -t "$tag" .
done
log "All images built successfully."
docker images | grep -E "app|email-worker|git-hook|gitserver|metrics-aggregator|static-server|gingress" | grep "$TAG" || true

1016
bun.lock

File diff suppressed because it is too large Load Diff

23
deploy/.helmignore Normal file
View File

@ -0,0 +1,23 @@
# Patterns to ignore when building packages.
# This supports shell glob matching, relative path matching, and
# negation (prefixed with !). Only one pattern per line.
.DS_Store
# Common VCS dirs
.git/
.gitignore
.bzr/
.bzrignore
.hg/
.hgignore
.svn/
# Common backup files
*.swp
*.bak
*.tmp
*.orig
*~
# Various IDEs
.project
.idea/
*.tmproj
.vscode/

6
deploy/Chart.yaml Normal file
View File

@ -0,0 +1,6 @@
apiVersion: v2
name: deploy
description: Helm chart for the project backend services
type: application
version: 0.1.0
appVersion: "0.2.9"

198
deploy/README.md Normal file
View File

@ -0,0 +1,198 @@
# Deploy Helm Chart
Monolithic Helm chart for all backend services.
## Services
| Service | Port(s) | Replicas | HPA | Purpose |
|---|---|---|---|---|
| `app` | 3000 (HTTP) | 2 | 210 | Main API server |
| `gitserver` | 8021 (HTTP), 2222 (SSH) | 1 | 15 | Git HTTP + SSH server |
| `email_worker` | 8084 (HTTP) | 1 | disabled | Email queue consumer (single instance only) |
| `git_hook` | 8083 (HTTP) | 1 | 15 | Git hook worker pool |
| `metrics_aggregator` | 9090 (HTTP) | 1 | 15 | Prometheus scrape + Loki push |
| `static_server` | 8081 (HTTP) | 1 | 15 | Static file server (avatars, blobs, media) |
## Prerequisites
The following resources must exist in the cluster **before** installing the Helm chart. They are not managed by Helm — install, upgrade, and uninstall of the chart will not touch them.
### 1. Namespace
```bash
kubectl create namespace app
```
### 2. PVC (aliyun-nfs, 200Ti, ReadWriteMany)
```bash
kubectl apply -f - <<'EOF'
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: shared-data
namespace: app
spec:
accessModes:
- ReadWriteMany
resources:
requests:
storage: 200Ti
storageClassName: aliyun-nfs
EOF
```
> The chart references this PVC by name. If you use a different name, pass `--set pvcName=your-pvc-name` to Helm.
### 3. ConfigMap
```bash
kubectl apply -f - <<'EOF'
apiVersion: v1
kind: ConfigMap
metadata:
name: app-env
namespace: app
data:
APP_REPOS_ROOT: "/data/repos"
APP_AVATAR_PATH: "/data/avatars"
STORAGE_PATH: "/data/files"
STATIC_ROOT: "/data"
APP_LOG_LEVEL: "info"
APP_COOKIE_SECURE: "false"
APP_DOMAIN_URL: "https://your-domain.com"
APP_DATABASE_URL: "postgres://user:pass@postgres:5432/app"
APP_REDIS_URL: "redis://redis:6379"
APP_AI_BASIC_URL: "https://api.openai.com/v1"
APP_AI_API_KEY: "sk-..."
APP_SMTP_PASSWORD: "..."
APP_SESSION_SECRET: "min-32-byte-random-string..."
APP_SSH_SERVER_PRIVATE_KEY: "<hex-encoded-private-key>"
EOF
```
| Variable | Default / Example | Required |
|---|---|---|
| `APP_REPOS_ROOT` | `/data/repos` | Yes |
| `APP_AVATAR_PATH` | `/data/avatars` | Yes |
| `STORAGE_PATH` | `/data/files` | Yes |
| `STATIC_ROOT` | `/data` | Yes |
| `APP_LOG_LEVEL` | `info` | No |
| `APP_COOKIE_SECURE` | `false` | No |
| `APP_DOMAIN_URL` | `https://your-domain.com` | Yes |
| `APP_DATABASE_URL` | `postgres://...` | **Yes** |
| `APP_REDIS_URL` | `redis://...` | **Yes** |
| `APP_AI_BASIC_URL` | `https://api.openai.com/v1` | **Yes** |
| `APP_AI_API_KEY` | `sk-...` | **Yes** |
| `APP_SMTP_PASSWORD` | `...` | **Yes** |
| `APP_SESSION_SECRET` | min 32 bytes | **Yes** |
| `APP_SSH_SERVER_PRIVATE_KEY` | hex-encoded PEM | **Yes** |
| `APP_SSH_PORT` | `2222` | Yes (k8s) |
> **SSH host key**: `APP_SSH_SERVER_PRIVATE_KEY` must be the hex-encoded Ed25519 private key PEM bytes.
> ```bash
> ssh-keygen -t ed25519 -f /tmp/ssh_host_key -N ""
> hexdump -v -e '/1 "%02x"' < /tmp/ssh_host_key
> ```
>
> **Session secret**: generate 48 random bytes:
> ```bash
> openssl rand -base64 48
> ```
>
> Override the ConfigMap name with `--set configMapName=your-cm-name`.
### 4. Verify prerequisites
```bash
kubectl get namespace app
kubectl get pvc -n app shared-data
kubectl get configmap -n app app-env
```
## Quick Start
```bash
helm template deploy ./deploy --namespace app --set imageRegistry=ghcr.io/your-org
helm lint ./deploy
# Install
helm upgrade --install deploy ./deploy \
--namespace app \
--set imageRegistry=ghcr.io/your-org \
--set imageTag=v0.2.9
```
## Storage
All services share a single PVC (`shared-data`) via `subPath` mounts:
| SubPath | Mount | Used By |
|---|---|---|
| `repos` | `/data/repos` | app, gitserver, git-hook |
| `avatars` | `/data/avatars` | app |
| `files` | `/data/files` | app |
| `static` | `/data` | static-server |
## Autoscaling
All services except `email_worker` have HPA enabled by default. The email worker is fixed at 1 replica and must not be scaled.
To adjust HPA bounds per service:
```bash
--set services.app.autoscaling.maxReplicas=20
--set services.app.autoscaling.targetCPUUtilization=70
```
To disable HPA for a service:
```bash
--set services.git_hook.autoscaling.enabled=false
```
## Ingress
```bash
helm upgrade --install deploy ./deploy \
--namespace app \
--set ingress.enabled=true \
--set ingress.className=nginx \
--set ingress.hosts[0].host=your-domain.com
```
## Dependencies
All services require these to be reachable from the cluster:
- PostgreSQL (via `APP_DATABASE_URL`)
- Redis (via `APP_REDIS_URL`)
- Git binary (included in all Docker images)
- OpenAI-compatible API (via `APP_AI_BASIC_URL` + `APP_AI_API_KEY`)
- Qdrant vector DB (via `APP_QDRANT_URL`)
- SMTP server (via `APP_SMTP_*`)
- Embedding model (via `APP_EMBED_MODEL_*`)
Optional dependencies with graceful degradation:
| Dependency | Variable | Fallback |
|---|---|---|
| NATS JetStream | `NATS_URL` + `NATS_TOKEN` | Redis queue |
| Loki | `LOKI_URL` | Logs discarded |
| OTEL Collector | `OTEL_EXPORTER_OTLP_ENDPOINT` | Tracing disabled |
## Production Example
```bash
helm upgrade --install deploy ./deploy \
--namespace app \
--set imageRegistry=ghcr.io/your-org \
--set imageTag=v0.2.9 \
--set services.app.replicas=3 \
--set services.app.autoscaling.maxReplicas=20 \
--set ingress.enabled=true \
--set ingress.className=nginx \
--set ingress.hosts[0].host=your-domain.com \
--set configMapName=app-env \
--set pvcName=shared-data
```

View File

@ -0,0 +1,93 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: gingress-controller
namespace: gingress-system
labels:
app: gingress
spec:
replicas: 2
selector:
matchLabels:
app: gingress
template:
metadata:
labels:
app: gingress
spec:
serviceAccountName: gingress-controller
containers:
- name: gingress
image: gingress:latest
imagePullPolicy: IfNotPresent
args:
- "--ingress-class=gingress"
- "--bind-http=0.0.0.0:80"
- "--bind-https=0.0.0.0:443"
- "--metrics-bind=0.0.0.0:8080"
ports:
- name: http
containerPort: 80
protocol: TCP
- name: https
containerPort: 443
protocol: TCP
- name: metrics
containerPort: 8080
protocol: TCP
env:
- name: RUST_LOG
value: "info"
- name: METRICS_PUSH_URL
value: "" # Optional: push to metrics aggregator
livenessProbe:
httpGet:
path: /healthz
port: 8080
initialDelaySeconds: 10
periodSeconds: 10
readinessProbe:
httpGet:
path: /readyz
port: 8080
initialDelaySeconds: 5
periodSeconds: 5
resources:
requests:
cpu: 100m
memory: 128Mi
limits:
cpu: 500m
memory: 512Mi
affinity:
podAntiAffinity:
preferredDuringSchedulingIgnoredDuringExecution:
- weight: 100
podAffinityTerm:
labelSelector:
matchLabels:
app: gingress
topologyKey: kubernetes.io/hostname
---
apiVersion: v1
kind: Service
metadata:
name: gingress
namespace: gingress-system
spec:
type: LoadBalancer
selector:
app: gingress
ports:
- name: http
port: 80
targetPort: 80
protocol: TCP
- name: https
port: 443
targetPort: 443
protocol: TCP
- name: metrics
port: 8080
targetPort: 8080
protocol: TCP

48
deploy/gingress/rbac.yaml Normal file
View File

@ -0,0 +1,48 @@
apiVersion: v1
kind: Namespace
metadata:
name: gingress-system
---
apiVersion: v1
kind: ServiceAccount
metadata:
name: gingress-controller
namespace: gingress-system
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
name: gingress-controller
rules:
- apiGroups: ["networking.k8s.io"]
resources: ["ingresses", "ingressclasses"]
verbs: ["get", "list", "watch"]
- apiGroups: ["networking.k8s.io"]
resources: ["ingresses/status"]
verbs: ["update", "patch"]
- apiGroups: [""]
resources: ["services", "endpoints", "endpointslices", "secrets", "nodes"]
verbs: ["get", "list", "watch"]
- apiGroups: ["discovery.k8s.io"]
resources: ["endpointslices"]
verbs: ["get", "list", "watch"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: gingress-controller
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: gingress-controller
subjects:
- kind: ServiceAccount
name: gingress-controller
namespace: gingress-system
---
apiVersion: networking.k8s.io/v1
kind: IngressClass
metadata:
name: gingress
spec:
controller: gingress.io/gingress-controller

View File

@ -0,0 +1,19 @@
Project backend services deployed to namespace: {{ .Release.Namespace }}
Services:
{{- range $svcKey, $svcVal := .Values.services }}
{{ $svcKey | replace "_" "-" }}: {{ if $svcVal.ports }}{{ range $portName, $portNum := $svcVal.ports }}{{ $portName }}={{ $portNum }} {{ end }}{{ else }}port={{ $svcVal.port }}{{ end }} {{ if $svcVal.autoscaling.enabled }}(HPA: {{ $svcVal.autoscaling.minReplicas }}-{{ $svcVal.autoscaling.maxReplicas }}){{ else }}(static: {{ $svcVal.replicaCount }}){{ end }}
{{- end }}
To access the app locally:
kubectl port-forward -n {{ .Release.Namespace }} svc/{{ include "deploy.serviceFullname" (dict "root" . "svcKey" "app") }} 3000:3000
To check HPA status:
{{- range $svcKey, $svcVal := .Values.services }}
{{- if $svcVal.autoscaling.enabled }}
kubectl get hpa -n {{ $.Release.Namespace }} {{ include "deploy.serviceFullname" (dict "root" $ "svcKey" $svcKey) }}
{{- end }}
{{- end }}
To check all pods:
kubectl get pods -n {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "deploy.name" . }}"

View File

@ -0,0 +1,78 @@
{{/*
Expand the name of the chart.
*/}}
{{- define "deploy.name" -}}
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Create a default fully qualified app name.
*/}}
{{- define "deploy.fullname" -}}
{{- if .Values.fullnameOverride }}
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- $name := default .Chart.Name .Values.nameOverride }}
{{- if contains $name .Release.Name }}
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
{{- end }}
{{- end }}
{{- end }}
{{/*
Service fullname — includes service key for per-service resources.
Underscores in svcKey are replaced with hyphens for valid Kubernetes names.
*/}}
{{- define "deploy.serviceFullname" -}}
{{- printf "%s-%s" (include "deploy.fullname" .root) (.svcKey | replace "_" "-") | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Chart name and version as used by the chart label.
*/}}
{{- define "deploy.chart" -}}
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Common labels
*/}}
{{- define "deploy.labels" -}}
helm.sh/chart: {{ include "deploy.chart" . }}
{{ include "deploy.selectorLabels" . }}
{{- if .Chart.AppVersion }}
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
{{- end }}
app.kubernetes.io/managed-by: {{ .Release.Service }}
{{- end }}
{{/*
Selector labels
*/}}
{{- define "deploy.selectorLabels" -}}
app.kubernetes.io/name: {{ include "deploy.name" . }}
app.kubernetes.io/instance: {{ .Release.Name }}
{{- end }}
{{/*
Per-service selector labels — used by Service to target the right Deployment.
Underscores in svcKey are replaced with hyphens for valid Kubernetes label values.
*/}}
{{- define "deploy.serviceSelectorLabels" -}}
app.kubernetes.io/name: {{ include "deploy.name" .root }}
app.kubernetes.io/instance: {{ .root.Release.Name }}
app.kubernetes.io/component: {{ .svcKey | replace "_" "-" }}
{{- end }}
{{/*
Create the name of the service account to use
*/}}
{{- define "deploy.serviceAccountName" -}}
{{- if .Values.serviceAccount.create }}
{{- default (include "deploy.fullname" .) .Values.serviceAccount.name }}
{{- else }}
{{- default "default" .Values.serviceAccount.name }}
{{- end }}
{{- end }}

View File

@ -0,0 +1,89 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "app") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: app
spec:
replicas: {{ .Values.services.app.replicaCount | default 1 }}
selector:
matchLabels:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "app") | nindent 6 }}
template:
metadata:
labels:
{{- include "deploy.labels" . | nindent 8 }}
app.kubernetes.io/component: app
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
{{- with .Values.podSecurityContext }}
securityContext:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: app
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 12 }}
{{- end }}
image: "{{ .Values.imageRegistry }}/{{ .Values.services.app.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
imagePullPolicy: IfNotPresent
{{- with .Values.services.app.command }}
command:
{{- toYaml . | nindent 12 }}
{{- end }}
ports:
- name: http
containerPort: {{ .Values.services.app.port }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.configMapName }}
{{- with .Values.services.app.extraEnv }}
env:
{{- range $key, $val := . }}
- name: {{ $key }}
value: {{ $val | quote }}
{{- end }}
{{- end }}
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 15
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 10
{{- with .Values.services.app.resources }}
resources:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.services.app.volumeMounts }}
volumeMounts:
{{- toYaml . | nindent 12 }}
{{- end }}
volumes:
- name: shared-data
persistentVolumeClaim:
claimName: {{ .Values.pvcName }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@ -0,0 +1,16 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "app") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: app
spec:
type: ClusterIP
ports:
- port: {{ .Values.services.app.port }}
targetPort: http
protocol: TCP
name: http
selector:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "app") | nindent 4 }}

View File

@ -0,0 +1,70 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "email_worker") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: email-worker
spec:
replicas: {{ .Values.services.email_worker.replicaCount | default 1 }}
selector:
matchLabels:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "email_worker") | nindent 6 }}
template:
metadata:
labels:
{{- include "deploy.labels" . | nindent 8 }}
app.kubernetes.io/component: email-worker
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
{{- with .Values.podSecurityContext }}
securityContext:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: email_worker
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 12 }}
{{- end }}
image: "{{ .Values.imageRegistry }}/{{ .Values.services.email_worker.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
imagePullPolicy: IfNotPresent
ports:
- name: http
containerPort: {{ .Values.services.email_worker.port }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.configMapName }}
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 15
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 10
{{- with .Values.services.email_worker.resources }}
resources:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@ -0,0 +1,16 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "email_worker") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: email-worker
spec:
type: ClusterIP
ports:
- port: {{ .Values.services.email_worker.port }}
targetPort: http
protocol: TCP
name: http
selector:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "email_worker") | nindent 4 }}

View File

@ -0,0 +1,78 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "git_hook") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: git-hook
spec:
replicas: {{ .Values.services.git_hook.replicaCount | default 1 }}
selector:
matchLabels:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "git_hook") | nindent 6 }}
template:
metadata:
labels:
{{- include "deploy.labels" . | nindent 8 }}
app.kubernetes.io/component: git-hook
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
{{- with .Values.podSecurityContext }}
securityContext:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: git-hook
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 12 }}
{{- end }}
image: "{{ .Values.imageRegistry }}/{{ .Values.services.git_hook.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
imagePullPolicy: IfNotPresent
ports:
- name: http
containerPort: {{ .Values.services.git_hook.port }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.configMapName }}
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 15
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 10
{{- with .Values.services.git_hook.resources }}
resources:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.services.git_hook.volumeMounts }}
volumeMounts:
{{- toYaml . | nindent 12 }}
{{- end }}
volumes:
- name: shared-data
persistentVolumeClaim:
claimName: {{ .Values.pvcName }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@ -0,0 +1,16 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "git_hook") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: git-hook
spec:
type: ClusterIP
ports:
- port: {{ .Values.services.git_hook.port }}
targetPort: http
protocol: TCP
name: http
selector:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "git_hook") | nindent 4 }}

View File

@ -0,0 +1,88 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "gitserver") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: gitserver
spec:
replicas: {{ .Values.services.gitserver.replicaCount | default 1 }}
selector:
matchLabels:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "gitserver") | nindent 6 }}
template:
metadata:
labels:
{{- include "deploy.labels" . | nindent 8 }}
app.kubernetes.io/component: gitserver
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
{{- with .Values.podSecurityContext }}
securityContext:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: gitserver
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 12 }}
{{- end }}
image: "{{ .Values.imageRegistry }}/{{ .Values.services.gitserver.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
imagePullPolicy: IfNotPresent
ports:
- name: http
containerPort: {{ .Values.services.gitserver.ports.http }}
protocol: TCP
- name: ssh
containerPort: {{ .Values.services.gitserver.ports.ssh }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.configMapName }}
{{- with .Values.services.gitserver.extraEnv }}
env:
{{- range $key, $val := . }}
- name: {{ $key }}
value: {{ $val | quote }}
{{- end }}
{{- end }}
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 15
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 10
{{- with .Values.services.gitserver.resources }}
resources:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.services.gitserver.volumeMounts }}
volumeMounts:
{{- toYaml . | nindent 12 }}
{{- end }}
volumes:
- name: shared-data
persistentVolumeClaim:
claimName: {{ .Values.pvcName }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@ -0,0 +1,20 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "gitserver") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: gitserver
spec:
type: ClusterIP
ports:
- port: {{ .Values.services.gitserver.ports.http }}
targetPort: http
protocol: TCP
name: http
- port: {{ .Values.services.gitserver.ports.ssh }}
targetPort: ssh
protocol: TCP
name: ssh
selector:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "gitserver") | nindent 4 }}

26
deploy/templates/hpa.yaml Normal file
View File

@ -0,0 +1,26 @@
{{- range $svcKey, $svcVal := .Values.services }}
{{- if $svcVal.autoscaling.enabled }}
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" $ "svcKey" $svcKey) }}
labels:
{{- include "deploy.labels" $ | nindent 4 }}
app.kubernetes.io/component: {{ $svcKey | replace "_" "-" }}
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: {{ include "deploy.serviceFullname" (dict "root" $ "svcKey" $svcKey) }}
minReplicas: {{ $svcVal.autoscaling.minReplicas }}
maxReplicas: {{ $svcVal.autoscaling.maxReplicas }}
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: {{ $svcVal.autoscaling.targetCPUUtilization }}
{{- end }}
{{- end }}

View File

@ -0,0 +1,41 @@
{{- if .Values.ingress.enabled -}}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ include "deploy.fullname" . }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
{{- with .Values.ingress.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
{{- with .Values.ingress.className }}
ingressClassName: {{ . }}
{{- end }}
{{- if .Values.ingress.tls }}
tls:
{{- range .Values.ingress.tls }}
- hosts:
{{- range .hosts }}
- {{ . | quote }}
{{- end }}
secretName: {{ .secretName }}
{{- end }}
{{- end }}
rules:
{{- range .Values.ingress.hosts }}
- host: {{ .host | quote }}
http:
paths:
{{- range .paths }}
- path: {{ .path }}
pathType: {{ .pathType }}
backend:
service:
name: {{ include "deploy.serviceFullname" (dict "root" $ "svcKey" .serviceName) }}
port:
number: {{ .servicePort }}
{{- end }}
{{- end }}
{{- end }}

View File

@ -0,0 +1,70 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "metrics_aggregator") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: metrics-aggregator
spec:
replicas: {{ .Values.services.metrics_aggregator.replicaCount | default 1 }}
selector:
matchLabels:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "metrics_aggregator") | nindent 6 }}
template:
metadata:
labels:
{{- include "deploy.labels" . | nindent 8 }}
app.kubernetes.io/component: metrics-aggregator
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
{{- with .Values.podSecurityContext }}
securityContext:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: metrics_aggregator
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 12 }}
{{- end }}
image: "{{ .Values.imageRegistry }}/{{ .Values.services.metrics_aggregator.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
imagePullPolicy: IfNotPresent
ports:
- name: http
containerPort: {{ .Values.services.metrics_aggregator.port }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.configMapName }}
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 15
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 10
{{- with .Values.services.metrics_aggregator.resources }}
resources:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@ -0,0 +1,16 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "metrics_aggregator") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: metrics-aggregator
spec:
type: ClusterIP
ports:
- port: {{ .Values.services.metrics_aggregator.port }}
targetPort: http
protocol: TCP
name: http
selector:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "metrics_aggregator") | nindent 4 }}

View File

@ -0,0 +1 @@
{{/* Secret disabled — all config via ConfigMap */}}

View File

@ -0,0 +1,13 @@
{{- if .Values.serviceAccount.create -}}
apiVersion: v1
kind: ServiceAccount
metadata:
name: {{ include "deploy.serviceAccountName" . }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
{{- with .Values.serviceAccount.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
automountServiceAccountToken: {{ .Values.serviceAccount.automount }}
{{- end }}

View File

@ -0,0 +1,78 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "static_server") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: static-server
spec:
replicas: {{ .Values.services.static_server.replicaCount | default 1 }}
selector:
matchLabels:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "static_server") | nindent 6 }}
template:
metadata:
labels:
{{- include "deploy.labels" . | nindent 8 }}
app.kubernetes.io/component: static-server
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
{{- with .Values.podSecurityContext }}
securityContext:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: static_server
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 12 }}
{{- end }}
image: "{{ .Values.imageRegistry }}/{{ .Values.services.static_server.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
imagePullPolicy: IfNotPresent
ports:
- name: http
containerPort: {{ .Values.services.static_server.port }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.configMapName }}
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 15
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 10
{{- with .Values.services.static_server.resources }}
resources:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.services.static_server.volumeMounts }}
volumeMounts:
{{- toYaml . | nindent 12 }}
{{- end }}
volumes:
- name: shared-data
persistentVolumeClaim:
claimName: {{ .Values.pvcName }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@ -0,0 +1,16 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "static_server") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: static-server
spec:
type: ClusterIP
ports:
- port: {{ .Values.services.static_server.port }}
targetPort: http
protocol: TCP
name: http
selector:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "static_server") | nindent 4 }}

182
deploy/values.yaml Normal file
View File

@ -0,0 +1,182 @@
# Global image registry and tag
imageRegistry: ""
imageTag: ""
# External ConfigMap (managed outside Helm)
configMapName: "app-env"
# Service definitions
services:
app:
repository: app
port: 3000
replicaCount: 2
autoscaling:
enabled: true
minReplicas: 2
maxReplicas: 10
targetCPUUtilization: 80
command:
- "app"
- "--bind"
- "0.0.0.0:3000"
resources:
requests:
cpu: 200m
memory: 256Mi
limits:
cpu: "1"
memory: 512Mi
volumeMounts:
- name: shared-data
mountPath: /data/repos
subPath: repos
- name: shared-data
mountPath: /data/avatars
subPath: avatars
- name: shared-data
mountPath: /data/files
subPath: files
email_worker:
repository: email-worker
port: 8084
replicaCount: 1
autoscaling:
enabled: false # email must stay at 1 replica
resources:
requests:
cpu: 100m
memory: 128Mi
limits:
cpu: 500m
memory: 256Mi
git_hook:
repository: git-hook
port: 8083
replicaCount: 1
autoscaling:
enabled: true
minReplicas: 1
maxReplicas: 5
targetCPUUtilization: 80
resources:
requests:
cpu: 100m
memory: 128Mi
limits:
cpu: 500m
memory: 256Mi
volumeMounts:
- name: shared-data
mountPath: /data/repos
subPath: repos
gitserver:
repository: gitserver
ports:
http: 8021
ssh: 2222
replicaCount: 1
autoscaling:
enabled: true
minReplicas: 1
maxReplicas: 5
targetCPUUtilization: 80
# SSH port must match the containerPort
extraEnv:
APP_SSH_PORT: "2222"
resources:
requests:
cpu: 100m
memory: 128Mi
limits:
cpu: 500m
memory: 256Mi
volumeMounts:
- name: shared-data
mountPath: /data/repos
subPath: repos
metrics_aggregator:
repository: metrics-aggregator
port: 9090
replicaCount: 1
autoscaling:
enabled: true
minReplicas: 1
maxReplicas: 5
targetCPUUtilization: 80
resources:
requests:
cpu: 100m
memory: 128Mi
limits:
cpu: 500m
memory: 256Mi
static_server:
repository: static-server
port: 8081
replicaCount: 1
autoscaling:
enabled: true
minReplicas: 1
maxReplicas: 5
targetCPUUtilization: 80
resources:
requests:
cpu: 50m
memory: 64Mi
limits:
cpu: 200m
memory: 128Mi
volumeMounts:
- name: shared-data
mountPath: /data
subPath: static
# External PVC (managed outside Helm — not deleted on uninstall)
pvcName: "shared-data"
# Ingress — only for the main app service
ingress:
enabled: false
className: ""
annotations: {}
hosts:
- host: chart-example.local
paths:
- path: /
pathType: Prefix
serviceName: app
servicePort: 3000
tls: []
# - secretName: chart-example-tls
# hosts:
# - chart-example.local
imagePullSecrets: []
nameOverride: ""
fullnameOverride: ""
serviceAccount:
create: true
automount: true
annotations: {}
name: ""
podSecurityContext:
runAsNonRoot: true
runAsUser: 1000
securityContext:
capabilities:
drop:
- ALL
readOnlyRootFilesystem: true
nodeSelector: {}
tolerations: []
affinity: {}

10
docker/app.Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 openssh-client procps git \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/app /bin
USER appuser
EXPOSE 3000
CMD ["app"]

10
docker/email.Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/email-worker /bin
USER appuser
EXPOSE 8084
CMD ["email-worker"]

View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/gingress /bin
USER appuser
EXPOSE 80 443 8080
CMD ["gingress"]

10
docker/githook.Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 git \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/git-hook /bin
USER appuser
EXPOSE 8083
CMD ["git-hook"]

View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 git openssh-client \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/gitserver /bin
USER appuser
EXPOSE 8021 2222
CMD ["gitserver"]

10
docker/metrics.Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/metrics-aggregator /bin
USER appuser
EXPOSE 9090
CMD ["metrics-aggregator"]

10
docker/static.Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/static-server /bin
USER appuser
EXPOSE 8081
CMD ["static-server"]

View File

@ -42,5 +42,6 @@ reqwest = { workspace = true, features = ["json"] }
utoipa = { workspace = true } utoipa = { workspace = true }
tokio-stream = { workspace = true } tokio-stream = { workspace = true }
redis = { workspace = true, features = ["tokio-comp"] } redis = { workspace = true, features = ["tokio-comp"] }
queue = { workspace = true }
[lints] [lints]
workspace = true workspace = true

View File

@ -152,7 +152,8 @@ impl RigAgentService {
Ok(MultiTurnStreamItem::StreamAssistantItem( Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text), StreamedAssistantContent::Text(text),
)) => { )) => {
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await; let cleaned = text.text.replace('\n', "");
let _ = tx.send(Ok(StreamChunk::Text(cleaned))).await;
final_content.push_str(&text.text); final_content.push_str(&text.text);
} }
Ok(MultiTurnStreamItem::StreamAssistantItem( Ok(MultiTurnStreamItem::StreamAssistantItem(
@ -237,7 +238,8 @@ impl RigAgentService {
Ok(MultiTurnStreamItem::StreamAssistantItem( Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text), StreamedAssistantContent::Text(text),
)) => { )) => {
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await; let cleaned = text.text.replace('\n', "");
let _ = tx.send(Ok(StreamChunk::Text(cleaned))).await;
final_content.push_str(&text.text); final_content.push_str(&text.text);
} }
Ok(MultiTurnStreamItem::FinalResponse(resp)) => { Ok(MultiTurnStreamItem::FinalResponse(resp)) => {

View File

@ -1,226 +1,455 @@
//! AI usage billing — records token costs against a project or workspace balance. //! Billing service — handles user-level and project-level billing, deduction,
//! credit initialization, and error persistence.
//! //!
//! All functions take `&DatabaseConnection` instead of `&AppService`. //! Architecture:
//! - Each user gets $10 personal balance on signup.
//! - Each project gets $20 balance only if it's the creator's first project,
//! $0 otherwise.
//! - AI usage is deducted from the project balance first; if insufficient,
//! falls through to the user's personal balance.
//! - Monthly quota only applies to pro users (is_pro = true).
//! - If both project and user balance are insufficient, a billing_error
//! record is persisted and an error is returned to the caller.
use db::database::AppDatabase; use db::database::AppDatabase;
use models::agents::model_pricing; use models::agents::model_pricing;
use models::projects::project; use models::ai::billing_error;
use models::projects::project_billing; use models::projects::{project, project_billing, project_billing_history};
use models::projects::project_billing_history; use models::users::user_billing;
use models::workspaces::workspace_billing;
use models::workspaces::workspace_billing_history;
use rust_decimal::Decimal; use rust_decimal::Decimal;
use sea_orm::*; use sea_orm::*;
use uuid::Uuid; use uuid::Uuid;
use crate::error::AgentError; use crate::error::AgentError;
// ── Constants ──
fn default_user_balance() -> Decimal { Decimal::new(100_000, 4) } // $10.0000
fn first_project_credit() -> Decimal { Decimal::new(200_000, 4) } // $20.0000
const SUBSEQUENT_PROJECT_BALANCE: Decimal = Decimal::ZERO;
// ── Types ──
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, utoipa::ToSchema)] #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, utoipa::ToSchema)]
pub struct BillingRecord { pub struct BillingRecord {
pub cost: f64, pub cost: f64,
pub currency: String, pub currency: String,
pub input_tokens: i64, pub input_tokens: i64,
pub output_tokens: i64, pub output_tokens: i64,
pub deducted_from: String, // "project" or "user"
} }
/// Extended result that includes insufficient balance flag for system message creation.
#[derive(Debug)] #[derive(Debug)]
pub enum BillingResult { pub enum BillingResult {
Success(BillingRecord), Success(BillingRecord),
InsufficientBalance { message: String }, InsufficientBalance { message: String },
} }
/// Record AI usage for a project with cascading billing. // ── Core deduction: AI usage ──
/// Record AI usage: deduct from project balance first, fall through to user balance.
/// ///
/// Billing strategy: /// Returns `InsufficientBalance` if neither account can cover the cost.
/// 1. Try to deduct from project balance first /// On insufficient balance, a `billing_error` record is persisted for frontend display.
/// 2. If insufficient, fallback to workspace balance (if project belongs to workspace)
/// 3. If both insufficient or no workspace, return InsufficientBalance error with room_id
///
/// Returns BillingError::InsufficientBalance with room_id for system message creation.
pub async fn record_ai_usage( pub async fn record_ai_usage(
db: &AppDatabase, db: &AppDatabase,
project_uid: Uuid, project_uid: Uuid,
user_uid: Uuid,
model_id: Uuid, model_id: Uuid,
input_tokens: i64, input_tokens: i64,
output_tokens: i64, output_tokens: i64,
) -> Result<BillingResult, AgentError> { ) -> Result<BillingResult, AgentError> {
// 1. Look up the active price for this model. let total_cost = compute_cost(db, model_id, input_tokens, output_tokens).await?;
let currency = get_currency(db, model_id).await?;
// Verify project exists
let _ = project::Entity::find_by_id(project_uid)
.one(db)
.await?
.ok_or_else(|| AgentError::Internal("Project not found".into()))?;
// Attempt project-level deduction first
let project_result = deduct_from_project(db, project_uid, total_cost, &currency, model_id, input_tokens, output_tokens).await;
match project_result {
Ok(()) => {
let cost_f64 = decimal_to_f64(total_cost);
tracing::info!(
project_id = %project_uid,
model_id = %model_id,
input_tokens, output_tokens,
cost = %cost_f64,
currency = %currency,
deducted_from = "project",
"ai_usage_recorded"
);
Ok(BillingResult::Success(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
deducted_from: "project".to_string(),
}))
}
Err(_) => {
// Project balance insufficient — try user personal balance
let user_result = deduct_from_user(db, user_uid, total_cost, &currency, project_uid, model_id, input_tokens, output_tokens).await;
match user_result {
Ok(()) => {
let cost_f64 = decimal_to_f64(total_cost);
tracing::info!(
user_id = %user_uid,
project_id = %project_uid,
model_id = %model_id,
input_tokens, output_tokens,
cost = %cost_f64,
currency = %currency,
deducted_from = "user",
"ai_usage_recorded"
);
Ok(BillingResult::Success(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
deducted_from: "user".to_string(),
}))
}
Err(insufficient_msg) => {
// Both project and user balance insufficient — persist error
persist_billing_error(
db,
"project",
project_uid,
"insufficient_balance",
&insufficient_msg,
Some(serde_json::json!({
"user_id": user_uid.to_string(),
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cost": decimal_to_f64(total_cost),
"currency": currency,
})),
).await?;
Ok(BillingResult::InsufficientBalance {
message: insufficient_msg,
})
}
}
}
}
}
/// Check whether a project + user has sufficient combined balance for a potential AI call.
/// Called before starting AI processing to avoid wasted compute.
pub async fn check_balance(
db: &AppDatabase,
project_uid: Uuid,
user_uid: Uuid,
model_id: Uuid,
estimated_input_tokens: i64,
estimated_output_tokens: i64,
) -> Result<bool, AgentError> {
let estimated_cost = compute_cost(db, model_id, estimated_input_tokens, estimated_output_tokens).await?;
let project_balance = get_project_balance(db, project_uid).await;
let user_balance = get_user_balance(db, user_uid).await;
Ok(project_balance + user_balance >= estimated_cost)
}
// ── Initialization ──
/// Initialize a user billing account with the default $10 balance.
/// Called on user signup / first login.
pub async fn initialize_user_billing(db: &AppDatabase, user_uid: Uuid) -> Result<(), AgentError> {
let now = chrono::Utc::now();
user_billing::ActiveModel {
user: Set(user_uid),
balance: Set(default_user_balance()),
currency: Set("USD".to_string()),
is_pro: Set(false),
monthly_quota: Set(Decimal::ZERO),
month_used: Set(Decimal::ZERO),
cycle_start: Set(None),
cycle_end: Set(None),
updated_at: Set(now),
created_at: Set(now),
}
.insert(db)
.await
.map_err(|e| AgentError::Internal(format!("failed to create user billing: {}", e)))?;
tracing::info!(user_id = %user_uid, balance = "$10", "user_billing_initialized");
Ok(())
}
/// Initialize a project billing account.
/// Grants $20 only if this is the creator's first project; $0 otherwise.
pub async fn initialize_project_billing(
db: &AppDatabase,
project_uid: Uuid,
creator_uid: Uuid,
) -> Result<(), AgentError> {
// Check how many projects this user has already created
let existing_count = project::Entity::find()
.filter(project::Column::CreatedBy.eq(creator_uid))
.count(db)
.await
.map_err(|e| AgentError::Internal(format!("failed to count user projects: {}", e)))?;
let is_first = existing_count == 0;
let initial_balance = if is_first { first_project_credit() } else { SUBSEQUENT_PROJECT_BALANCE };
let now = chrono::Utc::now();
project_billing::ActiveModel {
project: Set(project_uid),
balance: Set(initial_balance),
currency: Set("USD".to_string()),
user: Set(Some(creator_uid)),
initial_credit_granted: Set(is_first),
is_pro: Set(false),
monthly_quota: Set(Decimal::ZERO),
month_used: Set(Decimal::ZERO),
cycle_start: Set(None),
cycle_end: Set(None),
updated_at: Set(now),
created_at: Set(now),
}
.insert(db)
.await
.map_err(|e| AgentError::Internal(format!("failed to create project billing: {}", e)))?;
if is_first {
// Record the credit in billing history
project_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
project: Set(project_uid),
user: Set(Some(creator_uid)),
amount: Set(first_project_credit()),
currency: Set("USD".to_string()),
reason: Set("first_project_credit".to_string()),
extra: Set(Some(serde_json::json!({
"is_first_project": true,
}))),
created_at: Set(now),
..Default::default()
}
.insert(db)
.await
.map_err(|e| AgentError::Internal(format!("failed to record credit history: {}", e)))?;
}
tracing::info!(
project_id = %project_uid,
creator_id = %creator_uid,
is_first_project = is_first,
balance = if is_first { "$20" } else { "$0" },
"project_billing_initialized"
);
Ok(())
}
// ── Internal helpers ──
async fn compute_cost(
db: &AppDatabase,
model_id: Uuid,
input_tokens: i64,
output_tokens: i64,
) -> Result<Decimal, AgentError> {
let pricing = model_pricing::Entity::find() let pricing = model_pricing::Entity::find()
.filter(model_pricing::Column::ModelVersionId.eq(model_id)) .filter(model_pricing::Column::ModelVersionId.eq(model_id))
.order_by_desc(model_pricing::Column::EffectiveFrom) .order_by_desc(model_pricing::Column::EffectiveFrom)
.one(db) .one(db)
.await? .await?
.ok_or_else(|| { .ok_or_else(|| AgentError::Internal(
AgentError::Internal( "No pricing record found for this model. Please configure AI model pricing first.".into(),
"No pricing record found for this model. Please configure AI model pricing first." ))?;
.into(),
) let input_price: Decimal = pricing.input_price_per_1k_tokens.parse()
})?; .map_err(|e| AgentError::Internal(format!("Invalid input price: {}", e)))?;
let output_price: Decimal = pricing.output_price_per_1k_tokens.parse()
.map_err(|e| AgentError::Internal(format!("Invalid output price: {}", e)))?;
// 2. Compute cost using Decimal arithmetic.
let input_price: Decimal = pricing
.input_price_per_1k_tokens
.parse()
.map_err(|e| AgentError::Internal(format!("Invalid input price format: {}", e)))?;
let output_price: Decimal = pricing
.output_price_per_1k_tokens
.parse()
.map_err(|e| AgentError::Internal(format!("Invalid output price format: {}", e)))?;
let tokens_i = Decimal::from(input_tokens);
let tokens_o = Decimal::from(output_tokens);
let thousand = Decimal::from(1000); let thousand = Decimal::from(1000);
Ok((Decimal::from(input_tokens) / thousand) * input_price
+ (Decimal::from(output_tokens) / thousand) * output_price)
}
let total_cost = (tokens_i / thousand) * input_price async fn get_currency(db: &AppDatabase, model_id: Uuid) -> Result<String, AgentError> {
+ (tokens_o / thousand) * output_price; let pricing = model_pricing::Entity::find()
.filter(model_pricing::Column::ModelVersionId.eq(model_id))
let currency = pricing.currency.clone();
// 3. Cascading billing: project balance first, then workspace if insufficient.
let proj = project::Entity::find_by_id(project_uid)
.one(db) .one(db)
.await? .await?
.ok_or_else(|| AgentError::Internal("Project not found".into()))?; .ok_or_else(|| AgentError::Internal("No pricing found".into()))?;
Ok(pricing.currency.clone())
}
let txn = db.begin().await?; async fn get_project_balance(db: &AppDatabase, project_uid: Uuid) -> Decimal {
project_billing::Entity::find_by_id(project_uid)
.one(db)
.await
.ok()
.flatten()
.map(|b| b.balance)
.unwrap_or(Decimal::ZERO)
}
// Always check project balance first async fn get_user_balance(db: &AppDatabase, user_uid: Uuid) -> Decimal {
let project_billing = project_billing::Entity::find_by_id(project_uid) user_billing::Entity::find_by_id(user_uid)
.one(db)
.await
.ok()
.flatten()
.map(|b| b.balance)
.unwrap_or(Decimal::ZERO)
}
async fn deduct_from_project(
db: &AppDatabase,
project_uid: Uuid,
cost: Decimal,
currency: &str,
model_id: Uuid,
input_tokens: i64,
output_tokens: i64,
) -> Result<(), String> {
let txn = db.begin().await.map_err(|e| format!("db txn error: {}", e))?;
let billing = project_billing::Entity::find_by_id(project_uid)
.lock_exclusive() .lock_exclusive()
.one(&txn) .one(&txn)
.await? .await
.ok_or_else(|| AgentError::Internal("Project billing account not found".into()))?; .map_err(|e| format!("db error: {}", e))?
.ok_or_else(|| "Project billing account not found".to_string())?;
if billing.balance < cost {
txn.rollback().await.ok();
return Err(format!(
"Project balance insufficient. Required: {:.4} {}, Available: {:.4} {}",
cost, currency, billing.balance, currency
));
}
let now = chrono::Utc::now(); let now = chrono::Utc::now();
if project_billing.balance >= total_cost { project_billing_history::ActiveModel {
// ── Project has sufficient balance ────────────────────────── uid: Set(Uuid::new_v4()),
let amount_dec = -total_cost; project: Set(project_uid),
user: Set(None),
project_billing_history::ActiveModel { amount: Set(-cost),
uid: Set(Uuid::new_v4()), currency: Set(currency.to_string()),
project: Set(project_uid), reason: Set("ai_usage".to_string()),
user: Set(None), extra: Set(Some(serde_json::json!({
amount: Set(amount_dec), "model_id": model_id.to_string(),
currency: Set(currency.clone()), "input_tokens": input_tokens,
reason: Set("ai_usage".to_string()), "output_tokens": output_tokens,
extra: Set(Some(serde_json::json!({ "deducted_from": "project",
"model_id": model_id.to_string(), }))),
"input_tokens": input_tokens, created_at: Set(now),
"output_tokens": output_tokens, ..Default::default()
}))),
created_at: Set(now),
..Default::default()
}
.insert(&txn)
.await?;
let new_balance = project_billing.balance - total_cost;
let mut updated: project_billing::ActiveModel = project_billing.into();
updated.balance = Set(new_balance);
updated.updated_at = Set(now);
updated.update(&txn).await?;
txn.commit().await?;
let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0);
tracing::info!(
project_id = %project_uid,
model_id = %model_id,
input_tokens = input_tokens,
output_tokens = output_tokens,
cost = %cost_f64,
currency = %currency,
source = "project",
"ai_usage_recorded"
);
Ok(BillingResult::Success(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
}))
} else if let Some(workspace_id) = proj.workspace_id {
// ── Project insufficient, fallback to workspace ─────────────
let workspace_billing = workspace_billing::Entity::find_by_id(workspace_id)
.lock_exclusive()
.one(&txn)
.await?
.ok_or_else(|| AgentError::Internal("Workspace billing account not found".into()))?;
if workspace_billing.balance < total_cost {
txn.rollback().await?;
return Ok(BillingResult::InsufficientBalance {
message: format!(
"Insufficient balance. Project: {:.4} {}, Workspace: {:.4} {}, Required: {:.4} {}",
project_billing.balance, currency,
workspace_billing.balance, currency,
total_cost, currency
),
});
}
let amount_dec = -total_cost;
workspace_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
workspace_id: Set(workspace_id),
user_id: Set(Some(proj.created_by)),
amount: Set(amount_dec),
currency: Set(currency.clone()),
reason: Set(format!("ai_usage:{}", project_uid)),
extra: Set(Some(serde_json::json!({
"project_id": project_uid.to_string(),
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"fallback_reason": "project_balance_insufficient"
}))),
created_at: Set(now),
}
.insert(&txn)
.await?;
let new_balance = workspace_billing.balance - total_cost;
let new_total_spent = workspace_billing.total_spent + total_cost;
let mut updated: workspace_billing::ActiveModel = workspace_billing.into();
updated.balance = Set(new_balance);
updated.total_spent = Set(new_total_spent);
updated.updated_at = Set(now);
updated.update(&txn).await?;
txn.commit().await?;
let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0);
tracing::info!(
project_id = %project_uid,
model_id = %model_id,
input_tokens = input_tokens,
output_tokens = output_tokens,
cost = %cost_f64,
currency = %currency,
workspace_id = %workspace_id.to_string(),
source = "workspace_fallback",
"ai_usage_recorded"
);
Ok(BillingResult::Success(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
}))
} else {
// ── Project insufficient and no workspace ───────────────────
txn.rollback().await?;
Ok(BillingResult::InsufficientBalance {
message: format!(
"Insufficient balance. Required: {:.4} {}, Available: {:.4} {}",
total_cost, currency, project_billing.balance, currency
),
})
} }
.insert(&txn)
.await
.map_err(|e| format!("failed to insert history: {}", e))?;
let mut updated: project_billing::ActiveModel = billing.into();
updated.balance = Set(updated.balance.unwrap() - cost);
updated.updated_at = Set(now);
updated.update(&txn).await.map_err(|e| format!("failed to update balance: {}", e))?;
txn.commit().await.map_err(|e| format!("commit error: {}", e))?;
Ok(())
}
async fn deduct_from_user(
db: &AppDatabase,
user_uid: Uuid,
cost: Decimal,
currency: &str,
project_uid: Uuid,
model_id: Uuid,
input_tokens: i64,
output_tokens: i64,
) -> Result<(), String> {
let txn = db.begin().await.map_err(|e| format!("db txn error: {}", e))?;
let billing = user_billing::Entity::find_by_id(user_uid)
.lock_exclusive()
.one(&txn)
.await
.map_err(|e| format!("db error: {}", e))?
.ok_or_else(|| "User billing account not found".to_string())?;
if billing.balance < cost {
txn.rollback().await.ok();
return Err(format!(
"Insufficient balance (project + user). Project: unavailable, User: {:.4} {}. Required: {:.4} {}",
billing.balance, currency, cost, currency
));
}
let now = chrono::Utc::now();
// Record in project billing history (but deducted from user)
project_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
project: Set(project_uid),
user: Set(Some(user_uid)),
amount: Set(-cost),
currency: Set(currency.to_string()),
reason: Set("ai_usage_user_fallback".to_string()),
extra: Set(Some(serde_json::json!({
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"deducted_from": "user",
}))),
created_at: Set(now),
..Default::default()
}
.insert(&txn)
.await
.map_err(|e| format!("failed to insert history: {}", e))?;
let mut updated: user_billing::ActiveModel = billing.into();
updated.balance = Set(updated.balance.unwrap() - cost);
updated.updated_at = Set(now);
updated.update(&txn).await.map_err(|e| format!("failed to update user balance: {}", e))?;
txn.commit().await.map_err(|e| format!("commit error: {}", e))?;
Ok(())
}
pub async fn persist_billing_error(
db: &AppDatabase,
scope: &str,
scope_id: Uuid,
error_type: &str,
message: &str,
details: Option<serde_json::Value>,
) -> Result<(), AgentError> {
billing_error::ActiveModel {
id: Set(Uuid::new_v4()),
scope: Set(scope.to_string()),
scope_id: Set(scope_id),
error_type: Set(error_type.to_string()),
message: Set(message.to_string()),
details: Set(details),
resolved: Set(false),
created_at: Set(chrono::Utc::now()),
}
.insert(db)
.await
.map_err(|e| AgentError::Internal(format!("failed to persist billing error: {}", e)))?;
tracing::warn!(scope, %scope_id, error_type, "billing_error_persisted");
Ok(())
}
fn decimal_to_f64(d: Decimal) -> f64 {
d.round_dp(10).to_string().parse().unwrap_or(0.0)
} }

View File

@ -0,0 +1,357 @@
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream};
use crate::error::Result;
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolDefinition, ToolExecutor, ToolHandler, ToolParam};
use crate::tool::registry::ToolRegistry;
use crate::embed::EmbedService;
use sea_orm::{ActiveModelTrait, EntityTrait, Set};
use super::{AiChunkType, AiStreamChunk, StreamCallback};
use super::service::StreamResult;
// Keyword-extraction-based title generator: reads conversation messages, extracts
// meaningful words, and updates the conversation record with a short title.
async fn generate_title_for_conversation(
ctx: &ToolContext,
conversation_id: Uuid,
) -> Result<serde_json::Value> {
use models::ai::{ai_conversation, ai_message, AiMessage};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
let db_reader = ctx.db().reader();
let db_writer = ctx.db().writer();
let conv = ai_conversation::Entity::find_by_id(conversation_id)
.one(db_reader)
.await
.map_err(|e| crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: format!("db error: {}", e) })?
.ok_or_else(|| crate::error::AgentError::NotFound("Conversation not found".into()))?;
let recent_messages = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.filter(ai_message::Column::Role.eq("user"))
.order_by_desc(ai_message::Column::CreatedAt)
.limit(3)
.all(db_reader)
.await
.map_err(|e| crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: format!("db error: {}", e) })?;
if recent_messages.is_empty() {
return Err(crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: "No user messages found".into() });
}
let content = recent_messages
.first()
.and_then(|m| m.content.as_array())
.and_then(|arr| arr.first())
.and_then(|v| v.get("content"))
.and_then(|c| c.as_str())
.unwrap_or("");
let words: Vec<&str> = content
.split_whitespace()
.filter(|w| w.len() > 2 && !is_stop_word(w))
.take(5)
.collect();
let title = if words.is_empty() {
"New Chat".to_string()
} else {
words.join(" ")
};
let mut active: ai_conversation::ActiveModel = conv.into();
active.title = Set(Some(title.clone()));
active.updated_at = Set(chrono::Utc::now());
active
.update(db_writer)
.await
.map_err(|e| crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: format!("failed to update title: {}", e) })?;
Ok(serde_json::json!({ "conversation_id": conversation_id.to_string(), "title": title }))
}
fn is_stop_word(w: &str) -> bool {
matches!(
w.to_lowercase().as_str(),
"the" | "this" | "that" | "what" | "which" | "when" | "where"
| "why" | "how" | "can" | "could" | "would" | "should"
| "please" | "help" | "thanks" | "thank" | "you" | "your"
| "have" | "has" | "had" | "with" | "for" | "from" | "into"
| "about" | "also" | "just" | "now" | "very" | "really"
)
}
type SharedCallback = Arc<dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
/// Simplified ReAct execution for Chat API.
///
/// Unlike `execute_process_stream` (which requires `AiRequest` with room-specific data),
/// this function takes messages and tools directly. It does NOT record AI sessions to
/// the `ai_session` table — the caller is responsible for persisting results.
pub async fn execute_chat_stream(
messages: Vec<ChatRequestMessage>,
tools: Vec<serde_json::Value>,
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
max_tool_depth: usize,
tool_registry: Option<&ToolRegistry>,
db: db::database::AppDatabase,
cache: db::cache::AppCache,
app_config: config::AppConfig,
project_id: Uuid,
sender_uid: Uuid,
embed_service: Option<EmbedService>,
on_chunk: StreamCallback,
conversation_id: Option<uuid::Uuid>,
) -> Result<StreamResult> {
let on_chunk: SharedCallback = Arc::from(on_chunk);
let tools_enabled = !tools.is_empty();
let mut messages = messages;
let mut tool_depth = 0;
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let mut full_content = String::new();
let mut all_chunks: Vec<StreamChunk> = Vec::new();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
// Conditionally inject chat_generate_title tool if conversation has no title
let (tools, _tools_injected) = if let Some(conv_id) = conversation_id {
if let Some(registry) = tool_registry {
let db_reader = db.reader();
let has_title = models::ai::ai_conversation::Entity::find_by_id(conv_id)
.one(db_reader)
.await
.map(|c| c.map(|m| m.title.is_some()).unwrap_or(false))
.unwrap_or(false);
if !has_title {
let mut reg = registry.clone();
reg.register(
ToolDefinition::new("chat_generate_title")
.description(
"Generate a concise title (5 words or fewer) for the current conversation \
based on its message history, and save it to the conversation record. \
Call this tool at the start of a new conversation if it has no title.",
)
.parameters(crate::tool::ToolSchema {
schema_type: "object".into(),
properties: Some({
let mut p = std::collections::HashMap::new();
p.insert("conversation_id".into(), ToolParam {
name: "conversation_id".into(),
param_type: "string".into(),
description: Some("The UUID of the conversation (required).".into()),
required: true,
properties: None,
items: None,
});
p
}),
required: Some(vec!["conversation_id".into()]),
}),
ToolHandler::new(|ctx, args| {
let conv_id = args.get("conversation_id")
.and_then(|v| v.as_str())
.and_then(|s| Uuid::parse_str(s).ok());
Box::pin(async move {
match conv_id {
Some(id) => generate_title_for_conversation(&ctx, id).await
.map_err(|e| crate::tool::ToolError::ExecutionError(e.to_string())),
None => Err(crate::tool::ToolError::ExecutionError("conversation_id missing".into())),
}
})
}),
);
// Prepend system message instructing the model to generate title first
messages.insert(0, ChatRequestMessage::system(
"IMPORTANT: If the conversation has no title, you MUST call chat_generate_title \
with the conversation_id immediately before answering any user question. \
The title must be 5 words or fewer and should summarize the user's intent.".to_string(),
));
(reg.to_openai_tools(), true)
} else {
(tools.clone(), false)
}
} else {
(tools.clone(), false)
}
} else {
(tools.clone(), false)
};
loop {
let on_chunk_cb = on_chunk.clone();
let on_chunk_cb2 = on_chunk.clone();
let tx_arc = Arc::new(tx.clone());
let tx_arc2 = tx_arc.clone();
let response = call_stream(
&messages, model_name, config, temperature, max_tokens,
if tools_enabled { Some(&tools) } else { None }, None,
Arc::new(move |delta| {
let content = delta.to_string().replace('\n', "");
let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer });
fut
}),
Arc::new(move |delta| {
let fut = on_chunk_cb2(AiStreamChunk { content: delta.to_string(), done: false, chunk_type: AiChunkType::Thinking });
fut
}),
Arc::new(move |tc: &StreamedToolCall| {
let tx = tx_arc2.clone();
let tc_owned = tc.clone();
Box::pin(async move { let _ = tx.send(tc_owned); }) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
).await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
all_chunks.extend(response.chunks.clone());
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
if !has_tool_calls {
let final_content = response.content.clone();
// Don't broadcast the done chunk via SSE/NATS — incremental deltas
// already delivered the content; the separate "done" SSE event
// signals completion. Pushing full content again would duplicate it
// in the frontend streaming store.
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: final_content.clone() });
return Ok(StreamResult {
content: final_content,
reasoning_content: response.reasoning_content,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
chunks: all_chunks,
});
}
full_content.push_str(&response.content);
let tool_calls: Vec<ToolCall> = response.tool_calls.iter().map(|tc| ToolCall {
id: tc.id.clone(), type_: "function".into(),
function: crate::client::types::ToolCallFunction { name: tc.name.clone(), arguments: tc.arguments.clone() },
}).collect();
messages.push(ChatRequestMessage::assistant(Some(response.content.clone()), Some(tool_calls.clone())));
// Drain tool call notifications
loop {
match rx.try_recv() {
Ok(tc) => {
let args_display = if tc.arguments.len() > 100 {
let end = tc.arguments.char_indices().map(|(i, _)| i).take_while(|&i| i <= 100).last().unwrap_or(100);
format!("{}...", &tc.arguments[..end])
} else { tc.arguments.clone() };
let tool_display = format!("🔧 {}({})", tc.name, args_display);
on_chunk(AiStreamChunk { content: tool_display.clone(), done: false, chunk_type: AiChunkType::ToolCall }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: tool_display });
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
}
}
let calls: Vec<AgentToolCall> = response.tool_calls.iter().map(|tc| AgentToolCall {
id: tc.id.clone(), name: tc.name.clone(), arguments: tc.arguments.clone(),
}).collect();
let tool_messages = execute_tools(
&calls, &db, &cache, &app_config, project_id, sender_uid,
tool_registry, embed_service.as_ref(), &on_chunk, &mut all_chunks,
).await;
messages.extend(tool_messages);
tool_depth += 1;
if tool_depth >= max_tool_depth {
let max_depth_text = format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth);
on_chunk(AiStreamChunk { content: max_depth_text.clone(), done: true, chunk_type: AiChunkType::Answer }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: max_depth_text });
return Ok(StreamResult { content: full_content, reasoning_content: String::new(), input_tokens: 0, output_tokens: 0, chunks: all_chunks });
}
}
}
async fn execute_tools(
calls: &[AgentToolCall],
db: &db::database::AppDatabase,
cache: &db::cache::AppCache,
app_config: &config::AppConfig,
project_id: Uuid,
sender_uid: Uuid,
tool_registry: Option<&ToolRegistry>,
embed_service: Option<&EmbedService>,
on_chunk: &SharedCallback,
all_chunks: &mut Vec<StreamChunk>,
) -> Vec<ChatRequestMessage> {
let mut tool_messages = Vec::new();
let mut ctx = ToolContext::new(db.clone(), cache.clone(), app_config.clone(), Uuid::nil(), Some(sender_uid))
.with_project(project_id);
if let Some(es) = embed_service {
ctx = ctx.with_embed_service(es.clone());
}
if let Some(registry) = tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let mut join_set = tokio::task::JoinSet::new();
for call in calls {
let call_clone = call.clone();
let mut ctx_clone = ctx.clone();
join_set.spawn(async move {
let executor = ToolExecutor::new();
let res = executor.execute_batch(vec![call_clone.clone()], &mut ctx_clone).await;
(call_clone, res)
});
}
let heartbeat_dur = std::time::Duration::from_secs(10);
while !join_set.is_empty() {
tokio::select! {
Some(res) = join_set.join_next() => {
if let Ok((call, results)) = res {
match results {
Ok(results) => {
for result in &results {
let preview = match &result.result {
crate::tool::ToolResult::Ok(v) => {
let t = v.to_string();
if t.len() > 300 {
let end = t.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
format!("{}...", &t[..end])
} else { t.clone() }
}
crate::tool::ToolResult::Error(msg) => msg.clone(),
};
tracing::debug!("tool_result: {} — {}", call.name, preview);
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
let msgs = ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
Err(e) => {
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_text = format!("[Tool call failed: {}]", e);
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
}
}
}
},
_ = tokio::time::sleep(heartbeat_dur) => {
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall }).await;
}
}
}
tool_messages
}

View File

@ -4,7 +4,7 @@ use sea_orm::*;
use super::context::RoomMessageContext; use super::context::RoomMessageContext;
use super::{AiRequest, Mention}; use super::{AiRequest, Mention};
use crate::client::types::ChatRequestMessage; use crate::client::types::ChatRequestMessage;
use crate::compact::{CompactConfig, CompactService}; use crate::compact::CompactService;
use crate::embed::EmbedService; use crate::embed::EmbedService;
use crate::error::Result; use crate::error::Result;
use crate::perception::{PerceptionService, SkillEntry}; use crate::perception::{PerceptionService, SkillEntry};
@ -55,7 +55,6 @@ impl MessageBuilder {
let mut processed_history = Vec::new(); let mut processed_history = Vec::new();
if let Some(compact_service) = &self.compact_service { if let Some(compact_service) = &self.compact_service {
let compact_cache_key = format!("ai:compact:{}", request.room.id); let compact_cache_key = format!("ai:compact:{}", request.room.id);
let compact_config = CompactConfig::default();
let cached_summary: Option<String> = match request.cache.conn().await { let cached_summary: Option<String> = match request.cache.conn().await {
Ok(mut conn) => redis::cmd("GET").arg(&compact_cache_key).query_async::<Option<String>>(&mut conn).await.unwrap_or(None), Ok(mut conn) => redis::cmd("GET").arg(&compact_cache_key).query_async::<Option<String>>(&mut conn).await.unwrap_or(None),
Err(e) => { tracing::warn!(error = %e, "compact cache: conn failed"); None } Err(e) => { tracing::warn!(error = %e, "compact cache: conn failed"); None }
@ -71,7 +70,22 @@ impl MessageBuilder {
} }
if processed_history.is_empty() { if processed_history.is_empty() {
match compact_service.compact_room_auto(request.room.id, Some(request.user_names.clone()), compact_config).await { let compact_config = request.context_setting.as_ref()
.map(|s| crate::compact::CompactConfig::from_project_setting(
s.context_window_tokens,
s.compaction_threshold,
s.compaction_max_summary_ratio,
))
.unwrap_or_default();
match compact_service.compact_room(
request.room.id,
compact_config.default_level,
Some(request.user_names.clone()),
request.sender.uid,
request.context_setting.as_ref().map(|s| s.context_window_tokens).unwrap_or(128000),
request.context_setting.as_ref().map(|s| s.compaction_max_summary_ratio).unwrap_or(0.2),
).await {
Ok(compact_summary) => { Ok(compact_summary) => {
if !compact_summary.summary.is_empty() { if !compact_summary.summary.is_empty() {
messages.push(ChatRequestMessage::system(format!("Conversation summary:\n{}", compact_summary.summary))); messages.push(ChatRequestMessage::system(format!("Conversation summary:\n{}", compact_summary.summary)));
@ -174,7 +188,13 @@ impl MessageBuilder {
let keyword_skills = self.perception_service.inject_skills(&request.input, &history_texts, &[], &all_skills).await; let keyword_skills = self.perception_service.inject_skills(&request.input, &history_texts, &[], &all_skills).await;
let mut vector_skills = Vec::new(); let mut vector_skills = Vec::new();
if let Some(es) = &self.embed_service { if let Some(es) = &self.embed_service {
vector_skills = crate::perception::VectorActiveAwareness::default().detect(es, &request.input, &request.project.id.to_string()).await; let rag_enabled = request.context_setting.as_ref().map(|s| s.rag_enabled).unwrap_or(true);
if rag_enabled {
let max_results = request.context_setting.as_ref().map(|s| s.rag_max_results as usize).unwrap_or(3);
let min_score = request.context_setting.as_ref().map(|s| s.rag_min_score).unwrap_or(0.70);
let awareness = crate::perception::VectorActiveAwareness::new(max_results, min_score);
vector_skills = awareness.detect(es, &request.input, &request.project.id.to_string()).await;
}
} }
let mut seen = std::collections::HashSet::new(); let mut seen = std::collections::HashSet::new();
let mut result = Vec::new(); let mut result = Vec::new();
@ -184,8 +204,17 @@ impl MessageBuilder {
} }
async fn build_memory_context(&self, request: &AiRequest) -> Vec<crate::perception::vector::MemoryContext> { async fn build_memory_context(&self, request: &AiRequest) -> Vec<crate::perception::vector::MemoryContext> {
let rag_enabled = request.context_setting.as_ref().map(|s| s.rag_enabled).unwrap_or(true);
if !rag_enabled {
return Vec::new();
}
match &self.embed_service { match &self.embed_service {
Some(es) => crate::perception::VectorPassiveAwareness::default().detect(es, &request.input, &request.project.display_name, &request.room.id.to_string()).await, Some(es) => {
let max_results = request.context_setting.as_ref().map(|s| s.rag_max_results as usize).unwrap_or(3);
let min_score = request.context_setting.as_ref().map(|s| s.rag_min_score).unwrap_or(0.72);
let awareness = crate::perception::VectorPassiveAwareness::new(max_results, min_score);
awareness.detect(es, &request.input, &request.project.display_name, &request.room.id.to_string()).await
}
None => Vec::new(), None => Vec::new(),
} }
} }

View File

@ -3,7 +3,7 @@ use std::pin::Pin;
use db::cache::AppCache; use db::cache::AppCache;
use db::database::AppDatabase; use db::database::AppDatabase;
use models::agents::model; use models::agents::model;
use models::projects::project; use models::projects::{project, project_context_setting};
use models::repos::repo; use models::repos::repo;
use models::rooms::{room, room_message}; use models::rooms::{room, room_message};
use models::users::user; use models::users::user;
@ -44,7 +44,32 @@ impl Default for AiChunkType {
} }
} }
/// Optional streaming callback: called for each token chunk. const THINK_OPEN: &str = "\x3cthinking\x3e";
const THINK_CLOSE: &str = "\x3c/response\x3e";
/// Strip XML-format thinking tags that some models (e.g. DeepSeek-R1) embed
/// in reasoning output. Also normalizes excessive consecutive newlines (3+ → 2).
pub fn normalize_thinking_content(content: &str) -> String {
let content = content
.replace(THINK_CLOSE, "")
.replace(THINK_OPEN, "")
.replace("\x3cthinking", "")
.replace("/response\x3e", "");
let mut result = String::with_capacity(content.len());
let mut newline_count = 0usize;
for ch in content.chars() {
if ch == '\n' {
newline_count += 1;
if newline_count <= 2 {
result.push(ch);
}
} else {
newline_count = 0;
result.push(ch);
}
}
result.trim().to_string()
}
pub type StreamCallback = Box< pub type StreamCallback = Box<
dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync, dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync,
>; >;
@ -55,6 +80,7 @@ pub struct AiRequest {
pub config: AppConfig, pub config: AppConfig,
pub model: model::Model, pub model: model::Model,
pub project: project::Model, pub project: project::Model,
pub context_setting: Option<project_context_setting::Model>,
pub sender: user::Model, pub sender: user::Model,
pub room: room::Model, pub room: room::Model,
pub input: String, pub input: String,
@ -76,6 +102,7 @@ pub enum Mention {
Repo(repo::Model), Repo(repo::Model),
} }
pub mod chat_execution;
pub mod context; pub mod context;
pub mod message_builder; pub mod message_builder;
pub mod nonstreaming_execution; pub mod nonstreaming_execution;

View File

@ -82,13 +82,13 @@ pub async fn execute_process(
tool_depth += 1; tool_depth += 1;
if tool_depth >= max_tool_depth { if tool_depth >= max_tool_depth {
let content = if text.is_empty() { format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth) } else { text }; let content = if text.is_empty() { format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth) } else { text };
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), input_tokens, output_tokens, session_start.elapsed().as_millis() as i64).await; record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), input_tokens, output_tokens, session_start.elapsed().as_millis() as i64).await;
return Ok(ProcessResult { content, input_tokens, output_tokens }); return Ok(ProcessResult { content, input_tokens, output_tokens });
} }
continue; continue;
} }
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), input_tokens, output_tokens, session_start.elapsed().as_millis() as i64).await; record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), input_tokens, output_tokens, session_start.elapsed().as_millis() as i64).await;
return Ok(ProcessResult { content: text, input_tokens, output_tokens }); return Ok(ProcessResult { content: text, input_tokens, output_tokens });
} }
} }
@ -111,7 +111,7 @@ async fn execute_tools(
let elapsed = start.elapsed().as_millis() as i64; let elapsed = start.elapsed().as_millis() as i64;
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_)); let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None }; let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None };
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.clone(), session_id: recorder.session_id(), tool_name: call.clone(), caller: request.sender.uid, arguments: serde_json::Value::Null, status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success }, execution_time_ms: Some(elapsed), error_message: error_msg, error_stack: None, retry_count: 0 }); recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: Uuid::new_v4().to_string(), session_id: recorder.session_id(), tool_name: call.clone(), caller: request.sender.uid, arguments: serde_json::Value::Null, status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success }, execution_time_ms: Some(elapsed), error_message: error_msg, error_stack: None, retry_count: 0 });
} }
crate::tool::ToolExecutor::to_tool_messages(&results) crate::tool::ToolExecutor::to_tool_messages(&results)
} }

View File

@ -18,6 +18,8 @@ pub async fn execute_process_react<C, Fut>(
request: &AiRequest, mut on_chunk: C, request: &AiRequest, mut on_chunk: C,
tool_registry: &ToolRegistry, tool_registry: &ToolRegistry,
ai_base_url: Option<String>, ai_api_key: Option<String>, ai_base_url: Option<String>, ai_api_key: Option<String>,
room_preamble: Option<&str>,
message_producer: Option<queue::MessageProducer>,
) -> Result<(String, i64, i64)> ) -> Result<(String, i64, i64)>
where where
C: FnMut(ReactStep) -> Fut + Send, C: FnMut(ReactStep) -> Fut + Send,
@ -33,6 +35,9 @@ where
let room_id = request.room.id; let room_id = request.room.id;
let sender_uid = request.sender.uid; let sender_uid = request.sender.uid;
let project_id = request.project.id; let project_id = request.project.id;
let ai_model_id = request.model.id;
let ai_model_name = request.model.name.clone();
let sent_in_turn = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let session_id = Uuid::now_v7(); let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now(); let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find() let version_id = room_ai::Entity::find()
@ -46,7 +51,9 @@ where
if let Some(handler) = tool_registry.get(&name) { if let Some(handler) = tool_registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new( let adapter = crate::tool::RigToolAdapter::new(
handler.clone(), def.clone(), db.clone(), cache.clone(), cfg.clone(), handler.clone(), def.clone(), db.clone(), cache.clone(), cfg.clone(),
room_id, Some(sender_uid), project_id, room_id, Some(sender_uid), project_id, message_producer.clone(),
Some(ai_model_id), Some(ai_model_name.clone()),
sent_in_turn.clone(),
); );
tools.push(Box::new(RecordingTool::new(Box::new(adapter), db.clone(), session_id, sender_uid))); tools.push(Box::new(RecordingTool::new(Box::new(adapter), db.clone(), session_id, sender_uid)));
} }
@ -54,8 +61,14 @@ where
let rig_client = client_config.build_rig_client(); let rig_client = client_config.build_rig_client();
let model = rig_client.completion_model(&request.model.name); let model = rig_client.completion_model(&request.model.name);
let preamble = match room_preamble {
Some(rp) => format!("{}\n{}", rp, DEFAULT_SYSTEM_PROMPT),
None => DEFAULT_SYSTEM_PROMPT.to_string(),
};
let agent = AgentBuilder::new(model) let agent = AgentBuilder::new(model)
.preamble(DEFAULT_SYSTEM_PROMPT) .preamble(&preamble)
.tools(tools) .tools(tools)
.default_max_turns(request.max_tool_depth) .default_max_turns(request.max_tool_depth)
.build(); .build();
@ -77,7 +90,8 @@ where
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => { Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => {
step_count += 1; step_count += 1;
let t = text.text; let t = text.text;
on_chunk(ReactStep::Answer { step: step_count, answer: t.clone() }).await; let cleaned = t.replace('\n', "");
on_chunk(ReactStep::Answer { step: step_count, answer: cleaned }).await;
final_content.push_str(&t); final_content.push_str(&t);
} }
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(reasoning))) => { Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(reasoning))) => {
@ -120,7 +134,7 @@ where
} }
let elapsed_ms = session_start.elapsed().as_millis() as i64; let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, elapsed_ms).await; record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, elapsed_ms).await;
Ok((final_content, total_input_tokens, total_output_tokens)) Ok((final_content, total_input_tokens, total_output_tokens))
} }

View File

@ -7,6 +7,7 @@ use crate::embed::EmbedService;
use crate::error::Result; use crate::error::Result;
use crate::perception::PerceptionService; use crate::perception::PerceptionService;
use crate::tool::registry::ToolRegistry; use crate::tool::registry::ToolRegistry;
use queue::MessageProducer;
/// Result from streaming AI response. /// Result from streaming AI response.
pub struct StreamResult { pub struct StreamResult {
@ -94,7 +95,8 @@ impl ChatService {
) -> Option<crate::RigToolSet> { ) -> Option<crate::RigToolSet> {
self.tool_registry.as_ref().map(|registry| { self.tool_registry.as_ref().map(|registry| {
crate::RigToolSet::from_registry( crate::RigToolSet::from_registry(
registry, db, cache, config, room_id, sender_id, project_id, registry, db, cache, config, room_id, sender_id, project_id, None, None, None,
std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
) )
}) })
} }
@ -134,6 +136,35 @@ impl ChatService {
super::react_execution::execute_process_react( super::react_execution::execute_process_react(
request, on_chunk, registry, request, on_chunk, registry,
self.ai_base_url.clone(), self.ai_api_key.clone(), self.ai_base_url.clone(), self.ai_api_key.clone(),
None, None,
).await
}
/// Process AI request via rig-based ReAct streaming loop with room-specific tools.
///
/// Merges `room_tools` (e.g. `send_message`, `retract_message`) into the base
/// tool registry on-the-fly. The `room_preamble` is prepended to the default
/// system prompt to instruct the AI about room communication rules.
/// `message_producer` enables tools to publish events via the message queue.
pub async fn process_react_room<C, Fut>(
&self, request: &AiRequest, on_chunk: C,
room_tools: ToolRegistry,
room_preamble: Option<&str>,
message_producer: Option<MessageProducer>,
) -> Result<(String, i64, i64)>
where
C: FnMut(crate::react::ReactStep) -> Fut + Send,
Fut: std::future::Future<Output = ()> + Send,
{
let Some(registry) = &self.tool_registry else {
return Err(crate::error::AgentError::Internal("no tool registry registered".into()));
};
let mut merged = registry.clone();
merged.merge(room_tools);
super::react_execution::execute_process_react(
request, on_chunk, &merged,
self.ai_base_url.clone(), self.ai_api_key.clone(),
room_preamble, message_producer,
).await ).await
} }
} }

View File

@ -8,6 +8,7 @@ pub async fn record_ai_session(
cache: &AppCache, cache: &AppCache,
db: &AppDatabase, db: &AppDatabase,
project_id: Uuid, project_id: Uuid,
user_id: Uuid,
session_id: Uuid, session_id: Uuid,
room_id: Uuid, room_id: Uuid,
model_id: Uuid, model_id: Uuid,
@ -39,7 +40,7 @@ pub async fn record_ai_session(
} }
let (cost, currency, error_msg) = match crate::billing::record_ai_usage( let (cost, currency, error_msg) = match crate::billing::record_ai_usage(
db, project_id, version_id, input_tokens, output_tokens, db, project_id, user_id, version_id, input_tokens, output_tokens,
).await { ).await {
Ok(crate::billing::BillingResult::Success(record)) => { Ok(crate::billing::BillingResult::Success(record)) => {
(Some(record.cost), Some(record.currency), None) (Some(record.cost), Some(record.currency), None)
@ -70,7 +71,7 @@ async fn create_billing_error_system_message(
use models::rooms::{room_message, MessageContentType, MessageSenderType}; use models::rooms::{room_message, MessageContentType, MessageSenderType};
use sea_orm::Set; use sea_orm::Set;
let seq_key = format!("room:seq:{}", room_id); let seq_key = format!("seq:room:{}", room_id);
let seq = match cache.conn().await { let seq = match cache.conn().await {
Ok(mut conn) => { Ok(mut conn) => {
match redis::cmd("INCR").arg(&seq_key).query_async::<i64>(&mut conn).await { match redis::cmd("INCR").arg(&seq_key).query_async::<i64>(&mut conn).await {

View File

@ -62,7 +62,8 @@ pub async fn execute_process_stream(
&messages, &model_name, &config, temperature, max_tokens, &messages, &model_name, &config, temperature, max_tokens,
if tools_enabled { Some(&tools) } else { None }, None, if tools_enabled { Some(&tools) } else { None }, None,
Arc::new(move |delta| { Arc::new(move |delta| {
let fut = on_chunk_cb(AiStreamChunk { content: delta.to_string(), done: false, chunk_type: AiChunkType::Answer }); let content = delta.to_string().replace('\n', "");
let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer });
fut fut
}), }),
Arc::new(move |delta| { Arc::new(move |delta| {
@ -82,11 +83,10 @@ pub async fn execute_process_stream(
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty(); let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
if !has_tool_calls { if !has_tool_calls {
return handle_final_answer(response, full_content, on_chunk, all_chunks, &request, session_id, version_id, total_input_tokens, total_output_tokens, session_start).await; return handle_final_answer(response, all_chunks, &request, session_id, version_id, total_input_tokens, total_output_tokens, session_start).await;
} }
full_content.push_str(&response.content); full_content.push_str(&response.content);
full_content.push('\n');
let tool_calls: Vec<ToolCall> = response.tool_calls.iter().map(|tc| ToolCall { let tool_calls: Vec<ToolCall> = response.tool_calls.iter().map(|tc| ToolCall {
id: tc.id.clone(), type_: "function".into(), id: tc.id.clone(), type_: "function".into(),
@ -114,7 +114,7 @@ pub async fn execute_process_stream(
let max_depth_text = format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth); let max_depth_text = format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth);
on_chunk(AiStreamChunk { content: max_depth_text.clone(), done: true, chunk_type: AiChunkType::Answer }).await; on_chunk(AiStreamChunk { content: max_depth_text.clone(), done: true, chunk_type: AiChunkType::Answer }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: max_depth_text }); all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: max_depth_text });
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await; record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
return Ok(StreamResult { content: full_content, reasoning_content: String::new(), input_tokens: 0, output_tokens: 0, chunks: all_chunks }); return Ok(StreamResult { content: full_content, reasoning_content: String::new(), input_tokens: 0, output_tokens: 0, chunks: all_chunks });
} }
} }
@ -155,60 +155,83 @@ async fn execute_streaming_tools(
if let Some(registry) = tool_registry { ctx.registry_mut().merge(registry.clone()); } if let Some(registry) = tool_registry { ctx.registry_mut().merge(registry.clone()); }
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(request.db.clone(), session_id); let recorder = crate::tool::recorder::ToolCallRecorder::with_session(request.db.clone(), session_id);
let mut join_set = tokio::task::JoinSet::new();
for call in calls { for call in calls {
let start = std::time::Instant::now();
let call_clone = call.clone(); let call_clone = call.clone();
let mut ctx_clone = ctx.clone(); let mut ctx_clone = ctx.clone();
let (result_tx, mut result_rx) = tokio::sync::oneshot::channel(); let sender_uid = request.sender.uid;
tokio::spawn(async move { let recorder_clone = recorder.clone();
join_set.spawn(async move {
let start = std::time::Instant::now();
let executor = ToolExecutor::new(); let executor = ToolExecutor::new();
let res = executor.execute_batch(vec![call_clone], &mut ctx_clone).await; let res = executor.execute_batch(vec![call_clone.clone()], &mut ctx_clone).await;
let _ = result_tx.send(res); (call_clone, res, start.elapsed(), sender_uid, recorder_clone)
}); });
}
let heartbeat_dur = std::time::Duration::from_secs(10); let heartbeat_dur = std::time::Duration::from_secs(10);
let results = loop { while !join_set.is_empty() {
tokio::select! { tokio::select! {
res = &mut result_rx => { Some(res) = join_set.join_next() => {
match res { Ok(inner) => break inner, Err(_) => break Err(crate::tool::ToolError::ExecutionError("tool task cancelled".into())), } if let Ok((call, results, elapsed, sender_uid, recorder)) = res {
}, match results {
_ = tokio::time::sleep(heartbeat_dur) => { Ok(results) => {
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall }).await; for result in &results {
let text = match &result.result { crate::tool::ToolResult::Ok(v) => v.to_string(), crate::tool::ToolResult::Error(msg) => msg.clone() };
let preview = if text.len() > 300 {
let end = text.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
format!("{}...", &text[..end])
} else { text.clone() };
tracing::debug!("tool_result: {} — {}", call.name, preview);
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None };
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: sender_uid,
arguments: call.arguments_json().unwrap_or_default(),
status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success },
execution_time_ms: Some(elapsed.as_millis() as i64),
error_message: error_msg,
error_stack: None,
retry_count: 0
});
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
let msgs = ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
Err(e) => {
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: sender_uid,
arguments: call.arguments_json().unwrap_or_default(),
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed.as_millis() as i64),
error_message: Some(e.to_string()),
error_stack: None,
retry_count: 0
});
let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
}
}
} }
} },
}; _ = tokio::time::sleep(heartbeat_dur) => {
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall }).await;
match results {
Ok(results) => {
for result in &results {
let text = match &result.result { crate::tool::ToolResult::Ok(v) => v.to_string(), crate::tool::ToolResult::Error(msg) => msg.clone() };
let preview = if text.len() > 300 {
let end = text.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
format!("{}...", &text[..end])
} else { text.clone() };
tracing::debug!("tool_result: {} — {}", call.name, preview);
let elapsed = start.elapsed().as_millis() as i64;
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None };
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.id.clone(), session_id: recorder.session_id(), tool_name: call.name.clone(), caller: request.sender.uid, arguments: call.arguments_json().unwrap_or_default(), status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success }, execution_time_ms: Some(elapsed), error_message: error_msg, error_stack: None, retry_count: 0 });
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
let msgs = ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.id.clone(), session_id: recorder.session_id(), tool_name: call.name.clone(), caller: request.sender.uid, arguments: call.arguments_json().unwrap_or_default(), status: models::ai::ToolCallStatus::Failed, execution_time_ms: Some(elapsed), error_message: Some(e.to_string()), error_stack: None, retry_count: 0 });
let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
} }
} }
} }
@ -216,18 +239,20 @@ async fn execute_streaming_tools(
} }
async fn handle_final_answer( async fn handle_final_answer(
response: crate::client::StreamResponse, full_content: String, response: crate::client::StreamResponse,
on_chunk: SharedCallback,
mut all_chunks: Vec<StreamChunk>, request: &AiRequest, mut all_chunks: Vec<StreamChunk>, request: &AiRequest,
session_id: Uuid, version_id: Option<Uuid>, session_id: Uuid, version_id: Option<Uuid>,
total_input_tokens: i64, total_output_tokens: i64, total_input_tokens: i64, total_output_tokens: i64,
session_start: std::time::Instant, session_start: std::time::Instant,
) -> Result<StreamResult> { ) -> Result<StreamResult> {
let full_content = full_content + &response.content; let full_content = response.content.clone();
on_chunk(AiStreamChunk { content: response.content.clone(), done: true, chunk_type: AiChunkType::Answer }).await; // Don't broadcast the done chunk via SSE/NATS — incremental deltas
// already delivered the content; the separate completion event
// signals end of stream. Broadcasting full content again would
// duplicate it in the frontend streaming display.
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: response.content.clone() }); all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: response.content.clone() });
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await; record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
Ok(StreamResult { content: full_content, reasoning_content: response.reasoning_content, input_tokens: response.input_tokens, output_tokens: response.output_tokens, chunks: all_chunks }) Ok(StreamResult { content: full_content, reasoning_content: response.reasoning_content, input_tokens: total_input_tokens, output_tokens: total_output_tokens, chunks: all_chunks })
} }
async fn inject_passive_skills_stream( async fn inject_passive_skills_stream(

View File

@ -106,8 +106,10 @@ impl RetryState {
fn backoff_duration(&self) -> std::time::Duration { fn backoff_duration(&self) -> std::time::Duration {
let exp = self.attempt.min(5); let exp = self.attempt.min(5);
let base_ms = 500u64.saturating_mul(2u64.pow(exp)).min(self.max_backoff_ms); let base_ms = 500u64.saturating_mul(2u64.pow(exp)).min(self.max_backoff_ms);
let jitter = fastrand_u64(base_ms + 1); let max_jitter = (base_ms / 2).max(base_ms);
std::time::Duration::from_millis(jitter) let offset = fastrand_u64(max_jitter + 1).saturating_sub(base_ms / 2);
let total = base_ms.saturating_add(offset).min(self.max_backoff_ms);
std::time::Duration::from_millis(total)
} }
fn next(&mut self) { self.attempt += 1; } fn next(&mut self) { self.attempt += 1; }
} }

View File

@ -4,18 +4,14 @@ use models::rooms::room_message::{
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel, Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
}; };
use models::users::user::{Column as UserCol, Entity as User}; use models::users::user::{Column as UserCol, Entity as User};
use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder}; use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
use serde_json::Value;
use uuid::Uuid; use uuid::Uuid;
use crate::client::types::ChatRequestMessage; use crate::client::types::ChatRequestMessage;
use crate::client::AiClientConfig; use crate::client::AiClientConfig;
use crate::client::call_with_params; use crate::client::call_with_params;
use crate::AgentError; use crate::AgentError;
use crate::compact::helpers::summary_content; use crate::compact::types::{CompactLevel, CompactSummary, MessageSummary};
use crate::compact::types::{
CompactConfig, CompactLevel, CompactSummary, MessageSummary, ThresholdResult,
};
use crate::tokent::{TokenUsage, resolve_usage}; use crate::tokent::{TokenUsage, resolve_usage};
#[derive(Clone)] #[derive(Clone)]
@ -35,8 +31,29 @@ impl CompactService {
room_id: Uuid, room_id: Uuid,
level: CompactLevel, level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>, user_names: Option<std::collections::HashMap<Uuid, String>>,
requester_id: Uuid,
context_window_tokens: i32,
compaction_max_summary_ratio: f32,
) -> Result<CompactSummary, AgentError> { ) -> Result<CompactSummary, AgentError> {
let messages = self.fetch_room_messages(room_id).await?; // Verify room access at the database level to ensure auth context is enforced.
// Public rooms are accessible to project members.
// For simplicity in this audit fix, we'll fetch only if access exists.
let messages = self.fetch_room_messages_secure(room_id, requester_id).await?;
if messages.is_empty() {
// Check if room actually exists or if it's just empty/inaccessible
let room_exists = models::rooms::room::Entity::find_by_id(room_id)
.one(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?
.is_some();
if room_exists {
return Err(AgentError::Internal("Access denied or room empty".into()));
} else {
return Err(AgentError::Internal("Room not found".into()));
}
}
let user_ids: Vec<Uuid> = messages let user_ids: Vec<Uuid> = messages
.iter() .iter()
@ -74,7 +91,9 @@ impl CompactService {
.map(|m| Self::message_to_summary(m, &user_name_map)) .map(|m| Self::message_to_summary(m, &user_name_map))
.collect(); .collect();
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?; let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
// Build text of what was summarized (for tiktoken fallback) // Build text of what was summarized (for tiktoken fallback)
let summarized_text = to_summarize let summarized_text = to_summarize
@ -100,10 +119,13 @@ impl CompactService {
session_id: Uuid, session_id: Uuid,
level: CompactLevel, level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>, user_names: Option<std::collections::HashMap<Uuid, String>>,
context_window_tokens: i32,
compaction_max_summary_ratio: f32,
) -> Result<CompactSummary, AgentError> { ) -> Result<CompactSummary, AgentError> {
let messages: Vec<RoomMessageModel> = RoomMessage::find() let messages: Vec<RoomMessageModel> = RoomMessage::find()
.filter(RmCol::Room.eq(session_id)) .filter(RmCol::Room.eq(session_id))
.order_by_asc(RmCol::Seq) .order_by_asc(RmCol::Seq)
.limit(10000)
.all(&self.db) .all(&self.db)
.await .await
.map_err(|e| AgentError::Internal(e.to_string()))?; .map_err(|e| AgentError::Internal(e.to_string()))?;
@ -148,10 +170,10 @@ impl CompactService {
.map(|m| Self::message_to_summary(m, &user_name_map)) .map(|m| Self::message_to_summary(m, &user_name_map))
.collect(); .collect();
// Summarize the earlier messages let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
// Build text of what was summarized (for tiktoken fallback)
let summarized_text = to_summarize let summarized_text = to_summarize
.iter() .iter()
.map(|m| m.content.as_str()) .map(|m| m.content.as_str())
@ -170,164 +192,51 @@ impl CompactService {
}) })
} }
pub fn summary_as_system_message(summary: &CompactSummary) -> ChatRequestMessage { async fn fetch_room_messages_secure(
let content = summary_content(summary);
ChatRequestMessage::system(content)
}
/// Check if the message history for a room exceeds the token threshold.
/// Returns `ThresholdResult::Skip` if below threshold, `Compact` if above.
///
/// This method fetches messages and estimates their token count using tiktoken.
/// Call this before deciding whether to run full compaction.
pub async fn check_threshold(
&self, &self,
room_id: Uuid, room_id: Uuid,
config: CompactConfig, requester_id: Uuid,
) -> Result<ThresholdResult, AgentError> { ) -> Result<Vec<RoomMessageModel>, AgentError> {
let messages = self.fetch_room_messages(room_id).await?; use models::rooms::{RoomUserState, RoomAccess};
let tokens = self.estimate_message_tokens(&messages); use sea_orm::QueryTrait;
use sea_orm::sea_query::Expr;
if tokens < config.token_threshold { // Find messages for the room where the requester has access.
return Ok(ThresholdResult::Skip { // We check both the room_user_state table (membership) and the room_access table (explicit grants).
estimated_tokens: tokens, RoomMessage::find()
}); .filter(RmCol::Room.eq(room_id))
} .filter(
sea_orm::Condition::any()
.add(
Expr::exists(
RoomUserState::find()
.filter(models::rooms::room_user_state::Column::Room.eq(room_id))
.filter(models::rooms::room_user_state::Column::User.eq(requester_id))
.into_query()
)
)
.add(
Expr::exists(
RoomAccess::find()
.filter(models::rooms::room_access::Column::Room.eq(room_id))
.filter(models::rooms::room_access::Column::User.eq(requester_id))
.into_query()
)
)
)
.order_by_asc(RmCol::Seq)
.limit(10000)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))
}
let level = if config.auto_level { fn message_to_summary(m: &RoomMessageModel, user_name_map: &std::collections::HashMap<Uuid, String>) -> MessageSummary {
CompactLevel::auto_select(tokens, config.token_threshold) let sender_name = if let Some(user_id) = m.sender_id {
user_name_map.get(&user_id).cloned().unwrap_or_else(|| m.sender_type.to_string())
} else { } else {
config.default_level m.sender_type.to_string()
}; };
Ok(ThresholdResult::Compact {
estimated_tokens: tokens,
level,
})
}
/// Auto-compact a room: estimates token count, only compresses if over threshold.
///
/// This is the recommended entry point for automatic compaction.
/// - If tokens < threshold → returns a no-op summary (empty summary, no compression)
/// - If tokens >= threshold → compresses with auto-selected level
pub async fn compact_room_auto(
&self,
room_id: Uuid,
user_names: Option<std::collections::HashMap<Uuid, String>>,
config: CompactConfig,
) -> Result<CompactSummary, AgentError> {
let threshold_result = self.check_threshold(room_id, config).await?;
match threshold_result {
ThresholdResult::Skip { .. } => {
// Below threshold — no compaction needed, return empty summary
let messages = self.fetch_room_messages(room_id).await?;
let user_ids: Vec<Uuid> = messages.iter().filter_map(|m| m.sender_id).collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary: String::new(),
compacted_at: Utc::now(),
messages_compressed: 0,
usage: None,
});
}
ThresholdResult::Compact { level, .. } => {
// Above threshold — compress with selected level
return self
.compact_room_with_level(room_id, level, user_names)
.await;
}
}
}
/// Compact a room with a specific level (bypassing threshold check).
/// Use this when the caller has already decided compaction is needed.
async fn compact_room_with_level(
&self,
room_id: Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>,
) -> Result<CompactSummary, AgentError> {
let messages = self.fetch_room_messages(room_id).await?;
let user_ids: Vec<Uuid> = messages.iter().filter_map(|m| m.sender_id).collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
if messages.len() <= level.retain_count() {
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary: String::new(),
compacted_at: Utc::now(),
messages_compressed: 0,
usage: None,
});
}
let retain_count = level.retain_count();
let split_index = messages.len().saturating_sub(retain_count);
let (to_summarize, retained_messages) = messages.split_at(split_index);
let retained: Vec<MessageSummary> = retained_messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
let summarized_text = to_summarize
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary,
compacted_at: Utc::now(),
messages_compressed: to_summarize.len(),
usage: Some(usage),
})
}
/// Estimate total token count of a message list using tiktoken.
fn estimate_message_tokens(&self, messages: &[RoomMessageModel]) -> usize {
let total_chars: usize = messages.iter().map(|m| m.content.len()).sum();
// Rough estimate: ~4 chars per token (safe upper bound)
total_chars / 4
}
fn message_to_summary(
m: &RoomMessageModel,
user_name_map: &std::collections::HashMap<Uuid, String>,
) -> MessageSummary {
let sender_name = m
.sender_id
.and_then(|id| user_name_map.get(&id).cloned())
.unwrap_or_else(|| m.sender_type.to_string());
MessageSummary { MessageSummary {
id: m.id, id: m.id,
sender_type: m.sender_type.clone(), sender_type: m.sender_type.clone(),
@ -335,35 +244,11 @@ impl CompactService {
sender_name, sender_name,
content: m.content.clone(), content: m.content.clone(),
content_type: m.content_type.clone(), content_type: m.content_type.clone(),
tool_call_id: Self::extract_tool_call_id(&m.content), tool_call_id: None,
send_at: m.send_at, send_at: m.send_at,
} }
} }
fn extract_tool_call_id(content: &str) -> Option<String> {
let content = content.trim();
if let Ok(v) = serde_json::from_str::<Value>(content) {
v.get("tool_call_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
} else {
None
}
}
async fn fetch_room_messages(
&self,
room_id: Uuid,
) -> Result<Vec<RoomMessageModel>, AgentError> {
let messages: Vec<RoomMessageModel> = RoomMessage::find()
.filter(RmCol::Room.eq(room_id))
.order_by_asc(RmCol::Seq)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
Ok(messages)
}
async fn get_user_name_map( async fn get_user_name_map(
&self, &self,
user_ids: &[Uuid], user_ids: &[Uuid],
@ -386,8 +271,8 @@ impl CompactService {
async fn summarize_messages( async fn summarize_messages(
&self, &self,
messages: &[RoomMessageModel], messages: &[RoomMessageModel],
max_summary_tokens: usize,
) -> Result<(String, Option<TokenUsage>), AgentError> { ) -> Result<(String, Option<TokenUsage>), AgentError> {
// Collect distinct user IDs
let user_ids: Vec<Uuid> = messages let user_ids: Vec<Uuid> = messages
.iter() .iter()
.filter_map(|m| m.sender_id) .filter_map(|m| m.sender_id)
@ -395,10 +280,8 @@ impl CompactService {
.into_iter() .into_iter()
.collect(); .collect();
// Query usernames
let user_name_map = self.get_user_name_map(&user_ids).await?; let user_name_map = self.get_user_name_map(&user_ids).await?;
// Define sender mapper
let sender_mapper = |m: &RoomMessageModel| { let sender_mapper = |m: &RoomMessageModel| {
if let Some(user_id) = m.sender_id { if let Some(user_id) = m.sender_id {
if let Some(username) = user_name_map.get(&user_id) { if let Some(username) = user_name_map.get(&user_id) {
@ -413,11 +296,13 @@ impl CompactService {
let user_msg = ChatRequestMessage::user(format!( let user_msg = ChatRequestMessage::user(format!(
"Summarise the following conversation concisely, preserving all key facts, \ "Summarise the following conversation concisely, preserving all key facts, \
decisions, and any pending or in-progress work. \ decisions, and any pending or in-progress work. \
The summary MUST NOT exceed {} tokens. \
Use this format:\n\n\ Use this format:\n\n\
**Summary:** <one-paragraph overview>\n\ **Summary:** <one-paragraph overview>\n\
**Key decisions:** <bullet list or 'none'>\n\ **Key decisions:** <bullet list or 'none'>\n\
**Open items:** <bullet list or 'none'>\n\n\ **Open items:** <bullet list or 'none'>\n\n\
Conversation:\n\n{}", Conversation:\n\n{}",
max_summary_tokens,
body body
)); ));
@ -425,8 +310,8 @@ impl CompactService {
&[user_msg], &[user_msg],
&self.model, &self.model,
&self.ai_client_config, &self.ai_client_config,
0.3, // slightly higher temp for summarization 0.3,
1024, // max output tokens 2048,
None, None,
None, None,
None, None,
@ -434,7 +319,6 @@ impl CompactService {
.await .await
.map_err(|e| AgentError::OpenAi(e.to_string()))?; .map_err(|e| AgentError::OpenAi(e.to_string()))?;
// Prefer remote usage; fall back to None (caller will use tiktoken via resolve_usage)
let remote_usage = let remote_usage =
TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32); TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32);

View File

@ -74,6 +74,8 @@ pub struct CompactConfig {
pub auto_level: bool, pub auto_level: bool,
/// Fallback level when `auto_level` is false. /// Fallback level when `auto_level` is false.
pub default_level: CompactLevel, pub default_level: CompactLevel,
/// Maximum tokens the summary may contain (enforced via prompt).
pub max_summary_tokens: usize,
} }
impl Default for CompactConfig { impl Default for CompactConfig {
@ -83,6 +85,20 @@ impl Default for CompactConfig {
token_threshold: 8000, token_threshold: 8000,
auto_level: true, auto_level: true,
default_level: CompactLevel::Light, default_level: CompactLevel::Light,
max_summary_tokens: 256,
}
}
}
impl CompactConfig {
/// Build config from project context settings.
pub fn from_project_setting(context_window_tokens: i32, compaction_threshold: f32, compaction_max_summary_ratio: f32) -> Self {
let threshold = (context_window_tokens as f32 * compaction_threshold) as usize;
Self {
token_threshold: threshold,
auto_level: true,
default_level: CompactLevel::Light,
max_summary_tokens: (context_window_tokens as f32 * compaction_max_summary_ratio) as usize,
} }
} }
} }

View File

@ -575,11 +575,5 @@ pub struct EmbedMemoryInput {
} }
/// Input struct for batch tag embedding. /// Input struct for batch tag embedding.
#[derive(Debug, Clone)] /// Re-exported from models for backward compatibility.
pub struct TagEmbedInput { pub use models::TagEmbedInput;
pub repo_id: String,
pub repo_name: String,
pub project_id: String,
pub name: String,
pub description: Option<String>,
}

View File

@ -52,3 +52,9 @@ impl From<sea_orm::DbErr> for AgentError {
AgentError::Internal(e.to_string()) AgentError::Internal(e.to_string())
} }
} }
impl From<crate::tool::ToolError> for AgentError {
fn from(e: crate::tool::ToolError) -> Self {
AgentError::ToolExecutionFailed { tool: String::new(), cause: e.to_string() }
}
}

View File

@ -13,7 +13,7 @@ pub mod sync;
pub mod task; pub mod task;
pub mod tokent; pub mod tokent;
pub mod tool; pub mod tool;
pub use billing::{BillingRecord, BillingResult, record_ai_usage}; pub use billing::{BillingRecord, BillingResult, record_ai_usage, initialize_user_billing, initialize_project_billing, check_balance, persist_billing_error};
pub use sync::list_accessible_models; pub use sync::list_accessible_models;
pub use task::TaskService; pub use task::TaskService;
pub use tokent::{TokenUsage, resolve_usage}; pub use tokent::{TokenUsage, resolve_usage};
@ -33,7 +33,7 @@ pub use embed::{
EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult, TagEmbedInput, new_embed_client, EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult, TagEmbedInput, new_embed_client,
}; };
pub use error::{AgentError, Result}; pub use error::{AgentError, Result};
pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT}; pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT, ROOM_CONTEXT_PROMPT};
pub use tool::{ pub use tool::{
ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam, ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,
ToolRegistry, ToolResult, ToolSchema, ToolRegistry, ToolResult, ToolSchema,

View File

@ -44,6 +44,10 @@ impl Default for VectorActiveAwareness {
} }
impl VectorActiveAwareness { impl VectorActiveAwareness {
pub fn new(max_skills: usize, min_score: f32) -> Self {
Self { max_skills, min_score }
}
/// Search for skills semantically relevant to the user's input. /// Search for skills semantically relevant to the user's input.
/// ///
/// Uses Qdrant vector search within the given project to find skills whose /// Uses Qdrant vector search within the given project to find skills whose
@ -107,6 +111,10 @@ impl Default for VectorPassiveAwareness {
} }
impl VectorPassiveAwareness { impl VectorPassiveAwareness {
pub fn new(max_memories: usize, min_score: f32) -> Self {
Self { max_memories, min_score }
}
/// Search for past conversation messages semantically similar to the current context. /// Search for past conversation messages semantically similar to the current context.
/// ///
/// Uses Qdrant to find memories within the same room that share semantic similarity /// Uses Qdrant to find memories within the same room that share semantic similarity

View File

@ -16,7 +16,7 @@ pub const DEFAULT_SYSTEM_PROMPT: &str = r#"You are an AI assistant embedded in a
## Core Rule: Search Local Data First ## Core Rule: Search Local Data First
Always query the platform's local data before guessing or referring to external sources. Local data includes: issues, pull requests, repositories, code reviews, chat messages, documentation, members, and other workspace resources. Always query the platform's local data before guessing or referring to external sources. Local data includes: issues, pull requests, repositories, code reviews, chat messages, documentation, members, and other project resources.
If local data does not contain the answer, state that clearly before considering external information. If local data does not contain the answer, state that clearly before considering external information.
@ -38,3 +38,32 @@ If local data does not contain the answer, state that clearly before considering
- State ambiguity or uncertainty explicitly. - State ambiguity or uncertainty explicitly.
- Prefer facts over speculation. - Prefer facts over speculation.
"#; "#;
/// Room-specific system prompt appended when the AI is @mentioned in a chat room.
///
/// In room context, the AI must NOT produce long-form output directly. Instead,
/// it communicates through the `send_message` and `retract_message` tools.
/// This keeps room messages concise and gives the AI control over what appears
/// in the room.
pub const ROOM_CONTEXT_PROMPT: &str = r#"
## Room Communication Mode CRITICAL
You are NOT in a direct chat. You are @mentioned in a chat room. **Your default response text will NOT be seen by anyone.** The ONLY way to communicate with the room is through the tools listed below.
### Mandatory Communication Rules
1. **ALWAYS use `send_message`** to deliver ANY response to the room. No exceptions. If you produce a final text response without calling `send_message`, the room will receive NOTHING.
2. **Call `send_message` FIRST**, before any final text output. The tool call is what creates a visible room message.
3. **Keep each message concise** short, focused, actionable. No long reports, no multi-paragraph essays, no bullet lists longer than 5 items. If you need to convey a lot of information, summarize the key points and offer to provide details if asked.
4. **Use mentions** to reference entities: `@[user:uuid:username]` for users, `@[repo:uuid:name]` for repositories, `@[skill:slug]` for skills, `@[issue:uuid:title]` for issues, `@[ai:uuid:name]` for other AI models.
5. **Use `retract_message`** to revoke a message you just sent if it contains an error or needs to be withdrawn. You can only retract messages you sent in the current turn.
6. **You may send multiple messages** for complex responses, break your answer into multiple `send_message` calls (up to 99 per turn). Each message should be short, focused, and stand on its own. For example: first send a summary, then send follow-up details or action items as separate messages.
7. **After calling `send_message`, your final text response can be brief** just a summary or acknowledgment, since the actual room message has already been delivered via the tool call.
### Critical Reminder
Your response text output is NOT delivered to the room. The `send_message` tool IS the delivery mechanism. If you forget to call `send_message`, nobody in the room will see your response.
### Room-Only Tools
- `send_message(room_id?, content)` Send a brief message to the room. The `room_id` parameter is optional (defaults to the current room). The `content` parameter is required and supports `@[type:id:label]` mention syntax.
- `retract_message(message_id)` Retract (revoke) a message you sent in the current turn. Requires the message UUID returned by `send_message`.
"#;

View File

@ -9,8 +9,18 @@
//! return usage metadata (e.g., local models, streaming), tiktoken is used as //! return usage metadata (e.g., local models, streaming), tiktoken is used as
//! a fallback for accurate counting. //! a fallback for accurate counting.
use std::collections::HashMap;
use std::sync::OnceLock;
use std::sync::RwLock;
use crate::error::{AgentError, Result}; use crate::error::{AgentError, Result};
static TOKENIZER_CACHE: OnceLock<RwLock<HashMap<String, tiktoken_rs::CoreBPE>>> = OnceLock::new();
fn get_cached_tokenizers() -> &'static RwLock<HashMap<String, tiktoken_rs::CoreBPE>> {
TOKENIZER_CACHE.get_or_init(|| RwLock::new(HashMap::new()))
}
/// Token usage data. Use `from_remote()` when the API returns usage info, /// Token usage data. Use `from_remote()` when the API returns usage info,
/// or `from_estimate()` when falling back to tiktoken. /// or `from_estimate()` when falling back to tiktoken.
#[derive(Debug, Clone, Copy, Default, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, Copy, Default, serde::Serialize, serde::Deserialize)]
@ -155,14 +165,28 @@ fn safe_token_budget(context_limit: usize, reserve: usize) -> usize {
fn get_tokenizer(model: &str) -> Result<tiktoken_rs::CoreBPE> { fn get_tokenizer(model: &str) -> Result<tiktoken_rs::CoreBPE> {
use tiktoken_rs; use tiktoken_rs;
// Try model-specific tokenizer first {
if let Ok(bpe) = tiktoken_rs::get_bpe_from_model(model) { let cache = get_cached_tokenizers().read().unwrap();
return Ok(bpe); if let Some(bpe) = cache.get(model) {
return Ok(bpe.clone());
}
} }
// Fallback: use cl100k_base for unknown models // Try model-specific tokenizer first
tiktoken_rs::cl100k_base() let bpe = if let Ok(bpe) = tiktoken_rs::get_bpe_from_model(model) {
.map_err(|e| AgentError::Internal(format!("Failed to init tokenizer: {}", e))) bpe
} else {
// Fallback: use cl100k_base for unknown models
tiktoken_rs::cl100k_base()
.map_err(|e| AgentError::Internal(format!("Failed to init tokenizer: {}", e)))?
};
{
let mut cache = get_cached_tokenizers().write().unwrap();
cache.insert(model.to_string(), bpe.clone());
}
Ok(bpe)
} }
/// Estimate tokens for a simple prefix/suffix pattern (e.g., "assistant\n" + text). /// Estimate tokens for a simple prefix/suffix pattern (e.g., "assistant\n" + text).

View File

@ -8,6 +8,7 @@ use std::sync::Arc;
use db::cache::AppCache; use db::cache::AppCache;
use db::database::AppDatabase; use db::database::AppDatabase;
use config::AppConfig; use config::AppConfig;
use queue::MessageProducer;
use uuid::Uuid; use uuid::Uuid;
use super::registry::ToolRegistry; use super::registry::ToolRegistry;
@ -28,6 +29,15 @@ struct Inner {
pub project_id: Uuid, pub project_id: Uuid,
pub registry: ToolRegistry, pub registry: ToolRegistry,
pub embed_service: Option<crate::embed::EmbedService>, pub embed_service: Option<crate::embed::EmbedService>,
pub message_producer: Option<MessageProducer>,
/// When in room context, identifies the AI model that is responding.
/// Used by send_message/retract_message to set the correct sender.
pub ai_model_id: Option<Uuid>,
pub ai_model_name: Option<String>,
/// Message IDs sent by the AI in the current ReAct turn.
/// Shared across tool calls so send_message can register IDs
/// and retract_message can validate turn-scoped retraction.
pub sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<Uuid>>>,
depth: u32, depth: u32,
max_depth: u32, max_depth: u32,
tool_call_count: usize, tool_call_count: usize,
@ -52,6 +62,10 @@ impl ToolContext {
project_id: Uuid::nil(), project_id: Uuid::nil(),
registry: ToolRegistry::new(), registry: ToolRegistry::new(),
embed_service: None, embed_service: None,
message_producer: None,
ai_model_id: None,
ai_model_name: None,
sent_in_turn: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
depth: 0, depth: 0,
max_depth: 5, max_depth: 5,
tool_call_count: 0, tool_call_count: 0,
@ -85,10 +99,45 @@ impl ToolContext {
self self
} }
pub fn with_message_producer(mut self, producer: MessageProducer) -> Self {
Arc::make_mut(&mut self.inner).message_producer = Some(producer);
self
}
pub fn with_ai_model(mut self, model_id: Uuid, model_name: String) -> Self {
Arc::make_mut(&mut self.inner).ai_model_id = Some(model_id);
Arc::make_mut(&mut self.inner).ai_model_name = Some(model_name);
self
}
pub fn with_sent_in_turn(mut self, sent: std::sync::Arc<std::sync::Mutex<Vec<Uuid>>>) -> Self {
Arc::make_mut(&mut self.inner).sent_in_turn = sent;
self
}
/// Register a message ID as sent in the current turn (called by send_message).
pub fn register_sent_message(&self, id: Uuid) {
if let Ok(mut list) = self.inner.sent_in_turn.lock() {
list.push(id);
}
}
/// Check if a message ID was sent in the current turn (called by retract_message).
pub fn is_sent_in_turn(&self, id: Uuid) -> bool {
self.inner.sent_in_turn.lock()
.map(|list| list.contains(&id))
.unwrap_or(false)
}
pub fn embed_service(&self) -> Option<&crate::embed::EmbedService> { pub fn embed_service(&self) -> Option<&crate::embed::EmbedService> {
self.inner.embed_service.as_ref() self.inner.embed_service.as_ref()
} }
/// Message queue producer for publishing room events (messages, retractions, etc.).
pub fn message_producer(&self) -> Option<&MessageProducer> {
self.inner.message_producer.as_ref()
}
pub fn recursion_exceeded(&self) -> bool { pub fn recursion_exceeded(&self) -> bool {
self.inner.depth >= self.inner.max_depth self.inner.depth >= self.inner.max_depth
} }
@ -146,6 +195,16 @@ impl ToolContext {
self.inner.sender_id self.inner.sender_id
} }
/// AI model ID when in room context (the AI that is responding).
pub fn ai_model_id(&self) -> Option<Uuid> {
self.inner.ai_model_id
}
/// AI model display name when in room context.
pub fn ai_model_name(&self) -> Option<String> {
self.inner.ai_model_name.clone()
}
/// Project context for the room. /// Project context for the room.
pub fn project_id(&self) -> Uuid { pub fn project_id(&self) -> Uuid {
self.inner.project_id self.inner.project_id

View File

@ -14,6 +14,7 @@ use super::context::ToolContext;
use super::definition::ToolDefinition as AgentToolDefinition; use super::definition::ToolDefinition as AgentToolDefinition;
use super::recorder::{ToolCallRecord, ToolCallRecorder}; use super::recorder::{ToolCallRecord, ToolCallRecorder};
use super::registry::{ToolHandler, ToolRegistry}; use super::registry::{ToolHandler, ToolRegistry};
use queue::MessageProducer;
/// Returns true if the tool error message indicates a transient failure that can be retried. /// Returns true if the tool error message indicates a transient failure that can be retried.
pub fn is_retryable_tool_error(msg: &str) -> bool { pub fn is_retryable_tool_error(msg: &str) -> bool {
@ -170,6 +171,10 @@ impl RigToolSet {
room_id: uuid::Uuid, room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>, sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid, project_id: uuid::Uuid,
message_producer: Option<MessageProducer>,
ai_model_id: Option<uuid::Uuid>,
ai_model_name: Option<String>,
sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<uuid::Uuid>>>,
) -> Self { ) -> Self {
let mut toolset = ToolSet::default(); let mut toolset = ToolSet::default();
let mut definitions = HashMap::new(); let mut definitions = HashMap::new();
@ -191,6 +196,10 @@ impl RigToolSet {
room_id, room_id,
sender_id, sender_id,
project_id, project_id,
message_producer: message_producer.clone(),
ai_model_id,
ai_model_name: ai_model_name.clone(),
sent_in_turn: sent_in_turn.clone(),
}; };
toolset.add_tool(adapter); toolset.add_tool(adapter);
} }
@ -227,6 +236,10 @@ pub struct RigToolAdapter {
room_id: uuid::Uuid, room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>, sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid, project_id: uuid::Uuid,
message_producer: Option<MessageProducer>,
ai_model_id: Option<uuid::Uuid>,
ai_model_name: Option<String>,
sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<uuid::Uuid>>>,
} }
impl RigToolAdapter { impl RigToolAdapter {
@ -240,8 +253,12 @@ impl RigToolAdapter {
room_id: uuid::Uuid, room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>, sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid, project_id: uuid::Uuid,
message_producer: Option<MessageProducer>,
ai_model_id: Option<uuid::Uuid>,
ai_model_name: Option<String>,
sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<uuid::Uuid>>>,
) -> Self { ) -> Self {
Self { handler, definition, db, cache, config, room_id, sender_id, project_id } Self { handler, definition, db, cache, config, room_id, sender_id, project_id, message_producer, ai_model_id, ai_model_name, sent_in_turn }
} }
} }
@ -272,16 +289,27 @@ impl ToolDyn for RigToolAdapter {
let room_id = self.room_id; let room_id = self.room_id;
let sender_id = self.sender_id; let sender_id = self.sender_id;
let project_id = self.project_id; let project_id = self.project_id;
let message_producer = self.message_producer.clone();
let ai_model_id = self.ai_model_id;
let ai_model_name = self.ai_model_name.clone();
let sent_in_turn = self.sent_in_turn.clone();
async move { async move {
let ctx = ToolContext::new( let mut ctx = ToolContext::new(
db, db,
cache, cache,
config, config,
room_id, room_id,
sender_id, sender_id,
) )
.with_project(project_id); .with_project(project_id)
.with_sent_in_turn(sent_in_turn);
if let Some(mp) = message_producer {
ctx = ctx.with_message_producer(mp);
}
if let Some(mid) = ai_model_id {
ctx = ctx.with_ai_model(mid, ai_model_name.unwrap_or_default());
}
let args_json: serde_json::Value = serde_json::from_str(&args) let args_json: serde_json::Value = serde_json::from_str(&args)
.map_err(|e| ToolError::JsonError(e))?; .map_err(|e| ToolError::JsonError(e))?;

View File

@ -26,6 +26,7 @@ email = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
service = { workspace = true } service = { workspace = true }
session = { workspace = true } session = { workspace = true }
agent = { workspace = true }
git = { workspace = true } git = { workspace = true }
#frontend = { workspace = true } #frontend = { workspace = true }
models = { workspace = true } models = { workspace = true }
@ -51,5 +52,12 @@ sea-orm = "2.0.0-rc.37"
rust_decimal = "1.40.0" rust_decimal = "1.40.0"
actix-multipart = { workspace = true, features = ["tempfile"] } actix-multipart = { workspace = true, features = ["tempfile"] }
redis = { workspace = true } redis = { workspace = true }
reqwest = { workspace = true, features = ["json", "native-tls", "stream"] }
[build-dependencies]
brotli = "7"
flate2 = "1"
sha2 = "0.10"
[lints] [lints]
workspace = true workspace = true

View File

@ -18,7 +18,7 @@ pub fn init_agent_routes(cfg: &mut web::ServiceConfig) {
web::post().to(code_review::trigger_code_review), web::post().to(code_review::trigger_code_review),
) )
.route( .route(
"/{project}/issues/{issue_number}/triage", "/{project}/triage",
web::get().to(issue_triage::triage_issue), web::get().to(issue_triage::triage_issue),
) )
.route( .route(

View File

@ -1,11 +1,12 @@
use actix_web::{HttpResponse, Result, web}; use actix_web::{HttpResponse, Result, web};
use serde::Serialize; use serde::Serialize;
use session::SessionUser; use session::Session;
use utoipa::ToSchema; use utoipa::ToSchema;
use crate::ApiResponse; use crate::ApiResponse;
use crate::error::ApiError; use crate::error::ApiError;
use service::AppService; use service::AppService;
use service::error::AppError;
use service::ws_token::WS_TOKEN_TTL_SECONDS; use service::ws_token::WS_TOKEN_TTL_SECONDS;
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
@ -27,13 +28,16 @@ pub struct WsTokenResponse {
)] )]
pub async fn ws_token_generate( pub async fn ws_token_generate(
service: web::Data<AppService>, service: web::Data<AppService>,
session_user: SessionUser, session: Session,
) -> Result<HttpResponse, ApiError> { ) -> Result<HttpResponse, ApiError> {
let SessionUser(user_id) = session_user; let user_id = session.user().ok_or_else(|| ApiError::from(AppError::Unauthorized))?;
let device_id = session.get::<String>("device_id").unwrap_or_default();
let client_id = session.get::<String>("client_id").unwrap_or_default();
let token = service let token = service
.ws_token .ws_token
.generate_token(user_id) .generate_token(user_id, device_id, client_id)
.await .await
.map_err(ApiError::from)?; .map_err(ApiError::from)?;

View File

@ -1 +1,224 @@
fn main() {} //! Build script: reads all files from `dist/`, compresses them (brotli + gzip),
//! computes etags via SHA-256, and generates a `frontend` Rust module for `dist.rs`.
use std::collections::BTreeMap;
use std::env;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use sha2::{Digest, Sha256};
use flate2::write::GzEncoder;
use flate2::Compression;
// ── Compression helpers ──────────────────────────────────────────────────
fn gzip_compress(data: &[u8]) -> Vec<u8> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::new(6));
encoder.write_all(data).unwrap();
encoder.finish().unwrap()
}
fn brotli_compress(data: &[u8]) -> Option<Vec<u8>> {
use brotli::CompressorWriter;
let buf = Vec::new();
let mut writer = CompressorWriter::new(buf, 4096, 6, 16);
if writer.write_all(data).is_ok() && writer.flush().is_ok() {
Some(writer.into_inner())
} else {
None
}
}
// ── ETag computation ─────────────────────────────────────────────────────
fn compute_etag(data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(data);
let hash = hasher.finalize();
// First 32 hex chars for a compact etag
format!("{:x}", hash)[..32].to_string()
}
// ── Asset collection ─────────────────────────────────────────────────────
struct Asset {
path: String,
data: Vec<u8>,
etag: String,
brotli: Option<Vec<u8>>,
gzip: Vec<u8>,
}
fn collect_assets(dist_dir: &Path) -> BTreeMap<String, Asset> {
let mut assets = BTreeMap::new();
for entry in walkdir(dist_dir) {
let rel = entry.strip_prefix(dist_dir).unwrap();
let path_str = rel.to_string_lossy().replace('\\', "/");
if path_str.is_empty() {
continue;
}
let data = fs::read(&entry).unwrap_or_else(|e| {
panic!("Failed to read dist file {}: {}", path_str, e)
});
let etag = compute_etag(&data);
let brotli_data = brotli_compress(&data);
let gzip_data = gzip_compress(&data);
assets.insert(
path_str.clone(),
Asset {
path: path_str,
data,
etag,
brotli: brotli_data,
gzip: gzip_data,
},
);
}
assets
}
fn walkdir(dir: &Path) -> Vec<PathBuf> {
let mut files = Vec::new();
if let Ok(entries) = fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
files.extend(walkdir(&path));
} else {
files.push(path);
}
}
}
files
}
// ── Code generation ──────────────────────────────────────────────────────
fn rust_byte_literal(data: &[u8]) -> String {
if data.len() < 200 {
let bytes: Vec<String> = data.iter().map(|b| b.to_string()).collect();
format!("[{}]", bytes.join(", "))
} else {
let lines: Vec<String> = data
.chunks(80)
.map(|chunk| {
chunk.iter().map(|b| b.to_string()).collect::<Vec<_>>().join(", ")
})
.collect();
format!("[\n{}\n]", lines.join(",\n"))
}
}
fn path_to_ident(path: &str) -> String {
let s = path.replace('-', "_").replace('.', "_").replace('/', "_").to_uppercase();
format!("ASSET_{s}")
}
fn etag_ident(path: &str) -> String {
format!("ETAG_{}", path_to_ident(path))
}
fn br_ident(path: &str) -> String {
format!("{}_BR", path_to_ident(path))
}
fn gz_ident(path: &str) -> String {
format!("{}_GZ", path_to_ident(path))
}
fn generate_frontend_module(assets: &BTreeMap<String, Asset>, out_dir: &Path) {
let mut code = String::new();
code += "// AUTO-GENERATED by build.rs — DO NOT EDIT.\n";
code += "// Frontend assets from dist/ embedded as static byte arrays.\n\n";
// Generate static byte arrays for each asset
for (path, asset) in assets {
let ident = path_to_ident(path);
let etag_id = etag_ident(path);
let br_id = br_ident(path);
let gz_id = gz_ident(path);
code += &format!("static {}: &[u8] = &{};\n", ident, rust_byte_literal(&asset.data));
code += &format!("static {}: &str = \"{}\";\n", etag_id, asset.etag);
if let Some(ref br) = asset.brotli {
code += &format!("static {}: &[u8] = &{};\n", br_id, rust_byte_literal(br));
}
code += &format!("static {}: &[u8] = &{};\n", gz_id, rust_byte_literal(&asset.gzip));
code += "\n";
}
// Uncompressed lookup
code += "/// Get an uncompressed frontend asset by path, returning (data, etag).\n";
code += "pub fn get_frontend_asset_with_etag(path: &str) -> Option<(&'static [u8], &'static str)> {\n";
code += " match path {\n";
for (path, _asset) in assets {
let ident = path_to_ident(path);
let etag_id = etag_ident(path);
code += &format!(" \"{path}\" => Some((&{ident}, &{etag_id})),\n");
}
code += " _ => None,\n";
code += " }\n";
code += "}\n\n";
// Compressed lookup (prefers brotli, falls back to gzip)
code += "/// Get a pre-compressed frontend asset by path.\n";
code += "/// Returns (data, encoding, etag) — prefers brotli over gzip.\n";
code += "pub fn get_frontend_asset_compressed(path: &str) -> Option<(&'static [u8], &'static str, &'static str)> {\n";
code += " match path {\n";
for (path, asset) in assets {
let etag_id = etag_ident(path);
if asset.brotli.is_some() {
let br_id = br_ident(path);
code += &format!(" \"{path}\" => Some((&{br_id}, \"br\", &{etag_id})),\n");
} else {
let gz_id = gz_ident(path);
code += &format!(" \"{path}\" => Some((&{gz_id}, \"gzip\", &{etag_id})),\n");
}
}
code += " _ => None,\n";
code += " }\n";
code += "}\n";
let out_path = out_dir.join("frontend.rs");
fs::write(&out_path, code).unwrap_or_else(|e| {
panic!("Failed to write generated frontend.rs: {}", e)
});
}
fn main() {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let workspace_root = Path::new(&manifest_dir)
.parent()
.unwrap()
.parent()
.unwrap();
let dist_dir = workspace_root.join("dist");
if !dist_dir.exists() {
println!("cargo:warning=dist/ directory not found — frontend assets will not be embedded");
let out_dir = env::var("OUT_DIR").unwrap();
let out_path = Path::new(&out_dir).join("frontend.rs");
fs::write(
&out_path,
"//! No dist/ directory found — frontend assets not embedded.\n\
pub fn get_frontend_asset_with_etag(_path: &str) -> Option<(&'static [u8], &'static str)> { None }\n\
pub fn get_frontend_asset_compressed(_path: &str) -> Option<(&'static [u8], &'static str, &'static str)> { None }\n",
).unwrap();
return;
}
println!("cargo:rerun-if-changed=dist/");
let assets = collect_assets(&dist_dir);
println!("cargo:warning=Collected {} frontend assets from dist/", assets.len());
let out_dir = env::var("OUT_DIR").unwrap();
generate_frontend_module(&assets, Path::new(&out_dir));
}

View File

@ -0,0 +1,180 @@
use actix_web::{web, HttpResponse, Result};
use session::Session;
use service::error::AppError;
use uuid::Uuid;
use crate::error::ApiError;
use crate::ApiResponse;
use super::types::{ConversationListQuery, ConversationResponse, CreateConversationParams};
fn get_user_id(session: &Session) -> Result<Uuid, ApiError> {
session.user().ok_or_else(|| ApiError::from(AppError::Unauthorized))
}
#[utoipa::path(
post,
path = "/api/ai/conversations",
operation_id = "ai_conversation_create",
request_body = CreateConversationParams,
responses(
(status = 200, description = "Conversation created", body = ApiResponse<ConversationResponse>),
(status = 401, description = "Unauthorized"),
),
tag = "AI Chat"
)]
pub async fn conversation_create(
service: web::Data<service::AppService>,
session: Session,
params: web::Json<CreateConversationParams>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let model = params.model.clone().unwrap_or_else(|| "gpt-4".to_string());
let conversation = service
.create_conversation(
user_id,
params.project_id,
params.title.clone(),
model,
params.model_config.clone(),
params.access_visibility.clone(),
params.can_ask.clone(),
params.model_uid,
params.model_name.clone(),
)
.await?;
let resp = ConversationResponse::from(conversation);
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations",
operation_id = "ai_conversation_list",
params(
("project_id" = Option<Uuid>, Query, description = "Filter by project"),
),
responses(
(status = 200, description = "List of conversations", body = ApiResponse<Vec<ConversationResponse>>),
(status = 401, description = "Unauthorized"),
),
tag = "AI Chat"
)]
pub async fn conversation_list(
service: web::Data<service::AppService>,
session: Session,
query: web::Query<ConversationListQuery>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let convs = service
.list_conversations(user_id, query.project_id, 50)
.await?;
let resp: Vec<ConversationResponse> = convs
.into_iter()
.map(ConversationResponse::from)
.collect();
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}",
operation_id = "ai_conversation_get",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
),
responses(
(status = 200, description = "Get conversation", body = ApiResponse<ConversationResponse>),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn conversation_get(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let conversation_id = path.into_inner();
let c = service
.find_conversation_owned(conversation_id, user_id)
.await?;
let resp = ConversationResponse::from(c);
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
patch,
path = "/api/ai/conversations/{conversation_id}",
operation_id = "ai_conversation_update",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
),
request_body = super::types::UpdateConversationParams,
responses(
(status = 200, description = "Conversation updated"),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn conversation_update(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<Uuid>,
params: web::Json<super::types::UpdateConversationParams>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let conversation_id = path.into_inner();
service
.update_conversation(
conversation_id,
user_id,
params.title.clone(),
params.model.clone(),
params.model_config.clone(),
params.status.clone(),
params.access_visibility.clone(),
params.can_ask.clone(),
params.model_uid,
params.model_name.clone(),
)
.await?;
Ok(crate::api_success())
}
#[utoipa::path(
delete,
path = "/api/ai/conversations/{conversation_id}",
operation_id = "ai_conversation_delete",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
),
responses(
(status = 200, description = "Conversation deleted"),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn conversation_delete(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let conversation_id = path.into_inner();
service
.delete_conversation(conversation_id, user_id)
.await?;
Ok(crate::api_success())
}

View File

@ -0,0 +1,102 @@
use actix_web::{web, HttpResponse, Result};
use session::Session;
use service::error::AppError;
use uuid::Uuid;
use crate::error::ApiError;
use crate::ApiResponse;
#[derive(Debug, serde::Serialize, utoipa::ToSchema)]
pub struct ForkResponse {
pub id: Uuid,
pub conversation_id: Option<Uuid>,
pub source_message_id: Uuid,
pub fork_message_id: Uuid,
#[schema(value_type = chrono::DateTime<chrono::Utc>)]
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/fork/{target_message_id}",
operation_id = "ai_message_fork",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Source message ID"),
("target_message_id" = Uuid, Path, description = "Target/fork message ID to create"),
),
responses(
(status = 200, description = "Fork created", body = ApiResponse<ForkResponse>),
),
tag = "AI Chat"
)]
pub async fn message_fork(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = session
.user()
.ok_or_else(|| ApiError::from(AppError::Unauthorized))?;
let (conversation_id, source_message_id, target_message_id) = path.into_inner();
let fork_record = service
.fork_message(
conversation_id,
user_id,
source_message_id,
target_message_id,
)
.await?;
let resp = ForkResponse {
id: fork_record.id,
conversation_id: fork_record.conversation_id,
source_message_id: fork_record.source_message_id,
fork_message_id: fork_record.fork_message_id,
created_at: fork_record.created_at,
};
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/forks",
operation_id = "ai_message_forks",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Source message ID"),
),
responses(
(status = 200, description = "List forks from message", body = ApiResponse<Vec<ForkResponse>>),
),
tag = "AI Chat"
)]
pub async fn message_forks(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = session
.user()
.ok_or_else(|| ApiError::from(AppError::Unauthorized))?;
let (conversation_id, source_message_id) = path.into_inner();
let forks = service
.list_forks(conversation_id, user_id, source_message_id)
.await?;
let resp: Vec<ForkResponse> = forks
.into_iter()
.map(|f| ForkResponse {
id: f.id,
conversation_id: f.conversation_id,
source_message_id: f.source_message_id,
fork_message_id: f.fork_message_id,
created_at: f.created_at,
})
.collect();
Ok(ApiResponse::ok(resp).to_response())
}

View File

@ -0,0 +1,346 @@
use crate::error::ApiError;
use crate::ApiResponse;
use actix_web::{web, HttpResponse, Result};
use session::Session;
use service::error::AppError;
use uuid::Uuid;
use super::types::{CreateMessageParams, EditMessageParams, MessageListQuery, MessageResponse};
fn get_user_id(session: &Session) -> Result<Uuid, ApiError> {
session.user().ok_or_else(|| ApiError::from(AppError::Unauthorized))
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/messages",
operation_id = "ai_message_list",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("limit" = Option<i64>, Query, description = "Max messages"),
),
responses(
(status = 200, description = "List messages", body = ApiResponse<Vec<MessageResponse>>),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn message_list(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<Uuid>,
query: web::Query<MessageListQuery>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let conversation_id = path.into_inner();
let limit = query.limit.unwrap_or(50) as u64;
let msgs = service
.list_messages(conversation_id, user_id, limit)
.await?;
let resp: Vec<MessageResponse> = msgs
.into_iter()
.map(MessageResponse::from)
.collect();
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/messages",
operation_id = "ai_message_create",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
),
request_body = CreateMessageParams,
responses(
(status = 200, description = "Message created", body = ApiResponse<MessageResponse>),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn message_create(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<Uuid>,
params: web::Json<CreateMessageParams>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let conversation_id = path.into_inner();
let msg = service
.create_message(
conversation_id,
user_id,
params.parent_message_id,
params.content.role.clone(),
params.content.content.clone(),
params.model.clone(),
params.is_fork_origin.unwrap_or(false),
params.metadata.clone(),
params.room_id,
)
.await?;
let resp = MessageResponse::from(msg);
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}",
operation_id = "ai_message_get",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID"),
),
responses(
(status = 200, description = "Get message", body = ApiResponse<MessageResponse>),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn message_get(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
let msg = service
.get_message(conversation_id, user_id, message_id)
.await?;
let resp = MessageResponse::from(msg);
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/stop",
operation_id = "ai_message_stop",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID"),
),
responses(
(status = 200, description = "Message stopped"),
),
tag = "AI Chat"
)]
pub async fn message_stop(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
service
.stop_message(conversation_id, user_id, message_id)
.await?;
Ok(crate::api_success())
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/resend",
operation_id = "ai_message_resend",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID"),
),
responses(
(status = 200, description = "Resend message", body = ApiResponse<MessageResponse>),
),
tag = "AI Chat"
)]
pub async fn message_resend(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
let new_msg = service
.resend_message(conversation_id, user_id, message_id)
.await?;
let resp = MessageResponse::from(new_msg);
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/children",
operation_id = "ai_message_children",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Parent message ID"),
),
responses(
(status = 200, description = "List child messages", body = ApiResponse<Vec<MessageResponse>>),
),
tag = "AI Chat"
)]
pub async fn message_children(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, parent_message_id) = path.into_inner();
let msgs = service
.list_child_messages(conversation_id, user_id, parent_message_id)
.await?;
let resp: Vec<MessageResponse> = msgs
.into_iter()
.map(MessageResponse::from)
.collect();
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/stream",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID"),
),
responses(
(status = 200, description = "SSE stream"),
),
tag = "AI Chat"
)]
pub async fn message_stream(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
// Verify user owns the conversation
let conv = service
.find_conversation_owned(conversation_id, user_id)
.await?;
let model = conv.model;
let response = actix_web::HttpResponse::Ok()
.content_type("text/event-stream")
.insert_header(("Cache-Control", "no-cache"))
.insert_header(("X-Accel-Buffering", "no"))
.streaming(super::super::stream::create_chat_sse_stream(
service.get_ref().clone(),
conversation_id,
message_id,
model,
user_id,
));
Ok(response.into())
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/edit",
operation_id = "ai_message_edit",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID to edit"),
),
request_body = EditMessageParams,
responses(
(status = 200, description = "Message edited, new version created", body = ApiResponse<MessageResponse>),
),
tag = "AI Chat"
)]
pub async fn message_edit(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
params: web::Json<EditMessageParams>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
let new_msg = service
.edit_message(conversation_id, user_id, message_id, params.content.clone())
.await?;
let resp = MessageResponse::from(new_msg);
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/versions",
operation_id = "ai_message_versions",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID"),
),
responses(
(status = 200, description = "List message versions", body = ApiResponse<Vec<MessageResponse>>),
),
tag = "AI Chat"
)]
pub async fn message_versions(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
let versions = service
.list_message_versions(conversation_id, user_id, message_id)
.await?;
let resp: Vec<MessageResponse> = versions
.into_iter()
.map(MessageResponse::from)
.collect();
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/switch-version",
operation_id = "ai_message_switch_version",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID"),
),
request_body = super::types::SwitchVersionParams,
responses(
(status = 200, description = "Version switched", body = ApiResponse<MessageResponse>),
),
tag = "AI Chat"
)]
pub async fn message_switch_version(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
params: web::Json<super::types::SwitchVersionParams>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
let msg = service
.switch_message_version(conversation_id, user_id, message_id, params.version_number)
.await?;
let resp = MessageResponse::from(msg);
Ok(ApiResponse::ok(resp).to_response())
}

View File

@ -0,0 +1,5 @@
pub mod conversation;
pub mod fork;
pub mod message;
pub mod share;
pub mod types;

View File

@ -0,0 +1,76 @@
use actix_web::{web, HttpResponse, Result};
use session::Session;
use service::error::AppError;
use uuid::Uuid;
use crate::error::ApiError;
use crate::ApiResponse;
use super::types::{ConversationResponse, ShareResponse};
fn get_user_id(session: &Session) -> Result<Uuid, ApiError> {
session.user().ok_or_else(|| ApiError::from(AppError::Unauthorized))
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/share",
operation_id = "ai_conversation_share",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
),
responses(
(status = 200, description = "Share token created", body = ApiResponse<ShareResponse>),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn conversation_share(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let conversation_id = path.into_inner();
let (share, share_token) = service
.share_conversation(conversation_id, user_id)
.await?;
let resp = ShareResponse {
id: share.id,
share_token,
view_count: share.view_count,
expires_at: share.expires_at,
};
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/share/{share_token}",
operation_id = "ai_shared_conversation_get",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("share_token" = String, Path, description = "Share token"),
),
responses(
(status = 200, description = "Get shared conversation", body = ApiResponse<ConversationResponse>),
(status = 404, description = "Not found or expired"),
),
tag = "AI Chat"
)]
pub async fn shared_conversation_get(
service: web::Data<service::AppService>,
path: web::Path<(Uuid, String)>,
) -> Result<HttpResponse, ApiError> {
let (conversation_id, share_token) = path.into_inner();
let c = service
.get_shared_conversation(conversation_id, share_token)
.await?;
let resp = ConversationResponse::from(c);
Ok(ApiResponse::ok(resp).to_response())
}

View File

@ -0,0 +1,175 @@
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateConversationParams {
pub project_id: Option<Uuid>,
pub title: Option<String>,
pub model: Option<String>,
pub model_config: Option<serde_json::Value>,
pub access_visibility: Option<String>,
pub can_ask: Option<String>,
/// AI model UUID for model selection
pub model_uid: Option<Uuid>,
/// AI model display name
pub model_name: Option<String>,
}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct ConversationResponse {
pub id: Uuid,
pub user_id: Uuid,
pub project_id: Option<Uuid>,
pub scope: String,
pub title: Option<String>,
pub model: String,
pub model_config: Option<serde_json::Value>,
pub status: String,
pub root_message_id: Option<Uuid>,
pub fork_count: i32,
pub is_shared: bool,
pub message_count: i32,
pub token_usage_total: Option<i32>,
pub access_visibility: String,
pub can_ask: String,
pub project_uid: Option<i32>,
pub model_uid: Option<Uuid>,
pub model_name: Option<String>,
#[schema(value_type = chrono::DateTime<chrono::Utc>)]
pub created_at: chrono::DateTime<chrono::Utc>,
#[schema(value_type = chrono::DateTime<chrono::Utc>)]
pub updated_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateConversationParams {
pub title: Option<String>,
pub model: Option<String>,
pub model_config: Option<serde_json::Value>,
pub status: Option<String>,
pub access_visibility: Option<String>,
pub can_ask: Option<String>,
pub model_uid: Option<Uuid>,
pub model_name: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ConversationListQuery {
pub project_id: Option<Uuid>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct MessageContent {
pub role: String,
pub content: serde_json::Value,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateMessageParams {
pub parent_message_id: Option<Uuid>,
pub content: MessageContent,
pub model: Option<String>,
pub is_fork_origin: Option<bool>,
pub metadata: Option<serde_json::Value>,
pub room_id: Option<Uuid>,
}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct MessageResponse {
pub id: Uuid,
pub conversation_id: Uuid,
pub parent_message_id: Option<Uuid>,
pub role: String,
pub content: serde_json::Value,
pub model: Option<String>,
pub is_fork_origin: bool,
pub stop_reason: Option<String>,
pub input_tokens: Option<i32>,
pub output_tokens: Option<i32>,
pub latency_ms: Option<i32>,
pub metadata: Option<serde_json::Value>,
pub room_id: Option<Uuid>,
pub version_group_id: Option<Uuid>,
pub version_number: i32,
pub is_latest: bool,
#[schema(value_type = chrono::DateTime<chrono::Utc>)]
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Deserialize)]
pub struct MessageListQuery {
pub limit: Option<i64>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct EditMessageParams {
pub content: String,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct SwitchVersionParams {
pub version_number: i32,
}
#[derive(Debug, Deserialize)]
pub struct ForkParams {}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct ShareResponse {
pub id: Uuid,
pub share_token: String,
pub view_count: i32,
#[schema(value_type = Option<chrono::DateTime<chrono::Utc>>)]
pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
}
impl From<models::ai::ai_conversation::Model> for ConversationResponse {
fn from(c: models::ai::ai_conversation::Model) -> Self {
Self {
id: c.id,
user_id: c.user_id,
project_id: c.project_id,
scope: c.scope,
title: c.title,
model: c.model,
model_config: c.model_config,
status: c.status,
root_message_id: c.root_message_id,
fork_count: c.fork_count,
is_shared: c.is_shared,
message_count: c.message_count,
token_usage_total: c.token_usage_total,
access_visibility: c.access_visibility,
can_ask: c.can_ask,
project_uid: c.project_uid,
model_uid: c.model_uid,
model_name: c.model_name,
created_at: c.created_at,
updated_at: c.updated_at,
}
}
}
impl From<models::ai::ai_message::Model> for MessageResponse {
fn from(m: models::ai::ai_message::Model) -> Self {
Self {
id: m.id,
conversation_id: m.conversation_id,
parent_message_id: m.parent_message_id,
role: m.role,
content: m.content,
model: m.model,
is_fork_origin: m.is_fork_origin,
stop_reason: m.stop_reason,
input_tokens: m.input_tokens,
output_tokens: m.output_tokens,
latency_ms: m.latency_ms,
metadata: m.metadata,
room_id: m.room_id,
version_group_id: m.version_group_id,
version_number: m.version_number,
is_latest: m.is_latest,
created_at: m.created_at,
}
}
}

85
libs/api/chat/mod.rs Normal file
View File

@ -0,0 +1,85 @@
use actix_web::web;
pub mod handlers;
pub mod stream;
pub mod watch;
pub fn init_chat_routes(cfg: &mut web::ServiceConfig) {
cfg.service(
web::scope("/ai/conversations")
.route("", web::post().to(handlers::conversation::conversation_create))
.route("", web::get().to(handlers::conversation::conversation_list))
.route(
"/{conversation_id}",
web::get().to(handlers::conversation::conversation_get),
)
.route(
"/{conversation_id}",
web::patch().to(handlers::conversation::conversation_update),
)
.route(
"/{conversation_id}",
web::delete().to(handlers::conversation::conversation_delete),
)
.route(
"/{conversation_id}/watch",
web::get().to(watch::conversation_watch),
)
.route(
"/{conversation_id}/share",
web::post().to(handlers::share::conversation_share),
)
.route(
"/{conversation_id}/share/{share_token}",
web::get().to(handlers::share::shared_conversation_get),
)
.route(
"/{conversation_id}/messages",
web::get().to(handlers::message::message_list),
)
.route(
"/{conversation_id}/messages",
web::post().to(handlers::message::message_create),
)
.route(
"/{conversation_id}/messages/{message_id}",
web::get().to(handlers::message::message_get),
)
.route(
"/{conversation_id}/messages/{message_id}/stop",
web::post().to(handlers::message::message_stop),
)
.route(
"/{conversation_id}/messages/{message_id}/resend",
web::post().to(handlers::message::message_resend),
)
.route(
"/{conversation_id}/messages/{message_id}/fork/{target_message_id}",
web::post().to(handlers::fork::message_fork),
)
.route(
"/{conversation_id}/messages/{message_id}/forks",
web::get().to(handlers::fork::message_forks),
)
.route(
"/{conversation_id}/messages/{message_id}/stream",
web::get().to(handlers::message::message_stream),
)
.route(
"/{conversation_id}/messages/{message_id}/children",
web::get().to(handlers::message::message_children),
)
.route(
"/{conversation_id}/messages/{message_id}/edit",
web::post().to(handlers::message::message_edit),
)
.route(
"/{conversation_id}/messages/{message_id}/versions",
web::get().to(handlers::message::message_versions),
)
.route(
"/{conversation_id}/messages/{message_id}/switch-version",
web::post().to(handlers::message::message_switch_version),
),
);
}

463
libs/api/chat/stream.rs Normal file
View File

@ -0,0 +1,463 @@
use agent::chat::chat_execution;
use agent::chat::{normalize_thinking_content, AiChunkType, AiStreamChunk};
use agent::client::AiClientConfig;
use agent::client::types::ChatRequestMessage;
use agent::client::StreamChunkType;
use futures::StreamExt;
use models::ai::{ai_message, ai_conversation, AiMessage};
use queue::{ChatMessageEvent, ChatStreamChunkEvent};
use sea_orm::{EntityTrait, QueryFilter, ColumnTrait, QueryOrder, ActiveModelTrait, Set, PaginatorTrait};
use service::AppService;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid;
/// Create an SSE stream that executes AI chat with ReAct tool-calling.
///
/// Also publishes chat messages and stream chunks via NATS JetStream for
/// multi-viewer support. The requesting client receives SSE events, while
/// other viewers receive chunks via NATS → WebSocket broadcast.
pub fn create_chat_sse_stream(
service: AppService,
conversation_id: Uuid,
user_message_id: Uuid,
model_name: String,
user_id: Uuid,
) -> Pin<Box<dyn futures::Stream<Item = Result<actix_web::web::Bytes, actix_web::Error>> + Send>> {
let (tx, rx) = tokio::sync::mpsc::channel::<String>(100);
let cache = service.cache.clone();
tokio::spawn(async move {
// Check for active stream (SSE reconnect recovery) BEFORE starting a new one
// so the frontend can recover from a page refresh.
if let Some((msg_id, started_at)) = cache.get_chat_stream_active(conversation_id).await {
let _ = tx.send(format!(
"data: {{\"event\":\"recovery\",\"data\":{{\"message_id\":\"{}\",\"started_at\":{}}}}}\n\n",
msg_id,
started_at
)).await;
}
let queue = service.queue_producer.clone();
let chunk_seq = Arc::new(AtomicU64::new(0));
// Build messages from conversation history
let messages = match build_messages_from_history(&service, conversation_id).await {
Ok(msgs) => msgs,
Err(e) => {
let _ = tx.send(format!("data: {{\"event\":\"error\",\"data\":\"{}\"}}\n\n", e)).await;
return;
}
};
// Get AI config
let api_key = match service.config.ai_api_key() {
Ok(k) => k,
Err(_) => {
let _ = tx.send("data: {\"event\":\"error\",\"data\":\"AI not configured\"}\n\n".to_string()).await;
return;
}
};
let base_url = match service.config.ai_basic_url() {
Ok(u) => u,
Err(_) => {
let _ = tx.send("data: {\"event\":\"error\",\"data\":\"AI not configured\"}\n\n".to_string()).await;
return;
}
};
let config = AiClientConfig::new(api_key).with_base_url(&base_url);
// Get tools from ChatService if available
let (tools, tool_registry, embed_service) = match &service.chat_service {
Some(cs) => (
cs.tools(),
cs.tool_registry().cloned(),
service.embed_service.as_ref().map(|es| (**es).clone()),
),
None => (Vec::new(), None, None),
};
// Get project_id from conversation
let project_id = match service.find_conversation(conversation_id).await {
Ok(c) => c.project_id.unwrap_or(Uuid::nil()),
Err(_) => {
let _ = tx.send("data: {\"event\":\"error\",\"data\":\"conversation not found\"}\n\n".to_string()).await;
return;
}
};
// Pre-flight balance check: verify project + user can afford at least a minimal AI call
let balance_ok = agent::billing::check_balance(
&service.db, project_id, user_id, Uuid::nil(), 500, 250,
).await;
match balance_ok {
Ok(true) => {},
Ok(false) => {
tracing::warn!(project_id = %project_id, user_id = %user_id, "Insufficient balance for chat AI call");
let _ = agent::billing::persist_billing_error(
&service.db, "user", user_id, "insufficient_balance",
&format!("Insufficient balance. Your account does not have enough funds for this AI request."),
Some(serde_json::json!({
"user_id": user_id.to_string(),
"project_id": project_id.to_string(),
})),
).await;
let error_msg = "Insufficient balance. Your account does not have enough funds to process this AI request. Please add credits to continue.";
let _ = tx.send(format!("data: {{\"event\":\"billing_error\",\"data\":\"{}\"}}\n\n", error_msg)).await;
let _ = tx.send("data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n".to_string()).await;
return;
},
Err(e) => {
tracing::warn!(error = %e, "Balance check failed, proceeding without pre-flight check");
}
}
let max_tool_depth = 99;
// Determine conversation project_id for chat message event
let conv_project_id = match service.find_conversation(conversation_id).await {
Ok(c) => c.project_id,
Err(_) => None,
};
// Broadcast chat message start event via NATS
let chat_msg = ChatMessageEvent {
message_id: user_message_id,
conversation_id,
project_id: conv_project_id,
sender_id: Uuid::nil(),
role: "assistant".to_string(),
content: String::new(),
model: Some(model_name.clone()),
input_tokens: None,
output_tokens: None,
timestamp: chrono::Utc::now(),
};
let _ = queue.publish_chat_message(&chat_msg).await;
// Mark stream as active in Redis so page refresh can recover
let _ = cache.set_chat_stream_active(conversation_id, user_message_id).await;
let on_chunk_tx = tx.clone();
let on_chunk_queue = queue.clone();
let on_chunk_seq = chunk_seq.clone();
let on_chunk_conv_id = conversation_id;
let on_chunk_msg_id = user_message_id;
let on_chunk_model = model_name.clone();
let on_chunk: agent::chat::StreamCallback = Box::new(move |chunk: AiStreamChunk| {
let tx = on_chunk_tx.clone();
let queue = on_chunk_queue.clone();
let seq = on_chunk_seq.fetch_add(1, Ordering::Relaxed);
let conv_id = on_chunk_conv_id;
let msg_id = on_chunk_msg_id;
let model = on_chunk_model.clone();
Box::pin(async move {
let event = match chunk.chunk_type {
AiChunkType::Thinking => "thinking",
AiChunkType::Answer => "token",
AiChunkType::ToolCall => "tool_call",
AiChunkType::ToolResult => "tool_result",
};
let content = match chunk.chunk_type {
AiChunkType::Thinking => normalize_thinking_content(&chunk.content),
_ => chunk.content.clone(),
};
let sse = format!(
"data: {{\"event\":\"{}\",\"data\":{}}}\n\n",
event,
serde_json::to_string(&content).unwrap_or_default()
);
let _ = tx.send(sse).await;
// Also broadcast via NATS for other viewers
let natts_chunk = ChatStreamChunkEvent {
conversation_id: conv_id,
message_id: msg_id,
seq,
content,
done: false,
error: None,
chunk_type: Some(event.to_string()),
model_name: Some(model),
};
queue.publish_chat_chunk(&natts_chunk).await;
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
});
let result = chat_execution::execute_chat_stream(
messages,
tools,
&model_name,
&config,
0.7, // temperature
4096, // max_tokens
max_tool_depth,
tool_registry.as_ref(),
service.db.clone(),
service.cache.clone(),
service.config.clone(),
project_id,
Uuid::nil(), // sender_uid — unknown in Chat API context
embed_service,
on_chunk,
Some(conversation_id),
).await;
// Clear stream active state (streaming finished)
let _ = cache.clear_chat_stream_active(conversation_id).await;
match result {
Ok(stream_result) => {
// Build ordered content blocks from stream chunks, merging
// consecutive blocks of the same role (thinking/assistant).
let raw_blocks: Vec<(String, String)> = stream_result.chunks.iter()
.filter(|c| matches!(c.chunk_type, StreamChunkType::Thinking | StreamChunkType::Answer))
.map(|chunk| {
let role = match chunk.chunk_type {
StreamChunkType::Thinking => "thinking",
_ => "assistant",
};
(role.to_string(), chunk.content.clone())
})
.collect();
let merged_blocks = merge_consecutive_blocks(raw_blocks);
// Apply thinking normalization to the fully merged thinking
// blocks — per-token normalization is meaningless since each
// chunk is a single token.
let normalized_blocks: Vec<(String, String)> = merged_blocks.into_iter().map(|(role, content)| {
if role == "thinking" {
(role, normalize_thinking_content(&content))
} else {
(role, content)
}
}).collect();
let content_blocks: Vec<serde_json::Value> = normalized_blocks.iter()
.map(|(role, content)| serde_json::json!({ "role": role, "content": content }))
.collect();
let content_value = if content_blocks.is_empty() {
serde_json::json!([{ "role": "assistant", "content": stream_result.content }])
} else {
serde_json::json!(content_blocks)
};
// Persist assistant message
let assistant_msg_id = Uuid::now_v7();
let assistant_msg = ai_message::ActiveModel {
id: Set(assistant_msg_id),
conversation_id: Set(conversation_id),
parent_message_id: Set(Some(user_message_id)),
role: Set("assistant".to_string()),
content: Set(content_value),
model: Set(Some(model_name.clone())),
is_fork_origin: Set(false),
stop_reason: Set(Some("stop".to_string())),
input_tokens: Set(Some(stream_result.input_tokens as i32)),
output_tokens: Set(Some(stream_result.output_tokens as i32)),
latency_ms: Set(None),
metadata: Set(None),
room_id: Set(None),
version_group_id: Set(Some(assistant_msg_id)),
version_number: Set(1),
is_latest: Set(true),
created_at: Set(chrono::Utc::now()),
};
let saved = assistant_msg.insert(service.db.writer()).await;
if let Ok(msg) = &saved {
update_conversation_after_response(&service, conversation_id, msg).await;
// After AI response, check/update conversation title and emit via SSE
if let Ok(Some(conv)) = ai_conversation::Entity::find_by_id(conversation_id)
.one(service.db.reader()).await
{
let existing_title = conv.title.clone();
let needs_title = existing_title.as_deref().map(|t| t.is_empty() || t == "New Chat").unwrap_or(true);
if needs_title {
// Generate title from first user message
let first_user_msg = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.filter(ai_message::Column::Role.eq("user"))
.order_by_asc(ai_message::Column::CreatedAt)
.one(service.db.reader()).await.ok().flatten();
if let Some(user_msg) = first_user_msg {
let content = match &user_msg.content {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => {
arr.first()
.and_then(|f| f.get("content"))
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string()
}
other => other.to_string(),
};
// Simple title extraction: first meaningful words
let title = content
.split_whitespace()
.filter(|w| w.len() > 2)
.take(5)
.collect::<Vec<_>>()
.join(" ");
if !title.is_empty() {
let truncated: String = title.chars().take(40).collect();
// Save title to DB
let mut active: ai_conversation::ActiveModel = conv.into();
active.title = Set(Some(truncated.clone()));
active.updated_at = Set(chrono::Utc::now());
let _ = active.update(service.db.writer()).await;
// Emit title via SSE
let title_payload = serde_json::json!({"title": truncated}).to_string();
let _ = tx.send(format!("data: {{\"event\":\"title\",\"data\":{}}}\n\n", title_payload)).await;
}
}
} else if let Some(title) = &existing_title {
// Title already set (e.g. by AI tool) — emit it
let title_payload = serde_json::json!({"title": title}).to_string();
let _ = tx.send(format!("data: {{\"event\":\"title\",\"data\":{}}}\n\n", title_payload)).await;
}
}
}
// Broadcast final chat message with token usage
let final_msg = ChatMessageEvent {
message_id: user_message_id,
conversation_id,
project_id: conv_project_id,
sender_id: Uuid::nil(),
role: "assistant".to_string(),
content: stream_result.content.clone(),
model: Some(model_name.clone()),
input_tokens: Some(stream_result.input_tokens as i32),
output_tokens: Some(stream_result.output_tokens as i32),
timestamp: chrono::Utc::now(),
};
let _ = queue.publish_chat_message(&final_msg).await;
// Send final SSE done event
let _ = tx.send("data: {\"event\":\"done\",\"data\":\"ok\"}\n\n".to_string()).await;
}
Err(e) => {
let _ = tx.send(format!("data: {{\"event\":\"error\",\"data\":\"{}\"}}\n\n", e)).await;
}
}
});
Box::pin(ReceiverStream::new(rx).map(|msg| Ok(actix_web::web::Bytes::from(msg))))
}
/// Update conversation metadata after an AI assistant message is saved.
async fn update_conversation_after_response(
service: &AppService,
conversation_id: Uuid,
assistant_msg: &ai_message::Model,
) {
use models::ai::ai_conversation;
use sea_orm::EntityTrait;
if let Ok(Some(conv)) = ai_conversation::Entity::find_by_id(conversation_id)
.one(service.db.reader()).await
{
let input_tokens = assistant_msg.input_tokens.unwrap_or(0) as i64;
let output_tokens = assistant_msg.output_tokens.unwrap_or(0) as i64;
let total_tokens = input_tokens + output_tokens;
let mut active: ai_conversation::ActiveModel = conv.into();
if let Ok(count) = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.count(service.db.reader()).await
{
active.message_count = Set(count as i32);
}
active.token_usage_total = Set(Some(total_tokens as i32));
active.updated_at = Set(chrono::Utc::now());
let _ = active.update(service.db.writer()).await;
}
}
/// Build ChatRequestMessage list from ai_message conversation history.
async fn build_messages_from_history(
service: &AppService,
conversation_id: Uuid,
) -> Result<Vec<ChatRequestMessage>, String> {
let msgs = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.filter(ai_message::Column::IsLatest.eq(true))
.order_by_asc(ai_message::Column::CreatedAt)
.all(service.db.reader())
.await
.map_err(|e| format!("db error: {}", e))?;
let mut chat_messages = Vec::new();
for msg in &msgs {
let role = msg.role.as_str();
let content = match &msg.content {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => {
// Content is ordered blocks: [{role:"thinking",content:"..."}, {role:"assistant","content":"..."}, ...]
// For assistant messages: concatenate all "assistant" blocks
// For user/system messages: take the first block's content
if role == "assistant" {
arr.iter()
.filter(|item| item.get("role").and_then(|r| r.as_str()) != Some("thinking"))
.filter_map(|item| item.get("content").and_then(|c| c.as_str()))
.collect::<Vec<_>>()
.join("\n")
} else if let Some(first) = arr.first() {
first.get("content")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string()
} else {
String::new()
}
}
other => other.to_string(),
};
match role {
"user" => chat_messages.push(ChatRequestMessage::user(content)),
"assistant" => chat_messages.push(ChatRequestMessage::assistant(Some(content), None)),
"system" => chat_messages.push(ChatRequestMessage::system(content)),
_ => chat_messages.push(ChatRequestMessage::user(content)),
}
}
Ok(chat_messages)
}
/// Merge consecutive content blocks of the same role into single blocks.
/// This transforms many small per-chunk blocks into clean interleaved segments:
/// [thinking, thinking, assistant, assistant] → [thinking, assistant]
/// Per-token chunks are concatenated directly — the model sends \n inside
/// the token content where needed, not between tokens.
fn merge_consecutive_blocks(blocks: Vec<(String, String)>) -> Vec<(String, String)> {
let mut merged: Vec<(String, String)> = Vec::new();
for (role, content) in blocks {
if content.is_empty() { continue; }
if let Some(last) = merged.last_mut() {
if last.0 == role {
last.1.push_str(&content);
continue;
}
}
merged.push((role, content));
}
merged
}

158
libs/api/chat/watch.rs Normal file
View File

@ -0,0 +1,158 @@
//! SSE endpoint for watching a chat conversation in real-time via NATS.
//!
//! Unlike the primary SSE stream (which triggers AI execution), this endpoint
//! passively subscribes to NATS Core subjects and forwards chat messages and
//! stream chunks to connected clients. This enables multiple viewers to watch
//! the same AI conversation in real-time.
use actix_web::{web, HttpResponse, Result};
use futures::StreamExt;
use service::AppService;
use std::pin::Pin;
use uuid::Uuid;
use crate::error::ApiError;
/// SSE endpoint for watching a chat conversation.
///
/// `GET /api/ai/conversations/{conversation_id}/watch`
///
/// Subscribes to NATS Core subjects (`chat.chunk.{id}` and `chat.message.{id}`)
/// and forwards received events as SSE to the connected client.
///
/// SSE events:
/// - `chunk` — a stream chunk (thinking, token, tool_call, tool_result, done, error)
/// - `message` — a complete chat message
/// - `error` — an error event
pub fn create_watch_sse_stream(
service: AppService,
conversation_id: Uuid,
) -> Pin<Box<dyn futures::Stream<Item = Result<actix_web::web::Bytes, actix_web::Error>> + Send>> {
let (tx, rx) = tokio::sync::mpsc::channel::<String>(200);
tokio::spawn(async move {
let nats = match &service.queue_producer.nats {
Some(n) => n.clone(),
None => {
let _ = tx.send(format!(
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
serde_json::to_string("NATS not available").unwrap_or_default()
)).await;
return;
}
};
// Subscribe to chat chunks
let chunk_subject = format!("chat.chunk.{}", conversation_id);
let mut chunk_sub = match nats.subscribe(&chunk_subject).await {
Ok(s) => s,
Err(e) => {
let _ = tx.send(format!(
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
serde_json::to_string(&e.to_string()).unwrap_or_default()
)).await;
return;
}
};
// Subscribe to chat messages
let msg_subject = format!("chat.message.{}", conversation_id);
let mut msg_sub = match nats.subscribe(&msg_subject).await {
Ok(s) => s,
Err(e) => {
let _ = tx.send(format!(
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
serde_json::to_string(&e.to_string()).unwrap_or_default()
)).await;
return;
}
};
let _ = tx.send(":ok\n\n".to_string()).await;
loop {
tokio::select! {
chunk_msg = chunk_sub.next() => {
match chunk_msg {
Some(msg) => {
let payload = String::from_utf8_lossy(&msg.payload);
// Parse to get chunk_type for the event field
let event_type = if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&payload) {
parsed.get("chunk_type")
.and_then(|v| v.as_str())
.unwrap_or("chunk")
.to_string()
} else {
"chunk".to_string()
};
let sse = format!(
"data: {{\"event\":\"{}\",\"data\":{}}}\n\n",
event_type, payload
);
if tx.send(sse).await.is_err() {
break;
}
}
None => break,
}
}
msg = msg_sub.next() => {
match msg {
Some(msg) => {
let payload = String::from_utf8_lossy(&msg.payload);
let sse = format!(
"data: {{\"event\":\"message\",\"data\":{}}}\n\n",
payload
);
if tx.send(sse).await.is_err() {
break;
}
}
None => break,
}
}
}
}
});
Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx).map(|s| {
Ok(actix_web::web::Bytes::from(s))
}))
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/watch",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
),
responses(
(status = 200, description = "SSE stream of conversation events"),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn conversation_watch(
service: web::Data<AppService>,
session: session::Session,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or_else(|| ApiError::from(service::error::AppError::Unauthorized))?;
let conversation_id = path.into_inner();
// Verify access (view-only is sufficient)
let _conv = service
.find_conversation_owned(conversation_id, user_id)
.await?;
let response = HttpResponse::Ok()
.content_type("text/event-stream")
.insert_header(("Cache-Control", "no-cache"))
.insert_header(("X-Accel-Buffering", "no"))
.streaming(create_watch_sse_stream(
service.get_ref().clone(),
conversation_id,
));
Ok(response.into())
}

View File

@ -116,20 +116,30 @@ pub async fn serve_frontend(req: HttpRequest, path: web::Path<String>) -> HttpRe
}; };
let cc = cache_control_header(path_str); let cc = cache_control_header(path_str);
// Try brotli first (best compression), then gzip, then uncompressed. // Try brotli/gzip compressed variant first (best compression),
// Only serve compressed variant if client explicitly accepts it AND we have one. // then fall back to uncompressed if client doesn't accept the encoding.
let (data, encoding, etag, content_path) = let compressed = crate::frontend::get_frontend_asset_compressed(path_str);
match frontend::get_frontend_asset_compressed(path_str) { let uncompressed = crate::frontend::get_frontend_asset_with_etag(path_str);
Some(r) => (r.0, r.1, r.2, path_str),
None => { let (data, encoding, etag, content_path) = if let Some((c_data, c_enc, c_etag)) = compressed {
// Path not found — try index.html as SPA fallback. if accepts_encoding(&req, c_enc) {
// Also use "index.html" for Content-Type detection (text/html). (c_data, c_enc, c_etag, path_str)
match frontend::get_frontend_asset_with_etag("index.html") { } else if let Some((u_data, u_etag)) = uncompressed {
Some((data, etag)) => (data, "", etag, "index.html"), // Client doesn't accept the pre-compressed encoding — serve uncompressed.
None => return HttpResponse::NotFound().finish(), (u_data, "", u_etag, path_str)
} } else {
} // No uncompressed fallback — still serve compressed (client must handle it).
}; (c_data, c_enc, c_etag, path_str)
}
} else if let Some((data, etag)) = uncompressed {
(data, "", etag, path_str)
} else {
// Path not found — try index.html as SPA fallback.
match crate::frontend::get_frontend_asset_with_etag("index.html") {
Some((data, etag)) => (data, "", etag, "index.html"),
None => return HttpResponse::NotFound().finish(),
}
};
if !encoding.is_empty() && accepts_encoding(&req, &encoding) { if !encoding.is_empty() && accepts_encoding(&req, &encoding) {
build_asset_response(&req, data, etag, content_path, cc, &encoding) build_asset_response(&req, data, etag, content_path, cc, &encoding)

Some files were not shown because too many files have changed in this diff Show More