Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ public VariantField(DataField dataField, String path, VariantCastArgs castArgs)
this.castArgs = castArgs;
}

public VariantField(DataField dataField, String path) {
this(dataField, path, VariantCastArgs.defaultArgs());
}

public DataField dataField() {
return dataField;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.apache.paimon.data.variant.PaimonShreddingUtils.METADATA_FIELD_NAME;
import static org.apache.paimon.data.variant.PaimonShreddingUtils.TYPED_VALUE_FIELD_NAME;
import static org.apache.paimon.data.variant.PaimonShreddingUtils.VARIANT_VALUE_FIELD_NAME;

/** Utils for variant access. */
public class VariantAccessInfoUtils {
Expand Down Expand Up @@ -61,4 +67,49 @@ public static RowType buildReadRowType(
}
return new RowType(fields);
}

/** Clip the variant schema to read with variant access fields. */
public static RowType clipVariantSchema(
RowType shreddingSchema, List<VariantAccessInfo.VariantField> variantFields) {
boolean canClip = true;
Set<String> fieldsToRead = new HashSet<>();
for (VariantAccessInfo.VariantField variantField : variantFields) {
VariantPathSegment[] pathSegments = VariantPathSegment.parse(variantField.path());
if (pathSegments.length < 1) {
canClip = false;
break;
}

// todo: support nested column pruning
VariantPathSegment pathSegment = pathSegments[0];
if (pathSegment instanceof VariantPathSegment.ObjectExtraction) {
fieldsToRead.add(((VariantPathSegment.ObjectExtraction) pathSegment).getKey());
} else {
canClip = false;
break;
}
}

if (!canClip) {
return shreddingSchema;
}

List<DataField> typedFieldsToRead = new ArrayList<>();
DataField typedValue = shreddingSchema.getField(TYPED_VALUE_FIELD_NAME);
for (DataField field : ((RowType) typedValue.type()).getFields()) {
if (fieldsToRead.contains(field.name())) {
typedFieldsToRead.add(field);
fieldsToRead.remove(field.name());
}
}

List<DataField> shreddingSchemaFields = new ArrayList<>();
shreddingSchemaFields.add(shreddingSchema.getField(METADATA_FIELD_NAME));
// If there are fields to read not in the `typed_value`, add the `value` field.
if (!fieldsToRead.isEmpty()) {
shreddingSchemaFields.add(shreddingSchema.getField(VARIANT_VALUE_FIELD_NAME));
}
shreddingSchemaFields.add(typedValue.newType(new RowType(typedFieldsToRead)));
return new RowType(shreddingSchemaFields);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.io.Serializable;
import java.time.ZoneId;
import java.time.ZoneOffset;

/** Several parameters used by `VariantGet.cast`. Packed together to simplify parameter passing. */
public class VariantCastArgs implements Serializable {
Expand All @@ -42,6 +43,10 @@ public ZoneId zoneId() {
return zoneId;
}

public static VariantCastArgs defaultArgs() {
return new VariantCastArgs(true, ZoneOffset.UTC);
}

@Override
public String toString() {
return "VariantCastArgs{" + "failOnError=" + failOnError + ", zoneId=" + zoneId + '}';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,14 @@ private static ParquetField constructField(
if (type instanceof VariantType) {
if (shreddingSchema != null) {
VariantType variantType = (VariantType) type;
DataType clippedParquetType =
variantFields == null
? shreddingSchema
: VariantAccessInfoUtils.clipVariantSchema(
shreddingSchema, variantFields);
ParquetGroupField parquetField =
(ParquetGroupField)
constructField(dataField.newType(shreddingSchema), columnIO);
constructField(dataField.newType(clippedParquetType), columnIO);
DataType readType =
variantFields == null
? variantType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,22 @@

package org.apache.paimon.format.parquet;

import org.apache.paimon.data.BinaryString;
import org.apache.paimon.data.GenericRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.data.serializer.InternalRowSerializer;
import org.apache.paimon.data.variant.GenericVariant;
import org.apache.paimon.data.variant.VariantAccessInfo;
import org.apache.paimon.format.FileFormat;
import org.apache.paimon.format.FileFormatFactory;
import org.apache.paimon.format.FormatReadWriteTest;
import org.apache.paimon.format.FormatReaderContext;
import org.apache.paimon.format.FormatWriter;
import org.apache.paimon.format.FormatWriterFactory;
import org.apache.paimon.fs.PositionOutputStream;
import org.apache.paimon.options.Options;
import org.apache.paimon.reader.RecordReader;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.DataTypes;
import org.apache.paimon.types.RowType;

Expand All @@ -34,12 +43,16 @@
import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData;
import org.apache.parquet.hadoop.metadata.ParquetMetadata;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;

import static org.assertj.core.api.Assertions.assertThat;

/** A parquet {@link FormatReadWriteTest}. */
public class ParquetFormatReadWriteTest extends FormatReadWriteTest {

Expand Down Expand Up @@ -88,4 +101,92 @@ public void testEnableBloomFilter(boolean enabled) throws Exception {
}
}
}

@Test
public void testReadShreddedVariant() throws Exception {
Options options = new Options();
options.set(
"parquet.variant.shreddingSchema",
"{\"type\":\"ROW\",\"fields\":[{\"name\":\"v\",\"type\":{\"type\":\"ROW\",\"fields\":[{\"name\":\"age\",\"type\":\"INT\"},{\"name\":\"city\",\"type\":\"STRING\"}]}}]}");
ParquetFileFormat format =
new ParquetFileFormat(new FileFormatFactory.FormatContext(options, 1024, 1024));

RowType writeType = DataTypes.ROW(DataTypes.FIELD(0, "v", DataTypes.VARIANT()));

FormatWriterFactory factory = format.createWriterFactory(writeType);
write(
factory,
file,
GenericRow.of(GenericVariant.fromJson("{\"age\":35,\"city\":\"Chicago\"}")),
GenericRow.of(GenericVariant.fromJson("{\"age\":25,\"other\":\"Hello\"}")));

// read without pruning
List<InternalRow> result1 = new ArrayList<>();
try (RecordReader<InternalRow> reader =
format.createReaderFactory(writeType, writeType, new ArrayList<>())
.createReader(
new FormatReaderContext(fileIO, file, fileIO.getFileSize(file)))) {
InternalRowSerializer serializer = new InternalRowSerializer(writeType);
reader.forEachRemaining(row -> result1.add(serializer.copy(row)));
}
assertThat(result1.get(0).getVariant(0).toJson())
.isEqualTo("{\"age\":35,\"city\":\"Chicago\"}");
assertThat(result1.get(1).getVariant(0).toJson())
.isEqualTo("{\"age\":25,\"other\":\"Hello\"}");

// read with typed col only
List<VariantAccessInfo.VariantField> variantFields2 = new ArrayList<>();
variantFields2.add(
new VariantAccessInfo.VariantField(
new DataField(0, "age", DataTypes.INT()), "$.age"));
VariantAccessInfo[] variantAccess2 = {new VariantAccessInfo("v", variantFields2)};
RowType readStructType2 =
DataTypes.ROW(
DataTypes.FIELD(
0, "v", DataTypes.ROW(DataTypes.FIELD(0, "age", DataTypes.INT()))));
List<InternalRow> result2 = new ArrayList<>();
try (RecordReader<InternalRow> reader =
format.createReaderFactory(writeType, writeType, new ArrayList<>(), variantAccess2)
.createReader(
new FormatReaderContext(fileIO, file, fileIO.getFileSize(file)))) {
InternalRowSerializer serializer = new InternalRowSerializer(readStructType2);
reader.forEachRemaining(row -> result2.add(serializer.copy(row)));
}
assertThat(result2.get(0).equals(GenericRow.of(GenericRow.of(35)))).isTrue();
assertThat(result2.get(1).equals(GenericRow.of(GenericRow.of(25)))).isTrue();

// read with typed col and untyped col
List<VariantAccessInfo.VariantField> variantFields3 = new ArrayList<>();
variantFields3.add(
new VariantAccessInfo.VariantField(
new DataField(0, "age", DataTypes.INT()), "$.age"));
variantFields3.add(
new VariantAccessInfo.VariantField(
new DataField(1, "other", DataTypes.STRING()), "$.other"));
VariantAccessInfo[] variantAccess3 = {new VariantAccessInfo("v", variantFields3)};
RowType readStructType3 =
DataTypes.ROW(
DataTypes.FIELD(
0,
"v",
DataTypes.ROW(
DataTypes.FIELD(0, "age", DataTypes.INT()),
DataTypes.FIELD(1, "other", DataTypes.STRING()))));
List<InternalRow> result3 = new ArrayList<>();
try (RecordReader<InternalRow> reader =
format.createReaderFactory(writeType, writeType, new ArrayList<>(), variantAccess3)
.createReader(
new FormatReaderContext(fileIO, file, fileIO.getFileSize(file)))) {
InternalRowSerializer serializer = new InternalRowSerializer(readStructType3);
reader.forEachRemaining(row -> result3.add(serializer.copy(row)));
}
assertThat(result3.get(0).equals(GenericRow.of(GenericRow.of(35, null)))).isTrue();
assertThat(
result3.get(1)
.equals(
GenericRow.of(
GenericRow.of(
25, BinaryString.fromString("Hello")))))
.isTrue();
}
}