refactor(git/ssh): extract helper functions into dedicated modules

Move RefUpdate, GitService, branch_protection check, and forward
function from handle.rs into separate modules.
This commit is contained in:
ZhenYi 2026-05-11 17:05:30 +08:00
parent deb25614ba
commit 3b17a0493f
6 changed files with 254 additions and 246 deletions

View File

@ -0,0 +1,67 @@
use crate::ssh::ref_update::RefUpdate;
use models::repos::repo_branch_protect;
/// Ref name matches a protection rule exactly, or as a directory prefix
/// (e.g. "refs/heads/main" matches "refs/heads/main" and "refs/heads/main/*"
/// but NOT "refs/heads/main-v2").
fn ref_matches_protection(ref_name: &str, protection_branch: &str) -> bool {
ref_name == protection_branch
|| ref_name.starts_with(&format!("{}/", protection_branch))
}
/// Granular branch protection check (same logic as HTTP handler).
/// Returns `Some(error_message)` if the push should be rejected.
pub fn check_branch_protection(
branch_protects: &[repo_branch_protect::Model],
r#ref: &RefUpdate,
) -> Option<String> {
for protection in branch_protects {
if !ref_matches_protection(&r#ref.name, &protection.branch) {
continue;
}
// Check deletion (new_oid is all zeros)
if r#ref.new_oid == "0000000000000000000000000000000000000000" {
if protection.forbid_deletion {
return Some(format!(
"Deletion of protected branch '{}' is forbidden",
r#ref.name
));
}
continue;
}
// Check tag push
if r#ref.name.starts_with("refs/tags/") {
if protection.forbid_tag_push {
return Some(format!(
"Tag push to protected branch '{}' is forbidden",
r#ref.name
));
}
continue;
}
// Check force push: old != new AND old is non-zero (non-fast-forward)
let is_new_branch = r#ref.old_oid == "0000000000000000000000000000000000000000";
if !is_new_branch
&& r#ref.old_oid != r#ref.new_oid
&& r#ref.name.starts_with("refs/heads/")
&& protection.forbid_force_push
{
return Some(format!(
"Force push to protected branch '{}' is forbidden",
r#ref.name
));
}
// Check push
if protection.forbid_push {
return Some(format!(
"Push to protected branch '{}' is forbidden",
r#ref.name
));
}
}
None
}

50
libs/git/ssh/forward.rs Normal file
View File

@ -0,0 +1,50 @@
use russh::server::Handle;
use russh::ChannelId;
use std::future::Future;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio::time::sleep;
use tokio_util::bytes::Bytes;
pub async fn forward<'a, R, Fut, Fwd>(
session_handle: &'a Handle,
chan_id: ChannelId,
r: &mut R,
mut fwd: Fwd,
) -> Result<(), russh::Error>
where
R: AsyncRead + Send + Unpin,
Fut: Future<Output = Result<(), Bytes>> + 'a,
Fwd: FnMut(&'a Handle, ChannelId, Bytes) -> Fut,
{
const BUF_SIZE: usize = 1024 * 32;
const MAX_RETRIES: usize = 5;
const RETRY_DELAY: u64 = 10; // ms
let mut buf = [0u8; BUF_SIZE];
loop {
let read = r.read(&mut buf).await?;
if read == 0 {
break;
}
let mut chunk = Bytes::copy_from_slice(&buf[..read]);
let mut retries = 0;
loop {
match fwd(session_handle, chan_id, chunk).await {
Ok(()) => break,
Err(unsent) => {
retries += 1;
if retries >= MAX_RETRIES {
return Ok(());
}
chunk = unsent;
sleep(Duration::from_millis(RETRY_DELAY)).await;
}
}
}
}
Ok(())
}

View File

@ -0,0 +1,82 @@
use std::path::PathBuf;
use std::str::FromStr;
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum GitService {
UploadPack,
ReceivePack,
UploadArchive,
}
impl FromStr for GitService {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"upload-pack" => Ok(Self::UploadPack),
"receive-pack" => Ok(Self::ReceivePack),
"upload-archive" => Ok(Self::UploadArchive),
_ => Err(()),
}
}
}
pub fn parse_git_command(cmd: &str) -> Option<(GitService, &str)> {
let (svc, path) = match cmd.split_once(' ') {
Some(("git-receive-pack", path)) => (GitService::ReceivePack, path),
Some(("git-upload-pack", path)) => (GitService::UploadPack, path),
Some(("git-upload-archive", path)) => (GitService::UploadArchive, path),
_ => return None,
};
Some((svc, strip_apostrophes(path)))
}
pub fn parse_repo_path(path: &str) -> Option<(&str, &str)> {
let path = path.trim_matches('/');
let mut parts = path.splitn(2, '/');
match (parts.next(), parts.next()) {
(Some(owner), Some(repo)) if !owner.is_empty() && !repo.is_empty() => Some((owner, repo)),
_ => None,
}
}
pub fn build_git_command(service: GitService, path: PathBuf) -> tokio::process::Command {
let mut cmd = tokio::process::Command::new("git");
let cwd = match path.canonicalize() {
Ok(p) => p,
Err(e) => {
tracing::debug!(error = %e, "path canonicalize failed, falling back to raw path");
path.clone()
}
};
cmd.current_dir(cwd);
match service {
GitService::UploadPack => { cmd.arg("upload-pack"); }
GitService::ReceivePack => { cmd.arg("receive-pack"); }
GitService::UploadArchive => { cmd.arg("upload-archive"); }
}
cmd.arg(".")
.env("GIT_CONFIG_NOSYSTEM", "1")
.env("GIT_NO_REPLACE_OBJECTS", "1");
#[cfg(unix)]
{
cmd.env("GIT_CONFIG_GLOBAL", "/dev/null")
.env("GIT_CONFIG_SYSTEM", "/dev/null");
}
#[cfg(windows)]
{
let nul = "NUL";
cmd.env("GIT_CONFIG_GLOBAL", nul)
.env("GIT_CONFIG_SYSTEM", nul);
}
cmd
}
fn strip_apostrophes(s: &str) -> &str {
s.trim_matches('\'')
}

View File

@ -7,8 +7,13 @@ use db::database::AppDatabase;
use models::repos::{repo, repo_branch_protect};
use models::users::user;
use russh::keys::{Certificate, PublicKey};
use russh::server::{Auth, Handle, Msg, Session};
use russh::{Channel, ChannelId, CryptoVec, Disconnect};
use russh::server::{Auth, Msg, Session};
use russh::{Channel, ChannelId, Disconnect};
use crate::ssh::ref_update::RefUpdate;
use crate::ssh::git_service::{GitService, parse_git_command, parse_repo_path, build_git_command};
use crate::ssh::branch_protect::check_branch_protection;
use crate::ssh::forward::forward;
use tokio_util::bytes::Bytes;
use sea_orm::ColumnTrait;
use sea_orm::EntityTrait;
use sea_orm::QueryFilter;
@ -17,53 +22,13 @@ use std::io;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::process::Stdio;
use std::str::FromStr;
use std::time::Duration;
const PRE_PACK_LIMIT: usize = 1_048_576;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tokio::io::AsyncWriteExt;
use tokio::process::ChildStdin;
use tokio::sync::mpsc::Sender;
use tokio::time::sleep;
#[derive(Clone, Debug)]
pub struct RefUpdate {
pub name: String,
pub old_oid: String,
pub new_oid: String,
}
impl RefUpdate {
/// Parse git reference update commands from SSH protocol text.
/// Format: "<old-oid> <new-oid> <ref-name>\n"
pub fn parse_ref_updates(data: &[u8]) -> Result<Vec<Self>, String> {
let text = String::from_utf8_lossy(data);
let mut refs = Vec::new();
for line in text.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') || line.starts_with("PACK") {
continue;
}
let mut parts = line.split_whitespace();
let old_oid = parts.next().map(|s| s.to_string()).unwrap_or_default();
let new_oid = parts.next().map(|s| s.to_string()).unwrap_or_default();
let name = parts
.next()
.unwrap_or("")
.trim_start_matches('\0')
.to_string();
if !name.is_empty() {
refs.push(RefUpdate {
old_oid,
new_oid,
name,
});
}
}
Ok(refs)
}
}
pub struct SSHandle {
pub repo: Option<PathBuf>,
pub model: Option<repo::Model>,
@ -379,7 +344,7 @@ impl russh::server::Handler for SSHandle {
let _ = session.extended_data(
channel,
1,
CryptoVec::from_slice(msg.as_bytes()),
Bytes::copy_from_slice(msg.as_bytes()),
);
let _ = session.exit_status_request(channel, 1);
let _ = session.eof(channel);
@ -416,7 +381,7 @@ impl russh::server::Handler for SSHandle {
let _ = session.extended_data(
channel,
1,
CryptoVec::from_slice(full_msg.as_bytes()),
Bytes::copy_from_slice(full_msg.as_bytes()),
);
let _ = session.exit_status_request(channel, 1);
let _ = session.eof(channel);
@ -480,7 +445,7 @@ impl russh::server::Handler for SSHandle {
tracing::info!("shell_request user={}", user.username);
let _ = session
.data(channel_id, CryptoVec::from_slice(welcome_msg.as_bytes()));
.data(channel_id, Bytes::copy_from_slice(welcome_msg.as_bytes()));
let _ = session.exit_status_request(channel_id, 0);
let _ = session.eof(channel_id);
let _ = session.close(channel_id);
@ -489,7 +454,7 @@ impl russh::server::Handler for SSHandle {
tracing::warn!("shell_request_unauthenticated channel={:?}", channel_id);
let msg = "Authentication required\r\n";
let _ = session
.data(channel_id, CryptoVec::from_slice(msg.as_bytes()));
.data(channel_id, Bytes::copy_from_slice(msg.as_bytes()));
let _ = session.exit_status_request(channel_id, 1);
let _ = session.eof(channel_id);
let _ = session.close(channel_id);
@ -733,200 +698,3 @@ impl russh::server::Handler for SSHandle {
Ok(())
}
}
fn parse_git_command(cmd: &str) -> Option<(GitService, &str)> {
let (svc, path) = match cmd.split_once(' ') {
Some(("git-receive-pack", path)) => (GitService::ReceivePack, path),
Some(("git-upload-pack", path)) => (GitService::UploadPack, path),
Some(("git-upload-archive", path)) => (GitService::UploadArchive, path),
_ => return None,
};
Some((svc, strip_apostrophes(path)))
}
fn parse_repo_path(path: &str) -> Option<(&str, &str)> {
let path = path.trim_matches('/');
let mut parts = path.splitn(2, '/');
match (parts.next(), parts.next()) {
(Some(owner), Some(repo)) if !owner.is_empty() && !repo.is_empty() => Some((owner, repo)),
_ => None,
}
}
fn build_git_command(service: GitService, path: PathBuf) -> tokio::process::Command {
let mut cmd = tokio::process::Command::new("git");
// Canonicalize only for validation; if it fails, fall back to the raw path.
// Using canonicalize for current_dir is safe since we validate repo existence
// before reaching this point.
let cwd = match path.canonicalize() {
Ok(p) => p,
Err(e) => {
tracing::debug!(error = %e, "path canonicalize failed, falling back to raw path");
path.clone()
}
};
cmd.current_dir(cwd);
match service {
GitService::UploadPack => { cmd.arg("upload-pack"); }
GitService::ReceivePack => { cmd.arg("receive-pack"); }
GitService::UploadArchive => { cmd.arg("upload-archive"); }
}
cmd.arg(".")
.env("GIT_CONFIG_NOSYSTEM", "1")
.env("GIT_NO_REPLACE_OBJECTS", "1");
#[cfg(unix)]
{
cmd.env("GIT_CONFIG_GLOBAL", "/dev/null")
.env("GIT_CONFIG_SYSTEM", "/dev/null");
}
#[cfg(windows)]
{
// On Windows, /dev/null doesn't exist. Set invalid paths so git
// ignores them without crashing. GIT_CONFIG_NOSYSTEM already disables
// the system config.
let nul = "NUL";
cmd.env("GIT_CONFIG_GLOBAL", nul)
.env("GIT_CONFIG_SYSTEM", nul);
}
cmd
}
fn strip_apostrophes(s: &str) -> &str {
s.trim_matches('\'')
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum GitService {
UploadPack,
ReceivePack,
UploadArchive,
}
impl FromStr for GitService {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"upload-pack" => Ok(Self::UploadPack),
"receive-pack" => Ok(Self::ReceivePack),
"upload-archive" => Ok(Self::UploadArchive),
_ => Err(()),
}
}
}
/// Ref name matches a protection rule exactly, or as a directory prefix
/// (e.g. "refs/heads/main" matches "refs/heads/main" and "refs/heads/main/*"
/// but NOT "refs/heads/main-v2").
fn ref_matches_protection(ref_name: &str, protection_branch: &str) -> bool {
ref_name == protection_branch
|| ref_name.starts_with(&format!("{}/", protection_branch))
}
/// Granular branch protection check (same logic as HTTP handler).
/// Returns `Some(error_message)` if the push should be rejected.
fn check_branch_protection(
branch_protects: &[repo_branch_protect::Model],
r#ref: &RefUpdate,
) -> Option<String> {
for protection in branch_protects {
if !ref_matches_protection(&r#ref.name, &protection.branch) {
continue;
}
// Check deletion (new_oid is all zeros)
if r#ref.new_oid == "0000000000000000000000000000000000000000" {
if protection.forbid_deletion {
return Some(format!(
"Deletion of protected branch '{}' is forbidden",
r#ref.name
));
}
continue;
}
// Check tag push
if r#ref.name.starts_with("refs/tags/") {
if protection.forbid_tag_push {
return Some(format!(
"Tag push to protected branch '{}' is forbidden",
r#ref.name
));
}
continue;
}
// Check force push: old != new AND old is non-zero (non-fast-forward)
let is_new_branch = r#ref.old_oid == "0000000000000000000000000000000000000000";
if !is_new_branch
&& r#ref.old_oid != r#ref.new_oid
&& r#ref.name.starts_with("refs/heads/")
&& protection.forbid_force_push
{
return Some(format!(
"Force push to protected branch '{}' is forbidden",
r#ref.name
));
}
// Check push
if protection.forbid_push {
return Some(format!(
"Push to protected branch '{}' is forbidden",
r#ref.name
));
}
}
None
}
async fn forward<'a, R, Fut, Fwd>(
session_handle: &'a Handle,
chan_id: ChannelId,
r: &mut R,
mut fwd: Fwd,
) -> Result<(), russh::Error>
where
R: AsyncRead + Send + Unpin,
Fut: Future<Output = Result<(), CryptoVec>> + 'a,
Fwd: FnMut(&'a Handle, ChannelId, CryptoVec) -> Fut,
{
const BUF_SIZE: usize = 1024 * 32;
const MAX_RETRIES: usize = 5;
const RETRY_DELAY: u64 = 10; // ms
let mut buf = [0u8; BUF_SIZE];
loop {
let read = r.read(&mut buf).await?;
if read == 0 {
break;
}
let mut chunk = CryptoVec::from_slice(&buf[..read]);
let mut retries = 0;
loop {
match fwd(session_handle, chan_id, chunk).await {
Ok(()) => break,
Err(unsent) => {
retries += 1;
if retries >= MAX_RETRIES {
// Give up — connection is likely broken. Returning Ok (not Err)
// so the outer task can clean up gracefully without logging
// a spurious error for a normal disconnection.
return Ok(());
}
chunk = unsent;
sleep(Duration::from_millis(RETRY_DELAY)).await;
}
}
}
}
Ok(())
}

View File

@ -16,9 +16,13 @@ use std::sync::Arc;
use std::time::Duration;
pub mod authz;
pub mod branch_protect;
pub mod forward;
pub mod git_service;
pub mod handle;
pub mod server;
pub mod rate_limit;
pub mod ref_update;
pub mod server;
#[derive(Clone)]
pub struct SSHHandle {
pub db: AppDatabase,
@ -106,7 +110,7 @@ impl SSHHandle {
let mut config = Config::default();
config.keys = vec![private_key];
let version = format!("SSH-2.0-GitdataAI {}", env!("CARGO_PKG_VERSION"));
config.server_id = SshId::Standard(version);
config.server_id = SshId::Standard(version.into());
let mut method = MethodSet::empty();
method.push(MethodKind::PublicKey);
method.push(MethodKind::KeyboardInteractive);

View File

@ -0,0 +1,37 @@
#[derive(Clone, Debug)]
pub struct RefUpdate {
pub name: String,
pub old_oid: String,
pub new_oid: String,
}
impl RefUpdate {
/// Parse git reference update commands from SSH protocol text.
/// Format: "<old-oid> <new-oid> <ref-name>\n"
pub fn parse_ref_updates(data: &[u8]) -> Result<Vec<Self>, String> {
let text = String::from_utf8_lossy(data);
let mut refs = Vec::new();
for line in text.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') || line.starts_with("PACK") {
continue;
}
let mut parts = line.split_whitespace();
let old_oid = parts.next().map(|s| s.to_string()).unwrap_or_default();
let new_oid = parts.next().map(|s| s.to_string()).unwrap_or_default();
let name = parts
.next()
.unwrap_or("")
.trim_start_matches('\0')
.to_string();
if !name.is_empty() {
refs.push(RefUpdate {
old_oid,
new_oid,
name,
});
}
}
Ok(refs)
}
}