use std::{io, net::SocketAddr, sync::Arc}; use cache::AppCache; use db::database::AppDatabase; use deadpool_redis::cluster::Pool as RedisPool; use russh::server::Handler; use crate::{ ssh::{SshTokenService, handler::SSHandle, rate_limit::SshRateLimiter}, sync::ReceiveSyncService, }; pub struct SSHServer { pub db: AppDatabase, pub cache: AppCache, pub redis_pool: RedisPool, pub token_service: SshTokenService, pub rate_limiter: Arc, } impl SSHServer { pub fn new( db: AppDatabase, cache: AppCache, redis_pool: RedisPool, token_service: SshTokenService, ) -> Self { SSHServer { db, cache, redis_pool, token_service, rate_limiter: Arc::new(SshRateLimiter::new()), } } } impl russh::server::Server for SSHServer { type Handler = SSHandle; #[tracing::instrument(skip(self), fields(peer = ?addr))] fn new_client(&mut self, addr: Option) -> Self::Handler { if let Some(addr) = addr { let ip = addr.ip().to_string(); tracing::info!("New SSH connection ip={} port={}", ip, addr.port()); let limiter = self.rate_limiter.clone(); let ip_clone = ip.clone(); tokio::spawn(async move { if !limiter.is_ip_allowed(&ip_clone).await { tracing::warn!(ip = %ip_clone, "SSH connection rate limited"); } }); } else { tracing::info!("New SSH connection from unknown address"); } let sync_service = ReceiveSyncService::new(self.redis_pool.clone()); SSHandle::new( self.db.clone(), self.cache.clone(), sync_service, self.token_service.clone(), addr, ) } fn handle_session_error( &mut self, error: ::Error, ) { match error { russh::Error::Disconnect => { tracing::info!("Connection disconnected by peer"); } russh::Error::Inconsistent => { tracing::warn!("Protocol inconsistency detected"); } russh::Error::NotAuthenticated => { tracing::warn!("Authentication failed"); } russh::Error::IO(ref io_err) => { tracing::warn!( "SSH IO error kind={:?} message={} raw_os_error={:?}", io_err.kind(), io_err, io_err.raw_os_error() ); if io_err.kind() == io::ErrorKind::UnexpectedEof { tracing::warn!( "SSH peer closed the connection before a clean disconnect was received" ); } } _ => { tracing::warn!("SSH session error error={}", error); } } } }