Skip to content

Commit e0f4a8d

Browse files
committed
Enhance physical-expr projection handling
This PR adds trait implementations, a project_batch() method, and fixes a bug in update_expr() for literal expressions. Also adds comprehensive tests. Part of #18627
1 parent 0cfc1fe commit e0f4a8d

File tree

1 file changed

+135
-3
lines changed

1 file changed

+135
-3
lines changed

datafusion/physical-expr/src/projection.rs

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use crate::expressions::Column;
2222
use crate::utils::collect_columns;
2323
use crate::PhysicalExpr;
2424

25+
use arrow::array::{RecordBatch, RecordBatchOptions};
2526
use arrow::datatypes::{Field, Schema, SchemaRef};
2627
use datafusion_common::stats::{ColumnStatistics, Precision};
2728
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
@@ -47,6 +48,14 @@ pub struct ProjectionExpr {
4748
pub alias: String,
4849
}
4950

51+
impl PartialEq for ProjectionExpr {
52+
fn eq(&self, other: &Self) -> bool {
53+
self.expr.eq(&other.expr) && self.alias == other.alias
54+
}
55+
}
56+
57+
impl Eq for ProjectionExpr {}
58+
5059
impl std::fmt::Display for ProjectionExpr {
5160
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
5261
if self.expr.to_string() == self.alias {
@@ -99,7 +108,7 @@ impl From<ProjectionExpr> for (Arc<dyn PhysicalExpr>, String) {
99108
/// This struct encapsulates multiple `ProjectionExpr` instances,
100109
/// representing a complete projection operation and provides
101110
/// methods to manipulate and analyze the projection as a whole.
102-
#[derive(Debug, Clone)]
111+
#[derive(Debug, Clone, PartialEq, Eq)]
103112
pub struct ProjectionExprs {
104113
exprs: Vec<ProjectionExpr>,
105114
}
@@ -192,7 +201,7 @@ impl ProjectionExprs {
192201
/// assert_eq!(projection_with_dups.as_ref()[1].alias, "a"); // duplicate
193202
/// assert_eq!(projection_with_dups.as_ref()[2].alias, "b");
194203
/// ```
195-
pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Self {
204+
pub fn from_indices(indices: &[usize], schema: &Schema) -> Self {
196205
let projection_exprs = indices.iter().map(|&i| {
197206
let field = schema.field(i);
198207
ProjectionExpr {
@@ -396,6 +405,35 @@ impl ProjectionExprs {
396405
))
397406
}
398407

408+
/// Project a RecordBatch.
409+
///
410+
/// This function accepts a pre-computed output schema instead of calling [`ProjectionExprs::project_schema`]
411+
/// so that repeated calls do not have schema projection overhead.
412+
pub fn project_batch(
413+
&self,
414+
batch: &RecordBatch,
415+
output_schema: SchemaRef,
416+
) -> Result<RecordBatch> {
417+
let arrays = self
418+
.exprs
419+
.iter()
420+
.map(|expr| {
421+
expr.expr
422+
.evaluate(batch)
423+
.and_then(|v| v.into_array(batch.num_rows()))
424+
})
425+
.collect::<Result<Vec<_>>>()?;
426+
427+
if arrays.is_empty() {
428+
let options =
429+
RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
430+
RecordBatch::try_new_with_options(output_schema, arrays, &options)
431+
.map_err(Into::into)
432+
} else {
433+
RecordBatch::try_new(output_schema, arrays).map_err(Into::into)
434+
}
435+
}
436+
399437
/// Project statistics according to this projection.
400438
/// For example, for a projection `SELECT a AS x, b + 1 AS y`, where `a` is at index 0 and `b` is at index 1,
401439
/// if the input statistics has column statistics for columns `a`, `b`, and `c`, the output statistics would have column statistics for columns `x` and `y`.
@@ -545,7 +583,13 @@ pub fn update_expr(
545583
})
546584
.data()?;
547585

548-
Ok((state == RewriteState::RewrittenValid).then_some(new_expr))
586+
match state {
587+
RewriteState::RewrittenInvalid => Ok(None),
588+
// Both Unchanged and RewrittenValid are valid:
589+
// - Unchanged means no columns to rewrite (e.g., literals)
590+
// - RewrittenValid means columns were successfully rewritten
591+
RewriteState::Unchanged | RewriteState::RewrittenValid => Ok(Some(new_expr)),
592+
}
549593
}
550594

551595
/// Stores target expressions, along with their indices, that associate with a
@@ -2009,6 +2053,94 @@ pub(crate) mod tests {
20092053
);
20102054
}
20112055

2056+
#[test]
2057+
fn test_merge_empty_projection_with_literal() -> Result<()> {
2058+
// This test reproduces the issue from roundtrip_empty_projection test
2059+
// Query like: SELECT 1 FROM table
2060+
// where the file scan needs no columns (empty projection)
2061+
// but we project a literal on top
2062+
2063+
// Empty base projection (no columns needed from file)
2064+
let base_projection = ProjectionExprs::new(vec![]);
2065+
2066+
// Top projection with a literal expression: SELECT 1
2067+
let top_projection = ProjectionExprs::new(vec![ProjectionExpr {
2068+
expr: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
2069+
alias: "Int64(1)".to_string(),
2070+
}]);
2071+
2072+
// This should succeed - literals don't reference columns so they should
2073+
// pass through unchanged when merged with an empty projection
2074+
let merged = base_projection.try_merge(&top_projection)?;
2075+
assert_snapshot!(format!("{merged}"), @"Projection[1 AS Int64(1)]");
2076+
2077+
Ok(())
2078+
}
2079+
2080+
#[test]
2081+
fn test_update_expr_with_literal() -> Result<()> {
2082+
// Test that update_expr correctly handles expressions without column references
2083+
let literal_expr: Arc<dyn PhysicalExpr> =
2084+
Arc::new(Literal::new(ScalarValue::Int64(Some(42))));
2085+
let empty_projection: Vec<ProjectionExpr> = vec![];
2086+
2087+
// Updating a literal with an empty projection should return the literal unchanged
2088+
let result = update_expr(&literal_expr, &empty_projection, true)?;
2089+
assert!(result.is_some(), "Literal expression should be valid");
2090+
2091+
let result_expr = result.unwrap();
2092+
assert_eq!(
2093+
result_expr
2094+
.as_any()
2095+
.downcast_ref::<Literal>()
2096+
.unwrap()
2097+
.value(),
2098+
&ScalarValue::Int64(Some(42))
2099+
);
2100+
2101+
Ok(())
2102+
}
2103+
2104+
#[test]
2105+
fn test_update_expr_with_complex_literal_expr() -> Result<()> {
2106+
// Test update_expr with an expression containing both literals and a column
2107+
// This tests the case where we have: literal + column
2108+
let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
2109+
Arc::new(Literal::new(ScalarValue::Int64(Some(10)))),
2110+
Operator::Plus,
2111+
Arc::new(Column::new("x", 0)),
2112+
));
2113+
2114+
// Base projection that maps column 0 to a different expression
2115+
let base_projection = vec![ProjectionExpr {
2116+
expr: Arc::new(Column::new("a", 5)),
2117+
alias: "x".to_string(),
2118+
}];
2119+
2120+
// The expression should be updated: 10 + x@0 becomes 10 + a@5
2121+
let result = update_expr(&expr, &base_projection, true)?;
2122+
assert!(result.is_some(), "Expression should be valid");
2123+
2124+
let result_expr = result.unwrap();
2125+
let binary = result_expr
2126+
.as_any()
2127+
.downcast_ref::<BinaryExpr>()
2128+
.expect("Should be a BinaryExpr");
2129+
2130+
// Left side should still be the literal
2131+
assert!(binary.left().as_any().downcast_ref::<Literal>().is_some());
2132+
2133+
// Right side should be updated to reference column at index 5
2134+
let right_col = binary
2135+
.right()
2136+
.as_any()
2137+
.downcast_ref::<Column>()
2138+
.expect("Right should be a Column");
2139+
assert_eq!(right_col.index(), 5);
2140+
2141+
Ok(())
2142+
}
2143+
20122144
#[test]
20132145
fn test_project_schema_simple_columns() -> Result<()> {
20142146
// Input schema: [col0: Int64, col1: Utf8, col2: Float32]

0 commit comments

Comments
 (0)