diff --git a/src/iceberg/CMakeLists.txt b/src/iceberg/CMakeLists.txt index 27892935a..38771ac88 100644 --- a/src/iceberg/CMakeLists.txt +++ b/src/iceberg/CMakeLists.txt @@ -155,6 +155,7 @@ if(ICEBERG_BUILD_BUNDLE) arrow/arrow_fs_file_io.cc avro/avro_data_util.cc avro/avro_direct_decoder.cc + avro/avro_direct_encoder.cc avro/avro_reader.cc avro/avro_writer.cc avro/avro_register.cc diff --git a/src/iceberg/avro/avro_direct_decoder.cc b/src/iceberg/avro/avro_direct_decoder.cc index 60f79d218..3ab525d97 100644 --- a/src/iceberg/avro/avro_direct_decoder.cc +++ b/src/iceberg/avro/avro_direct_decoder.cc @@ -45,7 +45,7 @@ namespace { Status DecodeFieldToBuilder(const ::avro::NodePtr& avro_node, ::avro::Decoder& decoder, const FieldProjection& projection, const SchemaField& projected_field, - ::arrow::ArrayBuilder* array_builder, DecodeContext* ctx); + ::arrow::ArrayBuilder* array_builder, DecodeContext& ctx); /// \brief Skip an Avro value based on its schema without decoding Status SkipAvroValue(const ::avro::NodePtr& avro_node, ::avro::Decoder& decoder) { @@ -146,7 +146,7 @@ Status SkipAvroValue(const ::avro::NodePtr& avro_node, ::avro::Decoder& decoder) Status DecodeStructToBuilder(const ::avro::NodePtr& avro_node, ::avro::Decoder& decoder, const std::span& projections, const StructType& struct_type, - ::arrow::ArrayBuilder* array_builder, DecodeContext* ctx) { + ::arrow::ArrayBuilder* array_builder, DecodeContext& ctx) { if (avro_node->type() != ::avro::AVRO_RECORD) { return InvalidArgument("Expected Avro record, got type: {}", ToString(avro_node)); } @@ -157,15 +157,15 @@ Status DecodeStructToBuilder(const ::avro::NodePtr& avro_node, ::avro::Decoder& // Build a map from Avro field index to projection index (cached per struct schema) // -1 means the field should be skipped const FieldProjection* cache_key = projections.data(); - auto cache_it = ctx->avro_to_projection_cache.find(cache_key); + auto cache_it = ctx.avro_to_projection_cache.find(cache_key); std::vector* avro_to_projection; - if (cache_it != ctx->avro_to_projection_cache.end()) { + if (cache_it != ctx.avro_to_projection_cache.end()) { // Use cached mapping avro_to_projection = &cache_it->second; } else { // Build and cache the mapping - auto [inserted_it, inserted] = ctx->avro_to_projection_cache.emplace( + auto [inserted_it, inserted] = ctx.avro_to_projection_cache.emplace( cache_key, std::vector(avro_node->leaves(), -1)); avro_to_projection = &inserted_it->second; @@ -217,7 +217,7 @@ Status DecodeStructToBuilder(const ::avro::NodePtr& avro_node, ::avro::Decoder& Status DecodeListToBuilder(const ::avro::NodePtr& avro_node, ::avro::Decoder& decoder, const FieldProjection& element_projection, const ListType& list_type, - ::arrow::ArrayBuilder* array_builder, DecodeContext* ctx) { + ::arrow::ArrayBuilder* array_builder, DecodeContext& ctx) { if (avro_node->type() != ::avro::AVRO_ARRAY) { return InvalidArgument("Expected Avro array, got type: {}", ToString(avro_node)); } @@ -247,7 +247,7 @@ Status DecodeMapToBuilder(const ::avro::NodePtr& avro_node, ::avro::Decoder& dec const FieldProjection& key_projection, const FieldProjection& value_projection, const MapType& map_type, ::arrow::ArrayBuilder* array_builder, - DecodeContext* ctx) { + DecodeContext& ctx) { auto* map_builder = internal::checked_cast<::arrow::MapBuilder*>(array_builder); if (avro_node->type() == ::avro::AVRO_MAP) { @@ -317,7 +317,7 @@ Status DecodeNestedValueToBuilder(const ::avro::NodePtr& avro_node, const std::span& projections, const NestedType& projected_type, ::arrow::ArrayBuilder* array_builder, - DecodeContext* ctx) { + DecodeContext& ctx) { switch (projected_type.type_id()) { case TypeId::kStruct: { const auto& struct_type = internal::checked_cast(projected_type); @@ -354,7 +354,7 @@ Status DecodePrimitiveValueToBuilder(const ::avro::NodePtr& avro_node, ::avro::Decoder& decoder, const SchemaField& projected_field, ::arrow::ArrayBuilder* array_builder, - DecodeContext* ctx) { + DecodeContext& ctx) { const auto& projected_type = *projected_field.type(); if (!projected_type.is_primitive()) { return InvalidArgument("Expected primitive type, got: {}", projected_type.ToString()); @@ -430,8 +430,8 @@ Status DecodePrimitiveValueToBuilder(const ::avro::NodePtr& avro_node, ToString(avro_node)); } auto* builder = internal::checked_cast<::arrow::StringBuilder*>(array_builder); - decoder.decodeString(ctx->string_scratch); - ICEBERG_ARROW_RETURN_NOT_OK(builder->Append(ctx->string_scratch)); + decoder.decodeString(ctx.string_scratch); + ICEBERG_ARROW_RETURN_NOT_OK(builder->Append(ctx.string_scratch)); return {}; } @@ -441,9 +441,9 @@ Status DecodePrimitiveValueToBuilder(const ::avro::NodePtr& avro_node, ToString(avro_node)); } auto* builder = internal::checked_cast<::arrow::BinaryBuilder*>(array_builder); - decoder.decodeBytes(ctx->bytes_scratch); + decoder.decodeBytes(ctx.bytes_scratch); ICEBERG_ARROW_RETURN_NOT_OK(builder->Append( - ctx->bytes_scratch.data(), static_cast(ctx->bytes_scratch.size()))); + ctx.bytes_scratch.data(), static_cast(ctx.bytes_scratch.size()))); return {}; } @@ -456,9 +456,9 @@ Status DecodePrimitiveValueToBuilder(const ::avro::NodePtr& avro_node, auto* builder = internal::checked_cast<::arrow::FixedSizeBinaryBuilder*>(array_builder); - ctx->bytes_scratch.resize(fixed_type.length()); - decoder.decodeFixed(fixed_type.length(), ctx->bytes_scratch); - ICEBERG_ARROW_RETURN_NOT_OK(builder->Append(ctx->bytes_scratch.data())); + ctx.bytes_scratch.resize(fixed_type.length()); + decoder.decodeFixed(fixed_type.length(), ctx.bytes_scratch); + ICEBERG_ARROW_RETURN_NOT_OK(builder->Append(ctx.bytes_scratch.data())); return {}; } @@ -472,9 +472,9 @@ Status DecodePrimitiveValueToBuilder(const ::avro::NodePtr& avro_node, auto* builder = internal::checked_cast<::arrow::FixedSizeBinaryBuilder*>(array_builder); - ctx->bytes_scratch.resize(16); - decoder.decodeFixed(16, ctx->bytes_scratch); - ICEBERG_ARROW_RETURN_NOT_OK(builder->Append(ctx->bytes_scratch.data())); + ctx.bytes_scratch.resize(16); + decoder.decodeFixed(16, ctx.bytes_scratch); + ICEBERG_ARROW_RETURN_NOT_OK(builder->Append(ctx.bytes_scratch.data())); return {}; } @@ -489,11 +489,11 @@ Status DecodePrimitiveValueToBuilder(const ::avro::NodePtr& avro_node, size_t byte_width = avro_node->fixedSize(); auto* builder = internal::checked_cast<::arrow::Decimal128Builder*>(array_builder); - ctx->bytes_scratch.resize(byte_width); - decoder.decodeFixed(byte_width, ctx->bytes_scratch); + ctx.bytes_scratch.resize(byte_width); + decoder.decodeFixed(byte_width, ctx.bytes_scratch); ICEBERG_ARROW_ASSIGN_OR_RETURN( - auto decimal, ::arrow::Decimal128::FromBigEndian(ctx->bytes_scratch.data(), - ctx->bytes_scratch.size())); + auto decimal, ::arrow::Decimal128::FromBigEndian(ctx.bytes_scratch.data(), + ctx.bytes_scratch.size())); ICEBERG_ARROW_RETURN_NOT_OK(builder->Append(decimal)); return {}; } @@ -548,7 +548,7 @@ Status DecodePrimitiveValueToBuilder(const ::avro::NodePtr& avro_node, Status DecodeFieldToBuilder(const ::avro::NodePtr& avro_node, ::avro::Decoder& decoder, const FieldProjection& projection, const SchemaField& projected_field, - ::arrow::ArrayBuilder* array_builder, DecodeContext* ctx) { + ::arrow::ArrayBuilder* array_builder, DecodeContext& ctx) { if (avro_node->type() == ::avro::AVRO_UNION) { const size_t branch_index = decoder.decodeUnionIndex(); @@ -585,7 +585,7 @@ Status DecodeFieldToBuilder(const ::avro::NodePtr& avro_node, ::avro::Decoder& d Status DecodeAvroToBuilder(const ::avro::NodePtr& avro_node, ::avro::Decoder& decoder, const SchemaProjection& projection, const Schema& projected_schema, - ::arrow::ArrayBuilder* array_builder, DecodeContext* ctx) { + ::arrow::ArrayBuilder* array_builder, DecodeContext& ctx) { return DecodeNestedValueToBuilder(avro_node, decoder, projection.fields, projected_schema, array_builder, ctx); } diff --git a/src/iceberg/avro/avro_direct_decoder_internal.h b/src/iceberg/avro/avro_direct_decoder_internal.h index df4587fd0..5a2cf2240 100644 --- a/src/iceberg/avro/avro_direct_decoder_internal.h +++ b/src/iceberg/avro/avro_direct_decoder_internal.h @@ -82,6 +82,6 @@ struct DecodeContext { Status DecodeAvroToBuilder(const ::avro::NodePtr& avro_node, ::avro::Decoder& decoder, const SchemaProjection& projection, const Schema& projected_schema, - ::arrow::ArrayBuilder* array_builder, DecodeContext* ctx); + ::arrow::ArrayBuilder* array_builder, DecodeContext& ctx); } // namespace iceberg::avro diff --git a/src/iceberg/avro/avro_direct_encoder.cc b/src/iceberg/avro/avro_direct_encoder.cc new file mode 100644 index 000000000..0b500a893 --- /dev/null +++ b/src/iceberg/avro/avro_direct_encoder.cc @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include +#include +#include + +#include "iceberg/avro/avro_constants.h" +#include "iceberg/avro/avro_direct_encoder_internal.h" +#include "iceberg/util/checked_cast.h" + +namespace iceberg::avro { + +namespace { + +// Helper to validate union structure and get branch indices +// Returns {null_index, value_index, value_node} +struct UnionBranches { + size_t null_index; + size_t value_index; + ::avro::NodePtr value_node; +}; + +Result ValidateUnion(const ::avro::NodePtr& union_node) { + if (union_node->leaves() != 2) { + return InvalidArgument("Union must have exactly 2 branches, got {}", + union_node->leaves()); + } + + const auto& branch_0 = union_node->leafAt(0); + const auto& branch_1 = union_node->leafAt(1); + + if (branch_0->type() == ::avro::AVRO_NULL && branch_1->type() != ::avro::AVRO_NULL) { + return UnionBranches{.null_index = 0, .value_index = 1, .value_node = branch_1}; + } else if (branch_1->type() == ::avro::AVRO_NULL && + branch_0->type() != ::avro::AVRO_NULL) { + return UnionBranches{.null_index = 1, .value_index = 0, .value_node = branch_0}; + } else { + return InvalidArgument("Union must have exactly one null branch"); + } +} + +} // namespace + +Status EncodeArrowToAvro(const ::avro::NodePtr& avro_node, ::avro::Encoder& encoder, + const Type& type, const ::arrow::Array& array, int64_t row_index, + EncodeContext& ctx) { + if (row_index < 0 || row_index >= array.length()) { + return InvalidArgument("Row index {} out of bounds for array of length {}", row_index, + array.length()); + } + + const bool is_null = array.IsNull(row_index); + + // Handle unions (optional fields) + if (avro_node->type() == ::avro::AVRO_UNION) { + ICEBERG_ASSIGN_OR_RAISE(auto branches, ValidateUnion(avro_node)); + + if (is_null) { + encoder.encodeUnionIndex(branches.null_index); + encoder.encodeNull(); + return {}; + } else { + encoder.encodeUnionIndex(branches.value_index); + // Continue with the value branch + return EncodeArrowToAvro(branches.value_node, encoder, type, array, row_index, ctx); + } + } + + // Non-union null handling + if (is_null) { + return InvalidArgument("Null value in non-nullable field"); + } + + // Encode based on Avro type + switch (avro_node->type()) { + case ::avro::AVRO_NULL: + encoder.encodeNull(); + return {}; + + case ::avro::AVRO_BOOL: { + const auto& bool_array = + internal::checked_cast(array); + encoder.encodeBool(bool_array.Value(row_index)); + return {}; + } + + case ::avro::AVRO_INT: { + // AVRO_INT can represent: int32, date (days since epoch) + switch (array.type()->id()) { + case ::arrow::Type::INT32: { + const auto& int32_array = + internal::checked_cast(array); + encoder.encodeInt(int32_array.Value(row_index)); + return {}; + } + case ::arrow::Type::DATE32: { + const auto& date_array = + internal::checked_cast(array); + encoder.encodeInt(date_array.Value(row_index)); + return {}; + } + default: + return InvalidArgument("AVRO_INT expects Int32Array or Date32Array, got {}", + array.type()->ToString()); + } + } + + case ::avro::AVRO_LONG: { + // AVRO_LONG can represent: int64, time (microseconds), timestamp (microseconds) + switch (array.type()->id()) { + case ::arrow::Type::INT64: { + const auto& int64_array = + internal::checked_cast(array); + encoder.encodeLong(int64_array.Value(row_index)); + return {}; + } + case ::arrow::Type::TIME64: { + const auto& time_array = + internal::checked_cast(array); + encoder.encodeLong(time_array.Value(row_index)); + return {}; + } + case ::arrow::Type::TIMESTAMP: { + const auto& timestamp_array = + internal::checked_cast(array); + encoder.encodeLong(timestamp_array.Value(row_index)); + return {}; + } + default: + return InvalidArgument( + "AVRO_LONG expects Int64Array, Time64Array, or TimestampArray, got {}", + array.type()->ToString()); + } + } + + case ::avro::AVRO_FLOAT: { + const auto& float_array = internal::checked_cast(array); + encoder.encodeFloat(float_array.Value(row_index)); + return {}; + } + + case ::avro::AVRO_DOUBLE: { + const auto& double_array = + internal::checked_cast(array); + encoder.encodeDouble(double_array.Value(row_index)); + return {}; + } + + case ::avro::AVRO_STRING: { + const auto& string_array = + internal::checked_cast(array); + std::string_view value = string_array.GetView(row_index); + encoder.encodeString(std::string(value)); + return {}; + } + + case ::avro::AVRO_BYTES: { + const auto& binary_array = + internal::checked_cast(array); + std::string_view value = binary_array.GetView(row_index); + ctx.bytes_scratch.assign(value.begin(), value.end()); + encoder.encodeBytes(ctx.bytes_scratch); + return {}; + } + + case ::avro::AVRO_FIXED: { + // Handle UUID + if (avro_node->logicalType().type() == ::avro::LogicalType::UUID) { + const auto& extension_array = + internal::checked_cast(array); + const auto& fixed_array = + internal::checked_cast( + *extension_array.storage()); + std::string_view value = fixed_array.GetView(row_index); + ctx.bytes_scratch.assign(value.begin(), value.end()); + encoder.encodeFixed(ctx.bytes_scratch.data(), ctx.bytes_scratch.size()); + return {}; + } + + // Handle DECIMAL + if (avro_node->logicalType().type() == ::avro::LogicalType::DECIMAL) { + const auto& decimal_array = + internal::checked_cast(array); + std::string_view decimal_value = decimal_array.GetView(row_index); + ctx.bytes_scratch.assign(decimal_value.begin(), decimal_value.end()); + // Arrow Decimal128 bytes are in little-endian order, Avro requires big-endian + std::ranges::reverse(ctx.bytes_scratch); + encoder.encodeFixed(ctx.bytes_scratch.data(), ctx.bytes_scratch.size()); + return {}; + } + + // Handle regular FIXED + const auto& fixed_array = + internal::checked_cast(array); + std::string_view value = fixed_array.GetView(row_index); + ctx.bytes_scratch.assign(value.begin(), value.end()); + encoder.encodeFixed(ctx.bytes_scratch.data(), ctx.bytes_scratch.size()); + return {}; + } + + case ::avro::AVRO_RECORD: { + if (array.type()->id() != ::arrow::Type::STRUCT) { + return InvalidArgument("AVRO_RECORD expects StructArray, got {}", + array.type()->ToString()); + } + if (!type.is_nested()) { + return InvalidArgument("AVRO_RECORD expects nested type, got type {}", + type.ToString()); + } + + const auto& struct_array = + internal::checked_cast(array); + + // AVRO_RECORD corresponds to Iceberg StructType (including Schema which extends + // StructType). Note: ListType and MapType are encoded as AVRO_ARRAY and AVRO_MAP + // respectively, not AVRO_RECORD. + if (type.type_id() != TypeId::kStruct) { + return InvalidArgument("AVRO_RECORD expects StructType, got type {}", + type.ToString()); + } + + // Safe cast: type_id() == kStruct guarantees this is StructType or Schema + // (Schema extends StructType) + const auto& struct_type = static_cast(type); + const size_t num_fields = avro_node->leaves(); + + // Validate field count matches + if (struct_array.num_fields() != static_cast(num_fields)) { + return InvalidArgument( + "Field count mismatch: Arrow struct has {} fields, Avro node has {} fields", + struct_array.num_fields(), num_fields); + } + if (struct_type.fields().size() != num_fields) { + return InvalidArgument( + "Field count mismatch: Iceberg struct has {} fields, Avro node has {} fields", + struct_type.fields().size(), num_fields); + } + + for (size_t i = 0; i < num_fields; ++i) { + const auto& field_node = avro_node->leafAt(i); + const auto& field_array = struct_array.field(static_cast(i)); + const auto& field_schema = struct_type.fields()[i]; + + ICEBERG_RETURN_UNEXPECTED(EncodeArrowToAvro( + field_node, encoder, *field_schema.type(), *field_array, row_index, ctx)); + } + return {}; + } + + case ::avro::AVRO_ARRAY: { + // AVRO_ARRAY can represent either: + // 1. Iceberg ListType -> Arrow ListArray + // 2. Iceberg MapType with non-string keys -> Arrow MapArray (converted to array of + // records) + + const auto& element_node = avro_node->leafAt(0); + + // Try ListArray first (most common case) + if (array.type()->id() == ::arrow::Type::LIST) { + const auto& list_array = internal::checked_cast(array); + const auto& list_type = static_cast(type); + + const auto start = list_array.value_offset(row_index); + const auto end = list_array.value_offset(row_index + 1); + const auto length = end - start; + + encoder.arrayStart(); + if (length > 0) { + encoder.setItemCount(length); + const auto& values = list_array.values(); + const auto& element_type = *list_type.fields()[0].type(); + + for (int64_t i = start; i < end; ++i) { + encoder.startItem(); + ICEBERG_RETURN_UNEXPECTED( + EncodeArrowToAvro(element_node, encoder, element_type, *values, i, ctx)); + } + } + encoder.arrayEnd(); + return {}; + } + + // Handle MapArray (for maps with non-string keys, represented as array of key-value + // records in Avro) + if (array.type()->id() == ::arrow::Type::MAP) { + const auto& map_array = internal::checked_cast(array); + const auto& map_type = static_cast(type); + + const auto start = map_array.value_offset(row_index); + const auto end = map_array.value_offset(row_index + 1); + const auto length = end - start; + + encoder.arrayStart(); + if (length > 0) { + encoder.setItemCount(length); + const auto& keys = map_array.keys(); + const auto& values = map_array.items(); + const auto& key_type = *map_type.key().type(); + const auto& value_type = *map_type.value().type(); + + // The element_node should be a RECORD with "key" and "value" fields + for (int64_t i = start; i < end; ++i) { + encoder.startItem(); + // Encode the key-value pair as a record + if (element_node->type() != ::avro::AVRO_RECORD || + element_node->leaves() != 2) { + return InvalidArgument( + "Expected AVRO_RECORD with 2 fields for map key-value pair"); + } + + // Assumption: key is always at index 0, value at index 1 + // This matches the schema generation in ToAvroNodeVisitor::Visit(const + // MapType&) + const auto& key_node = element_node->leafAt(0); + const auto& value_node = element_node->leafAt(1); + + // Encode key + ICEBERG_RETURN_UNEXPECTED( + EncodeArrowToAvro(key_node, encoder, key_type, *keys, i, ctx)); + // Encode value + ICEBERG_RETURN_UNEXPECTED( + EncodeArrowToAvro(value_node, encoder, value_type, *values, i, ctx)); + } + } + encoder.arrayEnd(); + return {}; + } + + return InvalidArgument("AVRO_ARRAY must map to ListArray or MapArray, got {}", + array.type()->ToString()); + } + + case ::avro::AVRO_MAP: { + // AVRO_MAP is for maps with string keys + // Arrow represents this as MapArray + if (array.type()->id() != ::arrow::Type::MAP) { + return InvalidArgument("AVRO_MAP expects MapArray, got {}", + array.type()->ToString()); + } + if (type.type_id() != TypeId::kMap) { + return InvalidArgument("AVRO_MAP expects MapType, got type {}", type.ToString()); + } + + const auto& map_array = internal::checked_cast(array); + const auto& map_type = static_cast(type); + + const auto start = map_array.value_offset(row_index); + const auto end = map_array.value_offset(row_index + 1); + const auto length = end - start; + + encoder.mapStart(); + if (length > 0) { + encoder.setItemCount(length); + const auto& keys = map_array.keys(); + const auto& values = map_array.items(); + const auto& value_type = *map_type.value().type(); + // In Avro maps, leafAt(0) is the key type (always string), leafAt(1) is the value + // type + const auto& value_node = avro_node->leafAt(1); + + // Validate keys are strings + if (keys->type()->id() != ::arrow::Type::STRING && + keys->type()->id() != ::arrow::Type::LARGE_STRING) { + return InvalidArgument("AVRO_MAP keys must be StringArray, got {}", + keys->type()->ToString()); + } + + for (int64_t i = start; i < end; ++i) { + encoder.startItem(); + // Encode key (must be string in Avro maps) + if (keys->type()->id() == ::arrow::Type::STRING) { + const auto& string_array = + internal::checked_cast(*keys); + std::string_view key_value = string_array.GetView(i); + encoder.encodeString(std::string(key_value)); + } else { + const auto& large_string_array = + internal::checked_cast(*keys); + std::string_view key_value = large_string_array.GetView(i); + encoder.encodeString(std::string(key_value)); + } + + // Encode value + ICEBERG_RETURN_UNEXPECTED( + EncodeArrowToAvro(value_node, encoder, value_type, *values, i, ctx)); + } + } + encoder.mapEnd(); + return {}; + } + + case ::avro::AVRO_ENUM: + return NotSupported("ENUM type encoding not yet implemented"); + + case ::avro::AVRO_UNION: + // Already handled above + return InvalidArgument("Unexpected union handling"); + + default: + return NotSupported("Unsupported Avro type: {}", + ::avro::toString(avro_node->type())); + } +} + +} // namespace iceberg::avro diff --git a/src/iceberg/avro/avro_direct_encoder_internal.h b/src/iceberg/avro/avro_direct_encoder_internal.h new file mode 100644 index 000000000..82da51934 --- /dev/null +++ b/src/iceberg/avro/avro_direct_encoder_internal.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +#include +#include +#include + +#include "iceberg/result.h" +#include "iceberg/schema.h" + +namespace iceberg::avro { + +/// \brief Context for reusing scratch buffers during Avro encoding +/// +/// Avoids frequent small allocations by reusing temporary buffers across +/// multiple encode operations. This is particularly important for string, +/// binary, and fixed-size data types. +struct EncodeContext { + // Scratch buffer for binary/fixed/uuid/decimal data (reused across rows) + std::vector bytes_scratch; +}; + +/// \brief Directly encode Arrow data to Avro without GenericDatum +/// +/// Eliminates the GenericDatum intermediate layer by directly calling Avro encoder +/// methods from Arrow arrays. +/// +/// \param avro_node The Avro schema node for the data being encoded +/// \param encoder The Avro encoder to write data to +/// \param type The Iceberg type for the data +/// \param array The Arrow array containing the data to encode +/// \param row_index The index of the row to encode within the array +/// \param ctx Encode context for reusing scratch buffers +/// \return Status::OK if successful, or an error status +Status EncodeArrowToAvro(const ::avro::NodePtr& avro_node, ::avro::Encoder& encoder, + const Type& type, const ::arrow::Array& array, int64_t row_index, + EncodeContext& ctx); + +} // namespace iceberg::avro diff --git a/src/iceberg/avro/avro_reader.cc b/src/iceberg/avro/avro_reader.cc index 106e38655..964f6d1d4 100644 --- a/src/iceberg/avro/avro_reader.cc +++ b/src/iceberg/avro/avro_reader.cc @@ -173,7 +173,7 @@ class AvroReader::Impl { ICEBERG_RETURN_UNEXPECTED(DecodeAvroToBuilder( GetReaderSchema().root(), base_reader_->decoder(), projection_, *read_schema_, - context_->builder_.get(), &context_->decode_context_)); + context_->builder_.get(), context_->decode_context_)); } else { // GenericDatum-based decoding: decode via GenericDatum intermediate if (!datum_reader_->read(*context_->datum_)) { diff --git a/src/iceberg/avro/avro_writer.cc b/src/iceberg/avro/avro_writer.cc index 9d65db15c..a764b2670 100644 --- a/src/iceberg/avro/avro_writer.cc +++ b/src/iceberg/avro/avro_writer.cc @@ -32,6 +32,7 @@ #include "iceberg/arrow/arrow_fs_file_io_internal.h" #include "iceberg/arrow/arrow_status_internal.h" #include "iceberg/avro/avro_data_util_internal.h" +#include "iceberg/avro/avro_direct_encoder_internal.h" #include "iceberg/avro/avro_register.h" #include "iceberg/avro/avro_schema_util_internal.h" #include "iceberg/avro/avro_stream_internal.h" @@ -63,6 +64,7 @@ class AvroWriter::Impl { Status Open(const WriterOptions& options) { write_schema_ = options.schema; + use_direct_encoder_ = options.properties->Get(WriterProperties::kAvroSkipDatum); ::avro::NodePtr root; ICEBERG_RETURN_UNEXPECTED(ToAvroNodeVisitor{}.Visit(*write_schema_, &root)); @@ -87,11 +89,23 @@ class AvroWriter::Impl { vec.assign(value.begin(), value.end()); metadata.emplace(key, std::move(vec)); } - writer_ = std::make_unique<::avro::DataFileWriter<::avro::GenericDatum>>( - std::move(output_stream), *avro_schema_, - options.properties->Get(WriterProperties::kAvroSyncInterval), - ::avro::NULL_CODEC /*codec*/, metadata); - datum_ = std::make_unique<::avro::GenericDatum>(*avro_schema_); + + if (use_direct_encoder_) { + // Use direct encoder (faster path) + writer_base_ = std::make_unique<::avro::DataFileWriterBase>( + std::move(output_stream), *avro_schema_, + options.properties->Get(WriterProperties::kAvroSyncInterval), + ::avro::NULL_CODEC /*codec*/, metadata); + avro_root_node_ = avro_schema_->root(); + } else { + // Use GenericDatum (legacy path) + writer_datum_ = std::make_unique<::avro::DataFileWriter<::avro::GenericDatum>>( + std::move(output_stream), *avro_schema_, + options.properties->Get(WriterProperties::kAvroSyncInterval), + ::avro::NULL_CODEC /*codec*/, metadata); + datum_ = std::make_unique<::avro::GenericDatum>(*avro_schema_); + } + ICEBERG_RETURN_UNEXPECTED(ToArrowSchema(*write_schema_, &arrow_schema_)); return {}; } @@ -100,25 +114,47 @@ class AvroWriter::Impl { ICEBERG_ARROW_ASSIGN_OR_RETURN(auto result, ::arrow::ImportArray(data, &arrow_schema_)); - for (int64_t i = 0; i < result->length(); i++) { - ICEBERG_RETURN_UNEXPECTED(ExtractDatumFromArray(*result, i, datum_.get())); - writer_->write(*datum_); + if (use_direct_encoder_) { + // Direct encoder path + for (int64_t i = 0; i < result->length(); i++) { + ICEBERG_RETURN_UNEXPECTED( + EncodeArrowToAvro(avro_root_node_, writer_base_->encoder(), *write_schema_, + *result, i, encode_ctx_)); + writer_base_->incr(); + } + } else { + // GenericDatum path + for (int64_t i = 0; i < result->length(); i++) { + ICEBERG_RETURN_UNEXPECTED(ExtractDatumFromArray(*result, i, datum_.get())); + writer_datum_->write(*datum_); + } } return {}; } Status Close() { - if (writer_ != nullptr) { - writer_->close(); - writer_.reset(); - ICEBERG_ARROW_ASSIGN_OR_RETURN(total_bytes_, arrow_output_stream_->Tell()); - ICEBERG_ARROW_RETURN_NOT_OK(arrow_output_stream_->Close()); + if (use_direct_encoder_) { + if (writer_base_ != nullptr) { + writer_base_->close(); + writer_base_.reset(); + ICEBERG_ARROW_ASSIGN_OR_RETURN(total_bytes_, arrow_output_stream_->Tell()); + ICEBERG_ARROW_RETURN_NOT_OK(arrow_output_stream_->Close()); + } + } else { + if (writer_datum_ != nullptr) { + writer_datum_->close(); + writer_datum_.reset(); + ICEBERG_ARROW_ASSIGN_OR_RETURN(total_bytes_, arrow_output_stream_->Tell()); + ICEBERG_ARROW_RETURN_NOT_OK(arrow_output_stream_->Close()); + } } return {}; } - bool Closed() const { return writer_ == nullptr; } + bool Closed() const { + return use_direct_encoder_ ? writer_base_ == nullptr : writer_datum_ == nullptr; + } Result length() { if (Closed()) { @@ -136,14 +172,29 @@ class AvroWriter::Impl { std::shared_ptr<::avro::ValidSchema> avro_schema_; // Arrow output stream of the Avro file to write std::shared_ptr<::arrow::io::OutputStream> arrow_output_stream_; - // The avro writer to write the data into a datum. - std::unique_ptr<::avro::DataFileWriter<::avro::GenericDatum>> writer_; - // Reusable Avro datum for writing individual records. - std::unique_ptr<::avro::GenericDatum> datum_; // Arrow schema to write data. ArrowSchema arrow_schema_; // Total length of the written Avro file. int64_t total_bytes_ = 0; + + // Flag to determine which encoder to use + bool use_direct_encoder_ = true; + + // Direct encoder path (fast) + // Root node of the Avro schema (only used if direct encoder is enabled) + ::avro::NodePtr avro_root_node_; + // The avro writer using direct encoder (only used if direct encoder is enabled) + std::unique_ptr<::avro::DataFileWriterBase> writer_base_; + // Encode context for reusing scratch buffers (only used if direct encoder is enabled) + EncodeContext encode_ctx_; + + // GenericDatum path (legacy) + // The avro writer to write the data into a datum (only used if direct encoder is + // disabled) + std::unique_ptr<::avro::DataFileWriter<::avro::GenericDatum>> writer_datum_; + // Reusable Avro datum for writing individual records (only used if direct encoder is + // disabled) + std::unique_ptr<::avro::GenericDatum> datum_; }; AvroWriter::~AvroWriter() = default; diff --git a/src/iceberg/file_writer.h b/src/iceberg/file_writer.h index f3540dd75..87a771737 100644 --- a/src/iceberg/file_writer.h +++ b/src/iceberg/file_writer.h @@ -49,6 +49,10 @@ class WriterProperties : public ConfigBase { /// \brief The sync interval used by Avro writer. inline static Entry kAvroSyncInterval{"write.avro.sync-interval", 16 * 1024}; + /// \brief Whether to skip GenericDatum and use direct encoder for Avro writing. + /// When true, uses direct encoder (faster). When false, uses GenericDatum. + inline static Entry kAvroSkipDatum{"write.avro.skip-datum", true}; + /// TODO(gangwu): add more properties, like compression codec, compression level, etc. /// \brief Create a default WriterProperties instance. diff --git a/src/iceberg/test/avro_test.cc b/src/iceberg/test/avro_test.cc index 215462b5d..ada4d7141 100644 --- a/src/iceberg/test/avro_test.cc +++ b/src/iceberg/test/avro_test.cc @@ -503,6 +503,448 @@ INSTANTIATE_TEST_SUITE_P(DirectDecoderModes, AvroReaderParameterizedTest, return info.param ? "DirectDecoder" : "GenericDatum"; }); +// ==================================================================================== +// Dedicated Writer Tests - Verify encoder output directly using Avro library +// ==================================================================================== + +class AvroWriterTest : public TempFileTestBase { + protected: + static void SetUpTestSuite() { RegisterAll(); } + + void SetUp() override { + TempFileTestBase::SetUp(); + local_fs_ = std::make_shared<::arrow::fs::LocalFileSystem>(); + file_io_ = std::make_shared(local_fs_); + temp_avro_file_ = CreateNewTempFilePathWithSuffix(".avro"); + } + + // Helper to write Arrow data to Avro file + void WriteAvroFile(std::shared_ptr schema, const std::string& json_data) { + ArrowSchema arrow_c_schema; + ASSERT_THAT(ToArrowSchema(*schema, &arrow_c_schema), IsOk()); + + auto arrow_schema_result = ::arrow::ImportType(&arrow_c_schema); + ASSERT_TRUE(arrow_schema_result.ok()); + auto arrow_schema = arrow_schema_result.ValueOrDie(); + + auto array_result = ::arrow::json::ArrayFromJSONString(arrow_schema, json_data); + ASSERT_TRUE(array_result.ok()); + auto array = array_result.ValueOrDie(); + + struct ArrowArray arrow_array; + auto export_result = ::arrow::ExportArray(*array, &arrow_array); + ASSERT_TRUE(export_result.ok()); + + std::unordered_map metadata = { + {"writer_test", "direct_encoder"}}; + + auto writer_properties = WriterProperties::default_properties(); + writer_properties->Set(WriterProperties::kAvroSkipDatum, skip_datum_); + + auto writer_result = WriterFactoryRegistry::Open( + FileFormatType::kAvro, {.path = temp_avro_file_, + .schema = schema, + .io = file_io_, + .metadata = metadata, + .properties = std::move(writer_properties)}); + ASSERT_TRUE(writer_result.has_value()); + auto writer = std::move(writer_result.value()); + ASSERT_THAT(writer->Write(&arrow_array), IsOk()); + ASSERT_THAT(writer->Close(), IsOk()); + } + + // Helper to read raw Avro file and verify using Avro GenericDatum + template + void VerifyAvroFileContent(VerifyFunc verify_func) { + ::avro::DataFileReader<::avro::GenericDatum> reader(temp_avro_file_.c_str()); + + // Create datum with the schema from the file + ::avro::GenericDatum datum(reader.dataSchema()); + + size_t row_count = 0; + while (reader.read(datum)) { + verify_func(datum, row_count); + row_count++; + } + reader.close(); + } + + std::shared_ptr<::arrow::fs::LocalFileSystem> local_fs_; + std::shared_ptr file_io_; + std::string temp_avro_file_; + bool skip_datum_{true}; +}; + +// Parameterized test fixture for testing both direct encoder and GenericDatum modes +class AvroWriterParameterizedTest : public AvroWriterTest, + public ::testing::WithParamInterface { + protected: + void SetUp() override { + AvroWriterTest::SetUp(); + skip_datum_ = GetParam(); + } +}; + +TEST_P(AvroWriterParameterizedTest, WritePrimitiveTypes) { + auto schema = std::make_shared(std::vector{ + SchemaField::MakeRequired(1, "bool_col", std::make_shared()), + SchemaField::MakeRequired(2, "int_col", std::make_shared()), + SchemaField::MakeRequired(3, "long_col", std::make_shared()), + SchemaField::MakeRequired(4, "float_col", std::make_shared()), + SchemaField::MakeRequired(5, "double_col", std::make_shared()), + SchemaField::MakeRequired(6, "string_col", std::make_shared())}); + + std::string test_data = R"([ + [true, 42, 1234567890, 3.14, 2.71828, "hello"], + [false, -100, -9876543210, -1.5, 0.0, "world"] + ])"; + + WriteAvroFile(schema, test_data); + + VerifyAvroFileContent([](const ::avro::GenericDatum& datum, size_t row_idx) { + ASSERT_EQ(datum.type(), ::avro::AVRO_RECORD); + const auto& record = datum.value<::avro::GenericRecord>(); + ASSERT_EQ(record.fieldCount(), 6); + + if (row_idx == 0) { + EXPECT_TRUE(record.fieldAt(0).value()); + EXPECT_EQ(record.fieldAt(1).value(), 42); + EXPECT_EQ(record.fieldAt(2).value(), 1234567890); + EXPECT_FLOAT_EQ(record.fieldAt(3).value(), 3.14f); + EXPECT_DOUBLE_EQ(record.fieldAt(4).value(), 2.71828); + EXPECT_EQ(record.fieldAt(5).value(), "hello"); + } else if (row_idx == 1) { + EXPECT_FALSE(record.fieldAt(0).value()); + EXPECT_EQ(record.fieldAt(1).value(), -100); + EXPECT_EQ(record.fieldAt(2).value(), -9876543210); + EXPECT_FLOAT_EQ(record.fieldAt(3).value(), -1.5f); + EXPECT_DOUBLE_EQ(record.fieldAt(4).value(), 0.0); + EXPECT_EQ(record.fieldAt(5).value(), "world"); + } + }); +} + +TEST_P(AvroWriterParameterizedTest, WriteTemporalTypes) { + auto schema = std::make_shared(std::vector{ + SchemaField::MakeRequired(1, "date_col", std::make_shared()), + SchemaField::MakeRequired(2, "time_col", std::make_shared()), + SchemaField::MakeRequired(3, "timestamp_col", std::make_shared())}); + + std::string test_data = R"([ + [18628, 43200000000, 1640995200000000], + [18629, 86399000000, 1641081599000000] + ])"; + + WriteAvroFile(schema, test_data); + + VerifyAvroFileContent([](const ::avro::GenericDatum& datum, size_t row_idx) { + ASSERT_EQ(datum.type(), ::avro::AVRO_RECORD); + const auto& record = datum.value<::avro::GenericRecord>(); + ASSERT_EQ(record.fieldCount(), 3); + + if (row_idx == 0) { + EXPECT_EQ(record.fieldAt(0).value(), 18628); + EXPECT_EQ(record.fieldAt(1).value(), 43200000000); + EXPECT_EQ(record.fieldAt(2).value(), 1640995200000000); + } else if (row_idx == 1) { + EXPECT_EQ(record.fieldAt(0).value(), 18629); + EXPECT_EQ(record.fieldAt(1).value(), 86399000000); + EXPECT_EQ(record.fieldAt(2).value(), 1641081599000000); + } + }); +} + +TEST_P(AvroWriterParameterizedTest, WriteNestedStruct) { + auto schema = std::make_shared(std::vector{ + SchemaField::MakeRequired(1, "id", std::make_shared()), + SchemaField::MakeRequired( + 2, "person", + std::make_shared(std::vector{ + SchemaField::MakeRequired(3, "name", std::make_shared()), + SchemaField::MakeRequired(4, "age", std::make_shared())}))}); + + std::string test_data = R"([ + [1, ["Alice", 30]], + [2, ["Bob", 25]] + ])"; + + WriteAvroFile(schema, test_data); + + VerifyAvroFileContent([](const ::avro::GenericDatum& datum, size_t row_idx) { + ASSERT_EQ(datum.type(), ::avro::AVRO_RECORD); + const auto& record = datum.value<::avro::GenericRecord>(); + ASSERT_EQ(record.fieldCount(), 2); + + if (row_idx == 0) { + EXPECT_EQ(record.fieldAt(0).value(), 1); + const auto& person = record.fieldAt(1).value<::avro::GenericRecord>(); + EXPECT_EQ(person.fieldAt(0).value(), "Alice"); + EXPECT_EQ(person.fieldAt(1).value(), 30); + } else if (row_idx == 1) { + EXPECT_EQ(record.fieldAt(0).value(), 2); + const auto& person = record.fieldAt(1).value<::avro::GenericRecord>(); + EXPECT_EQ(person.fieldAt(0).value(), "Bob"); + EXPECT_EQ(person.fieldAt(1).value(), 25); + } + }); +} + +TEST_P(AvroWriterParameterizedTest, WriteListType) { + auto schema = std::make_shared(std::vector{ + SchemaField::MakeRequired(1, "id", std::make_shared()), + SchemaField::MakeRequired(2, "tags", + std::make_shared(SchemaField::MakeRequired( + 3, "element", std::make_shared())))}); + + std::string test_data = R"([ + [1, ["tag1", "tag2", "tag3"]], + [2, ["foo", "bar"]], + [3, []] + ])"; + + WriteAvroFile(schema, test_data); + + VerifyAvroFileContent([](const ::avro::GenericDatum& datum, size_t row_idx) { + ASSERT_EQ(datum.type(), ::avro::AVRO_RECORD); + const auto& record = datum.value<::avro::GenericRecord>(); + ASSERT_EQ(record.fieldCount(), 2); + + if (row_idx == 0) { + EXPECT_EQ(record.fieldAt(0).value(), 1); + const auto& tags = record.fieldAt(1).value<::avro::GenericArray>(); + ASSERT_EQ(tags.value().size(), 3); + EXPECT_EQ(tags.value()[0].value(), "tag1"); + EXPECT_EQ(tags.value()[1].value(), "tag2"); + EXPECT_EQ(tags.value()[2].value(), "tag3"); + } else if (row_idx == 1) { + EXPECT_EQ(record.fieldAt(0).value(), 2); + const auto& tags = record.fieldAt(1).value<::avro::GenericArray>(); + ASSERT_EQ(tags.value().size(), 2); + EXPECT_EQ(tags.value()[0].value(), "foo"); + EXPECT_EQ(tags.value()[1].value(), "bar"); + } else if (row_idx == 2) { + EXPECT_EQ(record.fieldAt(0).value(), 3); + const auto& tags = record.fieldAt(1).value<::avro::GenericArray>(); + EXPECT_EQ(tags.value().size(), 0); + } + }); +} + +TEST_P(AvroWriterParameterizedTest, WriteMapTypeWithStringKey) { + auto schema = std::make_shared( + std::vector{SchemaField::MakeRequired( + 1, "properties", + std::make_shared( + SchemaField::MakeRequired(2, "key", std::make_shared()), + SchemaField::MakeRequired(3, "value", std::make_shared())))}); + + std::string test_data = R"([ + [[["key1", 100], ["key2", 200]]], + [[["a", 1], ["b", 2], ["c", 3]]] + ])"; + + WriteAvroFile(schema, test_data); + + VerifyAvroFileContent([](const ::avro::GenericDatum& datum, size_t row_idx) { + ASSERT_EQ(datum.type(), ::avro::AVRO_RECORD); + const auto& record = datum.value<::avro::GenericRecord>(); + ASSERT_EQ(record.fieldCount(), 1); + + const auto& map = record.fieldAt(0).value<::avro::GenericMap>(); + const auto& map_value = map.value(); + if (row_idx == 0) { + ASSERT_EQ(map_value.size(), 2); + // Find entries by key + bool found_key1 = false; + bool found_key2 = false; + for (const auto& entry : map_value) { + if (entry.first == "key1") { + EXPECT_EQ(entry.second.value(), 100); + found_key1 = true; + } else if (entry.first == "key2") { + EXPECT_EQ(entry.second.value(), 200); + found_key2 = true; + } + } + EXPECT_TRUE(found_key1 && found_key2); + } else if (row_idx == 1) { + ASSERT_EQ(map_value.size(), 3); + // Find entries by key + bool found_a = false; + bool found_b = false; + bool found_c = false; + for (const auto& entry : map_value) { + if (entry.first == "a") { + EXPECT_EQ(entry.second.value(), 1); + found_a = true; + } else if (entry.first == "b") { + EXPECT_EQ(entry.second.value(), 2); + found_b = true; + } else if (entry.first == "c") { + EXPECT_EQ(entry.second.value(), 3); + found_c = true; + } + } + EXPECT_TRUE(found_a && found_b && found_c); + } + }); +} + +TEST_P(AvroWriterParameterizedTest, WriteMapTypeWithNonStringKey) { + auto schema = std::make_shared( + std::vector{SchemaField::MakeRequired( + 1, "int_map", + std::make_shared( + SchemaField::MakeRequired(2, "key", std::make_shared()), + SchemaField::MakeRequired(3, "value", std::make_shared())))}); + + std::string test_data = R"([ + [[[1, "one"], [2, "two"], [3, "three"]]], + [[[10, "ten"], [20, "twenty"]]] + ])"; + + WriteAvroFile(schema, test_data); + + VerifyAvroFileContent([](const ::avro::GenericDatum& datum, size_t row_idx) { + ASSERT_EQ(datum.type(), ::avro::AVRO_RECORD); + const auto& record = datum.value<::avro::GenericRecord>(); + ASSERT_EQ(record.fieldCount(), 1); + + // Maps with non-string keys are encoded as arrays of key-value records in Avro + const auto& array = record.fieldAt(0).value<::avro::GenericArray>(); + if (row_idx == 0) { + ASSERT_EQ(array.value().size(), 3); + + const auto& entry0 = array.value()[0].value<::avro::GenericRecord>(); + EXPECT_EQ(entry0.fieldAt(0).value(), 1); + EXPECT_EQ(entry0.fieldAt(1).value(), "one"); + + const auto& entry1 = array.value()[1].value<::avro::GenericRecord>(); + EXPECT_EQ(entry1.fieldAt(0).value(), 2); + EXPECT_EQ(entry1.fieldAt(1).value(), "two"); + + const auto& entry2 = array.value()[2].value<::avro::GenericRecord>(); + EXPECT_EQ(entry2.fieldAt(0).value(), 3); + EXPECT_EQ(entry2.fieldAt(1).value(), "three"); + } else if (row_idx == 1) { + ASSERT_EQ(array.value().size(), 2); + + const auto& entry0 = array.value()[0].value<::avro::GenericRecord>(); + EXPECT_EQ(entry0.fieldAt(0).value(), 10); + EXPECT_EQ(entry0.fieldAt(1).value(), "ten"); + + const auto& entry1 = array.value()[1].value<::avro::GenericRecord>(); + EXPECT_EQ(entry1.fieldAt(0).value(), 20); + EXPECT_EQ(entry1.fieldAt(1).value(), "twenty"); + } + }); +} + +TEST_P(AvroWriterParameterizedTest, WriteEmptyMaps) { + auto schema = std::make_shared(std::vector{ + SchemaField::MakeRequired( + 1, "string_map", + std::make_shared( + SchemaField::MakeRequired(2, "key", std::make_shared()), + SchemaField::MakeRequired(3, "value", std::make_shared()))), + SchemaField::MakeRequired( + 4, "int_map", + std::make_shared( + SchemaField::MakeRequired(5, "key", std::make_shared()), + SchemaField::MakeRequired(6, "value", std::make_shared())))}); + + // Test empty maps for both string and non-string keys + std::string test_data = R"([ + [[], []], + [[["a", 1]], []] + ])"; + + // Just verify writing succeeds (empty maps are handled correctly by the encoder) + ASSERT_NO_FATAL_FAILURE(WriteAvroFile(schema, test_data)); +} + +TEST_P(AvroWriterParameterizedTest, WriteOptionalFields) { + auto schema = std::make_shared(std::vector{ + SchemaField::MakeRequired(1, "id", std::make_shared()), + SchemaField::MakeOptional(2, "name", std::make_shared()), + SchemaField::MakeOptional(3, "age", std::make_shared())}); + + std::string test_data = R"([ + [1, "Alice", 30], + [2, null, 25], + [3, "Charlie", null], + [4, null, null] + ])"; + + WriteAvroFile(schema, test_data); + + VerifyAvroFileContent([](const ::avro::GenericDatum& datum, size_t row_idx) { + ASSERT_EQ(datum.type(), ::avro::AVRO_RECORD); + const auto& record = datum.value<::avro::GenericRecord>(); + ASSERT_EQ(record.fieldCount(), 3); + + EXPECT_EQ(record.fieldAt(0).value(), static_cast(row_idx + 1)); + + if (row_idx == 0) { + EXPECT_EQ(record.fieldAt(1).unionBranch(), 1); // non-null + EXPECT_EQ(record.fieldAt(1).value(), "Alice"); + EXPECT_EQ(record.fieldAt(2).unionBranch(), 1); // non-null + EXPECT_EQ(record.fieldAt(2).value(), 30); + } else if (row_idx == 1) { + EXPECT_EQ(record.fieldAt(1).unionBranch(), 0); // null + EXPECT_EQ(record.fieldAt(2).unionBranch(), 1); // non-null + EXPECT_EQ(record.fieldAt(2).value(), 25); + } else if (row_idx == 2) { + EXPECT_EQ(record.fieldAt(1).unionBranch(), 1); // non-null + EXPECT_EQ(record.fieldAt(1).value(), "Charlie"); + EXPECT_EQ(record.fieldAt(2).unionBranch(), 0); // null + } else if (row_idx == 3) { + EXPECT_EQ(record.fieldAt(1).unionBranch(), 0); // null + EXPECT_EQ(record.fieldAt(2).unionBranch(), 0); // null + } + }); +} + +TEST_P(AvroWriterParameterizedTest, WriteLargeDataset) { + auto schema = std::make_shared(std::vector{ + SchemaField::MakeRequired(1, "id", std::make_shared()), + SchemaField::MakeRequired(2, "value", std::make_shared())}); + + // Generate large dataset JSON + std::ostringstream json; + json << "["; + for (int i = 0; i < 1000; ++i) { + if (i > 0) json << ", "; + json << "[" << i << ", " << (i * 1.5) << "]"; + } + json << "]"; + + WriteAvroFile(schema, json.str()); + + size_t expected_row_count = 1000; + size_t actual_row_count = 0; + + VerifyAvroFileContent([&](const ::avro::GenericDatum& datum, size_t row_idx) { + ASSERT_EQ(datum.type(), ::avro::AVRO_RECORD); + const auto& record = datum.value<::avro::GenericRecord>(); + ASSERT_EQ(record.fieldCount(), 2); + + EXPECT_EQ(record.fieldAt(0).value(), static_cast(row_idx)); + EXPECT_DOUBLE_EQ(record.fieldAt(1).value(), row_idx * 1.5); + + actual_row_count++; + }); + + EXPECT_EQ(actual_row_count, expected_row_count); +} + +// Instantiate parameterized tests for both direct encoder and GenericDatum paths +INSTANTIATE_TEST_SUITE_P(DirectEncoderModes, AvroWriterParameterizedTest, + ::testing::Values(true, false), + [](const ::testing::TestParamInfo& info) { + return info.param ? "DirectEncoder" : "GenericDatum"; + }); + TEST_F(AvroReaderTest, BufferSizeConfiguration) { // Test default buffer size auto properties1 = ReaderProperties::default_properties();