diff --git a/.schema/pgdog.schema.json b/.schema/pgdog.schema.json index 846a01a8..87166b21 100644 --- a/.schema/pgdog.schema.json +++ b/.schema/pgdog.schema.json @@ -1173,6 +1173,11 @@ ], "format": "float" }, + "level": { + "description": "What kind of statements to replicate.", + "$ref": "#/$defs/MirroringLevel", + "default": "all" + }, "queue_length": { "description": "The length of the queue to provision for mirrored transactions. See [mirroring](https://docs.pgdog.dev/features/mirroring/) for more details. This overrides the [`mirror_queue`](https://docs.pgdog.dev/configuration/pgdog.toml/general/#mirror_queue) setting.\n\nhttps://docs.pgdog.dev/configuration/pgdog.toml/mirroring/#queue_depth", "type": [ @@ -1193,6 +1198,25 @@ "destination_db" ] }, + "MirroringLevel": { + "oneOf": [ + { + "description": "Replicate all statements.", + "type": "string", + "const": "all" + }, + { + "description": "Only DML (e.g., insert, update, delete, etc),", + "type": "string", + "const": "dml" + }, + { + "description": "Only DDL (CREATE, DROP, etc.)", + "type": "string", + "const": "ddl" + } + ] + }, "MultiTenant": { "description": "multi-tenant routing configuration, mapping queries to shards via a tenant identifier column.", "type": "object", diff --git a/integration/mirror/dev.sh b/integration/mirror/dev.sh index af8fe472..d6fd7582 100644 --- a/integration/mirror/dev.sh +++ b/integration/mirror/dev.sh @@ -9,7 +9,3 @@ mkdir -p ${GEM_HOME} bundle install bundle exec rspec *_spec.rb popd - -pushd ${SCRIPT_DIR}/php -bash run.sh -popd diff --git a/integration/mirror/pgdog.toml b/integration/mirror/pgdog.toml index 1e60ea82..24d0a62b 100644 --- a/integration/mirror/pgdog.toml +++ b/integration/mirror/pgdog.toml @@ -1,6 +1,7 @@ [general] mirror_exposure = 1.0 openmetrics_port = 9090 +query_parser = "on" [rewrite] enabled = false @@ -16,9 +17,20 @@ name = "pgdog_mirror" host = "127.0.0.1" database_name = "pgdog1" +[[databases]] +name = "pgdog_mirror2" +host = "127.0.0.1" +database_name = "pgdog2" + [[mirroring]] source_db = "pgdog" destination_db = "pgdog_mirror" + +[[mirroring]] +source_db = "pgdog" +destination_db = "pgdog_mirror2" +level = "ddl" + [admin] password = "pgdog" diff --git a/integration/mirror/ruby/copy_spec.rb b/integration/mirror/ruby/copy_spec.rb index f0c2c728..743f414d 100644 --- a/integration/mirror/ruby/copy_spec.rb +++ b/integration/mirror/ruby/copy_spec.rb @@ -80,3 +80,80 @@ conn.exec 'DROP TABLE IF EXISTS mirror_copy_test' end end + +describe 'ddl-only mirror' do + conn = PG.connect('postgres://pgdog:pgdog@127.0.0.1:6432/pgdog') + ddl_mirror = PG.connect('postgres://pgdog:pgdog@127.0.0.1:6432/pgdog_mirror2') + + before do + conn.exec 'DROP TABLE IF EXISTS ddl_mirror_test' + ddl_mirror.exec 'DROP TABLE IF EXISTS ddl_mirror_test' + end + + it 'replicates DDL to ddl-only mirror' do + conn.exec 'CREATE TABLE ddl_mirror_test (id BIGINT PRIMARY KEY, value VARCHAR)' + sleep(0.5) + + # DDL should be mirrored + result = ddl_mirror.exec "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'ddl_mirror_test')" + expect(result[0]['exists']).to eq('t') + end + + it 'does not replicate DML to ddl-only mirror' do + conn.exec 'CREATE TABLE ddl_mirror_test (id BIGINT PRIMARY KEY, value VARCHAR)' + sleep(0.5) + + conn.exec "INSERT INTO ddl_mirror_test VALUES (1, 'should not mirror')" + sleep(0.5) + + # Table should exist on ddl mirror (DDL was mirrored) + result = ddl_mirror.exec 'SELECT count(*) FROM ddl_mirror_test' + # But no rows (DML was not mirrored) + expect(result[0]['count'].to_i).to eq(0) + end + + it 'does not replicate UPDATE to ddl-only mirror' do + conn.exec 'CREATE TABLE ddl_mirror_test (id BIGINT PRIMARY KEY, value VARCHAR)' + sleep(0.5) + + # Insert directly into ddl mirror so we can check UPDATE doesn't propagate + ddl_mirror.exec "INSERT INTO ddl_mirror_test VALUES (1, 'original')" + + conn.exec "INSERT INTO ddl_mirror_test VALUES (1, 'original')" + conn.exec "UPDATE ddl_mirror_test SET value = 'updated' WHERE id = 1" + sleep(0.5) + + result = ddl_mirror.exec 'SELECT value FROM ddl_mirror_test WHERE id = 1' + expect(result[0]['value']).to eq('original') + end + + it 'replicates ALTER TABLE to ddl-only mirror' do + conn.exec 'CREATE TABLE ddl_mirror_test (id BIGINT PRIMARY KEY, value VARCHAR)' + sleep(0.5) + + conn.exec 'ALTER TABLE ddl_mirror_test ADD COLUMN extra TEXT' + sleep(0.5) + + result = ddl_mirror.exec "SELECT column_name FROM information_schema.columns WHERE table_name = 'ddl_mirror_test' AND column_name = 'extra'" + expect(result.ntuples).to eq(1) + end + + it 'replicates DROP TABLE to ddl-only mirror' do + conn.exec 'CREATE TABLE ddl_mirror_test (id BIGINT PRIMARY KEY, value VARCHAR)' + sleep(0.5) + + result = ddl_mirror.exec "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'ddl_mirror_test')" + expect(result[0]['exists']).to eq('t') + + conn.exec 'DROP TABLE ddl_mirror_test' + sleep(0.5) + + result = ddl_mirror.exec "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'ddl_mirror_test')" + expect(result[0]['exists']).to eq('f') + end + + after do + conn.exec 'DROP TABLE IF EXISTS ddl_mirror_test' + ddl_mirror.exec 'DROP TABLE IF EXISTS ddl_mirror_test' + end +end diff --git a/integration/mirror/run.sh b/integration/mirror/run.sh index 70970a1c..3b8ffad5 100644 --- a/integration/mirror/run.sh +++ b/integration/mirror/run.sh @@ -8,4 +8,10 @@ wait_for_pgdog bash ${SCRIPT_DIR}/dev.sh + +pushd ${SCRIPT_DIR}/php +bash run.sh +popd + + stop_pgdog diff --git a/integration/mirror/users.toml b/integration/mirror/users.toml index 84f71659..1dc8b1bb 100644 --- a/integration/mirror/users.toml +++ b/integration/mirror/users.toml @@ -7,3 +7,8 @@ database = "pgdog" name = "pgdog" password = "pgdog" database = "pgdog_mirror" + +[[users]] +name = "pgdog" +password = "pgdog" +database = "pgdog_mirror2" diff --git a/pgdog-config/src/core.rs b/pgdog-config/src/core.rs index 5e5e8c14..204e899b 100644 --- a/pgdog-config/src/core.rs +++ b/pgdog-config/src/core.rs @@ -18,7 +18,7 @@ use super::error::Error; use super::general::General; use super::networking::{MultiTenant, Tcp, TlsVerifyMode}; use super::pooling::PoolerMode; -use super::replication::{MirrorConfig, Mirroring, ReplicaLag, Replication}; +use super::replication::{MirrorConfig, Mirroring, MirroringLevel, ReplicaLag, Replication}; use super::rewrite::Rewrite; use super::sharding::{ManualQuery, OmnishardedTables, ShardedMapping, ShardedTable}; use super::users::{Admin, Plugin, ServerAuth, Users}; @@ -94,15 +94,13 @@ impl ConfigAndUsers { warn!("admin password has been randomly generated"); } - let mut config_and_users = ConfigAndUsers { + let config_and_users = ConfigAndUsers { config, users, config_path: config_path.to_owned(), users_path: users_path.to_owned(), }; - config_and_users.check()?; - Ok(config_and_users) } @@ -424,6 +422,7 @@ impl Config { role: Role, role_warned: bool, parser_warned: bool, + mirror_parser_warned: bool, have_replicas: bool, sharded: bool, } @@ -471,6 +470,7 @@ impl Config { role: database.role, role_warned: false, parser_warned: false, + mirror_parser_warned: false, have_replicas: database.role == Role::Replica, sharded: database.shard > 0, }, @@ -517,7 +517,30 @@ impl Config { } } - for (database, check) in checks { + for mirror in &self.mirroring { + if mirror.level == MirroringLevel::All { + continue; + } + if let Some(check) = checks.get_mut(&mirror.source_db) { + if check.mirror_parser_warned { + continue; + } + let parser_enabled = match self.general.query_parser { + QueryParserLevel::On => true, + QueryParserLevel::Off => false, + QueryParserLevel::Auto => check.have_replicas || check.sharded, + }; + if !parser_enabled { + check.mirror_parser_warned = true; + warn!( + r#"mirroring from "{}" with level "{}" requires the query parser to classify statements, but it won't be enabled, set query_parser = "on""#, + mirror.source_db, mirror.level + ); + } + } + } + + for (database, check) in &checks { if !check.have_replicas && self.general.read_write_split == ReadWriteSplit::ExcludePrimary { @@ -560,6 +583,7 @@ impl Config { .map(|m| MirrorConfig { queue_length: m.queue_length.unwrap_or(self.general.mirror_queue), exposure: m.exposure.unwrap_or(self.general.mirror_exposure), + level: m.level, }) } @@ -571,6 +595,7 @@ impl Config { let config = MirrorConfig { queue_length: mirror.queue_length.unwrap_or(self.general.mirror_queue), exposure: mirror.exposure.unwrap_or(self.general.mirror_exposure), + level: mirror.level, }; result diff --git a/pgdog-config/src/replication.rs b/pgdog-config/src/replication.rs index ccdc29f9..4b0acff6 100644 --- a/pgdog-config/src/replication.rs +++ b/pgdog-config/src/replication.rs @@ -156,6 +156,45 @@ pub struct Mirroring { /// /// https://docs.pgdog.dev/configuration/pgdog.toml/mirroring/#exposure pub exposure: Option, + + /// What kind of statements to replicate. + #[serde(default)] + pub level: MirroringLevel, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, JsonSchema, Copy)] +#[serde(deny_unknown_fields, rename_all = "lowercase")] +pub enum MirroringLevel { + /// Replicate all statements. + #[default] + All, + /// Only DML (e.g., insert, update, delete, etc), + Dml, + /// Only DDL (CREATE, DROP, etc.) + Ddl, +} + +impl std::fmt::Display for MirroringLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::All => write!(f, "all"), + Self::Dml => write!(f, "dml"), + Self::Ddl => write!(f, "ddl"), + } + } +} + +impl FromStr for MirroringLevel { + type Err = (); + + fn from_str(value: &str) -> Result { + match value { + "all" => Ok(Self::All), + "dml" => Ok(Self::Dml), + "ddl" => Ok(Self::Ddl), + _ => Err(()), + } + } } impl FromStr for Mirroring { @@ -166,6 +205,7 @@ impl FromStr for Mirroring { let mut destination_db = None; let mut queue_length = None; let mut exposure = None; + let mut level = MirroringLevel::default(); for pair in s.split('&') { let parts: Vec<&str> = pair.split('=').collect(); @@ -190,6 +230,7 @@ impl FromStr for Mirroring { .map_err(|_| format!("Invalid exposure: {}", parts[1]))?, ); } + "level" => level = MirroringLevel::from_str(parts[1]).unwrap_or_default(), _ => return Err(format!("Unknown parameter: {}", parts[0])), } } @@ -202,15 +243,18 @@ impl FromStr for Mirroring { destination_db, queue_length, exposure, + level, }) } } /// Runtime mirror configuration with defaults resolved from global settings. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct MirrorConfig { /// Effective queue length for this mirror. pub queue_length: usize, /// Effective exposure fraction for this mirror. pub exposure: f32, + /// What kind of statements to mirror. + pub level: MirroringLevel, } diff --git a/pgdog/src/backend/databases.rs b/pgdog/src/backend/databases.rs index 6711ea4f..682a9aa8 100644 --- a/pgdog/src/backend/databases.rs +++ b/pgdog/src/backend/databases.rs @@ -689,6 +689,7 @@ pub fn from_config(config: &ConfigAndUsers) -> Databases { exposure: mirror .exposure .unwrap_or(config.config.general.mirror_exposure), + level: mirror.level, }; mirror_configs.insert( (mirror.source_db.clone(), mirror.destination_db.clone()), @@ -852,8 +853,7 @@ mod tests { config.mirroring = vec![crate::config::Mirroring { source_db: "db1".to_string(), destination_db: "db1_mirror".to_string(), - queue_length: None, - exposure: None, + ..Default::default() }]; let users = crate::config::Users { @@ -932,8 +932,7 @@ mod tests { config.mirroring = vec![crate::config::Mirroring { source_db: "source_db".to_string(), destination_db: "dest_db".to_string(), - queue_length: None, - exposure: None, + ..Default::default() }]; let users = crate::config::Users { @@ -1011,6 +1010,7 @@ mod tests { destination_db: "dest_db".to_string(), queue_length: Some(256), exposure: Some(0.5), + ..Default::default() }]; let users = crate::config::Users { @@ -1087,8 +1087,7 @@ mod tests { config.mirroring = vec![crate::config::Mirroring { source_db: "db1".to_string(), destination_db: "db2".to_string(), - queue_length: None, - exposure: None, + ..Default::default() }]; let users = crate::config::Users { @@ -1168,13 +1167,13 @@ mod tests { source_db: "primary".to_string(), destination_db: "mirror1".to_string(), queue_length: Some(200), // Override queue only - exposure: None, + ..Default::default() }, crate::config::Mirroring { source_db: "primary".to_string(), destination_db: "mirror2".to_string(), - queue_length: None, exposure: Some(0.25), // Override exposure only + ..Default::default() }, ]; @@ -1259,6 +1258,7 @@ mod tests { destination_db: "dest".to_string(), queue_length: Some(256), exposure: Some(0.5), + ..Default::default() }]; // Create user mismatch - user1 for source, user2 for dest @@ -1325,6 +1325,7 @@ mod tests { destination_db: "dest_db".to_string(), queue_length: Some(256), exposure: Some(0.5), + ..Default::default() }]; // No users at all diff --git a/pgdog/src/backend/pool/connection/mirror/handler.rs b/pgdog/src/backend/pool/connection/mirror/handler.rs index 20769c61..0c667542 100644 --- a/pgdog/src/backend/pool/connection/mirror/handler.rs +++ b/pgdog/src/backend/pool/connection/mirror/handler.rs @@ -5,8 +5,10 @@ use super::*; use crate::backend::pool::MirrorStats; +use crate::frontend::router::parser::cache::StatementType; use crate::frontend::ClientRequest; use parking_lot::Mutex; +use pgdog_config::MirrorConfig; use std::sync::Arc; /// Mirror handle state. @@ -27,8 +29,7 @@ enum MirrorHandlerState { pub struct MirrorHandler { /// Sender. tx: Sender, - /// Percentage of requests being mirrored. 0 = 0%, 1.0 = 100%. - exposure: f32, + config: MirrorConfig, /// Mirror handle state. state: MirrorHandlerState, /// Request buffer. @@ -46,10 +47,14 @@ impl MirrorHandler { } /// Create new mirror handle with exposure. - pub fn new(tx: Sender, exposure: f32, stats: Arc>) -> Self { + pub fn new( + tx: Sender, + config: &MirrorConfig, + stats: Arc>, + ) -> Self { Self { tx, - exposure, + config: config.clone(), state: MirrorHandlerState::Idle, buffer: vec![], timer: Instant::now(), @@ -62,19 +67,33 @@ impl MirrorHandler { /// Returns true if request will be sent, false otherwise. /// pub fn send(&mut self, buffer: &ClientRequest) -> bool { + let stmt_type = buffer.ast.as_ref().map(|ast| ast.statement_type()); + if let Some(stmt_type) = stmt_type { + match (self.config.level, stmt_type) { + (MirroringLevel::Ddl, StatementType::Dml) => { + debug!("mirror dropping dml (level=ddl)"); + return false; + } + (MirroringLevel::Dml, StatementType::Ddl) => { + debug!("mirror dropping ddl (level=dml)"); + return false; + } + _ => (), + } + } match self.state { MirrorHandlerState::Dropping => { debug!("mirror dropping request"); false } MirrorHandlerState::Idle => { - let roll = if self.exposure < 1.0 { + let roll = if self.config.exposure < 1.0 { rng().random_range(0.0..1.0) } else { 0.99 }; - if roll < self.exposure { + if roll < self.config.exposure { self.state = MirrorHandlerState::Sending; self.buffer.push(BufferWithDelay { buffer: buffer.clone(), @@ -84,7 +103,10 @@ impl MirrorHandler { true } else { self.state = MirrorHandlerState::Dropping; - debug!("mirror dropping transaction [exposure: {}]", self.exposure); + debug!( + "mirror dropping transaction [exposure: {}]", + self.config.exposure + ); false } } @@ -191,7 +213,14 @@ mod tests { ) { let (tx, rx) = channel(1000); // Keep receiver to prevent channel closure let stats = Arc::new(Mutex::new(MirrorStats::default())); - let handler = MirrorHandler::new(tx, exposure, stats.clone()); + let handler = MirrorHandler::new( + tx, + &MirrorConfig { + exposure, + ..Default::default() + }, + stats.clone(), + ); (handler, stats, rx) } @@ -415,7 +444,14 @@ mod tests { fn test_queue_length_with_channel_overflow() { let (tx, _rx) = channel(1); // Channel with capacity of 1 let stats = Arc::new(Mutex::new(MirrorStats::default())); - let mut handler = MirrorHandler::new(tx, 1.0, stats.clone()); + let mut handler = MirrorHandler::new( + tx, + &MirrorConfig { + exposure: 1.0, + ..Default::default() + }, + stats.clone(), + ); // Fill the channel assert!(handler.send(&vec![].into())); @@ -439,6 +475,105 @@ mod tests { } } + fn create_test_handler_with_level( + exposure: f32, + level: MirroringLevel, + ) -> ( + MirrorHandler, + Arc>, + Receiver, + ) { + let (tx, rx) = channel(1000); + let stats = Arc::new(Mutex::new(MirrorStats::default())); + let handler = MirrorHandler::new( + tx, + &MirrorConfig { + exposure, + level, + ..Default::default() + }, + stats.clone(), + ); + (handler, stats, rx) + } + + fn request_with_ast(query: &str) -> ClientRequest { + use crate::frontend::router::Ast; + let ast = Ast::from_parse_result(pg_query::parse(query).unwrap()); + ClientRequest { + messages: vec![], + route: None, + ast: Some(ast), + } + } + + #[test] + fn test_ddl_level_drops_dml() { + let (mut handler, _, _rx) = create_test_handler_with_level(1.0, MirroringLevel::Ddl); + + // DML should be dropped + assert!(!handler.send(&request_with_ast("SELECT 1"))); + assert!(!handler.send(&request_with_ast("INSERT INTO t VALUES (1)"))); + assert!(!handler.send(&request_with_ast("UPDATE t SET x = 1"))); + assert!(!handler.send(&request_with_ast("DELETE FROM t"))); + assert!(!handler.send(&request_with_ast("BEGIN"))); + + // DDL should be sent + assert!(handler.send(&request_with_ast("CREATE TABLE t (id INT)"))); + assert!(handler.send(&request_with_ast("DROP TABLE t"))); + assert!(handler.send(&request_with_ast("ALTER TABLE t ADD COLUMN x INT"))); + } + + #[test] + fn test_dml_level_drops_ddl() { + let (mut handler, _, _rx) = create_test_handler_with_level(1.0, MirroringLevel::Dml); + + // DDL should be dropped + assert!(!handler.send(&request_with_ast("CREATE TABLE t (id INT)"))); + assert!(!handler.send(&request_with_ast("DROP TABLE t"))); + assert!(!handler.send(&request_with_ast("ALTER TABLE t ADD COLUMN x INT"))); + + // DML should be sent + assert!(handler.send(&request_with_ast("SELECT 1"))); + assert!(handler.send(&request_with_ast("INSERT INTO t VALUES (1)"))); + assert!(handler.send(&request_with_ast("UPDATE t SET x = 1"))); + assert!(handler.send(&request_with_ast("DELETE FROM t"))); + assert!(handler.send(&request_with_ast("BEGIN"))); + } + + #[test] + fn test_all_level_sends_everything() { + let (mut handler, _, _rx) = create_test_handler_with_level(1.0, MirroringLevel::All); + + assert!(handler.send(&request_with_ast("SELECT 1"))); + assert!(handler.send(&request_with_ast("CREATE TABLE t (id INT)"))); + assert!(handler.send(&request_with_ast("SET search_path TO public"))); + assert!(handler.send(&request_with_ast("BEGIN"))); + } + + #[test] + fn test_session_statements_pass_through_all_levels() { + for level in [ + MirroringLevel::Ddl, + MirroringLevel::Dml, + MirroringLevel::All, + ] { + let (mut handler, _, _rx) = create_test_handler_with_level(1.0, level); + assert!( + handler.send(&request_with_ast("SET search_path TO public")), + "SET should pass through at level {:?}", + level, + ); + } + } + + #[test] + fn test_no_ast_passes_through() { + // Requests without AST (e.g. Sync-only) should always be sent + let (mut handler, _, _rx) = create_test_handler_with_level(1.0, MirroringLevel::Ddl); + assert!(handler.send(&vec![].into())); + } + #[test] fn test_queue_length_never_negative() { // Test to ensure queue_length never goes negative even with mismatched increment/decrement diff --git a/pgdog/src/backend/pool/connection/mirror/mod.rs b/pgdog/src/backend/pool/connection/mirror/mod.rs index 661bd7e2..6193f8ec 100644 --- a/pgdog/src/backend/pool/connection/mirror/mod.rs +++ b/pgdog/src/backend/pool/connection/mirror/mod.rs @@ -2,6 +2,7 @@ use std::time::Duration; +use pgdog_config::MirroringLevel; use rand::{rng, Rng}; use tokio::select; use tokio::time::{sleep, Instant}; @@ -105,11 +106,12 @@ impl Mirror { .unwrap_or_else(|| crate::config::MirrorConfig { queue_length: config.config.general.mirror_queue, exposure: config.config.general.mirror_exposure, + level: MirroringLevel::default(), }); // Mirror queue. let (tx, mut rx) = channel(mirror_config.queue_length); - let handler = MirrorHandler::new(tx, mirror_config.exposure, cluster.stats()); + let handler = MirrorHandler::new(tx, &mirror_config, cluster.stats()); let stats_for_errors = cluster.stats(); spawn(async move { @@ -165,6 +167,8 @@ impl Mirror { #[cfg(test)] mod test { + use pgdog_config::MirrorConfig; + use crate::{ backend::pool::Request, config::{self, PoolerMode, PreparedStatements as PreparedStatementsLevel}, @@ -181,7 +185,14 @@ mod test { let (tx, rx) = channel(25); let stats = Arc::new(Mutex::new(MirrorStats::default())); - let mut handle = MirrorHandler::new(tx.clone(), 1.0, stats.clone()); + let mut handle = MirrorHandler::new( + tx.clone(), + &MirrorConfig { + exposure: 1.0, + ..Default::default() + }, + stats.clone(), + ); for _ in 0..25 { assert!( @@ -195,7 +206,14 @@ mod test { let (tx, rx) = channel(25); let stats2 = Arc::new(Mutex::new(MirrorStats::default())); - let mut handle = MirrorHandler::new(tx.clone(), 0.5, stats2); + let mut handle = MirrorHandler::new( + tx.clone(), + &MirrorConfig { + exposure: 0.5, + ..Default::default() + }, + stats2, + ); let dropped = (0..25) .map(|_| handle.send(&vec![].into()) && handle.send(&vec![].into()) && handle.flush()) .filter(|s| !s) diff --git a/pgdog/src/frontend/router/parser/cache/ast.rs b/pgdog/src/frontend/router/parser/cache/ast.rs index 2be0e0fb..c34d865d 100644 --- a/pgdog/src/frontend/router/parser/cache/ast.rs +++ b/pgdog/src/frontend/router/parser/cache/ast.rs @@ -214,4 +214,42 @@ impl Ast { guard.direct += 1; } } + + /// Get statement type. + pub fn statement_type(&self) -> StatementType { + let root = self + .ast + .protobuf + .stmts + .first() + .and_then(|s| s.stmt.as_ref()) + .and_then(|s| s.node.as_ref()); + + match root { + Some(NodeEnum::SelectStmt(_)) + | Some(NodeEnum::InsertStmt(_)) + | Some(NodeEnum::UpdateStmt(_)) + | Some(NodeEnum::DeleteStmt(_)) + | Some(NodeEnum::CopyStmt(_)) + | Some(NodeEnum::ExplainStmt(_)) + | Some(NodeEnum::TransactionStmt(_)) => StatementType::Dml, + + Some(NodeEnum::VariableSetStmt(_)) + | Some(NodeEnum::VariableShowStmt(_)) + | Some(NodeEnum::DeallocateStmt(_)) + | Some(NodeEnum::ListenStmt(_)) + | Some(NodeEnum::NotifyStmt(_)) + | Some(NodeEnum::UnlistenStmt(_)) + | Some(NodeEnum::DiscardStmt(_)) => StatementType::Session, + + _ => StatementType::Ddl, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum StatementType { + Ddl, + Dml, + Session, }