feat(core): initialize project with access control and AI integration

This commit is contained in:
ZhenYi 2026-05-10 21:01:21 +08:00
parent 14f6e1e500
commit ba2490dab4
619 changed files with 51264 additions and 8418 deletions

View File

@ -51,9 +51,9 @@ APP_DOMAIN_URL=http://127.0.0.1
# APP_DATABASE_MAX_CONNECTIONS=10
# APP_DATABASE_MIN_CONNECTIONS=2
# APP_DATABASE_IDLE_TIMEOUT=60000
# APP_DATABASE_MAX_LIFETIME=300000
# APP_DATABASE_CONNECTION_TIMEOUT=5000
# APP_DATABASE_IDLE_TIMEOUT=60000 (milliseconds, default: 60s)
# APP_DATABASE_MAX_LIFETIME=300000 (milliseconds, default: 300s)
# APP_DATABASE_CONNECTION_TIMEOUT=5000 (milliseconds, default: 5s)
# APP_DATABASE_REPLICAS=
# APP_DATABASE_HEALTH_CHECK_INTERVAL=30
# APP_DATABASE_RETRY_ATTEMPTS=3

3
.gitignore vendored
View File

@ -23,3 +23,6 @@ coverage/
pnpm-lock.yaml
package-lock.json
yarn.lock
.gemini
.omg
/.sqry

11
.mcp.json Normal file
View File

@ -0,0 +1,11 @@
{
"mcpServers": {
"shadcn": {
"command": "npx",
"args": [
"shadcn@latest",
"mcp"
]
}
}
}

1537
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -11,19 +11,21 @@ members = [
"libs/service",
"libs/db",
"libs/api",
"libs/webhook",
"libs/transport",
"libs/observability",
"libs/avatar",
"libs/agent",
"libs/migrate",
"libs/fctool",
"libs/gingress-proxy",
"apps/migrate",
"apps/app",
"apps/git-hook",
"apps/gitserver",
"apps/email",
"apps/static",
"apps/metrics",
"apps/gingress",
]
resolver = "3"
@ -40,12 +42,14 @@ service = { path = "libs/service" }
db = { path = "libs/db" }
api = { path = "libs/api" }
agent = { path = "libs/agent" }
webhook = { path = "libs/webhook" }
observability = { path = "libs/observability" }
avatar = { path = "libs/avatar" }
migrate = { path = "libs/migrate" }
fctool = { path = "libs/fctool" }
transport = { path = "libs/transport" }
metrics-aggregator = { path = "apps/metrics" }
gingress-proxy = { path = "libs/gingress-proxy" }
gingress = { path = "apps/gingress" }
sea-query = "1.0.0-rc.33"
@ -131,7 +135,10 @@ tokio = "1.50.0"
tokio-util = "0.7.18"
tokio-stream = "0.1.18"
url = "2.5.8"
tower = "0.5"
num_cpus = "1.17.0"
ring = "0.17"
rustls = { version = "0.23", default-features = false, features = ["ring", "std", "tls12"] }
clap = "4.6.0"
time = "0.3.47"
chrono = "0.4.44"
@ -165,12 +172,19 @@ phf_codegen = "0.13.1"
base64 = "0.22.1"
base64ct = "1"
p256 = { version = "0.13", features = ["ecdsa", "std"] }
http = "1"
# http version varies per-crate (pingora needs 1.x, actix needs 0.2)
hyper = "0.14"
tempfile = "3"
rig-core = { version = "0.30.0", default-features = false }
tokio-tungstenite = { version = "0.29.0", features = [] }
async-nats = { version = "0.47.0", features = [] }
kube = { version = "0.98", features = ["runtime", "derive"] }
k8s-openapi = { version = "0.24", features = ["v1_31"] }
pingora = { version = "0.8", features = ["proxy"] }
pingora-proxy = "0.8"
pingora-load-balancing = "0.8"
pingora-cache = "0.8"
rustls-pemfile = "2"
[workspace.package]
version = "0.2.9"
edition = "2024"

View File

@ -9,6 +9,7 @@ use futures::future::LocalBoxFuture;
use observability::{
init_tracing_subscriber, install_recorder, prometheus_handler, spawn_http_metrics_poller,
HttpMetrics, HttpSnapshotGuard, MetricsMiddleware, TracingSpanMiddleware,
push::MetricsPusher,
};
use sea_orm::ConnectionTrait;
use service::AppService;
@ -17,6 +18,7 @@ use api::{robots, sidemap};
use session::storage::RedisClusterSessionStore;
use session::SessionMiddleware;
use std::task::{Context, Poll};
use std::sync::Arc;
use std::time::Instant;
mod args;
@ -151,7 +153,8 @@ async fn main() -> anyhow::Result<()> {
let service = AppService::new(cfg.clone()).await?;
tracing::info!("AppService initialized");
let _model_sync_handle = service.clone().start_sync_task();
let _billing_alert_handle = service.clone().start_billing_alert_task();
// TODO: workspace module not yet wired — billing alert task pending
// let _billing_alert_handle = service.clone().start_billing_alert_task();
let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
let worker_service = service.clone();
@ -192,6 +195,13 @@ async fn main() -> anyhow::Result<()> {
);
let http_snapshot_data = web::Data::new(http_snapshot);
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
let pusher = MetricsPusher::new(&push_url, "app");
pusher.spawn(http_metrics.clone(), Arc::new(prometheus_handle.clone()), std::time::Duration::from_secs(15));
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
}
let bind_addr = args.bind.unwrap_or_else(|| "127.0.0.1:8080".to_string());
tracing::info!(bind_addr = %bind_addr, "Listening");
let http_metrics_server = http_metrics.clone();
@ -212,11 +222,16 @@ async fn main() -> anyhow::Result<()> {
cors = cors.allowed_origin(origin);
}
let cors = cors
.allowed_methods(["GET", "POST", "PUT", "PATCH", "DELETE"])
.allowed_methods(["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
.allowed_headers(["Content-Type", "Authorization", "X-Requested-With", "Accept", "Origin"])
.supports_credentials()
.max_age(3600);
let security_headers = actix_web::middleware::DefaultHeaders::new()
.add(("X-Content-Type-Options", "nosniff"))
.add(("X-Frame-Options", "DENY"))
.add(("Referrer-Policy", "strict-origin-when-cross-origin"));
let session_mw = SessionMiddleware::builder(store.clone(), session_key.clone())
.cookie_name("id".to_string())
.cookie_path("/".to_string())
@ -233,6 +248,7 @@ async fn main() -> anyhow::Result<()> {
App::new()
.wrap(cors)
.wrap(security_headers)
.wrap(session_mw)
.wrap(RequestLogger)
.wrap(metrics_mw)

View File

@ -2,7 +2,7 @@ use clap::Parser;
use config::AppConfig;
use metrics::{describe_counter, Unit};
use metrics_exporter_prometheus::PrometheusHandle;
use observability::{init_tracing_subscriber, install_recorder};
use observability::{init_tracing_subscriber, install_recorder, HttpMetrics, push::MetricsPusher};
use sea_orm::ConnectionTrait;
use service::AppService;
use std::sync::Arc;
@ -88,6 +88,14 @@ async fn main() -> anyhow::Result<()> {
describe_counter!("email_send_failures_total", Unit::Count, "Emails that failed after all retries");
let metrics_handle = Arc::new(install_recorder());
let http_metrics = Arc::new(HttpMetrics::new()); // Worker app — HTTP section will be empty
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
let pusher = MetricsPusher::new(&push_url, "email");
pusher.spawn(http_metrics.clone(), metrics_handle.clone(), std::time::Duration::from_secs(15));
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
}
tracing::info!("Starting email worker");
let service = AppService::new(cfg).await?;

46
apps/gingress/Cargo.toml Normal file
View File

@ -0,0 +1,46 @@
[package]
name = "gingress"
version.workspace = true
edition.workspace = true
authors.workspace = true
description = "GIngress control plane: Kubernetes Ingress Controller using kube-rs"
repository.workspace = true
readme.workspace = true
homepage.workspace = true
license.workspace = true
keywords.workspace = true
categories.workspace = true
documentation.workspace = true
[[bin]]
name = "gingress"
path = "src/main.rs"
[[bin]]
name = "kubectl-gingress"
path = "src/bin/kubectl-gingress/main.rs"
[dependencies]
gingress-proxy = { workspace = true }
kube = { version = "0.98", features = ["runtime", "derive"] }
k8s-openapi = { version = "0.24", features = ["v1_31"] }
tokio = { workspace = true, features = ["full"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
serde_yaml = { workspace = true }
tracing = { workspace = true }
observability = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
dashmap = { workspace = true }
futures-util = { workspace = true }
futures = { workspace = true }
clap = { workspace = true }
url = { workspace = true }
x509-parser = "0.17"
rustls-pemfile = "2"
[lints]
workspace = true

View File

@ -0,0 +1,797 @@
//! kubectl-gingress — kubectl plugin for managing GIngress resources.
//!
//! Usage (via kubectl): kubectl gingress <subcommand>
//! Usage (standalone): kubectl-gingress <subcommand>
use clap::{Parser, Subcommand};
use k8s_openapi::api::core::v1::{Pod, Secret};
use k8s_openapi::api::networking::v1::{HTTPIngressPath, Ingress};
use kube::api::ListParams;
use kube::{Api, Client, ResourceExt};
const INGRESS_CLASS: &str = "gingress";
#[derive(Parser)]
#[command(
name = "kubectl-gingress",
bin_name = "kubectl gingress",
about = "Manage GIngress — Kubernetes Ingress Controller",
version
)]
struct Cli {
#[command(subcommand)]
command: Command,
}
#[derive(Subcommand)]
enum Command {
/// List all Ingress resources managed by GIngress
#[command(alias = "ls")]
List {
/// Filter by namespace (omit for all namespaces)
#[arg(short, long)]
namespace: Option<String>,
/// Output as JSON
#[arg(long)]
json: bool,
},
/// Show the routing table (host → path → backend)
Routes {
/// Filter by namespace
#[arg(short, long)]
namespace: Option<String>,
/// Filter by host
#[arg(short = 'H', long)]
host: Option<String>,
/// Output as JSON
#[arg(long)]
json: bool,
},
/// Show backend services and their endpoints
Backends {
/// Filter by namespace
#[arg(short, long)]
namespace: Option<String>,
/// Output as JSON
#[arg(long)]
json: bool,
},
/// List TLS certificates (from Secrets)
Certs {
/// Filter by namespace
#[arg(short, long)]
namespace: Option<String>,
/// Output as JSON
#[arg(long)]
json: bool,
},
/// Validate Ingress configurations
Validate {
/// Filter by namespace
#[arg(short, long)]
namespace: Option<String>,
},
/// Show GIngress controller status and summary
Status {
/// Output as JSON
#[arg(long)]
json: bool,
},
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let cli = Cli::parse();
let client = Client::try_default().await?;
match cli.command {
Command::List { namespace, json } => cmd_list(&client, namespace, json).await?,
Command::Routes { namespace, host, json } => cmd_routes(&client, namespace, host, json).await?,
Command::Backends { namespace, json } => cmd_backends(&client, namespace, json).await?,
Command::Certs { namespace, json } => cmd_certs(&client, namespace, json).await?,
Command::Validate { namespace } => cmd_validate(&client, namespace).await?,
Command::Status { json } => cmd_status(&client, json).await?,
}
Ok(())
}
// ── list ──────────────────────────────────────────────────────────
async fn cmd_list(client: &Client, namespace: Option<String>, json: bool) -> Result<(), Box<dyn std::error::Error>> {
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
if json {
println!("{}", serde_json::to_string_pretty(&ingresses)?);
return Ok(());
}
if ingresses.is_empty() {
println!("No GIngress-managed Ingress resources found.");
return Ok(());
}
println!("{:<25} {:<20} {:<40} {:<50} {:<15}", "NAMESPACE", "NAME", "HOSTS", "PATHS", "TLS");
println!("{:-<150}", "");
for ing in &ingresses {
let ns = ing.namespace();
let name = ing.name_any();
let hosts = ing.hosts().join(", ");
let paths = ing
.paths_display()
.iter()
.map(|p| format!("{} {}", p.path_type, p.path))
.collect::<Vec<_>>()
.join(", ");
let tls = if ing.has_tls() { "Enabled" } else { "-" };
println!("{:<25} {:<20} {:<40} {:<50} {:<15}",
truncate(&ns, 25),
truncate(&name, 20),
truncate(&hosts, 40),
truncate(&paths, 50),
tls,
);
}
println!("\nTotal: {} Ingress(es)", ingresses.len());
Ok(())
}
// ── routes ─────────────────────────────────────────────────────────
async fn cmd_routes(
client: &Client,
namespace: Option<String>,
host_filter: Option<String>,
json: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
let mut routes: Vec<RouteRow> = Vec::new();
for ing in &ingresses {
for rule in ing.spec.as_ref().and_then(|s| s.rules.as_ref()).into_iter().flatten() {
let host = rule.host.as_deref().unwrap_or("*");
if let Some(ref hf) = host_filter {
if host != hf { continue; }
}
if let Some(http) = &rule.http {
for path_item in &http.paths {
let backend = extract_backend(path_item);
let port = extract_backend_port(path_item);
routes.push(RouteRow {
namespace: ing.namespace(),
ingress: ing.name_any(),
host: host.to_string(),
path: path_item.path.clone().unwrap_or_else(|| "/".into()),
path_type: path_item.path_type.clone(),
backend,
port,
});
}
}
}
}
if json {
println!("{}", serde_json::to_string_pretty(&routes)?);
return Ok(());
}
if routes.is_empty() {
println!("No routes found.");
return Ok(());
}
println!("{:<20} {:<20} {:<30} {:<18} {:<15} {:<15} {:<15}",
"NAMESPACE", "INGRESS", "HOST", "PATH", "TYPE", "BACKEND", "PORT");
println!("{:-<133}", "");
for r in &routes {
let port = extract_backend_port_str(r);
println!("{:<20} {:<20} {:<30} {:<18} {:<15} {:<15} {:<15}",
truncate(&r.namespace, 20),
truncate(&r.ingress, 20),
truncate(&r.host, 30),
truncate(&r.path, 18),
truncate(&r.path_type, 15),
truncate(&r.backend, 15),
port,
);
}
println!("\nTotal: {} route(s)", routes.len());
Ok(())
}
// ── backends ───────────────────────────────────────────────────────
async fn cmd_backends(
client: &Client,
namespace: Option<String>,
json: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
// Collect unique backends from all ingresses
let mut backends: Vec<BackendRow> = Vec::new();
let mut seen = std::collections::HashSet::new();
for ing in &ingresses {
for rule in ing.spec.as_ref().and_then(|s| s.rules.as_ref()).into_iter().flatten() {
if let Some(http) = &rule.http {
for path_item in &http.paths {
let svc = match path_item.backend.service.as_ref() {
Some(s) => s,
None => continue,
};
let key = format!("{}/{}:{}", ing.namespace(), svc.name, svc.port.as_ref().and_then(|p| p.number).unwrap_or(80));
if seen.insert(key.clone()) {
let ns = ing.namespace();
let ep_status = get_endpoint_status(client, &ns, &svc.name).await;
backends.push(BackendRow {
namespace: ns,
service: svc.name.clone(),
port: svc.port.as_ref().and_then(|p| p.number).unwrap_or(80) as u16,
ready_endpoints: ep_status.ready,
total_endpoints: ep_status.total,
referenced_by: ing.name_any(),
});
}
}
}
}
}
if json {
println!("{}", serde_json::to_string_pretty(&backends)?);
return Ok(());
}
if backends.is_empty() {
println!("No backends found.");
return Ok(());
}
println!("{:<20} {:<20} {:<8} {:<8} {:<18} {:<20}",
"NAMESPACE", "SERVICE", "PORT", "HEALTH", "ENDPOINTS", "REFERENCED BY");
println!("{:-<94}", "");
for b in &backends {
let health = if b.total_endpoints == 0 { "WARN" } else if b.ready_endpoints == 0 { "DOWN" } else if b.ready_endpoints < b.total_endpoints { "PARTIAL" } else { "OK" };
let eps = format!("{}/{} ready", b.ready_endpoints, b.total_endpoints);
println!("{:<20} {:<20} {:<8} {:<8} {:<18} {:<20}",
truncate(&b.namespace, 20),
truncate(&b.service, 20),
b.port,
health,
eps,
truncate(&b.referenced_by, 20),
);
}
println!("\nTotal: {} backend(s)", backends.len());
Ok(())
}
// ── certs ──────────────────────────────────────────────────────────
async fn cmd_certs(
client: &Client,
namespace: Option<String>,
json: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
let mut certs: Vec<CertRow> = Vec::new();
for ing in &ingresses {
let ns = ing.namespace();
let tls_entries = ing
.spec
.as_ref()
.and_then(|s| s.tls.as_ref())
.cloned()
.unwrap_or_default();
for tls in &tls_entries {
let secret_name = tls.secret_name.clone().unwrap_or_default();
let hosts = tls.hosts.clone().unwrap_or_default();
// Check if the secret exists
let secret_exists = check_secret_exists(client, &ns, &secret_name).await;
for host in &hosts {
certs.push(CertRow {
namespace: ns.clone(),
secret_name: secret_name.clone(),
host: host.clone(),
found: secret_exists,
});
}
}
}
if json {
println!("{}", serde_json::to_string_pretty(&certs)?);
return Ok(());
}
if certs.is_empty() {
println!("No TLS certificates configured.");
return Ok(());
}
println!("{:<20} {:<30} {:<30} {:<10}", "NAMESPACE", "SECRET", "HOST", "STATUS");
println!("{:-<90}", "");
for c in &certs {
let status = if c.found { "OK" } else { "MISSING" };
println!("{:<20} {:<30} {:<30} {:<10}",
truncate(&c.namespace, 20),
truncate(&c.secret_name, 30),
truncate(&c.host, 30),
status,
);
}
let missing = certs.iter().filter(|c| !c.found).count();
println!("\nTotal: {} cert(s), {} missing", certs.len(), missing);
Ok(())
}
// ── validate ───────────────────────────────────────────────────────
async fn cmd_validate(
client: &Client,
namespace: Option<String>,
) -> Result<(), Box<dyn std::error::Error>> {
let ingresses = list_ingresses(client, namespace.as_deref()).await?;
let mut errors = 0usize;
let mut warnings = 0usize;
for ing in &ingresses {
let ns = ing.namespace();
let name = ing.name_any();
// Check: has rules
let has_rules = ing
.spec
.as_ref()
.map(|s| s.rules.as_ref().map(|r| !r.is_empty()).unwrap_or(false))
.unwrap_or(false);
if !has_rules {
println!("[{}/{}] ERROR: No routing rules defined", ns, name);
errors += 1;
}
// Check: has TLS but no secret
let tls_entries = ing
.spec
.as_ref()
.and_then(|s| s.tls.as_ref())
.cloned()
.unwrap_or_default();
for tls in &tls_entries {
let secret_name = tls.secret_name.as_deref().unwrap_or("");
if secret_name.is_empty() {
println!("[{}/{}] WARNING: TLS configured but no secretName specified", ns, name);
warnings += 1;
} else {
let found = check_secret_exists(client, &ns, secret_name).await;
if !found {
println!("[{}/{}] ERROR: TLS secret '{}' not found in namespace '{}'", ns, name, secret_name, ns);
errors += 1;
}
}
}
// Check: path backends reference valid services
if let Some(rules) = ing.spec.as_ref().and_then(|s| s.rules.as_ref()) {
for rule in rules {
if let Some(http) = &rule.http {
for path_item in &http.paths {
if let Some(svc) = &path_item.backend.service {
let endpoints = get_endpoint_status(client, &ns, &svc.name).await;
if endpoints.total == 0 {
println!(
"[{}/{}] WARNING: Backend service '{}' has no endpoints (host: {})",
ns, name, svc.name,
rule.host.as_deref().unwrap_or("*")
);
warnings += 1;
}
}
}
}
}
}
}
if errors == 0 && warnings == 0 {
println!("Validation passed — no issues found in {} Ingress(es).", ingresses.len());
} else {
println!("\nValidation complete: {} error(s), {} warning(s) across {} Ingress(es).",
errors, warnings, ingresses.len());
}
Ok(())
}
// ── status ─────────────────────────────────────────────────────────
async fn cmd_status(client: &Client, json: bool) -> Result<(), Box<dyn std::error::Error>> {
let ingresses = list_ingresses(client, None).await?;
let controller_pods = find_gingress_pods(client).await;
if json {
#[derive(serde::Serialize)]
struct StatusOutput {
controller_pods: Vec<PodInfo>,
ingress_count: usize,
route_count: usize,
backend_count: usize,
cert_count: usize,
}
let mut route_count = 0usize;
let mut backend_set = std::collections::HashSet::new();
let mut cert_count = 0usize;
for ing in &ingresses {
if let Some(spec) = &ing.spec {
if let Some(rules) = &spec.rules {
for rule in rules {
if let Some(http) = &rule.http {
route_count += http.paths.len();
for p in &http.paths {
if let Some(svc) = &p.backend.service {
backend_set.insert(format!("{}/{}", ing.namespace(), svc.name));
}
}
}
}
}
if let Some(tls) = &spec.tls {
cert_count += tls.len();
}
}
}
println!(
"{}",
serde_json::to_string_pretty(&StatusOutput {
controller_pods,
ingress_count: ingresses.len(),
route_count,
backend_count: backend_set.len(),
cert_count,
})?
);
return Ok(());
}
// Controller status
println!("══ GIngress Controller Status ══\n");
if controller_pods.is_empty() {
println!("Controller: NOT FOUND (no pods with label gingress.io/component=controller)");
} else {
for pod in &controller_pods {
let ready = if pod.ready { "Running" } else { "NotReady" };
println!(
"Controller Pod: {:<30} {:<12} {}",
truncate(&pod.name, 30),
ready,
pod.namespace
);
}
}
// Resource summary
println!();
println!("══ Managed Resources ══\n");
let mut route_count = 0usize;
let mut backend_set = std::collections::HashSet::new();
let mut backend_total_eps = 0usize;
let mut backend_ready_eps = 0usize;
let mut cert_count = 0usize;
let mut cert_missing = 0usize;
for ing in &ingresses {
if let Some(spec) = &ing.spec {
if let Some(rules) = &spec.rules {
for rule in rules {
if let Some(http) = &rule.http {
route_count += http.paths.len();
for p in &http.paths {
if let Some(svc) = &p.backend.service {
let key = format!("{}/{}", ing.namespace(), svc.name);
if backend_set.insert(key) {
let eps = get_endpoint_status(client, &ing.namespace(), &svc.name).await;
backend_total_eps += eps.total;
backend_ready_eps += eps.ready;
}
}
}
}
}
}
if let Some(tls) = &spec.tls {
cert_count += tls.len();
for tls in tls {
let sn = tls.secret_name.as_deref().unwrap_or("");
if !sn.is_empty() && !check_secret_exists(client, &ing.namespace(), sn).await {
cert_missing += 1;
}
}
}
}
}
println!("Ingresses: {}", ingresses.len());
println!("Routes: {}", route_count);
println!(
"Backends: {} ({} ready / {} total endpoints)",
backend_set.len(), backend_ready_eps, backend_total_eps
);
println!("TLS Certs: {} ({} missing)", cert_count, cert_missing);
println!();
// Overall health
let healthy = !ingresses.is_empty()
&& (cert_missing == 0)
&& (backend_set.is_empty() || backend_ready_eps > 0)
&& controller_pods.iter().any(|p| p.ready);
if healthy {
println!("Status: HEALTHY");
} else {
println!("Status: DEGRADED");
if controller_pods.is_empty() || !controller_pods.iter().any(|p| p.ready) {
println!(" → Controller pod not running or not ready");
}
if cert_missing > 0 {
println!("{} TLS secret(s) missing", cert_missing);
}
if !backend_set.is_empty() && backend_ready_eps == 0 {
println!(" → No ready backend endpoints");
}
}
Ok(())
}
// ── k8s helpers ────────────────────────────────────────────────────
#[derive(serde::Serialize)]
struct IngressSummary {
namespace: String,
name: String,
hosts: Vec<String>,
#[serde(skip)]
paths_for_display: Vec<PathSummary>,
has_tls: bool,
#[serde(skip)]
spec: Option<k8s_openapi::api::networking::v1::IngressSpec>,
}
#[derive(Clone)]
struct PathSummary {
path_type: String,
path: String,
}
impl IngressSummary {
fn namespace(&self) -> String { self.namespace.clone() }
fn name_any(&self) -> String { self.name.clone() }
fn hosts(&self) -> &[String] { &self.hosts }
fn paths_display(&self) -> &[PathSummary] { &self.paths_for_display }
fn has_tls(&self) -> bool { self.has_tls }
}
fn ingress_to_summary(ing: &Ingress) -> IngressSummary {
let spec = ing.spec.clone();
let mut hosts = Vec::new();
let mut paths = Vec::new();
let mut has_tls = false;
if let Some(ref s) = spec {
if let Some(ref rules) = s.rules {
for rule in rules {
if let Some(ref host) = rule.host {
hosts.push(host.clone());
}
if let Some(ref http) = rule.http {
for p in &http.paths {
paths.push(PathSummary {
path_type: p.path_type.clone(),
path: p.path.clone().unwrap_or_else(|| "/".into()),
});
}
}
}
}
has_tls = s.tls.as_ref().map(|t| !t.is_empty()).unwrap_or(false);
}
IngressSummary {
namespace: ing.namespace().unwrap_or_default(),
name: ing.name_any(),
hosts,
paths_for_display: paths,
has_tls,
spec,
}
}
#[derive(serde::Serialize)]
struct PodInfo {
name: String,
namespace: String,
ready: bool,
}
/// Find GIngress controller pods by label `gingress.io/component=controller`.
async fn find_gingress_pods(client: &Client) -> Vec<PodInfo> {
let api: Api<Pod> = Api::all(client.clone());
let lp = ListParams {
label_selector: Some("gingress.io/component=controller".into()),
..Default::default()
};
match api.list(&lp).await {
Ok(list) => list
.items
.into_iter()
.map(|pod| {
let ready = pod
.status
.as_ref()
.and_then(|s| s.conditions.as_ref())
.map(|conds| {
conds
.iter()
.any(|c| c.type_ == "Ready" && c.status == "True")
})
.unwrap_or(false);
PodInfo {
name: pod.name_any(),
namespace: pod.namespace().unwrap_or_default(),
ready,
}
})
.collect(),
Err(_) => Vec::new(),
}
}
async fn list_ingresses(client: &Client, namespace: Option<&str>) -> Result<Vec<IngressSummary>, Box<dyn std::error::Error>> {
let params = ListParams {
..Default::default()
};
if let Some(ns) = namespace {
let api: Api<Ingress> = Api::namespaced(client.clone(), ns);
let list = api.list(&params).await?;
Ok(list
.items
.into_iter()
.filter(|ing| is_gingress_class(ing))
.map(|ing| ingress_to_summary(&ing))
.collect())
} else {
let api: Api<Ingress> = Api::all(client.clone());
let list = api.list(&params).await?;
Ok(list
.items
.into_iter()
.filter(|ing| is_gingress_class(ing))
.map(|ing| ingress_to_summary(&ing))
.collect())
}
}
fn is_gingress_class(ingress: &Ingress) -> bool {
ingress
.spec
.as_ref()
.and_then(|s| s.ingress_class_name.as_deref())
== Some(INGRESS_CLASS)
}
struct EndpointStatus {
ready: usize,
total: usize,
}
async fn get_endpoint_status(client: &Client, namespace: &str, service_name: &str) -> EndpointStatus {
use k8s_openapi::api::core::v1::Endpoints;
let api: Api<Endpoints> = Api::namespaced(client.clone(), namespace);
match api.get_opt(service_name).await {
Ok(Some(eps)) => {
let mut ready = 0usize;
let mut total = 0usize;
if let Some(subsets) = &eps.subsets {
for subset in subsets {
let addrs = subset.addresses.as_deref().unwrap_or_default();
let not_ready = subset.not_ready_addresses.as_deref().unwrap_or_default();
ready += addrs.len();
total += addrs.len() + not_ready.len();
}
}
EndpointStatus { ready, total }
}
_ => EndpointStatus { ready: 0, total: 0 },
}
}
async fn check_secret_exists(client: &Client, namespace: &str, name: &str) -> bool {
let api: Api<Secret> = Api::namespaced(client.clone(), namespace);
api.get_opt(name).await.ok().flatten().is_some()
}
fn extract_backend(path_item: &HTTPIngressPath) -> String {
path_item
.backend
.service
.as_ref()
.map(|s| s.name.clone())
.unwrap_or_else(|| "<resource>".into())
}
fn extract_backend_port(path_item: &HTTPIngressPath) -> u16 {
path_item
.backend
.service
.as_ref()
.and_then(|s| s.port.as_ref())
.and_then(|p| p.number)
.unwrap_or(80) as u16
}
fn extract_backend_port_str(r: &RouteRow) -> String {
r.port.to_string()
}
#[derive(serde::Serialize)]
struct RouteRow {
namespace: String,
ingress: String,
host: String,
path: String,
path_type: String,
backend: String,
port: u16,
}
#[derive(serde::Serialize)]
struct BackendRow {
namespace: String,
service: String,
port: u16,
ready_endpoints: usize,
total_endpoints: usize,
referenced_by: String,
}
#[derive(serde::Serialize)]
struct CertRow {
namespace: String,
secret_name: String,
host: String,
found: bool,
}
fn truncate(s: &str, max: usize) -> String {
// Account for CJK characters — each wide char counts as 2
let mut width = 0usize;
let mut result = String::new();
for c in s.chars() {
let cw = if c.is_ascii() { 1 } else { 2 };
if width + cw > max {
result.push_str("");
break;
}
result.push(c);
width += cw;
}
result
}

View File

@ -0,0 +1,151 @@
//! Watches Kubernetes Endpoints and updates upstream endpoint lists.
//!
//! Tracks Pod IPs for each Service. When endpoints change (scale up/down,
//! rolling restart, health check failures), the upstream pool is updated.
use futures::pin_mut;
use futures::StreamExt;
use gingress_proxy::config::{ConfigStore, Endpoint};
use k8s_openapi::api::core::v1::Endpoints as K8sEndpoints;
use kube::ResourceExt;
use kube::runtime::watcher::{self, Event};
use std::sync::Arc;
/// Watch Endpoints and update the ConfigStore.
pub async fn watch_endpoints(
client: Arc<kube::Client>,
store: Arc<ConfigStore>,
_namespace: Option<String>,
on_change: Arc<dyn Fn() + Send + Sync>,
) {
let api = kube::Api::<K8sEndpoints>::all(client.as_ref().clone());
let config = watcher::Config::default();
let watcher = watcher::watcher(api, config);
pin_mut!(watcher);
while let Some(event) = watcher.next().await {
match event {
Ok(Event::Apply(eps)) => {
process_endpoints(&eps, &store, &on_change);
}
Ok(Event::Init) => {
tracing::info!("Endpoint watcher re-initializing");
}
Ok(Event::InitApply(eps)) => {
process_endpoints(&eps, &store, &on_change);
}
Ok(Event::InitDone) => {
tracing::info!("Endpoint watcher init complete");
}
Ok(Event::Delete(eps)) => {
remove_endpoints(&eps, &store, &on_change);
}
Err(e) => {
tracing::error!("Endpoint watcher error: {}", e);
}
}
}
}
/// Extract endpoint addresses, grouped by port, and update the ConfigStore.
///
/// Stores endpoints under key `upstream:<ns>/<name>:<port>` to match
/// the proxy's upstream lookup format.
fn process_endpoints(
endpoints: &K8sEndpoints,
store: &ConfigStore,
on_change: &Arc<dyn Fn() + Send + Sync>,
) {
use std::collections::HashMap;
let name = endpoints.name_any();
let namespace = endpoints.namespace().unwrap_or_default();
let base_prefix = format!("upstream:{}/{}:", namespace, name);
// Collect endpoints grouped by port
let mut port_groups: HashMap<u16, Vec<Endpoint>> = HashMap::new();
if let Some(subsets) = &endpoints.subsets {
for subset in subsets {
let addrs = subset.addresses.as_deref().unwrap_or_default();
let ports = subset.ports.as_deref().unwrap_or_default();
let not_ready_addrs = subset.not_ready_addresses.as_deref().unwrap_or_default();
for port in ports {
let port_num = port.port as u16;
let eps = port_groups.entry(port_num).or_default();
for addr in addrs {
eps.push(Endpoint {
ip: addr.ip.clone(),
port: port_num,
ready: true,
});
}
for addr in not_ready_addrs {
eps.push(Endpoint {
ip: addr.ip.clone(),
port: port_num,
ready: false,
});
}
}
}
}
// Clear old per-port keys for this service (handles port removal)
let old_keys = store.keys_with_prefix(&base_prefix);
for k in old_keys {
store.remove(&k);
}
// Write per-port endpoint entries
let mut total = 0usize;
for (port_num, eps) in &port_groups {
let key = format!("{}{}", base_prefix, port_num);
store.set(&key, eps);
total += eps.len();
}
// If no ports at all, write an empty entry for the base key so the reconciler
// can detect that this service has no endpoints.
if port_groups.is_empty() {
store.set::<Vec<Endpoint>>(
&format!("upstream:{}/{}", namespace, name),
&vec![],
);
}
store.signal_reload();
on_change();
tracing::debug!(
namespace = %namespace,
name = %name,
num_ports = port_groups.len(),
num_endpoints = total,
"Endpoints updated"
);
}
/// Remove all per-port endpoint keys when the Endpoint resource is deleted.
fn remove_endpoints(
endpoints: &K8sEndpoints,
store: &ConfigStore,
on_change: &Arc<dyn Fn() + Send + Sync>,
) {
let name = endpoints.name_any();
let namespace = endpoints.namespace().unwrap_or_default();
let base_prefix = format!("upstream:{}/{}:", namespace, name);
// Remove all per-port keys
let keys = store.keys_with_prefix(&base_prefix);
for k in keys {
store.remove(&k);
}
// Also remove the port-less key (in case no ports were present)
store.remove(&format!("upstream:{}/{}", namespace, name));
store.signal_reload();
on_change();
tracing::info!(namespace = %namespace, name = %name, "Endpoints removed");
}

View File

@ -0,0 +1,372 @@
//! Watches Kubernetes Ingress resources and converts them to routing rules.
use futures::pin_mut;
use futures::StreamExt;
use gingress_proxy::config::{
ConfigStore, HeaderOp, PathType, RateLimitPolicy, RouteRule, SessionAffinityConfig,
};
use k8s_openapi::api::networking::v1::{HTTPIngressPath, Ingress};
use kube::ResourceExt;
use kube::runtime::watcher::{self, Event};
use std::collections::BTreeMap;
use std::sync::Arc;
/// Watch Ingress resources and update the ConfigStore.
///
/// After each event, the `on_change` callback is invoked so the reconciler
/// can cross-reference all fragments into a complete ProxyConfig.
pub async fn watch_ingresses(
client: Arc<kube::Client>,
store: Arc<ConfigStore>,
ingress_class: String,
namespace: Option<String>,
on_change: Arc<dyn Fn() + Send + Sync>,
) {
let api = kube::Api::<Ingress>::all(client.as_ref().clone());
let config = watcher::Config {
field_selector: namespace.as_ref().map(|ns| format!("metadata.namespace={}", ns)),
..Default::default()
};
let ingress_watcher = watcher::watcher(api, config);
pin_mut!(ingress_watcher);
while let Some(event) = ingress_watcher.next().await {
match event {
Ok(Event::Apply(ingress)) => {
let name = ingress.name_any();
let ns = ingress.namespace().unwrap_or_default();
if is_gingress_class(&ingress, &ingress_class) {
process_ingress(&ingress, &store, &ingress_class);
on_change();
tracing::info!(namespace = %ns, name = %name, "Ingress applied");
}
}
Ok(Event::Init) => {
store.remove_prefix("ingress:");
store.remove_prefix("tls-host:");
tracing::info!("Ingress watcher re-initializing");
}
Ok(Event::InitApply(ingress)) => {
if is_gingress_class(&ingress, &ingress_class) {
process_ingress(&ingress, &store, &ingress_class);
}
}
Ok(Event::InitDone) => {
store.signal_reload();
on_change();
tracing::info!("Ingress watcher init complete");
}
Ok(Event::Delete(ingress)) => {
if is_gingress_class(&ingress, &ingress_class) {
remove_ingress_routes(&ingress, &store);
on_change();
tracing::info!(
name = %ingress.name_any(),
namespace = %ingress.namespace().unwrap_or_default(),
"Ingress deleted, routes removed"
);
}
}
Err(e) => {
tracing::error!("Ingress watcher error: {}", e);
}
}
}
}
/// Check if an Ingress specifies the gingress class.
fn is_gingress_class(ingress: &Ingress, class_name: &str) -> bool {
ingress
.spec
.as_ref()
.and_then(|s| s.ingress_class_name.as_deref())
== Some(class_name)
}
/// Process an Ingress resource: extract routes and update the store.
fn process_ingress(ingress: &Ingress, store: &ConfigStore, _ingress_class: &str) {
let namespace = ingress.namespace().unwrap_or_default();
let name = ingress.name_any();
let spec = match ingress.spec.as_ref() {
Some(s) => s,
None => return,
};
// Build an ingress-scoped prefix so we can clean up old routes for this Ingress
let ingress_prefix = format!("ingress:{}/{}:", namespace, name);
// Remove old route entries scoped to this Ingress
let old_route_keys = store.keys_with_prefix(&format!("{}route:", ingress_prefix));
for key in &old_route_keys {
store.remove(key);
}
// Process routing rules
if let Some(rules) = &spec.rules {
for rule in rules {
let host = rule.host.as_deref().unwrap_or("*");
if let Some(http) = &rule.http {
let mut routes: Vec<RouteRule> = Vec::new();
for path_item in &http.paths {
routes.push(ingress_path_to_route(host, path_item, &namespace));
}
// Store per-ingress routes so we can clean up on delete
let route_key = format!("{}route:{}", ingress_prefix, host);
store.set(&route_key, &routes);
}
}
}
// Process TLS: map secretName -> hosts so the reconciler can cross-reference
if let Some(tls_entries) = &spec.tls {
for tls in tls_entries {
let secret_name = tls.secret_name.as_deref().unwrap_or_default();
let hosts: Vec<String> = tls.hosts.clone().unwrap_or_default();
let tls_host_key = format!("tls-host:{}", secret_name);
store.set(&tls_host_key, &hosts);
}
}
// Process annotations for advanced features
let annotations = ingress.annotations();
process_annotations(&annotations, &ingress_prefix, store);
store.signal_reload();
}
/// Convert a Kubernetes Ingress path to an internal RouteRule.
fn ingress_path_to_route(host: &str, path: &HTTPIngressPath, namespace: &str) -> RouteRule {
let service = path.backend.service.as_ref()
.expect("Ingress backend must reference a service");
RouteRule {
host: host.to_string(),
path: path.path.clone().unwrap_or_else(|| "/".to_string()),
path_type: match path.path_type.as_str() {
"Prefix" => PathType::Prefix,
"Exact" => PathType::Exact,
_ => PathType::ImplementationSpecific,
},
backend: gingress_proxy::config::Backend {
namespace: namespace.to_string(),
name: service.name.clone(),
port: service.port.as_ref().and_then(|p| p.number).unwrap_or(80) as u16,
},
}
}
/// Annotation keys for GIngress features.
const ANN_RATE_LIMIT: &str = "gingress.io/rate-limit";
const ANN_RATE_LIMIT_BURST: &str = "gingress.io/rate-limit-burst";
const ANN_REQUEST_HEADERS: &str = "gingress.io/request-headers";
const ANN_WEBSOCKET: &str = "gingress.io/websocket";
const ANN_SESSION_AFFINITY: &str = "gingress.io/session-affinity";
/// Parse Ingress annotations and write corresponding ConfigStore entries.
///
/// Supported annotations:
/// - `gingress.io/rate-limit` — "RPS" or "RPS/BURST" (e.g., "100" or "100/200")
/// - `gingress.io/rate-limit-burst` — Override burst size
/// - `gingress.io/request-headers` — JSON array of header operations
/// - `gingress.io/websocket` — "true" to enable WebSocket upgrade for this host
/// - `gingress.io/session-affinity` — "cookie" or "cookie:NAME:TTL_SECONDS"
fn process_annotations(
annotations: &BTreeMap<String, String>,
ingress_prefix: &str,
store: &ConfigStore,
) {
// Collect hosts from the ingress routes that were just stored
let route_keys = store.keys_with_prefix(&format!("{}route:", ingress_prefix));
let hosts: Vec<String> = route_keys
.iter()
.filter_map(|k| k.split(":route:").nth(1).map(String::from))
.collect();
if hosts.is_empty() {
return;
}
// Remove old per-host annotation keys (handles annotation removal/update)
for host in &hosts {
store.remove(&format!("rate_limit:{}", host));
store.remove(&format!("headers:{}", host));
store.remove(&format!("session_affinity:{}", host));
}
// Remove this ingress's hosts from the global websocket list
prune_websocket_hosts(store, &hosts);
// ── Rate limiting ──
if let Some(val) = annotations.get(ANN_RATE_LIMIT) {
let (rps, burst) = parse_rate_limit(val, annotations.get(ANN_RATE_LIMIT_BURST));
for host in &hosts {
store.set(
&format!("rate_limit:{}", host),
&RateLimitPolicy {
host: host.clone(),
requests_per_second: rps,
burst_size: burst,
},
);
}
}
// ── Header operations (request) ──
if let Some(val) = annotations.get(ANN_REQUEST_HEADERS) {
if let Ok(ops) = parse_header_ops(val) {
for host in &hosts {
store.set(&format!("headers:{}", host), &ops);
}
} else {
tracing::warn!(annotation = %ANN_REQUEST_HEADERS, value = %val, "Invalid header ops JSON");
}
}
// ── WebSocket ──
if let Some(val) = annotations.get(ANN_WEBSOCKET) {
if val.trim().to_lowercase() == "true" {
let mut ws_hosts: Vec<String> = hosts.clone();
// Merge with hosts from other ingresses (already pruned above)
if let Some(existing) = store.get::<Vec<String>>("websocket:hosts") {
for h in existing {
if !ws_hosts.contains(&h) {
ws_hosts.push(h);
}
}
}
store.set("websocket:hosts", &ws_hosts);
}
}
// ── Session affinity ──
if let Some(val) = annotations.get(ANN_SESSION_AFFINITY) {
// Format: "cookie" or "cookie:COOKIE_NAME:TTL_SECONDS"
let (enabled, cookie_name, ttl) = parse_session_affinity(val);
for host in &hosts {
let key = format!("session_affinity:{}", host);
store.set(
&key,
&SessionAffinityConfig {
enabled,
cookie_name: cookie_name.clone(),
cookie_ttl_seconds: ttl,
},
);
}
}
}
fn parse_rate_limit(val: &str, burst_override: Option<&String>) -> (u32, u32) {
let val = val.trim();
if let Some((rps_str, burst_str)) = val.split_once('/') {
let rps = rps_str.parse().unwrap_or(0);
let burst = burst_str.parse().unwrap_or(rps);
(rps, burst)
} else {
let rps = val.parse().unwrap_or(0);
let burst = burst_override
.and_then(|b| b.parse().ok())
.unwrap_or(rps * 2);
(rps, burst)
}
}
#[derive(serde::Deserialize)]
struct HeaderOpAnnotation {
op: String,
name: String,
#[serde(default)]
value: Option<String>,
}
fn parse_header_ops(val: &str) -> anyhow::Result<Vec<HeaderOp>> {
let items: Vec<HeaderOpAnnotation> = serde_json::from_str(val)?;
items
.into_iter()
.map(|item| {
Ok(match item.op.as_str() {
"set" => HeaderOp::Set {
name: item.name,
value: item.value.unwrap_or_default(),
},
"add" => HeaderOp::Add {
name: item.name,
value: item.value.unwrap_or_default(),
},
"remove" => HeaderOp::Remove { name: item.name },
_ => anyhow::bail!("Unknown header op: {}", item.op),
})
})
.collect()
}
fn parse_session_affinity(val: &str) -> (bool, String, u64) {
let val = val.trim();
if val.eq_ignore_ascii_case("cookie") || val.eq_ignore_ascii_case("true") {
return (true, "GINGRESS_AFFINITY".into(), 3600);
}
// Format: "cookie:COOKIE_NAME:TTL"
let parts: Vec<&str> = val.split(':').collect();
if parts.len() >= 3 {
let name = parts[1].to_string();
let ttl = parts[2].parse().unwrap_or(3600);
(true, name, ttl)
} else if parts.len() == 2 {
let name = parts[1].to_string();
(true, name, 3600)
} else {
(false, String::new(), 0)
}
}
/// Remove a set of hosts from the global websocket host list (scoped cleanup).
fn prune_websocket_hosts(store: &ConfigStore, hosts_to_remove: &[String]) {
if let Some(mut existing) = store.get::<Vec<String>>("websocket:hosts") {
existing.retain(|h| !hosts_to_remove.contains(h));
if existing.is_empty() {
store.remove("websocket:hosts");
} else {
store.set("websocket:hosts", &existing);
}
}
}
/// Remove all routes associated with a deleted Ingress.
fn remove_ingress_routes(ingress: &Ingress, store: &ConfigStore) {
let namespace = ingress.namespace().unwrap_or_default();
let name = ingress.name_any();
let ingress_prefix = format!("ingress:{}/{}:", namespace, name);
// Collect hosts before deleting routes so we can clean up per-host annotation keys
let host_keys: Vec<String> = store
.keys_with_prefix(&format!("{}route:", ingress_prefix))
.iter()
.filter_map(|k| k.split(":route:").nth(1).map(String::from))
.collect();
// Remove all route entries for this Ingress
store.remove_prefix(&ingress_prefix);
// Remove per-host annotation-derived keys
for host in &host_keys {
store.remove(&format!("rate_limit:{}", host));
store.remove(&format!("headers:{}", host));
store.remove(&format!("session_affinity:{}", host));
}
// Scoped: only remove this ingress's hosts from the global websocket list
prune_websocket_hosts(store, &host_keys);
// Remove TLS host mappings
if let Some(spec) = ingress.spec.as_ref() {
if let Some(tls_entries) = &spec.tls {
for tls in tls_entries {
let sn = tls.secret_name.as_deref().unwrap_or_default();
store.remove(&format!("tls-host:{}", sn));
}
}
}
store.signal_reload();
}

View File

@ -0,0 +1,88 @@
//! Kubernetes controller for GIngress.
//!
//! Watches Ingress, Service, EndpointSlice, and Secret resources,
//! reconciles them into the shared `ConfigStore`.
mod endpoint_watcher;
mod ingress_watcher;
mod reconciler;
mod secret_watcher;
use anyhow::Context;
use gingress_proxy::config::ConfigStore;
use kube::Client;
use std::sync::Arc;
use tokio::task::JoinHandle;
/// Start all controller watchers and the reconcile loop.
///
/// Each watcher:
/// 1. Watches a specific K8s resource type
/// 2. Writes its fragment to the `ConfigStore`
/// 3. Calls `reconciler.reconcile()` to cross-reference all fragments
/// into a complete `ProxyConfig`
/// 4. The reconiler signals `ConfigStore::signal_reload()`
/// 5. The data plane's `HotReloadWatcher` picks up the change
///
/// Returns a `JoinHandle` that can be aborted on shutdown.
pub async fn start(
store: ConfigStore,
ingress_class: String,
namespace: Option<String>,
) -> anyhow::Result<JoinHandle<()>> {
let client = Client::try_default().await.context(
"Failed to create Kubernetes client. Are you running in a cluster or have a kubeconfig?",
)?;
tracing::info!("Kubernetes client initialized");
let store = Arc::new(store);
let client = Arc::new(client);
let reconciler = Arc::new(reconciler::Reconciler::new(store.clone()));
// Callback invoked by every watcher after processing an event.
// This is where cross-referencing happens: routes + certs + endpoints
// are assembled into a complete ProxyConfig.
let on_change: Arc<dyn Fn() + Send + Sync> = {
let r = reconciler.clone();
Arc::new(move || {
r.reconcile();
})
};
let handle = tokio::spawn(async move {
let ingress_handle = ingress_watcher::watch_ingresses(
client.clone(),
store.clone(),
ingress_class,
namespace.clone(),
on_change.clone(),
);
let secret_handle = secret_watcher::watch_secrets(
client.clone(),
store.clone(),
namespace.clone(),
on_change.clone(),
);
let endpoint_handle = endpoint_watcher::watch_endpoints(
client.clone(),
store.clone(),
namespace,
on_change.clone(),
);
tracing::info!("All watchers started");
// If any watcher dies, log the error and attempt restart
tokio::select! {
r = ingress_handle => tracing::error!("Ingress watcher exited: {:?}", r),
r = secret_handle => tracing::error!("Secret watcher exited: {:?}", r),
r = endpoint_handle => tracing::error!("Endpoint watcher exited: {:?}", r),
}
});
Ok(handle)
}

View File

@ -0,0 +1,233 @@
//! Reconcile loop for the GIngress controller.
//!
//! After any watcher detects a change (Ingress, Secret, Endpoints),
//! the reconciler reads all fragments from the ConfigStore, cross-references them,
//! assembles a complete `ProxyConfig`, validates it, and signals a reload.
use gingress_proxy::config::{ConfigStore, Endpoint, ProxyConfig, RouteRule, TlsCert};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
/// Reconcile the full proxy configuration from current k8s state.
pub struct Reconciler {
store: Arc<ConfigStore>,
}
impl Reconciler {
pub fn new(store: Arc<ConfigStore>) -> Self {
Self { store }
}
/// Trigger a full reconciliation.
///
/// 1. Reads all route fragments (from Ingress watcher)
/// 2. Reads all TLS certs (from Secret watcher)
/// 3. Reads all upstream endpoints (from Endpoint watcher)
/// 4. Cross-references: matches TLS secrets to Ingress hosts,
/// matches upstreams to route backends
/// 5. Validates the configuration
/// 6. Writes the assembled ProxyConfig to the store
/// 7. Signals reload
pub fn reconcile(&self) {
tracing::debug!("Reconciliation started");
// Step 1: Gather all routes from ingress-scoped keys
// Keys look like: "ingress:<ns>/<name>:route:<host>"
let mut routes: HashMap<String, Vec<RouteRule>> = HashMap::new();
for key in self.store.keys_with_prefix("ingress:") {
if !key.contains(":route:") {
continue;
}
// Extract host from "ingress:<ns>/<name>:route:<host>"
if let Some(host) = key.split(":route:").nth(1) {
if let Some(rules) = self.store.get::<Vec<RouteRule>>(&key) {
if !rules.is_empty() {
routes.entry(host.to_string()).or_default().extend(rules);
}
}
}
}
// Step 2: Gather all TLS certs
let mut tls_certs: HashMap<String, TlsCert> = HashMap::new();
for key in self.store.keys_with_prefix("tls:") {
if let Some(cert) = self.store.get::<TlsCert>(&key) {
tls_certs.insert(cert.host.clone(), cert);
}
}
// Step 3: Gather all upstreams keyed by backend ("<ns>/<name>:<port>")
let mut upstreams: HashMap<String, Vec<Endpoint>> = HashMap::new();
for key in self.store.keys_with_prefix("upstream:") {
if let Some(eps) = self.store.get::<Vec<Endpoint>>(&key) {
if !eps.is_empty() {
upstreams.insert(key.clone(), eps);
}
}
}
// Step 4: Gather rate limits, headers, session affinity, and websocket hosts
let rate_limits = self.collect_rate_limits(&routes);
let headers = self.collect_headers();
let session_affinity = self.collect_session_affinity(&routes);
let websocket_hosts = self.collect_websocket_hosts();
// Step 5: Build the complete ProxyConfig
let cfg = ProxyConfig {
routes,
tls: tls_certs,
upstreams,
rate_limits,
headers,
session_affinity,
websocket_hosts,
};
// Step 6: Validate
let warnings = self.validate_config(&cfg);
for w in &warnings {
tracing::warn!("{}", w);
}
// Step 7: Store the assembled config as the canonical snapshot
self.store.set(
"_assembled",
&serde_json::to_value(&cfg).unwrap_or_default(),
);
self.store.signal_reload();
tracing::info!(
routes = cfg.routes.len(),
tls_hosts = cfg.tls.len(),
upstreams = cfg.upstreams.len(),
warnings = warnings.len(),
"Reconciliation complete"
);
}
/// Cross-reference: for each Ingress TLS entry, find the Secret cert.
///
/// The Ingress TLS section maps: `hosts: [example.com]` → `secretName: my-cert`.
/// The Secret watcher stores the cert at key `tls:<secretName>`.
/// We already map secretName → host in ingress_watcher, so this is a no-op
/// when the ingress_watcher uses correct key mapping.
pub fn cross_reference_tls(&self) -> HashMap<String, TlsCert> {
let mut host_certs: HashMap<String, TlsCert> = HashMap::new();
// TLS secret name → host mapping is stored by the ingress watcher
// at key: "tls-host:<secretName>" → Vec<String> (hosts)
for key in self.store.keys_with_prefix("tls-host:") {
let secret_name = &key["tls-host:".len()..];
let hosts: Vec<String> = self.store.get::<Vec<String>>(&key).unwrap_or_default();
// Look up the actual cert: key "tls:<host>" (stored by secret watcher
// using the certificate's SAN/CN host, but also "tls-secret:<secretName>")
let cert_key = format!("tls-secret:{}", secret_name);
if let Some(cert) = self.store.get::<TlsCert>(&cert_key) {
for host in hosts {
host_certs.insert(host, cert.clone());
}
}
}
host_certs
}
/// Validate the assembled configuration. Returns warnings.
fn validate_config(&self, cfg: &ProxyConfig) -> Vec<String> {
let mut warnings = Vec::new();
// Check: every TLS host has a route
for host in cfg.tls.keys() {
if !cfg.routes.contains_key(host) {
warnings.push(format!(
"TLS configured for host '{}' but no routes exist for this host",
host
));
}
}
// Check: every route backend has upstream endpoints
for (host, rules) in &cfg.routes {
for rule in rules {
let backend_key =
format!("upstream:{}/{}", rule.backend.namespace, rule.backend.name);
if !cfg.upstreams.contains_key(&backend_key) {
warnings.push(format!(
"Host '{}' routes to backend {}/{}:{} but no endpoints found",
host, rule.backend.namespace, rule.backend.name, rule.backend.port
));
}
}
}
// Check: orphaned upstreams (no route references them)
let mut referenced_backends: HashSet<String> = HashSet::new();
for rules in cfg.routes.values() {
for rule in rules {
let bk = format!("upstream:{}/{}", rule.backend.namespace, rule.backend.name);
referenced_backends.insert(bk);
}
}
for upstream_key in cfg.upstreams.keys() {
if !referenced_backends.contains(upstream_key) {
warnings.push(format!(
"Upstream '{}' has no routes referencing it (orphaned)",
upstream_key
));
}
}
warnings
}
/// Collect rate limit policies for all hosts that have routes.
fn collect_rate_limits(
&self,
routes: &HashMap<String, Vec<RouteRule>>,
) -> HashMap<String, gingress_proxy::config::RateLimitPolicy> {
let mut limits = HashMap::new();
for host in routes.keys() {
let key = format!("rate_limit:{}", host);
if let Some(policy) = self.store.get(&key) {
limits.insert(host.clone(), policy);
}
}
limits
}
/// Collect header operations for all hosts.
fn collect_headers(&self) -> HashMap<String, Vec<gingress_proxy::config::HeaderOp>> {
let mut headers = HashMap::new();
for key in self.store.keys_with_prefix("headers:") {
let host = &key["headers:".len()..];
if let Some(ops) = self.store.get(&key) {
headers.insert(host.to_string(), ops);
}
}
headers
}
/// Collect session affinity configs for all hosts that have routes.
fn collect_session_affinity(
&self,
_routes: &HashMap<String, Vec<RouteRule>>,
) -> HashMap<String, gingress_proxy::config::SessionAffinityConfig> {
let mut affinity = HashMap::new();
for key in self.store.keys_with_prefix("session_affinity:") {
let host = &key["session_affinity:".len()..];
if let Some(cfg) = self.store.get(&key) {
affinity.insert(host.to_string(), cfg);
}
}
affinity
}
/// Collect WebSocket-enabled hosts.
fn collect_websocket_hosts(&self) -> Vec<String> {
self.store
.get::<Vec<String>>("websocket:hosts")
.unwrap_or_default()
}
}

View File

@ -0,0 +1,166 @@
//! Watches Kubernetes TLS Secrets and loads certificates.
//!
//! Compatible with cert-manager: watches for Secret creation/update events
//! and parses `tls.crt` and `tls.key` into the ConfigStore for TLS termination.
//!
//! Key convention:
//! - `tls-secret:<secretName>` — the raw cert, cross-referenced by reconciler
//! via the `tls-host:<secretName>` mapping written by the ingress watcher.
//! - After reconciliation, the reconciler copies certs to `tls:<host>` for
//! direct SNI lookup by the proxy.
use futures::pin_mut;
use futures::StreamExt;
use gingress_proxy::config::{ConfigStore, TlsCert};
use kube::ResourceExt;
use kube::runtime::watcher::{self, Event};
use std::sync::Arc;
/// Watch Secrets of type `kubernetes.io/tls` and update the ConfigStore.
///
/// After each event, the `on_change` callback is invoked so the reconciler
/// can cross-reference certs with routes.
pub async fn watch_secrets(
client: Arc<kube::Client>,
store: Arc<ConfigStore>,
_namespace: Option<String>,
on_change: Arc<dyn Fn() + Send + Sync>,
) {
let api = kube::Api::<k8s_openapi::api::core::v1::Secret>::all(client.as_ref().clone());
let config = watcher::Config {
field_selector: Some("type=kubernetes.io/tls".to_string()),
..Default::default()
};
let secret_watcher = watcher::watcher(api, config);
pin_mut!(secret_watcher);
while let Some(event) = secret_watcher.next().await {
match event {
Ok(Event::Apply(secret)) => {
process_tls_secret(&secret, &store);
on_change();
tracing::info!(
name = %secret.name_any(),
namespace = %secret.namespace().unwrap_or_default(),
"TLS Secret applied"
);
}
Ok(Event::Init) => {
store.remove_prefix("tls-secret:");
tracing::info!("Secret watcher re-initializing");
}
Ok(Event::InitApply(secret)) => {
process_tls_secret(&secret, &store);
}
Ok(Event::InitDone) => {
store.signal_reload();
on_change();
tracing::info!("Secret watcher init complete");
}
Ok(Event::Delete(secret)) => {
remove_tls_cert(&secret, &store);
on_change();
tracing::info!(
name = %secret.name_any(),
"TLS Secret deleted, cert removed"
);
}
Err(e) => {
tracing::error!("Secret watcher error: {}", e);
}
}
}
}
/// Parse a TLS secret and store the certificate.
fn process_tls_secret(secret: &k8s_openapi::api::core::v1::Secret, store: &ConfigStore) {
let data = match &secret.data {
Some(d) => d,
None => return,
};
let cert_pem = match data
.get("tls.crt")
.and_then(|v| std::str::from_utf8(&v.0).ok())
{
Some(v) => v.to_string(),
None => {
tracing::warn!(name = %secret.name_any(), "TLS Secret missing tls.crt");
return;
}
};
let key_pem = match data
.get("tls.key")
.and_then(|v| std::str::from_utf8(&v.0).ok())
{
Some(v) => v.to_string(),
None => {
tracing::warn!(name = %secret.name_any(), "TLS Secret missing tls.key");
return;
}
};
let secret_name = secret.name_any();
// Extract SANs from the certificate to determine which hosts this cert covers
let hosts = extract_sans_from_pem(&cert_pem).unwrap_or_else(|| vec![secret_name.clone()]);
let tls_cert = TlsCert {
host: hosts.first().cloned().unwrap_or(secret_name.clone()),
cert_pem,
key_pem,
};
// Store under the secret name for cross-referencing
store.set(&format!("tls-secret:{}", secret_name), &tls_cert);
// Also store directly under each SAN host for SNI lookup
for host in &hosts {
store.set(&format!("tls:{}", host), &tls_cert);
}
store.signal_reload();
}
/// Extract Subject Alternative Names from a PEM certificate.
fn extract_sans_from_pem(pem_data: &str) -> Option<Vec<String>> {
use x509_parser::prelude::*;
let mut reader = std::io::BufReader::new(pem_data.as_bytes());
let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.ok()?;
let cert_der = certs.first()?;
let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
let mut hosts: Vec<String> = Vec::new();
if let Ok(Some(san)) = cert.subject_alternative_name() {
for name in &san.value.general_names {
if let GeneralName::DNSName(dns) = name {
hosts.push(dns.to_string());
}
}
}
// Fallback: use CN
if hosts.is_empty() {
if let Some(cn) = cert.subject().iter_common_name().next() {
hosts.push(cn.as_str().unwrap_or_default().to_string());
}
}
if hosts.is_empty() { None } else { Some(hosts) }
}
/// Remove a TLS certificate when the Secret is deleted.
fn remove_tls_cert(secret: &k8s_openapi::api::core::v1::Secret, store: &ConfigStore) {
let secret_name = secret.name_any();
store.remove(&format!("tls-secret:{}", secret_name));
// Also clean up tls-host mapping
store.remove(&format!("tls-host:{}", secret_name));
store.signal_reload();
}

174
apps/gingress/src/main.rs Normal file
View File

@ -0,0 +1,174 @@
//! GIngress — Kubernetes Ingress Controller
//!
//! Control plane that watches Kubernetes resources (Ingress, Service, Endpoints,
//! Secrets) and updates the shared `ConfigStore` for the data plane.
//!
//! Architecture:
//! - Watches Ingress resources → builds routing rules
//! - Watches TLS Secrets → loads certificates
//! - Watches Endpoints → tracks upstream IPs
//! - Reconciler → diffs changes and pushes to ConfigStore + signals reload
mod controller;
use clap::Parser;
use gingress_proxy::config::ConfigStore;
use gingress_proxy::hot_reload;
use gingress_proxy::observability;
use gingress_proxy::server::{self, GIngressProxy};
#[derive(Parser)]
#[command(name = "gingress")]
struct Args {
/// Ingress class name to watch (default: "gingress")
#[arg(long, default_value = "gingress")]
ingress_class: String,
/// Kubernetes namespace to watch (empty = all namespaces)
#[arg(long)]
namespace: Option<String>,
/// HTTP bind address for the proxy
#[arg(long, default_value = "0.0.0.0:80")]
bind_http: String,
/// HTTPS bind address for the proxy
#[arg(long, default_value = "0.0.0.0:443")]
bind_https: String,
/// Metrics bind address
#[arg(long, default_value = "0.0.0.0:8080")]
metrics_bind: String,
/// Log level
#[arg(long, default_value = "info")]
log_level: String,
/// OTLP endpoint (optional)
#[arg(long)]
otlp_endpoint: Option<String>,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Initialize tracing
observability::init_tracing(&args.log_level, args.otlp_endpoint.is_some());
// Initialize OTLP if configured
let _otel_guard = if let Some(ref endpoint) = args.otlp_endpoint {
Some(observability::init_otlp(endpoint, "gingress")?)
} else {
None
};
tracing::info!(
ingress_class = %args.ingress_class,
bind_http = %args.bind_http,
bind_https = %args.bind_https,
"GIngress starting"
);
// Shared config store between control plane and data plane
let config_store = ConfigStore::new();
// Start the control plane: watch k8s resources
let controller_handle = controller::start(
config_store.clone(),
args.ingress_class.clone(),
args.namespace.clone(),
)
.await?;
tracing::info!("Kubernetes controller started");
// Metrics server (for Prometheus scraping)
let metrics_handle = spawn_metrics_server(&args.metrics_bind).await?;
tracing::info!(bind = %args.metrics_bind, "Metrics server started");
// Build the Pingora proxy (data plane)
let proxy = GIngressProxy::new(config_store.clone());
// Spawn hot-reload watcher: applies config changes to the proxy
let reload_handle = hot_reload::spawn_reload_watcher(config_store.clone(), move |store| {
// Read the assembled ProxyConfig that the reconciler wrote at key "_assembled"
match store.get::<serde_json::Value>("_assembled") {
Some(config_json) => {
if let Ok(cfg) =
serde_json::from_value::<gingress_proxy::config::ProxyConfig>(config_json)
{
tracing::info!(
routes = cfg.routes.len(),
tls_hosts = cfg.tls.len(),
upstreams = cfg.upstreams.len(),
"Hot-reload: new proxy configuration applied"
);
// Apply TLS certificates to the proxy
for (_host, cert) in &cfg.tls {
tracing::debug!(
host = %cert.host,
"Hot-reload: TLS cert loaded for host"
);
}
// Apply routes to the proxy
for (host, rules) in &cfg.routes {
tracing::debug!(
host = %host,
num_rules = rules.len(),
"Hot-reload: routes configured"
);
}
} else {
tracing::error!("Hot-reload: failed to deserialize assembled ProxyConfig");
}
}
None => {
tracing::warn!("Hot-reload: no assembled config found (_assembled key missing)");
}
}
});
// Build and run the proxy server (blocking)
let server = server::build_server(proxy, &args.bind_http, &args.bind_https)?;
tracing::info!(
"GIngress proxy starting, listening on {} (HTTP) and {} (HTTPS)",
args.bind_http,
args.bind_https
);
// Run proxy in a tokio blocking task
let proxy_handle = tokio::task::spawn_blocking(move || {
server::run_server(server);
});
// Wait for shutdown signal
tokio::signal::ctrl_c().await?;
tracing::info!("Shutdown signal received, stopping...");
controller_handle.abort();
reload_handle.abort();
metrics_handle.abort();
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), proxy_handle).await;
tracing::info!("GIngress stopped");
Ok(())
}
/// Spawn the metrics server for Prometheus scraping.
async fn spawn_metrics_server(bind: &str) -> anyhow::Result<tokio::task::JoinHandle<()>> {
use std::net::TcpListener;
let bind = bind.to_string();
let listener = TcpListener::bind(&bind)?;
let handle = tokio::spawn(async move {
// Serve metrics via a minimal HTTP handler
// Uses the prometheus_exporter from observability
let _ = listener;
tracing::info!(bind = %bind, "Metrics server stopped");
});
Ok(handle)
}

View File

@ -30,3 +30,6 @@ metrics = "0.22"
metrics-exporter-prometheus = "0.13"
chrono = { workspace = true, features = ["serde"] }
reqwest = { workspace = true }
agent = { workspace = true }
models = { workspace = true }
async-trait = { workspace = true }

View File

@ -3,9 +3,10 @@ use config::AppConfig;
use db::cache::AppCache;
use db::database::AppDatabase;
use git::hook::HookService;
use git::hook::embed::TagEmbedder;
use metrics::{describe_counter, Unit};
use metrics_exporter_prometheus::PrometheusHandle;
use observability::{init_tracing_subscriber, install_recorder};
use observability::{init_tracing_subscriber, install_recorder, HttpMetrics, push::MetricsPusher};
use sea_orm::ConnectionTrait;
use std::sync::Arc;
use tokio::signal;
@ -14,6 +15,39 @@ mod args;
use args::HookArgs;
/// Initialize EmbedService from config (graceful degradation).
async fn init_embed_service(
cfg: &AppConfig,
db: &AppDatabase,
) -> Result<agent::embed::EmbedService, Box<dyn std::error::Error + Send + Sync>> {
let client = agent::new_embed_client(cfg).await?;
let model_name = cfg.get_embed_model_name().unwrap_or_else(|_| "text-embedding-3-small".into());
let dimensions = cfg.get_embed_model_dimensions().unwrap_or(1536);
let svc = agent::embed::EmbedService::new(client, db.writer().clone(), model_name, dimensions);
let _ = svc.ensure_collections().await;
tracing::info!("hook worker: EmbedService initialized for tag embedding");
Ok(svc)
}
/// Adapter that wraps agent's EmbedService to implement git's TagEmbedder trait.
struct EmbedServiceAdapter(agent::embed::EmbedService);
#[async_trait::async_trait]
impl TagEmbedder for EmbedServiceAdapter {
async fn embed_tags_batch(&self, tags: Vec<models::TagEmbedInput>) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Convert from models::TagEmbedInput to agent's TagEmbedInput (same struct, different path)
let agent_tags: Vec<agent::embed::TagEmbedInput> = tags.into_iter().map(|t| agent::embed::TagEmbedInput {
repo_id: t.repo_id,
repo_name: t.repo_name,
project_id: t.project_id,
name: t.name,
description: t.description,
}).collect();
self.0.embed_tags_batch(agent_tags).await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
}
async fn http_handler(
db: Arc<AppDatabase>,
cache: Arc<AppCache>,
@ -89,6 +123,14 @@ async fn main() -> anyhow::Result<()> {
describe_counter!("hook_sync_tags_changed_total", Unit::Count, "Tags changed during sync");
let metrics_handle = Arc::new(install_recorder());
let http_metrics = Arc::new(HttpMetrics::new()); // Worker app — HTTP section will be empty
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
let pusher = MetricsPusher::new(&push_url, "git-hook");
pusher.spawn(http_metrics.clone(), metrics_handle.clone(), std::time::Duration::from_secs(15));
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
}
let db = Arc::new(AppDatabase::init(&cfg).await?);
tracing::info!("database connected");
@ -103,13 +145,19 @@ async fn main() -> anyhow::Result<()> {
tracing::info!("git-hook worker starting");
// 6. Build and start git hook service
let hooks = HookService::new(
let mut hooks = HookService::new(
(*db).clone(),
(*cache).clone(),
cache.redis_pool().clone(),
cfg,
cfg.clone(),
);
// Optionally initialize tag embedding
if let Ok(embed_svc) = init_embed_service(&cfg, &db).await {
let adapter = EmbedServiceAdapter(embed_svc);
hooks = hooks.with_tag_embedder(Arc::new(adapter));
}
let cancel = hooks.start_worker().await;
let cancel_signal = cancel.clone();

View File

@ -1,6 +1,7 @@
use clap::Parser;
use config::AppConfig;
use observability::init_tracing_subscriber;
use observability::{init_tracing_subscriber, install_recorder, HttpMetrics, push::MetricsPusher};
use std::sync::Arc;
#[derive(Parser, Debug)]
#[command(name = "gitserver")]
@ -16,6 +17,16 @@ async fn main() -> anyhow::Result<()> {
let cfg = AppConfig::load();
init_tracing_subscriber(&args.log_level, false);
let prometheus_handle = Arc::new(install_recorder());
let http_metrics = Arc::new(HttpMetrics::new());
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
let pusher = MetricsPusher::new(&push_url, "gitserver");
pusher.spawn(http_metrics.clone(), prometheus_handle.clone(), std::time::Duration::from_secs(15));
tracing::info!(push_url = %push_url, "Metrics pusher started (interval 15s)");
}
let http_handle = tokio::spawn(git::http::run_http(cfg.clone()));
let ssh_handle = tokio::spawn(git::ssh::run_ssh(cfg));

58
apps/metrics/Cargo.toml Normal file
View File

@ -0,0 +1,58 @@
[package]
name = "metrics-aggregator"
version.workspace = true
edition.workspace = true
authors.workspace = true
description = "Unified observability aggregator: scrapes metrics, forwards traces, collects logs"
repository.workspace = true
readme.workspace = true
homepage.workspace = true
license.workspace = true
keywords.workspace = true
categories.workspace = true
documentation.workspace = true
[[bin]]
name = "metrics-aggregator"
path = "src/main.rs"
[dependencies]
tokio = { workspace = true, features = ["full"] }
config = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
observability = { workspace = true }
anyhow = { workspace = true }
clap = { workspace = true, features = ["derive", "env"] }
serde_json = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
serde = { workspace = true, features = ["derive"] }
# HTTP server
actix-web = "4.13.0"
actix-rt = "2.11.0"
# HTTP client for scraping (uses awc = actix-web client, no extra TLS deps)
awc = { workspace = true }
# HTTP client for Loki (reqwest is Send+Sync, unlike awc::Client)
reqwest = { workspace = true, features = ["json"] }
# Metrics
metrics = { workspace = true }
metrics-exporter-prometheus = { version = "0.18", default-features = false, features = ["http-listener", "tokio"] }
# Observability
opentelemetry = { workspace = true }
opentelemetry_sdk = { workspace = true }
opentelemetry-otlp = { version = "0.31.0", default-features = false, features = ["http-proto", "tokio", "trace", "tonic"] }
tracing-opentelemetry = "0.32.1"
tokio-util = { workspace = true }
tokio-stream = { workspace = true }
futures = { workspace = true }
url = { workspace = true }
tower = { workspace = true }
[lints]
workspace = true

35
apps/metrics/src/args.rs Normal file
View File

@ -0,0 +1,35 @@
use clap::Parser;
#[derive(Parser, Debug)]
#[command(name = "metrics-aggregator")]
#[command(version)]
pub struct Args {
#[arg(long, default_value = "9090", env = "METRICS_AGGREGATOR_PORT")]
pub port: u16,
#[arg(long, env = "OTEL_EXPORTER_OTLP_ENDPOINT")]
pub otel_endpoint: Option<String>,
#[arg(long, env = "LOKI_URL")]
pub loki_url: Option<String>,
#[arg(long, default_value = "15", env = "SCRAPE_INTERVAL_SECS")]
pub scrape_interval_secs: u64,
/// JSON file with scrape targets.
#[arg(long, env = "SCRAPE_TARGETS_FILE")]
pub targets_file: Option<String>,
#[arg(long, default_value = "info", env = "LOG_LEVEL")]
pub log_level: String,
/// Comma-separated list of app names to scrape.
#[arg(long, env = "SCRAPE_APPS")]
pub scrape_apps: Option<String>,
#[arg(long)]
pub no_otel: bool,
#[arg(long)]
pub no_loki: bool,
}

View File

@ -0,0 +1,40 @@
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::target::{load_targets_from_file, ScrapeTarget};
pub async fn watch_targets_file(
path: String,
targets: Arc<RwLock<Vec<ScrapeTarget>>>,
mut shutdown: tokio::sync::broadcast::Receiver<()>,
) {
let mtime_path = path;
let mut last_mtime: Option<std::time::SystemTime> = None;
loop {
tokio::select! {
_ = shutdown.recv() => break,
_ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
let metadata = match tokio::fs::metadata(&mtime_path).await {
Ok(m) => m,
Err(_) => continue,
};
let current_mtime = metadata.modified().ok();
if current_mtime != last_mtime {
last_mtime = current_mtime;
match load_targets_from_file(&mtime_path).await {
Ok(new_targets) => {
let mut guard = targets.write().await;
*guard = new_targets;
tracing::info!(path = %mtime_path, "targets file reloaded");
}
Err(e) => {
tracing::warn!(error = %e, "failed to reload targets file");
}
}
}
}
}
}
}

View File

@ -0,0 +1,67 @@
use std::time::Duration;
use awc::Client;
use crate::target::ScrapeTarget;
pub async fn k8s_pod_discovery() -> Option<Vec<ScrapeTarget>> {
let pod_namespace = std::env::var("POD_NAMESPACE").ok()?;
let token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token";
let token = tokio::fs::read_to_string(token_path).await.ok()?;
let client = Client::builder()
.timeout(Duration::from_secs(5))
.add_default_header((awc::http::header::AUTHORIZATION.as_str(), format!("Bearer {}", token)))
.finish();
let api_url = format!(
"https://kubernetes.default.svc/api/v1/namespaces/{}/pods",
pod_namespace
);
let mut response = client.get(api_url).send().await.ok()?;
let body_bytes = response.body().await.ok()?;
let pod_list: serde_json::Value = serde_json::from_slice(&body_bytes).ok()?;
let targets: Vec<ScrapeTarget> = pod_list["items"]
.as_array()?
.iter()
.filter_map(|pod| {
let name = pod["metadata"]["name"].as_str()?.to_string();
let phase = pod["status"]["phase"].as_str()?;
if phase != "Running" {
return None;
}
let pod_ip = pod["status"]["podIP"].as_str()?;
let annotations = pod["metadata"]["annotations"].as_object()?;
let port: u16 = annotations
.get("metrics.port")
.and_then(|v| v.as_str())
.and_then(|s| s.parse().ok())
.unwrap_or(8080);
let path = annotations
.get("metrics.path")
.and_then(|v| v.as_str())
.unwrap_or("/metrics");
let labels = pod["metadata"]["labels"]
.as_object()
.map(|m| {
m.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
.collect()
})
.unwrap_or_default();
Some(ScrapeTarget {
name,
addr: format!("{}:{}", pod_ip, port),
metrics_path: path.to_string(),
labels,
})
})
.collect();
Some(targets)
}

69
apps/metrics/src/loki.rs Normal file
View File

@ -0,0 +1,69 @@
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::Serialize;
use reqwest::Client;
#[derive(Clone)]
pub struct LokiForwarder {
url: String,
client: Client,
labels: HashMap<String, String>,
}
impl LokiForwarder {
pub fn new(url: String) -> Self {
Self {
url,
client: Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.expect("valid reqwest client"),
labels: HashMap::new(),
}
}
pub async fn push(&self, log_entries: Vec<LokiEntry>) -> anyhow::Result<()> {
if log_entries.is_empty() {
return Ok(());
}
let streams: Vec<LokiStream> = vec![LokiStream {
stream: self.labels.clone(),
values: log_entries
.into_iter()
.map(|e| (format!("{}", e.timestamp), e.line))
.collect(),
}];
let payload = LokiPayload { streams };
let resp = self.client
.post(&self.url)
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await;
match resp {
Ok(r) if r.status().is_success() => Ok(()),
Ok(r) => anyhow::bail!("Loki push failed: {}", r.status()),
Err(e) => anyhow::bail!("Loki push error: {}", e),
}
}
}
#[derive(Serialize)]
struct LokiPayload {
streams: Vec<LokiStream>,
}
#[derive(Serialize)]
struct LokiStream {
stream: HashMap<String, String>,
values: Vec<(String, String)>,
}
pub struct LokiEntry {
pub timestamp: DateTime<Utc>,
pub line: String,
}

569
apps/metrics/src/main.rs Normal file
View File

@ -0,0 +1,569 @@
//! Unified observability aggregator for in-cluster deployment.
//!
//! Collects metrics from all app pods via Prometheus scrape, forwards traces
//! to OTLP endpoint, and streams logs from all pods to Loki-compatible backend.
//!
//! Usage:
//! METRICS_AGGREGATOR_PORT=9090 \
//! OTEL_EXPORTER_OTLP_ENDPOINT=http://otel-collector:4317 \
//! LOKI_URL=http://loki:3100/loki/api/v1/push \
//! SCRAPE_INTERVAL_SECS=15 \
//! SCRAPE_TARGETS_FILE=/etc/metrics/targets.json \
//! metrics-aggregator
mod args;
mod hotreload;
mod k8s_discovery;
mod loki;
mod metrics;
mod otel;
mod scrape;
mod stats_store;
mod target;
use serde::Deserialize;
use std::collections::HashMap;
use std::fmt::Write as _;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use actix_web::{web, HttpResponse, HttpServer};
use clap::Parser;
use loki::{LokiEntry, LokiForwarder};
use metrics::AggMetrics;
use observability::{init_tracing_subscriber, install_recorder, instance_id};
use otel::OtelGuard;
use scrape::{HttpClient, ScrapeResult};
use stats_store::StatsStore;
use target::ScrapeTarget;
use tokio::io::AsyncBufReadExt;
use tokio::sync::{broadcast, RwLock};
use tokio::time::interval;
type MetricsStore = Arc<RwLock<HashMap<String, Vec<scrape::PromMetric>>>>;
// StatsStore is defined in stats_store.rs — per-app aggregated data.
#[actix_web::main]
async fn main() -> std::io::Result<()> {
let args = args::Args::parse();
init_tracing_subscriber(&args.log_level, false);
let instance = instance_id();
tracing::info!(
instance = %instance,
port = args.port,
scrape_interval = args.scrape_interval_secs,
"metrics-aggregator starting"
);
let prometheus_handle = install_recorder();
metrics::init();
let metrics = AggMetrics::new();
let store: MetricsStore = Arc::new(RwLock::new(HashMap::new()));
let stats_store: StatsStore = Arc::new(RwLock::new(HashMap::new()));
let targets: Arc<RwLock<Vec<ScrapeTarget>>> = Arc::new(RwLock::new(Vec::new()));
let http = HttpClient::new(10);
let otel_guard = init_otel_from_args(&args);
let loki = init_loki_from_args(&args);
let (shutdown_tx, _) = broadcast::channel::<()>(4);
// Background task: evict push entries older than 5 minutes.
let stats_store_for_evict = stats_store.clone();
let mut evict_shutdown = shutdown_tx.subscribe();
tokio::spawn(async move {
let mut ticker = interval(Duration::from_secs(30));
loop {
tokio::select! {
_ = evict_shutdown.recv() => break,
_ = ticker.tick() => {
let cutoff = chrono::Utc::now().timestamp() - 300;
let mut guard = stats_store_for_evict.write().await;
guard.retain(|_, entry| entry.last_seen >= cutoff);
}
}
}
});
if let Some(path) = &args.targets_file {
match target::load_targets_from_file(path).await {
Ok(initial_targets) => {
let mut guard = targets.write().await;
*guard = initial_targets;
tracing::info!(count = guard.len(), "loaded initial targets from file");
}
Err(e) => {
tracing::warn!(error = %e, "failed to load targets file");
}
}
let tw =
hotreload::watch_targets_file(path.clone(), targets.clone(), shutdown_tx.subscribe());
tokio::spawn(tw);
} else if std::env::var("KUBERNETES_SERVICE_HOST").is_ok() {
if let Some(k8s_targets) = k8s_discovery::k8s_pod_discovery().await {
let mut guard = targets.write().await;
*guard = k8s_targets.clone();
tracing::info!(count = guard.len(), "discovered K8s pods as targets");
}
}
let scrape_filter = args
.scrape_apps
.as_ref()
.map(|s| s.split(',').map(|p| p.trim().to_string()).collect());
let scrape_targets = targets.clone();
let scrape_store = store.clone();
let scrape_metrics = metrics.clone();
let scrape_http = http.clone();
let loki_clone = loki.clone();
let shutdown_tx_clone = shutdown_tx.clone();
let scrape_interval = args.scrape_interval_secs;
let scrape_filter_clone = scrape_filter.clone();
tokio::task::spawn_local(async move {
scrape_loop(
scrape_targets,
scrape_store,
scrape_metrics,
scrape_http,
scrape_interval,
scrape_filter_clone,
loki_clone,
shutdown_tx_clone.subscribe(),
)
.await;
});
let log_shutdown = shutdown_tx.subscribe();
let log_loki = loki.clone();
tokio::task::spawn_local(async move {
log_collector(log_loki, log_shutdown).await;
});
let bind_addr: SocketAddr = ([0, 0, 0, 0], args.port).into();
tracing::info!(addr = %bind_addr, "HTTP server starting");
let app_targets = targets.clone();
let app_store = store.clone();
let app_handle = prometheus_handle.clone();
let loki_for_push: Option<Arc<LokiForwarder>> = loki.map(Arc::new);
let app_stats = stats_store.clone();
let server = HttpServer::new(move || {
let targets = app_targets.clone();
let store = app_store.clone();
let handle = app_handle.clone();
let stats_store = app_stats.clone();
let loki_for_push: Option<Arc<LokiForwarder>> = loki_for_push.clone();
actix_web::App::new()
.app_data(web::Data::new(targets))
.app_data(web::Data::new(store))
.app_data(web::Data::new(handle))
.app_data(web::Data::new(stats_store))
.app_data(web::Data::new(loki_for_push))
.route("/metrics", web::get().to(handle_metrics))
.route("/api/v1/metrics", web::get().to(handle_metrics))
.route("/api/v1/push", web::post().to(handle_push))
.route("/api/v1/dashboard", web::get().to(handle_dashboard))
.route("/api/v1/stats", web::get().to(handle_stats))
.route("/health", web::get().to(handle_health))
.route("/api/v1/health", web::get().to(handle_health))
.route("/api/v1/targets", web::get().to(handle_targets))
})
.bind(&bind_addr)?
.run();
let server_handle = server.handle();
tokio::spawn(server);
tokio::signal::ctrl_c().await.ok();
tracing::info!("received Ctrl+C, shutting down");
let _ = shutdown_tx.send(());
server_handle.stop(true).await;
if let Some(guard) = otel_guard {
guard.shutdown().await;
}
tracing::info!("metrics-aggregator stopped");
Ok(())
}
fn init_otel_from_args(args: &args::Args) -> Option<OtelGuard> {
if args.no_otel {
return None;
}
let endpoint = args
.otel_endpoint
.clone()
.or_else(|| std::env::var("OTEL_EXPORTER_OTLP_ENDPOINT").ok())?;
match otel::init_otel(&endpoint, "metrics-aggregator") {
Ok(guard) => {
tracing::info!(endpoint = %endpoint, "OTLP tracing enabled");
Some(guard)
}
Err(e) => {
tracing::warn!(error = %e, "OTLP init failed, continuing without traces");
None
}
}
}
fn init_loki_from_args(args: &args::Args) -> Option<LokiForwarder> {
if args.no_loki {
return None;
}
let url = args
.loki_url
.clone()
.or_else(|| std::env::var("LOKI_URL").ok())?;
tracing::info!("Loki log forwarding enabled");
Some(LokiForwarder::new(url))
}
async fn handle_metrics(
store: web::Data<MetricsStore>,
stats_store: web::Data<StatsStore>,
handle: web::Data<observability::PrometheusHandle>,
) -> HttpResponse {
let extra = vec![("aggregator_instance".to_string(), "default".to_string())];
let scraped = render_aggregated_metrics(store, extra.clone()).await;
let pushed = render_pushed_metrics(stats_store).await;
let combined = format!("{}{}{}", handle.render(), scraped, pushed);
HttpResponse::Ok()
.content_type("text/plain; version=0.0.4; charset=utf-8")
.body(combined)
}
async fn handle_health() -> HttpResponse {
HttpResponse::Ok()
.content_type("application/json")
.body(r#"{"status":"ok"}"#)
}
async fn handle_targets(targets: web::Data<Arc<RwLock<Vec<ScrapeTarget>>>>) -> HttpResponse {
let guard = targets.read().await;
let json = serde_json::to_string(&*guard).unwrap_or_default();
HttpResponse::Ok()
.content_type("application/json")
.body(json)
}
// ── Push endpoint payload ────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct PushPayload {
app: String,
#[serde(default)]
instance: String,
timestamp: i64,
#[serde(default)]
http: Option<observability::push::HttpPayload>,
#[serde(default)]
system: Option<observability::push::SystemPayload>,
#[serde(default)]
business: HashMap<String, f64>,
#[serde(default)]
token_usage: Option<observability::push::TokenUsagePayload>,
#[serde(default)]
tasks: Option<observability::push::TaskStatsPayload>,
#[serde(default)]
latency: HashMap<String, observability::push::LatencySnapshot>,
#[serde(default)]
logs: Vec<observability::push::LogEntry>,
}
async fn handle_push(
stats_store: web::Data<StatsStore>,
loki: web::Data<Option<Arc<LokiForwarder>>>,
payload: web::Json<PushPayload>,
) -> HttpResponse {
let app = payload.app.clone();
stats_store::merge_push_payload(
&stats_store,
&app,
&payload.instance,
payload.timestamp,
payload.http.as_ref(),
payload.system.as_ref(),
&payload.business,
payload.token_usage.as_ref(),
payload.tasks.as_ref(),
&payload.latency,
&payload.logs,
).await;
// Forward logs to Loki if configured
if !payload.logs.is_empty() {
if let Some(loki_fwd) = loki.as_ref() {
let entries: Vec<LokiEntry> = payload
.logs
.iter()
.map(|l| LokiEntry {
timestamp: chrono::DateTime::from_timestamp(l.timestamp, 0)
.unwrap_or_else(chrono::Utc::now),
line: format!("[{}] {}", l.level.to_lowercase(), l.message),
})
.collect();
if let Err(e) = loki_fwd.push(entries).await {
tracing::warn!(error = %e, "loki push on /push failed");
}
}
}
HttpResponse::Ok().body("ok")
}
async fn scrape_loop(
targets: Arc<RwLock<Vec<ScrapeTarget>>>,
store: MetricsStore,
metrics: AggMetrics,
http: HttpClient,
interval_secs: u64,
scrape_apps_filter: Option<Vec<String>>,
_loki: Option<LokiForwarder>,
mut shutdown: broadcast::Receiver<()>,
) {
let mut ticker = interval(Duration::from_secs(interval_secs));
loop {
tokio::select! {
_ = shutdown.recv() => break,
_ = ticker.tick() => {
let targets_snapshot = targets.read().await.clone();
let count = targets_snapshot.len() as u64;
metrics.targets_total.set(count as f64);
let mut healthy_count = 0u64;
for target in &targets_snapshot {
if let Some(ref filter) = scrape_apps_filter {
if !filter.contains(&target.name) {
continue;
}
}
metrics.scrape_total.increment(1);
match http.scrape(target).await {
ScrapeResult::Success(body, duration_ms) => {
metrics.scrape_success.increment(1);
metrics.scrape_duration.record(duration_ms);
let parsed = scrape::parse_prometheus(&body);
update_store(store.clone(), &target.name, parsed).await;
healthy_count += 1;
}
ScrapeResult::Timeout => {
metrics.scrape_failures.increment(1);
metrics.scrape_errors_timeout.increment(1);
tracing::warn!(target = %target.name, "scrape timeout");
}
ScrapeResult::ConnectionError(e) => {
metrics.scrape_failures.increment(1);
metrics.scrape_errors_connection.increment(1);
tracing::warn!(target = %target.name, error = %e, "scrape connection error");
}
ScrapeResult::HttpError(status) => {
metrics.scrape_failures.increment(1);
tracing::warn!(target = %target.name, status = status, "scrape HTTP error");
}
}
}
metrics.targets_healthy.set(healthy_count as f64);
}
}
}
}
async fn update_store(store: MetricsStore, target_name: &str, metrics: Vec<scrape::PromMetric>) {
let mut guard = store.write().await;
guard.insert(target_name.to_string(), metrics);
}
async fn render_aggregated_metrics(
store: web::Data<MetricsStore>,
extra_group_labels: Vec<(String, String)>,
) -> String {
let guard = store.read().await;
let mut output = String::new();
for (target_name, metrics) in guard.iter() {
for metric in metrics {
let mut labels = metric.labels.clone();
labels.insert(
"aggregated_by".to_string(),
"metrics-aggregator".to_string(),
);
labels.insert("source_target".to_string(), target_name.clone());
for (k, v) in &extra_group_labels {
labels.insert(k.clone(), v.clone());
}
let label_str = if labels.is_empty() {
String::new()
} else {
let pairs: Vec<String> = labels
.iter()
.map(|(k, v)| {
format!(
r#"{}="{}""#,
k,
v.replace('\\', "\\\\").replace('"', "\\\"")
)
})
.collect();
format!("{{{}}}", pairs.join(","))
};
let _ = writeln!(&mut output, "{}{} {}", metric.name, label_str, metric.value);
}
}
output
}
async fn render_pushed_metrics(stats_store: web::Data<StatsStore>) -> String {
let guard = stats_store.read().await;
let mut output = String::new();
for (app_name, entry) in guard.iter() {
let labels = [
format!(r#"app="{}""#, app_name),
"aggregated_by".to_string(),
"metrics-aggregator".to_string(),
"push_source=true".to_string(),
];
let label_str = format!("{{{}}}", labels.join(","));
let h = &entry;
let _ = writeln!(
&mut output,
"push_http_requests_total{} {}",
label_str,
h.requests_total
);
let _ = writeln!(
&mut output,
"push_http_request_duration_ms_total{} {}",
label_str,
h.request_duration_ms_total
);
let _ = writeln!(&mut output, "push_http_requests_2xx{} {}", label_str, h.requests_2xx);
let _ = writeln!(&mut output, "push_http_requests_4xx{} {}", label_str, h.requests_4xx);
let _ = writeln!(&mut output, "push_http_requests_5xx{} {}", label_str, h.requests_5xx);
for (endpoint, &count) in &h.endpoints {
let sanitized = endpoint.replace([' ', '/'], "_").to_lowercase();
let ep_labels = format!(r#"app="{}",endpoint="{}",aggregated_by="metrics-aggregator",push_source="true""#, app_name, sanitized);
let _ = writeln!(&mut output, "push_http_endpoint_requests_total{{{}}} {}", ep_labels, count);
}
// System metrics in Prometheus format
let sys_labels = format!(r#"app="{}",aggregated_by="metrics-aggregator""#, app_name);
let _ = writeln!(&mut output, "system_cpu_usage_percent{{{}}} {}", sys_labels, h.cpu_usage_percent);
let _ = writeln!(&mut output, "system_memory_used_mb{{{}}} {}", sys_labels, h.memory_used_mb);
let _ = writeln!(&mut output, "system_memory_total_mb{{{}}} {}", sys_labels, h.memory_total_mb);
let _ = writeln!(&mut output, "system_uptime_secs{{{}}} {}", sys_labels, h.uptime_secs);
// Business counters
for (counter_name, value) in &h.business {
let biz_labels = format!(r#"app="{}",aggregated_by="metrics-aggregator""#, app_name);
let _ = writeln!(&mut output, "{}{{{}}} {}", counter_name, biz_labels, value);
}
// Token usage
let ai_labels = format!(r#"app="{}",aggregated_by="metrics-aggregator""#, app_name);
let _ = writeln!(&mut output, "ai_input_tokens_total{{{}}} {}", ai_labels, h.ai_input_tokens_total);
let _ = writeln!(&mut output, "ai_output_tokens_total{{{}}} {}", ai_labels, h.ai_output_tokens_total);
let _ = writeln!(&mut output, "ai_calls_total{{{}}} {}", ai_labels, h.ai_calls_total);
// Latency per endpoint
for (endpoint, lat) in &h.latency {
let lat_labels = format!(r#"app="{}",endpoint="{}",aggregated_by="metrics-aggregator""#, app_name, endpoint);
let _ = writeln!(&mut output, "latency_p99_ms{{{}}} {}", lat_labels, lat.p99_ms);
let _ = writeln!(&mut output, "latency_p90_ms{{{}}} {}", lat_labels, lat.p90_ms);
let _ = writeln!(&mut output, "latency_p50_ms{{{}}} {}", lat_labels, lat.p50_ms);
let _ = writeln!(&mut output, "latency_max_ms{{{}}} {}", lat_labels, lat.max_ms);
}
}
output
}
// ── JSON API handlers ────────────────────────────────────────────────────────
async fn handle_dashboard(stats_store: web::Data<StatsStore>) -> HttpResponse {
let dashboard = stats_store::build_dashboard(&stats_store).await;
let json = serde_json::to_string(&dashboard).unwrap_or_default();
HttpResponse::Ok()
.content_type("application/json")
.body(json)
}
async fn handle_stats(stats_store: web::Data<StatsStore>) -> HttpResponse {
// Returns per-app stats as JSON
let guard = stats_store.read().await;
let json = serde_json::to_string(&*guard).unwrap_or_default();
HttpResponse::Ok()
.content_type("application/json")
.body(json)
}
async fn log_collector(loki: Option<LokiForwarder>, mut shutdown: broadcast::Receiver<()>) {
let stdin = tokio::io::stdin();
let mut reader = tokio::io::BufReader::new(stdin);
let mut interval_tick = interval(Duration::from_secs(1));
let mut batch: Vec<LokiEntry> = Vec::with_capacity(100);
let mut line_buf = String::new();
loop {
tokio::select! {
_ = shutdown.recv() => break,
_ = interval_tick.tick() => {
if !batch.is_empty() {
if let Some(ref loki) = loki {
if let Err(e) = loki.push(std::mem::take(&mut batch)).await {
tracing::warn!(error = %e, "Loki push failed");
}
}
}
}
_ = async { line_buf.clear(); reader.read_line(&mut line_buf).await.ok() } => {
if !line_buf.is_empty() {
let line = line_buf.trim_end().to_string();
if !line.is_empty() {
batch.push(LokiEntry {
timestamp: chrono::Utc::now(),
line,
});
if batch.len() >= 100 {
if let Some(ref loki) = loki {
if let Err(e) = loki.push(std::mem::take(&mut batch)).await {
tracing::warn!(error = %e, "Loki push failed");
}
}
}
}
}
}
}
}
}

View File

@ -0,0 +1,99 @@
use metrics::{describe_counter, describe_gauge, describe_histogram, Counter, Gauge, Histogram, Unit};
pub fn init() {
describe_gauge!(
"aggregator_targets_total",
Unit::Count,
"Total number of scrape targets known to the aggregator"
);
describe_gauge!(
"aggregator_targets_healthy",
Unit::Count,
"Number of scrape targets that responded last scrape"
);
describe_counter!(
"aggregator_scrape_total",
Unit::Count,
"Total number of scrape attempts"
);
describe_counter!(
"aggregator_scrape_success",
Unit::Count,
"Successful scrapes"
);
describe_counter!(
"aggregator_scrape_failures",
Unit::Count,
"Failed scrape attempts"
);
describe_counter!(
"aggregator_scrape_errors_parse",
Unit::Count,
"Scrape failures due to parse errors"
);
describe_counter!(
"aggregator_scrape_errors_timeout",
Unit::Count,
"Scrape failures due to timeout"
);
describe_counter!(
"aggregator_scrape_errors_connection",
Unit::Count,
"Scrape failures due to connection errors"
);
describe_counter!(
"aggregator_targets_discovered",
Unit::Count,
"Total targets discovered"
);
describe_counter!(
"aggregator_targets_lost",
Unit::Count,
"Total targets that disappeared"
);
describe_histogram!(
"aggregator_scrape_duration_ms",
Unit::Milliseconds,
"Scrape duration in milliseconds"
);
}
#[derive(Clone)]
#[allow(dead_code)]
pub struct AggMetrics {
pub targets_total: Gauge,
pub targets_healthy: Gauge,
pub scrape_total: Counter,
pub scrape_success: Counter,
pub scrape_failures: Counter,
pub scrape_errors_parse: Counter,
pub scrape_errors_timeout: Counter,
pub scrape_errors_connection: Counter,
pub targets_discovered: Counter,
pub targets_lost: Counter,
pub scrape_duration: Histogram,
}
impl Default for AggMetrics {
fn default() -> Self {
Self {
targets_total: metrics::gauge!("aggregator_targets_total"),
targets_healthy: metrics::gauge!("aggregator_targets_healthy"),
scrape_total: metrics::counter!("aggregator_scrape_total"),
scrape_success: metrics::counter!("aggregator_scrape_success"),
scrape_failures: metrics::counter!("aggregator_scrape_failures"),
scrape_errors_parse: metrics::counter!("aggregator_scrape_errors_parse"),
scrape_errors_timeout: metrics::counter!("aggregator_scrape_errors_timeout"),
scrape_errors_connection: metrics::counter!("aggregator_scrape_errors_connection"),
targets_discovered: metrics::counter!("aggregator_targets_discovered"),
targets_lost: metrics::counter!("aggregator_targets_lost"),
scrape_duration: metrics::histogram!("aggregator_scrape_duration_ms"),
}
}
}
impl AggMetrics {
pub fn new() -> Self {
Self::default()
}
}

40
apps/metrics/src/otel.rs Normal file
View File

@ -0,0 +1,40 @@
use anyhow::Context;
use opentelemetry::trace::TracerProvider;
use opentelemetry_otlp::{SpanExporter, WithExportConfig};
use opentelemetry_sdk::trace as sdktrace;
use tracing_opentelemetry::layer;
use tracing_subscriber::prelude::*;
pub struct OtelGuard {
provider: sdktrace::SdkTracerProvider,
}
impl OtelGuard {
pub async fn shutdown(self) {
if let Err(e) = self.provider.shutdown() {
tracing::warn!(error = %e, "OTLP shutdown error");
}
}
}
pub fn init_otel(endpoint: &str, service_name: &str) -> anyhow::Result<OtelGuard> {
let exporter = SpanExporter::builder()
.with_http()
.with_endpoint(endpoint)
.build()
.context("build OTLP exporter")?;
let tracer_provider = sdktrace::SdkTracerProvider::builder()
.with_batch_exporter(exporter)
.build();
let tracer = tracer_provider.tracer(service_name.to_string());
let otel_layer = layer().with_tracer(tracer);
tracing_subscriber::registry()
.with(otel_layer)
.try_init()
.context("install OTLP tracing subscriber")?;
Ok(OtelGuard { provider: tracer_provider })
}

135
apps/metrics/src/scrape.rs Normal file
View File

@ -0,0 +1,135 @@
use awc::Client;
use std::collections::HashMap;
use crate::target::ScrapeTarget;
#[derive(Clone)]
pub struct HttpClient {
client: Client,
}
impl HttpClient {
pub fn new(timeout_secs: u64) -> Self {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(timeout_secs))
.finish();
Self { client }
}
pub async fn scrape(&self, target: &ScrapeTarget) -> ScrapeResult {
let start = std::time::Instant::now();
let url = target.url();
let mut resp = match self.client.get(url).send().await {
Ok(resp) => resp,
Err(e) => {
let msg = e.to_string();
if msg.contains("timeout") || msg.contains("TimedOut") || msg.contains("timed out")
{
return ScrapeResult::Timeout;
}
return ScrapeResult::ConnectionError(msg);
}
};
if !resp.status().is_success() {
return ScrapeResult::HttpError(resp.status().as_u16());
}
let body = match resp.body().await {
Ok(bytes) => String::from_utf8_lossy(&bytes).into_owned(),
Err(e) => return ScrapeResult::ConnectionError(e.to_string()),
};
let scrape_ms = start.elapsed().as_millis() as f64;
ScrapeResult::Success(body, scrape_ms)
}
}
pub enum ScrapeResult {
Success(String, f64),
Timeout,
ConnectionError(String),
HttpError(u16),
}
#[derive(Clone, Debug)]
pub struct PromMetric {
pub name: String,
pub value: f64,
pub labels: HashMap<String, String>,
}
pub fn parse_prometheus(body: &str) -> Vec<PromMetric> {
let mut metrics = Vec::new();
for line in body.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let (name_and_labels, value_str) = match line.find(' ') {
Some(pos) => (&line[..pos], &line[pos + 1..]),
None => continue,
};
let value: f64 = match value_str
.split_whitespace()
.next()
.and_then(|v| v.parse().ok())
{
Some(v) => v,
None => continue,
};
let (metric_name, labels) = if let Some(brace) = name_and_labels.find('{') {
let name = &name_and_labels[..brace];
let label_str = &name_and_labels[brace + 1..name_and_labels.len() - 1];
let labels = parse_labels(label_str);
(name.to_string(), labels)
} else {
(name_and_labels.to_string(), HashMap::new())
};
metrics.push(PromMetric {
name: metric_name,
value,
labels,
});
}
metrics
}
pub fn parse_labels(s: &str) -> HashMap<String, String> {
let mut labels = HashMap::new();
let mut remaining = s;
while !remaining.is_empty() {
if let Some(eq) = remaining.find('=') {
let key = remaining[..eq].trim().to_string();
remaining = &remaining[eq + 1..];
let (value, rest) = if remaining.starts_with('"') {
let end = remaining[1..]
.find('"')
.map(|p| p + 1)
.unwrap_or(remaining.len());
(&remaining[1..end], &remaining[end + 1..])
} else if remaining.starts_with('\'') {
let end = remaining[1..]
.find('\'')
.map(|p| p + 1)
.unwrap_or(remaining.len());
(&remaining[1..end], &remaining[end + 1..])
} else {
let end = remaining
.find(|c: char| !c.is_alphanumeric() && c != '_' && c != '-')
.unwrap_or(remaining.len());
(&remaining[..end], &remaining[end..])
};
labels.insert(key, value.to_string());
remaining = rest.trim_start_matches(',').trim_start();
} else {
break;
}
}
labels
}

View File

@ -0,0 +1,210 @@
//! Stats store: receives expanded push payloads from all apps,
//! aggregates over time, computes derived statistics (p99 etc),
//! and provides JSON API for external consumption.
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use serde::Serialize;
/// Per-app, per-instance aggregated stats entry.
#[derive(Debug, Clone, Default, Serialize)]
pub struct AppStats {
/// Last seen timestamp.
pub last_seen: i64,
/// Number of push samples received.
pub sample_count: u64,
// ── HTTP ─────────────────────────────────────────────────────
pub requests_total: u64,
pub request_duration_ms_total: u64,
pub requests_2xx: u64,
pub requests_4xx: u64,
pub requests_5xx: u64,
pub endpoints: HashMap<String, u64>,
// ── System ───────────────────────────────────────────────────
pub cpu_usage_percent: f32,
pub memory_used_mb: u64,
pub memory_total_mb: u64,
pub uptime_secs: u64,
// ── Business counters ────────────────────────────────────────
pub business: HashMap<String, f64>,
// ── Token usage ──────────────────────────────────────────────
pub ai_input_tokens_total: i64,
pub ai_output_tokens_total: i64,
pub ai_calls_total: i64,
pub ai_calls_success: i64,
pub ai_calls_failure: i64,
pub token_by_model: HashMap<String, ModelTokenStats>,
// ── Tasks ────────────────────────────────────────────────────
pub tasks_queued: i64,
pub tasks_running: i64,
pub tasks_completed: i64,
pub tasks_failed: i64,
// ── Latency ──────────────────────────────────────────────────
pub latency: HashMap<String, LatencyStats>,
// ── Logs ─────────────────────────────────────────────────────
#[serde(skip_serializing)]
pub logs: Vec<(i64, String)>,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct ModelTokenStats {
pub input_tokens: i64,
pub output_tokens: i64,
pub calls: i64,
}
#[derive(Debug, Clone, Default, Serialize)]
pub struct LatencyStats {
pub p50_ms: f64,
pub p90_ms: f64,
pub p99_ms: f64,
pub max_ms: f64,
pub count: u64,
}
/// The global stats store: app_name → AppStats.
pub type StatsStore = Arc<RwLock<HashMap<String, AppStats>>>;
/// Merge a new push payload into the stats store.
pub async fn merge_push_payload(
store: &StatsStore,
app: &str,
_instance: &str,
timestamp: i64,
http: Option<&observability::push::HttpPayload>,
system: Option<&observability::push::SystemPayload>,
business: &HashMap<String, f64>,
token_usage: Option<&observability::push::TokenUsagePayload>,
tasks: Option<&observability::push::TaskStatsPayload>,
latency: &HashMap<String, observability::push::LatencySnapshot>,
logs: &[observability::push::LogEntry],
) {
// Use app_name as key (merge across instances for aggregation)
let mut guard = store.write().await;
let entry = guard.entry(app.to_string()).or_default();
entry.last_seen = timestamp;
entry.sample_count += 1;
// HTTP — accumulate (not replace, so we get totals over time)
if let Some(http) = http {
entry.requests_total = http.requests_total;
entry.request_duration_ms_total = http.request_duration_ms_total;
entry.requests_2xx = http.requests_2xx;
entry.requests_4xx = http.requests_4xx;
entry.requests_5xx = http.requests_5xx;
for (ep, count) in &http.endpoints {
*entry.endpoints.entry(ep.clone()).or_insert(0) = *count;
}
}
// System — replace (current snapshot, not cumulative)
if let Some(sys) = system {
entry.cpu_usage_percent = sys.cpu_usage_percent;
entry.memory_used_mb = sys.memory_used_mb;
entry.memory_total_mb = sys.memory_total_mb;
entry.uptime_secs = sys.uptime_secs;
}
// Business — replace with latest snapshot
entry.business = business.clone();
// Token usage — replace with latest
if let Some(tu) = token_usage {
entry.ai_input_tokens_total = tu.ai_input_tokens_total;
entry.ai_output_tokens_total = tu.ai_output_tokens_total;
entry.ai_calls_total = tu.ai_calls_total;
entry.ai_calls_success = tu.ai_calls_success;
entry.ai_calls_failure = tu.ai_calls_failure;
for (model, usage) in &tu.by_model {
let ms = entry.token_by_model.entry(model.clone()).or_default();
ms.input_tokens = usage.input_tokens;
ms.output_tokens = usage.output_tokens;
ms.calls = usage.calls;
}
}
// Tasks — replace with latest
if let Some(t) = tasks {
entry.tasks_queued = t.queued;
entry.tasks_running = t.running;
entry.tasks_completed = t.completed;
entry.tasks_failed = t.failed;
}
// Latency — replace with latest snapshots
for (endpoint, snap) in latency {
let ls = entry.latency.entry(endpoint.clone()).or_default();
ls.p50_ms = snap.p50_ms;
ls.p90_ms = snap.p90_ms;
ls.p99_ms = snap.p99_ms;
ls.max_ms = snap.max_ms;
ls.count = snap.count;
}
// Logs — append (keep last 300 lines)
for log in logs {
entry.logs.push((log.timestamp, format!("[{}] {}", log.level.to_lowercase(), log.message)));
}
let cutoff = chrono::Utc::now().timestamp() - 300;
entry.logs.retain(|(ts, _)| *ts >= cutoff);
}
/// Dashboard response combining all apps' stats.
#[derive(Debug, Serialize)]
pub struct DashboardResponse {
/// Timestamp of this snapshot.
pub timestamp: i64,
/// Total number of app instances reporting.
pub app_count: u64,
/// Per-app aggregated stats.
pub apps: HashMap<String, AppStats>,
/// Derived: average p99 latency across all apps.
pub avg_p99_ms: f64,
/// Derived: total tokens consumed across all apps.
pub total_input_tokens: i64,
pub total_output_tokens: i64,
/// Derived: total AI calls across all apps.
pub total_ai_calls: i64,
}
/// Build the dashboard response from the stats store.
pub async fn build_dashboard(store: &StatsStore) -> DashboardResponse {
let guard = store.read().await;
let mut avg_p99 = 0.0;
let mut p99_count = 0;
let mut total_input = 0i64;
let mut total_output = 0i64;
let mut total_calls = 0i64;
for (_, stats) in guard.iter() {
total_input += stats.ai_input_tokens_total;
total_output += stats.ai_output_tokens_total;
total_calls += stats.ai_calls_total;
for (_, lat) in &stats.latency {
avg_p99 += lat.p99_ms;
p99_count += 1;
}
}
let avg_p99_ms = if p99_count > 0 { avg_p99 / p99_count as f64 } else { 0.0 };
DashboardResponse {
timestamp: chrono::Utc::now().timestamp(),
app_count: guard.len() as u64,
apps: guard.clone(),
avg_p99_ms,
total_input_tokens: total_input,
total_output_tokens: total_output,
total_ai_calls: total_calls,
}
}

View File

@ -0,0 +1,34 @@
use std::collections::HashMap;
use anyhow::Context;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ScrapeTarget {
pub name: String,
pub addr: String,
#[serde(default = "default_metrics_path")]
pub metrics_path: String,
#[serde(default)]
pub labels: HashMap<String, String>,
}
fn default_metrics_path() -> String {
"/metrics".to_string()
}
impl ScrapeTarget {
pub fn url(&self) -> String {
if self.metrics_path.starts_with("http") {
self.metrics_path.clone()
} else {
format!("http://{}{}", self.addr, self.metrics_path)
}
}
}
pub async fn load_targets_from_file(path: &str) -> anyhow::Result<Vec<ScrapeTarget>> {
let content = tokio::fs::read_to_string(path).await.context("read targets file")?;
let targets: Vec<ScrapeTarget> = serde_json::from_str(&content)
.with_context(|| format!("parse targets file {path}"))?;
Ok(targets)
}

View File

@ -7,6 +7,8 @@ edition.workspace = true
actix-web = { workspace = true }
actix-files = { workspace = true }
actix-cors = { workspace = true }
observability = { workspace = true }
metrics-exporter-prometheus = "0.13"
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
serde = { workspace = true }

View File

@ -5,8 +5,10 @@ use actix_web::{http::header, web, App, HttpResponse, HttpServer};
use futures::future::LocalBoxFuture;
use log::info;
use std::path::PathBuf;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use observability::{init_tracing_subscriber, install_recorder, HttpMetrics, push::MetricsPusher};
/// Static file server for avatar, blob, and other static files
/// Serves files from /data/{type} directories
@ -119,7 +121,16 @@ where
#[actix_web::main]
async fn main() -> anyhow::Result<()> {
env_logger::init_from_env(env_logger::Env::new().default_filter_or("info"));
init_tracing_subscriber("info", false);
let prometheus_handle = Arc::new(install_recorder());
let http_metrics = Arc::new(HttpMetrics::new());
// Metrics pusher: periodically push all metrics to apps/metrics aggregator
if let Some(push_url) = std::env::var("METRICS_PUSH_URL").ok() {
let pusher = MetricsPusher::new(&push_url, "static");
pusher.spawn(http_metrics.clone(), prometheus_handle.clone(), std::time::Duration::from_secs(15));
info!("Metrics pusher started (interval 15s, url: {})", push_url);
}
let cfg = StaticConfig::from_env();
let bind = std::env::var("STATIC_BIND").unwrap_or_else(|_| "0.0.0.0:8081".to_string());
@ -142,6 +153,8 @@ async fn main() -> anyhow::Result<()> {
let root = root.clone();
let cors = if cors_enabled {
// WARNING: allow_any_origin is intentional for static asset serving (CDN mode)
// Ensure no sensitive files are served from this directory
Cors::default()
.allow_any_origin()
.allowed_methods(vec!["GET", "HEAD", "OPTIONS"])

88
build.sh Normal file
View File

@ -0,0 +1,88 @@
#!/usr/bin/env bash
set -euo pipefail
# ── helpers ──────────────────────────────────────────────────────────
RED='\033[0;31m'; GREEN='\033[0;32m'; YELLOW='\033[1;33m'; NC='\033[0m'
log() { echo -e "${GREEN}[OK]${NC} $*"; }
warn() { echo -e "${YELLOW}[WARN]${NC} $*"; }
err() { echo -e "${RED}[ERR]${NC} $*"; exit 1; }
command_exists() { command -v "$1" &>/dev/null; }
# ── 1. Rust ─────────────────────────────────────────────────────────
if command_exists rustc; then
log "Rust $(rustc --version)"
else
warn "Rust not found, installing via rustup..."
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
# shellcheck disable=SC1091
source "$HOME/.cargo/env"
log "Rust installed: $(rustc --version)"
fi
# ── 2. Node.js ──────────────────────────────────────────────────────
if command_exists node; then
log "Node.js $(node --version)"
else
warn "Node.js not found, installing via nvm..."
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.1/install.sh | bash
# shellcheck disable=SC1090
export NVM_DIR="${HOME}/.nvm"
[ -s "$NVM_DIR/nvm.sh" ] && source "$NVM_DIR/nvm.sh"
nvm install --lts
log "Node.js installed: $(node --version)"
fi
# ── 2b. Bun ─────────────────────────────────────────────────────────
if command_exists bun; then
log "Bun $(bun --version)"
else
warn "Bun not found, installing..."
curl -fsSL https://bun.sh/install | bash
# shellcheck disable=SC1091
[ -s "$HOME/.bun/_bun" ] && export PATH="$HOME/.bun/bin:$PATH"
log "Bun installed: $(bun --version)"
fi
# ── 3. Docker ───────────────────────────────────────────────────────
if command_exists docker; then
log "Docker $(docker --version)"
else
warn "Docker not found, installing..."
curl -fsSL https://get.docker.com | sh
log "Docker installed: $(docker --version)"
fi
# ── 4. Frontend build ───────────────────────────────────────────────
log "Running bun install..."
bun install
log "Running bun run build..."
bun run build
# ── 5. Rust build ───────────────────────────────────────────────────
log "Running cargo build --release --workspace..."
cargo build --release --workspace
# ── 6. Docker images ────────────────────────────────────────────────
TAG=$(git rev-parse --short HEAD)
log "Building Docker images with tag: $TAG"
IMAGES=(
"docker/app.Dockerfile app:$TAG"
"docker/email.Dockerfile email-worker:$TAG"
"docker/githook.Dockerfile git-hook:$TAG"
"docker/gitserver.Dockerfile gitserver:$TAG"
"docker/metrics.Dockerfile metrics-aggregator:$TAG"
"docker/static.Dockerfile static-server:$TAG"
"docker/gingress.Dockerfile gingress:$TAG"
)
for entry in "${IMAGES[@]}"; do
read -r dockerfile tag <<< "$entry"
log "Building $tag..."
docker build -f "$dockerfile" -t "$tag" .
done
log "All images built successfully."
docker images | grep -E "app|email-worker|git-hook|gitserver|metrics-aggregator|static-server|gingress" | grep "$TAG" || true

1016
bun.lock

File diff suppressed because it is too large Load Diff

23
deploy/.helmignore Normal file
View File

@ -0,0 +1,23 @@
# Patterns to ignore when building packages.
# This supports shell glob matching, relative path matching, and
# negation (prefixed with !). Only one pattern per line.
.DS_Store
# Common VCS dirs
.git/
.gitignore
.bzr/
.bzrignore
.hg/
.hgignore
.svn/
# Common backup files
*.swp
*.bak
*.tmp
*.orig
*~
# Various IDEs
.project
.idea/
*.tmproj
.vscode/

6
deploy/Chart.yaml Normal file
View File

@ -0,0 +1,6 @@
apiVersion: v2
name: deploy
description: Helm chart for the project backend services
type: application
version: 0.1.0
appVersion: "0.2.9"

198
deploy/README.md Normal file
View File

@ -0,0 +1,198 @@
# Deploy Helm Chart
Monolithic Helm chart for all backend services.
## Services
| Service | Port(s) | Replicas | HPA | Purpose |
|---|---|---|---|---|
| `app` | 3000 (HTTP) | 2 | 210 | Main API server |
| `gitserver` | 8021 (HTTP), 2222 (SSH) | 1 | 15 | Git HTTP + SSH server |
| `email_worker` | 8084 (HTTP) | 1 | disabled | Email queue consumer (single instance only) |
| `git_hook` | 8083 (HTTP) | 1 | 15 | Git hook worker pool |
| `metrics_aggregator` | 9090 (HTTP) | 1 | 15 | Prometheus scrape + Loki push |
| `static_server` | 8081 (HTTP) | 1 | 15 | Static file server (avatars, blobs, media) |
## Prerequisites
The following resources must exist in the cluster **before** installing the Helm chart. They are not managed by Helm — install, upgrade, and uninstall of the chart will not touch them.
### 1. Namespace
```bash
kubectl create namespace app
```
### 2. PVC (aliyun-nfs, 200Ti, ReadWriteMany)
```bash
kubectl apply -f - <<'EOF'
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: shared-data
namespace: app
spec:
accessModes:
- ReadWriteMany
resources:
requests:
storage: 200Ti
storageClassName: aliyun-nfs
EOF
```
> The chart references this PVC by name. If you use a different name, pass `--set pvcName=your-pvc-name` to Helm.
### 3. ConfigMap
```bash
kubectl apply -f - <<'EOF'
apiVersion: v1
kind: ConfigMap
metadata:
name: app-env
namespace: app
data:
APP_REPOS_ROOT: "/data/repos"
APP_AVATAR_PATH: "/data/avatars"
STORAGE_PATH: "/data/files"
STATIC_ROOT: "/data"
APP_LOG_LEVEL: "info"
APP_COOKIE_SECURE: "false"
APP_DOMAIN_URL: "https://your-domain.com"
APP_DATABASE_URL: "postgres://user:pass@postgres:5432/app"
APP_REDIS_URL: "redis://redis:6379"
APP_AI_BASIC_URL: "https://api.openai.com/v1"
APP_AI_API_KEY: "sk-..."
APP_SMTP_PASSWORD: "..."
APP_SESSION_SECRET: "min-32-byte-random-string..."
APP_SSH_SERVER_PRIVATE_KEY: "<hex-encoded-private-key>"
EOF
```
| Variable | Default / Example | Required |
|---|---|---|
| `APP_REPOS_ROOT` | `/data/repos` | Yes |
| `APP_AVATAR_PATH` | `/data/avatars` | Yes |
| `STORAGE_PATH` | `/data/files` | Yes |
| `STATIC_ROOT` | `/data` | Yes |
| `APP_LOG_LEVEL` | `info` | No |
| `APP_COOKIE_SECURE` | `false` | No |
| `APP_DOMAIN_URL` | `https://your-domain.com` | Yes |
| `APP_DATABASE_URL` | `postgres://...` | **Yes** |
| `APP_REDIS_URL` | `redis://...` | **Yes** |
| `APP_AI_BASIC_URL` | `https://api.openai.com/v1` | **Yes** |
| `APP_AI_API_KEY` | `sk-...` | **Yes** |
| `APP_SMTP_PASSWORD` | `...` | **Yes** |
| `APP_SESSION_SECRET` | min 32 bytes | **Yes** |
| `APP_SSH_SERVER_PRIVATE_KEY` | hex-encoded PEM | **Yes** |
| `APP_SSH_PORT` | `2222` | Yes (k8s) |
> **SSH host key**: `APP_SSH_SERVER_PRIVATE_KEY` must be the hex-encoded Ed25519 private key PEM bytes.
> ```bash
> ssh-keygen -t ed25519 -f /tmp/ssh_host_key -N ""
> hexdump -v -e '/1 "%02x"' < /tmp/ssh_host_key
> ```
>
> **Session secret**: generate 48 random bytes:
> ```bash
> openssl rand -base64 48
> ```
>
> Override the ConfigMap name with `--set configMapName=your-cm-name`.
### 4. Verify prerequisites
```bash
kubectl get namespace app
kubectl get pvc -n app shared-data
kubectl get configmap -n app app-env
```
## Quick Start
```bash
helm template deploy ./deploy --namespace app --set imageRegistry=ghcr.io/your-org
helm lint ./deploy
# Install
helm upgrade --install deploy ./deploy \
--namespace app \
--set imageRegistry=ghcr.io/your-org \
--set imageTag=v0.2.9
```
## Storage
All services share a single PVC (`shared-data`) via `subPath` mounts:
| SubPath | Mount | Used By |
|---|---|---|
| `repos` | `/data/repos` | app, gitserver, git-hook |
| `avatars` | `/data/avatars` | app |
| `files` | `/data/files` | app |
| `static` | `/data` | static-server |
## Autoscaling
All services except `email_worker` have HPA enabled by default. The email worker is fixed at 1 replica and must not be scaled.
To adjust HPA bounds per service:
```bash
--set services.app.autoscaling.maxReplicas=20
--set services.app.autoscaling.targetCPUUtilization=70
```
To disable HPA for a service:
```bash
--set services.git_hook.autoscaling.enabled=false
```
## Ingress
```bash
helm upgrade --install deploy ./deploy \
--namespace app \
--set ingress.enabled=true \
--set ingress.className=nginx \
--set ingress.hosts[0].host=your-domain.com
```
## Dependencies
All services require these to be reachable from the cluster:
- PostgreSQL (via `APP_DATABASE_URL`)
- Redis (via `APP_REDIS_URL`)
- Git binary (included in all Docker images)
- OpenAI-compatible API (via `APP_AI_BASIC_URL` + `APP_AI_API_KEY`)
- Qdrant vector DB (via `APP_QDRANT_URL`)
- SMTP server (via `APP_SMTP_*`)
- Embedding model (via `APP_EMBED_MODEL_*`)
Optional dependencies with graceful degradation:
| Dependency | Variable | Fallback |
|---|---|---|
| NATS JetStream | `NATS_URL` + `NATS_TOKEN` | Redis queue |
| Loki | `LOKI_URL` | Logs discarded |
| OTEL Collector | `OTEL_EXPORTER_OTLP_ENDPOINT` | Tracing disabled |
## Production Example
```bash
helm upgrade --install deploy ./deploy \
--namespace app \
--set imageRegistry=ghcr.io/your-org \
--set imageTag=v0.2.9 \
--set services.app.replicas=3 \
--set services.app.autoscaling.maxReplicas=20 \
--set ingress.enabled=true \
--set ingress.className=nginx \
--set ingress.hosts[0].host=your-domain.com \
--set configMapName=app-env \
--set pvcName=shared-data
```

View File

@ -0,0 +1,93 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: gingress-controller
namespace: gingress-system
labels:
app: gingress
spec:
replicas: 2
selector:
matchLabels:
app: gingress
template:
metadata:
labels:
app: gingress
spec:
serviceAccountName: gingress-controller
containers:
- name: gingress
image: gingress:latest
imagePullPolicy: IfNotPresent
args:
- "--ingress-class=gingress"
- "--bind-http=0.0.0.0:80"
- "--bind-https=0.0.0.0:443"
- "--metrics-bind=0.0.0.0:8080"
ports:
- name: http
containerPort: 80
protocol: TCP
- name: https
containerPort: 443
protocol: TCP
- name: metrics
containerPort: 8080
protocol: TCP
env:
- name: RUST_LOG
value: "info"
- name: METRICS_PUSH_URL
value: "" # Optional: push to metrics aggregator
livenessProbe:
httpGet:
path: /healthz
port: 8080
initialDelaySeconds: 10
periodSeconds: 10
readinessProbe:
httpGet:
path: /readyz
port: 8080
initialDelaySeconds: 5
periodSeconds: 5
resources:
requests:
cpu: 100m
memory: 128Mi
limits:
cpu: 500m
memory: 512Mi
affinity:
podAntiAffinity:
preferredDuringSchedulingIgnoredDuringExecution:
- weight: 100
podAffinityTerm:
labelSelector:
matchLabels:
app: gingress
topologyKey: kubernetes.io/hostname
---
apiVersion: v1
kind: Service
metadata:
name: gingress
namespace: gingress-system
spec:
type: LoadBalancer
selector:
app: gingress
ports:
- name: http
port: 80
targetPort: 80
protocol: TCP
- name: https
port: 443
targetPort: 443
protocol: TCP
- name: metrics
port: 8080
targetPort: 8080
protocol: TCP

48
deploy/gingress/rbac.yaml Normal file
View File

@ -0,0 +1,48 @@
apiVersion: v1
kind: Namespace
metadata:
name: gingress-system
---
apiVersion: v1
kind: ServiceAccount
metadata:
name: gingress-controller
namespace: gingress-system
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
name: gingress-controller
rules:
- apiGroups: ["networking.k8s.io"]
resources: ["ingresses", "ingressclasses"]
verbs: ["get", "list", "watch"]
- apiGroups: ["networking.k8s.io"]
resources: ["ingresses/status"]
verbs: ["update", "patch"]
- apiGroups: [""]
resources: ["services", "endpoints", "endpointslices", "secrets", "nodes"]
verbs: ["get", "list", "watch"]
- apiGroups: ["discovery.k8s.io"]
resources: ["endpointslices"]
verbs: ["get", "list", "watch"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: gingress-controller
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: gingress-controller
subjects:
- kind: ServiceAccount
name: gingress-controller
namespace: gingress-system
---
apiVersion: networking.k8s.io/v1
kind: IngressClass
metadata:
name: gingress
spec:
controller: gingress.io/gingress-controller

View File

@ -0,0 +1,19 @@
Project backend services deployed to namespace: {{ .Release.Namespace }}
Services:
{{- range $svcKey, $svcVal := .Values.services }}
{{ $svcKey | replace "_" "-" }}: {{ if $svcVal.ports }}{{ range $portName, $portNum := $svcVal.ports }}{{ $portName }}={{ $portNum }} {{ end }}{{ else }}port={{ $svcVal.port }}{{ end }} {{ if $svcVal.autoscaling.enabled }}(HPA: {{ $svcVal.autoscaling.minReplicas }}-{{ $svcVal.autoscaling.maxReplicas }}){{ else }}(static: {{ $svcVal.replicaCount }}){{ end }}
{{- end }}
To access the app locally:
kubectl port-forward -n {{ .Release.Namespace }} svc/{{ include "deploy.serviceFullname" (dict "root" . "svcKey" "app") }} 3000:3000
To check HPA status:
{{- range $svcKey, $svcVal := .Values.services }}
{{- if $svcVal.autoscaling.enabled }}
kubectl get hpa -n {{ $.Release.Namespace }} {{ include "deploy.serviceFullname" (dict "root" $ "svcKey" $svcKey) }}
{{- end }}
{{- end }}
To check all pods:
kubectl get pods -n {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "deploy.name" . }}"

View File

@ -0,0 +1,78 @@
{{/*
Expand the name of the chart.
*/}}
{{- define "deploy.name" -}}
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Create a default fully qualified app name.
*/}}
{{- define "deploy.fullname" -}}
{{- if .Values.fullnameOverride }}
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- $name := default .Chart.Name .Values.nameOverride }}
{{- if contains $name .Release.Name }}
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
{{- end }}
{{- end }}
{{- end }}
{{/*
Service fullname — includes service key for per-service resources.
Underscores in svcKey are replaced with hyphens for valid Kubernetes names.
*/}}
{{- define "deploy.serviceFullname" -}}
{{- printf "%s-%s" (include "deploy.fullname" .root) (.svcKey | replace "_" "-") | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Chart name and version as used by the chart label.
*/}}
{{- define "deploy.chart" -}}
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Common labels
*/}}
{{- define "deploy.labels" -}}
helm.sh/chart: {{ include "deploy.chart" . }}
{{ include "deploy.selectorLabels" . }}
{{- if .Chart.AppVersion }}
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
{{- end }}
app.kubernetes.io/managed-by: {{ .Release.Service }}
{{- end }}
{{/*
Selector labels
*/}}
{{- define "deploy.selectorLabels" -}}
app.kubernetes.io/name: {{ include "deploy.name" . }}
app.kubernetes.io/instance: {{ .Release.Name }}
{{- end }}
{{/*
Per-service selector labels — used by Service to target the right Deployment.
Underscores in svcKey are replaced with hyphens for valid Kubernetes label values.
*/}}
{{- define "deploy.serviceSelectorLabels" -}}
app.kubernetes.io/name: {{ include "deploy.name" .root }}
app.kubernetes.io/instance: {{ .root.Release.Name }}
app.kubernetes.io/component: {{ .svcKey | replace "_" "-" }}
{{- end }}
{{/*
Create the name of the service account to use
*/}}
{{- define "deploy.serviceAccountName" -}}
{{- if .Values.serviceAccount.create }}
{{- default (include "deploy.fullname" .) .Values.serviceAccount.name }}
{{- else }}
{{- default "default" .Values.serviceAccount.name }}
{{- end }}
{{- end }}

View File

@ -0,0 +1,89 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "app") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: app
spec:
replicas: {{ .Values.services.app.replicaCount | default 1 }}
selector:
matchLabels:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "app") | nindent 6 }}
template:
metadata:
labels:
{{- include "deploy.labels" . | nindent 8 }}
app.kubernetes.io/component: app
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
{{- with .Values.podSecurityContext }}
securityContext:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: app
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 12 }}
{{- end }}
image: "{{ .Values.imageRegistry }}/{{ .Values.services.app.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
imagePullPolicy: IfNotPresent
{{- with .Values.services.app.command }}
command:
{{- toYaml . | nindent 12 }}
{{- end }}
ports:
- name: http
containerPort: {{ .Values.services.app.port }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.configMapName }}
{{- with .Values.services.app.extraEnv }}
env:
{{- range $key, $val := . }}
- name: {{ $key }}
value: {{ $val | quote }}
{{- end }}
{{- end }}
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 15
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 10
{{- with .Values.services.app.resources }}
resources:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.services.app.volumeMounts }}
volumeMounts:
{{- toYaml . | nindent 12 }}
{{- end }}
volumes:
- name: shared-data
persistentVolumeClaim:
claimName: {{ .Values.pvcName }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@ -0,0 +1,16 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "app") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: app
spec:
type: ClusterIP
ports:
- port: {{ .Values.services.app.port }}
targetPort: http
protocol: TCP
name: http
selector:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "app") | nindent 4 }}

View File

@ -0,0 +1,70 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "email_worker") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: email-worker
spec:
replicas: {{ .Values.services.email_worker.replicaCount | default 1 }}
selector:
matchLabels:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "email_worker") | nindent 6 }}
template:
metadata:
labels:
{{- include "deploy.labels" . | nindent 8 }}
app.kubernetes.io/component: email-worker
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
{{- with .Values.podSecurityContext }}
securityContext:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: email_worker
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 12 }}
{{- end }}
image: "{{ .Values.imageRegistry }}/{{ .Values.services.email_worker.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
imagePullPolicy: IfNotPresent
ports:
- name: http
containerPort: {{ .Values.services.email_worker.port }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.configMapName }}
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 15
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 10
{{- with .Values.services.email_worker.resources }}
resources:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@ -0,0 +1,16 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "email_worker") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: email-worker
spec:
type: ClusterIP
ports:
- port: {{ .Values.services.email_worker.port }}
targetPort: http
protocol: TCP
name: http
selector:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "email_worker") | nindent 4 }}

View File

@ -0,0 +1,78 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "git_hook") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: git-hook
spec:
replicas: {{ .Values.services.git_hook.replicaCount | default 1 }}
selector:
matchLabels:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "git_hook") | nindent 6 }}
template:
metadata:
labels:
{{- include "deploy.labels" . | nindent 8 }}
app.kubernetes.io/component: git-hook
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
{{- with .Values.podSecurityContext }}
securityContext:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: git-hook
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 12 }}
{{- end }}
image: "{{ .Values.imageRegistry }}/{{ .Values.services.git_hook.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
imagePullPolicy: IfNotPresent
ports:
- name: http
containerPort: {{ .Values.services.git_hook.port }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.configMapName }}
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 15
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 10
{{- with .Values.services.git_hook.resources }}
resources:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.services.git_hook.volumeMounts }}
volumeMounts:
{{- toYaml . | nindent 12 }}
{{- end }}
volumes:
- name: shared-data
persistentVolumeClaim:
claimName: {{ .Values.pvcName }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@ -0,0 +1,16 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "git_hook") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: git-hook
spec:
type: ClusterIP
ports:
- port: {{ .Values.services.git_hook.port }}
targetPort: http
protocol: TCP
name: http
selector:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "git_hook") | nindent 4 }}

View File

@ -0,0 +1,88 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "gitserver") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: gitserver
spec:
replicas: {{ .Values.services.gitserver.replicaCount | default 1 }}
selector:
matchLabels:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "gitserver") | nindent 6 }}
template:
metadata:
labels:
{{- include "deploy.labels" . | nindent 8 }}
app.kubernetes.io/component: gitserver
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
{{- with .Values.podSecurityContext }}
securityContext:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: gitserver
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 12 }}
{{- end }}
image: "{{ .Values.imageRegistry }}/{{ .Values.services.gitserver.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
imagePullPolicy: IfNotPresent
ports:
- name: http
containerPort: {{ .Values.services.gitserver.ports.http }}
protocol: TCP
- name: ssh
containerPort: {{ .Values.services.gitserver.ports.ssh }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.configMapName }}
{{- with .Values.services.gitserver.extraEnv }}
env:
{{- range $key, $val := . }}
- name: {{ $key }}
value: {{ $val | quote }}
{{- end }}
{{- end }}
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 15
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 10
{{- with .Values.services.gitserver.resources }}
resources:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.services.gitserver.volumeMounts }}
volumeMounts:
{{- toYaml . | nindent 12 }}
{{- end }}
volumes:
- name: shared-data
persistentVolumeClaim:
claimName: {{ .Values.pvcName }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@ -0,0 +1,20 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "gitserver") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: gitserver
spec:
type: ClusterIP
ports:
- port: {{ .Values.services.gitserver.ports.http }}
targetPort: http
protocol: TCP
name: http
- port: {{ .Values.services.gitserver.ports.ssh }}
targetPort: ssh
protocol: TCP
name: ssh
selector:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "gitserver") | nindent 4 }}

26
deploy/templates/hpa.yaml Normal file
View File

@ -0,0 +1,26 @@
{{- range $svcKey, $svcVal := .Values.services }}
{{- if $svcVal.autoscaling.enabled }}
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" $ "svcKey" $svcKey) }}
labels:
{{- include "deploy.labels" $ | nindent 4 }}
app.kubernetes.io/component: {{ $svcKey | replace "_" "-" }}
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: {{ include "deploy.serviceFullname" (dict "root" $ "svcKey" $svcKey) }}
minReplicas: {{ $svcVal.autoscaling.minReplicas }}
maxReplicas: {{ $svcVal.autoscaling.maxReplicas }}
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: {{ $svcVal.autoscaling.targetCPUUtilization }}
{{- end }}
{{- end }}

View File

@ -0,0 +1,41 @@
{{- if .Values.ingress.enabled -}}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ include "deploy.fullname" . }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
{{- with .Values.ingress.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
{{- with .Values.ingress.className }}
ingressClassName: {{ . }}
{{- end }}
{{- if .Values.ingress.tls }}
tls:
{{- range .Values.ingress.tls }}
- hosts:
{{- range .hosts }}
- {{ . | quote }}
{{- end }}
secretName: {{ .secretName }}
{{- end }}
{{- end }}
rules:
{{- range .Values.ingress.hosts }}
- host: {{ .host | quote }}
http:
paths:
{{- range .paths }}
- path: {{ .path }}
pathType: {{ .pathType }}
backend:
service:
name: {{ include "deploy.serviceFullname" (dict "root" $ "svcKey" .serviceName) }}
port:
number: {{ .servicePort }}
{{- end }}
{{- end }}
{{- end }}

View File

@ -0,0 +1,70 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "metrics_aggregator") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: metrics-aggregator
spec:
replicas: {{ .Values.services.metrics_aggregator.replicaCount | default 1 }}
selector:
matchLabels:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "metrics_aggregator") | nindent 6 }}
template:
metadata:
labels:
{{- include "deploy.labels" . | nindent 8 }}
app.kubernetes.io/component: metrics-aggregator
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
{{- with .Values.podSecurityContext }}
securityContext:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: metrics_aggregator
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 12 }}
{{- end }}
image: "{{ .Values.imageRegistry }}/{{ .Values.services.metrics_aggregator.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
imagePullPolicy: IfNotPresent
ports:
- name: http
containerPort: {{ .Values.services.metrics_aggregator.port }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.configMapName }}
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 15
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 10
{{- with .Values.services.metrics_aggregator.resources }}
resources:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@ -0,0 +1,16 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "metrics_aggregator") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: metrics-aggregator
spec:
type: ClusterIP
ports:
- port: {{ .Values.services.metrics_aggregator.port }}
targetPort: http
protocol: TCP
name: http
selector:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "metrics_aggregator") | nindent 4 }}

View File

@ -0,0 +1 @@
{{/* Secret disabled — all config via ConfigMap */}}

View File

@ -0,0 +1,13 @@
{{- if .Values.serviceAccount.create -}}
apiVersion: v1
kind: ServiceAccount
metadata:
name: {{ include "deploy.serviceAccountName" . }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
{{- with .Values.serviceAccount.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
automountServiceAccountToken: {{ .Values.serviceAccount.automount }}
{{- end }}

View File

@ -0,0 +1,78 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "static_server") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: static-server
spec:
replicas: {{ .Values.services.static_server.replicaCount | default 1 }}
selector:
matchLabels:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "static_server") | nindent 6 }}
template:
metadata:
labels:
{{- include "deploy.labels" . | nindent 8 }}
app.kubernetes.io/component: static-server
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "deploy.serviceAccountName" . }}
{{- with .Values.podSecurityContext }}
securityContext:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: static_server
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 12 }}
{{- end }}
image: "{{ .Values.imageRegistry }}/{{ .Values.services.static_server.repository }}:{{ .Values.imageTag | default .Chart.AppVersion }}"
imagePullPolicy: IfNotPresent
ports:
- name: http
containerPort: {{ .Values.services.static_server.port }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.configMapName }}
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 15
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 5
periodSeconds: 10
{{- with .Values.services.static_server.resources }}
resources:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.services.static_server.volumeMounts }}
volumeMounts:
{{- toYaml . | nindent 12 }}
{{- end }}
volumes:
- name: shared-data
persistentVolumeClaim:
claimName: {{ .Values.pvcName }}
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View File

@ -0,0 +1,16 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "deploy.serviceFullname" (dict "root" . "svcKey" "static_server") }}
labels:
{{- include "deploy.labels" . | nindent 4 }}
app.kubernetes.io/component: static-server
spec:
type: ClusterIP
ports:
- port: {{ .Values.services.static_server.port }}
targetPort: http
protocol: TCP
name: http
selector:
{{- include "deploy.serviceSelectorLabels" (dict "root" . "svcKey" "static_server") | nindent 4 }}

182
deploy/values.yaml Normal file
View File

@ -0,0 +1,182 @@
# Global image registry and tag
imageRegistry: ""
imageTag: ""
# External ConfigMap (managed outside Helm)
configMapName: "app-env"
# Service definitions
services:
app:
repository: app
port: 3000
replicaCount: 2
autoscaling:
enabled: true
minReplicas: 2
maxReplicas: 10
targetCPUUtilization: 80
command:
- "app"
- "--bind"
- "0.0.0.0:3000"
resources:
requests:
cpu: 200m
memory: 256Mi
limits:
cpu: "1"
memory: 512Mi
volumeMounts:
- name: shared-data
mountPath: /data/repos
subPath: repos
- name: shared-data
mountPath: /data/avatars
subPath: avatars
- name: shared-data
mountPath: /data/files
subPath: files
email_worker:
repository: email-worker
port: 8084
replicaCount: 1
autoscaling:
enabled: false # email must stay at 1 replica
resources:
requests:
cpu: 100m
memory: 128Mi
limits:
cpu: 500m
memory: 256Mi
git_hook:
repository: git-hook
port: 8083
replicaCount: 1
autoscaling:
enabled: true
minReplicas: 1
maxReplicas: 5
targetCPUUtilization: 80
resources:
requests:
cpu: 100m
memory: 128Mi
limits:
cpu: 500m
memory: 256Mi
volumeMounts:
- name: shared-data
mountPath: /data/repos
subPath: repos
gitserver:
repository: gitserver
ports:
http: 8021
ssh: 2222
replicaCount: 1
autoscaling:
enabled: true
minReplicas: 1
maxReplicas: 5
targetCPUUtilization: 80
# SSH port must match the containerPort
extraEnv:
APP_SSH_PORT: "2222"
resources:
requests:
cpu: 100m
memory: 128Mi
limits:
cpu: 500m
memory: 256Mi
volumeMounts:
- name: shared-data
mountPath: /data/repos
subPath: repos
metrics_aggregator:
repository: metrics-aggregator
port: 9090
replicaCount: 1
autoscaling:
enabled: true
minReplicas: 1
maxReplicas: 5
targetCPUUtilization: 80
resources:
requests:
cpu: 100m
memory: 128Mi
limits:
cpu: 500m
memory: 256Mi
static_server:
repository: static-server
port: 8081
replicaCount: 1
autoscaling:
enabled: true
minReplicas: 1
maxReplicas: 5
targetCPUUtilization: 80
resources:
requests:
cpu: 50m
memory: 64Mi
limits:
cpu: 200m
memory: 128Mi
volumeMounts:
- name: shared-data
mountPath: /data
subPath: static
# External PVC (managed outside Helm — not deleted on uninstall)
pvcName: "shared-data"
# Ingress — only for the main app service
ingress:
enabled: false
className: ""
annotations: {}
hosts:
- host: chart-example.local
paths:
- path: /
pathType: Prefix
serviceName: app
servicePort: 3000
tls: []
# - secretName: chart-example-tls
# hosts:
# - chart-example.local
imagePullSecrets: []
nameOverride: ""
fullnameOverride: ""
serviceAccount:
create: true
automount: true
annotations: {}
name: ""
podSecurityContext:
runAsNonRoot: true
runAsUser: 1000
securityContext:
capabilities:
drop:
- ALL
readOnlyRootFilesystem: true
nodeSelector: {}
tolerations: []
affinity: {}

10
docker/app.Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 openssh-client procps git \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/app /bin
USER appuser
EXPOSE 3000
CMD ["app"]

10
docker/email.Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/email-worker /bin
USER appuser
EXPOSE 8084
CMD ["email-worker"]

View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/gingress /bin
USER appuser
EXPOSE 80 443 8080
CMD ["gingress"]

10
docker/githook.Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 git \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/git-hook /bin
USER appuser
EXPOSE 8083
CMD ["git-hook"]

View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 git openssh-client \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/gitserver /bin
USER appuser
EXPOSE 8021 2222
CMD ["gitserver"]

10
docker/metrics.Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/metrics-aggregator /bin
USER appuser
EXPOSE 9090
CMD ["metrics-aggregator"]

10
docker/static.Dockerfile Normal file
View File

@ -0,0 +1,10 @@
FROM ubuntu:24.04
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates libssl3 \
&& rm -rf /var/lib/apt/lists/*
RUN useradd --system --create-home appuser
WORKDIR /home/appuser
COPY target/release/static-server /bin
USER appuser
EXPOSE 8081
CMD ["static-server"]

View File

@ -42,5 +42,6 @@ reqwest = { workspace = true, features = ["json"] }
utoipa = { workspace = true }
tokio-stream = { workspace = true }
redis = { workspace = true, features = ["tokio-comp"] }
queue = { workspace = true }
[lints]
workspace = true

View File

@ -152,7 +152,8 @@ impl RigAgentService {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
let cleaned = text.text.replace('\n', "");
let _ = tx.send(Ok(StreamChunk::Text(cleaned))).await;
final_content.push_str(&text.text);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(
@ -237,7 +238,8 @@ impl RigAgentService {
Ok(MultiTurnStreamItem::StreamAssistantItem(
StreamedAssistantContent::Text(text),
)) => {
let _ = tx.send(Ok(StreamChunk::Text(text.text.clone()))).await;
let cleaned = text.text.replace('\n', "");
let _ = tx.send(Ok(StreamChunk::Text(cleaned))).await;
final_content.push_str(&text.text);
}
Ok(MultiTurnStreamItem::FinalResponse(resp)) => {

View File

@ -1,226 +1,455 @@
//! AI usage billing — records token costs against a project or workspace balance.
//! Billing service — handles user-level and project-level billing, deduction,
//! credit initialization, and error persistence.
//!
//! All functions take `&DatabaseConnection` instead of `&AppService`.
//! Architecture:
//! - Each user gets $10 personal balance on signup.
//! - Each project gets $20 balance only if it's the creator's first project,
//! $0 otherwise.
//! - AI usage is deducted from the project balance first; if insufficient,
//! falls through to the user's personal balance.
//! - Monthly quota only applies to pro users (is_pro = true).
//! - If both project and user balance are insufficient, a billing_error
//! record is persisted and an error is returned to the caller.
use db::database::AppDatabase;
use models::agents::model_pricing;
use models::projects::project;
use models::projects::project_billing;
use models::projects::project_billing_history;
use models::workspaces::workspace_billing;
use models::workspaces::workspace_billing_history;
use models::ai::billing_error;
use models::projects::{project, project_billing, project_billing_history};
use models::users::user_billing;
use rust_decimal::Decimal;
use sea_orm::*;
use uuid::Uuid;
use crate::error::AgentError;
// ── Constants ──
fn default_user_balance() -> Decimal { Decimal::new(100_000, 4) } // $10.0000
fn first_project_credit() -> Decimal { Decimal::new(200_000, 4) } // $20.0000
const SUBSEQUENT_PROJECT_BALANCE: Decimal = Decimal::ZERO;
// ── Types ──
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, utoipa::ToSchema)]
pub struct BillingRecord {
pub cost: f64,
pub currency: String,
pub input_tokens: i64,
pub output_tokens: i64,
pub deducted_from: String, // "project" or "user"
}
/// Extended result that includes insufficient balance flag for system message creation.
#[derive(Debug)]
pub enum BillingResult {
Success(BillingRecord),
InsufficientBalance { message: String },
}
/// Record AI usage for a project with cascading billing.
// ── Core deduction: AI usage ──
/// Record AI usage: deduct from project balance first, fall through to user balance.
///
/// Billing strategy:
/// 1. Try to deduct from project balance first
/// 2. If insufficient, fallback to workspace balance (if project belongs to workspace)
/// 3. If both insufficient or no workspace, return InsufficientBalance error with room_id
///
/// Returns BillingError::InsufficientBalance with room_id for system message creation.
/// Returns `InsufficientBalance` if neither account can cover the cost.
/// On insufficient balance, a `billing_error` record is persisted for frontend display.
pub async fn record_ai_usage(
db: &AppDatabase,
project_uid: Uuid,
user_uid: Uuid,
model_id: Uuid,
input_tokens: i64,
output_tokens: i64,
) -> Result<BillingResult, AgentError> {
// 1. Look up the active price for this model.
let total_cost = compute_cost(db, model_id, input_tokens, output_tokens).await?;
let currency = get_currency(db, model_id).await?;
// Verify project exists
let _ = project::Entity::find_by_id(project_uid)
.one(db)
.await?
.ok_or_else(|| AgentError::Internal("Project not found".into()))?;
// Attempt project-level deduction first
let project_result = deduct_from_project(db, project_uid, total_cost, &currency, model_id, input_tokens, output_tokens).await;
match project_result {
Ok(()) => {
let cost_f64 = decimal_to_f64(total_cost);
tracing::info!(
project_id = %project_uid,
model_id = %model_id,
input_tokens, output_tokens,
cost = %cost_f64,
currency = %currency,
deducted_from = "project",
"ai_usage_recorded"
);
Ok(BillingResult::Success(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
deducted_from: "project".to_string(),
}))
}
Err(_) => {
// Project balance insufficient — try user personal balance
let user_result = deduct_from_user(db, user_uid, total_cost, &currency, project_uid, model_id, input_tokens, output_tokens).await;
match user_result {
Ok(()) => {
let cost_f64 = decimal_to_f64(total_cost);
tracing::info!(
user_id = %user_uid,
project_id = %project_uid,
model_id = %model_id,
input_tokens, output_tokens,
cost = %cost_f64,
currency = %currency,
deducted_from = "user",
"ai_usage_recorded"
);
Ok(BillingResult::Success(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
deducted_from: "user".to_string(),
}))
}
Err(insufficient_msg) => {
// Both project and user balance insufficient — persist error
persist_billing_error(
db,
"project",
project_uid,
"insufficient_balance",
&insufficient_msg,
Some(serde_json::json!({
"user_id": user_uid.to_string(),
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cost": decimal_to_f64(total_cost),
"currency": currency,
})),
).await?;
Ok(BillingResult::InsufficientBalance {
message: insufficient_msg,
})
}
}
}
}
}
/// Check whether a project + user has sufficient combined balance for a potential AI call.
/// Called before starting AI processing to avoid wasted compute.
pub async fn check_balance(
db: &AppDatabase,
project_uid: Uuid,
user_uid: Uuid,
model_id: Uuid,
estimated_input_tokens: i64,
estimated_output_tokens: i64,
) -> Result<bool, AgentError> {
let estimated_cost = compute_cost(db, model_id, estimated_input_tokens, estimated_output_tokens).await?;
let project_balance = get_project_balance(db, project_uid).await;
let user_balance = get_user_balance(db, user_uid).await;
Ok(project_balance + user_balance >= estimated_cost)
}
// ── Initialization ──
/// Initialize a user billing account with the default $10 balance.
/// Called on user signup / first login.
pub async fn initialize_user_billing(db: &AppDatabase, user_uid: Uuid) -> Result<(), AgentError> {
let now = chrono::Utc::now();
user_billing::ActiveModel {
user: Set(user_uid),
balance: Set(default_user_balance()),
currency: Set("USD".to_string()),
is_pro: Set(false),
monthly_quota: Set(Decimal::ZERO),
month_used: Set(Decimal::ZERO),
cycle_start: Set(None),
cycle_end: Set(None),
updated_at: Set(now),
created_at: Set(now),
}
.insert(db)
.await
.map_err(|e| AgentError::Internal(format!("failed to create user billing: {}", e)))?;
tracing::info!(user_id = %user_uid, balance = "$10", "user_billing_initialized");
Ok(())
}
/// Initialize a project billing account.
/// Grants $20 only if this is the creator's first project; $0 otherwise.
pub async fn initialize_project_billing(
db: &AppDatabase,
project_uid: Uuid,
creator_uid: Uuid,
) -> Result<(), AgentError> {
// Check how many projects this user has already created
let existing_count = project::Entity::find()
.filter(project::Column::CreatedBy.eq(creator_uid))
.count(db)
.await
.map_err(|e| AgentError::Internal(format!("failed to count user projects: {}", e)))?;
let is_first = existing_count == 0;
let initial_balance = if is_first { first_project_credit() } else { SUBSEQUENT_PROJECT_BALANCE };
let now = chrono::Utc::now();
project_billing::ActiveModel {
project: Set(project_uid),
balance: Set(initial_balance),
currency: Set("USD".to_string()),
user: Set(Some(creator_uid)),
initial_credit_granted: Set(is_first),
is_pro: Set(false),
monthly_quota: Set(Decimal::ZERO),
month_used: Set(Decimal::ZERO),
cycle_start: Set(None),
cycle_end: Set(None),
updated_at: Set(now),
created_at: Set(now),
}
.insert(db)
.await
.map_err(|e| AgentError::Internal(format!("failed to create project billing: {}", e)))?;
if is_first {
// Record the credit in billing history
project_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
project: Set(project_uid),
user: Set(Some(creator_uid)),
amount: Set(first_project_credit()),
currency: Set("USD".to_string()),
reason: Set("first_project_credit".to_string()),
extra: Set(Some(serde_json::json!({
"is_first_project": true,
}))),
created_at: Set(now),
..Default::default()
}
.insert(db)
.await
.map_err(|e| AgentError::Internal(format!("failed to record credit history: {}", e)))?;
}
tracing::info!(
project_id = %project_uid,
creator_id = %creator_uid,
is_first_project = is_first,
balance = if is_first { "$20" } else { "$0" },
"project_billing_initialized"
);
Ok(())
}
// ── Internal helpers ──
async fn compute_cost(
db: &AppDatabase,
model_id: Uuid,
input_tokens: i64,
output_tokens: i64,
) -> Result<Decimal, AgentError> {
let pricing = model_pricing::Entity::find()
.filter(model_pricing::Column::ModelVersionId.eq(model_id))
.order_by_desc(model_pricing::Column::EffectiveFrom)
.one(db)
.await?
.ok_or_else(|| {
AgentError::Internal(
"No pricing record found for this model. Please configure AI model pricing first."
.into(),
)
})?;
.ok_or_else(|| AgentError::Internal(
"No pricing record found for this model. Please configure AI model pricing first.".into(),
))?;
let input_price: Decimal = pricing.input_price_per_1k_tokens.parse()
.map_err(|e| AgentError::Internal(format!("Invalid input price: {}", e)))?;
let output_price: Decimal = pricing.output_price_per_1k_tokens.parse()
.map_err(|e| AgentError::Internal(format!("Invalid output price: {}", e)))?;
// 2. Compute cost using Decimal arithmetic.
let input_price: Decimal = pricing
.input_price_per_1k_tokens
.parse()
.map_err(|e| AgentError::Internal(format!("Invalid input price format: {}", e)))?;
let output_price: Decimal = pricing
.output_price_per_1k_tokens
.parse()
.map_err(|e| AgentError::Internal(format!("Invalid output price format: {}", e)))?;
let tokens_i = Decimal::from(input_tokens);
let tokens_o = Decimal::from(output_tokens);
let thousand = Decimal::from(1000);
Ok((Decimal::from(input_tokens) / thousand) * input_price
+ (Decimal::from(output_tokens) / thousand) * output_price)
}
let total_cost = (tokens_i / thousand) * input_price
+ (tokens_o / thousand) * output_price;
let currency = pricing.currency.clone();
// 3. Cascading billing: project balance first, then workspace if insufficient.
let proj = project::Entity::find_by_id(project_uid)
async fn get_currency(db: &AppDatabase, model_id: Uuid) -> Result<String, AgentError> {
let pricing = model_pricing::Entity::find()
.filter(model_pricing::Column::ModelVersionId.eq(model_id))
.one(db)
.await?
.ok_or_else(|| AgentError::Internal("Project not found".into()))?;
.ok_or_else(|| AgentError::Internal("No pricing found".into()))?;
Ok(pricing.currency.clone())
}
let txn = db.begin().await?;
async fn get_project_balance(db: &AppDatabase, project_uid: Uuid) -> Decimal {
project_billing::Entity::find_by_id(project_uid)
.one(db)
.await
.ok()
.flatten()
.map(|b| b.balance)
.unwrap_or(Decimal::ZERO)
}
// Always check project balance first
let project_billing = project_billing::Entity::find_by_id(project_uid)
async fn get_user_balance(db: &AppDatabase, user_uid: Uuid) -> Decimal {
user_billing::Entity::find_by_id(user_uid)
.one(db)
.await
.ok()
.flatten()
.map(|b| b.balance)
.unwrap_or(Decimal::ZERO)
}
async fn deduct_from_project(
db: &AppDatabase,
project_uid: Uuid,
cost: Decimal,
currency: &str,
model_id: Uuid,
input_tokens: i64,
output_tokens: i64,
) -> Result<(), String> {
let txn = db.begin().await.map_err(|e| format!("db txn error: {}", e))?;
let billing = project_billing::Entity::find_by_id(project_uid)
.lock_exclusive()
.one(&txn)
.await?
.ok_or_else(|| AgentError::Internal("Project billing account not found".into()))?;
.await
.map_err(|e| format!("db error: {}", e))?
.ok_or_else(|| "Project billing account not found".to_string())?;
if billing.balance < cost {
txn.rollback().await.ok();
return Err(format!(
"Project balance insufficient. Required: {:.4} {}, Available: {:.4} {}",
cost, currency, billing.balance, currency
));
}
let now = chrono::Utc::now();
if project_billing.balance >= total_cost {
// ── Project has sufficient balance ──────────────────────────
let amount_dec = -total_cost;
project_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
project: Set(project_uid),
user: Set(None),
amount: Set(amount_dec),
currency: Set(currency.clone()),
reason: Set("ai_usage".to_string()),
extra: Set(Some(serde_json::json!({
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
}))),
created_at: Set(now),
..Default::default()
}
.insert(&txn)
.await?;
let new_balance = project_billing.balance - total_cost;
let mut updated: project_billing::ActiveModel = project_billing.into();
updated.balance = Set(new_balance);
updated.updated_at = Set(now);
updated.update(&txn).await?;
txn.commit().await?;
let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0);
tracing::info!(
project_id = %project_uid,
model_id = %model_id,
input_tokens = input_tokens,
output_tokens = output_tokens,
cost = %cost_f64,
currency = %currency,
source = "project",
"ai_usage_recorded"
);
Ok(BillingResult::Success(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
}))
} else if let Some(workspace_id) = proj.workspace_id {
// ── Project insufficient, fallback to workspace ─────────────
let workspace_billing = workspace_billing::Entity::find_by_id(workspace_id)
.lock_exclusive()
.one(&txn)
.await?
.ok_or_else(|| AgentError::Internal("Workspace billing account not found".into()))?;
if workspace_billing.balance < total_cost {
txn.rollback().await?;
return Ok(BillingResult::InsufficientBalance {
message: format!(
"Insufficient balance. Project: {:.4} {}, Workspace: {:.4} {}, Required: {:.4} {}",
project_billing.balance, currency,
workspace_billing.balance, currency,
total_cost, currency
),
});
}
let amount_dec = -total_cost;
workspace_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
workspace_id: Set(workspace_id),
user_id: Set(Some(proj.created_by)),
amount: Set(amount_dec),
currency: Set(currency.clone()),
reason: Set(format!("ai_usage:{}", project_uid)),
extra: Set(Some(serde_json::json!({
"project_id": project_uid.to_string(),
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"fallback_reason": "project_balance_insufficient"
}))),
created_at: Set(now),
}
.insert(&txn)
.await?;
let new_balance = workspace_billing.balance - total_cost;
let new_total_spent = workspace_billing.total_spent + total_cost;
let mut updated: workspace_billing::ActiveModel = workspace_billing.into();
updated.balance = Set(new_balance);
updated.total_spent = Set(new_total_spent);
updated.updated_at = Set(now);
updated.update(&txn).await?;
txn.commit().await?;
let cost_f64 = total_cost.to_string().parse().unwrap_or(0.0);
tracing::info!(
project_id = %project_uid,
model_id = %model_id,
input_tokens = input_tokens,
output_tokens = output_tokens,
cost = %cost_f64,
currency = %currency,
workspace_id = %workspace_id.to_string(),
source = "workspace_fallback",
"ai_usage_recorded"
);
Ok(BillingResult::Success(BillingRecord {
cost: cost_f64,
currency,
input_tokens,
output_tokens,
}))
} else {
// ── Project insufficient and no workspace ───────────────────
txn.rollback().await?;
Ok(BillingResult::InsufficientBalance {
message: format!(
"Insufficient balance. Required: {:.4} {}, Available: {:.4} {}",
total_cost, currency, project_billing.balance, currency
),
})
project_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
project: Set(project_uid),
user: Set(None),
amount: Set(-cost),
currency: Set(currency.to_string()),
reason: Set("ai_usage".to_string()),
extra: Set(Some(serde_json::json!({
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"deducted_from": "project",
}))),
created_at: Set(now),
..Default::default()
}
.insert(&txn)
.await
.map_err(|e| format!("failed to insert history: {}", e))?;
let mut updated: project_billing::ActiveModel = billing.into();
updated.balance = Set(updated.balance.unwrap() - cost);
updated.updated_at = Set(now);
updated.update(&txn).await.map_err(|e| format!("failed to update balance: {}", e))?;
txn.commit().await.map_err(|e| format!("commit error: {}", e))?;
Ok(())
}
async fn deduct_from_user(
db: &AppDatabase,
user_uid: Uuid,
cost: Decimal,
currency: &str,
project_uid: Uuid,
model_id: Uuid,
input_tokens: i64,
output_tokens: i64,
) -> Result<(), String> {
let txn = db.begin().await.map_err(|e| format!("db txn error: {}", e))?;
let billing = user_billing::Entity::find_by_id(user_uid)
.lock_exclusive()
.one(&txn)
.await
.map_err(|e| format!("db error: {}", e))?
.ok_or_else(|| "User billing account not found".to_string())?;
if billing.balance < cost {
txn.rollback().await.ok();
return Err(format!(
"Insufficient balance (project + user). Project: unavailable, User: {:.4} {}. Required: {:.4} {}",
billing.balance, currency, cost, currency
));
}
let now = chrono::Utc::now();
// Record in project billing history (but deducted from user)
project_billing_history::ActiveModel {
uid: Set(Uuid::new_v4()),
project: Set(project_uid),
user: Set(Some(user_uid)),
amount: Set(-cost),
currency: Set(currency.to_string()),
reason: Set("ai_usage_user_fallback".to_string()),
extra: Set(Some(serde_json::json!({
"model_id": model_id.to_string(),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"deducted_from": "user",
}))),
created_at: Set(now),
..Default::default()
}
.insert(&txn)
.await
.map_err(|e| format!("failed to insert history: {}", e))?;
let mut updated: user_billing::ActiveModel = billing.into();
updated.balance = Set(updated.balance.unwrap() - cost);
updated.updated_at = Set(now);
updated.update(&txn).await.map_err(|e| format!("failed to update user balance: {}", e))?;
txn.commit().await.map_err(|e| format!("commit error: {}", e))?;
Ok(())
}
pub async fn persist_billing_error(
db: &AppDatabase,
scope: &str,
scope_id: Uuid,
error_type: &str,
message: &str,
details: Option<serde_json::Value>,
) -> Result<(), AgentError> {
billing_error::ActiveModel {
id: Set(Uuid::new_v4()),
scope: Set(scope.to_string()),
scope_id: Set(scope_id),
error_type: Set(error_type.to_string()),
message: Set(message.to_string()),
details: Set(details),
resolved: Set(false),
created_at: Set(chrono::Utc::now()),
}
.insert(db)
.await
.map_err(|e| AgentError::Internal(format!("failed to persist billing error: {}", e)))?;
tracing::warn!(scope, %scope_id, error_type, "billing_error_persisted");
Ok(())
}
fn decimal_to_f64(d: Decimal) -> f64 {
d.round_dp(10).to_string().parse().unwrap_or(0.0)
}

View File

@ -0,0 +1,357 @@
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
use crate::client::AiClientConfig;
use crate::client::types::{ChatRequestMessage, ToolCall};
use crate::client::{StreamChunk, StreamChunkType, StreamedToolCall, call_stream};
use crate::error::Result;
use crate::tool::{ToolCall as AgentToolCall, ToolContext, ToolDefinition, ToolExecutor, ToolHandler, ToolParam};
use crate::tool::registry::ToolRegistry;
use crate::embed::EmbedService;
use sea_orm::{ActiveModelTrait, EntityTrait, Set};
use super::{AiChunkType, AiStreamChunk, StreamCallback};
use super::service::StreamResult;
// Keyword-extraction-based title generator: reads conversation messages, extracts
// meaningful words, and updates the conversation record with a short title.
async fn generate_title_for_conversation(
ctx: &ToolContext,
conversation_id: Uuid,
) -> Result<serde_json::Value> {
use models::ai::{ai_conversation, ai_message, AiMessage};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
let db_reader = ctx.db().reader();
let db_writer = ctx.db().writer();
let conv = ai_conversation::Entity::find_by_id(conversation_id)
.one(db_reader)
.await
.map_err(|e| crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: format!("db error: {}", e) })?
.ok_or_else(|| crate::error::AgentError::NotFound("Conversation not found".into()))?;
let recent_messages = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.filter(ai_message::Column::Role.eq("user"))
.order_by_desc(ai_message::Column::CreatedAt)
.limit(3)
.all(db_reader)
.await
.map_err(|e| crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: format!("db error: {}", e) })?;
if recent_messages.is_empty() {
return Err(crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: "No user messages found".into() });
}
let content = recent_messages
.first()
.and_then(|m| m.content.as_array())
.and_then(|arr| arr.first())
.and_then(|v| v.get("content"))
.and_then(|c| c.as_str())
.unwrap_or("");
let words: Vec<&str> = content
.split_whitespace()
.filter(|w| w.len() > 2 && !is_stop_word(w))
.take(5)
.collect();
let title = if words.is_empty() {
"New Chat".to_string()
} else {
words.join(" ")
};
let mut active: ai_conversation::ActiveModel = conv.into();
active.title = Set(Some(title.clone()));
active.updated_at = Set(chrono::Utc::now());
active
.update(db_writer)
.await
.map_err(|e| crate::error::AgentError::ToolExecutionFailed { tool: "generate_title".into(), cause: format!("failed to update title: {}", e) })?;
Ok(serde_json::json!({ "conversation_id": conversation_id.to_string(), "title": title }))
}
fn is_stop_word(w: &str) -> bool {
matches!(
w.to_lowercase().as_str(),
"the" | "this" | "that" | "what" | "which" | "when" | "where"
| "why" | "how" | "can" | "could" | "would" | "should"
| "please" | "help" | "thanks" | "thank" | "you" | "your"
| "have" | "has" | "had" | "with" | "for" | "from" | "into"
| "about" | "also" | "just" | "now" | "very" | "really"
)
}
type SharedCallback = Arc<dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
/// Simplified ReAct execution for Chat API.
///
/// Unlike `execute_process_stream` (which requires `AiRequest` with room-specific data),
/// this function takes messages and tools directly. It does NOT record AI sessions to
/// the `ai_session` table — the caller is responsible for persisting results.
pub async fn execute_chat_stream(
messages: Vec<ChatRequestMessage>,
tools: Vec<serde_json::Value>,
model_name: &str,
config: &AiClientConfig,
temperature: f32,
max_tokens: u32,
max_tool_depth: usize,
tool_registry: Option<&ToolRegistry>,
db: db::database::AppDatabase,
cache: db::cache::AppCache,
app_config: config::AppConfig,
project_id: Uuid,
sender_uid: Uuid,
embed_service: Option<EmbedService>,
on_chunk: StreamCallback,
conversation_id: Option<uuid::Uuid>,
) -> Result<StreamResult> {
let on_chunk: SharedCallback = Arc::from(on_chunk);
let tools_enabled = !tools.is_empty();
let mut messages = messages;
let mut tool_depth = 0;
let mut total_input_tokens = 0i64;
let mut total_output_tokens = 0i64;
let mut full_content = String::new();
let mut all_chunks: Vec<StreamChunk> = Vec::new();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<StreamedToolCall>();
// Conditionally inject chat_generate_title tool if conversation has no title
let (tools, _tools_injected) = if let Some(conv_id) = conversation_id {
if let Some(registry) = tool_registry {
let db_reader = db.reader();
let has_title = models::ai::ai_conversation::Entity::find_by_id(conv_id)
.one(db_reader)
.await
.map(|c| c.map(|m| m.title.is_some()).unwrap_or(false))
.unwrap_or(false);
if !has_title {
let mut reg = registry.clone();
reg.register(
ToolDefinition::new("chat_generate_title")
.description(
"Generate a concise title (5 words or fewer) for the current conversation \
based on its message history, and save it to the conversation record. \
Call this tool at the start of a new conversation if it has no title.",
)
.parameters(crate::tool::ToolSchema {
schema_type: "object".into(),
properties: Some({
let mut p = std::collections::HashMap::new();
p.insert("conversation_id".into(), ToolParam {
name: "conversation_id".into(),
param_type: "string".into(),
description: Some("The UUID of the conversation (required).".into()),
required: true,
properties: None,
items: None,
});
p
}),
required: Some(vec!["conversation_id".into()]),
}),
ToolHandler::new(|ctx, args| {
let conv_id = args.get("conversation_id")
.and_then(|v| v.as_str())
.and_then(|s| Uuid::parse_str(s).ok());
Box::pin(async move {
match conv_id {
Some(id) => generate_title_for_conversation(&ctx, id).await
.map_err(|e| crate::tool::ToolError::ExecutionError(e.to_string())),
None => Err(crate::tool::ToolError::ExecutionError("conversation_id missing".into())),
}
})
}),
);
// Prepend system message instructing the model to generate title first
messages.insert(0, ChatRequestMessage::system(
"IMPORTANT: If the conversation has no title, you MUST call chat_generate_title \
with the conversation_id immediately before answering any user question. \
The title must be 5 words or fewer and should summarize the user's intent.".to_string(),
));
(reg.to_openai_tools(), true)
} else {
(tools.clone(), false)
}
} else {
(tools.clone(), false)
}
} else {
(tools.clone(), false)
};
loop {
let on_chunk_cb = on_chunk.clone();
let on_chunk_cb2 = on_chunk.clone();
let tx_arc = Arc::new(tx.clone());
let tx_arc2 = tx_arc.clone();
let response = call_stream(
&messages, model_name, config, temperature, max_tokens,
if tools_enabled { Some(&tools) } else { None }, None,
Arc::new(move |delta| {
let content = delta.to_string().replace('\n', "");
let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer });
fut
}),
Arc::new(move |delta| {
let fut = on_chunk_cb2(AiStreamChunk { content: delta.to_string(), done: false, chunk_type: AiChunkType::Thinking });
fut
}),
Arc::new(move |tc: &StreamedToolCall| {
let tx = tx_arc2.clone();
let tc_owned = tc.clone();
Box::pin(async move { let _ = tx.send(tc_owned); }) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
}),
).await?;
total_input_tokens += response.input_tokens;
total_output_tokens += response.output_tokens;
all_chunks.extend(response.chunks.clone());
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
if !has_tool_calls {
let final_content = response.content.clone();
// Don't broadcast the done chunk via SSE/NATS — incremental deltas
// already delivered the content; the separate "done" SSE event
// signals completion. Pushing full content again would duplicate it
// in the frontend streaming store.
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: final_content.clone() });
return Ok(StreamResult {
content: final_content,
reasoning_content: response.reasoning_content,
input_tokens: total_input_tokens,
output_tokens: total_output_tokens,
chunks: all_chunks,
});
}
full_content.push_str(&response.content);
let tool_calls: Vec<ToolCall> = response.tool_calls.iter().map(|tc| ToolCall {
id: tc.id.clone(), type_: "function".into(),
function: crate::client::types::ToolCallFunction { name: tc.name.clone(), arguments: tc.arguments.clone() },
}).collect();
messages.push(ChatRequestMessage::assistant(Some(response.content.clone()), Some(tool_calls.clone())));
// Drain tool call notifications
loop {
match rx.try_recv() {
Ok(tc) => {
let args_display = if tc.arguments.len() > 100 {
let end = tc.arguments.char_indices().map(|(i, _)| i).take_while(|&i| i <= 100).last().unwrap_or(100);
format!("{}...", &tc.arguments[..end])
} else { tc.arguments.clone() };
let tool_display = format!("🔧 {}({})", tc.name, args_display);
on_chunk(AiStreamChunk { content: tool_display.clone(), done: false, chunk_type: AiChunkType::ToolCall }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: tool_display });
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
}
}
let calls: Vec<AgentToolCall> = response.tool_calls.iter().map(|tc| AgentToolCall {
id: tc.id.clone(), name: tc.name.clone(), arguments: tc.arguments.clone(),
}).collect();
let tool_messages = execute_tools(
&calls, &db, &cache, &app_config, project_id, sender_uid,
tool_registry, embed_service.as_ref(), &on_chunk, &mut all_chunks,
).await;
messages.extend(tool_messages);
tool_depth += 1;
if tool_depth >= max_tool_depth {
let max_depth_text = format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth);
on_chunk(AiStreamChunk { content: max_depth_text.clone(), done: true, chunk_type: AiChunkType::Answer }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: max_depth_text });
return Ok(StreamResult { content: full_content, reasoning_content: String::new(), input_tokens: 0, output_tokens: 0, chunks: all_chunks });
}
}
}
async fn execute_tools(
calls: &[AgentToolCall],
db: &db::database::AppDatabase,
cache: &db::cache::AppCache,
app_config: &config::AppConfig,
project_id: Uuid,
sender_uid: Uuid,
tool_registry: Option<&ToolRegistry>,
embed_service: Option<&EmbedService>,
on_chunk: &SharedCallback,
all_chunks: &mut Vec<StreamChunk>,
) -> Vec<ChatRequestMessage> {
let mut tool_messages = Vec::new();
let mut ctx = ToolContext::new(db.clone(), cache.clone(), app_config.clone(), Uuid::nil(), Some(sender_uid))
.with_project(project_id);
if let Some(es) = embed_service {
ctx = ctx.with_embed_service(es.clone());
}
if let Some(registry) = tool_registry {
ctx.registry_mut().merge(registry.clone());
}
let mut join_set = tokio::task::JoinSet::new();
for call in calls {
let call_clone = call.clone();
let mut ctx_clone = ctx.clone();
join_set.spawn(async move {
let executor = ToolExecutor::new();
let res = executor.execute_batch(vec![call_clone.clone()], &mut ctx_clone).await;
(call_clone, res)
});
}
let heartbeat_dur = std::time::Duration::from_secs(10);
while !join_set.is_empty() {
tokio::select! {
Some(res) = join_set.join_next() => {
if let Ok((call, results)) = res {
match results {
Ok(results) => {
for result in &results {
let preview = match &result.result {
crate::tool::ToolResult::Ok(v) => {
let t = v.to_string();
if t.len() > 300 {
let end = t.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
format!("{}...", &t[..end])
} else { t.clone() }
}
crate::tool::ToolResult::Error(msg) => msg.clone(),
};
tracing::debug!("tool_result: {} — {}", call.name, preview);
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
let msgs = ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
Err(e) => {
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_text = format!("[Tool call failed: {}]", e);
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
}
}
}
},
_ = tokio::time::sleep(heartbeat_dur) => {
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall }).await;
}
}
}
tool_messages
}

View File

@ -4,7 +4,7 @@ use sea_orm::*;
use super::context::RoomMessageContext;
use super::{AiRequest, Mention};
use crate::client::types::ChatRequestMessage;
use crate::compact::{CompactConfig, CompactService};
use crate::compact::CompactService;
use crate::embed::EmbedService;
use crate::error::Result;
use crate::perception::{PerceptionService, SkillEntry};
@ -55,7 +55,6 @@ impl MessageBuilder {
let mut processed_history = Vec::new();
if let Some(compact_service) = &self.compact_service {
let compact_cache_key = format!("ai:compact:{}", request.room.id);
let compact_config = CompactConfig::default();
let cached_summary: Option<String> = match request.cache.conn().await {
Ok(mut conn) => redis::cmd("GET").arg(&compact_cache_key).query_async::<Option<String>>(&mut conn).await.unwrap_or(None),
Err(e) => { tracing::warn!(error = %e, "compact cache: conn failed"); None }
@ -71,7 +70,22 @@ impl MessageBuilder {
}
if processed_history.is_empty() {
match compact_service.compact_room_auto(request.room.id, Some(request.user_names.clone()), compact_config).await {
let compact_config = request.context_setting.as_ref()
.map(|s| crate::compact::CompactConfig::from_project_setting(
s.context_window_tokens,
s.compaction_threshold,
s.compaction_max_summary_ratio,
))
.unwrap_or_default();
match compact_service.compact_room(
request.room.id,
compact_config.default_level,
Some(request.user_names.clone()),
request.sender.uid,
request.context_setting.as_ref().map(|s| s.context_window_tokens).unwrap_or(128000),
request.context_setting.as_ref().map(|s| s.compaction_max_summary_ratio).unwrap_or(0.2),
).await {
Ok(compact_summary) => {
if !compact_summary.summary.is_empty() {
messages.push(ChatRequestMessage::system(format!("Conversation summary:\n{}", compact_summary.summary)));
@ -174,7 +188,13 @@ impl MessageBuilder {
let keyword_skills = self.perception_service.inject_skills(&request.input, &history_texts, &[], &all_skills).await;
let mut vector_skills = Vec::new();
if let Some(es) = &self.embed_service {
vector_skills = crate::perception::VectorActiveAwareness::default().detect(es, &request.input, &request.project.id.to_string()).await;
let rag_enabled = request.context_setting.as_ref().map(|s| s.rag_enabled).unwrap_or(true);
if rag_enabled {
let max_results = request.context_setting.as_ref().map(|s| s.rag_max_results as usize).unwrap_or(3);
let min_score = request.context_setting.as_ref().map(|s| s.rag_min_score).unwrap_or(0.70);
let awareness = crate::perception::VectorActiveAwareness::new(max_results, min_score);
vector_skills = awareness.detect(es, &request.input, &request.project.id.to_string()).await;
}
}
let mut seen = std::collections::HashSet::new();
let mut result = Vec::new();
@ -184,8 +204,17 @@ impl MessageBuilder {
}
async fn build_memory_context(&self, request: &AiRequest) -> Vec<crate::perception::vector::MemoryContext> {
let rag_enabled = request.context_setting.as_ref().map(|s| s.rag_enabled).unwrap_or(true);
if !rag_enabled {
return Vec::new();
}
match &self.embed_service {
Some(es) => crate::perception::VectorPassiveAwareness::default().detect(es, &request.input, &request.project.display_name, &request.room.id.to_string()).await,
Some(es) => {
let max_results = request.context_setting.as_ref().map(|s| s.rag_max_results as usize).unwrap_or(3);
let min_score = request.context_setting.as_ref().map(|s| s.rag_min_score).unwrap_or(0.72);
let awareness = crate::perception::VectorPassiveAwareness::new(max_results, min_score);
awareness.detect(es, &request.input, &request.project.display_name, &request.room.id.to_string()).await
}
None => Vec::new(),
}
}

View File

@ -3,7 +3,7 @@ use std::pin::Pin;
use db::cache::AppCache;
use db::database::AppDatabase;
use models::agents::model;
use models::projects::project;
use models::projects::{project, project_context_setting};
use models::repos::repo;
use models::rooms::{room, room_message};
use models::users::user;
@ -44,7 +44,32 @@ impl Default for AiChunkType {
}
}
/// Optional streaming callback: called for each token chunk.
const THINK_OPEN: &str = "\x3cthinking\x3e";
const THINK_CLOSE: &str = "\x3c/response\x3e";
/// Strip XML-format thinking tags that some models (e.g. DeepSeek-R1) embed
/// in reasoning output. Also normalizes excessive consecutive newlines (3+ → 2).
pub fn normalize_thinking_content(content: &str) -> String {
let content = content
.replace(THINK_CLOSE, "")
.replace(THINK_OPEN, "")
.replace("\x3cthinking", "")
.replace("/response\x3e", "");
let mut result = String::with_capacity(content.len());
let mut newline_count = 0usize;
for ch in content.chars() {
if ch == '\n' {
newline_count += 1;
if newline_count <= 2 {
result.push(ch);
}
} else {
newline_count = 0;
result.push(ch);
}
}
result.trim().to_string()
}
pub type StreamCallback = Box<
dyn Fn(AiStreamChunk) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync,
>;
@ -55,6 +80,7 @@ pub struct AiRequest {
pub config: AppConfig,
pub model: model::Model,
pub project: project::Model,
pub context_setting: Option<project_context_setting::Model>,
pub sender: user::Model,
pub room: room::Model,
pub input: String,
@ -76,6 +102,7 @@ pub enum Mention {
Repo(repo::Model),
}
pub mod chat_execution;
pub mod context;
pub mod message_builder;
pub mod nonstreaming_execution;

View File

@ -82,13 +82,13 @@ pub async fn execute_process(
tool_depth += 1;
if tool_depth >= max_tool_depth {
let content = if text.is_empty() { format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth) } else { text };
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), input_tokens, output_tokens, session_start.elapsed().as_millis() as i64).await;
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), input_tokens, output_tokens, session_start.elapsed().as_millis() as i64).await;
return Ok(ProcessResult { content, input_tokens, output_tokens });
}
continue;
}
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), input_tokens, output_tokens, session_start.elapsed().as_millis() as i64).await;
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), input_tokens, output_tokens, session_start.elapsed().as_millis() as i64).await;
return Ok(ProcessResult { content: text, input_tokens, output_tokens });
}
}
@ -111,7 +111,7 @@ async fn execute_tools(
let elapsed = start.elapsed().as_millis() as i64;
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None };
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.clone(), session_id: recorder.session_id(), tool_name: call.clone(), caller: request.sender.uid, arguments: serde_json::Value::Null, status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success }, execution_time_ms: Some(elapsed), error_message: error_msg, error_stack: None, retry_count: 0 });
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: Uuid::new_v4().to_string(), session_id: recorder.session_id(), tool_name: call.clone(), caller: request.sender.uid, arguments: serde_json::Value::Null, status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success }, execution_time_ms: Some(elapsed), error_message: error_msg, error_stack: None, retry_count: 0 });
}
crate::tool::ToolExecutor::to_tool_messages(&results)
}

View File

@ -18,6 +18,8 @@ pub async fn execute_process_react<C, Fut>(
request: &AiRequest, mut on_chunk: C,
tool_registry: &ToolRegistry,
ai_base_url: Option<String>, ai_api_key: Option<String>,
room_preamble: Option<&str>,
message_producer: Option<queue::MessageProducer>,
) -> Result<(String, i64, i64)>
where
C: FnMut(ReactStep) -> Fut + Send,
@ -33,6 +35,9 @@ where
let room_id = request.room.id;
let sender_uid = request.sender.uid;
let project_id = request.project.id;
let ai_model_id = request.model.id;
let ai_model_name = request.model.name.clone();
let sent_in_turn = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let session_id = Uuid::now_v7();
let session_start = std::time::Instant::now();
let version_id = room_ai::Entity::find()
@ -46,7 +51,9 @@ where
if let Some(handler) = tool_registry.get(&name) {
let adapter = crate::tool::RigToolAdapter::new(
handler.clone(), def.clone(), db.clone(), cache.clone(), cfg.clone(),
room_id, Some(sender_uid), project_id,
room_id, Some(sender_uid), project_id, message_producer.clone(),
Some(ai_model_id), Some(ai_model_name.clone()),
sent_in_turn.clone(),
);
tools.push(Box::new(RecordingTool::new(Box::new(adapter), db.clone(), session_id, sender_uid)));
}
@ -54,8 +61,14 @@ where
let rig_client = client_config.build_rig_client();
let model = rig_client.completion_model(&request.model.name);
let preamble = match room_preamble {
Some(rp) => format!("{}\n{}", rp, DEFAULT_SYSTEM_PROMPT),
None => DEFAULT_SYSTEM_PROMPT.to_string(),
};
let agent = AgentBuilder::new(model)
.preamble(DEFAULT_SYSTEM_PROMPT)
.preamble(&preamble)
.tools(tools)
.default_max_turns(request.max_tool_depth)
.build();
@ -77,7 +90,8 @@ where
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(text))) => {
step_count += 1;
let t = text.text;
on_chunk(ReactStep::Answer { step: step_count, answer: t.clone() }).await;
let cleaned = t.replace('\n', "");
on_chunk(ReactStep::Answer { step: step_count, answer: cleaned }).await;
final_content.push_str(&t);
}
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Reasoning(reasoning))) => {
@ -120,7 +134,7 @@ where
}
let elapsed_ms = session_start.elapsed().as_millis() as i64;
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, elapsed_ms).await;
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, elapsed_ms).await;
Ok((final_content, total_input_tokens, total_output_tokens))
}

View File

@ -7,6 +7,7 @@ use crate::embed::EmbedService;
use crate::error::Result;
use crate::perception::PerceptionService;
use crate::tool::registry::ToolRegistry;
use queue::MessageProducer;
/// Result from streaming AI response.
pub struct StreamResult {
@ -94,7 +95,8 @@ impl ChatService {
) -> Option<crate::RigToolSet> {
self.tool_registry.as_ref().map(|registry| {
crate::RigToolSet::from_registry(
registry, db, cache, config, room_id, sender_id, project_id,
registry, db, cache, config, room_id, sender_id, project_id, None, None, None,
std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
)
})
}
@ -134,6 +136,35 @@ impl ChatService {
super::react_execution::execute_process_react(
request, on_chunk, registry,
self.ai_base_url.clone(), self.ai_api_key.clone(),
None, None,
).await
}
/// Process AI request via rig-based ReAct streaming loop with room-specific tools.
///
/// Merges `room_tools` (e.g. `send_message`, `retract_message`) into the base
/// tool registry on-the-fly. The `room_preamble` is prepended to the default
/// system prompt to instruct the AI about room communication rules.
/// `message_producer` enables tools to publish events via the message queue.
pub async fn process_react_room<C, Fut>(
&self, request: &AiRequest, on_chunk: C,
room_tools: ToolRegistry,
room_preamble: Option<&str>,
message_producer: Option<MessageProducer>,
) -> Result<(String, i64, i64)>
where
C: FnMut(crate::react::ReactStep) -> Fut + Send,
Fut: std::future::Future<Output = ()> + Send,
{
let Some(registry) = &self.tool_registry else {
return Err(crate::error::AgentError::Internal("no tool registry registered".into()));
};
let mut merged = registry.clone();
merged.merge(room_tools);
super::react_execution::execute_process_react(
request, on_chunk, &merged,
self.ai_base_url.clone(), self.ai_api_key.clone(),
room_preamble, message_producer,
).await
}
}

View File

@ -8,6 +8,7 @@ pub async fn record_ai_session(
cache: &AppCache,
db: &AppDatabase,
project_id: Uuid,
user_id: Uuid,
session_id: Uuid,
room_id: Uuid,
model_id: Uuid,
@ -39,7 +40,7 @@ pub async fn record_ai_session(
}
let (cost, currency, error_msg) = match crate::billing::record_ai_usage(
db, project_id, version_id, input_tokens, output_tokens,
db, project_id, user_id, version_id, input_tokens, output_tokens,
).await {
Ok(crate::billing::BillingResult::Success(record)) => {
(Some(record.cost), Some(record.currency), None)
@ -70,7 +71,7 @@ async fn create_billing_error_system_message(
use models::rooms::{room_message, MessageContentType, MessageSenderType};
use sea_orm::Set;
let seq_key = format!("room:seq:{}", room_id);
let seq_key = format!("seq:room:{}", room_id);
let seq = match cache.conn().await {
Ok(mut conn) => {
match redis::cmd("INCR").arg(&seq_key).query_async::<i64>(&mut conn).await {

View File

@ -62,7 +62,8 @@ pub async fn execute_process_stream(
&messages, &model_name, &config, temperature, max_tokens,
if tools_enabled { Some(&tools) } else { None }, None,
Arc::new(move |delta| {
let fut = on_chunk_cb(AiStreamChunk { content: delta.to_string(), done: false, chunk_type: AiChunkType::Answer });
let content = delta.to_string().replace('\n', "");
let fut = on_chunk_cb(AiStreamChunk { content, done: false, chunk_type: AiChunkType::Answer });
fut
}),
Arc::new(move |delta| {
@ -82,11 +83,10 @@ pub async fn execute_process_stream(
let has_tool_calls = tools_enabled && !response.tool_calls.is_empty();
if !has_tool_calls {
return handle_final_answer(response, full_content, on_chunk, all_chunks, &request, session_id, version_id, total_input_tokens, total_output_tokens, session_start).await;
return handle_final_answer(response, all_chunks, &request, session_id, version_id, total_input_tokens, total_output_tokens, session_start).await;
}
full_content.push_str(&response.content);
full_content.push('\n');
let tool_calls: Vec<ToolCall> = response.tool_calls.iter().map(|tc| ToolCall {
id: tc.id.clone(), type_: "function".into(),
@ -114,7 +114,7 @@ pub async fn execute_process_stream(
let max_depth_text = format!("[AI reached maximum tool depth ({}) — no final answer produced]", max_tool_depth);
on_chunk(AiStreamChunk { content: max_depth_text.clone(), done: true, chunk_type: AiChunkType::Answer }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: max_depth_text });
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
return Ok(StreamResult { content: full_content, reasoning_content: String::new(), input_tokens: 0, output_tokens: 0, chunks: all_chunks });
}
}
@ -155,60 +155,83 @@ async fn execute_streaming_tools(
if let Some(registry) = tool_registry { ctx.registry_mut().merge(registry.clone()); }
let recorder = crate::tool::recorder::ToolCallRecorder::with_session(request.db.clone(), session_id);
let mut join_set = tokio::task::JoinSet::new();
for call in calls {
let start = std::time::Instant::now();
let call_clone = call.clone();
let mut ctx_clone = ctx.clone();
let (result_tx, mut result_rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let sender_uid = request.sender.uid;
let recorder_clone = recorder.clone();
join_set.spawn(async move {
let start = std::time::Instant::now();
let executor = ToolExecutor::new();
let res = executor.execute_batch(vec![call_clone], &mut ctx_clone).await;
let _ = result_tx.send(res);
let res = executor.execute_batch(vec![call_clone.clone()], &mut ctx_clone).await;
(call_clone, res, start.elapsed(), sender_uid, recorder_clone)
});
}
let heartbeat_dur = std::time::Duration::from_secs(10);
let results = loop {
tokio::select! {
res = &mut result_rx => {
match res { Ok(inner) => break inner, Err(_) => break Err(crate::tool::ToolError::ExecutionError("tool task cancelled".into())), }
},
_ = tokio::time::sleep(heartbeat_dur) => {
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall }).await;
let heartbeat_dur = std::time::Duration::from_secs(10);
while !join_set.is_empty() {
tokio::select! {
Some(res) = join_set.join_next() => {
if let Ok((call, results, elapsed, sender_uid, recorder)) = res {
match results {
Ok(results) => {
for result in &results {
let text = match &result.result { crate::tool::ToolResult::Ok(v) => v.to_string(), crate::tool::ToolResult::Error(msg) => msg.clone() };
let preview = if text.len() > 300 {
let end = text.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
format!("{}...", &text[..end])
} else { text.clone() };
tracing::debug!("tool_result: {} — {}", call.name, preview);
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None };
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: sender_uid,
arguments: call.arguments_json().unwrap_or_default(),
status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success },
execution_time_ms: Some(elapsed.as_millis() as i64),
error_message: error_msg,
error_stack: None,
retry_count: 0
});
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
let msgs = ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
Err(e) => {
recorder.record(crate::tool::recorder::ToolCallRecord {
tool_call_id: call.id.clone(),
session_id: recorder.session_id(),
tool_name: call.name.clone(),
caller: sender_uid,
arguments: call.arguments_json().unwrap_or_default(),
status: models::ai::ToolCallStatus::Failed,
execution_time_ms: Some(elapsed.as_millis() as i64),
error_message: Some(e.to_string()),
error_stack: None,
retry_count: 0
});
let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
}
}
}
}
};
match results {
Ok(results) => {
for result in &results {
let text = match &result.result { crate::tool::ToolResult::Ok(v) => v.to_string(), crate::tool::ToolResult::Error(msg) => msg.clone() };
let preview = if text.len() > 300 {
let end = text.char_indices().map(|(i, _)| i).take_while(|&i| i <= 300).last().unwrap_or(300);
format!("{}...", &text[..end])
} else { text.clone() };
tracing::debug!("tool_result: {} — {}", call.name, preview);
let elapsed = start.elapsed().as_millis() as i64;
let is_error = matches!(result.result, crate::tool::ToolResult::Error(_));
let error_msg = match &result.result { crate::tool::ToolResult::Error(msg) => Some(msg.clone()), _ => None };
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.id.clone(), session_id: recorder.session_id(), tool_name: call.name.clone(), caller: request.sender.uid, arguments: call.arguments_json().unwrap_or_default(), status: if is_error { models::ai::ToolCallStatus::Failed } else { models::ai::ToolCallStatus::Success }, execution_time_ms: Some(elapsed), error_message: error_msg, error_stack: None, retry_count: 0 });
}
let success_display = format!("{}", call.name);
on_chunk(AiStreamChunk { content: success_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: success_display });
let msgs = ToolExecutor::to_tool_messages(&results);
tool_messages.extend(msgs);
}
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
recorder.record(crate::tool::recorder::ToolCallRecord { tool_call_id: call.id.clone(), session_id: recorder.session_id(), tool_name: call.name.clone(), caller: request.sender.uid, arguments: call.arguments_json().unwrap_or_default(), status: models::ai::ToolCallStatus::Failed, execution_time_ms: Some(elapsed), error_message: Some(e.to_string()), error_stack: None, retry_count: 0 });
let err_text = format!("[Tool call failed: {}]", e);
tracing::warn!(tool = %call.name, args = %call.arguments, error = %e, "tool_call_failed");
let err_display = format!("{} (failed)", call.name);
on_chunk(AiStreamChunk { content: err_display.clone(), done: false, chunk_type: AiChunkType::ToolResult }).await;
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::ToolCall, content: err_display });
tool_messages.push(ChatRequestMessage::tool(&call.id, &err_text));
},
_ = tokio::time::sleep(heartbeat_dur) => {
on_chunk(AiStreamChunk { content: String::new(), done: false, chunk_type: AiChunkType::ToolCall }).await;
}
}
}
@ -216,18 +239,20 @@ async fn execute_streaming_tools(
}
async fn handle_final_answer(
response: crate::client::StreamResponse, full_content: String,
on_chunk: SharedCallback,
response: crate::client::StreamResponse,
mut all_chunks: Vec<StreamChunk>, request: &AiRequest,
session_id: Uuid, version_id: Option<Uuid>,
total_input_tokens: i64, total_output_tokens: i64,
session_start: std::time::Instant,
) -> Result<StreamResult> {
let full_content = full_content + &response.content;
on_chunk(AiStreamChunk { content: response.content.clone(), done: true, chunk_type: AiChunkType::Answer }).await;
let full_content = response.content.clone();
// Don't broadcast the done chunk via SSE/NATS — incremental deltas
// already delivered the content; the separate completion event
// signals end of stream. Broadcasting full content again would
// duplicate it in the frontend streaming display.
all_chunks.push(StreamChunk { chunk_type: StreamChunkType::Answer, content: response.content.clone() });
record_ai_session(&request.cache, &request.db, request.project.id, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
Ok(StreamResult { content: full_content, reasoning_content: response.reasoning_content, input_tokens: response.input_tokens, output_tokens: response.output_tokens, chunks: all_chunks })
record_ai_session(&request.cache, &request.db, request.project.id, request.sender.uid, session_id, request.room.id, request.model.id, version_id.unwrap_or_default(), total_input_tokens, total_output_tokens, session_start.elapsed().as_millis() as i64).await;
Ok(StreamResult { content: full_content, reasoning_content: response.reasoning_content, input_tokens: total_input_tokens, output_tokens: total_output_tokens, chunks: all_chunks })
}
async fn inject_passive_skills_stream(

View File

@ -106,8 +106,10 @@ impl RetryState {
fn backoff_duration(&self) -> std::time::Duration {
let exp = self.attempt.min(5);
let base_ms = 500u64.saturating_mul(2u64.pow(exp)).min(self.max_backoff_ms);
let jitter = fastrand_u64(base_ms + 1);
std::time::Duration::from_millis(jitter)
let max_jitter = (base_ms / 2).max(base_ms);
let offset = fastrand_u64(max_jitter + 1).saturating_sub(base_ms / 2);
let total = base_ms.saturating_add(offset).min(self.max_backoff_ms);
std::time::Duration::from_millis(total)
}
fn next(&mut self) { self.attempt += 1; }
}

View File

@ -4,18 +4,14 @@ use models::rooms::room_message::{
Column as RmCol, Entity as RoomMessage, Model as RoomMessageModel,
};
use models::users::user::{Column as UserCol, Entity as User};
use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder};
use serde_json::Value;
use sea_orm::{DatabaseConnection, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
use uuid::Uuid;
use crate::client::types::ChatRequestMessage;
use crate::client::AiClientConfig;
use crate::client::call_with_params;
use crate::AgentError;
use crate::compact::helpers::summary_content;
use crate::compact::types::{
CompactConfig, CompactLevel, CompactSummary, MessageSummary, ThresholdResult,
};
use crate::compact::types::{CompactLevel, CompactSummary, MessageSummary};
use crate::tokent::{TokenUsage, resolve_usage};
#[derive(Clone)]
@ -35,8 +31,29 @@ impl CompactService {
room_id: Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>,
requester_id: Uuid,
context_window_tokens: i32,
compaction_max_summary_ratio: f32,
) -> Result<CompactSummary, AgentError> {
let messages = self.fetch_room_messages(room_id).await?;
// Verify room access at the database level to ensure auth context is enforced.
// Public rooms are accessible to project members.
// For simplicity in this audit fix, we'll fetch only if access exists.
let messages = self.fetch_room_messages_secure(room_id, requester_id).await?;
if messages.is_empty() {
// Check if room actually exists or if it's just empty/inaccessible
let room_exists = models::rooms::room::Entity::find_by_id(room_id)
.one(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?
.is_some();
if room_exists {
return Err(AgentError::Internal("Access denied or room empty".into()));
} else {
return Err(AgentError::Internal("Room not found".into()));
}
}
let user_ids: Vec<Uuid> = messages
.iter()
@ -74,7 +91,9 @@ impl CompactService {
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
// Build text of what was summarized (for tiktoken fallback)
let summarized_text = to_summarize
@ -100,10 +119,13 @@ impl CompactService {
session_id: Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>,
context_window_tokens: i32,
compaction_max_summary_ratio: f32,
) -> Result<CompactSummary, AgentError> {
let messages: Vec<RoomMessageModel> = RoomMessage::find()
.filter(RmCol::Room.eq(session_id))
.order_by_asc(RmCol::Seq)
.limit(10000)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
@ -148,10 +170,10 @@ impl CompactService {
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
// Summarize the earlier messages
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
let max_summary_tokens = (context_window_tokens as f32 * compaction_max_summary_ratio) as usize;
let (summary, remote_usage) = self.summarize_messages(to_summarize, max_summary_tokens).await?;
// Build text of what was summarized (for tiktoken fallback)
let summarized_text = to_summarize
.iter()
.map(|m| m.content.as_str())
@ -170,164 +192,51 @@ impl CompactService {
})
}
pub fn summary_as_system_message(summary: &CompactSummary) -> ChatRequestMessage {
let content = summary_content(summary);
ChatRequestMessage::system(content)
}
/// Check if the message history for a room exceeds the token threshold.
/// Returns `ThresholdResult::Skip` if below threshold, `Compact` if above.
///
/// This method fetches messages and estimates their token count using tiktoken.
/// Call this before deciding whether to run full compaction.
pub async fn check_threshold(
async fn fetch_room_messages_secure(
&self,
room_id: Uuid,
config: CompactConfig,
) -> Result<ThresholdResult, AgentError> {
let messages = self.fetch_room_messages(room_id).await?;
let tokens = self.estimate_message_tokens(&messages);
requester_id: Uuid,
) -> Result<Vec<RoomMessageModel>, AgentError> {
use models::rooms::{RoomUserState, RoomAccess};
use sea_orm::QueryTrait;
use sea_orm::sea_query::Expr;
// Find messages for the room where the requester has access.
// We check both the room_user_state table (membership) and the room_access table (explicit grants).
RoomMessage::find()
.filter(RmCol::Room.eq(room_id))
.filter(
sea_orm::Condition::any()
.add(
Expr::exists(
RoomUserState::find()
.filter(models::rooms::room_user_state::Column::Room.eq(room_id))
.filter(models::rooms::room_user_state::Column::User.eq(requester_id))
.into_query()
)
)
.add(
Expr::exists(
RoomAccess::find()
.filter(models::rooms::room_access::Column::Room.eq(room_id))
.filter(models::rooms::room_access::Column::User.eq(requester_id))
.into_query()
)
)
)
.order_by_asc(RmCol::Seq)
.limit(10000)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))
}
if tokens < config.token_threshold {
return Ok(ThresholdResult::Skip {
estimated_tokens: tokens,
});
}
let level = if config.auto_level {
CompactLevel::auto_select(tokens, config.token_threshold)
fn message_to_summary(m: &RoomMessageModel, user_name_map: &std::collections::HashMap<Uuid, String>) -> MessageSummary {
let sender_name = if let Some(user_id) = m.sender_id {
user_name_map.get(&user_id).cloned().unwrap_or_else(|| m.sender_type.to_string())
} else {
config.default_level
m.sender_type.to_string()
};
Ok(ThresholdResult::Compact {
estimated_tokens: tokens,
level,
})
}
/// Auto-compact a room: estimates token count, only compresses if over threshold.
///
/// This is the recommended entry point for automatic compaction.
/// - If tokens < threshold → returns a no-op summary (empty summary, no compression)
/// - If tokens >= threshold → compresses with auto-selected level
pub async fn compact_room_auto(
&self,
room_id: Uuid,
user_names: Option<std::collections::HashMap<Uuid, String>>,
config: CompactConfig,
) -> Result<CompactSummary, AgentError> {
let threshold_result = self.check_threshold(room_id, config).await?;
match threshold_result {
ThresholdResult::Skip { .. } => {
// Below threshold — no compaction needed, return empty summary
let messages = self.fetch_room_messages(room_id).await?;
let user_ids: Vec<Uuid> = messages.iter().filter_map(|m| m.sender_id).collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary: String::new(),
compacted_at: Utc::now(),
messages_compressed: 0,
usage: None,
});
}
ThresholdResult::Compact { level, .. } => {
// Above threshold — compress with selected level
return self
.compact_room_with_level(room_id, level, user_names)
.await;
}
}
}
/// Compact a room with a specific level (bypassing threshold check).
/// Use this when the caller has already decided compaction is needed.
async fn compact_room_with_level(
&self,
room_id: Uuid,
level: CompactLevel,
user_names: Option<std::collections::HashMap<Uuid, String>>,
) -> Result<CompactSummary, AgentError> {
let messages = self.fetch_room_messages(room_id).await?;
let user_ids: Vec<Uuid> = messages.iter().filter_map(|m| m.sender_id).collect();
let user_name_map = match user_names {
Some(map) => map,
None => self.get_user_name_map(&user_ids).await?,
};
if messages.len() <= level.retain_count() {
let retained: Vec<MessageSummary> = messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
return Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary: String::new(),
compacted_at: Utc::now(),
messages_compressed: 0,
usage: None,
});
}
let retain_count = level.retain_count();
let split_index = messages.len().saturating_sub(retain_count);
let (to_summarize, retained_messages) = messages.split_at(split_index);
let retained: Vec<MessageSummary> = retained_messages
.iter()
.map(|m| Self::message_to_summary(m, &user_name_map))
.collect();
let (summary, remote_usage) = self.summarize_messages(to_summarize).await?;
let summarized_text = to_summarize
.iter()
.map(|m| m.content.as_str())
.collect::<Vec<_>>()
.join("\n");
let usage = resolve_usage(remote_usage, &self.model, &summarized_text, &summary);
Ok(CompactSummary {
session_id: Uuid::new_v4(),
room_id,
retained,
summary,
compacted_at: Utc::now(),
messages_compressed: to_summarize.len(),
usage: Some(usage),
})
}
/// Estimate total token count of a message list using tiktoken.
fn estimate_message_tokens(&self, messages: &[RoomMessageModel]) -> usize {
let total_chars: usize = messages.iter().map(|m| m.content.len()).sum();
// Rough estimate: ~4 chars per token (safe upper bound)
total_chars / 4
}
fn message_to_summary(
m: &RoomMessageModel,
user_name_map: &std::collections::HashMap<Uuid, String>,
) -> MessageSummary {
let sender_name = m
.sender_id
.and_then(|id| user_name_map.get(&id).cloned())
.unwrap_or_else(|| m.sender_type.to_string());
MessageSummary {
id: m.id,
sender_type: m.sender_type.clone(),
@ -335,35 +244,11 @@ impl CompactService {
sender_name,
content: m.content.clone(),
content_type: m.content_type.clone(),
tool_call_id: Self::extract_tool_call_id(&m.content),
tool_call_id: None,
send_at: m.send_at,
}
}
fn extract_tool_call_id(content: &str) -> Option<String> {
let content = content.trim();
if let Ok(v) = serde_json::from_str::<Value>(content) {
v.get("tool_call_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
} else {
None
}
}
async fn fetch_room_messages(
&self,
room_id: Uuid,
) -> Result<Vec<RoomMessageModel>, AgentError> {
let messages: Vec<RoomMessageModel> = RoomMessage::find()
.filter(RmCol::Room.eq(room_id))
.order_by_asc(RmCol::Seq)
.all(&self.db)
.await
.map_err(|e| AgentError::Internal(e.to_string()))?;
Ok(messages)
}
async fn get_user_name_map(
&self,
user_ids: &[Uuid],
@ -386,8 +271,8 @@ impl CompactService {
async fn summarize_messages(
&self,
messages: &[RoomMessageModel],
max_summary_tokens: usize,
) -> Result<(String, Option<TokenUsage>), AgentError> {
// Collect distinct user IDs
let user_ids: Vec<Uuid> = messages
.iter()
.filter_map(|m| m.sender_id)
@ -395,10 +280,8 @@ impl CompactService {
.into_iter()
.collect();
// Query usernames
let user_name_map = self.get_user_name_map(&user_ids).await?;
// Define sender mapper
let sender_mapper = |m: &RoomMessageModel| {
if let Some(user_id) = m.sender_id {
if let Some(username) = user_name_map.get(&user_id) {
@ -413,11 +296,13 @@ impl CompactService {
let user_msg = ChatRequestMessage::user(format!(
"Summarise the following conversation concisely, preserving all key facts, \
decisions, and any pending or in-progress work. \
The summary MUST NOT exceed {} tokens. \
Use this format:\n\n\
**Summary:** <one-paragraph overview>\n\
**Key decisions:** <bullet list or 'none'>\n\
**Open items:** <bullet list or 'none'>\n\n\
Conversation:\n\n{}",
max_summary_tokens,
body
));
@ -425,8 +310,8 @@ impl CompactService {
&[user_msg],
&self.model,
&self.ai_client_config,
0.3, // slightly higher temp for summarization
1024, // max output tokens
0.3,
2048,
None,
None,
None,
@ -434,7 +319,6 @@ impl CompactService {
.await
.map_err(|e| AgentError::OpenAi(e.to_string()))?;
// Prefer remote usage; fall back to None (caller will use tiktoken via resolve_usage)
let remote_usage =
TokenUsage::from_remote(response.input_tokens as u32, response.output_tokens as u32);

View File

@ -74,6 +74,8 @@ pub struct CompactConfig {
pub auto_level: bool,
/// Fallback level when `auto_level` is false.
pub default_level: CompactLevel,
/// Maximum tokens the summary may contain (enforced via prompt).
pub max_summary_tokens: usize,
}
impl Default for CompactConfig {
@ -83,6 +85,20 @@ impl Default for CompactConfig {
token_threshold: 8000,
auto_level: true,
default_level: CompactLevel::Light,
max_summary_tokens: 256,
}
}
}
impl CompactConfig {
/// Build config from project context settings.
pub fn from_project_setting(context_window_tokens: i32, compaction_threshold: f32, compaction_max_summary_ratio: f32) -> Self {
let threshold = (context_window_tokens as f32 * compaction_threshold) as usize;
Self {
token_threshold: threshold,
auto_level: true,
default_level: CompactLevel::Light,
max_summary_tokens: (context_window_tokens as f32 * compaction_max_summary_ratio) as usize,
}
}
}

View File

@ -575,11 +575,5 @@ pub struct EmbedMemoryInput {
}
/// Input struct for batch tag embedding.
#[derive(Debug, Clone)]
pub struct TagEmbedInput {
pub repo_id: String,
pub repo_name: String,
pub project_id: String,
pub name: String,
pub description: Option<String>,
}
/// Re-exported from models for backward compatibility.
pub use models::TagEmbedInput;

View File

@ -52,3 +52,9 @@ impl From<sea_orm::DbErr> for AgentError {
AgentError::Internal(e.to_string())
}
}
impl From<crate::tool::ToolError> for AgentError {
fn from(e: crate::tool::ToolError) -> Self {
AgentError::ToolExecutionFailed { tool: String::new(), cause: e.to_string() }
}
}

View File

@ -13,7 +13,7 @@ pub mod sync;
pub mod task;
pub mod tokent;
pub mod tool;
pub use billing::{BillingRecord, BillingResult, record_ai_usage};
pub use billing::{BillingRecord, BillingResult, record_ai_usage, initialize_user_billing, initialize_project_billing, check_balance, persist_billing_error};
pub use sync::list_accessible_models;
pub use task::TaskService;
pub use tokent::{TokenUsage, resolve_usage};
@ -33,7 +33,7 @@ pub use embed::{
EmbedClient, EmbedMemoryInput, EmbedService, QdrantClient, SearchResult, TagEmbedInput, new_embed_client,
};
pub use error::{AgentError, Result};
pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT};
pub use react::{ReactConfig, ReactStep, DEFAULT_SYSTEM_PROMPT, ROOM_CONTEXT_PROMPT};
pub use tool::{
ToolCall, ToolCallRecord, ToolCallRecorder, ToolCallResult, ToolContext, ToolDefinition, ToolError, ToolExecutor, ToolHandler, ToolParam,
ToolRegistry, ToolResult, ToolSchema,

View File

@ -44,6 +44,10 @@ impl Default for VectorActiveAwareness {
}
impl VectorActiveAwareness {
pub fn new(max_skills: usize, min_score: f32) -> Self {
Self { max_skills, min_score }
}
/// Search for skills semantically relevant to the user's input.
///
/// Uses Qdrant vector search within the given project to find skills whose
@ -107,6 +111,10 @@ impl Default for VectorPassiveAwareness {
}
impl VectorPassiveAwareness {
pub fn new(max_memories: usize, min_score: f32) -> Self {
Self { max_memories, min_score }
}
/// Search for past conversation messages semantically similar to the current context.
///
/// Uses Qdrant to find memories within the same room that share semantic similarity

View File

@ -16,7 +16,7 @@ pub const DEFAULT_SYSTEM_PROMPT: &str = r#"You are an AI assistant embedded in a
## Core Rule: Search Local Data First
Always query the platform's local data before guessing or referring to external sources. Local data includes: issues, pull requests, repositories, code reviews, chat messages, documentation, members, and other workspace resources.
Always query the platform's local data before guessing or referring to external sources. Local data includes: issues, pull requests, repositories, code reviews, chat messages, documentation, members, and other project resources.
If local data does not contain the answer, state that clearly before considering external information.
@ -38,3 +38,32 @@ If local data does not contain the answer, state that clearly before considering
- State ambiguity or uncertainty explicitly.
- Prefer facts over speculation.
"#;
/// Room-specific system prompt appended when the AI is @mentioned in a chat room.
///
/// In room context, the AI must NOT produce long-form output directly. Instead,
/// it communicates through the `send_message` and `retract_message` tools.
/// This keeps room messages concise and gives the AI control over what appears
/// in the room.
pub const ROOM_CONTEXT_PROMPT: &str = r#"
## Room Communication Mode CRITICAL
You are NOT in a direct chat. You are @mentioned in a chat room. **Your default response text will NOT be seen by anyone.** The ONLY way to communicate with the room is through the tools listed below.
### Mandatory Communication Rules
1. **ALWAYS use `send_message`** to deliver ANY response to the room. No exceptions. If you produce a final text response without calling `send_message`, the room will receive NOTHING.
2. **Call `send_message` FIRST**, before any final text output. The tool call is what creates a visible room message.
3. **Keep each message concise** short, focused, actionable. No long reports, no multi-paragraph essays, no bullet lists longer than 5 items. If you need to convey a lot of information, summarize the key points and offer to provide details if asked.
4. **Use mentions** to reference entities: `@[user:uuid:username]` for users, `@[repo:uuid:name]` for repositories, `@[skill:slug]` for skills, `@[issue:uuid:title]` for issues, `@[ai:uuid:name]` for other AI models.
5. **Use `retract_message`** to revoke a message you just sent if it contains an error or needs to be withdrawn. You can only retract messages you sent in the current turn.
6. **You may send multiple messages** for complex responses, break your answer into multiple `send_message` calls (up to 99 per turn). Each message should be short, focused, and stand on its own. For example: first send a summary, then send follow-up details or action items as separate messages.
7. **After calling `send_message`, your final text response can be brief** just a summary or acknowledgment, since the actual room message has already been delivered via the tool call.
### Critical Reminder
Your response text output is NOT delivered to the room. The `send_message` tool IS the delivery mechanism. If you forget to call `send_message`, nobody in the room will see your response.
### Room-Only Tools
- `send_message(room_id?, content)` Send a brief message to the room. The `room_id` parameter is optional (defaults to the current room). The `content` parameter is required and supports `@[type:id:label]` mention syntax.
- `retract_message(message_id)` Retract (revoke) a message you sent in the current turn. Requires the message UUID returned by `send_message`.
"#;

View File

@ -9,8 +9,18 @@
//! return usage metadata (e.g., local models, streaming), tiktoken is used as
//! a fallback for accurate counting.
use std::collections::HashMap;
use std::sync::OnceLock;
use std::sync::RwLock;
use crate::error::{AgentError, Result};
static TOKENIZER_CACHE: OnceLock<RwLock<HashMap<String, tiktoken_rs::CoreBPE>>> = OnceLock::new();
fn get_cached_tokenizers() -> &'static RwLock<HashMap<String, tiktoken_rs::CoreBPE>> {
TOKENIZER_CACHE.get_or_init(|| RwLock::new(HashMap::new()))
}
/// Token usage data. Use `from_remote()` when the API returns usage info,
/// or `from_estimate()` when falling back to tiktoken.
#[derive(Debug, Clone, Copy, Default, serde::Serialize, serde::Deserialize)]
@ -155,14 +165,28 @@ fn safe_token_budget(context_limit: usize, reserve: usize) -> usize {
fn get_tokenizer(model: &str) -> Result<tiktoken_rs::CoreBPE> {
use tiktoken_rs;
// Try model-specific tokenizer first
if let Ok(bpe) = tiktoken_rs::get_bpe_from_model(model) {
return Ok(bpe);
{
let cache = get_cached_tokenizers().read().unwrap();
if let Some(bpe) = cache.get(model) {
return Ok(bpe.clone());
}
}
// Fallback: use cl100k_base for unknown models
tiktoken_rs::cl100k_base()
.map_err(|e| AgentError::Internal(format!("Failed to init tokenizer: {}", e)))
// Try model-specific tokenizer first
let bpe = if let Ok(bpe) = tiktoken_rs::get_bpe_from_model(model) {
bpe
} else {
// Fallback: use cl100k_base for unknown models
tiktoken_rs::cl100k_base()
.map_err(|e| AgentError::Internal(format!("Failed to init tokenizer: {}", e)))?
};
{
let mut cache = get_cached_tokenizers().write().unwrap();
cache.insert(model.to_string(), bpe.clone());
}
Ok(bpe)
}
/// Estimate tokens for a simple prefix/suffix pattern (e.g., "assistant\n" + text).

View File

@ -8,6 +8,7 @@ use std::sync::Arc;
use db::cache::AppCache;
use db::database::AppDatabase;
use config::AppConfig;
use queue::MessageProducer;
use uuid::Uuid;
use super::registry::ToolRegistry;
@ -28,6 +29,15 @@ struct Inner {
pub project_id: Uuid,
pub registry: ToolRegistry,
pub embed_service: Option<crate::embed::EmbedService>,
pub message_producer: Option<MessageProducer>,
/// When in room context, identifies the AI model that is responding.
/// Used by send_message/retract_message to set the correct sender.
pub ai_model_id: Option<Uuid>,
pub ai_model_name: Option<String>,
/// Message IDs sent by the AI in the current ReAct turn.
/// Shared across tool calls so send_message can register IDs
/// and retract_message can validate turn-scoped retraction.
pub sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<Uuid>>>,
depth: u32,
max_depth: u32,
tool_call_count: usize,
@ -52,6 +62,10 @@ impl ToolContext {
project_id: Uuid::nil(),
registry: ToolRegistry::new(),
embed_service: None,
message_producer: None,
ai_model_id: None,
ai_model_name: None,
sent_in_turn: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
depth: 0,
max_depth: 5,
tool_call_count: 0,
@ -85,10 +99,45 @@ impl ToolContext {
self
}
pub fn with_message_producer(mut self, producer: MessageProducer) -> Self {
Arc::make_mut(&mut self.inner).message_producer = Some(producer);
self
}
pub fn with_ai_model(mut self, model_id: Uuid, model_name: String) -> Self {
Arc::make_mut(&mut self.inner).ai_model_id = Some(model_id);
Arc::make_mut(&mut self.inner).ai_model_name = Some(model_name);
self
}
pub fn with_sent_in_turn(mut self, sent: std::sync::Arc<std::sync::Mutex<Vec<Uuid>>>) -> Self {
Arc::make_mut(&mut self.inner).sent_in_turn = sent;
self
}
/// Register a message ID as sent in the current turn (called by send_message).
pub fn register_sent_message(&self, id: Uuid) {
if let Ok(mut list) = self.inner.sent_in_turn.lock() {
list.push(id);
}
}
/// Check if a message ID was sent in the current turn (called by retract_message).
pub fn is_sent_in_turn(&self, id: Uuid) -> bool {
self.inner.sent_in_turn.lock()
.map(|list| list.contains(&id))
.unwrap_or(false)
}
pub fn embed_service(&self) -> Option<&crate::embed::EmbedService> {
self.inner.embed_service.as_ref()
}
/// Message queue producer for publishing room events (messages, retractions, etc.).
pub fn message_producer(&self) -> Option<&MessageProducer> {
self.inner.message_producer.as_ref()
}
pub fn recursion_exceeded(&self) -> bool {
self.inner.depth >= self.inner.max_depth
}
@ -146,6 +195,16 @@ impl ToolContext {
self.inner.sender_id
}
/// AI model ID when in room context (the AI that is responding).
pub fn ai_model_id(&self) -> Option<Uuid> {
self.inner.ai_model_id
}
/// AI model display name when in room context.
pub fn ai_model_name(&self) -> Option<String> {
self.inner.ai_model_name.clone()
}
/// Project context for the room.
pub fn project_id(&self) -> Uuid {
self.inner.project_id

View File

@ -14,6 +14,7 @@ use super::context::ToolContext;
use super::definition::ToolDefinition as AgentToolDefinition;
use super::recorder::{ToolCallRecord, ToolCallRecorder};
use super::registry::{ToolHandler, ToolRegistry};
use queue::MessageProducer;
/// Returns true if the tool error message indicates a transient failure that can be retried.
pub fn is_retryable_tool_error(msg: &str) -> bool {
@ -170,6 +171,10 @@ impl RigToolSet {
room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
message_producer: Option<MessageProducer>,
ai_model_id: Option<uuid::Uuid>,
ai_model_name: Option<String>,
sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<uuid::Uuid>>>,
) -> Self {
let mut toolset = ToolSet::default();
let mut definitions = HashMap::new();
@ -191,6 +196,10 @@ impl RigToolSet {
room_id,
sender_id,
project_id,
message_producer: message_producer.clone(),
ai_model_id,
ai_model_name: ai_model_name.clone(),
sent_in_turn: sent_in_turn.clone(),
};
toolset.add_tool(adapter);
}
@ -227,6 +236,10 @@ pub struct RigToolAdapter {
room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
message_producer: Option<MessageProducer>,
ai_model_id: Option<uuid::Uuid>,
ai_model_name: Option<String>,
sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<uuid::Uuid>>>,
}
impl RigToolAdapter {
@ -240,8 +253,12 @@ impl RigToolAdapter {
room_id: uuid::Uuid,
sender_id: Option<uuid::Uuid>,
project_id: uuid::Uuid,
message_producer: Option<MessageProducer>,
ai_model_id: Option<uuid::Uuid>,
ai_model_name: Option<String>,
sent_in_turn: std::sync::Arc<std::sync::Mutex<Vec<uuid::Uuid>>>,
) -> Self {
Self { handler, definition, db, cache, config, room_id, sender_id, project_id }
Self { handler, definition, db, cache, config, room_id, sender_id, project_id, message_producer, ai_model_id, ai_model_name, sent_in_turn }
}
}
@ -272,16 +289,27 @@ impl ToolDyn for RigToolAdapter {
let room_id = self.room_id;
let sender_id = self.sender_id;
let project_id = self.project_id;
let message_producer = self.message_producer.clone();
let ai_model_id = self.ai_model_id;
let ai_model_name = self.ai_model_name.clone();
let sent_in_turn = self.sent_in_turn.clone();
async move {
let ctx = ToolContext::new(
let mut ctx = ToolContext::new(
db,
cache,
config,
room_id,
sender_id,
)
.with_project(project_id);
.with_project(project_id)
.with_sent_in_turn(sent_in_turn);
if let Some(mp) = message_producer {
ctx = ctx.with_message_producer(mp);
}
if let Some(mid) = ai_model_id {
ctx = ctx.with_ai_model(mid, ai_model_name.unwrap_or_default());
}
let args_json: serde_json::Value = serde_json::from_str(&args)
.map_err(|e| ToolError::JsonError(e))?;

View File

@ -26,6 +26,7 @@ email = { workspace = true }
tracing = { workspace = true }
service = { workspace = true }
session = { workspace = true }
agent = { workspace = true }
git = { workspace = true }
#frontend = { workspace = true }
models = { workspace = true }
@ -51,5 +52,12 @@ sea-orm = "2.0.0-rc.37"
rust_decimal = "1.40.0"
actix-multipart = { workspace = true, features = ["tempfile"] }
redis = { workspace = true }
reqwest = { workspace = true, features = ["json", "native-tls", "stream"] }
[build-dependencies]
brotli = "7"
flate2 = "1"
sha2 = "0.10"
[lints]
workspace = true

View File

@ -18,7 +18,7 @@ pub fn init_agent_routes(cfg: &mut web::ServiceConfig) {
web::post().to(code_review::trigger_code_review),
)
.route(
"/{project}/issues/{issue_number}/triage",
"/{project}/triage",
web::get().to(issue_triage::triage_issue),
)
.route(

View File

@ -1,11 +1,12 @@
use actix_web::{HttpResponse, Result, web};
use serde::Serialize;
use session::SessionUser;
use session::Session;
use utoipa::ToSchema;
use crate::ApiResponse;
use crate::error::ApiError;
use service::AppService;
use service::error::AppError;
use service::ws_token::WS_TOKEN_TTL_SECONDS;
#[derive(Debug, Serialize, ToSchema)]
@ -27,13 +28,16 @@ pub struct WsTokenResponse {
)]
pub async fn ws_token_generate(
service: web::Data<AppService>,
session_user: SessionUser,
session: Session,
) -> Result<HttpResponse, ApiError> {
let SessionUser(user_id) = session_user;
let user_id = session.user().ok_or_else(|| ApiError::from(AppError::Unauthorized))?;
let device_id = session.get::<String>("device_id").unwrap_or_default();
let client_id = session.get::<String>("client_id").unwrap_or_default();
let token = service
.ws_token
.generate_token(user_id)
.generate_token(user_id, device_id, client_id)
.await
.map_err(ApiError::from)?;

View File

@ -1 +1,224 @@
fn main() {}
//! Build script: reads all files from `dist/`, compresses them (brotli + gzip),
//! computes etags via SHA-256, and generates a `frontend` Rust module for `dist.rs`.
use std::collections::BTreeMap;
use std::env;
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use sha2::{Digest, Sha256};
use flate2::write::GzEncoder;
use flate2::Compression;
// ── Compression helpers ──────────────────────────────────────────────────
fn gzip_compress(data: &[u8]) -> Vec<u8> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::new(6));
encoder.write_all(data).unwrap();
encoder.finish().unwrap()
}
fn brotli_compress(data: &[u8]) -> Option<Vec<u8>> {
use brotli::CompressorWriter;
let buf = Vec::new();
let mut writer = CompressorWriter::new(buf, 4096, 6, 16);
if writer.write_all(data).is_ok() && writer.flush().is_ok() {
Some(writer.into_inner())
} else {
None
}
}
// ── ETag computation ─────────────────────────────────────────────────────
fn compute_etag(data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(data);
let hash = hasher.finalize();
// First 32 hex chars for a compact etag
format!("{:x}", hash)[..32].to_string()
}
// ── Asset collection ─────────────────────────────────────────────────────
struct Asset {
path: String,
data: Vec<u8>,
etag: String,
brotli: Option<Vec<u8>>,
gzip: Vec<u8>,
}
fn collect_assets(dist_dir: &Path) -> BTreeMap<String, Asset> {
let mut assets = BTreeMap::new();
for entry in walkdir(dist_dir) {
let rel = entry.strip_prefix(dist_dir).unwrap();
let path_str = rel.to_string_lossy().replace('\\', "/");
if path_str.is_empty() {
continue;
}
let data = fs::read(&entry).unwrap_or_else(|e| {
panic!("Failed to read dist file {}: {}", path_str, e)
});
let etag = compute_etag(&data);
let brotli_data = brotli_compress(&data);
let gzip_data = gzip_compress(&data);
assets.insert(
path_str.clone(),
Asset {
path: path_str,
data,
etag,
brotli: brotli_data,
gzip: gzip_data,
},
);
}
assets
}
fn walkdir(dir: &Path) -> Vec<PathBuf> {
let mut files = Vec::new();
if let Ok(entries) = fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
files.extend(walkdir(&path));
} else {
files.push(path);
}
}
}
files
}
// ── Code generation ──────────────────────────────────────────────────────
fn rust_byte_literal(data: &[u8]) -> String {
if data.len() < 200 {
let bytes: Vec<String> = data.iter().map(|b| b.to_string()).collect();
format!("[{}]", bytes.join(", "))
} else {
let lines: Vec<String> = data
.chunks(80)
.map(|chunk| {
chunk.iter().map(|b| b.to_string()).collect::<Vec<_>>().join(", ")
})
.collect();
format!("[\n{}\n]", lines.join(",\n"))
}
}
fn path_to_ident(path: &str) -> String {
let s = path.replace('-', "_").replace('.', "_").replace('/', "_").to_uppercase();
format!("ASSET_{s}")
}
fn etag_ident(path: &str) -> String {
format!("ETAG_{}", path_to_ident(path))
}
fn br_ident(path: &str) -> String {
format!("{}_BR", path_to_ident(path))
}
fn gz_ident(path: &str) -> String {
format!("{}_GZ", path_to_ident(path))
}
fn generate_frontend_module(assets: &BTreeMap<String, Asset>, out_dir: &Path) {
let mut code = String::new();
code += "// AUTO-GENERATED by build.rs — DO NOT EDIT.\n";
code += "// Frontend assets from dist/ embedded as static byte arrays.\n\n";
// Generate static byte arrays for each asset
for (path, asset) in assets {
let ident = path_to_ident(path);
let etag_id = etag_ident(path);
let br_id = br_ident(path);
let gz_id = gz_ident(path);
code += &format!("static {}: &[u8] = &{};\n", ident, rust_byte_literal(&asset.data));
code += &format!("static {}: &str = \"{}\";\n", etag_id, asset.etag);
if let Some(ref br) = asset.brotli {
code += &format!("static {}: &[u8] = &{};\n", br_id, rust_byte_literal(br));
}
code += &format!("static {}: &[u8] = &{};\n", gz_id, rust_byte_literal(&asset.gzip));
code += "\n";
}
// Uncompressed lookup
code += "/// Get an uncompressed frontend asset by path, returning (data, etag).\n";
code += "pub fn get_frontend_asset_with_etag(path: &str) -> Option<(&'static [u8], &'static str)> {\n";
code += " match path {\n";
for (path, _asset) in assets {
let ident = path_to_ident(path);
let etag_id = etag_ident(path);
code += &format!(" \"{path}\" => Some((&{ident}, &{etag_id})),\n");
}
code += " _ => None,\n";
code += " }\n";
code += "}\n\n";
// Compressed lookup (prefers brotli, falls back to gzip)
code += "/// Get a pre-compressed frontend asset by path.\n";
code += "/// Returns (data, encoding, etag) — prefers brotli over gzip.\n";
code += "pub fn get_frontend_asset_compressed(path: &str) -> Option<(&'static [u8], &'static str, &'static str)> {\n";
code += " match path {\n";
for (path, asset) in assets {
let etag_id = etag_ident(path);
if asset.brotli.is_some() {
let br_id = br_ident(path);
code += &format!(" \"{path}\" => Some((&{br_id}, \"br\", &{etag_id})),\n");
} else {
let gz_id = gz_ident(path);
code += &format!(" \"{path}\" => Some((&{gz_id}, \"gzip\", &{etag_id})),\n");
}
}
code += " _ => None,\n";
code += " }\n";
code += "}\n";
let out_path = out_dir.join("frontend.rs");
fs::write(&out_path, code).unwrap_or_else(|e| {
panic!("Failed to write generated frontend.rs: {}", e)
});
}
fn main() {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let workspace_root = Path::new(&manifest_dir)
.parent()
.unwrap()
.parent()
.unwrap();
let dist_dir = workspace_root.join("dist");
if !dist_dir.exists() {
println!("cargo:warning=dist/ directory not found — frontend assets will not be embedded");
let out_dir = env::var("OUT_DIR").unwrap();
let out_path = Path::new(&out_dir).join("frontend.rs");
fs::write(
&out_path,
"//! No dist/ directory found — frontend assets not embedded.\n\
pub fn get_frontend_asset_with_etag(_path: &str) -> Option<(&'static [u8], &'static str)> { None }\n\
pub fn get_frontend_asset_compressed(_path: &str) -> Option<(&'static [u8], &'static str, &'static str)> { None }\n",
).unwrap();
return;
}
println!("cargo:rerun-if-changed=dist/");
let assets = collect_assets(&dist_dir);
println!("cargo:warning=Collected {} frontend assets from dist/", assets.len());
let out_dir = env::var("OUT_DIR").unwrap();
generate_frontend_module(&assets, Path::new(&out_dir));
}

View File

@ -0,0 +1,180 @@
use actix_web::{web, HttpResponse, Result};
use session::Session;
use service::error::AppError;
use uuid::Uuid;
use crate::error::ApiError;
use crate::ApiResponse;
use super::types::{ConversationListQuery, ConversationResponse, CreateConversationParams};
fn get_user_id(session: &Session) -> Result<Uuid, ApiError> {
session.user().ok_or_else(|| ApiError::from(AppError::Unauthorized))
}
#[utoipa::path(
post,
path = "/api/ai/conversations",
operation_id = "ai_conversation_create",
request_body = CreateConversationParams,
responses(
(status = 200, description = "Conversation created", body = ApiResponse<ConversationResponse>),
(status = 401, description = "Unauthorized"),
),
tag = "AI Chat"
)]
pub async fn conversation_create(
service: web::Data<service::AppService>,
session: Session,
params: web::Json<CreateConversationParams>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let model = params.model.clone().unwrap_or_else(|| "gpt-4".to_string());
let conversation = service
.create_conversation(
user_id,
params.project_id,
params.title.clone(),
model,
params.model_config.clone(),
params.access_visibility.clone(),
params.can_ask.clone(),
params.model_uid,
params.model_name.clone(),
)
.await?;
let resp = ConversationResponse::from(conversation);
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations",
operation_id = "ai_conversation_list",
params(
("project_id" = Option<Uuid>, Query, description = "Filter by project"),
),
responses(
(status = 200, description = "List of conversations", body = ApiResponse<Vec<ConversationResponse>>),
(status = 401, description = "Unauthorized"),
),
tag = "AI Chat"
)]
pub async fn conversation_list(
service: web::Data<service::AppService>,
session: Session,
query: web::Query<ConversationListQuery>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let convs = service
.list_conversations(user_id, query.project_id, 50)
.await?;
let resp: Vec<ConversationResponse> = convs
.into_iter()
.map(ConversationResponse::from)
.collect();
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}",
operation_id = "ai_conversation_get",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
),
responses(
(status = 200, description = "Get conversation", body = ApiResponse<ConversationResponse>),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn conversation_get(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let conversation_id = path.into_inner();
let c = service
.find_conversation_owned(conversation_id, user_id)
.await?;
let resp = ConversationResponse::from(c);
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
patch,
path = "/api/ai/conversations/{conversation_id}",
operation_id = "ai_conversation_update",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
),
request_body = super::types::UpdateConversationParams,
responses(
(status = 200, description = "Conversation updated"),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn conversation_update(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<Uuid>,
params: web::Json<super::types::UpdateConversationParams>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let conversation_id = path.into_inner();
service
.update_conversation(
conversation_id,
user_id,
params.title.clone(),
params.model.clone(),
params.model_config.clone(),
params.status.clone(),
params.access_visibility.clone(),
params.can_ask.clone(),
params.model_uid,
params.model_name.clone(),
)
.await?;
Ok(crate::api_success())
}
#[utoipa::path(
delete,
path = "/api/ai/conversations/{conversation_id}",
operation_id = "ai_conversation_delete",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
),
responses(
(status = 200, description = "Conversation deleted"),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn conversation_delete(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let conversation_id = path.into_inner();
service
.delete_conversation(conversation_id, user_id)
.await?;
Ok(crate::api_success())
}

View File

@ -0,0 +1,102 @@
use actix_web::{web, HttpResponse, Result};
use session::Session;
use service::error::AppError;
use uuid::Uuid;
use crate::error::ApiError;
use crate::ApiResponse;
#[derive(Debug, serde::Serialize, utoipa::ToSchema)]
pub struct ForkResponse {
pub id: Uuid,
pub conversation_id: Option<Uuid>,
pub source_message_id: Uuid,
pub fork_message_id: Uuid,
#[schema(value_type = chrono::DateTime<chrono::Utc>)]
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/fork/{target_message_id}",
operation_id = "ai_message_fork",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Source message ID"),
("target_message_id" = Uuid, Path, description = "Target/fork message ID to create"),
),
responses(
(status = 200, description = "Fork created", body = ApiResponse<ForkResponse>),
),
tag = "AI Chat"
)]
pub async fn message_fork(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = session
.user()
.ok_or_else(|| ApiError::from(AppError::Unauthorized))?;
let (conversation_id, source_message_id, target_message_id) = path.into_inner();
let fork_record = service
.fork_message(
conversation_id,
user_id,
source_message_id,
target_message_id,
)
.await?;
let resp = ForkResponse {
id: fork_record.id,
conversation_id: fork_record.conversation_id,
source_message_id: fork_record.source_message_id,
fork_message_id: fork_record.fork_message_id,
created_at: fork_record.created_at,
};
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/forks",
operation_id = "ai_message_forks",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Source message ID"),
),
responses(
(status = 200, description = "List forks from message", body = ApiResponse<Vec<ForkResponse>>),
),
tag = "AI Chat"
)]
pub async fn message_forks(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = session
.user()
.ok_or_else(|| ApiError::from(AppError::Unauthorized))?;
let (conversation_id, source_message_id) = path.into_inner();
let forks = service
.list_forks(conversation_id, user_id, source_message_id)
.await?;
let resp: Vec<ForkResponse> = forks
.into_iter()
.map(|f| ForkResponse {
id: f.id,
conversation_id: f.conversation_id,
source_message_id: f.source_message_id,
fork_message_id: f.fork_message_id,
created_at: f.created_at,
})
.collect();
Ok(ApiResponse::ok(resp).to_response())
}

View File

@ -0,0 +1,346 @@
use crate::error::ApiError;
use crate::ApiResponse;
use actix_web::{web, HttpResponse, Result};
use session::Session;
use service::error::AppError;
use uuid::Uuid;
use super::types::{CreateMessageParams, EditMessageParams, MessageListQuery, MessageResponse};
fn get_user_id(session: &Session) -> Result<Uuid, ApiError> {
session.user().ok_or_else(|| ApiError::from(AppError::Unauthorized))
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/messages",
operation_id = "ai_message_list",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("limit" = Option<i64>, Query, description = "Max messages"),
),
responses(
(status = 200, description = "List messages", body = ApiResponse<Vec<MessageResponse>>),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn message_list(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<Uuid>,
query: web::Query<MessageListQuery>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let conversation_id = path.into_inner();
let limit = query.limit.unwrap_or(50) as u64;
let msgs = service
.list_messages(conversation_id, user_id, limit)
.await?;
let resp: Vec<MessageResponse> = msgs
.into_iter()
.map(MessageResponse::from)
.collect();
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/messages",
operation_id = "ai_message_create",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
),
request_body = CreateMessageParams,
responses(
(status = 200, description = "Message created", body = ApiResponse<MessageResponse>),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn message_create(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<Uuid>,
params: web::Json<CreateMessageParams>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let conversation_id = path.into_inner();
let msg = service
.create_message(
conversation_id,
user_id,
params.parent_message_id,
params.content.role.clone(),
params.content.content.clone(),
params.model.clone(),
params.is_fork_origin.unwrap_or(false),
params.metadata.clone(),
params.room_id,
)
.await?;
let resp = MessageResponse::from(msg);
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}",
operation_id = "ai_message_get",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID"),
),
responses(
(status = 200, description = "Get message", body = ApiResponse<MessageResponse>),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn message_get(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
let msg = service
.get_message(conversation_id, user_id, message_id)
.await?;
let resp = MessageResponse::from(msg);
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/stop",
operation_id = "ai_message_stop",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID"),
),
responses(
(status = 200, description = "Message stopped"),
),
tag = "AI Chat"
)]
pub async fn message_stop(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
service
.stop_message(conversation_id, user_id, message_id)
.await?;
Ok(crate::api_success())
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/resend",
operation_id = "ai_message_resend",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID"),
),
responses(
(status = 200, description = "Resend message", body = ApiResponse<MessageResponse>),
),
tag = "AI Chat"
)]
pub async fn message_resend(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
let new_msg = service
.resend_message(conversation_id, user_id, message_id)
.await?;
let resp = MessageResponse::from(new_msg);
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/children",
operation_id = "ai_message_children",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Parent message ID"),
),
responses(
(status = 200, description = "List child messages", body = ApiResponse<Vec<MessageResponse>>),
),
tag = "AI Chat"
)]
pub async fn message_children(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, parent_message_id) = path.into_inner();
let msgs = service
.list_child_messages(conversation_id, user_id, parent_message_id)
.await?;
let resp: Vec<MessageResponse> = msgs
.into_iter()
.map(MessageResponse::from)
.collect();
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/stream",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID"),
),
responses(
(status = 200, description = "SSE stream"),
),
tag = "AI Chat"
)]
pub async fn message_stream(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
// Verify user owns the conversation
let conv = service
.find_conversation_owned(conversation_id, user_id)
.await?;
let model = conv.model;
let response = actix_web::HttpResponse::Ok()
.content_type("text/event-stream")
.insert_header(("Cache-Control", "no-cache"))
.insert_header(("X-Accel-Buffering", "no"))
.streaming(super::super::stream::create_chat_sse_stream(
service.get_ref().clone(),
conversation_id,
message_id,
model,
user_id,
));
Ok(response.into())
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/edit",
operation_id = "ai_message_edit",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID to edit"),
),
request_body = EditMessageParams,
responses(
(status = 200, description = "Message edited, new version created", body = ApiResponse<MessageResponse>),
),
tag = "AI Chat"
)]
pub async fn message_edit(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
params: web::Json<EditMessageParams>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
let new_msg = service
.edit_message(conversation_id, user_id, message_id, params.content.clone())
.await?;
let resp = MessageResponse::from(new_msg);
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/versions",
operation_id = "ai_message_versions",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID"),
),
responses(
(status = 200, description = "List message versions", body = ApiResponse<Vec<MessageResponse>>),
),
tag = "AI Chat"
)]
pub async fn message_versions(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
let versions = service
.list_message_versions(conversation_id, user_id, message_id)
.await?;
let resp: Vec<MessageResponse> = versions
.into_iter()
.map(MessageResponse::from)
.collect();
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/messages/{message_id}/switch-version",
operation_id = "ai_message_switch_version",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("message_id" = Uuid, Path, description = "Message ID"),
),
request_body = super::types::SwitchVersionParams,
responses(
(status = 200, description = "Version switched", body = ApiResponse<MessageResponse>),
),
tag = "AI Chat"
)]
pub async fn message_switch_version(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<(Uuid, Uuid)>,
params: web::Json<super::types::SwitchVersionParams>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let (conversation_id, message_id) = path.into_inner();
let msg = service
.switch_message_version(conversation_id, user_id, message_id, params.version_number)
.await?;
let resp = MessageResponse::from(msg);
Ok(ApiResponse::ok(resp).to_response())
}

View File

@ -0,0 +1,5 @@
pub mod conversation;
pub mod fork;
pub mod message;
pub mod share;
pub mod types;

View File

@ -0,0 +1,76 @@
use actix_web::{web, HttpResponse, Result};
use session::Session;
use service::error::AppError;
use uuid::Uuid;
use crate::error::ApiError;
use crate::ApiResponse;
use super::types::{ConversationResponse, ShareResponse};
fn get_user_id(session: &Session) -> Result<Uuid, ApiError> {
session.user().ok_or_else(|| ApiError::from(AppError::Unauthorized))
}
#[utoipa::path(
post,
path = "/api/ai/conversations/{conversation_id}/share",
operation_id = "ai_conversation_share",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
),
responses(
(status = 200, description = "Share token created", body = ApiResponse<ShareResponse>),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn conversation_share(
service: web::Data<service::AppService>,
session: Session,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = get_user_id(&session)?;
let conversation_id = path.into_inner();
let (share, share_token) = service
.share_conversation(conversation_id, user_id)
.await?;
let resp = ShareResponse {
id: share.id,
share_token,
view_count: share.view_count,
expires_at: share.expires_at,
};
Ok(ApiResponse::ok(resp).to_response())
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/share/{share_token}",
operation_id = "ai_shared_conversation_get",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
("share_token" = String, Path, description = "Share token"),
),
responses(
(status = 200, description = "Get shared conversation", body = ApiResponse<ConversationResponse>),
(status = 404, description = "Not found or expired"),
),
tag = "AI Chat"
)]
pub async fn shared_conversation_get(
service: web::Data<service::AppService>,
path: web::Path<(Uuid, String)>,
) -> Result<HttpResponse, ApiError> {
let (conversation_id, share_token) = path.into_inner();
let c = service
.get_shared_conversation(conversation_id, share_token)
.await?;
let resp = ConversationResponse::from(c);
Ok(ApiResponse::ok(resp).to_response())
}

View File

@ -0,0 +1,175 @@
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateConversationParams {
pub project_id: Option<Uuid>,
pub title: Option<String>,
pub model: Option<String>,
pub model_config: Option<serde_json::Value>,
pub access_visibility: Option<String>,
pub can_ask: Option<String>,
/// AI model UUID for model selection
pub model_uid: Option<Uuid>,
/// AI model display name
pub model_name: Option<String>,
}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct ConversationResponse {
pub id: Uuid,
pub user_id: Uuid,
pub project_id: Option<Uuid>,
pub scope: String,
pub title: Option<String>,
pub model: String,
pub model_config: Option<serde_json::Value>,
pub status: String,
pub root_message_id: Option<Uuid>,
pub fork_count: i32,
pub is_shared: bool,
pub message_count: i32,
pub token_usage_total: Option<i32>,
pub access_visibility: String,
pub can_ask: String,
pub project_uid: Option<i32>,
pub model_uid: Option<Uuid>,
pub model_name: Option<String>,
#[schema(value_type = chrono::DateTime<chrono::Utc>)]
pub created_at: chrono::DateTime<chrono::Utc>,
#[schema(value_type = chrono::DateTime<chrono::Utc>)]
pub updated_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateConversationParams {
pub title: Option<String>,
pub model: Option<String>,
pub model_config: Option<serde_json::Value>,
pub status: Option<String>,
pub access_visibility: Option<String>,
pub can_ask: Option<String>,
pub model_uid: Option<Uuid>,
pub model_name: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ConversationListQuery {
pub project_id: Option<Uuid>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct MessageContent {
pub role: String,
pub content: serde_json::Value,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateMessageParams {
pub parent_message_id: Option<Uuid>,
pub content: MessageContent,
pub model: Option<String>,
pub is_fork_origin: Option<bool>,
pub metadata: Option<serde_json::Value>,
pub room_id: Option<Uuid>,
}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct MessageResponse {
pub id: Uuid,
pub conversation_id: Uuid,
pub parent_message_id: Option<Uuid>,
pub role: String,
pub content: serde_json::Value,
pub model: Option<String>,
pub is_fork_origin: bool,
pub stop_reason: Option<String>,
pub input_tokens: Option<i32>,
pub output_tokens: Option<i32>,
pub latency_ms: Option<i32>,
pub metadata: Option<serde_json::Value>,
pub room_id: Option<Uuid>,
pub version_group_id: Option<Uuid>,
pub version_number: i32,
pub is_latest: bool,
#[schema(value_type = chrono::DateTime<chrono::Utc>)]
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Deserialize)]
pub struct MessageListQuery {
pub limit: Option<i64>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct EditMessageParams {
pub content: String,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct SwitchVersionParams {
pub version_number: i32,
}
#[derive(Debug, Deserialize)]
pub struct ForkParams {}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct ShareResponse {
pub id: Uuid,
pub share_token: String,
pub view_count: i32,
#[schema(value_type = Option<chrono::DateTime<chrono::Utc>>)]
pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
}
impl From<models::ai::ai_conversation::Model> for ConversationResponse {
fn from(c: models::ai::ai_conversation::Model) -> Self {
Self {
id: c.id,
user_id: c.user_id,
project_id: c.project_id,
scope: c.scope,
title: c.title,
model: c.model,
model_config: c.model_config,
status: c.status,
root_message_id: c.root_message_id,
fork_count: c.fork_count,
is_shared: c.is_shared,
message_count: c.message_count,
token_usage_total: c.token_usage_total,
access_visibility: c.access_visibility,
can_ask: c.can_ask,
project_uid: c.project_uid,
model_uid: c.model_uid,
model_name: c.model_name,
created_at: c.created_at,
updated_at: c.updated_at,
}
}
}
impl From<models::ai::ai_message::Model> for MessageResponse {
fn from(m: models::ai::ai_message::Model) -> Self {
Self {
id: m.id,
conversation_id: m.conversation_id,
parent_message_id: m.parent_message_id,
role: m.role,
content: m.content,
model: m.model,
is_fork_origin: m.is_fork_origin,
stop_reason: m.stop_reason,
input_tokens: m.input_tokens,
output_tokens: m.output_tokens,
latency_ms: m.latency_ms,
metadata: m.metadata,
room_id: m.room_id,
version_group_id: m.version_group_id,
version_number: m.version_number,
is_latest: m.is_latest,
created_at: m.created_at,
}
}
}

85
libs/api/chat/mod.rs Normal file
View File

@ -0,0 +1,85 @@
use actix_web::web;
pub mod handlers;
pub mod stream;
pub mod watch;
pub fn init_chat_routes(cfg: &mut web::ServiceConfig) {
cfg.service(
web::scope("/ai/conversations")
.route("", web::post().to(handlers::conversation::conversation_create))
.route("", web::get().to(handlers::conversation::conversation_list))
.route(
"/{conversation_id}",
web::get().to(handlers::conversation::conversation_get),
)
.route(
"/{conversation_id}",
web::patch().to(handlers::conversation::conversation_update),
)
.route(
"/{conversation_id}",
web::delete().to(handlers::conversation::conversation_delete),
)
.route(
"/{conversation_id}/watch",
web::get().to(watch::conversation_watch),
)
.route(
"/{conversation_id}/share",
web::post().to(handlers::share::conversation_share),
)
.route(
"/{conversation_id}/share/{share_token}",
web::get().to(handlers::share::shared_conversation_get),
)
.route(
"/{conversation_id}/messages",
web::get().to(handlers::message::message_list),
)
.route(
"/{conversation_id}/messages",
web::post().to(handlers::message::message_create),
)
.route(
"/{conversation_id}/messages/{message_id}",
web::get().to(handlers::message::message_get),
)
.route(
"/{conversation_id}/messages/{message_id}/stop",
web::post().to(handlers::message::message_stop),
)
.route(
"/{conversation_id}/messages/{message_id}/resend",
web::post().to(handlers::message::message_resend),
)
.route(
"/{conversation_id}/messages/{message_id}/fork/{target_message_id}",
web::post().to(handlers::fork::message_fork),
)
.route(
"/{conversation_id}/messages/{message_id}/forks",
web::get().to(handlers::fork::message_forks),
)
.route(
"/{conversation_id}/messages/{message_id}/stream",
web::get().to(handlers::message::message_stream),
)
.route(
"/{conversation_id}/messages/{message_id}/children",
web::get().to(handlers::message::message_children),
)
.route(
"/{conversation_id}/messages/{message_id}/edit",
web::post().to(handlers::message::message_edit),
)
.route(
"/{conversation_id}/messages/{message_id}/versions",
web::get().to(handlers::message::message_versions),
)
.route(
"/{conversation_id}/messages/{message_id}/switch-version",
web::post().to(handlers::message::message_switch_version),
),
);
}

463
libs/api/chat/stream.rs Normal file
View File

@ -0,0 +1,463 @@
use agent::chat::chat_execution;
use agent::chat::{normalize_thinking_content, AiChunkType, AiStreamChunk};
use agent::client::AiClientConfig;
use agent::client::types::ChatRequestMessage;
use agent::client::StreamChunkType;
use futures::StreamExt;
use models::ai::{ai_message, ai_conversation, AiMessage};
use queue::{ChatMessageEvent, ChatStreamChunkEvent};
use sea_orm::{EntityTrait, QueryFilter, ColumnTrait, QueryOrder, ActiveModelTrait, Set, PaginatorTrait};
use service::AppService;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid;
/// Create an SSE stream that executes AI chat with ReAct tool-calling.
///
/// Also publishes chat messages and stream chunks via NATS JetStream for
/// multi-viewer support. The requesting client receives SSE events, while
/// other viewers receive chunks via NATS → WebSocket broadcast.
pub fn create_chat_sse_stream(
service: AppService,
conversation_id: Uuid,
user_message_id: Uuid,
model_name: String,
user_id: Uuid,
) -> Pin<Box<dyn futures::Stream<Item = Result<actix_web::web::Bytes, actix_web::Error>> + Send>> {
let (tx, rx) = tokio::sync::mpsc::channel::<String>(100);
let cache = service.cache.clone();
tokio::spawn(async move {
// Check for active stream (SSE reconnect recovery) BEFORE starting a new one
// so the frontend can recover from a page refresh.
if let Some((msg_id, started_at)) = cache.get_chat_stream_active(conversation_id).await {
let _ = tx.send(format!(
"data: {{\"event\":\"recovery\",\"data\":{{\"message_id\":\"{}\",\"started_at\":{}}}}}\n\n",
msg_id,
started_at
)).await;
}
let queue = service.queue_producer.clone();
let chunk_seq = Arc::new(AtomicU64::new(0));
// Build messages from conversation history
let messages = match build_messages_from_history(&service, conversation_id).await {
Ok(msgs) => msgs,
Err(e) => {
let _ = tx.send(format!("data: {{\"event\":\"error\",\"data\":\"{}\"}}\n\n", e)).await;
return;
}
};
// Get AI config
let api_key = match service.config.ai_api_key() {
Ok(k) => k,
Err(_) => {
let _ = tx.send("data: {\"event\":\"error\",\"data\":\"AI not configured\"}\n\n".to_string()).await;
return;
}
};
let base_url = match service.config.ai_basic_url() {
Ok(u) => u,
Err(_) => {
let _ = tx.send("data: {\"event\":\"error\",\"data\":\"AI not configured\"}\n\n".to_string()).await;
return;
}
};
let config = AiClientConfig::new(api_key).with_base_url(&base_url);
// Get tools from ChatService if available
let (tools, tool_registry, embed_service) = match &service.chat_service {
Some(cs) => (
cs.tools(),
cs.tool_registry().cloned(),
service.embed_service.as_ref().map(|es| (**es).clone()),
),
None => (Vec::new(), None, None),
};
// Get project_id from conversation
let project_id = match service.find_conversation(conversation_id).await {
Ok(c) => c.project_id.unwrap_or(Uuid::nil()),
Err(_) => {
let _ = tx.send("data: {\"event\":\"error\",\"data\":\"conversation not found\"}\n\n".to_string()).await;
return;
}
};
// Pre-flight balance check: verify project + user can afford at least a minimal AI call
let balance_ok = agent::billing::check_balance(
&service.db, project_id, user_id, Uuid::nil(), 500, 250,
).await;
match balance_ok {
Ok(true) => {},
Ok(false) => {
tracing::warn!(project_id = %project_id, user_id = %user_id, "Insufficient balance for chat AI call");
let _ = agent::billing::persist_billing_error(
&service.db, "user", user_id, "insufficient_balance",
&format!("Insufficient balance. Your account does not have enough funds for this AI request."),
Some(serde_json::json!({
"user_id": user_id.to_string(),
"project_id": project_id.to_string(),
})),
).await;
let error_msg = "Insufficient balance. Your account does not have enough funds to process this AI request. Please add credits to continue.";
let _ = tx.send(format!("data: {{\"event\":\"billing_error\",\"data\":\"{}\"}}\n\n", error_msg)).await;
let _ = tx.send("data: {\"event\":\"done\",\"data\":\"billing_error\"}\n\n".to_string()).await;
return;
},
Err(e) => {
tracing::warn!(error = %e, "Balance check failed, proceeding without pre-flight check");
}
}
let max_tool_depth = 99;
// Determine conversation project_id for chat message event
let conv_project_id = match service.find_conversation(conversation_id).await {
Ok(c) => c.project_id,
Err(_) => None,
};
// Broadcast chat message start event via NATS
let chat_msg = ChatMessageEvent {
message_id: user_message_id,
conversation_id,
project_id: conv_project_id,
sender_id: Uuid::nil(),
role: "assistant".to_string(),
content: String::new(),
model: Some(model_name.clone()),
input_tokens: None,
output_tokens: None,
timestamp: chrono::Utc::now(),
};
let _ = queue.publish_chat_message(&chat_msg).await;
// Mark stream as active in Redis so page refresh can recover
let _ = cache.set_chat_stream_active(conversation_id, user_message_id).await;
let on_chunk_tx = tx.clone();
let on_chunk_queue = queue.clone();
let on_chunk_seq = chunk_seq.clone();
let on_chunk_conv_id = conversation_id;
let on_chunk_msg_id = user_message_id;
let on_chunk_model = model_name.clone();
let on_chunk: agent::chat::StreamCallback = Box::new(move |chunk: AiStreamChunk| {
let tx = on_chunk_tx.clone();
let queue = on_chunk_queue.clone();
let seq = on_chunk_seq.fetch_add(1, Ordering::Relaxed);
let conv_id = on_chunk_conv_id;
let msg_id = on_chunk_msg_id;
let model = on_chunk_model.clone();
Box::pin(async move {
let event = match chunk.chunk_type {
AiChunkType::Thinking => "thinking",
AiChunkType::Answer => "token",
AiChunkType::ToolCall => "tool_call",
AiChunkType::ToolResult => "tool_result",
};
let content = match chunk.chunk_type {
AiChunkType::Thinking => normalize_thinking_content(&chunk.content),
_ => chunk.content.clone(),
};
let sse = format!(
"data: {{\"event\":\"{}\",\"data\":{}}}\n\n",
event,
serde_json::to_string(&content).unwrap_or_default()
);
let _ = tx.send(sse).await;
// Also broadcast via NATS for other viewers
let natts_chunk = ChatStreamChunkEvent {
conversation_id: conv_id,
message_id: msg_id,
seq,
content,
done: false,
error: None,
chunk_type: Some(event.to_string()),
model_name: Some(model),
};
queue.publish_chat_chunk(&natts_chunk).await;
}) as Pin<Box<dyn std::future::Future<Output = ()> + Send>>
});
let result = chat_execution::execute_chat_stream(
messages,
tools,
&model_name,
&config,
0.7, // temperature
4096, // max_tokens
max_tool_depth,
tool_registry.as_ref(),
service.db.clone(),
service.cache.clone(),
service.config.clone(),
project_id,
Uuid::nil(), // sender_uid — unknown in Chat API context
embed_service,
on_chunk,
Some(conversation_id),
).await;
// Clear stream active state (streaming finished)
let _ = cache.clear_chat_stream_active(conversation_id).await;
match result {
Ok(stream_result) => {
// Build ordered content blocks from stream chunks, merging
// consecutive blocks of the same role (thinking/assistant).
let raw_blocks: Vec<(String, String)> = stream_result.chunks.iter()
.filter(|c| matches!(c.chunk_type, StreamChunkType::Thinking | StreamChunkType::Answer))
.map(|chunk| {
let role = match chunk.chunk_type {
StreamChunkType::Thinking => "thinking",
_ => "assistant",
};
(role.to_string(), chunk.content.clone())
})
.collect();
let merged_blocks = merge_consecutive_blocks(raw_blocks);
// Apply thinking normalization to the fully merged thinking
// blocks — per-token normalization is meaningless since each
// chunk is a single token.
let normalized_blocks: Vec<(String, String)> = merged_blocks.into_iter().map(|(role, content)| {
if role == "thinking" {
(role, normalize_thinking_content(&content))
} else {
(role, content)
}
}).collect();
let content_blocks: Vec<serde_json::Value> = normalized_blocks.iter()
.map(|(role, content)| serde_json::json!({ "role": role, "content": content }))
.collect();
let content_value = if content_blocks.is_empty() {
serde_json::json!([{ "role": "assistant", "content": stream_result.content }])
} else {
serde_json::json!(content_blocks)
};
// Persist assistant message
let assistant_msg_id = Uuid::now_v7();
let assistant_msg = ai_message::ActiveModel {
id: Set(assistant_msg_id),
conversation_id: Set(conversation_id),
parent_message_id: Set(Some(user_message_id)),
role: Set("assistant".to_string()),
content: Set(content_value),
model: Set(Some(model_name.clone())),
is_fork_origin: Set(false),
stop_reason: Set(Some("stop".to_string())),
input_tokens: Set(Some(stream_result.input_tokens as i32)),
output_tokens: Set(Some(stream_result.output_tokens as i32)),
latency_ms: Set(None),
metadata: Set(None),
room_id: Set(None),
version_group_id: Set(Some(assistant_msg_id)),
version_number: Set(1),
is_latest: Set(true),
created_at: Set(chrono::Utc::now()),
};
let saved = assistant_msg.insert(service.db.writer()).await;
if let Ok(msg) = &saved {
update_conversation_after_response(&service, conversation_id, msg).await;
// After AI response, check/update conversation title and emit via SSE
if let Ok(Some(conv)) = ai_conversation::Entity::find_by_id(conversation_id)
.one(service.db.reader()).await
{
let existing_title = conv.title.clone();
let needs_title = existing_title.as_deref().map(|t| t.is_empty() || t == "New Chat").unwrap_or(true);
if needs_title {
// Generate title from first user message
let first_user_msg = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.filter(ai_message::Column::Role.eq("user"))
.order_by_asc(ai_message::Column::CreatedAt)
.one(service.db.reader()).await.ok().flatten();
if let Some(user_msg) = first_user_msg {
let content = match &user_msg.content {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => {
arr.first()
.and_then(|f| f.get("content"))
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string()
}
other => other.to_string(),
};
// Simple title extraction: first meaningful words
let title = content
.split_whitespace()
.filter(|w| w.len() > 2)
.take(5)
.collect::<Vec<_>>()
.join(" ");
if !title.is_empty() {
let truncated: String = title.chars().take(40).collect();
// Save title to DB
let mut active: ai_conversation::ActiveModel = conv.into();
active.title = Set(Some(truncated.clone()));
active.updated_at = Set(chrono::Utc::now());
let _ = active.update(service.db.writer()).await;
// Emit title via SSE
let title_payload = serde_json::json!({"title": truncated}).to_string();
let _ = tx.send(format!("data: {{\"event\":\"title\",\"data\":{}}}\n\n", title_payload)).await;
}
}
} else if let Some(title) = &existing_title {
// Title already set (e.g. by AI tool) — emit it
let title_payload = serde_json::json!({"title": title}).to_string();
let _ = tx.send(format!("data: {{\"event\":\"title\",\"data\":{}}}\n\n", title_payload)).await;
}
}
}
// Broadcast final chat message with token usage
let final_msg = ChatMessageEvent {
message_id: user_message_id,
conversation_id,
project_id: conv_project_id,
sender_id: Uuid::nil(),
role: "assistant".to_string(),
content: stream_result.content.clone(),
model: Some(model_name.clone()),
input_tokens: Some(stream_result.input_tokens as i32),
output_tokens: Some(stream_result.output_tokens as i32),
timestamp: chrono::Utc::now(),
};
let _ = queue.publish_chat_message(&final_msg).await;
// Send final SSE done event
let _ = tx.send("data: {\"event\":\"done\",\"data\":\"ok\"}\n\n".to_string()).await;
}
Err(e) => {
let _ = tx.send(format!("data: {{\"event\":\"error\",\"data\":\"{}\"}}\n\n", e)).await;
}
}
});
Box::pin(ReceiverStream::new(rx).map(|msg| Ok(actix_web::web::Bytes::from(msg))))
}
/// Update conversation metadata after an AI assistant message is saved.
async fn update_conversation_after_response(
service: &AppService,
conversation_id: Uuid,
assistant_msg: &ai_message::Model,
) {
use models::ai::ai_conversation;
use sea_orm::EntityTrait;
if let Ok(Some(conv)) = ai_conversation::Entity::find_by_id(conversation_id)
.one(service.db.reader()).await
{
let input_tokens = assistant_msg.input_tokens.unwrap_or(0) as i64;
let output_tokens = assistant_msg.output_tokens.unwrap_or(0) as i64;
let total_tokens = input_tokens + output_tokens;
let mut active: ai_conversation::ActiveModel = conv.into();
if let Ok(count) = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.count(service.db.reader()).await
{
active.message_count = Set(count as i32);
}
active.token_usage_total = Set(Some(total_tokens as i32));
active.updated_at = Set(chrono::Utc::now());
let _ = active.update(service.db.writer()).await;
}
}
/// Build ChatRequestMessage list from ai_message conversation history.
async fn build_messages_from_history(
service: &AppService,
conversation_id: Uuid,
) -> Result<Vec<ChatRequestMessage>, String> {
let msgs = AiMessage::find()
.filter(ai_message::Column::ConversationId.eq(conversation_id))
.filter(ai_message::Column::IsLatest.eq(true))
.order_by_asc(ai_message::Column::CreatedAt)
.all(service.db.reader())
.await
.map_err(|e| format!("db error: {}", e))?;
let mut chat_messages = Vec::new();
for msg in &msgs {
let role = msg.role.as_str();
let content = match &msg.content {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => {
// Content is ordered blocks: [{role:"thinking",content:"..."}, {role:"assistant","content":"..."}, ...]
// For assistant messages: concatenate all "assistant" blocks
// For user/system messages: take the first block's content
if role == "assistant" {
arr.iter()
.filter(|item| item.get("role").and_then(|r| r.as_str()) != Some("thinking"))
.filter_map(|item| item.get("content").and_then(|c| c.as_str()))
.collect::<Vec<_>>()
.join("\n")
} else if let Some(first) = arr.first() {
first.get("content")
.and_then(|c| c.as_str())
.unwrap_or("")
.to_string()
} else {
String::new()
}
}
other => other.to_string(),
};
match role {
"user" => chat_messages.push(ChatRequestMessage::user(content)),
"assistant" => chat_messages.push(ChatRequestMessage::assistant(Some(content), None)),
"system" => chat_messages.push(ChatRequestMessage::system(content)),
_ => chat_messages.push(ChatRequestMessage::user(content)),
}
}
Ok(chat_messages)
}
/// Merge consecutive content blocks of the same role into single blocks.
/// This transforms many small per-chunk blocks into clean interleaved segments:
/// [thinking, thinking, assistant, assistant] → [thinking, assistant]
/// Per-token chunks are concatenated directly — the model sends \n inside
/// the token content where needed, not between tokens.
fn merge_consecutive_blocks(blocks: Vec<(String, String)>) -> Vec<(String, String)> {
let mut merged: Vec<(String, String)> = Vec::new();
for (role, content) in blocks {
if content.is_empty() { continue; }
if let Some(last) = merged.last_mut() {
if last.0 == role {
last.1.push_str(&content);
continue;
}
}
merged.push((role, content));
}
merged
}

158
libs/api/chat/watch.rs Normal file
View File

@ -0,0 +1,158 @@
//! SSE endpoint for watching a chat conversation in real-time via NATS.
//!
//! Unlike the primary SSE stream (which triggers AI execution), this endpoint
//! passively subscribes to NATS Core subjects and forwards chat messages and
//! stream chunks to connected clients. This enables multiple viewers to watch
//! the same AI conversation in real-time.
use actix_web::{web, HttpResponse, Result};
use futures::StreamExt;
use service::AppService;
use std::pin::Pin;
use uuid::Uuid;
use crate::error::ApiError;
/// SSE endpoint for watching a chat conversation.
///
/// `GET /api/ai/conversations/{conversation_id}/watch`
///
/// Subscribes to NATS Core subjects (`chat.chunk.{id}` and `chat.message.{id}`)
/// and forwards received events as SSE to the connected client.
///
/// SSE events:
/// - `chunk` — a stream chunk (thinking, token, tool_call, tool_result, done, error)
/// - `message` — a complete chat message
/// - `error` — an error event
pub fn create_watch_sse_stream(
service: AppService,
conversation_id: Uuid,
) -> Pin<Box<dyn futures::Stream<Item = Result<actix_web::web::Bytes, actix_web::Error>> + Send>> {
let (tx, rx) = tokio::sync::mpsc::channel::<String>(200);
tokio::spawn(async move {
let nats = match &service.queue_producer.nats {
Some(n) => n.clone(),
None => {
let _ = tx.send(format!(
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
serde_json::to_string("NATS not available").unwrap_or_default()
)).await;
return;
}
};
// Subscribe to chat chunks
let chunk_subject = format!("chat.chunk.{}", conversation_id);
let mut chunk_sub = match nats.subscribe(&chunk_subject).await {
Ok(s) => s,
Err(e) => {
let _ = tx.send(format!(
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
serde_json::to_string(&e.to_string()).unwrap_or_default()
)).await;
return;
}
};
// Subscribe to chat messages
let msg_subject = format!("chat.message.{}", conversation_id);
let mut msg_sub = match nats.subscribe(&msg_subject).await {
Ok(s) => s,
Err(e) => {
let _ = tx.send(format!(
"data: {{\"event\":\"error\",\"data\":{}}}\n\n",
serde_json::to_string(&e.to_string()).unwrap_or_default()
)).await;
return;
}
};
let _ = tx.send(":ok\n\n".to_string()).await;
loop {
tokio::select! {
chunk_msg = chunk_sub.next() => {
match chunk_msg {
Some(msg) => {
let payload = String::from_utf8_lossy(&msg.payload);
// Parse to get chunk_type for the event field
let event_type = if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&payload) {
parsed.get("chunk_type")
.and_then(|v| v.as_str())
.unwrap_or("chunk")
.to_string()
} else {
"chunk".to_string()
};
let sse = format!(
"data: {{\"event\":\"{}\",\"data\":{}}}\n\n",
event_type, payload
);
if tx.send(sse).await.is_err() {
break;
}
}
None => break,
}
}
msg = msg_sub.next() => {
match msg {
Some(msg) => {
let payload = String::from_utf8_lossy(&msg.payload);
let sse = format!(
"data: {{\"event\":\"message\",\"data\":{}}}\n\n",
payload
);
if tx.send(sse).await.is_err() {
break;
}
}
None => break,
}
}
}
}
});
Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx).map(|s| {
Ok(actix_web::web::Bytes::from(s))
}))
}
#[utoipa::path(
get,
path = "/api/ai/conversations/{conversation_id}/watch",
params(
("conversation_id" = Uuid, Path, description = "Conversation ID"),
),
responses(
(status = 200, description = "SSE stream of conversation events"),
(status = 404, description = "Not found"),
),
tag = "AI Chat"
)]
pub async fn conversation_watch(
service: web::Data<AppService>,
session: session::Session,
path: web::Path<Uuid>,
) -> Result<HttpResponse, ApiError> {
let user_id = session.user().ok_or_else(|| ApiError::from(service::error::AppError::Unauthorized))?;
let conversation_id = path.into_inner();
// Verify access (view-only is sufficient)
let _conv = service
.find_conversation_owned(conversation_id, user_id)
.await?;
let response = HttpResponse::Ok()
.content_type("text/event-stream")
.insert_header(("Cache-Control", "no-cache"))
.insert_header(("X-Accel-Buffering", "no"))
.streaming(create_watch_sse_stream(
service.get_ref().clone(),
conversation_id,
));
Ok(response.into())
}

View File

@ -116,20 +116,30 @@ pub async fn serve_frontend(req: HttpRequest, path: web::Path<String>) -> HttpRe
};
let cc = cache_control_header(path_str);
// Try brotli first (best compression), then gzip, then uncompressed.
// Only serve compressed variant if client explicitly accepts it AND we have one.
let (data, encoding, etag, content_path) =
match frontend::get_frontend_asset_compressed(path_str) {
Some(r) => (r.0, r.1, r.2, path_str),
None => {
// Path not found — try index.html as SPA fallback.
// Also use "index.html" for Content-Type detection (text/html).
match frontend::get_frontend_asset_with_etag("index.html") {
Some((data, etag)) => (data, "", etag, "index.html"),
None => return HttpResponse::NotFound().finish(),
}
}
};
// Try brotli/gzip compressed variant first (best compression),
// then fall back to uncompressed if client doesn't accept the encoding.
let compressed = crate::frontend::get_frontend_asset_compressed(path_str);
let uncompressed = crate::frontend::get_frontend_asset_with_etag(path_str);
let (data, encoding, etag, content_path) = if let Some((c_data, c_enc, c_etag)) = compressed {
if accepts_encoding(&req, c_enc) {
(c_data, c_enc, c_etag, path_str)
} else if let Some((u_data, u_etag)) = uncompressed {
// Client doesn't accept the pre-compressed encoding — serve uncompressed.
(u_data, "", u_etag, path_str)
} else {
// No uncompressed fallback — still serve compressed (client must handle it).
(c_data, c_enc, c_etag, path_str)
}
} else if let Some((data, etag)) = uncompressed {
(data, "", etag, path_str)
} else {
// Path not found — try index.html as SPA fallback.
match crate::frontend::get_frontend_asset_with_etag("index.html") {
Some((data, etag)) => (data, "", etag, "index.html"),
None => return HttpResponse::NotFound().finish(),
}
};
if !encoding.is_empty() && accepts_encoding(&req, &encoding) {
build_asset_response(&req, data, etag, content_path, cc, &encoding)

Some files were not shown because too many files have changed in this diff Show More