//! read_sql — parse and analyze SQL files. use crate::file_tools::MAX_FILE_SIZE; use crate::git_tools::ctx::GitToolCtx; use agent::{ToolDefinition, ToolHandler, ToolParam, ToolRegistry, ToolSchema}; use sqlparser::ast::{Statement, ColumnDef}; use sqlparser::dialect::{GenericDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect}; use sqlparser::parser::Parser; use std::collections::HashMap; async fn read_sql_exec( ctx: GitToolCtx, args: serde_json::Value, ) -> Result { let p: serde_json::Map = serde_json::from_value(args).map_err(|e| e.to_string())?; let project_name = p .get("project_name") .and_then(|v| v.as_str()) .ok_or("missing project_name")?; let repo_name = p .get("repo_name") .and_then(|v| v.as_str()) .ok_or("missing repo_name")?; let path = p.get("path").and_then(|v| v.as_str()).ok_or("missing path")?; let rev = p .get("rev") .and_then(|v| v.as_str()) .map(String::from) .unwrap_or_else(|| "HEAD".to_string()); let dialect = p.get("dialect").and_then(|v| v.as_str()).unwrap_or("generic"); let domain = ctx.open_repo(project_name, repo_name).await?; let commit_oid = if rev.len() >= 40 { git::commit::types::CommitOid::new(&rev) } else { domain .commit_get_prefix(&rev) .map_err(|e| e.to_string())? .oid }; let entry = domain .tree_entry_by_path_from_commit(&commit_oid, path) .map_err(|e| e.to_string())?; let content = domain.blob_content(&entry.oid).map_err(|e| e.to_string())?; let data = &content.content; if data.len() > MAX_FILE_SIZE { return Err(format!( "file too large ({} bytes), max {} bytes", data.len(), MAX_FILE_SIZE )); } let text = String::from_utf8_lossy(data); let parser_dialect: Box = match dialect { "mysql" => Box::new(MySqlDialect {}), "postgresql" | "postgres" => Box::new(PostgreSqlDialect {}), "sqlite" => Box::new(SQLiteDialect {}), _ => Box::new(GenericDialect {}), }; let statements = Parser::parse_sql(parser_dialect.as_ref(), &text) .map_err(|e| format!("SQL parse error: {}", e))?; let mut tables: Vec = Vec::new(); let mut views: Vec = Vec::new(); let mut functions: Vec = Vec::new(); let mut indexes: Vec = Vec::new(); let mut statement_kinds: std::collections::HashMap = std::collections::HashMap::new(); for statement in &statements { let kind = format!("{:?}", statement).split('{').next().unwrap_or("unknown").to_string(); *statement_kinds.entry(kind).or_insert(0) += 1; match statement { Statement::CreateTable(stmt) => { let name = stmt.name.to_string(); let columns: Vec = stmt.columns.iter().map(format_column_def).collect(); tables.push(serde_json::json!({ "name": name, "columns": columns, "if_not_exists": stmt.if_not_exists, })); } Statement::CreateView { name, query, .. } => { views.push(serde_json::json!({ "name": name.to_string(), "query": query.to_string(), })); } Statement::CreateIndex(stmt) => { indexes.push(serde_json::json!({ "name": stmt.name.as_ref().map(|n| n.to_string()).unwrap_or_default(), "table": stmt.table_name.to_string(), "columns": stmt.columns.iter().map(|c| c.to_string()).collect::>(), })); } Statement::CreateFunction(stmt) => { functions.push(serde_json::json!({ "name": stmt.name.to_string(), "args": stmt.args.iter().flat_map(|args| args.iter().filter_map(|a| a.name.as_ref().map(|n| n.to_string()))).collect::>(), "return_type": stmt.return_type.as_ref().map(|r| r.to_string()).unwrap_or_default(), })); } _ => {} } } Ok(serde_json::json!({ "path": path, "rev": rev, "dialect": dialect, "statement_count": statements.len(), "statement_kinds": statement_kinds, "tables": tables, "views": views, "functions": functions, "indexes": indexes, })) } fn format_column_def(col: &ColumnDef) -> String { let name = col.name.to_string(); let data_type = col.data_type.to_string(); format!("{} {}", name, data_type) } pub fn register_sql_tools(registry: &mut ToolRegistry) { let p = HashMap::from([ ("project_name".into(), ToolParam { name: "project_name".into(), param_type: "string".into(), description: Some("Project name (slug)".into()), required: true, properties: None, items: None }), ("repo_name".into(), ToolParam { name: "repo_name".into(), param_type: "string".into(), description: Some("Repository name".into()), required: true, properties: None, items: None }), ("path".into(), ToolParam { name: "path".into(), param_type: "string".into(), description: Some("File path to the SQL file".into()), required: true, properties: None, items: None }), ("rev".into(), ToolParam { name: "rev".into(), param_type: "string".into(), description: Some("Git revision (default: HEAD)".into()), required: false, properties: None, items: None }), ("dialect".into(), ToolParam { name: "dialect".into(), param_type: "string".into(), description: Some("SQL dialect: generic, mysql, postgresql, sqlite. Default: generic".into()), required: false, properties: None, items: None }), ]); let schema = ToolSchema { schema_type: "object".into(), properties: Some(p), required: Some(vec!["project_name".into(), "repo_name".into(), "path".into()]) }; registry.register( ToolDefinition::new("read_sql") .description("Parse and analyze a SQL file. Extracts CREATE TABLE statements (with columns and types), CREATE VIEW, CREATE INDEX, CREATE FUNCTION, and counts all statement types.") .parameters(schema), ToolHandler::new(|ctx, args| { let gctx = GitToolCtx::new(ctx); Box::pin(async move { read_sql_exec(gctx, args).await.map_err(agent::ToolError::ExecutionError) }) }), ); }