gitdataai/lib/service/pull_request/pull_request.rs
2026-05-30 01:38:40 +08:00

410 lines
13 KiB
Rust

use db::{sqlx, sqlx::AssertSqlSafe};
use model::pull_request::PullRequestModel;
use serde::Deserialize;
use session::Session;
use crate::{
AppService, Pagination,
error::AppError,
issues::types::issue_author,
pull_request::types::{PullRequestFilter, PullRequestResponse},
session_user,
};
#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)]
pub struct CreatePullRequest {
pub title: String,
pub body: Option<String>,
pub source_branch: String,
pub target_branch: Option<String>,
pub source_repo: Option<String>,
pub draft: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, utoipa::ToSchema)]
pub struct UpdatePullRequest {
pub title: Option<String>,
pub body: Option<Option<String>>,
pub draft: Option<bool>,
pub state: Option<String>,
}
impl AppService {
pub async fn pr_create(
&self,
ctx: &Session,
wk_name: &str,
repo_name: &str,
params: CreatePullRequest,
) -> Result<PullRequestResponse, AppError> {
let user_uid = session_user(ctx)?;
let (repo_id, repo) =
self.pr_resolve_repo(ctx, wk_name, repo_name).await?;
let title = params.title.trim();
if title.is_empty() {
return Err(AppError::BadRequest(
"pull request title is required".to_string(),
));
}
let target_branch = params
.target_branch
.unwrap_or_else(|| repo.default_branch.clone());
let source_repo_name =
params.source_repo.unwrap_or_else(|| repo_name.to_string());
let source_repo_id = if source_repo_name == repo_name {
repo_id
} else {
let wk = self.workspace_resolve(wk_name).await?;
let source_repo =
self.repo_resolve(wk.id, &source_repo_name).await?;
source_repo.id
};
let source_sha = self
.branch_head_sha(source_repo_id, &params.source_branch)
.await?;
let target_sha = self.branch_head_sha(repo_id, &target_branch).await?;
if source_sha == target_sha {
return Err(AppError::Conflict(
"source and target branches are the same".to_string(),
));
}
let now = chrono::Utc::now();
let id = uuid::Uuid::now_v7();
let pr = sqlx::query_as::<_, PullRequestModel>(
"INSERT INTO pull_request (id, repo, number, title, body, state, draft, author, \
source_repo, source_branch, source_sha, target_branch, target_sha, \
merged_by, merged_at, closed_by, closed_at, created_at, updated_at) \
VALUES ($1, $2, (SELECT COALESCE(MAX(number), 0) + 1 FROM pull_request WHERE repo = $2 AND deleted_at IS NULL), \
$3, $4, 'open', $5, $6, $7, $8, $9, $10, $11, NULL, NULL, NULL, NULL, $12, $12) \
RETURNING id, repo, number, title, body, state, draft, author, \
source_repo, source_branch, source_sha, target_branch, target_sha, \
merged_by, merged_at, closed_by, closed_at, created_at, updated_at, deleted_at",
)
.bind(id)
.bind(repo_id)
.bind(title)
.bind(&params.body)
.bind(params.draft.unwrap_or(false))
.bind(user_uid)
.bind(source_repo_id)
.bind(&params.source_branch)
.bind(&source_sha)
.bind(&target_branch)
.bind(&target_sha)
.bind(now)
.fetch_one(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
self.pr_build_response(pr).await
}
pub async fn pr_list(
&self,
ctx: &Session,
wk_name: &str,
repo_name: &str,
filter: PullRequestFilter,
pagination: Pagination,
) -> Result<Vec<PullRequestResponse>, AppError> {
let user_uid = session_user(ctx)?;
let wk = self.workspace_resolve(wk_name).await?;
self.workspace_require_member(wk.id, user_uid).await?;
let repo = self.repo_resolve(wk.id, repo_name).await?;
let mut conditions = vec![
"pr.repo = $1".to_string(),
"pr.deleted_at IS NULL".to_string(),
];
let mut param_idx = 2;
if filter.state.is_some() {
conditions.push(format!("pr.state = ${param_idx}"));
param_idx += 1;
}
if filter.author.is_some() {
conditions.push(format!(
"EXISTS(SELECT 1 FROM \"user\" u WHERE u.id = pr.author AND u.username = ${param_idx})"
));
param_idx += 1;
}
if filter.assignee.is_some() {
conditions.push(format!(
"EXISTS(SELECT 1 FROM pull_request_assignee pa INNER JOIN \"user\" u ON u.id = pa.\"user\" \
WHERE pa.pull_request = pr.id AND u.username = ${param_idx})"
));
param_idx += 1;
}
if filter.label.is_some() {
conditions.push(format!(
"EXISTS(SELECT 1 FROM pull_request_label pl INNER JOIN label l ON l.id = pl.label \
WHERE pl.pull_request = pr.id AND l.name = ${param_idx})"
));
param_idx += 1;
}
let where_clause = conditions.join(" AND ");
let limit_idx = param_idx;
let offset_idx = param_idx + 1;
let query = format!(
"SELECT pr.id, pr.repo, pr.number, pr.title, pr.body, pr.state, pr.draft, pr.author, \
pr.source_repo, pr.source_branch, pr.source_sha, pr.target_branch, pr.target_sha, \
pr.merged_by, pr.merged_at, pr.closed_by, pr.closed_at, pr.created_at, pr.updated_at, pr.deleted_at \
FROM pull_request pr WHERE {where_clause} \
ORDER BY pr.created_at DESC LIMIT ${limit_idx} OFFSET ${offset_idx}"
);
let mut q = sqlx::query_as::<_, PullRequestModel>(AssertSqlSafe(query))
.bind(repo.id);
if let Some(state) = &filter.state {
q = q.bind(state);
}
if let Some(author) = &filter.author {
q = q.bind(author);
}
if let Some(assignee) = &filter.assignee {
q = q.bind(assignee);
}
if let Some(label) = &filter.label {
q = q.bind(label);
}
q = q
.bind(pagination.limit() as i64)
.bind(pagination.offset() as i64);
let prs = q
.fetch_all(self.db.reader())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let mut results = Vec::new();
for pr in prs {
results.push(self.pr_build_response(pr).await?);
}
Ok(results)
}
pub async fn pr_get(
&self,
ctx: &Session,
wk_name: &str,
repo_name: &str,
number: i64,
) -> Result<PullRequestResponse, AppError> {
let (repo_id, _) =
self.pr_resolve_repo(ctx, wk_name, repo_name).await?;
let pr = self.pr_resolve(repo_id, number).await?;
self.pr_build_response(pr).await
}
pub async fn pr_update(
&self,
ctx: &Session,
wk_name: &str,
repo_name: &str,
number: i64,
params: UpdatePullRequest,
) -> Result<PullRequestResponse, AppError> {
if let Some(ref state) = params.state {
return match state.as_str() {
"closed" => self.pr_close(ctx, wk_name, repo_name, number).await,
"open" => self.pr_reopen(ctx, wk_name, repo_name, number).await,
other => Err(AppError::BadRequest(format!(
"invalid state '{}': must be 'open' or 'closed'", other
))),
};
}
let user_uid = session_user(ctx)?;
let (repo_id, _) =
self.pr_resolve_repo(ctx, wk_name, repo_name).await?;
let mut pr = self.pr_resolve(repo_id, number).await?;
if pr.author != user_uid {
return Err(AppError::Forbidden(
"only the author can update this pull request".to_string(),
));
}
if pr.state != "open" {
return Err(AppError::BadRequest(
"cannot update a closed or merged pull request".to_string(),
));
}
let next_title = params
.title
.map(|t| t.trim().to_string())
.unwrap_or(pr.title.clone());
if next_title.is_empty() {
return Err(AppError::BadRequest(
"pull request title is required".to_string(),
));
}
let next_body = params.body.map(Some).unwrap_or(Some(pr.body.clone()));
let next_draft = params.draft.unwrap_or(pr.draft);
let now = chrono::Utc::now();
pr = sqlx::query_as::<_, PullRequestModel>(
"UPDATE pull_request SET title = $1, body = $2, draft = $3, updated_at = $4 WHERE id = $5 \
RETURNING id, repo, number, title, body, state, draft, author, \
source_repo, source_branch, source_sha, target_branch, target_sha, \
merged_by, merged_at, closed_by, closed_at, created_at, updated_at, deleted_at",
)
.bind(&next_title)
.bind(&next_body)
.bind(next_draft)
.bind(now)
.bind(pr.id)
.fetch_one(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
self.pr_build_response(pr).await
}
pub async fn pr_close(
&self,
ctx: &Session,
wk_name: &str,
repo_name: &str,
number: i64,
) -> Result<PullRequestResponse, AppError> {
let user_uid = session_user(ctx)?;
let (repo_id, _) =
self.pr_resolve_repo(ctx, wk_name, repo_name).await?;
let pr = self.pr_resolve(repo_id, number).await?;
if pr.state != "open" {
return Err(AppError::BadRequest(
"pull request is already closed or merged".to_string(),
));
}
let now = chrono::Utc::now();
sqlx::query(
"UPDATE pull_request SET state = 'closed', closed_by = $1, closed_at = $2, updated_at = $2 WHERE id = $3",
)
.bind(user_uid)
.bind(now)
.bind(pr.id)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let pr = self.pr_resolve(repo_id, number).await?;
self.pr_build_response(pr).await
}
pub async fn pr_reopen(
&self,
ctx: &Session,
wk_name: &str,
repo_name: &str,
number: i64,
) -> Result<PullRequestResponse, AppError> {
let _user_uid = session_user(ctx)?;
let (repo_id, _) =
self.pr_resolve_repo(ctx, wk_name, repo_name).await?;
let pr = self.pr_resolve(repo_id, number).await?;
if pr.state != "closed" {
return Err(AppError::BadRequest(
"pull request is not closed".to_string(),
));
}
if pr.merged_by.is_some() {
return Err(AppError::BadRequest(
"merged pull request cannot be reopened".to_string(),
));
}
let now = chrono::Utc::now();
sqlx::query(
"UPDATE pull_request SET state = 'open', closed_by = NULL, closed_at = NULL, updated_at = $1 WHERE id = $2",
)
.bind(now)
.bind(pr.id)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
let pr = self.pr_resolve(repo_id, number).await?;
self.pr_build_response(pr).await
}
pub async fn pr_delete(
&self,
ctx: &Session,
wk_name: &str,
repo_name: &str,
number: i64,
) -> Result<(), AppError> {
let (repo_id, _) =
self.pr_resolve_repo_admin(ctx, wk_name, repo_name).await?;
let pr = self.pr_resolve(repo_id, number).await?;
sqlx::query("UPDATE pull_request SET deleted_at = $1 WHERE id = $2")
.bind(chrono::Utc::now())
.bind(pr.id)
.execute(self.db.writer())
.await
.map_err(|e| AppError::DatabaseError(e.to_string()))?;
Ok(())
}
pub async fn pr_build_response(
&self,
pr: PullRequestModel,
) -> Result<PullRequestResponse, AppError> {
let author = self.users_find_by_id(pr.author).await?;
let merged_by = if let Some(uid) = pr.merged_by {
Some(issue_author(self.users_find_by_id(uid).await?))
} else {
None
};
let closed_by = if let Some(uid) = pr.closed_by {
Some(issue_author(self.users_find_by_id(uid).await?))
} else {
None
};
let labels = self.pr_labels(pr.id).await?;
let assignees = self.pr_assignees_list(pr.id).await?;
let reviews = self.pr_reviews_list(pr.id).await?;
Ok(PullRequestResponse {
number: pr.number,
title: pr.title,
body: pr.body,
state: pr.state,
draft: pr.draft,
author: issue_author(author),
source_repo: pr.source_repo,
source_branch: pr.source_branch,
source_sha: pr.source_sha,
target_branch: pr.target_branch,
target_sha: pr.target_sha,
merged_by,
merged_at: pr.merged_at,
closed_by,
closed_at: pr.closed_at,
created_at: pr.created_at,
updated_at: pr.updated_at,
labels,
assignees,
reviews,
})
}
}