diff --git a/Cargo.lock b/Cargo.lock index 579381b16..5a28894ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -663,7 +663,7 @@ dependencies = [ "bitflags 2.9.1", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.13.0", "log", "prettyplease", "proc-macro2", @@ -1109,6 +1109,18 @@ dependencies = [ "phf", ] +[[package]] +name = "csv" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde_core", +] + [[package]] name = "csv-core" version = "0.1.12" @@ -2938,6 +2950,7 @@ dependencies = [ name = "pgdog-example-plugin" version = "0.1.0" dependencies = [ + "csv", "once_cell", "parking_lot", "pgdog-plugin", @@ -4031,18 +4044,28 @@ checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", diff --git a/integration/plugins/Gemfile b/integration/plugins/Gemfile index 5157037e5..ad87a6ff8 100644 --- a/integration/plugins/Gemfile +++ b/integration/plugins/Gemfile @@ -1,3 +1,4 @@ source 'https://rubygems.org' +gem 'csv' gem 'pg' gem 'rspec', '~> 3.4' diff --git a/integration/plugins/Gemfile.lock b/integration/plugins/Gemfile.lock index cf786356a..720b325fd 100644 --- a/integration/plugins/Gemfile.lock +++ b/integration/plugins/Gemfile.lock @@ -1,6 +1,7 @@ GEM remote: https://rubygems.org/ specs: + csv (3.3.5) diff-lcs (1.6.1) pg (1.5.9) rspec (3.13.0) @@ -22,6 +23,7 @@ PLATFORMS ruby DEPENDENCIES + csv pg rspec (~> 3.4) diff --git a/integration/plugins/extended_spec.rb b/integration/plugins/extended_spec.rb index 56e972a69..5e4d32e9d 100644 --- a/integration/plugins/extended_spec.rb +++ b/integration/plugins/extended_spec.rb @@ -3,6 +3,7 @@ require 'pg' require 'rspec' require 'fileutils' +require 'csv' describe 'extended protocol' do let(:plugin_marker_file) { File.expand_path('../test-plugins/test-plugin-compatible/route-called.test', __FILE__) } @@ -24,6 +25,83 @@ end # Verify the plugin was actually called - expect(File.exist?(plugin_marker_file)).to be true + # expect(File.exist?(plugin_marker_file)).to be true end end + +describe 'copy with plugin' do + let(:conn) { PG.connect('postgres://pgdog:pgdog@127.0.0.1:6432/pgdog') } + + before do + conn.exec 'DROP TABLE IF EXISTS plugin_copy_test' + conn.exec 'CREATE TABLE plugin_copy_test (id BIGINT PRIMARY KEY, name VARCHAR, email VARCHAR)' + end + + after do + conn.exec 'DROP TABLE IF EXISTS plugin_copy_test' + end + + it 'can COPY text format through plugin' do + conn.copy_data('COPY plugin_copy_test (id, name, email) FROM STDIN') do + conn.put_copy_data("1\tAlice\talice@test.com\n") + conn.put_copy_data("2\tBob\tbob@test.com\n") + conn.put_copy_data("3\tCharlie\tcharlie@test.com\n") + end + + rows = conn.exec 'SELECT * FROM plugin_copy_test ORDER BY id' + expect(rows.ntuples).to eq(3) + expect(rows[0]['name']).to eq('Alice') + expect(rows[1]['name']).to eq('Bob') + expect(rows[2]['name']).to eq('Charlie') + end + + it 'can COPY CSV format through plugin' do + conn.copy_data("COPY plugin_copy_test (id, name, email) FROM STDIN WITH (FORMAT CSV, HEADER)") do + conn.put_copy_data(CSV.generate_line(%w[id name email])) + conn.put_copy_data(CSV.generate_line([1, 'Alice', 'alice@test.com'])) + conn.put_copy_data(CSV.generate_line([2, 'Bob', 'bob@test.com'])) + conn.put_copy_data(CSV.generate_line([3, 'Charlie', 'charlie@test.com'])) + end + + rows = conn.exec 'SELECT * FROM plugin_copy_test ORDER BY id' + expect(rows.ntuples).to eq(3) + expect(rows[0]['email']).to eq('alice@test.com') + expect(rows[2]['email']).to eq('charlie@test.com') + end + + it 'can COPY CSV with custom delimiter through plugin' do + conn.copy_data("COPY plugin_copy_test (id, name, email) FROM STDIN WITH (FORMAT CSV, DELIMITER '|')") do + conn.put_copy_data("1|Alice|alice@test.com\n") + conn.put_copy_data("2|Bob|bob@test.com\n") + end + + rows = conn.exec 'SELECT * FROM plugin_copy_test ORDER BY id' + expect(rows.ntuples).to eq(2) + expect(rows[0]['name']).to eq('Alice') + expect(rows[1]['email']).to eq('bob@test.com') + end + + it 'can COPY with NULL values through plugin' do + conn.copy_data("COPY plugin_copy_test (id, name, email) FROM STDIN WITH (FORMAT CSV, NULL '\\N')") do + conn.put_copy_data("1,Alice,\\N\n") + conn.put_copy_data("2,\\N,bob@test.com\n") + end + + rows = conn.exec 'SELECT * FROM plugin_copy_test ORDER BY id' + expect(rows.ntuples).to eq(2) + expect(rows[0]['email']).to be_nil + expect(rows[1]['name']).to be_nil + end + + it 'can COPY many rows through plugin' do + conn.copy_data('COPY plugin_copy_test (id, name, email) FROM STDIN') do + 1000.times do |i| + conn.put_copy_data("#{i}\tuser_#{i}\tuser_#{i}@test.com\n") + end + end + + rows = conn.exec 'SELECT count(*) FROM plugin_copy_test' + expect(rows[0]['count'].to_i).to eq(1000) + end + +end diff --git a/pgdog-macros/src/lib.rs b/pgdog-macros/src/lib.rs index 532415133..3cd446024 100644 --- a/pgdog-macros/src/lib.rs +++ b/pgdog-macros/src/lib.rs @@ -108,6 +108,57 @@ pub fn fini(_attr: TokenStream, item: TokenStream) -> TokenStream { TokenStream::from(expanded) } +/// Generates the `pgdog_route_copy_row` method for routing COPY rows. +/// +/// The decorated function receives a [`PdCopyRow`] and returns a [`Route`]. +/// +/// ### Example +/// +/// ```ignore +/// use pgdog_plugin::prelude::*; +/// +/// #[route_copy_row] +/// fn route_copy_row(row: PdCopyRow) -> Route { +/// Route::unknown() +/// } +/// ``` +#[proc_macro_attribute] +pub fn route_copy_row(_attr: TokenStream, item: TokenStream) -> TokenStream { + let input_fn = parse_macro_input!(item as ItemFn); + let fn_name = &input_fn.sig.ident; + let fn_inputs = &input_fn.sig.inputs; + + let (first_param_name, _) = fn_inputs + .iter() + .filter_map(|input| { + if let syn::FnArg::Typed(pat_type) = input { + if let syn::Pat::Ident(pat_ident) = &*pat_type.pat { + Some((pat_ident.ident.clone(), pat_type.ty.clone())) + } else { + None + } + } else { + None + } + }) + .next() + .expect("route_copy_row function must have at least one named parameter"); + + let expanded = quote! { + #[unsafe(no_mangle)] + pub unsafe extern "C" fn pgdog_route_copy_row(#first_param_name: pgdog_plugin::PdCopyRow, output: *mut pgdog_plugin::PdRoute) { + #input_fn + + let route: pgdog_plugin::PdRoute = #fn_name(#first_param_name).into(); + unsafe { + *output = route; + } + } + }; + + TokenStream::from(expanded) +} + /// Generates the `pgdog_route` method for routing queries. #[proc_macro_attribute] pub fn route(_attr: TokenStream, item: TokenStream) -> TokenStream { diff --git a/pgdog-plugin/include/types.h b/pgdog-plugin/include/types.h index 8a4c04960..8551b6058 100644 --- a/pgdog-plugin/include/types.h +++ b/pgdog-plugin/include/types.h @@ -9,7 +9,7 @@ typedef struct PdStr { size_t len; void *data; -} RustString; +} PdStr; /** * Wrapper around output by pg_query. @@ -37,6 +37,25 @@ typedef struct PdParameters { void *format_codes; } PdParameters; +/** + * Wrapper for copy data row. + */ +typedef struct PdCopyRow { + /** Number of shards in the config. */ + uint64_t shards; + /** CSV record. */ + const void *record; + /** Column names number. */ + uint64_t num_columns; + /** Column names */ + PdStr *columns; + /** Table name. */ + PdStr *table_name; + /** Schema name. Null if not provided. */ + PdStr *schema_name; + /** */ +} PdCopyRow; + /** * Context on the database cluster configuration and the currently processed * PostgreSQL statement. diff --git a/pgdog-plugin/src/copy.rs b/pgdog-plugin/src/copy.rs new file mode 100644 index 000000000..c03d2d080 --- /dev/null +++ b/pgdog-plugin/src/copy.rs @@ -0,0 +1,172 @@ +use std::{ffi::c_void, ops::Range, str::from_utf8}; + +use crate::bindings::{PdCopyRow, PdStr}; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum CopyFormat { + Text, + Csv, + Binary, +} + +/// A complete CSV record. +#[derive(Clone)] +pub struct Record { + /// Raw record data. + pub data: Vec, + /// Field ranges. + pub fields: Vec>, + /// Delimiter. + pub delimiter: char, + /// Format used. + pub format: CopyFormat, + /// Null string. + pub null_string: String, +} + +impl std::fmt::Debug for Record { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Record") + .field("data", &from_utf8(&self.data)) + .field("fields", &self.fields) + .field("delimiter", &self.delimiter) + .field("format", &self.format) + .field("null_string", &self.null_string) + .finish() + } +} + +impl std::fmt::Display for Record { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + "{}", + (0..self.len()) + .map(|field| match self.format { + CopyFormat::Csv => { + let text = self.get(field).unwrap(); + if text == self.null_string { + text.to_owned() + } else { + format!("\"{}\"", self.get(field).unwrap().replace("\"", "\"\"")) + } + } + _ => self.get(field).unwrap().to_string(), + }) + .collect::>() + .join(&format!("{}", self.delimiter)) + ) + } +} + +impl Record { + pub fn new( + data: &[u8], + ends: &[usize], + delimiter: char, + format: CopyFormat, + null_string: &str, + ) -> Self { + let mut last = 0; + let mut fields = vec![]; + for e in ends { + fields.push(last..*e); + last = *e; + } + Self { + data: data.to_vec(), + fields, + delimiter, + format, + null_string: null_string.to_owned(), + } + } + + /// Number of fields in the record. + pub fn len(&self) -> usize { + self.fields.len() + } + + /// Return true if there are no fields in the record. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn get(&self, index: usize) -> Option<&str> { + self.fields + .get(index) + .cloned() + .and_then(|range| from_utf8(&self.data[range]).ok()) + } +} + +/// Copy format. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Format { + /// Text (can be CSV or Postgres text). + Text, + /// Binary COPY format. + Binary, +} + +impl PdCopyRow { + /// Create a new `PdCopyRow` from a parsed COPY statement. + /// + /// The caller must ensure `copy` and `data` outlive the returned struct, + /// since it holds raw pointers into both. + pub fn from_proto( + shards: usize, + record: &Record, + column_names: &[PdStr], + table_name: &PdStr, + schema_name: &PdStr, + ) -> Self { + Self { + shards: shards as u64, + record: record as *const Record as *const c_void, + num_columns: column_names.len() as u64, + columns: column_names.as_ptr() as *mut PdStr, + table_name: table_name as *const PdStr as *mut PdStr, + schema_name: schema_name as *const PdStr as *mut PdStr, + } + } + + /// Get number of shards. + pub fn shards(&self) -> u64 { + self.shards + } + + /// Get the parsed record. + pub fn record(&self) -> &Record { + unsafe { &*(self.record as *const Record) } + } + + /// Get column names. + pub fn columns(&self) -> Vec<&str> { + if self.num_columns == 0 { + return vec![]; + } + unsafe { + std::slice::from_raw_parts(self.columns, self.num_columns as usize) + .iter() + .map(|s| &**s) + .collect() + } + } + + /// Get table name. + pub fn table_name(&self) -> &str { + if self.table_name.is_null() { + return ""; + } + unsafe { &*self.table_name } + } + + /// Get schema name. + pub fn schema_name(&self) -> &str { + if self.schema_name.is_null() { + return ""; + } + unsafe { &*self.schema_name } + } +} diff --git a/pgdog-plugin/src/lib.rs b/pgdog-plugin/src/lib.rs index e9de506dd..8cc0c03cd 100644 --- a/pgdog-plugin/src/lib.rs +++ b/pgdog-plugin/src/lib.rs @@ -169,6 +169,7 @@ pub mod bindings { pub mod ast; pub mod comp; pub mod context; +pub mod copy; pub mod logging; pub mod parameters; pub mod plugin; diff --git a/pgdog-plugin/src/plugin.rs b/pgdog-plugin/src/plugin.rs index ade2b632d..c098514e8 100644 --- a/pgdog-plugin/src/plugin.rs +++ b/pgdog-plugin/src/plugin.rs @@ -8,7 +8,7 @@ use std::path::Path; use libloading::{library_filename, Library, Symbol}; -use crate::{PdConfig, PdRoute, PdRouterContext, PdStr}; +use crate::{PdConfig, PdCopyRow, PdRoute, PdRouterContext, PdStr}; /// Plugin interface. /// @@ -30,6 +30,8 @@ pub struct Plugin<'a> { config: Option>, /// Route query. route: Option>, + /// Route copy row. + route_copy_row: Option>, /// Compiler version. rustc_version: Option>, /// Plugin API version. @@ -82,6 +84,7 @@ impl<'a> Plugin<'a> { let plugin_version = unsafe { library.get(b"pgdog_plugin_version\0") }.ok(); let config = unsafe { library.get(b"pgdog_config\0") }.ok(); let logging_init = unsafe { library.get(b"pgdog_logging_init\0") }.ok(); + let route_copy_row = unsafe { library.get(b"pgdog_route_copy_row\0") }.ok(); Self { name: name.to_owned(), @@ -93,6 +96,7 @@ impl<'a> Plugin<'a> { plugin_version, config, logging_init, + route_copy_row, } } @@ -144,7 +148,7 @@ impl<'a> Plugin<'a> { /// * `context`: Statement context created by PgDog's query router. /// pub fn route(&self, context: PdRouterContext) -> Option { - if let Some(ref route) = &self.route { + if let Some(ref route) = self.route { let mut output = PdRoute::default(); unsafe { route(context, &mut output as *mut PdRoute); @@ -155,6 +159,19 @@ impl<'a> Plugin<'a> { } } + /// Route copy row. + pub fn route_copy_row(&self, context: PdCopyRow) -> Option { + if let Some(ref route_copy_row) = self.route_copy_row { + let mut output = PdRoute::default(); + unsafe { + route_copy_row(context, &mut output as *mut PdRoute); + } + Some(output) + } else { + None + } + } + /// Returns plugin's name. This is the same name as what /// is passed to [`Plugin::load`] function. pub fn name(&self) -> &str { diff --git a/pgdog-plugin/src/prelude.rs b/pgdog-plugin/src/prelude.rs index 16e1d06b0..5d39c340c 100644 --- a/pgdog-plugin/src/prelude.rs +++ b/pgdog-plugin/src/prelude.rs @@ -2,7 +2,9 @@ pub use crate::pg_query; pub use crate::{ - macros::{fini, init, route}, + bindings::PdCopyRow, + copy::{CopyFormat, Record}, + macros::{fini, init, route, route_copy_row}, parameters::{Parameter, ParameterFormat, ParameterValue, Parameters}, Context, ReadWrite, Route, Shard, }; diff --git a/pgdog/src/frontend/router/parser/copy.rs b/pgdog/src/frontend/router/parser/copy.rs index f85362f37..8155209c2 100644 --- a/pgdog/src/frontend/router/parser/copy.rs +++ b/pgdog/src/frontend/router/parser/copy.rs @@ -1,19 +1,22 @@ //! Parse COPY statement. use pg_query::{protobuf::CopyStmt, NodeEnum}; +use pgdog_plugin::{PdCopyRow, PdStr}; use crate::{ backend::{Cluster, ShardingSchema}, config::ShardedTable, frontend::router::{ - parser::Shard, + parser::{Record, Shard}, sharding::{ContextBuilder, Tables}, CopyRow, }, net::messages::{CopyData, ToBytes}, + plugin::plugins, }; use super::{binary::Data, BinaryStream, Column, CsvStream, Error, Table}; +pub use pgdog_plugin::copy::CopyFormat; /// Copy information parsed from a COPY statement. #[derive(Debug, Clone)] @@ -39,13 +42,6 @@ impl Default for CopyInfo { } } -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum CopyFormat { - Text, - Csv, - Binary, -} - #[derive(Debug, Clone)] enum CopyStream { Text(Box), @@ -58,8 +54,8 @@ pub struct CopyParser { headers: bool, /// CSV delimiter. delimiter: Option, - /// Number of columns - columns: usize, + /// Column names from the COPY statement. + column_names: Vec, /// This is a COPY coming from the client. is_from: bool, /// Stream parser. @@ -74,6 +70,10 @@ pub struct CopyParser { schema_shard: Option, /// String representing NULL values in text/CSV format. null_string: String, + /// Table name from the COPY statement. + table_name: String, + /// Schema name from the COPY statement. + schema_name: String, } impl Default for CopyParser { @@ -81,7 +81,7 @@ impl Default for CopyParser { Self { headers: false, delimiter: None, - columns: 0, + column_names: vec![], is_from: false, stream: CopyStream::Text(Box::new(CsvStream::new(',', false, CopyFormat::Csv, "\\N"))), sharding_schema: ShardingSchema::default(), @@ -89,6 +89,8 @@ impl Default for CopyParser { sharded_column: 0, schema_shard: None, null_string: "\\N".to_owned(), + table_name: String::new(), + schema_name: String::new(), } } } @@ -113,7 +115,20 @@ impl CopyParser { } } + parser.column_names = stmt + .attlist + .iter() + .filter_map(|n| match &n.node { + Some(NodeEnum::String(s)) => Some(s.sval.clone()), + _ => None, + }) + .collect(); + let table = Table::from(rel); + parser.table_name = table.name.to_owned(); + if let Some(schema) = table.schema { + parser.schema_name = schema.to_owned(); + } // The CopyParser is used for replicating // data during data-sync. This will ensure all rows @@ -127,8 +142,6 @@ impl CopyParser { parser.sharded_column = key.position; } - parser.columns = columns.len(); - for option in &stmt.options { if let Some(NodeEnum::DefElem(ref elem)) = option.node { match elem.defname.to_lowercase().as_str() { @@ -233,6 +246,14 @@ impl CopyParser { let shard = if is_end_marker { Shard::All + } else if let Some(shard) = Self::check_plugins( + &self.column_names, + &self.sharding_schema, + &record, + &self.table_name, + &self.schema_name, + ) { + shard } else if let Some(table) = &self.sharded_table { let key = record .get(self.sharded_column) @@ -307,6 +328,37 @@ impl CopyParser { Ok(rows) } + + fn check_plugins( + column_names: &[String], + schema: &ShardingSchema, + record: &Record, + table_name: &str, + schema_name: &str, + ) -> Option { + if let Some(plugins) = plugins() { + let columns: Vec = column_names + .iter() + .map(|s| PdStr::from(s.as_str())) + .collect(); + let table_name = PdStr::from(table_name); + let schema_name = PdStr::from(schema_name); + let context = + PdCopyRow::from_proto(schema.shards, record, &columns, &table_name, &schema_name); + + for plugin in plugins { + if let Some(route) = plugin.route_copy_row(context) { + if route.shard == -1 { + return Some(Shard::All); + } else if route.shard >= 0 { + return Some(Shard::Direct(route.shard as usize)); + } + } + } + } + + None + } } #[cfg(test)] diff --git a/pgdog/src/frontend/router/parser/csv/mod.rs b/pgdog/src/frontend/router/parser/csv/mod.rs index b36419e62..b0a82d358 100644 --- a/pgdog/src/frontend/router/parser/csv/mod.rs +++ b/pgdog/src/frontend/router/parser/csv/mod.rs @@ -6,7 +6,7 @@ pub mod record; pub use iterator::Iter; pub use record::Record; -use super::CopyFormat; +use pgdog_plugin::copy::CopyFormat; static RECORD_BUFFER: usize = 4096; static ENDS_BUFFER: usize = 2048; // Max of 2048 columns in a CSV. diff --git a/pgdog/src/frontend/router/parser/csv/record.rs b/pgdog/src/frontend/router/parser/csv/record.rs index 5a21e200f..d5d657e39 100644 --- a/pgdog/src/frontend/router/parser/csv/record.rs +++ b/pgdog/src/frontend/router/parser/csv/record.rs @@ -1,93 +1 @@ -use super::super::CopyFormat; -use std::{ops::Range, str::from_utf8}; - -/// A complete CSV record. -#[derive(Clone)] -pub struct Record { - /// Raw record data. - pub data: Vec, - /// Field ranges. - pub fields: Vec>, - /// Delimiter. - pub delimiter: char, - /// Format used. - pub format: CopyFormat, - /// Null string. - pub null_string: String, -} - -impl std::fmt::Debug for Record { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Record") - .field("data", &from_utf8(&self.data)) - .field("fields", &self.fields) - .field("delimiter", &self.delimiter) - .field("format", &self.format) - .field("null_string", &self.null_string) - .finish() - } -} - -impl std::fmt::Display for Record { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!( - f, - "{}", - (0..self.len()) - .map(|field| match self.format { - CopyFormat::Csv => { - let text = self.get(field).unwrap(); - if text == self.null_string { - text.to_owned() - } else { - format!("\"{}\"", self.get(field).unwrap().replace("\"", "\"\"")) - } - } - _ => self.get(field).unwrap().to_string(), - }) - .collect::>() - .join(&format!("{}", self.delimiter)) - ) - } -} - -impl Record { - pub(super) fn new( - data: &[u8], - ends: &[usize], - delimiter: char, - format: CopyFormat, - null_string: &str, - ) -> Self { - let mut last = 0; - let mut fields = vec![]; - for e in ends { - fields.push(last..*e); - last = *e; - } - Self { - data: data.to_vec(), - fields, - delimiter, - format, - null_string: null_string.to_owned(), - } - } - - /// Number of fields in the record. - pub fn len(&self) -> usize { - self.fields.len() - } - - /// Return true if there are no fields in the record. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - pub fn get(&self, index: usize) -> Option<&str> { - self.fields - .get(index) - .cloned() - .and_then(|range| from_utf8(&self.data[range]).ok()) - } -} +pub use pgdog_plugin::copy::{CopyFormat, Record}; diff --git a/pgdog/src/frontend/router/parser/query/mod.rs b/pgdog/src/frontend/router/parser/query/mod.rs index 1207f22a1..cf68d7108 100644 --- a/pgdog/src/frontend/router/parser/query/mod.rs +++ b/pgdog/src/frontend/router/parser/query/mod.rs @@ -491,6 +491,7 @@ impl QueryParser { } } + // let plugin_context = context.plugin_context(&ast.ast.protobuf, &None); let parser = CopyParser::new(stmt, context.router_context.cluster)?; if !stmt.is_from { context diff --git a/plugins/pgdog-example-plugin/Cargo.toml b/plugins/pgdog-example-plugin/Cargo.toml index 905c6098d..3bf2e8c98 100644 --- a/plugins/pgdog-example-plugin/Cargo.toml +++ b/plugins/pgdog-example-plugin/Cargo.toml @@ -8,6 +8,7 @@ crate-type = ["cdylib"] [dependencies] pgdog-plugin.workspace = true +csv = "1" once_cell = "1" parking_lot = "0.12" thiserror = "2" diff --git a/plugins/pgdog-example-plugin/src/lib.rs b/plugins/pgdog-example-plugin/src/lib.rs index f15bbeca7..dc796ef9f 100644 --- a/plugins/pgdog-example-plugin/src/lib.rs +++ b/plugins/pgdog-example-plugin/src/lib.rs @@ -8,7 +8,7 @@ pub mod plugin; -use pgdog_plugin::{Context, Route, macros}; +use pgdog_plugin::{Context, PdCopyRow, Route, macros}; // This identifies this library is a PgDog plugin and adds some // required methods automatically. @@ -31,6 +31,16 @@ fn route(context: Context) -> Route { crate::plugin::route_query(context).unwrap_or(Route::unknown()) } +/// If defined, this function is called for every row during a sharded COPY. +/// +/// It receives the raw row data along with metadata parsed from the COPY statement +/// (columns, format, delimiter, etc.) and returns a routing decision. +/// +#[macros::route_copy_row] +fn route_copy_row(row: PdCopyRow) -> Route { + crate::plugin::route_copy(row).unwrap_or(Route::unknown()) +} + /// Run any code before PgDog is shut down. /// /// This allows for plugins to upload stats to some external service diff --git a/plugins/pgdog-example-plugin/src/plugin.rs b/plugins/pgdog-example-plugin/src/plugin.rs index d6e26adf1..3039187cb 100644 --- a/plugins/pgdog-example-plugin/src/plugin.rs +++ b/plugins/pgdog-example-plugin/src/plugin.rs @@ -105,9 +105,54 @@ pub(crate) fn route_query(context: Context) -> Result { Ok(Route::unknown()) } +/// Route a COPY row to the correct shard. +/// +/// Finds the "id" column and hashes its value to pick a shard. +pub(crate) fn route_copy(row: PdCopyRow) -> Result { + let columns = row.columns(); + let shards = row.shards() as usize; + let record = row.record(); + + if columns.is_empty() { + return Ok(Route::unknown()); + } + + // Find the position of the "id" column. + let id_pos = match columns.iter().position(|&c| c == "id") { + Some(pos) => pos, + None => return Ok(Route::unknown()), + }; + + let field = match record.get(id_pos) { + Some(f) => f, + None => return Ok(Route::unknown()), + }; + + // NULL values go to all shards. + if field == record.null_string { + return Ok(Route::new(Shard::All, ReadWrite::Write)); + } + + // Parse the id and hash to a shard. + let id: i64 = match field.parse() { + Ok(v) => v, + Err(_) => return Ok(Route::unknown()), + }; + + println!( + "copy decoded row with id {} (table={}.{})", + id, + row.schema_name(), + row.table_name() + ); + + let shard = (id.unsigned_abs() as usize) % shards; + Ok(Route::new(Shard::Direct(shard), ReadWrite::Write)) +} + #[cfg(test)] mod test { - use pgdog_plugin::{PdParameters, PdStatement}; + use pgdog_plugin::{PdParameters, PdStatement, PdStr}; use super::*; @@ -132,4 +177,87 @@ mod test { assert_eq!(read_write, ReadWrite::Read); assert_eq!(shard, Shard::Unknown); } + + #[test] + fn test_copy_routes_by_id() { + let columns = [PdStr::from("id"), PdStr::from("name")]; + + // "7" + "Alice" concatenated, ends at [1, 6] + let record = Record::new(b"7Alice", &[1, 6], '\t', CopyFormat::Text, "\\N"); + let row = PdCopyRow::from_proto( + 4, + &record, + &columns, + &PdStr::from("users"), + &PdStr::from("public"), + ); + let route = route_copy(row).unwrap(); + assert_eq!(route.shard.try_into(), Ok(Shard::Direct(3))); // 7 % 4 = 3 + + // "0" + "Bob" concatenated, ends at [1, 4] + let record = Record::new(b"0Bob", &[1, 4], '\t', CopyFormat::Text, "\\N"); + let row = PdCopyRow::from_proto( + 4, + &record, + &columns, + &PdStr::from("users"), + &PdStr::from("public"), + ); + let route = route_copy(row).unwrap(); + assert_eq!(route.shard.try_into(), Ok(Shard::Direct(0))); // 0 % 4 = 0 + } + + #[test] + fn test_copy_null_id_routes_to_all() { + let columns = [PdStr::from("id"), PdStr::from("name")]; + + let record = Record::new(b"\\NAlice", &[2, 7], '\t', CopyFormat::Text, "\\N"); + let row = PdCopyRow::from_proto( + 4, + &record, + &columns, + &PdStr::from("users"), + &PdStr::from("public"), + ); + let route = route_copy(row).unwrap(); + assert_eq!(route.shard.try_into(), Ok(Shard::All)); + } + + #[test] + fn test_copy_csv_delimiter() { + let columns = [PdStr::from("id"), PdStr::from("name")]; + + let record = Record::new(b"5Charlie", &[1, 8], ',', CopyFormat::Csv, "\\N"); + let row = PdCopyRow::from_proto( + 3, + &record, + &columns, + &PdStr::from("users"), + &PdStr::from("public"), + ); + let route = route_copy(row).unwrap(); + assert_eq!(route.shard.try_into(), Ok(Shard::Direct(2))); // 5 % 3 = 2 + } + + #[test] + fn test_copy_no_id_column() { + let columns = [PdStr::from("name"), PdStr::from("email")]; + + let record = Record::new( + b"Alicealice@test.com", + &[5, 19], + '\t', + CopyFormat::Text, + "\\N", + ); + let row = PdCopyRow::from_proto( + 4, + &record, + &columns, + &PdStr::from("users"), + &PdStr::from("public"), + ); + let route = route_copy(row).unwrap(); + assert_eq!(route.shard.try_into(), Ok(Shard::Unknown)); + } }