Skip to content

Commit b46e69b

Browse files
committed
Refactor primitive read translators to reduce code duplication
- Add translateUnboundedSource() helper method for common unbounded source translation logic - Add translateBoundedSource() helper method for common bounded source translation logic - Simplify UnboundedReadSourceTranslator to use helper method - Simplify PrimitiveUnboundedReadTranslator to use helper method - Simplify BoundedReadSourceTranslator to use helper method - Simplify PrimitiveBoundedReadTranslator to use helper method - Add missing try-catch, .returns(), and batch mode slot sharing to PrimitiveBoundedReadTranslator This refactoring reduces ~100 lines of duplicated code and ensures consistent behavior across all read translators.
1 parent 282f84a commit b46e69b

File tree

1 file changed

+119
-188
lines changed

1 file changed

+119
-188
lines changed

runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java

Lines changed: 119 additions & 188 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
import org.apache.beam.sdk.values.WindowingStrategy;
9898
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
9999
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
100+
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
100101
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
101102
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
102103
import org.apache.flink.api.common.functions.FlatMapFunction;
@@ -195,29 +196,127 @@ public static String getCurrentTransformName(FlinkStreamingTranslationContext co
195196
// Transformation Implementations
196197
// --------------------------------------------------------------------------------------------
197198

199+
/** Common translation logic for unbounded sources. */
200+
@SuppressWarnings("unchecked")
201+
private static <T> void translateUnboundedSource(
202+
UnboundedSource<T, ?> rawSource,
203+
String transformName,
204+
FlinkStreamingTranslationContext context) {
205+
206+
PCollection<T> output =
207+
(PCollection<T>)
208+
Iterables.getOnlyElement(context.getCurrentTransform().getOutputs().values());
209+
210+
DataStream<WindowedValue<T>> source;
211+
DataStream<WindowedValue<ValueWithRecordId<T>>> nonDedupSource;
212+
TypeInformation<WindowedValue<T>> outputTypeInfo = context.getTypeInfo(output);
213+
214+
Coder<T> coder = output.getCoder();
215+
216+
TypeInformation<WindowedValue<ValueWithRecordId<T>>> withIdTypeInfo =
217+
new CoderTypeInformation<>(
218+
WindowedValues.getFullCoder(
219+
ValueWithRecordId.ValueWithRecordIdCoder.of(coder),
220+
output.getWindowingStrategy().getWindowFn().windowCoder()),
221+
context.getPipelineOptions());
222+
223+
String fullName = getCurrentTransformName(context);
224+
try {
225+
int parallelism =
226+
context.getExecutionEnvironment().getMaxParallelism() > 0
227+
? context.getExecutionEnvironment().getMaxParallelism()
228+
: context.getExecutionEnvironment().getParallelism();
229+
230+
FlinkUnboundedSource<T> unboundedSource =
231+
FlinkSource.unbounded(
232+
transformName,
233+
rawSource,
234+
new SerializablePipelineOptions(context.getPipelineOptions()),
235+
parallelism);
236+
nonDedupSource =
237+
context
238+
.getExecutionEnvironment()
239+
.fromSource(
240+
unboundedSource, WatermarkStrategy.noWatermarks(), fullName, withIdTypeInfo)
241+
.uid(fullName);
242+
243+
if (rawSource.requiresDeduping()) {
244+
source =
245+
nonDedupSource
246+
.keyBy(new ValueWithRecordIdKeySelector<>())
247+
.transform(
248+
"deduping",
249+
outputTypeInfo,
250+
new DedupingOperator<>(context.getPipelineOptions()))
251+
.uid(format("%s/__deduplicated__", fullName));
252+
} else {
253+
source =
254+
nonDedupSource
255+
.flatMap(new StripIdsMap<>(context.getPipelineOptions()))
256+
.returns(outputTypeInfo);
257+
}
258+
} catch (Exception e) {
259+
throw new RuntimeException("Error while translating UnboundedSource: " + rawSource, e);
260+
}
261+
262+
context.setOutputDataStream(output, source);
263+
}
264+
265+
/** Common translation logic for bounded sources. */
266+
@SuppressWarnings("unchecked")
267+
private static <T> void translateBoundedSource(
268+
BoundedSource<T> rawSource, String transformName, FlinkStreamingTranslationContext context) {
269+
270+
PCollection<T> output =
271+
(PCollection<T>)
272+
Iterables.getOnlyElement(context.getCurrentTransform().getOutputs().values());
273+
274+
TypeInformation<WindowedValue<T>> outputTypeInfo = context.getTypeInfo(output);
275+
276+
String fullName = getCurrentTransformName(context);
277+
int parallelism =
278+
context.getExecutionEnvironment().getMaxParallelism() > 0
279+
? context.getExecutionEnvironment().getMaxParallelism()
280+
: context.getExecutionEnvironment().getParallelism();
281+
282+
FlinkBoundedSource<T> flinkBoundedSource =
283+
FlinkSource.bounded(
284+
transformName,
285+
rawSource,
286+
new SerializablePipelineOptions(context.getPipelineOptions()),
287+
parallelism);
288+
289+
SingleOutputStreamOperator<WindowedValue<T>> source;
290+
try {
291+
source =
292+
context
293+
.getExecutionEnvironment()
294+
.fromSource(
295+
flinkBoundedSource, WatermarkStrategy.noWatermarks(), fullName, outputTypeInfo)
296+
.uid(fullName)
297+
.returns(outputTypeInfo);
298+
299+
if (!context.isStreaming()
300+
&& context
301+
.getPipelineOptions()
302+
.as(FlinkPipelineOptions.class)
303+
.getForceSlotSharingGroup()) {
304+
source = source.slotSharingGroup(FORCED_SLOT_GROUP);
305+
}
306+
} catch (Exception e) {
307+
throw new RuntimeException("Error while translating BoundedSource: " + rawSource, e);
308+
}
309+
310+
context.setOutputDataStream(output, source);
311+
}
312+
198313
private static class UnboundedReadSourceTranslator<T>
199314
extends FlinkStreamingPipelineTranslator.StreamTransformTranslator<
200315
PTransform<PBegin, PCollection<T>>> {
201316

202317
@Override
203318
public void translateNode(
204319
PTransform<PBegin, PCollection<T>> transform, FlinkStreamingTranslationContext context) {
205-
PCollection<T> output = context.getOutput(transform);
206-
207-
DataStream<WindowedValue<T>> source;
208-
DataStream<WindowedValue<ValueWithRecordId<T>>> nonDedupSource;
209-
TypeInformation<WindowedValue<T>> outputTypeInfo =
210-
context.getTypeInfo(context.getOutput(transform));
211-
212-
Coder<T> coder = context.getOutput(transform).getCoder();
213-
214-
TypeInformation<WindowedValue<ValueWithRecordId<T>>> withIdTypeInfo =
215-
new CoderTypeInformation<>(
216-
WindowedValues.getFullCoder(
217-
ValueWithRecordId.ValueWithRecordIdCoder.of(coder),
218-
output.getWindowingStrategy().getWindowFn().windowCoder()),
219-
context.getPipelineOptions());
220-
221320
UnboundedSource<T, ?> rawSource;
222321
try {
223322
rawSource =
@@ -227,47 +326,7 @@ public void translateNode(
227326
} catch (IOException e) {
228327
throw new RuntimeException(e);
229328
}
230-
231-
String fullName = getCurrentTransformName(context);
232-
try {
233-
int parallelism =
234-
context.getExecutionEnvironment().getMaxParallelism() > 0
235-
? context.getExecutionEnvironment().getMaxParallelism()
236-
: context.getExecutionEnvironment().getParallelism();
237-
238-
FlinkUnboundedSource<T> unboundedSource =
239-
FlinkSource.unbounded(
240-
transform.getName(),
241-
rawSource,
242-
new SerializablePipelineOptions(context.getPipelineOptions()),
243-
parallelism);
244-
nonDedupSource =
245-
context
246-
.getExecutionEnvironment()
247-
.fromSource(
248-
unboundedSource, WatermarkStrategy.noWatermarks(), fullName, withIdTypeInfo)
249-
.uid(fullName);
250-
251-
if (rawSource.requiresDeduping()) {
252-
source =
253-
nonDedupSource
254-
.keyBy(new ValueWithRecordIdKeySelector<>())
255-
.transform(
256-
"deduping",
257-
outputTypeInfo,
258-
new DedupingOperator<>(context.getPipelineOptions()))
259-
.uid(format("%s/__deduplicated__", fullName));
260-
} else {
261-
source =
262-
nonDedupSource
263-
.flatMap(new StripIdsMap<>(context.getPipelineOptions()))
264-
.returns(outputTypeInfo);
265-
}
266-
} catch (Exception e) {
267-
throw new RuntimeException("Error while translating UnboundedSource: " + rawSource, e);
268-
}
269-
270-
context.setOutputDataStream(output, source);
329+
translateUnboundedSource(rawSource, transform.getName(), context);
271330
}
272331
}
273332

@@ -285,66 +344,7 @@ private static class PrimitiveUnboundedReadTranslator<T>
285344
public void translateNode(
286345
SplittableParDo.PrimitiveUnboundedRead<T> transform,
287346
FlinkStreamingTranslationContext context) {
288-
289-
PCollection<T> output = context.getOutput(transform);
290-
291-
DataStream<WindowedValue<T>> source;
292-
DataStream<WindowedValue<ValueWithRecordId<T>>> nonDedupSource;
293-
TypeInformation<WindowedValue<T>> outputTypeInfo =
294-
context.getTypeInfo(context.getOutput(transform));
295-
296-
Coder<T> coder = context.getOutput(transform).getCoder();
297-
298-
TypeInformation<WindowedValue<ValueWithRecordId<T>>> withIdTypeInfo =
299-
new CoderTypeInformation<>(
300-
WindowedValues.getFullCoder(
301-
ValueWithRecordId.ValueWithRecordIdCoder.of(coder),
302-
output.getWindowingStrategy().getWindowFn().windowCoder()),
303-
context.getPipelineOptions());
304-
305-
// Get source directly from PrimitiveUnboundedRead (not via ReadTranslation)
306-
UnboundedSource<T, ?> rawSource = transform.getSource();
307-
308-
String fullName = getCurrentTransformName(context);
309-
try {
310-
int parallelism =
311-
context.getExecutionEnvironment().getMaxParallelism() > 0
312-
? context.getExecutionEnvironment().getMaxParallelism()
313-
: context.getExecutionEnvironment().getParallelism();
314-
315-
FlinkUnboundedSource<T> unboundedSource =
316-
FlinkSource.unbounded(
317-
transform.getName(),
318-
rawSource,
319-
new SerializablePipelineOptions(context.getPipelineOptions()),
320-
parallelism);
321-
nonDedupSource =
322-
context
323-
.getExecutionEnvironment()
324-
.fromSource(
325-
unboundedSource, WatermarkStrategy.noWatermarks(), fullName, withIdTypeInfo)
326-
.uid(fullName);
327-
328-
if (rawSource.requiresDeduping()) {
329-
source =
330-
nonDedupSource
331-
.keyBy(new ValueWithRecordIdKeySelector<>())
332-
.transform(
333-
"deduping",
334-
outputTypeInfo,
335-
new DedupingOperator<>(context.getPipelineOptions()))
336-
.uid(format("%s/__deduplicated__", fullName));
337-
} else {
338-
source =
339-
nonDedupSource
340-
.flatMap(new StripIdsMap<>(context.getPipelineOptions()))
341-
.returns(outputTypeInfo);
342-
}
343-
} catch (Exception e) {
344-
throw new RuntimeException("Error while translating UnboundedSource: " + rawSource, e);
345-
}
346-
347-
context.setOutputDataStream(output, source);
347+
translateUnboundedSource(transform.getSource(), transform.getName(), context);
348348
}
349349
}
350350

@@ -362,35 +362,7 @@ private static class PrimitiveBoundedReadTranslator<T>
362362
public void translateNode(
363363
SplittableParDo.PrimitiveBoundedRead<T> transform,
364364
FlinkStreamingTranslationContext context) {
365-
366-
PCollection<T> output = context.getOutput(transform);
367-
TypeInformation<WindowedValue<T>> outputTypeInfo =
368-
context.getTypeInfo(context.getOutput(transform));
369-
370-
// Get source directly from PrimitiveBoundedRead (not via ReadTranslation)
371-
BoundedSource<T> rawSource = transform.getSource();
372-
373-
String fullName = getCurrentTransformName(context);
374-
int parallelism =
375-
context.getExecutionEnvironment().getMaxParallelism() > 0
376-
? context.getExecutionEnvironment().getMaxParallelism()
377-
: context.getExecutionEnvironment().getParallelism();
378-
379-
FlinkBoundedSource<T> flinkBoundedSource =
380-
FlinkSource.bounded(
381-
fullName,
382-
rawSource,
383-
new SerializablePipelineOptions(context.getPipelineOptions()),
384-
parallelism);
385-
386-
DataStream<WindowedValue<T>> source =
387-
context
388-
.getExecutionEnvironment()
389-
.fromSource(
390-
flinkBoundedSource, WatermarkStrategy.noWatermarks(), fullName, outputTypeInfo)
391-
.uid(fullName);
392-
393-
context.setOutputDataStream(output, source);
365+
translateBoundedSource(transform.getSource(), transform.getName(), context);
394366
}
395367
}
396368

@@ -503,11 +475,6 @@ private static class BoundedReadSourceTranslator<T>
503475
@Override
504476
public void translateNode(
505477
PTransform<PBegin, PCollection<T>> transform, FlinkStreamingTranslationContext context) {
506-
PCollection<T> output = context.getOutput(transform);
507-
508-
TypeInformation<WindowedValue<T>> outputTypeInfo =
509-
context.getTypeInfo(context.getOutput(transform));
510-
511478
BoundedSource<T> rawSource;
512479
try {
513480
rawSource =
@@ -517,43 +484,7 @@ public void translateNode(
517484
} catch (IOException e) {
518485
throw new RuntimeException(e);
519486
}
520-
521-
String fullName = getCurrentTransformName(context);
522-
int parallelism =
523-
context.getExecutionEnvironment().getMaxParallelism() > 0
524-
? context.getExecutionEnvironment().getMaxParallelism()
525-
: context.getExecutionEnvironment().getParallelism();
526-
527-
FlinkBoundedSource<T> flinkBoundedSource =
528-
FlinkSource.bounded(
529-
transform.getName(),
530-
rawSource,
531-
new SerializablePipelineOptions(context.getPipelineOptions()),
532-
parallelism);
533-
534-
TypeInformation<WindowedValue<T>> typeInfo = context.getTypeInfo(output);
535-
536-
SingleOutputStreamOperator<WindowedValue<T>> source;
537-
try {
538-
source =
539-
context
540-
.getExecutionEnvironment()
541-
.fromSource(
542-
flinkBoundedSource, WatermarkStrategy.noWatermarks(), fullName, outputTypeInfo)
543-
.uid(fullName)
544-
.returns(typeInfo);
545-
546-
if (!context.isStreaming()
547-
&& context
548-
.getPipelineOptions()
549-
.as(FlinkPipelineOptions.class)
550-
.getForceSlotSharingGroup()) {
551-
source = source.slotSharingGroup(FORCED_SLOT_GROUP);
552-
}
553-
} catch (Exception e) {
554-
throw new RuntimeException("Error while translating BoundedSource: " + rawSource, e);
555-
}
556-
context.setOutputDataStream(output, source);
487+
translateBoundedSource(rawSource, transform.getName(), context);
557488
}
558489
}
559490

0 commit comments

Comments
 (0)