diff --git a/crates/core/src/delta_datafusion/engine/expressions/mod.rs b/crates/core/src/delta_datafusion/engine/expressions/mod.rs new file mode 100644 index 0000000000..6b2bed37e0 --- /dev/null +++ b/crates/core/src/delta_datafusion/engine/expressions/mod.rs @@ -0,0 +1,79 @@ +pub use self::to_datafusion::*; +pub(crate) use self::to_kernel::*; + +mod to_datafusion; +mod to_json; +mod to_kernel; + +#[cfg(test)] +mod tests { + use std::ops::Not; + + use datafusion::logical_expr::{col, lit}; + use delta_kernel::schema::DataType; + + use super::*; + + #[test] + fn test_roundtrip_simple_and() { + let df_expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); + let delta_expr = to_delta_expression(&df_expr).unwrap(); + let df_expr_roundtrip = to_datafusion_expr(&delta_expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(df_expr, df_expr_roundtrip); + } + + #[test] + fn test_roundtrip_nested_and() { + let df_expr = col("a") + .eq(lit(1)) + .and(col("b").eq(lit(2))) + .and(col("c").eq(lit(3))) + .and(col("d").eq(lit(4))); + let delta_expr = to_delta_expression(&df_expr).unwrap(); + let df_expr_roundtrip = to_datafusion_expr(&delta_expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(df_expr, df_expr_roundtrip); + } + + #[test] + fn test_roundtrip_mixed_and_or() { + let df_expr = col("a") + .eq(lit(1)) + .and(col("b").eq(lit(2))) + .or(col("c").eq(lit(3)).and(col("d").eq(lit(4)))); + let delta_expr = to_delta_expression(&df_expr).unwrap(); + let df_expr_roundtrip = to_datafusion_expr(&delta_expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(df_expr, df_expr_roundtrip); + } + + #[test] + fn test_roundtrip_unary() { + let df_expr = !col("a").eq(lit(1)); + let delta_expr = to_delta_expression(&df_expr).unwrap(); + let df_expr_roundtrip = to_datafusion_expr(&delta_expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(df_expr, df_expr_roundtrip); + } + + #[test] + fn test_roundtrip_is_null() { + let df_expr = col("a").is_null(); + let delta_expr = to_delta_expression(&df_expr).unwrap(); + let df_expr_roundtrip = to_datafusion_expr(&delta_expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(df_expr, df_expr_roundtrip); + } + + #[test] + fn test_roundtrip_binary_ops() { + let df_expr = col("a") + col("b") * col("c"); + let delta_expr = to_delta_expression(&df_expr).unwrap(); + let df_expr_roundtrip = to_datafusion_expr(&delta_expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(df_expr, df_expr_roundtrip); + } + + #[test] + fn test_roundtrip_comparison_ops() { + let df_expr = col("a").gt(col("b")).and(col("c").gt(col("d")).not()); + let delta_expr = to_delta_expression(&df_expr).unwrap(); + let df_expr_roundtrip = to_datafusion_expr(&delta_expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(df_expr, df_expr_roundtrip); + } +} diff --git a/crates/core/src/delta_datafusion/engine/expressions/to_datafusion.rs b/crates/core/src/delta_datafusion/engine/expressions/to_datafusion.rs new file mode 100644 index 0000000000..dae4e24f79 --- /dev/null +++ b/crates/core/src/delta_datafusion/engine/expressions/to_datafusion.rs @@ -0,0 +1,883 @@ +use std::sync::Arc; + +use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue, not_impl_err}; +use datafusion::functions::core::expr_ext::FieldAccessor; +use datafusion::functions::expr_fn::named_struct; +use datafusion::logical_expr::expr::ScalarFunction; +use datafusion::logical_expr::{BinaryExpr, Expr, Operator, col, lit}; +use delta_kernel::Predicate; +use delta_kernel::arrow::datatypes::{DataType as ArrowDataType, Field as ArrowField}; +use delta_kernel::engine::arrow_conversion::TryIntoArrow; +use delta_kernel::expressions::{ + BinaryExpression, BinaryExpressionOp, BinaryPredicate, BinaryPredicateOp, Expression, + JunctionPredicate, JunctionPredicateOp, Scalar, UnaryExpression, UnaryExpressionOp, + UnaryPredicate, UnaryPredicateOp, +}; +use delta_kernel::schema::DataType; +use itertools::Itertools; + +use crate::delta_datafusion::engine::expressions::to_json::to_json; + +pub fn to_datafusion_expr(expr: &Expression, output_type: &DataType) -> DFResult { + match expr { + Expression::Literal(scalar) => scalar_to_df(scalar).map(lit), + Expression::Column(name) => { + let mut name_iter = name.iter(); + let base_name = name_iter.next().ok_or_else(|| { + DataFusionError::Internal("Expected at least one column name".into()) + })?; + Ok(name_iter.fold(col(base_name), |acc, n| acc.field(n))) + } + Expression::Predicate(expr) => predicate_to_df(expr, output_type), + Expression::Struct(fields) => struct_to_df(fields, output_type), + Expression::Binary(expr) => binary_to_df(expr, output_type), + Expression::Unary(expr) => unary_to_df(expr, output_type), + Expression::Opaque(_) => not_impl_err!("Opaque expressions are not yet supported"), + Expression::Unknown(_) => not_impl_err!("Unknown expressions are not yet supported"), + Expression::Transform(_) => not_impl_err!("Transform expressions are not yet supported"), + Expression::Variadic(_) => not_impl_err!("Variadic expressions are not yet supported"), + } +} + +pub(crate) fn scalar_to_df(scalar: &Scalar) -> DFResult { + Ok(match scalar { + Scalar::Boolean(value) => ScalarValue::Boolean(Some(*value)), + Scalar::String(value) => ScalarValue::Utf8(Some(value.clone())), + Scalar::Byte(value) => ScalarValue::Int8(Some(*value)), + Scalar::Short(value) => ScalarValue::Int16(Some(*value)), + Scalar::Integer(value) => ScalarValue::Int32(Some(*value)), + Scalar::Long(value) => ScalarValue::Int64(Some(*value)), + Scalar::Float(value) => ScalarValue::Float32(Some(*value)), + Scalar::Double(value) => ScalarValue::Float64(Some(*value)), + Scalar::Timestamp(value) => { + ScalarValue::TimestampMicrosecond(Some(*value), Some("UTC".into())) + } + Scalar::TimestampNtz(value) => ScalarValue::TimestampMicrosecond(Some(*value), None), + Scalar::Date(value) => ScalarValue::Date32(Some(*value)), + Scalar::Binary(value) => ScalarValue::Binary(Some(value.clone())), + Scalar::Decimal(data) => { + ScalarValue::Decimal128(Some(data.bits()), data.precision(), data.scale() as i8) + } + Scalar::Struct(data) => { + let fields: Vec = data + .fields() + .iter() + .map(|f| f.try_into_arrow()) + .try_collect()?; + let values: Vec<_> = data.values().iter().map(scalar_to_df).try_collect()?; + fields + .into_iter() + .zip(values.into_iter()) + .fold(ScalarStructBuilder::new(), |builder, (field, value)| { + builder.with_scalar(field, value) + }) + .build()? + } + Scalar::Array(_) => { + return Err(DataFusionError::NotImplemented( + "Array scalar values not implemented".into(), + )); + } + Scalar::Map(_) => { + return Err(DataFusionError::NotImplemented( + "Map scalar values not implemented".into(), + )); + } + Scalar::Null(data_type) => { + let data_type: ArrowDataType = data_type + .try_into_arrow() + .map_err(|e| DataFusionError::External(e.into()))?; + ScalarValue::try_from(&data_type)? + } + }) +} + +fn binary_to_df(bin: &BinaryExpression, output_type: &DataType) -> DFResult { + let BinaryExpression { left, op, right } = bin; + let left_expr = to_datafusion_expr(left, output_type)?; + let right_expr = to_datafusion_expr(right, output_type)?; + Ok(match op { + BinaryExpressionOp::Plus => left_expr + right_expr, + BinaryExpressionOp::Minus => left_expr - right_expr, + BinaryExpressionOp::Multiply => left_expr * right_expr, + BinaryExpressionOp::Divide => left_expr / right_expr, + }) +} + +fn unary_to_df(un: &UnaryExpression, output_type: &DataType) -> DFResult { + let UnaryExpression { op, expr } = un; + let expr = to_datafusion_expr(expr, output_type)?; + Ok(match op { + UnaryExpressionOp::ToJson => Expr::ScalarFunction(ScalarFunction { + func: to_json(), + args: vec![expr], + }), + }) +} + +fn binary_pred_to_df(bin: &BinaryPredicate, output_type: &DataType) -> DFResult { + let BinaryPredicate { left, op, right } = bin; + let left_expr = to_datafusion_expr(left, output_type)?; + let right_expr = to_datafusion_expr(right, output_type)?; + + Ok(match op { + BinaryPredicateOp::Equal => left_expr.eq(right_expr), + BinaryPredicateOp::LessThan => left_expr.lt(right_expr), + BinaryPredicateOp::GreaterThan => left_expr.gt(right_expr), + BinaryPredicateOp::Distinct => Expr::BinaryExpr(BinaryExpr { + left: left_expr.into(), + op: Operator::IsDistinctFrom, + right: right_expr.into(), + }), + BinaryPredicateOp::In => Err(DataFusionError::NotImplemented( + "IN operator not supported".into(), + ))?, + }) +} + +fn predicate_to_df(predicate: &Predicate, output_type: &DataType) -> DFResult { + match predicate { + Predicate::BooleanExpression(expr) => to_datafusion_expr(expr, output_type), + Predicate::Not(expr) => Ok(!(predicate_to_df(expr, output_type)?)), + Predicate::Unary(expr) => unary_pred_to_df(expr, output_type), + Predicate::Binary(expr) => binary_pred_to_df(expr, output_type), + Predicate::Junction(expr) => junction_to_df(expr, output_type), + Predicate::Opaque(_) => not_impl_err!("Opaque predicates are not yet supported"), + Predicate::Unknown(_) => not_impl_err!("Unknown predicates are not yet supported"), + } +} + +fn unary_pred_to_df(unary: &UnaryPredicate, output_type: &DataType) -> DFResult { + let UnaryPredicate { op, expr } = unary; + let df_expr = to_datafusion_expr(expr, output_type)?; + Ok(match op { + UnaryPredicateOp::IsNull => df_expr.is_null(), + }) +} + +fn junction_to_df(junction: &JunctionPredicate, output_type: &DataType) -> DFResult { + let JunctionPredicate { op, preds } = junction; + let df_exprs: Vec<_> = preds + .iter() + .map(|e| predicate_to_df(e, output_type)) + .try_collect()?; + match op { + JunctionPredicateOp::And => Ok(df_exprs + .into_iter() + .reduce(|a, b| a.and(b)) + .unwrap_or(lit(true))), + JunctionPredicateOp::Or => Ok(df_exprs + .into_iter() + .reduce(|a, b| a.or(b)) + .unwrap_or(lit(false))), + } +} + +fn struct_to_df(fields: &[Arc], output_type: &DataType) -> DFResult { + let DataType::Struct(struct_type) = output_type else { + return Err(DataFusionError::Execution( + "expected struct output type".into(), + )); + }; + let df_exprs: Vec<_> = fields + .iter() + .zip(struct_type.fields()) + .map(|(expr, field)| { + Ok(vec![ + lit(field.name().to_string()), + to_datafusion_expr(expr, field.data_type())?, + ]) + }) + .flatten_ok() + .try_collect::<_, _, DataFusionError>()?; + Ok(named_struct(df_exprs)) +} + +#[cfg(test)] +mod tests { + use std::ops::Not; + + use datafusion::logical_expr::{col, lit}; + use delta_kernel::expressions::ColumnName; + use delta_kernel::expressions::{ArrayData, BinaryExpression, MapData, Scalar, StructData}; + use delta_kernel::schema::{ArrayType, DataType, MapType, StructField, StructType}; + + use super::*; + + /// Test conversion of primitive scalar types to DataFusion scalar values + #[test] + fn test_scalar_to_df_primitives() { + let test_cases = vec![ + (Scalar::Boolean(true), ScalarValue::Boolean(Some(true))), + ( + Scalar::String("test".to_string()), + ScalarValue::Utf8(Some("test".to_string())), + ), + (Scalar::Integer(42), ScalarValue::Int32(Some(42))), + (Scalar::Long(42), ScalarValue::Int64(Some(42))), + (Scalar::Float(42.0), ScalarValue::Float32(Some(42.0))), + (Scalar::Double(42.0), ScalarValue::Float64(Some(42.0))), + (Scalar::Byte(42), ScalarValue::Int8(Some(42))), + (Scalar::Short(42), ScalarValue::Int16(Some(42))), + ]; + + for (input, expected) in test_cases { + let result = scalar_to_df(&input).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test conversion of temporal scalar types to DataFusion scalar values + #[test] + fn test_scalar_to_df_temporal() { + let test_cases = vec![ + ( + Scalar::Timestamp(1234567890), + ScalarValue::TimestampMicrosecond(Some(1234567890), Some("UTC".into())), + ), + ( + Scalar::TimestampNtz(1234567890), + ScalarValue::TimestampMicrosecond(Some(1234567890), None), + ), + (Scalar::Date(18262), ScalarValue::Date32(Some(18262))), + ]; + + for (input, expected) in test_cases { + let result = scalar_to_df(&input).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test conversion of binary and decimal scalar types to DataFusion scalar values + #[test] + fn test_scalar_to_df_binary_decimal() { + let binary_data = vec![1, 2, 3]; + let decimal_data = Scalar::decimal(123456789, 10, 2).unwrap(); + + let test_cases = vec![ + ( + Scalar::Binary(binary_data.clone()), + ScalarValue::Binary(Some(binary_data)), + ), + ( + decimal_data, + ScalarValue::Decimal128(Some(123456789), 10, 2), + ), + ]; + + for (input, expected) in test_cases { + let result = scalar_to_df(&input).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test conversion of struct scalar type to DataFusion scalar value + #[test] + fn test_scalar_to_df_struct() { + let result = scalar_to_df(&Scalar::Struct( + StructData::try_new( + vec![ + StructField::nullable("a", DataType::INTEGER), + StructField::nullable("b", DataType::STRING), + ], + vec![Scalar::Integer(42), Scalar::String("test".to_string())], + ) + .unwrap(), + )) + .unwrap(); + + // Create the expected struct value + let expected = ScalarStructBuilder::new() + .with_scalar( + ArrowField::new("a", ArrowDataType::Int32, true), + ScalarValue::Int32(Some(42)), + ) + .with_scalar( + ArrowField::new("b", ArrowDataType::Utf8, true), + ScalarValue::Utf8(Some("test".to_string())), + ) + .build() + .unwrap(); + + assert_eq!(result, expected); + } + + /// Test conversion of null scalar types to DataFusion scalar values + #[test] + fn test_scalar_to_df_null() { + let test_cases = vec![ + (Scalar::Null(DataType::INTEGER), ScalarValue::Int32(None)), + (Scalar::Null(DataType::STRING), ScalarValue::Utf8(None)), + (Scalar::Null(DataType::BOOLEAN), ScalarValue::Boolean(None)), + (Scalar::Null(DataType::DOUBLE), ScalarValue::Float64(None)), + ]; + + for (input, expected) in test_cases { + let result = scalar_to_df(&input).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test error cases for unsupported scalar types (Array and Map) + #[test] + fn test_scalar_to_df_errors() { + let array_data = ArrayData::try_new( + ArrayType::new(DataType::INTEGER, true), + vec![Scalar::Integer(1), Scalar::Integer(2)], + ) + .unwrap(); + + let map_data = MapData::try_new( + MapType::new(DataType::STRING, DataType::INTEGER, true), + vec![ + (Scalar::String("key1".to_string()), Scalar::Integer(1)), + (Scalar::String("key2".to_string()), Scalar::Integer(2)), + ], + ) + .unwrap(); + + let test_cases = vec![ + ( + Scalar::Array(array_data), + "Array scalar values not implemented", + ), + (Scalar::Map(map_data), "Map scalar values not implemented"), + ]; + + for (input, expected_error) in test_cases { + let result = scalar_to_df(&input); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains(expected_error)); + } + } + + /// Test basic column reference: `test_col` + #[test] + fn test_column_expression() { + let expr = Expression::Column(ColumnName::new(["test_col"])); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, col("test_col")); + + let expr = Expression::Column(ColumnName::new(["test_col", "field_1", "field_2"])); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, col("test_col").field("field_1").field("field_2")); + } + + /// Test various literal values: + /// - `true` (boolean) + /// - `"test"` (string) + /// - `42` (integer) + /// - `42L` (long) + /// - `42.0f` (float) + /// - `42.0` (double) + /// - `NULL` (null boolean) + #[test] + fn test_literal_expressions() { + // Test various scalar types + let test_cases = vec![ + (Expression::Literal(Scalar::Boolean(true)), lit(true)), + ( + Expression::Literal(Scalar::String("test".to_string())), + lit("test"), + ), + (Expression::Literal(Scalar::Integer(42)), lit(42)), + (Expression::Literal(Scalar::Long(42)), lit(42i64)), + (Expression::Literal(Scalar::Float(42.0)), lit(42.0f32)), + (Expression::Literal(Scalar::Double(42.0)), lit(42.0)), + ( + Expression::Literal(Scalar::Null(DataType::BOOLEAN)), + lit(ScalarValue::Boolean(None)), + ), + ]; + + for (input, expected) in test_cases { + let result = to_datafusion_expr(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test binary operations: + /// - `a = 1` (equality) + /// - `a + b` (addition) + /// - `a * 2` (multiplication) + #[test] + fn test_binary_expressions() { + let test_cases = vec![ + ( + Expression::Binary(BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Plus, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }), + col("a") + col("b"), + ), + ( + Expression::Binary(BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Multiply, + right: Box::new(Expression::Literal(Scalar::Integer(2))), + }), + col("a") * lit(2), + ), + ]; + + for (input, expected) in test_cases { + let result = to_datafusion_expr(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test binary operations: + /// - `a = 1` (equality) + /// - `a + b` (addition) + /// - `a * 2` (multiplication) + #[test] + fn test_binary_predicate() { + let test_cases = vec![( + BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::Equal, + right: Box::new(Expression::Literal(Scalar::Integer(1))), + }, + col("a").eq(lit(1)), + )]; + + for (input, expected) in test_cases { + let result = binary_pred_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test unary operations: + /// - `a IS NULL` (null check) + /// - `NOT a` (logical negation) + #[test] + fn test_unary_expressions() { + let test_cases = vec![( + UnaryPredicate { + op: UnaryPredicateOp::IsNull, + expr: Box::new(Expression::Column(ColumnName::new(["a"]))), + }, + col("a").is_null(), + )]; + + for (input, expected) in test_cases { + let result = unary_pred_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test junction operations: + /// - `a AND b` (logical AND) + /// - `a OR b` (logical OR) + #[test] + fn test_junction_expressions() { + let test_cases = vec![ + ( + JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["a"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + ], + }, + col("a").and(col("b")), + ), + ( + JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["a"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + ], + }, + col("a").or(col("b")), + ), + ]; + + for (input, expected) in test_cases { + let result = junction_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test complex nested expression: + /// `(a > 1 AND b < 2) OR (c = 3)` + #[test] + fn test_complex_nested_expressions() { + // Test a complex expression: (a > 1 AND b < 2) OR (c = 3) + let expr = Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![ + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![ + Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::GreaterThan, + right: Box::new(Expression::Literal(Scalar::Integer(1))), + }), + Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["b"]))), + op: BinaryPredicateOp::LessThan, + right: Box::new(Expression::Literal(Scalar::Integer(2))), + }), + ], + }), + Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["c"]))), + op: BinaryPredicateOp::Equal, + right: Box::new(Expression::Literal(Scalar::Integer(3))), + }), + ], + }); + + let result = predicate_to_df(&expr, &DataType::BOOLEAN).unwrap(); + let expected = (col("a").gt(lit(1)).and(col("b").lt(lit(2)))).or(col("c").eq(lit(3))); + assert_eq!(result, expected); + } + + #[test] + fn test_struct_expression() { + let expr = Expression::Struct(vec![ + Expression::Column(ColumnName::new(["a"])).into(), + Expression::Column(ColumnName::new(["b"])).into(), + ]); + let result = to_datafusion_expr( + &expr, + &DataType::Struct(Box::new( + StructType::try_new(vec![ + StructField::nullable("a", DataType::INTEGER), + StructField::nullable("b", DataType::INTEGER), + ]) + .unwrap(), + )), + ) + .unwrap(); + assert_eq!( + result, + named_struct(vec![lit("a"), col("a"), lit("b"), col("b")]) + ); + } + + /// Test binary expression conversions: + /// - Addition: a + b + /// - Subtraction: a - b + /// - Multiplication: a * b + /// - Division: a / b + #[test] + fn test_binary_to_df() { + let test_cases = vec![ + ( + BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Plus, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a") + col("b"), + ), + ( + BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Minus, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a") - col("b"), + ), + ( + BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Multiply, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a") * col("b"), + ), + ( + BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Divide, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a") / col("b"), + ), + ]; + + for (input, expected) in test_cases { + let result = binary_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test binary expression conversions: + /// - Equality: a = b + /// - Inequality: a != b + /// - Less than: a < b + /// - Less than or equal: a <= b + /// - Greater than: a > b + /// - Greater than or equal: a >= b + #[test] + fn test_binary_pred_to_df() { + let test_cases = vec![ + ( + BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::Equal, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a").eq(col("b")), + ), + ( + BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::LessThan, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a").lt(col("b")), + ), + ( + BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::GreaterThan, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }, + col("a").gt(col("b")), + ), + ]; + + for (input, expected) in test_cases { + let result = binary_pred_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + + let test_cases = vec![ + ( + Predicate::Not(Box::new(Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::Equal, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }))), + col("a").eq(col("b")).not(), + ), + ( + Predicate::Not(Box::new(Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::GreaterThan, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }))), + col("a").gt(col("b")).not(), + ), + ( + Predicate::Not(Box::new(Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::LessThan, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + }))), + col("a").lt(col("b")).not(), + ), + ]; + + for (input, expected) in test_cases { + let result = predicate_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test junction expression conversions: + /// - Simple AND: a AND b + /// - Simple OR: a OR b + /// - Multiple AND: a AND b AND c + /// - Multiple OR: a OR b OR c + /// - Empty AND (should return true) + /// - Empty OR (should return false) + #[test] + fn test_junction_to_df() { + let test_cases = vec![ + // Simple AND + ( + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["a"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + ], + }), + col("a").and(col("b")), + ), + // Simple OR + ( + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["a"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + ], + }), + col("a").or(col("b")), + ), + // Multiple AND + ( + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["a"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["c"]))), + ], + }), + col("a").and(col("b")).and(col("c")), + ), + // Multiple OR + ( + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["a"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["c"]))), + ], + }), + col("a").or(col("b")).or(col("c")), + ), + // Empty AND (should return true) + ( + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![], + }), + lit(true), + ), + // Empty OR (should return false) + ( + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![], + }), + lit(false), + ), + ]; + + for (input, expected) in test_cases { + let result = predicate_to_df(&input, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, expected); + } + } + + /// Test to_datafusion_expr with various expression types and combinations: + /// - Column expressions with nested fields + /// - Complex unary expressions + /// - Nested binary expressions + /// - Mixed junction expressions + /// - Struct expressions with nested fields + /// - Complex combinations of all expression types + #[test] + fn test_to_datafusion_expr_comprehensive() { + // Test column expressions with nested fields + let expr = Expression::Column(ColumnName::new(["struct", "field", "nested"])); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, col("struct").field("field").field("nested")); + + // Test complex unary expressions + let expr = Expression::Predicate(Box::new(Predicate::Not(Box::new(Predicate::Unary( + UnaryPredicate { + op: UnaryPredicateOp::IsNull, + expr: Box::new(Expression::Column(ColumnName::new(["a"]))), + }, + ))))); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, !col("a").is_null()); + + // Test nested binary expressions + let expr = Expression::Binary(BinaryExpression { + left: Box::new(Expression::Binary(BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryExpressionOp::Plus, + right: Box::new(Expression::Column(ColumnName::new(["b"]))), + })), + op: BinaryExpressionOp::Multiply, + right: Box::new(Expression::Column(ColumnName::new(["c"]))), + }); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, (col("a") + col("b")) * col("c")); + + // Test mixed junction expressions + let expr = Expression::Predicate(Box::new(Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![ + Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["a"]))), + op: BinaryPredicateOp::GreaterThan, + right: Box::new(Expression::Literal(Scalar::Integer(0))), + }), + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["b"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["c"]))), + ], + }), + ], + }))); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!(result, col("a").gt(lit(0)).and(col("b").or(col("c")))); + + // Test struct expressions with nested fields + let expr = Expression::Struct(vec![ + Expression::Column(ColumnName::new(["a"])).into(), + Expression::Binary(BinaryExpression { + left: Box::new(Expression::Column(ColumnName::new(["b"]))), + op: BinaryExpressionOp::Plus, + right: Box::new(Expression::Column(ColumnName::new(["c"]))), + }) + .into(), + ]); + let result = to_datafusion_expr( + &expr, + &DataType::Struct(Box::new( + StructType::try_new(vec![ + StructField::nullable("a", DataType::INTEGER), + StructField::nullable("sum", DataType::INTEGER), + ]) + .unwrap(), + )), + ) + .unwrap(); + assert_eq!( + result, + named_struct(vec![lit("a"), col("a"), lit("sum"), col("b") + col("c")]) + ); + + // Test complex combination of all expression types + let expr = Expression::Predicate(Box::new(Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: vec![ + Predicate::Not(Box::new(Predicate::BooleanExpression(Expression::Column( + ColumnName::new(["a"]), + )))), + Predicate::Binary(BinaryPredicate { + left: Box::new(Expression::Column(ColumnName::new(["b"]))), + op: BinaryPredicateOp::Equal, + right: Box::new(Expression::Literal(Scalar::Integer(42))), + }), + Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::Or, + preds: vec![ + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["c"]))), + Predicate::BooleanExpression(Expression::Column(ColumnName::new(["d"]))), + ], + }), + ], + }))); + let result = to_datafusion_expr(&expr, &DataType::BOOLEAN).unwrap(); + assert_eq!( + result, + (!col("a")) + .and(col("b").eq(lit(42))) + .and(col("c").or(col("d"))) + ); + + // Test error case: empty column name + let expr = Expression::Column(ColumnName::new::<&str>([])); + assert!(to_datafusion_expr(&expr, &DataType::BOOLEAN).is_err()); + } +} diff --git a/crates/core/src/delta_datafusion/engine/expressions/to_json.rs b/crates/core/src/delta_datafusion/engine/expressions/to_json.rs new file mode 100644 index 0000000000..9990337a7b --- /dev/null +++ b/crates/core/src/delta_datafusion/engine/expressions/to_json.rs @@ -0,0 +1,131 @@ +use std::sync::Arc; +use std::{any::Any, sync::LazyLock}; + +use arrow::datatypes::DataType; +use datafusion::common::Result; +use datafusion::error::DataFusionError; +use datafusion::logical_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion::logical_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, scalar_doc_sections::DOC_SECTION_STRUCT, +}; +use delta_kernel::engine::arrow_expression::evaluate_expression::to_json as to_json_kernel; + +pub fn to_json() -> Arc { + static INSTANCE: LazyLock> = + LazyLock::new(|| Arc::new(ScalarUDF::new_from_impl(ToJson::new()))); + Arc::clone(&INSTANCE) +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ToJson { + signature: Signature, +} + +impl ToJson { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Stable), + } + } +} + +static DOCUMENTATION: LazyLock = LazyLock::new(|| { + Documentation::builder( + DOC_SECTION_STRUCT, + "Serialize data as a JSON string.", + "to_json()", + ) + .with_argument("data", "The data to be converted to JSON format.") + .build() +}); + +fn get_doc() -> &'static Documentation { + &DOCUMENTATION +} + +/// Implement the ScalarUDFImpl trait for AddOne +impl ScalarUDFImpl for ToJson { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "to_json" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { + args, number_rows, .. + } = args; + let Some(data) = args.first().map(|c| c.to_array(number_rows)).transpose()? else { + return Err(DataFusionError::Internal( + "to_json requires one argument".to_string(), + )); + }; + Ok(ColumnarValue::Array(to_json_kernel(&data)?)) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_doc()) + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + // The function preserves the order of its argument. + Ok(input[0].sort_properties) + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{ArrayRef, Float64Array, Int8Array, RecordBatch}; + use arrow_array::StructArray; + use datafusion::{ + assert_batches_eq, + prelude::{SessionContext, col}, + }; + + use super::*; + + #[tokio::test] + async fn test_to_json() -> Result<(), Box> { + let long: ArrayRef = Arc::new(Float64Array::from(vec![100.0, -122.4783, -122.4783])); + let lat: ArrayRef = Arc::new(Float64Array::from(vec![45.0, 37.8199, 37.8199])); + let res: ArrayRef = Arc::new(Int8Array::from(vec![6, 13, 16])); + let struct_array: ArrayRef = + Arc::new(StructArray::from(RecordBatch::try_from_iter(vec![ + ("long", long), + ("lat", lat), + ("res", res), + ])?)); + let batch = RecordBatch::try_from_iter(vec![("geometry", struct_array)])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + + let df = ctx.table("t").await?; + let df = df.select(vec![to_json().call(vec![col("geometry")]).alias("json")])?; + + let results = df.collect().await?; + let expected = vec![ + r#"+-------------------------------------------+"#, + r#"| json |"#, + r#"+-------------------------------------------+"#, + r#"| {"long":100.0,"lat":45.0,"res":6} |"#, + r#"| {"long":-122.4783,"lat":37.8199,"res":13} |"#, + r#"| {"long":-122.4783,"lat":37.8199,"res":16} |"#, + r#"+-------------------------------------------+"#, + ]; + assert_batches_eq!(expected, &results); + + Ok(()) + } +} diff --git a/crates/core/src/delta_datafusion/engine/expressions/to_kernel.rs b/crates/core/src/delta_datafusion/engine/expressions/to_kernel.rs new file mode 100644 index 0000000000..d2b33125b5 --- /dev/null +++ b/crates/core/src/delta_datafusion/engine/expressions/to_kernel.rs @@ -0,0 +1,551 @@ +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::logical_expr::{BinaryExpr, Expr, Operator}; +use delta_kernel::Error as DeltaKernelError; +use delta_kernel::expressions::{ + BinaryExpression, BinaryExpressionOp, BinaryPredicate, BinaryPredicateOp, DecimalData, + Expression, JunctionPredicate, JunctionPredicateOp, Predicate, Scalar, UnaryPredicate, + UnaryPredicateOp, +}; +use delta_kernel::schema::{DataType, DecimalType, PrimitiveType}; +use itertools::Itertools; + +pub(crate) fn to_df_err(e: DeltaKernelError) -> DataFusionError { + DataFusionError::External(Box::new(e)) +} + +pub(crate) fn to_delta_predicate(filters: &[Expr]) -> DFResult { + if filters.is_empty() { + return Ok(Predicate::BooleanExpression(Expression::Literal( + Scalar::Boolean(true), + ))); + }; + if filters.len() == 1 { + return to_predicate(&filters[0]); + } + Ok(Predicate::Junction(JunctionPredicate { + op: JunctionPredicateOp::And, + preds: filters.iter().map(to_predicate).try_collect()?, + })) +} + +pub(crate) fn to_predicate(expr: &Expr) -> DFResult { + match to_delta_expression(expr)? { + Expression::Predicate(pred) => Ok(pred.as_ref().clone()), + expr => Ok(Predicate::BooleanExpression(expr)), + } +} + +/// Convert a DataFusion expression to a Delta expression. +pub(crate) fn to_delta_expression(expr: &Expr) -> DFResult { + match expr { + Expr::Column(column) => Ok(Expression::Column( + column + .name + .parse() + .map_err(|e| DataFusionError::External(Box::new(e)))?, + )), + Expr::Literal(scalar, _meta) => { + Ok(Expression::Literal(datafusion_scalar_to_scalar(scalar)?)) + } + Expr::BinaryExpr(BinaryExpr { + op: op @ (Operator::And | Operator::Or), + .. + }) => { + let preds = flatten_junction_expr(expr, *op)?; + Ok(Expression::Predicate(Box::new(Predicate::Junction( + JunctionPredicate { + op: to_junction_op(*op), + preds, + }, + )))) + } + Expr::BinaryExpr(BinaryExpr { + op: op @ (Operator::Eq | Operator::Lt | Operator::Gt | Operator::IsDistinctFrom), + left, + right, + }) => Ok(Expression::Predicate(Box::new(Predicate::Binary( + BinaryPredicate { + left: Box::new(to_delta_expression(left.as_ref())?), + op: to_binary_predicate_op(*op)?, + right: Box::new(to_delta_expression(right.as_ref())?), + }, + )))), + Expr::BinaryExpr(BinaryExpr { + op: op @ (Operator::NotEq | Operator::LtEq | Operator::GtEq), + left, + right, + }) => { + let inverted = match op { + Operator::NotEq => Operator::Eq, + Operator::LtEq => Operator::Gt, + Operator::GtEq => Operator::Lt, + _ => unreachable!(), + }; + Ok(Expression::Predicate(Box::new(Predicate::Not(Box::new( + Predicate::Binary(BinaryPredicate { + left: Box::new(to_delta_expression(left.as_ref())?), + op: to_binary_predicate_op(inverted)?, + right: Box::new(to_delta_expression(right.as_ref())?), + }), + ))))) + } + Expr::BinaryExpr(BinaryExpr { + op: Operator::IsNotDistinctFrom, + left, + right, + }) => Ok(Expression::Predicate(Box::new(Predicate::Not(Box::new( + Predicate::Binary(BinaryPredicate { + left: Box::new(to_delta_expression(left.as_ref())?), + op: to_binary_predicate_op(Operator::IsDistinctFrom)?, + right: Box::new(to_delta_expression(right.as_ref())?), + }), + ))))), + Expr::BinaryExpr(BinaryExpr { op, left, right }) => { + Ok(Expression::Binary(BinaryExpression { + left: Box::new(to_delta_expression(left.as_ref())?), + op: to_binary_op(*op)?, + right: Box::new(to_delta_expression(right.as_ref())?), + })) + } + Expr::IsNull(expr) => Ok(Expression::Predicate(Box::new(Predicate::Unary( + UnaryPredicate { + op: UnaryPredicateOp::IsNull, + expr: Box::new(to_delta_expression(expr.as_ref())?), + }, + )))), + Expr::Not(expr) => Ok(Expression::Predicate(Box::new(Predicate::Not(Box::new( + Predicate::BooleanExpression(to_delta_expression(expr.as_ref())?), + ))))), + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported expression: {:?}", + expr + ))), + } +} + +fn datafusion_scalar_to_scalar(scalar: &ScalarValue) -> DFResult { + match scalar { + ScalarValue::Boolean(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Boolean(*value)), + None => Ok(Scalar::Null(DataType::BOOLEAN)), + }, + ScalarValue::Utf8(maybe_value) + | ScalarValue::LargeUtf8(maybe_value) + | ScalarValue::Utf8View(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::String(value.clone())), + None => Ok(Scalar::Null(DataType::STRING)), + }, + ScalarValue::Int8(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Byte(*value)), + None => Ok(Scalar::Null(DataType::BYTE)), + }, + ScalarValue::Int16(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Short(*value)), + None => Ok(Scalar::Null(DataType::SHORT)), + }, + ScalarValue::Int32(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Integer(*value)), + None => Ok(Scalar::Null(DataType::INTEGER)), + }, + ScalarValue::Int64(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Long(*value)), + None => Ok(Scalar::Null(DataType::LONG)), + }, + ScalarValue::Float32(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Float(*value)), + None => Ok(Scalar::Null(DataType::FLOAT)), + }, + ScalarValue::Float64(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Double(*value)), + None => Ok(Scalar::Null(DataType::DOUBLE)), + }, + ScalarValue::TimestampMicrosecond(maybe_value, Some(_)) => match maybe_value { + Some(value) => Ok(Scalar::Timestamp(*value)), + None => Ok(Scalar::Null(DataType::TIMESTAMP)), + }, + ScalarValue::TimestampMicrosecond(maybe_value, None) => match maybe_value { + Some(value) => Ok(Scalar::TimestampNtz(*value)), + None => Ok(Scalar::Null(DataType::TIMESTAMP_NTZ)), + }, + ScalarValue::Date32(maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Date(*value)), + None => Ok(Scalar::Null(DataType::DATE)), + }, + ScalarValue::Binary(maybe_value) + | ScalarValue::LargeBinary(maybe_value) + | ScalarValue::BinaryView(maybe_value) + | ScalarValue::FixedSizeBinary(_, maybe_value) => match maybe_value { + Some(value) => Ok(Scalar::Binary(value.clone())), + None => Ok(Scalar::Null(DataType::BINARY)), + }, + ScalarValue::Decimal128(maybe_value, precision, scale) => match maybe_value { + Some(value) => Ok(Scalar::Decimal( + DecimalData::try_new( + *value, + DecimalType::try_new(*precision, *scale as u8).map_err(to_df_err)?, + ) + .map_err(to_df_err)?, + )), + None => Ok(Scalar::Null(DataType::Primitive(PrimitiveType::Decimal( + DecimalType::try_new(*precision, *scale as u8).map_err(to_df_err)?, + )))), + }, + ScalarValue::Dictionary(_, value) => datafusion_scalar_to_scalar(value.as_ref()), + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported scalar value: {:?}", + scalar + ))), + } +} + +fn to_binary_predicate_op(op: Operator) -> DFResult { + match op { + Operator::Eq => Ok(BinaryPredicateOp::Equal), + Operator::Lt => Ok(BinaryPredicateOp::LessThan), + Operator::Gt => Ok(BinaryPredicateOp::GreaterThan), + Operator::IsDistinctFrom => Ok(BinaryPredicateOp::Distinct), + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported operator: {:?}", + op + ))), + } +} + +fn to_binary_op(op: Operator) -> DFResult { + match op { + Operator::Plus => Ok(BinaryExpressionOp::Plus), + Operator::Minus => Ok(BinaryExpressionOp::Minus), + Operator::Multiply => Ok(BinaryExpressionOp::Multiply), + Operator::Divide => Ok(BinaryExpressionOp::Divide), + _ => Err(DataFusionError::NotImplemented(format!( + "Unsupported operator: {:?}", + op + ))), + } +} + +/// Helper function to flatten nested AND/OR expressions into a single junction expression +fn flatten_junction_expr(expr: &Expr, target_op: Operator) -> DFResult> { + match expr { + Expr::BinaryExpr(BinaryExpr { op, left, right }) if *op == target_op => { + let mut left_exprs = flatten_junction_expr(left.as_ref(), target_op)?; + let mut right_exprs = flatten_junction_expr(right.as_ref(), target_op)?; + left_exprs.append(&mut right_exprs); + Ok(left_exprs) + } + _ => { + let delta_expr = to_predicate(expr)?; + Ok(vec![delta_expr]) + } + } +} + +fn to_junction_op(op: Operator) -> JunctionPredicateOp { + match op { + Operator::And => JunctionPredicateOp::And, + Operator::Or => JunctionPredicateOp::Or, + _ => unimplemented!("Unsupported operator: {:?}", op), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::logical_expr::{col, lit}; + use delta_kernel::expressions::{BinaryExpressionOp, JunctionPredicateOp, Scalar}; + + fn assert_junction_expr( + expr: &Expr, + expected_op: JunctionPredicateOp, + expected_children: usize, + ) { + let delta_expr = to_delta_expression(expr).unwrap(); + match delta_expr { + Expression::Predicate(predicate) => match predicate.as_ref() { + Predicate::Junction(junction) => { + assert_eq!(junction.op, expected_op); + assert_eq!(junction.preds.len(), expected_children); + } + _ => panic!("Expected Junction predicate, got {:?}", predicate), + }, + _ => panic!("Expected Junction expression, got {:?}", delta_expr), + } + } + + #[test] + fn test_simple_and() { + let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); + assert_junction_expr(&expr, JunctionPredicateOp::And, 2); + } + + #[test] + fn test_simple_or() { + let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2))); + assert_junction_expr(&expr, JunctionPredicateOp::Or, 2); + } + + #[test] + fn test_nested_and() { + let expr = col("a") + .eq(lit(1)) + .and(col("b").eq(lit(2))) + .and(col("c").eq(lit(3))) + .and(col("d").eq(lit(4))); + assert_junction_expr(&expr, JunctionPredicateOp::And, 4); + } + + #[test] + fn test_nested_or() { + let expr = col("a") + .eq(lit(1)) + .or(col("b").eq(lit(2))) + .or(col("c").eq(lit(3))) + .or(col("d").eq(lit(4))); + assert_junction_expr(&expr, JunctionPredicateOp::Or, 4); + } + + #[test] + fn test_mixed_nested_and_or() { + // (a AND b) OR (c AND d) + let left = col("a").eq(lit(1)).and(col("b").eq(lit(2))); + let right = col("c").eq(lit(3)).and(col("d").eq(lit(4))); + let expr = left.or(right); + + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Predicate(predicate) => match predicate.as_ref() { + Predicate::Junction(junction) => { + assert_eq!(junction.op, JunctionPredicateOp::Or); + assert_eq!(junction.preds.len(), 2); + + // Check that both children are AND junctions + for child in &junction.preds { + match child { + Predicate::Junction(binary) => { + assert_eq!(binary.op, JunctionPredicateOp::And); + } + _ => panic!("Expected Binary expression in child: {:?}", child), + } + } + } + _ => panic!("Expected Junction predicate, got {:?}", predicate), + }, + _ => panic!("Expected Junction expression"), + } + } + + #[test] + fn test_deeply_nested_and() { + // (((a AND b) AND c) AND d) + let expr = col("a") + .eq(lit(1)) + .and(col("b").eq(lit(2))) + .and(col("c").eq(lit(3))) + .and(col("d").eq(lit(4))); + assert_junction_expr(&expr, JunctionPredicateOp::And, 4); + } + + #[test] + fn test_complex_expression() { + // (a AND b) OR ((c AND d) AND e) + let left = col("a").eq(lit(1)).and(col("b").eq(lit(2))); + let right = col("c") + .eq(lit(3)) + .and(col("d").eq(lit(4))) + .and(col("e").eq(lit(5))); + let expr = left.or(right); + + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Predicate(predicate) => match predicate.as_ref() { + Predicate::Junction(junction) => { + assert_eq!(junction.op, JunctionPredicateOp::Or); + assert_eq!(junction.preds.len(), 2); + + // First child should be an AND with 2 expressions + match &junction.preds[0] { + Predicate::Junction(child_junction) => { + assert_eq!(child_junction.op, JunctionPredicateOp::And); + assert_eq!(child_junction.preds.len(), 2); + } + _ => panic!("Expected Junction expression in first child"), + } + + // Second child should be an AND with 3 expressions + match &junction.preds[1] { + Predicate::Junction(child_junction) => { + assert_eq!(child_junction.op, JunctionPredicateOp::And); + assert_eq!(child_junction.preds.len(), 3); + } + _ => panic!("Expected Junction expression in second child"), + } + } + _ => panic!("Expected Junction predicate, got {:?}", predicate), + }, + _ => panic!("Expected Junction expression"), + } + } + + #[test] + fn test_column_expression() { + let expr = col("test_column"); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Column(name) => assert_eq!(&name.to_string(), "test_column"), + _ => panic!("Expected Column expression, got {:?}", delta_expr), + } + } + + #[test] + fn test_literal_expressions() { + // Test boolean literal + let expr = lit(true); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Literal(Scalar::Boolean(value)) => assert!(value), + _ => panic!("Expected Boolean literal, got {:?}", delta_expr), + } + + // Test string literal + let expr = lit("test"); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Literal(Scalar::String(value)) => assert_eq!(value, "test"), + _ => panic!("Expected String literal, got {:?}", delta_expr), + } + + // Test integer literal + let expr = lit(42i32); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Literal(Scalar::Integer(value)) => assert_eq!(value, 42), + _ => panic!("Expected Integer literal, got {:?}", delta_expr), + } + + // Test decimal literal + let expr = lit(ScalarValue::Decimal128(Some(12345), 10, 2)); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Literal(Scalar::Decimal(data)) => { + assert_eq!(data.bits(), 12345); + assert_eq!(data.precision(), 10); + assert_eq!(data.scale(), 2); + } + _ => panic!("Expected Decimal literal, got {:?}", delta_expr), + } + } + + #[test] + fn test_binary_expressions() { + // Test comparison operators + let test_cases = vec![ + (col("a").eq(lit(1)), BinaryPredicateOp::Equal), + (col("a").lt(lit(1)), BinaryPredicateOp::LessThan), + (col("a").gt(lit(1)), BinaryPredicateOp::GreaterThan), + ]; + + for (expr, expected_op) in test_cases { + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Predicate(predicate) => match predicate.as_ref() { + Predicate::Binary(binary) => { + assert_eq!(binary.op, expected_op); + match binary.left.as_ref() { + Expression::Column(name) => assert_eq!(name.to_string(), "a"), + _ => panic!("Expected Column expression in left operand"), + } + match *binary.right.as_ref() { + Expression::Literal(Scalar::Integer(value)) => assert_eq!(value, 1), + _ => panic!("Expected Integer literal in right operand"), + } + } + _ => panic!("Expected Binary predicate, got {:?}", predicate), + }, + _ => panic!("Expected Binary expression, got {:?}", delta_expr), + } + } + + // Test arithmetic operators + let test_cases = vec![ + (col("a") + lit(1), BinaryExpressionOp::Plus), + (col("a") - lit(1), BinaryExpressionOp::Minus), + (col("a") * lit(1), BinaryExpressionOp::Multiply), + (col("a") / lit(1), BinaryExpressionOp::Divide), + ]; + + for (expr, expected_op) in test_cases { + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Binary(binary) => { + assert_eq!(binary.op, expected_op); + match binary.left.as_ref() { + Expression::Column(name) => assert_eq!(name.to_string(), "a"), + _ => panic!("Expected Column expression in left operand"), + } + match *binary.right.as_ref() { + Expression::Literal(Scalar::Integer(value)) => assert_eq!(value, 1), + _ => panic!("Expected Integer literal in right operand"), + } + } + _ => panic!("Expected Binary expression, got {:?}", delta_expr), + } + } + } + + #[test] + fn test_unary_expressions() { + // Test IS NULL + let expr = col("a").is_null(); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Predicate(predicate) => match predicate.as_ref() { + Predicate::Unary(unary) => { + assert_eq!(unary.op, UnaryPredicateOp::IsNull); + match unary.expr.as_ref() { + Expression::Column(name) => assert_eq!(name.to_string(), "a"), + _ => panic!("Expected Column expression in operand"), + } + } + _ => panic!("Expected Unary predicate, got {:?}", predicate), + }, + _ => panic!("Expected Unary expression, got {:?}", delta_expr), + } + + // Test NOT + let expr = !col("a"); + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Predicate(predicate) => match predicate.as_ref() { + Predicate::Not(unary) => match unary.as_ref() { + Predicate::BooleanExpression(expr) => match expr { + Expression::Column(name) => assert_eq!(name.to_string(), "a"), + _ => panic!("Expected Column expression in operand"), + }, + _ => panic!("Expected Boolean expression in operand"), + }, + _ => panic!("Expected Unary predicate, got {:?}", predicate), + }, + _ => panic!("Expected Unary expression, got {:?}", delta_expr), + } + } + + #[test] + fn test_null_literals() { + let test_cases = vec![ + (lit(ScalarValue::Boolean(None)), DataType::BOOLEAN), + (lit(ScalarValue::Utf8(None)), DataType::STRING), + (lit(ScalarValue::Int32(None)), DataType::INTEGER), + (lit(ScalarValue::Float64(None)), DataType::DOUBLE), + ]; + + for (expr, expected_type) in test_cases { + let delta_expr = to_delta_expression(&expr).unwrap(); + match delta_expr { + Expression::Literal(Scalar::Null(data_type)) => { + assert_eq!(data_type, expected_type); + } + _ => panic!("Expected Null literal, got {:?}", delta_expr), + } + } + } +} diff --git a/crates/core/src/delta_datafusion/engine/mod.rs b/crates/core/src/delta_datafusion/engine/mod.rs index 478541df5c..50ff89c3c6 100644 --- a/crates/core/src/delta_datafusion/engine/mod.rs +++ b/crates/core/src/delta_datafusion/engine/mod.rs @@ -6,9 +6,12 @@ use delta_kernel::{Engine, EvaluationHandler, JsonHandler, ParquetHandler, Stora use tokio::runtime::Handle; use self::file_formats::DataFusionFileFormatHandler; -use self::storage::DataFusionStorageHandler; use crate::kernel::ARROW_HANDLER; +pub use expressions::*; +pub(crate) use storage::*; + +mod expressions; mod file_formats; mod storage; diff --git a/crates/core/src/delta_datafusion/mod.rs b/crates/core/src/delta_datafusion/mod.rs index d7ded37751..98af920165 100644 --- a/crates/core/src/delta_datafusion/mod.rs +++ b/crates/core/src/delta_datafusion/mod.rs @@ -224,7 +224,7 @@ fn _arrow_schema( ) -> ArrowSchemaRef { let fields = schema .fields() - .into_iter() + .iter() .filter(|f| !partition_columns.contains(&f.name().to_string())) .cloned() .chain( @@ -235,7 +235,7 @@ fn _arrow_schema( let corrected = if wrap_partitions { match field.data_type() { // Only dictionary-encode types that may be large - // // https://github.com/apache/arrow-datafusion/pull/5545 + // https://github.com/apache/arrow-datafusion/pull/5545 ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Binary @@ -780,14 +780,14 @@ impl TableProviderFactory for DeltaTableFactory { _ctx: &dyn Session, cmd: &CreateExternalTable, ) -> datafusion::error::Result> { - let provider = if cmd.options.is_empty() { + let table = if cmd.options.is_empty() { let table_url = ensure_table_uri(&cmd.to_owned().location)?; open_table(table_url).await? } else { let table_url = ensure_table_uri(&cmd.to_owned().location)?; open_table_with_storage_options(table_url, cmd.to_owned().options).await? }; - Ok(Arc::new(provider)) + Ok(Arc::new(table)) } } @@ -1605,9 +1605,10 @@ mod tests { .unwrap(); let datafusion = SessionContext::new(); - let table = Arc::new(table); - datafusion.register_table("snapshot", table).unwrap(); + datafusion + .register_table("snapshot", Arc::new(table)) + .unwrap(); let df = datafusion .sql("select * from snapshot where id > 10000 and id < 20000") diff --git a/crates/core/src/delta_datafusion/session.rs b/crates/core/src/delta_datafusion/session.rs index 9a55c20d9c..6216dd1044 100644 --- a/crates/core/src/delta_datafusion/session.rs +++ b/crates/core/src/delta_datafusion/session.rs @@ -19,6 +19,20 @@ pub fn create_session() -> DeltaSessionContext { DeltaSessionContext::default() } +#[cfg(test)] +pub fn create_test_session() -> DeltaSessionContext { + use std::sync::Arc; + + use object_store::memory::InMemory; + + let session = DeltaSessionContext::default(); + session.inner.runtime_env().register_object_store( + &url::Url::parse("memory:///").unwrap(), + Arc::new(InMemory::new()), + ); + session +} + // Given a `Session` reference, get the concrete `SessionState` reference // Note: this may stop working in future versions, #[deprecated( diff --git a/crates/core/src/delta_datafusion/table_provider.rs b/crates/core/src/delta_datafusion/table_provider.rs index d29e22bbc3..9dd0096b66 100644 --- a/crates/core/src/delta_datafusion/table_provider.rs +++ b/crates/core/src/delta_datafusion/table_provider.rs @@ -43,8 +43,10 @@ use datafusion::{ prelude::Expr, scalar::ScalarValue, }; +use delta_kernel::Version; use delta_kernel::table_properties::DataSkippingNumIndexedCols; use futures::StreamExt as _; +use futures::future::BoxFuture; use itertools::Itertools; use object_store::ObjectMeta; use serde::{Deserialize, Serialize}; @@ -57,13 +59,16 @@ use crate::delta_datafusion::{ }; use crate::kernel::schema::cast::cast_record_batch; use crate::kernel::transaction::{CommitBuilder, PROTOCOL}; -use crate::kernel::{Action, Add, EagerSnapshot, Remove}; +use crate::kernel::{Action, Add, EagerSnapshot, Remove, resolve_snapshot}; +use crate::logstore::LogStore; use crate::operations::write::WriterStatsConfig; use crate::operations::write::writer::{DeltaWriter, WriterConfig}; use crate::protocol::{DeltaOperation, SaveMode}; use crate::table::normalize_table_url; use crate::{DeltaResult, DeltaTable, DeltaTableError, logstore::LogStoreRef}; +pub(crate) mod next; + const PATH_COLUMN: &str = "__delta_rs_path"; /// DataSink implementation for delta lake @@ -700,6 +705,64 @@ impl<'a> DeltaScanBuilder<'a> { } } +pub struct TableProviderBuilder { + log_store: Arc, + snapshot: Option, + file_column: Option, + table_version: Option, +} + +impl TableProviderBuilder { + fn new(log_store: Arc, snapshot: Option) -> Self { + Self { + log_store, + snapshot, + file_column: None, + table_version: None, + } + } + + /// Specify the version of the table to provide + pub fn with_table_version(mut self, version: impl Into>) -> Self { + self.table_version = version.into(); + self + } + + pub fn with_file_column(mut self, file_column: String) -> Self { + self.file_column = Some(file_column); + self + } +} + +impl std::future::IntoFuture for TableProviderBuilder { + type Output = Result>; + type IntoFuture = BoxFuture<'static, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + let this = self; + + Box::pin(async move { + let snapshot = + resolve_snapshot(&this.log_store, this.snapshot, false, this.table_version).await?; + Ok(Arc::new(snapshot) as Arc) + }) + } +} + +impl DeltaTable { + /// Get a table provider for the table referenced by this DeltaTable. + /// + /// See [`TableProviderBuilder`] for options when building the provider. + pub fn table_provider(&self) -> TableProviderBuilder { + TableProviderBuilder::new( + self.log_store(), + self.snapshot() + .ok() + .map(|snapshot| snapshot.snapshot().clone()), + ) + } +} + // TODO: implement this for Snapshot, not for DeltaTable since DeltaTable has unknown load state. // the unwraps in the schema method are a dead giveaway .. #[async_trait::async_trait] diff --git a/crates/core/src/delta_datafusion/table_provider/next/mod.rs b/crates/core/src/delta_datafusion/table_provider/next/mod.rs new file mode 100644 index 0000000000..adec9d7654 --- /dev/null +++ b/crates/core/src/delta_datafusion/table_provider/next/mod.rs @@ -0,0 +1,572 @@ +//! Datafusion TableProvider implementation for Delta Lake tables. +//! +//! This module provides an implementation of the DataFusion `TableProvider` trait +//! for Delta Lake tables, allowing seamless integration with DataFusion's query engine. +//! +//!
+//! +//! The table provider is based on Snapshots of a Delta Table. Therefore, it represents +//! a static view of the table at a specific point in time. Changes to the underlying +//! Delta Table after the snapshot was taken will not be reflected in queries executed. +//! +//! To work with a dynamic view of the table that reflects ongoing changes, consider using +//! the catalog abstractions in this crate, which provide a higher-level interface for managing +//! Delta Tables within DataFusion sessions. +//! +//!
+//! +//! The TableProvider provides converts the planning information from delta-kernel into +//! DataFusion's physical plan representation. It supports filter pushdown and projection +//! pushdown to optimize data access. The actual reading of data files is delegated to +//! DataFusion's existing Parquet reader implementations, with additional handling for +//! Delta Lake-specific features. +//! +//! +use std::any::Any; +use std::pin::Pin; +use std::{borrow::Cow, sync::Arc}; + +use arrow::datatypes::{DataType, Field, SchemaRef}; +use arrow_schema::Schema; +use dashmap::DashMap; +use datafusion::catalog::memory::DataSourceExec; +use datafusion::common::{DataFusionError, HashMap, HashSet, Result}; +use datafusion::datasource::TableType; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::parquet::DefaultParquetFileReaderFactory; +use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::logical_expr::TableProviderFilterPushDown; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; +use datafusion::physical_plan::union::UnionExec; +use datafusion::prelude::Expr; +use datafusion::scalar::ScalarValue; +use datafusion::{ + catalog::{Session, TableProvider}, + logical_expr::{LogicalPlan, dml::InsertOp}, + physical_plan::ExecutionPlan, +}; +use delta_kernel::Engine; +use delta_kernel::engine::arrow_conversion::{TryIntoArrow, TryIntoKernel}; +use delta_kernel::scan::ScanMetadata; +use delta_kernel::schema::SchemaRef as KernelSchemaRef; +use futures::future::ready; +use futures::{Stream, TryStreamExt as _}; +use itertools::Itertools; +use object_store::path::Path; + +use crate::DeltaTableError; +use crate::delta_datafusion::DataFusionMixins as _; +use crate::delta_datafusion::engine::{ + AsObjectStoreUrl as _, DataFusionEngine, to_delta_predicate, +}; +use crate::delta_datafusion::table_provider::get_pushdown_filters; +use crate::delta_datafusion::table_provider::next::replay::{ScanFileContext, ScanFileStream}; +pub use crate::delta_datafusion::table_provider::next::scan::DeltaScanExec; +use crate::kernel::{EagerSnapshot, Scan, Snapshot}; + +mod replay; +mod scan; + +pub type ScanMetadataStream = + Pin> + Send>>; + +#[async_trait::async_trait] +impl TableProvider for Snapshot { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.read_schema() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + fn get_table_definition(&self) -> Option<&str> { + None + } + + fn get_logical_plan(&self) -> Option> { + None + } + + async fn scan( + &self, + session: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let (scan, drop_count) = self.kernel_scan(projection, filters)?; + let engine = DataFusionEngine::new_from_session(session); + let stream = scan.scan_metadata(engine.clone()); + self.execution_plan(session, scan, stream, engine, limit, drop_count) + .await + } + + fn supports_filters_pushdown( + &self, + filter: &[&Expr], + ) -> Result> { + Ok(get_pushdown_filters( + filter, + self.metadata().partition_columns(), + )) + } + + /// Insert the data into the delta table + /// Insert operation is only supported for Append and Overwrite + /// Return the execution plan + async fn insert_into( + &self, + _state: &dyn Session, + _input: Arc, + _insert_op: InsertOp, + ) -> Result> { + todo!("Implement insert_into method") + } +} + +#[async_trait::async_trait] +impl TableProvider for EagerSnapshot { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + TableProvider::schema(self.snapshot()) + } + + fn table_type(&self) -> TableType { + self.snapshot().table_type() + } + + fn get_table_definition(&self) -> Option<&str> { + self.snapshot().get_table_definition() + } + + fn get_logical_plan(&self) -> Option> { + self.snapshot().get_logical_plan() + } + + async fn scan( + &self, + session: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let (scan, drop_count) = self.snapshot().kernel_scan(projection, filters)?; + let engine = DataFusionEngine::new_from_session(session); + let stream = if let Ok(files) = self.files() { + scan.scan_metadata_from( + engine.clone(), + self.snapshot().version() as u64, + Box::new(files.to_vec().into_iter()), + None, + ) + } else { + scan.scan_metadata(engine.clone()) + }; + self.snapshot() + .execution_plan(session, scan, stream, engine, limit, drop_count) + .await + } + + fn supports_filters_pushdown( + &self, + filter: &[&Expr], + ) -> Result> { + self.snapshot().supports_filters_pushdown(filter) + } + + /// Insert the data into the delta table + /// Insert operation is only supported for Append and Overwrite + /// Return the execution plan + async fn insert_into( + &self, + state: &dyn Session, + input: Arc, + insert_op: InsertOp, + ) -> Result> { + self.snapshot().insert_into(state, input, insert_op).await + } +} + +impl Snapshot { + /// Create a kernel scan with the given projection and filters. + /// + /// The projection is adjusted to include columns referenced in the filters which are not + /// part of the original projection. The consumer of the generated Scan is responsible to + /// project the final output schema after the scan. + fn kernel_scan( + &self, + projection: Option<&Vec>, + filters: &[Expr], + ) -> Result<(Arc, usize)> { + let pushdowns = get_pushdown_filters( + &filters.iter().collect::>(), + self.metadata().partition_columns().as_slice(), + ); + let mut project_away = 0usize; + let projection = projection + .map(|p| { + let mut projection = p.clone(); + let mut column_refs = HashSet::new(); + for (f, pd) in filters.iter().zip(pushdowns.iter()) { + match pd { + TableProviderFilterPushDown::Exact => { + for col in f.column_refs() { + column_refs.insert(col); + } + } + // TODO: Inexact and Unknown pushdowns should always be included in the projection, + // otherwise the upstream filters cannot be applied correctly. we could validate this + // here, but for now we assume datafusion handles this correctly. An error is raised + // later in the execution plan if this is not the case. + _ => (), + } + } + for col in column_refs { + let col_idx = self.read_schema().index_of(col.name())?; + if !projection.contains(&col_idx) { + projection.push(col_idx); + project_away += 1; + } + } + Ok::<_, DataFusionError>(projection) + }) + .transpose()?; + + let (_, projected_kernel_schema) = project_schema(self.read_schema(), projection.as_ref())?; + Ok(( + Arc::new( + self.scan_builder() + .with_schema(projected_kernel_schema.clone()) + .with_predicate(Arc::new(to_delta_predicate(filters)?)) + .build()?, + ), + project_away, + )) + } + + async fn execution_plan( + &self, + session: &dyn Session, + scan: Arc, + stream: ScanMetadataStream, + engine: Arc, + limit: Option, + drop_count: usize, + ) -> Result> { + let mut stream = ScanFileStream::new(engine, &scan, stream); + let mut files = Vec::new(); + while let Some(file) = stream.try_next().await? { + files.extend(file); + } + + let transforms: HashMap<_, _> = files + .iter_mut() + .flat_map(|file| { + file.transform + .take() + .map(|t| (file.file_url.to_string(), t)) + }) + .collect(); + let dv_stream = stream.dv_stream.build(); + let dvs: DashMap<_, _> = dv_stream + .try_filter_map(|(url, dv)| ready(Ok(dv.map(|dv| (url.to_string(), dv))))) + .try_collect() + .await?; + + let metrics = ExecutionPlanMetricsSet::new(); + MetricBuilder::new(&metrics) + .global_counter("count_files_skipped") + .add(stream.metrics.num_skipped); + MetricBuilder::new(&metrics) + .global_counter("count_files_scanned") + .add(stream.metrics.num_scanned); + + let file_id_column = "__delta_rs_file_id".to_string(); + + // Convert the files into datafusions `PartitionedFile`s grouped by the object store they are stored in + // this is used to create a DataSourceExec plan for each store + // To correlate the data with the original file, we add the file url as a partition value + // This is required to apply the correct transform to the data in downstream processing. + let to_partitioned_file = |f: ScanFileContext| { + let file_path = Path::from_url_path(f.file_url.path())?; + let mut partitioned_file = PartitionedFile::new(file_path.to_string(), f.size) + .with_statistics(Arc::new(f.stats)); + partitioned_file.partition_values = + vec![ScalarValue::Utf8(Some(f.file_url.to_string()))]; + // NB: we need to reassign the location since the 'new' method does incompatible path encoding internally. + partitioned_file.object_meta.location = file_path; + Ok::<_, DataFusionError>(( + f.file_url.as_object_store_url(), + (partitioned_file, None::>), + )) + }; + + // Group the files by their object store url. Since datafusion assumes that all files in a + // DataSourceExec are stored in the same object store, we need to create one plan per store + let files_by_store = files + .into_iter() + .map(to_partitioned_file) + .try_collect::<_, Vec<_>, _>()? + .into_iter() + .into_group_map(); + + let physical_schema: SchemaRef = + Arc::new(scan.physical_schema().as_ref().try_into_arrow()?); + + let pq_plan = get_read_plan( + files_by_store, + &physical_schema, + session, + limit, + Field::new( + file_id_column.clone(), + DataType::Dictionary(DataType::UInt16.into(), DataType::Utf8.into()), + false, + ), + &metrics, + ) + .await?; + + // we collect the logical fields from the read schema to ensure + // we consider the correct physical types as well (e.g. for StringView etc.) + let logical_fields = scan + .logical_schema() + .fields() + .flat_map(|f| { + self.read_schema() + .column_with_name(f.name()) + .map(|(_, c)| c.clone()) + }) + .collect_vec(); + let projected_arrow_schema = Arc::new(Schema::new(logical_fields)); + let exec = DeltaScanExec::new( + projected_arrow_schema, + scan.logical_schema().clone(), + pq_plan, + Arc::new(transforms), + Arc::new(dvs), + file_id_column, + metrics, + drop_count, + ); + + Ok(Arc::new(exec)) + } +} + +type FilesByStore = (ObjectStoreUrl, Vec<(PartitionedFile, Option>)>); +async fn get_read_plan( + files_by_store: impl IntoIterator, + physical_schema: &SchemaRef, + state: &dyn Session, + limit: Option, + file_id_field: Field, + _metrics: &ExecutionPlanMetricsSet, +) -> Result> { + // TODO: update parquet source. + let source = ParquetSource::default(); + + let mut plans = Vec::new(); + + for (store_url, files) in files_by_store.into_iter() { + // state.ensure_object_store(store_url.as_ref()).await?; + + let store = state.runtime_env().object_store(&store_url)?; + let _reader_factory = source + .parquet_file_reader_factory() + .cloned() + .unwrap_or_else(|| Arc::new(DefaultParquetFileReaderFactory::new(store))); + + // let file_group = compute_parquet_access_plans(&reader_factory, files, &metrics).await?; + let file_group = files.into_iter().map(|file| file.0); + + // TODO: convert passed predicate to an expression in terms of physical columns + // and add it to the FileScanConfig + // let file_source = + // source.with_schema_adapter_factory(Arc::new(NestedSchemaAdapterFactory))?; + let file_source = Arc::new(source.clone()); + let config = FileScanConfigBuilder::new(store_url, physical_schema.clone(), file_source) + .with_file_group(file_group.into_iter().collect()) + .with_table_partition_cols(vec![file_id_field.clone()]) + .with_limit(limit) + .build(); + let plan: Arc = DataSourceExec::from_data_source(config); + plans.push(plan); + } + + let plan = match plans.len() { + 1 => plans.remove(0), + _ => UnionExec::try_new(plans)?, + }; + Ok(match plan.with_fetch(limit) { + Some(limit) => limit, + None => plan, + }) +} + +fn project_schema( + schema: SchemaRef, + projection: Option<&Vec>, +) -> Result<(SchemaRef, KernelSchemaRef)> { + let projected_arrow_schema = match projection { + Some(p) => Arc::new(schema.project(p)?), + None => schema, + }; + let projected_kernel_schema: KernelSchemaRef = Arc::new( + projected_arrow_schema + .as_ref() + .try_into_kernel() + .map_err(DeltaTableError::from)?, + ); + Ok((projected_arrow_schema, projected_kernel_schema)) +} + +#[cfg(test)] +mod tests { + use datafusion::{ + datasource::{physical_plan::FileScanConfig, source::DataSource}, + physical_plan::{ExecutionPlanVisitor, collect_partitioned, visit_execution_plan}, + }; + + use crate::{ + assert_batches_sorted_eq, + delta_datafusion::session::create_test_session, + kernel::Snapshot, + test_utils::{TestResult, TestTables}, + }; + + use super::*; + + /// Extracts fields from the parquet scan + #[derive(Default)] + struct DeltaScanVisitor { + num_skipped: Option, + num_scanned: Option, + total_bytes_scanned: Option, + } + + impl DeltaScanVisitor { + fn pre_visit_delta_scan( + &mut self, + delta_scan_exec: &DeltaScanExec, + ) -> Result { + let Some(metrics) = delta_scan_exec.metrics() else { + return Ok(true); + }; + + self.num_skipped = metrics + .sum_by_name("count_files_skipped") + .map(|v| v.as_usize()); + self.num_scanned = metrics + .sum_by_name("count_files_scanned") + .map(|v| v.as_usize()); + + Ok(true) + } + + fn pre_visit_data_source( + &mut self, + datasource_exec: &DataSourceExec, + ) -> Result { + let Some(scan_config) = datasource_exec + .data_source() + .as_any() + .downcast_ref::() + else { + return Ok(true); + }; + + let pq_metrics = scan_config + .metrics() + .clone_inner() + .sum_by_name("bytes_scanned"); + self.total_bytes_scanned = pq_metrics.map(|v| v.as_usize()); + + // if let Some(parquet_source) = scan_config + // .file_source + // .as_any() + // .downcast_ref::() + // { + // parquet_source + // } + + Ok(true) + } + } + + impl ExecutionPlanVisitor for DeltaScanVisitor { + type Error = DataFusionError; + + fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { + if let Some(delta_scan_exec) = plan.as_any().downcast_ref::() { + return self.pre_visit_delta_scan(delta_scan_exec); + }; + + if let Some(datasource_exec) = plan.as_any().downcast_ref::() { + return self.pre_visit_data_source(datasource_exec); + } + + Ok(true) + } + } + + #[tokio::test] + async fn test_query_simple_table() -> TestResult { + let log_store = TestTables::Simple.table_builder()?.build_storage()?; + let snapshot = Arc::new(Snapshot::try_new(&log_store, Default::default(), None).await?); + + let session = Arc::new(create_test_session().into_inner()); + session.register_table("delta_table", snapshot).unwrap(); + + let df = session.sql("SELECT * FROM delta_table").await.unwrap(); + let batches = df.collect().await?; + + let expected = vec![ + "+----+", "| id |", "+----+", "| 5 |", "| 7 |", "| 9 |", "+----+", + ]; + assert_batches_sorted_eq!(&expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn test_scan_simple_table() -> TestResult { + let log_store = TestTables::Simple.table_builder()?.build_storage()?; + let snapshot = Snapshot::try_new(&log_store, Default::default(), None).await?; + + let session = Arc::new(create_test_session().into_inner()); + let state = session.state_ref().read().clone(); + + let plan = snapshot.scan(&state, None, &[], None).await?; + + let batches: Vec<_> = collect_partitioned(plan.clone(), session.task_ctx()) + .await? + .into_iter() + .flatten() + .collect(); + + let mut visitor = DeltaScanVisitor::default(); + visit_execution_plan(plan.as_ref(), &mut visitor).unwrap(); + + assert_eq!(visitor.num_scanned, Some(5)); + assert_eq!(visitor.num_skipped, Some(28)); + assert_eq!(visitor.total_bytes_scanned, Some(231)); + + let expected = vec![ + "+----+", "| id |", "+----+", "| 5 |", "| 7 |", "| 9 |", "+----+", + ]; + + assert_batches_sorted_eq!(&expected, &batches); + + Ok(()) + } +} diff --git a/crates/core/src/delta_datafusion/table_provider/next/replay.rs b/crates/core/src/delta_datafusion/table_provider/next/replay.rs new file mode 100644 index 0000000000..83989a8336 --- /dev/null +++ b/crates/core/src/delta_datafusion/table_provider/next/replay.rs @@ -0,0 +1,342 @@ +use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use arrow::{array::BooleanArray, compute::filter_record_batch}; +use datafusion::{ + common::{ + ColumnStatistics, HashMap, Statistics, error::DataFusionErrorBuilder, stats::Precision, + }, + error::DataFusionError, + scalar::ScalarValue, +}; +use delta_kernel::{ + Engine, ExpressionRef, + engine::{arrow_conversion::TryIntoArrow, arrow_data::ArrowEngineData}, + expressions::{Scalar, StructData}, + scan::{ + Scan as KernelScan, ScanMetadata, + state::{DvInfo, ScanFile}, + }, +}; +use futures::Stream; +use itertools::Itertools; +use pin_project_lite::pin_project; +use url::Url; + +use crate::{ + DeltaResult, + delta_datafusion::engine::scalar_to_df, + kernel::{ + LogicalFileView, ReceiverStreamBuilder, Scan, StructDataExt, + arrow::engine_ext::stats_schema, parse_stats_column_with_schema, + }, +}; + +#[derive(Debug)] +pub(crate) struct ReplayStats { + pub(crate) num_skipped: usize, + pub(crate) num_scanned: usize, +} + +impl ReplayStats { + fn new() -> Self { + Self { + num_skipped: 0, + num_scanned: 0, + } + } +} + +pin_project! { + /// Stream to read scan file contexts from a scan metadata stream. + pub(crate) struct ScanFileStream { + pub(crate) metrics: ReplayStats, + + engine: Arc, + + table_root: Url, + + kernel_scan: Arc, + + pub(crate) dv_stream: ReceiverStreamBuilder<(Url, Option>)>, + + #[pin] + stream: S, + } +} + +impl ScanFileStream { + pub(crate) fn new(engine: Arc, scan: &Arc, stream: S) -> Self { + Self { + metrics: ReplayStats::new(), + dv_stream: ReceiverStreamBuilder::<(Url, Option>)>::new(100), + engine, + table_root: scan.table_root().clone(), + kernel_scan: scan.inner.clone(), + stream, + } + } +} + +impl Stream for ScanFileStream +where + S: Stream>, +{ + type Item = DeltaResult>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + let physical_arrow = this + .kernel_scan + .physical_schema() + .as_ref() + .try_into_arrow() + .unwrap(); + match this.stream.poll_next(cx) { + Poll::Ready(Some(Ok(scan_data))) => { + let mut ctx = ScanContext::new(this.table_root.clone()); + ctx = match scan_data.visit_scan_files(ctx, visit_scan_file) { + Ok(ctx) => ctx, + Err(err) => return Poll::Ready(Some(Err(err.into()))), + }; + + // Spawn tasks to read the deletion vectors from disk. + for file in &ctx.files { + let engine = this.engine.clone(); + let dv_info = file.dv_info.clone(); + let file_url = file.file_url.clone(); + let table_root = this.table_root.clone(); + let tx = this.dv_stream.tx(); + if dv_info.has_vector() { + let load_dv = move || { + let dv = dv_info.get_selection_vector(engine.as_ref(), &table_root)?; + let _ = tx.blocking_send(Ok((file_url, dv))); + Ok(()) + }; + this.dv_stream.spawn_blocking(load_dv); + } + } + + this.metrics.num_scanned += ctx.count; + this.metrics.num_skipped += scan_data + .scan_files + .selection_vector() + .len() + .saturating_sub(ctx.count); + + let (data, selection_vector) = scan_data.scan_files.into_parts(); + let batch = ArrowEngineData::try_from_engine_data(data)?.into(); + let scan_files = + filter_record_batch(&batch, &BooleanArray::from(selection_vector))?; + + let stats_schema = Arc::new(stats_schema( + this.kernel_scan.physical_schema(), + this.kernel_scan.snapshot().table_properties(), + )); + let parsed_stats = parse_stats_column_with_schema( + this.kernel_scan.snapshot().as_ref(), + &scan_files, + stats_schema, + )?; + + let mut file_statistics = extract_file_statistics(this.kernel_scan, parsed_stats); + + Poll::Ready(Some(Ok(ctx + .files + .into_iter() + .map(|ctx| { + let stats = file_statistics + .remove(&ctx.file_url) + .unwrap_or_else(|| Statistics::new_unknown(&physical_arrow)); + ScanFileContext::new(ctx, stats) + }) + .collect_vec()))) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} + +fn extract_file_statistics( + scan: &KernelScan, + parsed_stats: arrow_array::RecordBatch, +) -> HashMap { + (0..parsed_stats.num_rows()) + .map(move |idx| LogicalFileView::new(parsed_stats.clone(), idx)) + .filter_map(|view| { + let num_rows = view + .num_records() + .map(Precision::Exact) + .unwrap_or(Precision::Absent); + let total_byte_size = Precision::Exact(view.size() as usize); + + let null_counts = extract_struct(view.null_counts()); + let max_values = extract_struct(view.max_values()); + let min_values = extract_struct(view.min_values()); + + let column_statistics = scan + .physical_schema() + .fields() + .map(|f| { + let null_count = if let Some(field_index) = + null_counts.as_ref().and_then(|v| v.index_of(f.name())) + { + null_counts + .as_ref() + .map(|v| match v.values()[field_index] { + Scalar::Integer(int_val) => Precision::Exact(int_val as usize), + Scalar::Long(long_val) => Precision::Exact(long_val as usize), + _ => Precision::Absent, + }) + .unwrap_or_default() + } else { + Precision::Absent + }; + + let max_value = extract_precision(&max_values, f.name()); + let min_value = extract_precision(&min_values, f.name()); + + ColumnStatistics { + null_count, + max_value, + min_value, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + } + }) + .collect_vec(); + + Some(( + parse_path(scan.snapshot().table_root(), view.path().as_ref()).ok()?, + Statistics { + num_rows, + total_byte_size, + column_statistics, + }, + )) + }) + .collect() +} + +fn extract_precision(data: &Option, name: impl AsRef) -> Precision { + if let Some(field_index) = data.as_ref().and_then(|v| v.index_of(name.as_ref())) { + data.as_ref() + .map(|v| match scalar_to_df(&v.values()[field_index]) { + Ok(df) => Precision::Exact(df), + _ => Precision::Absent, + }) + .unwrap_or_default() + } else { + Precision::Absent + } +} + +fn extract_struct(scalar: Option) -> Option { + match scalar { + Some(Scalar::Struct(data)) => Some(data), + _ => None, + } +} + +#[derive(Debug)] +pub struct ScanFileContext { + /// Fully qualified URL of the file. + pub file_url: Url, + /// Size of the file on disk. + pub size: u64, + /// Selection vector to filter the data in the file. + // pub selection_vector: Option>, + /// Transformations to apply to the data in the file. + pub transform: Option, + /// Statistics about the data in the file. + /// + /// The query engine may choose to use these statistics to further optimize the scan. + pub stats: Statistics, +} + +impl ScanFileContext { + /// Create a new `ScanFileContext` with the given file URL, size, and statistics. + fn new(inner: ScanFileContextInner, stats: Statistics) -> Self { + Self { + file_url: inner.file_url, + size: inner.size, + transform: inner.transform, + stats, + } + } +} + +/// Metadata to read a data file from object storage. +struct ScanFileContextInner { + /// Fully qualified URL of the file. + pub file_url: Url, + /// Size of the file on disk. + pub size: u64, + /// Selection vector to filter the data in the file. + // pub selection_vector: Option>, + /// Transformations to apply to the data in the file. + pub transform: Option, + + pub dv_info: DvInfo, +} + +struct ScanContext { + /// Table root URL + table_root: Url, + /// Files to be scanned. + files: Vec, + /// Errors encountered during the scan. + errs: DataFusionErrorBuilder, + count: usize, +} + +impl ScanContext { + fn new(table_root: Url) -> Self { + Self { + table_root, + files: Vec::new(), + errs: DataFusionErrorBuilder::new(), + count: 0, + } + } + + fn parse_path(&self, path: &str) -> DeltaResult { + parse_path(&self.table_root, path) + } +} + +fn parse_path(url: &Url, path: &str) -> DeltaResult { + Ok(match Url::parse(path) { + Ok(url) => url, + Err(_) => url + .join(path) + .map_err(|e| DataFusionError::External(Box::new(e)))?, + }) +} + +fn visit_scan_file(ctx: &mut ScanContext, scan_file: ScanFile) { + let file_url = match ctx.parse_path(&scan_file.path) { + Ok(v) => v, + Err(e) => { + ctx.errs.add_error(e); + return; + } + }; + + ctx.files.push(ScanFileContextInner { + dv_info: scan_file.dv_info, + transform: scan_file.transform, + file_url, + size: scan_file.size as u64, + }); + ctx.count += 1; +} diff --git a/crates/core/src/delta_datafusion/table_provider/next/scan.rs b/crates/core/src/delta_datafusion/table_provider/next/scan.rs new file mode 100644 index 0000000000..15ae7beae8 --- /dev/null +++ b/crates/core/src/delta_datafusion/table_provider/next/scan.rs @@ -0,0 +1,369 @@ +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::{ArrayAccessor, AsArray, RecordBatch, StringArray}; +use arrow::compute::filter_record_batch; +use arrow::datatypes::{SchemaRef, UInt16Type}; +use arrow_array::BooleanArray; +use arrow_schema::Schema; +use dashmap::DashMap; +use datafusion::common::HashMap; +use datafusion::common::config::ConfigOptions; +use datafusion::common::error::{DataFusionError, Result}; +use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::execution_plan::{CardinalityEffect, PlanProperties}; +use datafusion::physical_plan::filter_pushdown::{FilterDescription, FilterPushdownPhase}; +use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PhysicalExpr, Statistics, +}; +use delta_kernel::engine::arrow_conversion::TryIntoKernel; +use delta_kernel::schema::{DataType as KernelDataType, SchemaRef as KernelSchemaRef}; +use delta_kernel::{EvaluationHandler, ExpressionRef}; +use futures::stream::{Stream, StreamExt}; +use itertools::Itertools; + +use crate::cast_record_batch; +use crate::kernel::ARROW_HANDLER; +use crate::kernel::arrow::engine_ext::ExpressionEvaluatorExt; + +#[derive(Clone, Debug)] +pub struct DeltaScanExec { + /// Output schema for processed data. + kernel_logical_schema: KernelSchemaRef, + /// Execution plan yielding the raw data read from data files. + input: Arc, + /// Transforms to be applied to data eminating from individual files + transforms: Arc>, + /// Deletion vectors for the table + selection_vectors: Arc>>, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Column name for the file id + file_id_column: String, + + /// plan properties + properties: PlanProperties, + drop_count: usize, +} + +impl DisplayAs for DeltaScanExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + // TODO: actually implement formatting according to the type + match t { + DisplayFormatType::Default + | DisplayFormatType::Verbose + | DisplayFormatType::TreeRender => { + write!(f, "DeltaScanExec: file_id_column={}", self.file_id_column) + } + } + } +} + +impl DeltaScanExec { + pub(crate) fn new( + logical_schema: SchemaRef, + kernel_logical_schema: KernelSchemaRef, + input: Arc, + transforms: Arc>, + selection_vectors: Arc>>, + file_id_column: String, + metrics: ExecutionPlanMetricsSet, + drop_count: usize, + ) -> Self { + let max_idx = logical_schema.fields().len().saturating_sub(drop_count); + let output_fields = logical_schema + .fields() + .into_iter() + .enumerate() + .filter_map(|(i, f)| if i < max_idx { Some(f.clone()) } else { None }); + let logical_schema = Arc::new(Schema::new(output_fields.collect_vec())); + let properties = PlanProperties::new( + EquivalenceProperties::new(logical_schema), + input.properties().partitioning.clone(), + input.properties().emission_type, + input.properties().boundedness, + ); + Self { + kernel_logical_schema, + input, + transforms, + selection_vectors, + metrics, + file_id_column, + properties, + drop_count, + } + } +} + +impl ExecutionPlan for DeltaScanExec { + fn name(&self) -> &'static str { + "DeltaScanExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + // fn maintains_input_order(&self) -> Vec { + // // Tell optimizer this operator doesn't reorder its input + // vec![true] + // } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != 1 { + return Err(DataFusionError::Plan(format!( + "DeltaScan: wrong number of children {}", + children.len() + ))); + } + Ok(Arc::new(Self::new( + self.properties.eq_properties.schema().clone(), + self.kernel_logical_schema.clone(), + children[0].clone(), + self.transforms.clone(), + self.selection_vectors.clone(), + self.file_id_column.clone(), + self.metrics.clone(), + self.drop_count, + ))) + } + + fn repartitioned( + &self, + target_partitions: usize, + config: &ConfigOptions, + ) -> Result>> { + if let Some(input) = self.input.repartitioned(target_partitions, config)? { + Ok(Some(Arc::new(Self { + input, + ..self.clone() + }))) + } else { + Ok(None) + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + Ok(Box::pin(DeltaScanStream { + schema: Arc::clone(self.properties.eq_properties.schema()), + kernel_type: Arc::clone(&self.kernel_logical_schema).into(), + input: self.input.execute(partition, context)?, + kernel_input_schema: Arc::new(self.input.schema().as_ref().try_into_kernel()?), + baseline_metrics: BaselineMetrics::new(&self.metrics, partition), + transforms: Arc::clone(&self.transforms), + selection_vectors: Arc::clone(&self.selection_vectors), + file_id_column: self.file_id_column.clone(), + drop_count: self.drop_count, + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + // self.input.partition_statistics(None) + Ok(Statistics::new_unknown(self.schema().as_ref())) + } + + fn supports_limit_pushdown(&self) -> bool { + true + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } + + fn fetch(&self) -> Option { + self.input.fetch() + } + + fn with_fetch(&self, limit: Option) -> Option> { + if let Some(new_input) = self.input.with_fetch(limit) { + let mut new_plan = self.clone(); + new_plan.input = new_input; + Some(Arc::new(new_plan)) + } else { + None + } + } + + fn partition_statistics(&self, partition: Option) -> Result { + // TODO: handle statistics conversion properly to leverage parquet plan statistics. + // self.input.partition_statistics(partition) + Ok(Statistics::new_unknown(self.schema().as_ref())) + } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + // TODO(roeap): this will likely not do much for column mapping enabled tables + // since the default methods determines this based on existence of columns in child + // schemas. In the case of column mapping all columns will have a different name. + FilterDescription::from_children(parent_filters, &self.children()) + } +} + +/// Stream of RecordBatches produced by scanning a Delta table. +/// +/// The data returned by this stream represents the logical data caontained in the table. +/// This means all transformations according to the Delta protocol are applied. This includes: +/// - partition values +/// - column mapping to the logical schema +/// - deletion vectors +struct DeltaScanStream { + /// Output schema for processed data. + schema: SchemaRef, + /// Kernel data type for the data after transformations + kernel_type: KernelDataType, + /// Input stream yielding raw data read from data files. + input: SendableRecordBatchStream, + /// Kernel schema for the data before transformations + kernel_input_schema: KernelSchemaRef, + baseline_metrics: BaselineMetrics, + /// Transforms to be applied to data read from individual files + transforms: Arc>, + /// Selection vectors to be applied to data read from individual files + selection_vectors: Arc>>, + /// Column name for the file id + file_id_column: String, + drop_count: usize, +} + +impl DeltaScanStream { + /// Apply the per-file transformation to a RecordBatch. + fn batch_project(&mut self, mut batch: RecordBatch) -> Result { + let _timer = self.baseline_metrics.elapsed_compute().timer(); + + let (file_id, file_id_idx) = extract_file_id(&batch, &self.file_id_column)?; + batch.remove_column(file_id_idx); + + let selection = if let Some(mut selection_vector) = self.selection_vectors.get_mut(&file_id) + { + if selection_vector.len() >= batch.num_rows() { + let sv: Vec = selection_vector.drain(0..batch.num_rows()).collect(); + Some(sv) + } else { + let remaining = batch.num_rows() - selection_vector.len(); + let sel_len = selection_vector.len(); + let mut sv: Vec = selection_vector.drain(0..sel_len).collect(); + sv.extend(vec![true; remaining]); + Some(sv) + } + } else { + None + }; + + // NOTE: this case may occur e.g. in a COUNT(*) query where no columns are projected + if batch.num_columns() == 0 { + if let Some(selection) = selection { + let filtered_batch = filter_record_batch(&batch, &BooleanArray::from(selection))?; + return Ok(filtered_batch); + } + return Ok(batch); + } + + let batch = if let Some(selection) = selection { + filter_record_batch(&batch, &BooleanArray::from(selection))? + } else { + batch + }; + + let Some(transform) = self.transforms.get(&file_id) else { + let batch = RecordBatch::try_new(self.schema.clone(), batch.columns().to_vec())?; + return Ok(batch); + }; + + let evaluator = ARROW_HANDLER + .new_expression_evaluator( + self.kernel_input_schema.clone(), + transform.clone(), + self.kernel_type.clone(), + ) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + let mut result = evaluator + .evaluate_arrow(batch) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + for _ in 0..self.drop_count { + result.remove_column(result.num_columns() - 1); + } + + // TODO: all casting should be done in the expression evaluator + Ok(cast_record_batch(&result, self.schema.clone(), true, true)?) + } +} + +impl Stream for DeltaScanStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let poll = self.input.poll_next_unpin(cx).map(|x| match x { + Some(Ok(batch)) => Some(self.batch_project(batch)), + other => other, + }); + self.baseline_metrics.record_poll(poll) + } + + fn size_hint(&self) -> (usize, Option) { + self.input.size_hint() + } +} + +impl RecordBatchStream for DeltaScanStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +fn extract_file_id(batch: &RecordBatch, file_id_column: &str) -> Result<(String, usize)> { + let file_id_idx = batch + .schema_ref() + .fields() + .iter() + .position(|f| f.name() == file_id_column) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Expected column '{}' to be present in the input", + file_id_column + )) + })?; + + let file_id = batch + .column(file_id_idx) + .as_dictionary::() + .downcast_dict::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Expected file id column to be a dictionary of strings" + )) + })? + .value(0) + .to_string(); + + Ok((file_id, file_id_idx)) +} diff --git a/crates/core/src/kernel/snapshot/iterators.rs b/crates/core/src/kernel/snapshot/iterators.rs index 45b3ba848a..a517d8246e 100644 --- a/crates/core/src/kernel/snapshot/iterators.rs +++ b/crates/core/src/kernel/snapshot/iterators.rs @@ -20,6 +20,8 @@ use crate::kernel::scalars::ScalarExt; use crate::kernel::{Add, DeletionVectorDescriptor, Remove}; use crate::{DeltaResult, DeltaTableError}; +#[cfg(feature = "datafusion")] +pub(crate) use self::scan_row::parse_stats_column_with_schema; pub(crate) use self::scan_row::{ScanRowOutStream, scan_row_in_eval}; pub use self::tombstones::TombstoneView; diff --git a/crates/core/src/kernel/snapshot/iterators/scan_row.rs b/crates/core/src/kernel/snapshot/iterators/scan_row.rs index 2e4d8e7ebf..246c7655b5 100644 --- a/crates/core/src/kernel/snapshot/iterators/scan_row.rs +++ b/crates/core/src/kernel/snapshot/iterators/scan_row.rs @@ -13,8 +13,8 @@ use delta_kernel::engine::parse_json; use delta_kernel::expressions::Scalar; use delta_kernel::expressions::UnaryExpressionOp; use delta_kernel::scan::scan_row_schema; -use delta_kernel::schema::DataType; use delta_kernel::schema::PrimitiveType; +use delta_kernel::schema::{DataType, SchemaRef as KernelSchemaRef}; use delta_kernel::snapshot::Snapshot as KernelSnapshot; use delta_kernel::table_features::ColumnMappingMode; use delta_kernel::{EvaluationHandler, Expression, ExpressionEvaluator}; @@ -94,7 +94,19 @@ pub(crate) fn scan_row_in_eval( )?) } -fn parse_stats_column(sn: &KernelSnapshot, batch: &RecordBatch) -> DeltaResult { +pub(crate) fn parse_stats_column( + sn: &KernelSnapshot, + batch: &RecordBatch, +) -> DeltaResult { + let stats_schema = sn.stats_schema()?; + parse_stats_column_with_schema(sn, batch, stats_schema) +} + +pub(crate) fn parse_stats_column_with_schema( + sn: &KernelSnapshot, + batch: &RecordBatch, + stats_schema: KernelSchemaRef, +) -> DeltaResult { let Some((stats_idx, _)) = batch.schema_ref().column_with_name("stats") else { return Err(DeltaTableError::SchemaMismatch { msg: "stats column not found".to_string(), @@ -105,7 +117,6 @@ fn parse_stats_column(sn: &KernelSnapshot, batch: &RecordBatch) -> DeltaResult>) -> Self { self.inner = self.inner.with_predicate(predicate); self @@ -70,7 +71,7 @@ impl ScanBuilder { #[derive(Debug)] pub struct Scan { - inner: Arc, + pub(crate) inner: Arc, } impl From for Scan { diff --git a/crates/core/src/operations/load.rs b/crates/core/src/operations/load.rs index d864634449..0a9c562591 100644 --- a/crates/core/src/operations/load.rs +++ b/crates/core/src/operations/load.rs @@ -115,11 +115,13 @@ impl std::future::IntoFuture for LoadBuilder { #[cfg(test)] mod tests { + use crate::delta_datafusion::create_session; use crate::operations::collect_sendable_stream; use crate::writer::test_utils::{TestResult, get_record_batch}; use crate::{DeltaTable, DeltaTableBuilder}; use datafusion::assert_batches_sorted_eq; use std::path::Path; + use std::sync::Arc; use url::Url; #[tokio::test] @@ -157,7 +159,16 @@ mod tests { .write(vec![batch.clone()]) .await?; - let (_table, stream) = table.scan_table().await?; + let session = create_session().into_inner(); + session.runtime_env().register_object_store( + &url::Url::parse("memory:///")?, + table.log_store().object_store(None), + ); + + let (_table, stream) = table + .scan_table() + .with_session_state(Arc::new(session.state())) + .await?; let data = collect_sendable_stream(stream).await?; let expected = vec![ @@ -190,7 +201,17 @@ mod tests { .write(vec![batch.clone()]) .await?; - let (_table, stream) = table.scan_table().with_columns(["id", "value"]).await?; + let session = create_session().into_inner(); + session.runtime_env().register_object_store( + &url::Url::parse("memory:///")?, + table.log_store().object_store(None), + ); + + let (_table, stream) = table + .scan_table() + .with_columns(["id", "value"]) + .with_session_state(Arc::new(session.state())) + .await?; let data = collect_sendable_stream(stream).await?; let expected = vec![ diff --git a/crates/core/src/operations/vacuum.rs b/crates/core/src/operations/vacuum.rs index 8b81f4673e..9b143787d5 100644 --- a/crates/core/src/operations/vacuum.rs +++ b/crates/core/src/operations/vacuum.rs @@ -553,13 +553,14 @@ async fn get_stale_files( #[cfg(test)] mod tests { + use std::path::Path; + use std::{io::Read, time::SystemTime}; + use object_store::{PutPayload, local::LocalFileSystem, memory::InMemory}; + use url::Url; use super::*; use crate::{ensure_table_uri, open_table}; - use std::path::Path; - use std::{io::Read, time::SystemTime}; - use url::Url; #[tokio::test] async fn test_vacuum_full() -> DeltaResult<()> { @@ -704,7 +705,7 @@ mod tests { async fn test_vacuum_keep_version_validity() { use datafusion::prelude::SessionContext; use object_store::GetResultPayload; - let store = InMemory::new(); + let store = Arc::new(InMemory::new()); let source = LocalFileSystem::new_with_prefix("../test/tests/data/simple_table").unwrap(); let mut stream = source.list(None); @@ -726,7 +727,7 @@ mod tests { let table_url = url::Url::parse("memory:///").unwrap(); let mut table = crate::DeltaTableBuilder::from_url(table_url.clone()) .unwrap() - .with_storage_backend(Arc::new(store), table_url) + .with_storage_backend(store.clone(), table_url) .build() .unwrap(); table.load().await.unwrap(); @@ -752,6 +753,8 @@ mod tests { assert_eq!(Some(6), table.version()); let ctx = SessionContext::new(); + ctx.runtime_env() + .register_object_store(&url::Url::parse("memory:///").unwrap(), store); ctx.register_table("test", Arc::new(table)).unwrap(); let _batches = ctx .sql("SELECT * FROM test") diff --git a/crates/core/tests/dat.rs b/crates/core/tests/dat.rs index 127221e18c..46586d9e21 100644 --- a/crates/core/tests/dat.rs +++ b/crates/core/tests/dat.rs @@ -5,12 +5,7 @@ use deltalake_test::{TestResult, acceptance::read_dat_case}; use pretty_assertions::assert_eq; use rstest::rstest; -static SKIPPED_TESTS: &[&str; 4] = &[ - "iceberg_compat_v1", - "column_mapping", - "check_constraints", - "deletion_vectors", -]; +static SKIPPED_TESTS: &[&str; 1] = &["iceberg_compat_v1"]; #[rstest] #[tokio::test] diff --git a/crates/core/tests/datafusion_dat.rs b/crates/core/tests/datafusion_dat.rs new file mode 100644 index 0000000000..b7dc73e32b --- /dev/null +++ b/crates/core/tests/datafusion_dat.rs @@ -0,0 +1,77 @@ +#![cfg(feature = "datafusion")] +use std::{path::PathBuf, sync::Arc}; + +use datafusion::{ + catalog::{Session, TableProvider}, + datasource::{ + file_format::parquet::ParquetFormat, + listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, + }, +}; +use deltalake_core::{DeltaTableBuilder, delta_datafusion::create_session}; +use deltalake_test::acceptance::assert_data_matches; +use deltalake_test::{ + TestResult, + acceptance::{TableVersion, read_dat_case}, +}; +use rstest::rstest; +use url::Url; + +static SKIPPED_TESTS: &[&str; 1] = &["iceberg_compat_v1"]; + +async fn parquet_provider( + table_path: &TableVersion, + state: &dyn Session, +) -> TestResult> { + let table_url = Url::from_directory_path(&table_path.data_dir).unwrap(); + let table_path = ListingTableUrl::parse(table_url)?; + let file_format = ParquetFormat::new(); + let listing_options = + ListingOptions::new(Arc::new(file_format)).with_file_extension(".parquet"); + let config = ListingTableConfig::new(table_path).with_listing_options(listing_options); + let config = config.infer_schema(state).await?; + Ok(Arc::new(ListingTable::try_new(config)?)) +} + +#[rstest] +#[tokio::test] +async fn scan_dat( + #[files("../../dat/v0.0.3/reader_tests/generated/**/test_case_info.json")] path: PathBuf, +) -> TestResult<()> { + let case_dir = path.parent().expect("parent dir"); + if SKIPPED_TESTS.iter().any(|c| case_dir.ends_with(c)) { + println!("Skipping test: {case_dir:?}"); + return Ok(()); + } + + let case = read_dat_case(case_dir)?; + let ctx = create_session().into_inner(); + let table = DeltaTableBuilder::from_url(case.table_root()?)?.build()?; + + for version in case.all_table_versions()? { + let version = version?; + + let pq = parquet_provider(&version, &ctx.state()).await?; + let schema = pq.schema(); + let columns = schema + .fields() + .iter() + .map(|f| f.name().as_str()) + .collect::>(); + let expected = ctx.read_table(pq)?.collect().await?; + + let delta = table + .table_provider() + .with_table_version(version.meta.version) + .await?; + let actual = ctx + .read_table(delta)? + .select_columns(&columns)? + .collect() + .await?; + + assert_data_matches(&actual, &expected)?; + } + + Ok(()) +} diff --git a/crates/core/tests/datafusion_table_provider.rs b/crates/core/tests/datafusion_table_provider.rs new file mode 100644 index 0000000000..96055f5c01 --- /dev/null +++ b/crates/core/tests/datafusion_table_provider.rs @@ -0,0 +1,236 @@ +#![cfg(feature = "datafusion")] +use std::sync::Arc; + +use arrow_array::RecordBatch; +use datafusion::assert_batches_sorted_eq; +use datafusion::datasource::TableProvider; +use datafusion::physical_plan::{ExecutionPlan, collect_partitioned}; +use datafusion::prelude::{SessionContext, col, lit}; +use deltalake_core::delta_datafusion::create_session; +use deltalake_core::delta_datafusion::engine::DataFusionEngine; +use deltalake_core::kernel::Snapshot; +use deltalake_test::TestResult; +use deltalake_test::acceptance::read_dat_case; + +async fn scan_dat(case: &str) -> TestResult<(Snapshot, SessionContext)> { + let root_dir = format!( + "{}/../../dat/v0.0.3/reader_tests/generated/{}/", + env!["CARGO_MANIFEST_DIR"], + case + ); + let root_dir = std::fs::canonicalize(root_dir)?; + let case = read_dat_case(root_dir)?; + + let session = create_session().into_inner(); + let engine = DataFusionEngine::new_from_session(&session.state()); + + let snapshot = + Snapshot::try_new_with_engine(engine.clone(), case.table_root()?, Default::default(), None) + .await?; + + Ok((snapshot, session)) +} + +async fn collect_plan( + plan: Arc, + session: &SessionContext, +) -> TestResult> { + let batches: Vec<_> = collect_partitioned(plan, session.task_ctx()) + .await? + .into_iter() + .flatten() + .collect(); + Ok(batches) +} + +#[tokio::test] +async fn test_all_primitive_types() -> TestResult<()> { + let (snapshot, session) = scan_dat("all_primitive_types").await?; + + let plan = snapshot.scan(&session.state(), None, &[], None).await?; + let batches: Vec<_> = collect_plan(plan, &session).await?; + let expected = vec![ + "+------+-------+-------+-------+------+---------+---------+-------+----------+---------+------------+----------------------+", + "| utf8 | int64 | int32 | int16 | int8 | float32 | float64 | bool | binary | decimal | date32 | timestamp |", + "+------+-------+-------+-------+------+---------+---------+-------+----------+---------+------------+----------------------+", + "| 0 | 0 | 0 | 0 | 0 | 0.0 | 0.0 | true | | 10.000 | 1970-01-01 | 1970-01-01T00:00:00Z |", + "| 1 | 1 | 1 | 1 | 1 | 1.0 | 1.0 | false | 00 | 11.000 | 1970-01-02 | 1970-01-01T01:00:00Z |", + "| 2 | 2 | 2 | 2 | 2 | 2.0 | 2.0 | true | 0000 | 12.000 | 1970-01-03 | 1970-01-01T02:00:00Z |", + "| 3 | 3 | 3 | 3 | 3 | 3.0 | 3.0 | false | 000000 | 13.000 | 1970-01-04 | 1970-01-01T03:00:00Z |", + "| 4 | 4 | 4 | 4 | 4 | 4.0 | 4.0 | true | 00000000 | 14.000 | 1970-01-05 | 1970-01-01T04:00:00Z |", + "+------+-------+-------+-------+------+---------+---------+-------+----------+---------+------------+----------------------+", + ]; + assert_batches_sorted_eq!(&expected, &batches); + + let plan = snapshot + .scan(&session.state(), Some(&vec![1, 3]), &[], None) + .await?; + let batches: Vec<_> = collect_plan(plan, &session).await?; + let expected = vec![ + "+-------+-------+", + "| int64 | int16 |", + "+-------+-------+", + "| 0 | 0 |", + "| 1 | 1 |", + "| 2 | 2 |", + "| 3 | 3 |", + "| 4 | 4 |", + "+-------+-------+", + ]; + assert_batches_sorted_eq!(&expected, &batches); + + let plan = snapshot + .scan(&session.state(), Some(&vec![1, 3]), &[], Some(2)) + .await?; + let batches: Vec<_> = collect_plan(plan, &session).await?; + let expected = vec![ + "+-------+-------+", + "| int64 | int16 |", + "+-------+-------+", + "| 0 | 0 |", + "| 1 | 1 |", + "+-------+-------+", + ]; + assert_batches_sorted_eq!(&expected, &batches); + + // While we are passing a filter, the table provider does not yet push down + // the filter to the parquet scan, so we expect the same result as above. + let pred = col("float64").gt(lit(2.0_f64)); + let plan = snapshot + .scan(&session.state(), Some(&vec![1, 3, 6]), &[pred], None) + .await?; + let batches: Vec<_> = collect_plan(plan, &session).await?; + let expected = vec![ + "+-------+-------+---------+", + "| int64 | int16 | float64 |", + "+-------+-------+---------+", + "| 0 | 0 | 0.0 |", + "| 1 | 1 | 1.0 |", + "| 2 | 2 | 2.0 |", + "| 3 | 3 | 3.0 |", + "| 4 | 4 | 4.0 |", + "+-------+-------+---------+", + ]; + assert_batches_sorted_eq!(&expected, &batches); + + Ok(()) +} + +#[tokio::test] +async fn test_multi_partitioned() -> TestResult<()> { + let (snapshot, session) = scan_dat("multi_partitioned").await?; + + let plan = snapshot.scan(&session.state(), None, &[], None).await?; + let batches: Vec<_> = collect_plan(plan, &session).await?; + let expected = vec![ + "+--------+--------+------------+------------+", + "| number | letter | date | data |", + "+--------+--------+------------+------------+", + "| 6 | /%20%f | 1970-01-01 | 68656c6c6f |", + "| 7 | b | 1970-01-01 | f09f9888 |", + "+--------+--------+------------+------------+", + ]; + assert_batches_sorted_eq!(&expected, &batches); + + // While are not yet pushing down predicates to the parquet scan, the number values + // are separated by files, so we expect only one file to be read here. + let pred = col("number").gt(lit(6_i64)); + let plan = snapshot.scan(&session.state(), None, &[pred], None).await?; + let batches: Vec<_> = collect_plan(plan, &session).await?; + let expected = vec![ + "+--------+--------+------------+----------+", + "| number | letter | date | data |", + "+--------+--------+------------+----------+", + "| 7 | b | 1970-01-01 | f09f9888 |", + "+--------+--------+------------+----------+", + ]; + assert_batches_sorted_eq!(&expected, &batches); + + let pred = col("letter").eq(lit("/%20%f")); + let plan = snapshot.scan(&session.state(), None, &[pred], None).await?; + let batches: Vec<_> = collect_plan(plan, &session).await?; + let expected = vec![ + "+--------+--------+------------+------------+", + "| number | letter | date | data |", + "+--------+--------+------------+------------+", + "| 6 | /%20%f | 1970-01-01 | 68656c6c6f |", + "+--------+--------+------------+------------+", + ]; + assert_batches_sorted_eq!(&expected, &batches); + + // Since we are able to do an exact filter on a partition column, + // Datafusion will not include partition columns that are only referenced + // in a predicate but not part of the end result in the passed projection. + let pred = col("letter").eq(lit("/%20%f")); + let plan = snapshot + .scan(&session.state(), Some(&vec![0]), &[pred], None) + .await?; + let batches: Vec<_> = collect_plan(plan, &session).await?; + let expected = vec![ + "+--------+", + "| number |", + "+--------+", + "| 6 |", + "+--------+", + ]; + assert_batches_sorted_eq!(&expected, &batches); + + // COUNT(*) queries may not include any columns in the projection, + // but we still need to process predicates properly within the scan. + let pred = col("letter").eq(lit("/%20%f")); + let plan = snapshot + .scan(&session.state(), Some(&vec![]), &[pred], None) + .await?; + let batches: Vec<_> = collect_plan(plan, &session).await?; + let expected = vec!["++", "++", "++"]; + assert_batches_sorted_eq!(&expected, &batches); + let n_rows = batches.iter().map(|b| b.num_rows()).sum::(); + assert_eq!(n_rows, 1); + + Ok(()) +} + +#[tokio::test] +async fn test_column_mapping() -> TestResult<()> { + let (snapshot, session) = scan_dat("column_mapping").await?; + + let plan = snapshot.scan(&session.state(), None, &[], None).await?; + let batches: Vec<_> = collect_plan(plan, &session).await?; + let expected = vec![ + "+--------+---------+------------+", + "| letter | new_int | date |", + "+--------+---------+------------+", + "| a | 25 | 2017-05-01 |", + "| a | 25 | 2017-05-01 |", + "| a | 604 | 1997-01-01 |", + "| a | 604 | 1997-01-01 |", + "| a | 692 | 2017-09-01 |", + "| a | 692 | 2017-09-01 |", + "| a | 95 | 1983-04-01 |", + "| a | 95 | 1983-04-01 |", + "| b | 228 | 1978-12-01 |", + "| b | 228 | 1978-12-01 |", + "+--------+---------+------------+", + ]; + assert_batches_sorted_eq!(&expected, &batches); + + Ok(()) +} + +#[tokio::test] +async fn test_deletion_vectors() -> TestResult<()> { + let (snapshot, session) = scan_dat("deletion_vectors").await?; + + let plan = snapshot.scan(&session.state(), None, &[], None).await?; + let batches: Vec<_> = collect_plan(plan, &session).await?; + let expected = vec![ + "+--------+-----+------------+", + "| letter | int | date |", + "+--------+-----+------------+", + "| b | 228 | 1978-12-01 |", + "+--------+-----+------------+", + ]; + assert_batches_sorted_eq!(&expected, &batches); + + Ok(()) +}