diff --git a/README.md b/README.md index b2ade50..8ea09d6 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,10 @@ the [pgwire](https://github.com/sunng87/pgwire) project. - [x] pgcli - [x] VSCode SQLTools - [ ] Intellij Datagrip -- BI +- BI & Visualization - [x] Metabase - [ ] PowerBI + - [x] Grafana ## Quick Start diff --git a/datafusion-pg-catalog/src/pg_catalog.rs b/datafusion-pg-catalog/src/pg_catalog.rs index 88651d4..5703d98 100644 --- a/datafusion-pg-catalog/src/pg_catalog.rs +++ b/datafusion-pg-catalog/src/pg_catalog.rs @@ -42,6 +42,7 @@ pub mod pg_settings; pub mod pg_stat_gssapi; pub mod pg_tables; pub mod pg_views; +pub mod quote_ident_udf; const PG_CATALOG_TABLE_PG_AGGREGATE: &str = "pg_aggregate"; const PG_CATALOG_TABLE_PG_AM: &str = "pg_am"; @@ -1481,6 +1482,8 @@ where session_context.register_udf(create_pg_stat_get_numscans()); session_context.register_udf(create_pg_get_constraintdef()); session_context.register_udf(create_pg_get_partition_ancestors_udf()); + session_context.register_udf(quote_ident_udf::create_quote_ident_udf()); + session_context.register_udf(quote_ident_udf::create_parse_ident_udf()); Ok(()) } diff --git a/datafusion-pg-catalog/src/pg_catalog/quote_ident_udf.rs b/datafusion-pg-catalog/src/pg_catalog/quote_ident_udf.rs new file mode 100644 index 0000000..b206792 --- /dev/null +++ b/datafusion-pg-catalog/src/pg_catalog/quote_ident_udf.rs @@ -0,0 +1,428 @@ +use std::sync::Arc; + +use datafusion::arrow::array::{Array, ArrayRef, AsArray, ListBuilder, StringBuilder}; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::error::Result; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion::prelude::create_udf; + +/// Create a PostgreSQL quote_ident UDF +pub fn create_quote_ident_udf() -> ScalarUDF { + let func = move |args: &[ColumnarValue]| { + let args = ColumnarValue::values_to_arrays(args)?; + let string_array = args[0].as_string::(); + + let mut builder = StringBuilder::new(); + for ident in string_array.iter() { + if let Some(ident) = ident { + // PostgreSQL quote_ident implementation: + // 1. If identifier is already quoted and contains no special chars, return as-is + // 2. If identifier contains no special chars and is not a reserved word, return as-is + // 3. Otherwise, wrap in double quotes and escape any internal double quotes + let quoted = if ident.starts_with('"') && ident.ends_with('"') { + // Already quoted, just escape internal quotes + ident.replace('"', "\"\"") + } else if needs_quoting(ident) { + // Needs quoting - wrap in quotes and escape internal quotes + format!("\"{}\"", ident.replace('"', "\"\"")) + } else { + // No quoting needed + ident.to_string() + }; + builder.append_value("ed); + } else { + builder.append_null(); + } + } + let array: ArrayRef = Arc::new(builder.finish()); + + Ok(ColumnarValue::Array(array)) + }; + + create_udf( + "quote_ident", + vec![DataType::Utf8], + DataType::Utf8, + Volatility::Stable, + Arc::new(func), + ) +} + +#[derive(Debug, Hash, PartialEq, Eq)] +pub struct ParseIdentUDF { + signature: Signature, +} + +impl ParseIdentUDF { + pub fn new() -> ParseIdentUDF { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Boolean]), + ], + Volatility::Stable, + ), + } + } + + pub fn into_scalar_udf(self) -> ScalarUDF { + ScalarUDF::new_from_impl(self) + } +} + +impl Default for ParseIdentUDF { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ParseIdentUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new( + "item", + DataType::Utf8, + true, + )))) + } + + fn name(&self) -> &str { + "parse_ident" + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let string_array = args[0].as_string::(); + let strict_mode = if args.len() > 1 { + args[1].as_boolean() + } else { + &datafusion::arrow::array::BooleanArray::from(vec![false; string_array.len()]) + }; + + let mut builder = ListBuilder::new(StringBuilder::new()); + + for (i, ident) in string_array.iter().enumerate() { + if let Some(ident) = ident { + let strict = strict_mode.value(i); + match parse_ident_string(ident, strict) { + Ok(parts) => { + for part in parts { + builder.values().append_value(part); + } + builder.append(true); + } + Err(_) => { + if strict { + return Err(datafusion::error::DataFusionError::Execution(format!( + "invalid identifier: {}", + ident + ))); + } else { + // In non-strict mode, return empty array for invalid identifiers + builder.append(true); + } + } + } + } else { + builder.append_null(); + } + } + + let array: ArrayRef = Arc::new(builder.finish()); + Ok(ColumnarValue::Array(array)) + } +} + +/// Create a PostgreSQL parse_ident UDF +pub fn create_parse_ident_udf() -> ScalarUDF { + ParseIdentUDF::new().into_scalar_udf() +} + +/// Parse an identifier string into its component parts +fn parse_ident_string(ident: &str, strict: bool) -> Result, &'static str> { + if ident.is_empty() { + return Err("empty identifier"); + } + + let mut parts = Vec::new(); + let mut chars = ident.chars().peekable(); + let mut current_part = String::new(); + let mut in_quotes = false; + + while let Some(&c) = chars.peek() { + match c { + '"' if !in_quotes => { + // Start of quoted identifier + in_quotes = true; + chars.next(); // consume the quote + } + '"' if in_quotes => { + // Check for escaped quote (double quote) + chars.next(); // consume first quote + if let Some(&'"') = chars.peek() { + // Escaped quote + current_part.push('"'); + chars.next(); // consume second quote + } else { + // End of quoted identifier + in_quotes = false; + if !current_part.is_empty() { + parts.push(current_part); + current_part = String::new(); + } + } + } + '.' if !in_quotes => { + // Separator between parts + chars.next(); // consume the dot + if !current_part.is_empty() { + parts.push(current_part); + current_part = String::new(); + } else if strict { + return Err("empty identifier part"); + } + } + _ => { + current_part.push(c); + chars.next(); + } + } + } + + // Handle the last part + if in_quotes { + return Err("unterminated quoted identifier"); + } + + if !current_part.is_empty() { + parts.push(current_part); + } else if ident.ends_with('.') && strict { + // In strict mode, trailing dot indicates empty identifier part + return Err("empty identifier part"); + } + + if parts.is_empty() { + Err("no valid identifier parts") + } else { + Ok(parts) + } +} + +/// Check if an identifier needs quoting according to PostgreSQL rules +fn needs_quoting(ident: &str) -> bool { + if ident.is_empty() { + return true; + } + + // Check if identifier starts with a letter or underscore and contains only letters, digits, underscores + let mut chars = ident.chars(); + if let Some(first_char) = chars.next() { + if !first_char.is_alphabetic() && first_char != '_' { + return true; + } + } + + // Check remaining characters + for c in chars { + if !c.is_alphanumeric() && c != '_' { + return true; + } + } + + // Check if it's a PostgreSQL reserved word + is_reserved_word(ident) +} + +/// Check if identifier is a PostgreSQL reserved word +fn is_reserved_word(word: &str) -> bool { + let reserved_words = [ + "ALL", + "ANALYSE", + "ANALYZE", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "ASYMMETRIC", + "AUTHORIZATION", + "BETWEEN", + "BINARY", + "BOTH", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "CONCURRENTLY", + "CONSTRAINT", + "CREATE", + "CROSS", + "CURRENT_CATALOG", + "CURRENT_DATE", + "CURRENT_ROLE", + "CURRENT_SCHEMA", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "DEFAULT", + "DEFERRABLE", + "DESC", + "DISTINCT", + "DO", + "ELSE", + "END", + "EXCEPT", + "FALSE", + "FETCH", + "FOR", + "FOREIGN", + "FROM", + "FULL", + "GRANT", + "GROUP", + "HAVING", + "ILIKE", + "IN", + "INITIALLY", + "INNER", + "INTERSECT", + "INTO", + "IS", + "ISNULL", + "JOIN", + "LATERAL", + "LEADING", + "LEFT", + "LIKE", + "LIMIT", + "LOCALTIME", + "LOCALTIMESTAMP", + "NATURAL", + "NOT", + "NOTNULL", + "NULL", + "OFFSET", + "ON", + "ONLY", + "OR", + "ORDER", + "OUTER", + "OVERLAPS", + "PLACING", + "PRIMARY", + "REFERENCES", + "RETURNING", + "RIGHT", + "SELECT", + "SESSION_USER", + "SIMILAR", + "SOME", + "SYMMETRIC", + "TABLE", + "TABLESAMPLE", + "THEN", + "TO", + "TRAILING", + "TRUE", + "UNION", + "UNIQUE", + "USER", + "USING", + "VARIADIC", + "VERBOSE", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + ]; + + reserved_words.contains(&word.to_uppercase().as_str()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_quote_ident() { + // Test the helper functions directly + assert_eq!(needs_quoting("simple"), false); + assert_eq!(needs_quoting("_underscore"), false); + assert_eq!(needs_quoting("with space"), true); + assert_eq!(needs_quoting("123start"), true); + assert_eq!(needs_quoting("with-dash"), true); + assert_eq!(needs_quoting("select"), true); + assert_eq!(needs_quoting("SELECT"), true); + assert_eq!(needs_quoting(""), true); + + // Test reserved word detection + assert_eq!(is_reserved_word("select"), true); + assert_eq!(is_reserved_word("SELECT"), true); + assert_eq!(is_reserved_word("not_reserved"), false); + } + + #[test] + fn test_parse_ident() { + // Test basic parsing + assert_eq!(parse_ident_string("simple", false).unwrap(), vec!["simple"]); + assert_eq!( + parse_ident_string("schema.table", false).unwrap(), + vec!["schema", "table"] + ); + assert_eq!( + parse_ident_string("db.schema.table", false).unwrap(), + vec!["db", "schema", "table"] + ); + + // Test quoted identifiers + assert_eq!( + parse_ident_string("\"quoted.ident\"", false).unwrap(), + vec!["quoted.ident"] + ); + assert_eq!( + parse_ident_string("\"schema\".\"table\"", false).unwrap(), + vec!["schema", "table"] + ); + + // Test escaped quotes + assert_eq!( + parse_ident_string("\"quote\"\"test\"", false).unwrap(), + vec!["quote\"test"] + ); + + // Test mixed quoted and unquoted + assert_eq!( + parse_ident_string("schema.\"table.name\"", false).unwrap(), + vec!["schema", "table.name"] + ); + + // Test edge cases + assert!(parse_ident_string("", false).is_err()); + assert!(parse_ident_string("unterminated\"", false).is_err()); + assert_eq!( + parse_ident_string("trailing.", false).unwrap(), + vec!["trailing"] + ); + assert_eq!( + parse_ident_string(".leading", false).unwrap(), + vec!["leading"] + ); + + // Test strict mode + assert!(parse_ident_string("trailing.", true).is_err()); + assert!(parse_ident_string(".leading", true).is_err()); + assert!(parse_ident_string("..", true).is_err()); + } +} diff --git a/datafusion-pg-catalog/src/sql/parser.rs b/datafusion-pg-catalog/src/sql/parser.rs index 0420c9c..a110ca7 100644 --- a/datafusion-pg-catalog/src/sql/parser.rs +++ b/datafusion-pg-catalog/src/sql/parser.rs @@ -7,12 +7,11 @@ use datafusion::sql::sqlparser::parser::ParserError; use datafusion::sql::sqlparser::tokenizer::Token; use datafusion::sql::sqlparser::tokenizer::TokenWithSpan; -use crate::sql::rules::FixVersionColumnName; - use super::rules::AliasDuplicatedProjectionRewrite; use super::rules::CurrentUserVariableToSessionUserFunctionCall; use super::rules::FixArrayLiteral; use super::rules::FixCollate; +use super::rules::FixVersionColumnName; use super::rules::PrependUnqualifiedPgTableName; use super::rules::RemoveQualifier; use super::rules::RemoveSubqueryFromProjection; @@ -164,6 +163,17 @@ const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[ NULL::TEXT AS _2 WHERE false" ), + + // grafana array index magic + (r#"SELECT + CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END + FROM + generate_series( + array_lower(string_to_array(current_setting('search_path'),','),1), + array_upper(string_to_array(current_setting('search_path'),','),1) + ) as i, + string_to_array(current_setting('search_path'),',') s"#, +"''") ]; /// A parser with Postgres Compatibility for Datafusion @@ -173,7 +183,7 @@ const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[ /// statement to a similar version if rewrite doesn't worth the effort for now. #[derive(Debug)] pub struct PostgresCompatibilityParser { - blacklist: Vec<(Vec, Statement)>, + blacklist: Vec<(Vec, Vec)>, rewrite_rules: Vec>, } @@ -200,8 +210,11 @@ impl PostgresCompatibilityParser { Parser::new(&PostgreSqlDialect {}) .try_with_sql(sql_to) .unwrap() - .parse_statement() - .unwrap(), + .into_tokens() + .into_iter() + .map(|t| t.token) + .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon)) + .collect(), )); } @@ -225,52 +238,110 @@ impl PostgresCompatibilityParser { } } - /// return statement if matched - fn parse_and_replace(&self, input: &str) -> Result { + /// return tokens with replacements applied + fn maybe_replace_tokens(&self, input: &str) -> Result, ParserError> { let parser = Parser::new(&PostgreSqlDialect {}); let tokens = parser.try_with_sql(input)?.into_tokens(); - let tokens_without_whitespace = tokens + // Get token values (without spans) and filter out only whitespace + // Keep semicolons as they separate statements + let filtered_tokens: Vec = tokens .iter() - .filter(|t| !matches!(t.token, Token::Whitespace(_) | Token::SemiColon)) - .collect::>(); + .map(|t| t.token.clone()) + .filter(|t| !matches!(t, Token::Whitespace(_))) + .collect(); + + // Handle empty input + if filtered_tokens.is_empty() { + return Ok(Vec::new()); + } + + // Build result by processing filtered tokens sequentially + let mut result = Vec::new(); + let mut i = 0; - for (blacklisted_sql_tokens, replacement) in &self.blacklist { - if blacklisted_sql_tokens.len() == tokens_without_whitespace.len() { - let matches = blacklisted_sql_tokens - .iter() - .zip(tokens_without_whitespace.iter()) - .all(|(a, b)| { - if matches!(a, Token::Placeholder(_)) { - true - } else { - *a == b.token + while i < filtered_tokens.len() { + // Keep semicolons as-is + if matches!(&filtered_tokens[i], Token::SemiColon) { + result.push(filtered_tokens[i].clone()); + i += 1; + continue; + } + + // Try to find a blacklist pattern match starting at this position + let mut matched = false; + for (pattern, replacement) in &self.blacklist { + if pattern.is_empty() { + continue; + } + + // Check if we have enough tokens remaining + let mut j = 0; + let mut pattern_idx = 0; + while i + j < filtered_tokens.len() && pattern_idx < pattern.len() { + // Skip semicolons in the input when matching patterns + if matches!(&filtered_tokens[i + j], Token::SemiColon) { + j += 1; + continue; + } + + match &pattern[pattern_idx] { + Token::Placeholder(_) => { + // Placeholder matches any non-semicolon token + pattern_idx += 1; + j += 1; } - }); - if matches { - return Ok(MatchResult::Matches(Box::new(replacement.clone()))); + _ => { + if filtered_tokens[i + j] != pattern[pattern_idx] { + break; + } + pattern_idx += 1; + j += 1; + } + } + } + + // Check if we matched the entire pattern + if pattern_idx == pattern.len() { + // Add replacement tokens + result.extend(replacement.iter().cloned()); + // Skip the matched pattern (including any semicolons we skipped) + i += j; + matched = true; + break; } - } else { - continue; + } + + if !matched { + // No match, keep the original token + result.push(filtered_tokens[i].clone()); + i += 1; } } - Ok(MatchResult::Unmatches(tokens)) + Ok(result) } - fn parse_tokens(&self, tokens: Vec) -> Result, ParserError> { + fn parse_tokens(&self, tokens: Vec) -> Result, ParserError> { let parser = Parser::new(&PostgreSqlDialect {}); - parser.with_tokens_with_locations(tokens).parse_statements() + // Convert tokens to TokenWithSpan with dummy spans + let tokens_with_spans: Vec = tokens + .into_iter() + .map(|token| TokenWithSpan { + token, + span: datafusion::sql::sqlparser::tokenizer::Span::empty(), + }) + .collect(); + parser + .with_tokens_with_locations(tokens_with_spans) + .parse_statements() } pub fn parse(&self, input: &str) -> Result, ParserError> { - let statements = match self.parse_and_replace(input)? { - MatchResult::Matches(statement) => vec![*statement], - MatchResult::Unmatches(tokens) => self.parse_tokens(tokens)?, - }; - - let statements = statements.into_iter().map(|s| self.rewrite(s)).collect(); + let tokens = self.maybe_replace_tokens(input)?; + let statements = self.parse_tokens(tokens)?; + let statements: Vec<_> = statements.into_iter().map(|s| self.rewrite(s)).collect(); Ok(statements) } @@ -283,17 +354,12 @@ impl PostgresCompatibilityParser { } } -pub(crate) enum MatchResult { - Matches(Box), - Unmatches(Vec), -} - #[cfg(test)] mod tests { use super::*; #[test] - fn test_sql_mapping() { + fn test_full_match() { let sql = "SELECT pol.polname, pol.polpermissive, CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END, pg_catalog.pg_get_expr(pol.polqual, pol.polrelid), @@ -308,8 +374,32 @@ mod tests { WHERE pol.polrelid = '16384' ORDER BY 1;"; let parser = PostgresCompatibilityParser::new(); - let match_result = parser.parse_and_replace(sql).expect("failed to parse sql"); - assert!(matches!(match_result, MatchResult::Matches(_))); + let actual_tokens = parser + .maybe_replace_tokens(sql) + .expect("failed to parse sql") + .into_iter() + .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon)) + .collect::>(); + + let expected_sql = r#"SELECT + NULL::TEXT AS polname, + NULL::TEXT AS polpermissive, + NULL::TEXT AS array_to_string, + NULL::TEXT AS pg_get_expr_1, + NULL::TEXT AS pg_get_expr_2, + NULL::TEXT AS cmd + WHERE false"#; + + let expected_tokens = Parser::new(&PostgreSqlDialect {}) + .try_with_sql(expected_sql) + .unwrap() + .into_tokens() + .into_iter() + .map(|t| t.token) + .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon)) + .collect::>(); + + assert_eq!(actual_tokens, expected_tokens); let sql = "SELECT n.nspname schema_name, t.typname type_name @@ -333,8 +423,80 @@ mod tests { ORDER BY 1, 2"; let parser = PostgresCompatibilityParser::new(); - let match_result = parser.parse_and_replace(sql).expect("failed to parse sql"); - assert!(matches!(match_result, MatchResult::Matches(_))); + + let actual_tokens = parser + .maybe_replace_tokens(sql) + .expect("failed to parse sql") + .into_iter() + .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon)) + .collect::>(); + + let expected_sql = + r#"SELECT NULL::TEXT AS schema_name, NULL::TEXT AS type_name WHERE false"#; + + let expected_tokens = Parser::new(&PostgreSqlDialect {}) + .try_with_sql(expected_sql) + .unwrap() + .into_tokens() + .into_iter() + .map(|t| t.token) + .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon)) + .collect::>(); + + assert_eq!(actual_tokens, expected_tokens); + + let sql = "SELECT pubname + , NULL + , NULL + FROM pg_catalog.pg_publication p + JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid + JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid + WHERE pc.oid ='16384' and pg_catalog.pg_relation_is_publishable('16384') + UNION + SELECT pubname + , pg_get_expr(pr.prqual, c.oid) + , (CASE WHEN pr.prattrs IS NOT NULL THEN + (SELECT string_agg(attname, ', ') + FROM pg_catalog.generate_series(0, pg_catalog.array_upper(pr.prattrs::pg_catalog.int2[], 1)) s, + pg_catalog.pg_attribute + WHERE attrelid = pr.prrelid AND attnum = prattrs[s]) + ELSE NULL END) FROM pg_catalog.pg_publication p + JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid + JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid + WHERE pr.prrelid = '16384' + UNION + SELECT pubname + , NULL + , NULL + FROM pg_catalog.pg_publication p + WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable('16384') + ORDER BY 1;"; + + let parser = PostgresCompatibilityParser::new(); + + let actual_tokens = parser + .maybe_replace_tokens(sql) + .expect("failed to parse sql") + .into_iter() + .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon)) + .collect::>(); + + let expected_sql = r#"SELECT + NULL::TEXT AS pubname, + NULL::TEXT AS _1, + NULL::TEXT AS _2 + WHERE false"#; + + let expected_tokens = Parser::new(&PostgreSqlDialect {}) + .try_with_sql(expected_sql) + .unwrap() + .into_tokens() + .into_iter() + .map(|t| t.token) + .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon)) + .collect::>(); + + assert_eq!(actual_tokens, expected_tokens); } #[test] @@ -349,4 +511,92 @@ mod tests { let result = parser.parse(";").expect("failed to parse sql"); assert!(result.is_empty()); } + + #[test] + fn test_partial_match() { + let parser = PostgresCompatibilityParser::new(); + + // Test partial match where the beginning matches a blacklisted query + // Using a simpler query that doesn't have placeholders for easier testing + let sql = r#"SELECT + CASE WHEN + quote_ident(table_schema) IN ( + SELECT + CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END + FROM + generate_series( + array_lower(string_to_array(current_setting('search_path'),','),1), + array_upper(string_to_array(current_setting('search_path'),','),1) + ) as i, + string_to_array(current_setting('search_path'),',') s + ) + THEN quote_ident(table_name) + ELSE quote_ident(table_schema) || '.' || quote_ident(table_name) + END AS "table" + FROM information_schema.tables + WHERE quote_ident(table_schema) NOT IN ('information_schema', + 'pg_catalog', + '_timescaledb_cache', + '_timescaledb_catalog', + '_timescaledb_internal', + '_timescaledb_config', + 'timescaledb_information', + 'timescaledb_experimental') + ORDER BY CASE WHEN + quote_ident(table_schema) IN ( + SELECT + CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END + FROM + generate_series( + array_lower(string_to_array(current_setting('search_path'),','),1), + array_upper(string_to_array(current_setting('search_path'),','),1) + ) as i, + string_to_array(current_setting('search_path'),',') s + ) THEN 0 ELSE 1 END, 1"#; + + let tokens = parser + .maybe_replace_tokens(sql) + .expect("failed to parse sql"); + // Should have the beginning replaced with 'SELECT' and the rest preserved + assert!(tokens.len() > 0); + + let expected_sql = r#"SELECT + CASE WHEN + quote_ident(table_schema) IN ( + '') + THEN quote_ident(table_name) + ELSE quote_ident(table_schema) || '.' || quote_ident(table_name) + END AS "table" + FROM information_schema.tables + WHERE quote_ident(table_schema) NOT IN ('information_schema', + 'pg_catalog', + '_timescaledb_cache', + '_timescaledb_catalog', + '_timescaledb_internal', + '_timescaledb_config', + 'timescaledb_information', + 'timescaledb_experimental') + ORDER BY CASE WHEN + quote_ident(table_schema) IN ( + '' + ) THEN 0 ELSE 1 END, 1"#; + + let expected_tokens = Parser::new(&PostgreSqlDialect {}) + .try_with_sql(expected_sql) + .unwrap() + .into_tokens(); + + // Compare token values (ignoring spans and whitespace) + let actual_tokens: Vec<_> = tokens + .iter() + .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon)) + .collect(); + let expected_token_values: Vec<_> = expected_tokens + .iter() + .map(|t| &t.token) + .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon)) + .collect(); + + assert_eq!(actual_tokens, expected_token_values); + } } diff --git a/datafusion-postgres/src/testing.rs b/datafusion-postgres/src/testing.rs index 98a1b69..f2d53db 100644 --- a/datafusion-postgres/src/testing.rs +++ b/datafusion-postgres/src/testing.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, sync::Arc}; -use datafusion::prelude::SessionContext; +use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_pg_catalog::pg_catalog::setup_pg_catalog; use futures::Sink; use pgwire::{ @@ -13,7 +13,9 @@ use pgwire::{ use crate::{auth::AuthManager, DfSessionService}; pub fn setup_handlers() -> DfSessionService { - let session_context = SessionContext::new(); + let session_config = SessionConfig::new().with_information_schema(true); + let session_context = SessionContext::new_with_config(session_config); + setup_pg_catalog( &session_context, "datafusion", diff --git a/datafusion-postgres/tests/grafana.rs b/datafusion-postgres/tests/grafana.rs new file mode 100644 index 0000000..b6b14bd --- /dev/null +++ b/datafusion-postgres/tests/grafana.rs @@ -0,0 +1,73 @@ +use pgwire::api::query::SimpleQueryHandler; + +use datafusion_postgres::testing::*; + +const GRAFANA_QUERIES: &[&str] = &[ + r#"SELECT + CASE WHEN + quote_ident(table_schema) IN ( + SELECT + CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END + FROM + generate_series( + array_lower(string_to_array(current_setting('search_path'),','),1), + array_upper(string_to_array(current_setting('search_path'),','),1) + ) as i, + string_to_array(current_setting('search_path'),',') s + ) + THEN quote_ident(table_name) + ELSE quote_ident(table_schema) || '.' || quote_ident(table_name) + END AS "table" + FROM information_schema.tables + WHERE quote_ident(table_schema) NOT IN ('information_schema', + 'pg_catalog', + '_timescaledb_cache', + '_timescaledb_catalog', + '_timescaledb_internal', + '_timescaledb_config', + 'timescaledb_information', + 'timescaledb_experimental') + ORDER BY CASE WHEN + quote_ident(table_schema) IN ( + SELECT + CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END + FROM + generate_series( + array_lower(string_to_array(current_setting('search_path'),','),1), + array_upper(string_to_array(current_setting('search_path'),','),1) + ) as i, + string_to_array(current_setting('search_path'),',') s + ) THEN 0 ELSE 1 END, 1"#, + r#"SELECT quote_ident(column_name) AS "column", data_type AS "type" + FROM information_schema.columns + WHERE + CASE WHEN array_length(parse_ident('public.games'),1) = 2 + THEN quote_ident(table_schema) = (parse_ident('public.games'))[1] + AND quote_ident(table_name) = (parse_ident('public.games'))[2] + ELSE quote_ident(table_name) = 'public.games' + AND + quote_ident(table_schema) IN ( + SELECT + CASE WHEN trim(s[i]) = '"$user"' THEN user ELSE trim(s[i]) END + FROM + generate_series( + array_lower(string_to_array(current_setting('search_path'),','),1), + array_upper(string_to_array(current_setting('search_path'),','),1) + ) as i, + string_to_array(current_setting('search_path'),',') s + ) + END"#, +]; + +#[tokio::test] +pub async fn test_grafana_sql() { + env_logger::init(); + let service = setup_handlers(); + let mut client = MockClient::new(); + + for query in GRAFANA_QUERIES { + SimpleQueryHandler::do_query(&service, &mut client, query) + .await + .unwrap_or_else(|e| panic!("failed to run sql: {query}\n{e}")); + } +}