Compare commits
No commits in common. "39d126d843037c77ec67f088cf7a11475e45f3c7" and "329b526bfb21b335df82cd8d771dd178bff234cc" have entirely different histories.
39d126d843
...
329b526bfb
@ -107,10 +107,6 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let service = AppService::new(cfg.clone()).await?;
|
let service = AppService::new(cfg.clone()).await?;
|
||||||
slog::info!(log, "AppService initialized");
|
slog::info!(log, "AppService initialized");
|
||||||
|
|
||||||
// Spawn background task: sync OpenRouter models immediately on startup,
|
|
||||||
// then every 10 minutes.
|
|
||||||
let _model_sync_handle = service.clone().start_sync_task();
|
|
||||||
|
|
||||||
let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
|
let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel::<()>(1);
|
||||||
let worker_service = service.clone();
|
let worker_service = service.clone();
|
||||||
let log_for_http = log.clone();
|
let log_for_http = log.clone();
|
||||||
|
|||||||
@ -2,6 +2,7 @@ use crate::ssh::ReceiveSyncService;
|
|||||||
use crate::ssh::RepoReceiveSyncTask;
|
use crate::ssh::RepoReceiveSyncTask;
|
||||||
use crate::ssh::SshTokenService;
|
use crate::ssh::SshTokenService;
|
||||||
use crate::ssh::authz::SshAuthService;
|
use crate::ssh::authz::SshAuthService;
|
||||||
|
use crate::ssh::rate_limit::SshRateLimiter;
|
||||||
use db::cache::AppCache;
|
use db::cache::AppCache;
|
||||||
use db::database::AppDatabase;
|
use db::database::AppDatabase;
|
||||||
use models::repos::{repo, repo_branch_protect};
|
use models::repos::{repo, repo_branch_protect};
|
||||||
@ -19,6 +20,7 @@ use std::net::SocketAddr;
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::process::Stdio;
|
use std::process::Stdio;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
|
||||||
use tokio::process::ChildStdin;
|
use tokio::process::ChildStdin;
|
||||||
@ -78,6 +80,7 @@ pub struct SSHandle {
|
|||||||
pub sync: ReceiveSyncService,
|
pub sync: ReceiveSyncService,
|
||||||
pub upload_pack_eof_sent: HashSet<ChannelId>,
|
pub upload_pack_eof_sent: HashSet<ChannelId>,
|
||||||
pub logger: Logger,
|
pub logger: Logger,
|
||||||
|
pub rate_limiter: Arc<SshRateLimiter>,
|
||||||
pub token_service: SshTokenService,
|
pub token_service: SshTokenService,
|
||||||
pub client_addr: Option<SocketAddr>,
|
pub client_addr: Option<SocketAddr>,
|
||||||
}
|
}
|
||||||
@ -88,6 +91,7 @@ impl SSHandle {
|
|||||||
cache: AppCache,
|
cache: AppCache,
|
||||||
sync: ReceiveSyncService,
|
sync: ReceiveSyncService,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
|
rate_limiter: Arc<SshRateLimiter>,
|
||||||
token_service: SshTokenService,
|
token_service: SshTokenService,
|
||||||
client_addr: Option<SocketAddr>,
|
client_addr: Option<SocketAddr>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
@ -111,6 +115,7 @@ impl SSHandle {
|
|||||||
sync,
|
sync,
|
||||||
upload_pack_eof_sent: HashSet::new(),
|
upload_pack_eof_sent: HashSet::new(),
|
||||||
logger,
|
logger,
|
||||||
|
rate_limiter,
|
||||||
token_service,
|
token_service,
|
||||||
client_addr,
|
client_addr,
|
||||||
}
|
}
|
||||||
@ -196,6 +201,17 @@ impl russh::server::Handler for SSHandle {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let user_id = user_model.uid.to_string();
|
||||||
|
if !self.rate_limiter.is_user_allowed(&user_id).await {
|
||||||
|
warn!(
|
||||||
|
self.logger,
|
||||||
|
"SSH token auth rate limit exceeded: {}, client: {}",
|
||||||
|
user_model.username,
|
||||||
|
client_info
|
||||||
|
);
|
||||||
|
return Err(russh::Error::NotAuthenticated);
|
||||||
|
}
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
self.logger,
|
self.logger,
|
||||||
"SSH token authentication successful: user={}, client={}",
|
"SSH token authentication successful: user={}, client={}",
|
||||||
@ -262,6 +278,16 @@ impl russh::server::Handler for SSHandle {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let user_id = user_model.uid.to_string();
|
||||||
|
if !self.rate_limiter.is_user_allowed(&user_id).await {
|
||||||
|
let msg = format!(
|
||||||
|
"User rate limit exceeded: {}, client: {}",
|
||||||
|
user_model.username, client_info
|
||||||
|
);
|
||||||
|
warn!(self.logger, "{}", msg);
|
||||||
|
return Err(russh::Error::NotAuthenticated);
|
||||||
|
}
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
self.logger,
|
self.logger,
|
||||||
"SSH authentication successful: user={}, client={}", user_model.username, client_info
|
"SSH authentication successful: user={}, client={}", user_model.username, client_info
|
||||||
@ -319,6 +345,16 @@ impl russh::server::Handler for SSHandle {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let user_id = user_model.uid.to_string();
|
||||||
|
if !self.rate_limiter.is_user_allowed(&user_id).await {
|
||||||
|
let msg = format!(
|
||||||
|
"User rate limit exceeded: {}, client: {}",
|
||||||
|
user_model.username, client_info
|
||||||
|
);
|
||||||
|
warn!(self.logger, "{}", msg);
|
||||||
|
return Err(russh::Error::NotAuthenticated);
|
||||||
|
}
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
self.logger,
|
self.logger,
|
||||||
"SSH authentication successful: user={}, client={}", user_model.username, client_info
|
"SSH authentication successful: user={}, client={}", user_model.username, client_info
|
||||||
@ -335,7 +371,10 @@ impl russh::server::Handler for SSHandle {
|
|||||||
channel: ChannelId,
|
channel: ChannelId,
|
||||||
_: &mut Session,
|
_: &mut Session,
|
||||||
) -> Result<(), Self::Error> {
|
) -> Result<(), Self::Error> {
|
||||||
info!(self.logger, "{}", format!("channel_close channel={:?} client={:?}", channel, self.client_addr));
|
info!(self.logger, "channel_close";
|
||||||
|
"channel" => ?channel,
|
||||||
|
"client" => ?self.client_addr
|
||||||
|
);
|
||||||
self.cleanup_channel(channel);
|
self.cleanup_channel(channel);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -643,6 +682,19 @@ impl russh::server::Handler for SSHandle {
|
|||||||
return Err(russh::Error::Disconnect);
|
return Err(russh::Error::Disconnect);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let user_id = operator.uid.to_string();
|
||||||
|
let repo_path = format!("{}/{}", owner, &repo.repo_name);
|
||||||
|
if !self
|
||||||
|
.rate_limiter
|
||||||
|
.is_repo_access_allowed(&user_id, &repo_path)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
let msg = format!("Rate limit exceeded for repository access: {}", repo_path);
|
||||||
|
warn!(self.logger, "{}", format!("Repo access rate limit exceeded user={} repo={}", operator.username, repo.repo_name));
|
||||||
|
session.disconnect(Disconnect::ByApplication, &msg, "").ok();
|
||||||
|
return Err(russh::Error::Disconnect);
|
||||||
|
}
|
||||||
|
|
||||||
info!(self.logger, "{}", format!("Access granted user={} repo={} is_write={}", operator.username, repo.repo_name, is_write));
|
info!(self.logger, "{}", format!("Access granted user={} repo={} is_write={}", operator.username, repo.repo_name, is_write));
|
||||||
|
|
||||||
let repo_path = PathBuf::from(&repo.storage_path);
|
let repo_path = PathBuf::from(&repo.storage_path);
|
||||||
|
|||||||
@ -19,6 +19,7 @@ use std::time::Duration;
|
|||||||
|
|
||||||
pub mod authz;
|
pub mod authz;
|
||||||
pub mod handle;
|
pub mod handle;
|
||||||
|
pub mod rate_limit;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -147,6 +148,7 @@ impl SSHHandle {
|
|||||||
|
|
||||||
// Start the rate limiter cleanup background task so the HashMap
|
// Start the rate limiter cleanup background task so the HashMap
|
||||||
// doesn't grow unbounded over time.
|
// doesn't grow unbounded over time.
|
||||||
|
let _cleanup = server.rate_limiter.clone().start_cleanup();
|
||||||
let ssh_port = self.app.ssh_port()?;
|
let ssh_port = self.app.ssh_port()?;
|
||||||
let bind_addr = format!("0.0.0.0:{}", ssh_port);
|
let bind_addr = format!("0.0.0.0:{}", ssh_port);
|
||||||
let public_host = self.app.ssh_domain()?;
|
let public_host = self.app.ssh_domain()?;
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
use crate::ssh::ReceiveSyncService;
|
use crate::ssh::ReceiveSyncService;
|
||||||
use crate::ssh::SshTokenService;
|
use crate::ssh::SshTokenService;
|
||||||
use crate::ssh::handle::SSHandle;
|
use crate::ssh::handle::SSHandle;
|
||||||
|
use crate::ssh::rate_limit::SshRateLimiter;
|
||||||
use db::cache::AppCache;
|
use db::cache::AppCache;
|
||||||
use db::database::AppDatabase;
|
use db::database::AppDatabase;
|
||||||
use deadpool_redis::cluster::Pool as RedisPool;
|
use deadpool_redis::cluster::Pool as RedisPool;
|
||||||
@ -8,12 +9,14 @@ use russh::server::Handler;
|
|||||||
use slog::{Logger, info, warn};
|
use slog::{Logger, info, warn};
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
pub struct SSHServer {
|
pub struct SSHServer {
|
||||||
pub db: AppDatabase,
|
pub db: AppDatabase,
|
||||||
pub cache: AppCache,
|
pub cache: AppCache,
|
||||||
pub redis_pool: RedisPool,
|
pub redis_pool: RedisPool,
|
||||||
pub logger: Logger,
|
pub logger: Logger,
|
||||||
|
pub rate_limiter: Arc<SshRateLimiter>,
|
||||||
pub token_service: SshTokenService,
|
pub token_service: SshTokenService,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -30,6 +33,7 @@ impl SSHServer {
|
|||||||
cache,
|
cache,
|
||||||
redis_pool,
|
redis_pool,
|
||||||
logger,
|
logger,
|
||||||
|
rate_limiter: Arc::new(SshRateLimiter::new()),
|
||||||
token_service,
|
token_service,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -39,12 +43,21 @@ impl russh::server::Server for SSHServer {
|
|||||||
|
|
||||||
fn new_client(&mut self, addr: Option<SocketAddr>) -> Self::Handler {
|
fn new_client(&mut self, addr: Option<SocketAddr>) -> Self::Handler {
|
||||||
if let Some(addr) = addr {
|
if let Some(addr) = addr {
|
||||||
|
let ip = addr.ip().to_string();
|
||||||
info!(
|
info!(
|
||||||
self.logger,
|
self.logger,
|
||||||
"New SSH connection from {}:{}",
|
"New SSH connection from {}:{}",
|
||||||
addr.ip(),
|
ip,
|
||||||
addr.port()
|
addr.port()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let rate_limiter = self.rate_limiter.clone();
|
||||||
|
let logger = self.logger.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if !rate_limiter.is_ip_allowed(&ip).await {
|
||||||
|
warn!(logger, "{}", format!("IP rate limit exceeded ip={}", ip));
|
||||||
|
}
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
info!(self.logger, "New SSH connection from unknown address");
|
info!(self.logger, "New SSH connection from unknown address");
|
||||||
}
|
}
|
||||||
@ -54,6 +67,7 @@ impl russh::server::Server for SSHServer {
|
|||||||
self.cache.clone(),
|
self.cache.clone(),
|
||||||
sync_service,
|
sync_service,
|
||||||
self.logger.clone(),
|
self.logger.clone(),
|
||||||
|
self.rate_limiter.clone(),
|
||||||
self.token_service.clone(),
|
self.token_service.clone(),
|
||||||
addr,
|
addr,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -6,15 +6,6 @@
|
|||||||
//! OpenRouter returns rich metadata per model including `context_length`,
|
//! OpenRouter returns rich metadata per model including `context_length`,
|
||||||
//! `pricing`, and `architecture.modality` — these are used to populate all
|
//! `pricing`, and `architecture.modality` — these are used to populate all
|
||||||
//! five model tables without any hard-coded heuristics.
|
//! five model tables without any hard-coded heuristics.
|
||||||
//!
|
|
||||||
//! Usage: call `start_sync_task()` to launch a background task that syncs
|
|
||||||
//! immediately and then every 10 minutes. On app startup, run it once
|
|
||||||
//! eagerly before accepting traffic.
|
|
||||||
|
|
||||||
use std::time::Duration;
|
|
||||||
use tokio::time::interval;
|
|
||||||
use tokio::task::JoinHandle;
|
|
||||||
use slog::Logger;
|
|
||||||
|
|
||||||
use crate::AppService;
|
use crate::AppService;
|
||||||
use crate::error::AppError;
|
use crate::error::AppError;
|
||||||
@ -158,32 +149,131 @@ fn infer_capability(name: &str) -> ModelCapability {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_context_length(ctx: Option<u64>) -> i64 {
|
fn infer_context_length(name: &str) -> i64 {
|
||||||
ctx.map(|c| c as i64).unwrap_or(8_192)
|
let lower = name.to_lowercase();
|
||||||
}
|
// Hard-coded fallback table for known models
|
||||||
|
let fallbacks: &[(&str, i64)] = &[
|
||||||
fn infer_max_output(top_provider_max: Option<u64>) -> Option<i64> {
|
("gpt-4o", 128_000),
|
||||||
top_provider_max.map(|v| v as i64)
|
("chatgpt-4o", 128_000),
|
||||||
}
|
("o1-preview", 128_000),
|
||||||
|
("o1-mini", 65_536),
|
||||||
fn infer_capability_list(arch: &OpenRouterArchitecture) -> Vec<(CapabilityType, bool)> {
|
("o1", 65_536),
|
||||||
// Derive capabilities purely from OpenRouter architecture data.
|
("o3-mini", 65_536),
|
||||||
// FunctionCall is a safe baseline for chat models.
|
("gpt-4-turbo", 128_000),
|
||||||
let mut caps = vec![(CapabilityType::FunctionCall, true)];
|
("gpt-4-32k", 32_768),
|
||||||
|
("gpt-4", 8_192),
|
||||||
// Vision capability from modality.
|
("gpt-4o-mini", 128_000),
|
||||||
if let Some(m) = &arch.modality {
|
("chatgpt-4o-mini", 128_000),
|
||||||
let m = m.to_lowercase();
|
("gpt-3.5-turbo-16k", 16_384),
|
||||||
if m.contains("image") || m.contains("vision") {
|
("gpt-3.5-turbo", 16_385),
|
||||||
caps.push((CapabilityType::Vision, true));
|
("text-embedding-3-large", 8_191),
|
||||||
}
|
("text-embedding-3-small", 8_191),
|
||||||
if m.contains("text") || m.contains("chat") {
|
("text-embedding-ada", 8_191),
|
||||||
caps.push((CapabilityType::ToolUse, true));
|
("dall-e", 4_096),
|
||||||
|
("whisper", 30_000),
|
||||||
|
("gpt-image-1", 16_384),
|
||||||
|
];
|
||||||
|
for (prefix, ctx) in fallbacks {
|
||||||
|
if lower.starts_with(prefix) {
|
||||||
|
return *ctx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
8_192
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_max_output(name: &str, top_provider_max: Option<u64>) -> Option<i64> {
|
||||||
|
if let Some(v) = top_provider_max {
|
||||||
|
return Some(v as i64);
|
||||||
|
}
|
||||||
|
let lower = name.to_lowercase();
|
||||||
|
let fallbacks: &[(&str, i64)] = &[
|
||||||
|
("gpt-4o", 16_384),
|
||||||
|
("chatgpt-4o", 16_384),
|
||||||
|
("o1-preview", 32_768),
|
||||||
|
("o1-mini", 65_536),
|
||||||
|
("o1", 100_000),
|
||||||
|
("o3-mini", 100_000),
|
||||||
|
("gpt-4-turbo", 4_096),
|
||||||
|
("gpt-4-32k", 32_768),
|
||||||
|
("gpt-4", 8_192),
|
||||||
|
("gpt-4o-mini", 16_384),
|
||||||
|
("chatgpt-4o-mini", 16_384),
|
||||||
|
("gpt-3.5-turbo", 4_096),
|
||||||
|
("gpt-image-1", 1_024),
|
||||||
|
];
|
||||||
|
for (prefix, max) in fallbacks {
|
||||||
|
if lower.starts_with(prefix) {
|
||||||
|
return Some(*max);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lower.starts_with("gpt") || lower.starts_with("o1") || lower.starts_with("o3") {
|
||||||
|
Some(4_096)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_capability_list(name: &str) -> Vec<(CapabilityType, bool)> {
|
||||||
|
let lower = name.to_lowercase();
|
||||||
|
let mut caps = Vec::new();
|
||||||
|
caps.push((CapabilityType::FunctionCall, true));
|
||||||
|
|
||||||
|
if lower.contains("gpt-") || lower.contains("o1") || lower.contains("o3") {
|
||||||
|
caps.push((CapabilityType::ToolUse, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
if lower.contains("vision")
|
||||||
|
|| lower.contains("gpt-4o")
|
||||||
|
|| lower.contains("gpt-image")
|
||||||
|
|| lower.contains("dall-e")
|
||||||
|
{
|
||||||
|
caps.push((CapabilityType::Vision, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
if lower.contains("o1") || lower.contains("o3") {
|
||||||
|
caps.push((CapabilityType::Reasoning, true));
|
||||||
|
}
|
||||||
|
|
||||||
caps
|
caps
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn infer_pricing_fallback(name: &str) -> Option<(String, String)> {
|
||||||
|
let lower = name.to_lowercase();
|
||||||
|
if lower.contains("gpt-4o-mini") || lower.contains("chatgpt-4o-mini") {
|
||||||
|
Some(("0.075".to_string(), "0.30".to_string()))
|
||||||
|
} else if lower.contains("gpt-4o") || lower.contains("chatgpt-4o") {
|
||||||
|
Some(("2.50".to_string(), "10.00".to_string()))
|
||||||
|
} else if lower.contains("gpt-4-turbo") {
|
||||||
|
Some(("10.00".to_string(), "30.00".to_string()))
|
||||||
|
} else if lower.contains("gpt-4") && !lower.contains("4o") {
|
||||||
|
Some(("15.00".to_string(), "60.00".to_string()))
|
||||||
|
} else if lower.contains("gpt-3.5-turbo") {
|
||||||
|
Some(("0.50".to_string(), "1.50".to_string()))
|
||||||
|
} else if lower.contains("o1-preview") {
|
||||||
|
Some(("15.00".to_string(), "60.00".to_string()))
|
||||||
|
} else if lower.contains("o1-mini") {
|
||||||
|
Some(("3.00".to_string(), "12.00".to_string()))
|
||||||
|
} else if lower.contains("o1") {
|
||||||
|
Some(("15.00".to_string(), "60.00".to_string()))
|
||||||
|
} else if lower.contains("o3-mini") {
|
||||||
|
Some(("1.50".to_string(), "6.00".to_string()))
|
||||||
|
} else if lower.contains("embedding-3-small") {
|
||||||
|
Some(("0.02".to_string(), "0.00".to_string()))
|
||||||
|
} else if lower.contains("embedding-3-large") {
|
||||||
|
Some(("0.13".to_string(), "0.00".to_string()))
|
||||||
|
} else if lower.contains("embedding-ada") {
|
||||||
|
Some(("0.10".to_string(), "0.00".to_string()))
|
||||||
|
} else if lower.contains("embedding") {
|
||||||
|
Some(("0.10".to_string(), "0.00".to_string()))
|
||||||
|
} else if lower.contains("dall-e") {
|
||||||
|
Some(("0.00".to_string(), "4.00".to_string()))
|
||||||
|
} else if lower.contains("whisper") {
|
||||||
|
Some(("0.00".to_string(), "0.006".to_string()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Provider helpers -----------------------------------------------------------
|
// Provider helpers -----------------------------------------------------------
|
||||||
|
|
||||||
/// Extract provider slug from OpenRouter model ID (e.g. "anthropic/claude-3.5-sonnet" → "anthropic").
|
/// Extract provider slug from OpenRouter model ID (e.g. "anthropic/claude-3.5-sonnet" → "anthropic").
|
||||||
@ -274,10 +364,13 @@ async fn upsert_model(
|
|||||||
let capability = infer_capability(model_id_str);
|
let capability = infer_capability(model_id_str);
|
||||||
|
|
||||||
// OpenRouter context_length takes priority; fall back to inference
|
// OpenRouter context_length takes priority; fall back to inference
|
||||||
let context_length = infer_context_length(or_model.context_length);
|
let context_length = or_model
|
||||||
|
.context_length
|
||||||
|
.map(|c| c as i64)
|
||||||
|
.unwrap_or_else(|| infer_context_length(model_id_str));
|
||||||
|
|
||||||
let max_output =
|
let max_output =
|
||||||
infer_max_output(or_model.top_provider.as_ref().and_then(|p| p.max_completion_tokens));
|
infer_max_output(model_id_str, or_model.top_provider.as_ref().and_then(|p| p.max_completion_tokens));
|
||||||
|
|
||||||
use models::agents::model::Column as MCol;
|
use models::agents::model::Column as MCol;
|
||||||
if let Some(existing) = ModelEntity::find()
|
if let Some(existing) = ModelEntity::find()
|
||||||
@ -349,6 +442,7 @@ async fn upsert_pricing(
|
|||||||
db: &AppDatabase,
|
db: &AppDatabase,
|
||||||
version_uuid: Uuid,
|
version_uuid: Uuid,
|
||||||
pricing: Option<&OpenRouterPricing>,
|
pricing: Option<&OpenRouterPricing>,
|
||||||
|
model_name: &str,
|
||||||
) -> Result<bool, AppError> {
|
) -> Result<bool, AppError> {
|
||||||
use models::agents::model_pricing::Column as PCol;
|
use models::agents::model_pricing::Column as PCol;
|
||||||
let existing = PricingEntity::find()
|
let existing = PricingEntity::find()
|
||||||
@ -359,9 +453,11 @@ async fn upsert_pricing(
|
|||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenRouter prices are per-million-tokens strings; if missing, insert zero prices.
|
|
||||||
let (input_str, output_str) = if let Some(p) = pricing {
|
let (input_str, output_str) = if let Some(p) = pricing {
|
||||||
|
// OpenRouter prices are per-million-tokens strings
|
||||||
(p.prompt.clone(), p.completion.clone())
|
(p.prompt.clone(), p.completion.clone())
|
||||||
|
} else if let Some((i, o)) = infer_pricing_fallback(model_name) {
|
||||||
|
(i, o)
|
||||||
} else {
|
} else {
|
||||||
("0.00".to_string(), "0.00".to_string())
|
("0.00".to_string(), "0.00".to_string())
|
||||||
};
|
};
|
||||||
@ -382,16 +478,10 @@ async fn upsert_pricing(
|
|||||||
async fn upsert_capabilities(
|
async fn upsert_capabilities(
|
||||||
db: &AppDatabase,
|
db: &AppDatabase,
|
||||||
version_uuid: Uuid,
|
version_uuid: Uuid,
|
||||||
arch: Option<&OpenRouterArchitecture>,
|
model_name: &str,
|
||||||
) -> Result<i64, AppError> {
|
) -> Result<i64, AppError> {
|
||||||
use models::agents::model_capability::Column as CCol;
|
use models::agents::model_capability::Column as CCol;
|
||||||
let caps = infer_capability_list(arch.unwrap_or(&OpenRouterArchitecture {
|
let caps = infer_capability_list(model_name);
|
||||||
modality: None,
|
|
||||||
input_modalities: None,
|
|
||||||
output_modalities: None,
|
|
||||||
tokenizer: None,
|
|
||||||
instruct_type: None,
|
|
||||||
}));
|
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
let mut created = 0i64;
|
let mut created = 0i64;
|
||||||
|
|
||||||
@ -456,19 +546,31 @@ async fn upsert_parameter_profile(
|
|||||||
impl AppService {
|
impl AppService {
|
||||||
/// Sync models from OpenRouter into the local database.
|
/// Sync models from OpenRouter into the local database.
|
||||||
///
|
///
|
||||||
/// Calls OpenRouter's public `GET /api/v1/models` endpoint (no auth required),
|
/// Calls OpenRouter's `GET /api/v1/models` using `OPENROUTER_API_KEY`
|
||||||
/// then upserts provider / model / version / pricing / capability /
|
/// (falls back to `AI_API_KEY` if not set), then upserts provider /
|
||||||
/// parameter-profile records.
|
/// model / version / pricing / capability / parameter-profile records.
|
||||||
///
|
///
|
||||||
/// OpenRouter returns `context_length`, `pricing`, and `architecture.modality`
|
/// OpenRouter returns `context_length`, `pricing`, and `architecture.modality`
|
||||||
/// per model — these drive all field population. No model names are hardcoded.
|
/// per model — these drive all inference-free field population.
|
||||||
|
/// Capabilities are still inferred from model name patterns.
|
||||||
pub async fn sync_upstream_models(
|
pub async fn sync_upstream_models(
|
||||||
&self,
|
&self,
|
||||||
_ctx: &Session,
|
_ctx: &Session,
|
||||||
) -> Result<SyncModelsResponse, AppError> {
|
) -> Result<SyncModelsResponse, AppError> {
|
||||||
|
// Resolve API key: prefer OPENROUTER_API_KEY env var, fall back to AI_API_KEY.
|
||||||
|
let api_key = std::env::var("OPENROUTER_API_KEY")
|
||||||
|
.ok()
|
||||||
|
.or_else(|| self.config.ai_api_key().ok())
|
||||||
|
.ok_or_else(|| {
|
||||||
|
AppError::InternalServerError(
|
||||||
|
"OPENROUTER_API_KEY or AI_API_KEY must be configured to sync models".into(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let resp: OpenRouterResponse = client
|
let resp: OpenRouterResponse = client
|
||||||
.get("https://openrouter.ai/api/v1/models")
|
.get("https://openrouter.ai/api/v1/models")
|
||||||
|
.header("Authorization", format!("Bearer {api_key}"))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| AppError::InternalServerError(format!("OpenRouter API request failed: {}", e)))?
|
.map_err(|e| AppError::InternalServerError(format!("OpenRouter API request failed: {}", e)))?
|
||||||
@ -504,29 +606,26 @@ impl AppService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let (version_record, version_is_new) =
|
let (version_record, version_is_new) =
|
||||||
match upsert_version(&self.db, model_record.id).await {
|
upsert_version(&self.db, model_record.id).await?;
|
||||||
Ok(v) => v,
|
|
||||||
Err(e) => {
|
|
||||||
slog::warn!(self.logs, "{}", format!("sync_upstream_models: upsert_version error: {:?}", e));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if version_is_new {
|
if version_is_new {
|
||||||
versions_created += 1;
|
versions_created += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(e) = upsert_pricing(&self.db, version_record.id, or_model.pricing.as_ref()).await {
|
if upsert_pricing(
|
||||||
slog::warn!(self.logs, "{}", format!("sync_upstream_models: upsert_pricing error: {:?}", e));
|
&self.db,
|
||||||
} else {
|
version_record.id,
|
||||||
|
or_model.pricing.as_ref(),
|
||||||
|
&or_model.id,
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
{
|
||||||
pricing_created += 1;
|
pricing_created += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
capabilities_created +=
|
capabilities_created +=
|
||||||
upsert_capabilities(&self.db, version_record.id, or_model.architecture.as_ref())
|
upsert_capabilities(&self.db, version_record.id, &or_model.id).await?;
|
||||||
.await
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
if upsert_parameter_profile(&self.db, version_record.id, &or_model.id).await.unwrap_or(false) {
|
if upsert_parameter_profile(&self.db, version_record.id, &or_model.id).await? {
|
||||||
profiles_created += 1;
|
profiles_created += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -540,133 +639,4 @@ impl AppService {
|
|||||||
profiles_created,
|
profiles_created,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Spawn a background task that syncs OpenRouter models immediately
|
|
||||||
/// and then every 10 minutes. Returns the `JoinHandle`.
|
|
||||||
///
|
|
||||||
/// Failures are logged but do not stop the task — it keeps retrying.
|
|
||||||
pub fn start_sync_task(self) -> JoinHandle<()> {
|
|
||||||
let db = self.db.clone();
|
|
||||||
let log = self.logs.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
// Run once immediately on startup before taking traffic.
|
|
||||||
Self::sync_once(&db, &log).await;
|
|
||||||
|
|
||||||
let mut tick = interval(Duration::from_secs(60 * 10));
|
|
||||||
loop {
|
|
||||||
tick.tick().await;
|
|
||||||
Self::sync_once(&db, &log).await;
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Perform a single sync pass. Errors are logged and silently swallowed
|
|
||||||
/// so the periodic task never stops.
|
|
||||||
async fn sync_once(db: &AppDatabase, log: &Logger) {
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
let resp = match client
|
|
||||||
.get("https://openrouter.ai/api/v1/models")
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(r) => match r.error_for_status() {
|
|
||||||
Ok(resp) => match resp.json::<OpenRouterResponse>().await {
|
|
||||||
Ok(resp) => resp,
|
|
||||||
Err(e) => {
|
|
||||||
slog::error!(log, "{}", format!("OpenRouter model sync: failed to parse response: {}", e));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
slog::error!(log, "{}", format!("OpenRouter model sync: API error: {}", e));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
slog::error!(log, "{}", format!("OpenRouter model sync: request failed: {}", e));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut models_created = 0i64;
|
|
||||||
let mut models_updated = 0i64;
|
|
||||||
let mut versions_created = 0i64;
|
|
||||||
let mut pricing_created = 0i64;
|
|
||||||
let mut capabilities_created = 0i64;
|
|
||||||
let mut profiles_created = 0i64;
|
|
||||||
|
|
||||||
for or_model in resp.data {
|
|
||||||
if or_model.id == "openrouter/auto" {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let provider_slug = extract_provider(&or_model.id);
|
|
||||||
let provider = match upsert_provider(db, provider_slug).await {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(e) => {
|
|
||||||
slog::warn!(log, "{}", format!("OpenRouter model sync: upsert_provider error: {:?}", e));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let model_record = match upsert_model(db, provider.id, &or_model.id, &or_model).await {
|
|
||||||
Ok((m, true)) => {
|
|
||||||
models_created += 1;
|
|
||||||
m
|
|
||||||
}
|
|
||||||
Ok((m, false)) => {
|
|
||||||
models_updated += 1;
|
|
||||||
m
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
slog::warn!(log, "{}", format!("OpenRouter model sync: upsert_model error: {:?}", e));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let (version_record, version_is_new) = match upsert_version(db, model_record.id).await {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(e) => {
|
|
||||||
slog::warn!(log, "{}", format!("OpenRouter model sync: upsert_version error: {:?}", e));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if version_is_new {
|
|
||||||
versions_created += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
if upsert_pricing(db, version_record.id, or_model.pricing.as_ref())
|
|
||||||
.await
|
|
||||||
.unwrap_or(false)
|
|
||||||
{
|
|
||||||
pricing_created += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
capabilities_created +=
|
|
||||||
upsert_capabilities(db, version_record.id, or_model.architecture.as_ref())
|
|
||||||
.await
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
if upsert_parameter_profile(db, version_record.id, &or_model.id)
|
|
||||||
.await
|
|
||||||
.unwrap_or(false)
|
|
||||||
{
|
|
||||||
profiles_created += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
slog::info!(log, "{}",
|
|
||||||
format!(
|
|
||||||
"OpenRouter model sync complete: created={} updated={} \
|
|
||||||
versions={} pricing={} capabilities={} profiles={}",
|
|
||||||
models_created,
|
|
||||||
models_updated,
|
|
||||||
versions_created,
|
|
||||||
pricing_created,
|
|
||||||
capabilities_created,
|
|
||||||
profiles_created
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user