Skip to content

Commit 3fefeeb

Browse files
authored
[spark] Merge into supports _ROW_ID shortcut (#6745)
1 parent 6177366 commit 3fefeeb

File tree

2 files changed

+107
-11
lines changed

2 files changed

+107
-11
lines changed

paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/commands/MergeIntoPaimonDataEvolutionTable.scala

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.paimon.table.source.DataSplit
3232
import org.apache.spark.sql.{Dataset, Row, SparkSession}
3333
import org.apache.spark.sql.PaimonUtils._
3434
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.resolver
35-
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, Literal}
35+
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, EqualTo, Expression, Literal}
3636
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
3737
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftOuter}
3838
import org.apache.spark.sql.catalyst.plans.logical._
@@ -148,18 +148,21 @@ case class MergeIntoPaimonDataEvolutionTable(
148148
}
149149

150150
private def targetRelatedSplits(sparkSession: SparkSession): Seq[DataSplit] = {
151-
val targetDss = createDataset(
152-
sparkSession,
153-
targetRelation
154-
)
155151
val sourceDss = createDataset(sparkSession, sourceRelation)
156152

157-
val firstRowIdsTouched = mutable.Set.empty[Long]
158-
159-
firstRowIdsTouched ++= findRelatedFirstRowIds(
160-
targetDss.alias("_left").join(sourceDss, toColumn(matchedCondition), "inner"),
161-
sparkSession,
162-
"_left." + ROW_ID_NAME)
153+
val firstRowIdsTouched = extractSourceRowIdMapping match {
154+
case Some(sourceRowIdAttr) =>
155+
// Shortcut: Directly get _FIRST_ROW_IDs from the source table.
156+
findRelatedFirstRowIds(sourceDss, sparkSession, sourceRowIdAttr.name).toSet
157+
158+
case None =>
159+
// Perform the full join to find related _FIRST_ROW_IDs.
160+
val targetDss = createDataset(sparkSession, targetRelation)
161+
findRelatedFirstRowIds(
162+
targetDss.alias("_left").join(sourceDss, toColumn(matchedCondition), "inner"),
163+
sparkSession,
164+
"_left." + ROW_ID_NAME).toSet
165+
}
163166

164167
table
165168
.newSnapshotReader()
@@ -312,6 +315,43 @@ case class MergeIntoPaimonDataEvolutionTable(
312315
writer.write(toWrite)
313316
}
314317

318+
/**
319+
* Attempts to identify a direct mapping from sourceTable's attribute to the target table's
320+
* `_ROW_ID`.
321+
*
322+
* This is a shortcut optimization for `MERGE INTO` to avoid a full, expensive join when the merge
323+
* condition is a simple equality on the target's `_ROW_ID`.
324+
*
325+
* @return
326+
* An `Option` containing the sourceTable's attribute if a pattern like
327+
* `target._ROW_ID = source.col` (or its reverse) is found, otherwise `None`.
328+
*/
329+
private def extractSourceRowIdMapping: Option[AttributeReference] = {
330+
331+
// Helper to check if an attribute is the target's _ROW_ID
332+
def isTargetRowId(attr: AttributeReference): Boolean = {
333+
attr.name == ROW_ID_NAME && (targetRelation.output ++ targetRelation.metadataOutput)
334+
.exists(_.exprId.equals(attr.exprId))
335+
}
336+
337+
// Helper to check if an attribute belongs to the source table
338+
def isSourceAttribute(attr: AttributeReference): Boolean = {
339+
(sourceRelation.output ++ sourceRelation.metadataOutput).exists(_.exprId.equals(attr.exprId))
340+
}
341+
342+
matchedCondition match {
343+
// Case 1: target._ROW_ID = source.col
344+
case EqualTo(left: AttributeReference, right: AttributeReference)
345+
if isTargetRowId(left) && isSourceAttribute(right) =>
346+
Some(right)
347+
// Case 2: source.col = target._ROW_ID
348+
case EqualTo(left: AttributeReference, right: AttributeReference)
349+
if isSourceAttribute(left) && isTargetRowId(right) =>
350+
Some(left)
351+
case _ => None
352+
}
353+
}
354+
315355
private def findRelatedFirstRowIds(
316356
dataset: Dataset[Row],
317357
sparkSession: SparkSession,

paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/RowTrackingTestBase.scala

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ import org.apache.paimon.Snapshot.CommitKind
2222
import org.apache.paimon.spark.PaimonSparkTestBase
2323

2424
import org.apache.spark.sql.Row
25+
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
26+
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
27+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
28+
import org.apache.spark.sql.execution.joins.BaseJoinExec
29+
import org.apache.spark.sql.util.QueryExecutionListener
30+
31+
import scala.collection.mutable
2532

2633
abstract class RowTrackingTestBase extends PaimonSparkTestBase {
2734

@@ -360,6 +367,55 @@ abstract class RowTrackingTestBase extends PaimonSparkTestBase {
360367
}
361368
}
362369

370+
test("Data Evolution: merge into table with data-evolution with _ROW_ID shortcut") {
371+
withTable("source", "target") {
372+
sql("CREATE TABLE source (target_ROW_ID BIGINT, b INT, c STRING)")
373+
sql(
374+
"INSERT INTO source VALUES (0, 100, 'c11'), (2, 300, 'c33'), (4, 500, 'c55'), (6, 700, 'c77'), (8, 900, 'c99')")
375+
376+
sql(
377+
"CREATE TABLE target (a INT, b INT, c STRING) TBLPROPERTIES ('row-tracking.enabled' = 'true', 'data-evolution.enabled' = 'true')")
378+
sql(
379+
"INSERT INTO target values (1, 10, 'c1'), (2, 20, 'c2'), (3, 30, 'c3'), (4, 40, 'c4'), (5, 50, 'c5')")
380+
381+
val capturedPlans: mutable.ListBuffer[LogicalPlan] = mutable.ListBuffer.empty
382+
val listener = new QueryExecutionListener {
383+
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
384+
capturedPlans += qe.analyzed
385+
}
386+
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
387+
capturedPlans += qe.analyzed
388+
}
389+
}
390+
spark.listenerManager.register(listener)
391+
sql(s"""
392+
|MERGE INTO target
393+
|USING source
394+
|ON target._ROW_ID = source.target_ROW_ID
395+
|WHEN MATCHED AND target.a = 5 THEN UPDATE SET b = source.b + target.b
396+
|WHEN MATCHED AND source.c > 'c2' THEN UPDATE SET b = source.b, c = source.c
397+
|WHEN NOT MATCHED AND c > 'c9' THEN INSERT (a, b, c) VALUES (target_ROW_ID, b * 1.1, c)
398+
|WHEN NOT MATCHED THEN INSERT (a, b, c) VALUES (target_ROW_ID, b, c)
399+
|""".stripMargin)
400+
// Assert that no Join operator was used during
401+
// `org.apache.paimon.spark.commands.MergeIntoPaimonDataEvolutionTable.targetRelatedSplits`
402+
assert(capturedPlans.head.collect { case plan: Join => plan }.isEmpty)
403+
spark.listenerManager.unregister(listener)
404+
405+
checkAnswer(
406+
sql("SELECT *, _ROW_ID, _SEQUENCE_NUMBER FROM target ORDER BY a"),
407+
Seq(
408+
Row(1, 10, "c1", 0, 2),
409+
Row(2, 20, "c2", 1, 2),
410+
Row(3, 300, "c33", 2, 2),
411+
Row(4, 40, "c4", 3, 2),
412+
Row(5, 550, "c5", 4, 2),
413+
Row(6, 700, "c77", 5, 2),
414+
Row(8, 990, "c99", 6, 2))
415+
)
416+
}
417+
}
418+
363419
test("Data Evolution: update table throws exception") {
364420
withTable("t") {
365421
sql(

0 commit comments

Comments
 (0)