Skip to content

Commit f8f7091

Browse files
[arrow] Improve customization capabilities for data type conversion. (#6695)
1 parent 1a1ff56 commit f8f7091

File tree

7 files changed

+294
-14
lines changed

7 files changed

+294
-14
lines changed

paimon-arrow/src/main/java/org/apache/paimon/arrow/ArrowUtils.java

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,15 @@ public class ArrowUtils {
6464

6565
public static VectorSchemaRoot createVectorSchemaRoot(
6666
RowType rowType, BufferAllocator allocator) {
67-
return createVectorSchemaRoot(rowType, allocator, true);
67+
return createVectorSchemaRoot(
68+
rowType, allocator, true, ArrowFieldTypeConversion.ARROW_FIELD_TYPE_VISITOR);
6869
}
6970

7071
public static VectorSchemaRoot createVectorSchemaRoot(
71-
RowType rowType, BufferAllocator allocator, boolean caseSensitive) {
72+
RowType rowType,
73+
BufferAllocator allocator,
74+
boolean caseSensitive,
75+
ArrowFieldTypeConversion.ArrowFieldTypeVisitor visitor) {
7276
List<Field> fields =
7377
rowType.getFields().stream()
7478
.map(
@@ -77,23 +81,53 @@ public static VectorSchemaRoot createVectorSchemaRoot(
7781
toLowerCaseIfNeed(f.name(), caseSensitive),
7882
f.id(),
7983
f.type(),
80-
0))
84+
0,
85+
visitor))
8186
.collect(Collectors.toList());
8287
return VectorSchemaRoot.create(new Schema(fields), allocator);
8388
}
8489

90+
public static FieldVector createVector(
91+
DataField dataField,
92+
BufferAllocator allocator,
93+
boolean caseSensitive,
94+
ArrowFieldTypeConversion.ArrowFieldTypeVisitor visitor) {
95+
return toArrowField(
96+
toLowerCaseIfNeed(dataField.name(), caseSensitive),
97+
dataField.id(),
98+
dataField.type(),
99+
0,
100+
visitor)
101+
.createVector(allocator);
102+
}
103+
85104
public static FieldVector createVector(
86105
DataField dataField, BufferAllocator allocator, boolean caseSensitive) {
87106
return toArrowField(
88107
toLowerCaseIfNeed(dataField.name(), caseSensitive),
89108
dataField.id(),
90109
dataField.type(),
91-
0)
110+
0,
111+
ArrowFieldTypeConversion.ARROW_FIELD_TYPE_VISITOR)
92112
.createVector(allocator);
93113
}
94114

95115
public static Field toArrowField(String fieldName, int fieldId, DataType dataType, int depth) {
96-
FieldType fieldType = dataType.accept(ArrowFieldTypeConversion.ARROW_FIELD_TYPE_VISITOR);
116+
return toArrowField(
117+
fieldName,
118+
fieldId,
119+
dataType,
120+
depth,
121+
ArrowFieldTypeConversion.ARROW_FIELD_TYPE_VISITOR);
122+
}
123+
124+
public static Field toArrowField(
125+
String fieldName,
126+
int fieldId,
127+
DataType dataType,
128+
int depth,
129+
ArrowFieldTypeConversion.ArrowFieldTypeVisitor visitor) {
130+
FieldType fieldType = dataType.accept(visitor);
97131
fieldType =
98132
new FieldType(
99133
fieldType.isNullable(),
@@ -107,7 +141,8 @@ public static Field toArrowField(String fieldName, int fieldId, DataType dataTyp
107141
ListVector.DATA_VECTOR_NAME,
108142
fieldId,
109143
((ArrayType) dataType).getElementType(),
110-
depth + 1);
144+
depth + 1,
145+
visitor);
111146
FieldType typeInner = field.getFieldType();
112147
field =
113148
new Field(
@@ -128,7 +163,11 @@ public static Field toArrowField(String fieldName, int fieldId, DataType dataTyp
128163

129164
Field keyField =
130165
toArrowField(
131-
MapVector.KEY_NAME, fieldId, mapType.getKeyType().notNull(), depth + 1);
166+
MapVector.KEY_NAME,
167+
fieldId,
168+
mapType.getKeyType().notNull(),
169+
depth + 1,
170+
visitor);
132171
FieldType keyType = keyField.getFieldType();
133172
keyField =
134173
new Field(
@@ -145,7 +184,12 @@ public static Field toArrowField(String fieldName, int fieldId, DataType dataTyp
145184
keyField.getChildren());
146185

147186
Field valueField =
148-
toArrowField(MapVector.VALUE_NAME, fieldId, mapType.getValueType(), depth + 1);
187+
toArrowField(
188+
MapVector.VALUE_NAME,
189+
fieldId,
190+
mapType.getValueType(),
191+
depth + 1,
192+
visitor);
149193
FieldType valueType = valueField.getFieldType();
150194
valueField =
151195
new Field(
@@ -179,7 +223,7 @@ public static Field toArrowField(String fieldName, int fieldId, DataType dataTyp
179223
RowType rowType = (RowType) dataType;
180224
children = new ArrayList<>();
181225
for (DataField field : rowType.getFields()) {
182-
children.add(toArrowField(field.name(), field.id(), field.type(), 0));
226+
children.add(toArrowField(field.name(), field.id(), field.type(), 0, visitor));
183227
}
184228
}
185229
return new Field(fieldName, fieldType, children);

paimon-arrow/src/main/java/org/apache/paimon/arrow/converter/Arrow2PaimonVectorConverter.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,18 @@ static Arrow2PaimonVectorConverter construct(DataType type) {
9595
return type.accept(Arrow2PaimonVectorConvertorVisitor.INSTANCE);
9696
}
9797

98+
static Arrow2PaimonVectorConverter construct(
99+
Arrow2PaimonVectorConvertorVisitor visitor, DataType type) {
100+
return type.accept(visitor);
101+
}
102+
98103
ColumnVector convertVector(FieldVector vector);
99104

100105
/** Visitor to create convertor from arrow to paimon. */
101106
class Arrow2PaimonVectorConvertorVisitor
102107
implements DataTypeVisitor<Arrow2PaimonVectorConverter> {
103108

104-
private static final Arrow2PaimonVectorConvertorVisitor INSTANCE =
109+
public static final Arrow2PaimonVectorConvertorVisitor INSTANCE =
105110
new Arrow2PaimonVectorConvertorVisitor();
106111

107112
@Override

paimon-arrow/src/main/java/org/apache/paimon/arrow/reader/ArrowBatchReader.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,24 @@ public class ArrowBatchReader {
4444
private final boolean caseSensitive;
4545

4646
public ArrowBatchReader(RowType rowType, boolean caseSensitive) {
47+
this(
48+
rowType,
49+
caseSensitive,
50+
Arrow2PaimonVectorConverter.Arrow2PaimonVectorConvertorVisitor.INSTANCE);
51+
}
52+
53+
public ArrowBatchReader(
54+
RowType rowType,
55+
boolean caseSensitive,
56+
Arrow2PaimonVectorConverter.Arrow2PaimonVectorConvertorVisitor visitor) {
4757
ColumnVector[] columnVectors = new ColumnVector[rowType.getFieldCount()];
4858
this.convertors = new Arrow2PaimonVectorConverter[rowType.getFieldCount()];
4959
this.batch = new VectorizedColumnBatch(columnVectors);
5060
this.projectedRowType = rowType;
5161

5262
for (int i = 0; i < columnVectors.length; i++) {
53-
this.convertors[i] = Arrow2PaimonVectorConverter.construct(rowType.getTypeAt(i));
63+
this.convertors[i] =
64+
Arrow2PaimonVectorConverter.construct(visitor, rowType.getTypeAt(i));
5465
}
5566
this.caseSensitive = caseSensitive;
5667
}

paimon-arrow/src/main/java/org/apache/paimon/arrow/vector/ArrowFormatCWriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public ArrowFormatCWriter(
5656
this(new ArrowFormatWriter(rowType, writeBatchSize, caseSensitive, allocator, null));
5757
}
5858

59-
private ArrowFormatCWriter(ArrowFormatWriter arrowFormatWriter) {
59+
public ArrowFormatCWriter(ArrowFormatWriter arrowFormatWriter) {
6060
this.realWriter = arrowFormatWriter;
6161
BufferAllocator allocator = realWriter.getAllocator();
6262
array = ArrowArray.allocateNew(allocator);

paimon-arrow/src/main/java/org/apache/paimon/arrow/vector/ArrowFormatWriter.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
package org.apache.paimon.arrow.vector;
2020

21+
import org.apache.paimon.arrow.ArrowFieldTypeConversion;
2122
import org.apache.paimon.arrow.ArrowUtils;
2223
import org.apache.paimon.arrow.writer.ArrowFieldWriter;
2324
import org.apache.paimon.arrow.writer.ArrowFieldWriterFactoryVisitor;
@@ -65,16 +66,36 @@ public ArrowFormatWriter(
6566
boolean caseSensitive,
6667
BufferAllocator allocator,
6768
@Nullable Long memoryUsedMaxInBytes) {
69+
this(
70+
rowType,
71+
writeBatchSize,
72+
caseSensitive,
73+
new RootAllocator(),
74+
memoryUsedMaxInBytes,
75+
ArrowFieldTypeConversion.ARROW_FIELD_TYPE_VISITOR,
76+
ArrowFieldWriterFactoryVisitor.INSTANCE);
77+
}
78+
79+
public ArrowFormatWriter(
80+
RowType rowType,
81+
int writeBatchSize,
82+
boolean caseSensitive,
83+
BufferAllocator allocator,
84+
@Nullable Long memoryUsedMaxInBytes,
85+
ArrowFieldTypeConversion.ArrowFieldTypeVisitor fieldTypeVisitor,
86+
ArrowFieldWriterFactoryVisitor fieldWriterFactory) {
6887
this.allocator = allocator;
6988

70-
vectorSchemaRoot = ArrowUtils.createVectorSchemaRoot(rowType, allocator, caseSensitive);
89+
vectorSchemaRoot =
90+
ArrowUtils.createVectorSchemaRoot(
91+
rowType, allocator, caseSensitive, fieldTypeVisitor);
7192

7293
fieldWriters = new ArrowFieldWriter[rowType.getFieldCount()];
7394

7495
for (int i = 0; i < fieldWriters.length; i++) {
7596
DataType type = rowType.getFields().get(i).type();
7697
fieldWriters[i] =
77-
type.accept(ArrowFieldWriterFactoryVisitor.INSTANCE)
98+
type.accept(fieldWriterFactory)
7899
.create(vectorSchemaRoot.getVector(i), type.isNullable());
79100
}
80101

paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
package org.apache.paimon.arrow.vector;
2020

2121
import org.apache.paimon.arrow.ArrowBundleRecords;
22+
import org.apache.paimon.arrow.ArrowFieldTypeConversion;
23+
import org.apache.paimon.arrow.converter.Arrow2PaimonVectorConverter;
2224
import org.apache.paimon.arrow.reader.ArrowBatchReader;
25+
import org.apache.paimon.arrow.writer.ArrowFieldWriterFactoryVisitor;
2326
import org.apache.paimon.data.BinaryString;
2427
import org.apache.paimon.data.Decimal;
2528
import org.apache.paimon.data.GenericArray;
@@ -319,6 +322,37 @@ public void testWriteArrayMapTwice() {
319322
}
320323
}
321324

325+
@Test
326+
public void testCustomArrowFormatCWriter() {
327+
// Create custom field type visitor that converts decimals to binary
328+
ArrowFieldTypeConversion.ArrowFieldTypeVisitor customFieldTypeVisitor =
329+
new CustomDecimalArrowConversion.CustomArrowFieldTypeFactory();
330+
331+
// Create custom field writer factory visitor for decimal to binary conversion
332+
ArrowFieldWriterFactoryVisitor customFieldWriterVisitor =
333+
new CustomDecimalArrowConversion.CustomArrowFieldWriterFactory();
334+
335+
// Create custom vector converter visitor for binary to decimal conversion
336+
Arrow2PaimonVectorConverter.Arrow2PaimonVectorConvertorVisitor customConverterVisitor =
337+
new CustomDecimalArrowConversion.CustomArrow2PaimonVectorConvertorVisitor();
338+
339+
try (RootAllocator allocator = new RootAllocator()) {
340+
// Create writer with custom visitors
341+
try (ArrowFormatCWriter writer =
342+
new ArrowFormatCWriter(
343+
new ArrowFormatWriter(
344+
PRIMITIVE_TYPE,
345+
4096,
346+
true,
347+
allocator,
348+
null,
349+
customFieldTypeVisitor,
350+
customFieldWriterVisitor))) {
351+
writeAndCheckCustom(writer, customConverterVisitor);
352+
}
353+
}
354+
}
355+
322356
private void writeAndCheckArrayMap(ArrowFormatWriter arrowFormatWriter) {
323357
GenericRow genericRow = new GenericRow(1);
324358
Map<BinaryString, BinaryString> map = new HashMap<>();
@@ -452,6 +486,40 @@ private void writeAndCheck(ArrowFormatCWriter writer) {
452486
vectorSchemaRoot.close();
453487
}
454488

489+
private void writeAndCheckCustom(
490+
ArrowFormatCWriter writer,
491+
Arrow2PaimonVectorConverter.Arrow2PaimonVectorConvertorVisitor visitor) {
492+
List<InternalRow> list = new ArrayList<>();
493+
List<InternalRow.FieldGetter> fieldGetters = new ArrayList<>();
494+
495+
for (int i = 0; i < PRIMITIVE_TYPE.getFieldCount(); i++) {
496+
fieldGetters.add(InternalRow.createFieldGetter(PRIMITIVE_TYPE.getTypeAt(i), i));
497+
}
498+
for (int i = 0; i < 1000; i++) {
499+
list.add(GenericRow.of(randomRowValues(null)));
500+
}
501+
502+
list.forEach(writer::write);
503+
504+
writer.flush();
505+
VectorSchemaRoot vectorSchemaRoot = writer.getVectorSchemaRoot();
506+
507+
ArrowBatchReader arrowBatchReader = new ArrowBatchReader(PRIMITIVE_TYPE, true, visitor);
508+
Iterable<InternalRow> rows = arrowBatchReader.readBatch(vectorSchemaRoot);
509+
510+
Iterator<InternalRow> iterator = rows.iterator();
511+
for (int i = 0; i < 1000; i++) {
512+
InternalRow actual = iterator.next();
513+
InternalRow expectec = list.get(i);
514+
515+
for (InternalRow.FieldGetter fieldGetter : fieldGetters) {
516+
assertThat(fieldGetter.getFieldOrNull(actual))
517+
.isEqualTo(fieldGetter.getFieldOrNull(expectec));
518+
}
519+
}
520+
vectorSchemaRoot.close();
521+
}
522+
455523
private Object[] randomRowValues(boolean[] nullable) {
456524
Object[] values = new Object[18];
457525
values[0] = BinaryString.fromString(StringUtils.getRandomString(RND, 10, 10));

0 commit comments

Comments
 (0)