diff --git a/Cargo.toml b/Cargo.toml index bb1d0c03d1..7331cd8228 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ object_store = { version = "0.12.1" } parquet = { version = "57" } # datafusion -datafusion = "51.0" +datafusion = { version = "51.0", features = ["default", "parquet_encryption"] } datafusion-ffi = "51.0" datafusion-proto = "51.0" @@ -60,7 +60,7 @@ tempfile = { version = "3" } uuid = { version = "1" } # runtime / async -async-trait = { version = "0.1" } +async-trait = { version = "0.1.89" } futures = { version = "0.3" } tokio = { version = "1" } typed-builder = { version = "0.23.0" } @@ -144,4 +144,4 @@ incremental = false inherits = "release" opt-level = 3 codegen-units = 1 -lto = "fat" +lto = "fat" diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 0c959e3eb2..8db638b542 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -31,7 +31,7 @@ arrow-ord = { workspace = true } arrow-row = { workspace = true } arrow-schema = { workspace = true, features = ["serde"] } arrow-select = { workspace = true } -parquet = { workspace = true, features = ["async", "object_store"] } +parquet = { workspace = true, features = ["async", "object_store", "encryption"] } object_store = { workspace = true } # datafusion @@ -94,6 +94,7 @@ datatest-stable = "0.3" deltalake-test = { path = "../test" } dotenvy = "0" fs_extra = "1.2.0" +paste = "1" pretty_assertions = "1.2.1" pretty_env_logger = "0.5.0" rstest = { version = "0.26.1" } @@ -145,6 +146,10 @@ required-features = ["datafusion"] name = "command_vacuum" required-features = ["datafusion"] +[[test]] +name = "commands_with_encryption" +required-features = ["datafusion"] + [[test]] name = "commit_info_format" required-features = ["datafusion"] diff --git a/crates/core/src/delta_datafusion/table_provider.rs b/crates/core/src/delta_datafusion/table_provider.rs index 9dcc071269..42d753128d 100644 --- a/crates/core/src/delta_datafusion/table_provider.rs +++ b/crates/core/src/delta_datafusion/table_provider.rs @@ -343,11 +343,18 @@ impl DeltaScanConfigBuilder { None }; + let table_parquet_options = snapshot + .load_config() + .file_format_options + .as_ref() + .map(|ffo| ffo.table_options().parquet); + Ok(DeltaScanConfig { file_column_name, wrap_partition_values: self.wrap_partition_values.unwrap_or(true), enable_parquet_pushdown: self.enable_parquet_pushdown, schema: self.schema.clone(), + table_parquet_options, }) } } @@ -363,6 +370,9 @@ pub struct DeltaScanConfig { pub enable_parquet_pushdown: bool, /// Schema to read as pub schema: Option, + /// Options that control how Parquet files are read + #[serde(skip)] + pub table_parquet_options: Option, } pub(crate) struct DeltaScanBuilder<'a> { @@ -644,13 +654,31 @@ impl<'a> DeltaScanBuilder<'a> { let stats = stats.unwrap_or(Statistics::new_unknown(&schema)); - let parquet_options = TableParquetOptions { - global: self.session.config().options().execution.parquet.clone(), - ..Default::default() - }; + let parquet_options: TableParquetOptions = config + .table_parquet_options + .clone() + .unwrap_or_else(|| self.session.table_options().parquet.clone()); + + // We have to set the encryption factory on the ParquetSource based on the Parquet options, + // as this is usually handled by the ParquetFormat type in DataFusion, + // which is not used in delta-rs. + let encryption_factory = parquet_options + .crypto + .factory_id + .as_ref() + .map(|factory_id| { + self.session + .runtime_env() + .parquet_encryption_factory(factory_id) + }) + .transpose()?; let mut file_source = ParquetSource::new(parquet_options); + if let Some(encryption_factory) = encryption_factory { + file_source = file_source.with_encryption_factory(encryption_factory); + } + // Sometimes (i.e Merge) we want to prune files that don't make the // filter and read the entire contents for files that do match the // filter @@ -731,6 +759,9 @@ impl TableProvider for DeltaTable { limit: Option, ) -> Result> { register_store(self.log_store(), session.runtime_env().as_ref()); + if let Some(format_options) = &self.config.file_format_options { + format_options.update_session(session)?; + } let filter_expr = conjunction(filters.iter().cloned()); let scan = DeltaScanBuilder::new(self.snapshot()?.snapshot(), self.log_store(), session) @@ -819,6 +850,10 @@ impl TableProvider for DeltaTableProvider { limit: Option, ) -> Result> { register_store(self.log_store.clone(), session.runtime_env().as_ref()); + if let Some(format_options) = &self.snapshot.load_config().file_format_options { + format_options.update_session(session)?; + } + let filter_expr = conjunction(filters.iter().cloned()); let mut scan = DeltaScanBuilder::new(&self.snapshot, self.log_store.clone(), session) diff --git a/crates/core/src/operations/add_column.rs b/crates/core/src/operations/add_column.rs index e789cf64ce..10b02f5ba2 100644 --- a/crates/core/src/operations/add_column.rs +++ b/crates/core/src/operations/add_column.rs @@ -1,10 +1,9 @@ //! Add a new column to a table -use std::sync::Arc; - use delta_kernel::schema::StructType; use futures::future::BoxFuture; use itertools::Itertools; +use std::sync::Arc; use super::{CustomExecuteHandler, Operation}; use crate::kernel::schema::merge_delta_struct; diff --git a/crates/core/src/operations/create.rs b/crates/core/src/operations/create.rs index d73bf74117..b62dcccc6d 100644 --- a/crates/core/src/operations/create.rs +++ b/crates/core/src/operations/create.rs @@ -1,12 +1,11 @@ //! Command for creating a new delta table // https://github.com/delta-io/delta/blob/master/core/src/main/scala/org/apache/spark/sql/delta/commands/CreateDeltaTableCommand.scala -use std::collections::HashMap; -use std::sync::Arc; - use delta_kernel::schema::MetadataValue; use futures::future::BoxFuture; use serde_json::Value; +use std::collections::HashMap; +use std::sync::Arc; use tracing::log::*; use uuid::Uuid; @@ -21,7 +20,7 @@ use crate::logstore::LogStoreRef; use crate::protocol::{DeltaOperation, SaveMode}; use crate::table::builder::ensure_table_uri; use crate::table::config::TableProperty; -use crate::{DeltaTable, DeltaTableBuilder}; +use crate::{DeltaTable, DeltaTableBuilder, DeltaTableConfig}; #[derive(thiserror::Error, Debug)] enum CreateError { @@ -61,6 +60,7 @@ pub struct CreateBuilder { storage_options: Option>, actions: Vec, log_store: Option, + table_config: DeltaTableConfig, configuration: HashMap>, /// Additional information to add to the commit commit_properties: CommitProperties, @@ -98,6 +98,7 @@ impl CreateBuilder { storage_options: None, actions: Default::default(), log_store: None, + table_config: DeltaTableConfig::default(), configuration: Default::default(), commit_properties: CommitProperties::default(), raise_if_key_not_exists: true, @@ -238,6 +239,12 @@ impl CreateBuilder { self } + /// Set configuration options for the table + pub fn with_table_config(mut self, table_config: DeltaTableConfig) -> Self { + self.table_config = table_config; + self + } + /// Set a custom execute handler, for pre and post execution pub fn with_custom_execute_handler(mut self, handler: Arc) -> Self { self.custom_execute_handler = Some(handler); @@ -262,7 +269,7 @@ impl CreateBuilder { let (storage_url, table) = if let Some(log_store) = self.log_store { ( ensure_table_uri(log_store.root_uri())?.as_str().to_string(), - DeltaTable::new(log_store, Default::default()), + DeltaTable::new(log_store, self.table_config.clone()), ) } else { let storage_url = @@ -270,6 +277,7 @@ impl CreateBuilder { ( storage_url.as_str().to_string(), DeltaTableBuilder::from_uri(storage_url)? + .with_table_config(self.table_config.clone()) .with_storage_options(self.storage_options.clone().unwrap_or_default()) .build()?, ) diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index 119b057fff..b87a3f5ca6 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -19,7 +19,7 @@ use async_trait::async_trait; use datafusion::catalog::Session; -use datafusion::common::ScalarValue; +use datafusion::common::{exec_datafusion_err, ScalarValue}; use datafusion::dataframe::DataFrame; use datafusion::datasource::provider_as_source; use datafusion::error::Result as DataFusionResult; @@ -59,6 +59,9 @@ use crate::operations::write::WriterStatsConfig; use crate::operations::CustomExecuteHandler; use crate::protocol::DeltaOperation; use crate::table::config::TablePropertiesExt as _; +use crate::table::file_format_options::{ + state_with_file_format_options, IntoWriterPropertiesFactoryRef, WriterPropertiesFactoryRef, +}; use crate::table::state::DeltaTableState; use crate::{DeltaTable, DeltaTableError}; @@ -78,7 +81,7 @@ pub struct DeleteBuilder { /// Datafusion session state relevant for executing the input plan session: Option>, /// Properties passed to underlying parquet writer for when files are rewritten - writer_properties: Option, + writer_properties_factory: Option, /// Commit properties and configuration commit_properties: CommitProperties, custom_execute_handler: Option>, @@ -126,13 +129,22 @@ impl super::Operation for DeleteBuilder { impl DeleteBuilder { /// Create a new [`DeleteBuilder`] pub(crate) fn new(log_store: LogStoreRef, snapshot: Option) -> Self { + let file_format_options = snapshot + .as_ref() + .map(|ss| ss.load_config().file_format_options.clone()); + let writer_properties_factory = match file_format_options { + Some(file_format_options) => file_format_options + .clone() + .map(|ffo| ffo.writer_properties_factory()), + None => None, + }; Self { predicate: None, snapshot, log_store, session: None, commit_properties: CommitProperties::default(), - writer_properties: None, + writer_properties_factory, custom_execute_handler: None, } } @@ -157,7 +169,8 @@ impl DeleteBuilder { /// Writer properties passed to parquet writer for when files are rewritten pub fn with_writer_properties(mut self, writer_properties: WriterProperties) -> Self { - self.writer_properties = Some(writer_properties); + let writer_properties_factory = writer_properties.into_factory_ref(); + self.writer_properties_factory = Some(writer_properties_factory); self } @@ -189,6 +202,20 @@ impl std::future::IntoFuture for DeleteBuilder { register_store(this.log_store.clone(), session.runtime_env().as_ref()); + let file_format_options = &snapshot.load_config().file_format_options; + let session_state = + session + .as_any() + .downcast_ref::() + .ok_or_else(|| { + exec_datafusion_err!("Failed to downcast Session to SessionState") + })?; + + let session = Arc::new(state_with_file_format_options( + session_state.clone(), + file_format_options.as_ref(), + )?); + let predicate = match this.predicate { Some(predicate) => match predicate { Expression::DataFusion(expr) => Some(expr), @@ -203,8 +230,8 @@ impl std::future::IntoFuture for DeleteBuilder { predicate, this.log_store.clone(), snapshot, - session.as_ref(), - this.writer_properties, + session.as_ref().clone(), + this.writer_properties_factory, this.commit_properties, operation_id, this.custom_execute_handler.as_ref(), @@ -268,7 +295,7 @@ async fn execute_non_empty_expr( expression: &Expr, rewrite: &[Add], metrics: &mut DeleteMetrics, - writer_properties: Option, + writer_properties_factory: Option, partition_scan: bool, operation_id: Uuid, ) -> DeltaResult> { @@ -322,7 +349,7 @@ async fn execute_non_empty_expr( log_store.object_store(Some(operation_id)), Some(snapshot.table_properties().target_file_size().get() as usize), None, - writer_properties.clone(), + writer_properties_factory.clone(), writer_stats_config.clone(), ) .await?; @@ -359,7 +386,7 @@ async fn execute_non_empty_expr( log_store.object_store(Some(operation_id)), Some(snapshot.table_properties().target_file_size().get() as usize), None, - writer_properties, + writer_properties_factory, writer_stats_config, ) .await?; @@ -375,8 +402,8 @@ async fn execute( predicate: Option, log_store: LogStoreRef, snapshot: EagerSnapshot, - session: &dyn Session, - writer_properties: Option, + session: SessionState, + writer_properties_factory: Option, mut commit_properties: CommitProperties, operation_id: Uuid, handle: Option<&Arc>, @@ -389,7 +416,7 @@ async fn execute( let mut metrics = DeleteMetrics::default(); let scan_start = Instant::now(); - let candidates = find_files(&snapshot, log_store.clone(), session, predicate.clone()).await?; + let candidates = find_files(&snapshot, log_store.clone(), &session, predicate.clone()).await?; metrics.scan_time_ms = Instant::now().duration_since(scan_start).as_millis() as u64; let predicate = predicate.unwrap_or(lit(true)); @@ -399,11 +426,11 @@ async fn execute( let add = execute_non_empty_expr( &snapshot, log_store.clone(), - session, + &session, &predicate, &candidates.candidates, &mut metrics, - writer_properties, + writer_properties_factory.clone(), candidates.partition_scan, operation_id, ) diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 03804b6edc..4eee686a34 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -58,7 +58,6 @@ use datafusion::{ physical_plan::ExecutionPlan, prelude::{cast, DataFrame, SessionContext}, }; - use delta_kernel::engine::arrow_conversion::{TryIntoArrow as _, TryIntoKernel as _}; use delta_kernel::schema::{ColumnMetadataKey, StructType}; use filter::try_construct_early_filter; @@ -92,6 +91,9 @@ use crate::operations::write::generated_columns::{ use crate::operations::write::WriterStatsConfig; use crate::protocol::{DeltaOperation, MergePredicate}; use crate::table::config::TablePropertiesExt as _; +use crate::table::file_format_options::{ + state_with_file_format_options, IntoWriterPropertiesFactoryRef, WriterPropertiesFactoryRef, +}; use crate::table::state::DeltaTableState; use crate::{DeltaResult, DeltaTable, DeltaTableError}; @@ -146,7 +148,7 @@ pub struct MergeBuilder { /// Datafusion session state relevant for executing the input plan state: Option>, /// Properties passed to underlying parquet writer for when files are rewritten - writer_properties: Option, + writer_properties_factory: Option, /// Additional information to add to the commit commit_properties: CommitProperties, /// safe_cast determines how data types that do not match the underlying table are handled @@ -173,6 +175,15 @@ impl MergeBuilder { source: DataFrame, ) -> Self { let predicate = predicate.into(); + let file_format_options = snapshot + .as_ref() + .map(|ss| ss.load_config().file_format_options.clone()); + let writer_properties_factory = match file_format_options { + Some(file_format_options) => file_format_options + .clone() + .map(|ffo| ffo.writer_properties_factory()), + None => None, + }; Self { predicate, source, @@ -182,7 +193,7 @@ impl MergeBuilder { target_alias: None, state: None, commit_properties: CommitProperties::default(), - writer_properties: None, + writer_properties_factory, merge_schema: false, match_operations: Vec::new(), not_match_operations: Vec::new(), @@ -395,9 +406,10 @@ impl MergeBuilder { self } - /// Writer properties passed to parquet writer for when fiiles are rewritten + /// Writer properties passed to parquet writer for when files are rewritten pub fn with_writer_properties(mut self, writer_properties: WriterProperties) -> Self { - self.writer_properties = Some(writer_properties); + let writer_properties_factory = writer_properties.into_factory_ref(); + self.writer_properties_factory = Some(writer_properties_factory); self } @@ -740,7 +752,7 @@ async fn execute( log_store: LogStoreRef, snapshot: EagerSnapshot, state: SessionState, - writer_properties: Option, + writer_properties_factory: Option, mut commit_properties: CommitProperties, _safe_cast: bool, streaming: bool, @@ -775,6 +787,9 @@ async fn execute( let current_metadata = snapshot.metadata(); let merge_planner = DeltaPlanner::new(); + let file_format_options = snapshot.load_config().file_format_options.clone(); + let state = state_with_file_format_options(state, file_format_options.as_ref())?; + let state = SessionStateBuilder::new_from_existing(state) .with_query_planner(merge_planner) .build(); @@ -1408,7 +1423,7 @@ async fn execute( log_store.object_store(Some(operation_id)), Some(snapshot.table_properties().target_file_size().get() as usize), None, - writer_properties.clone(), + writer_properties_factory.clone(), writer_stats_config.clone(), None, should_cdc, // if true, write execution plan splits batches in [normal, cdc] data before writing @@ -1563,7 +1578,7 @@ impl std::future::IntoFuture for MergeBuilder { this.log_store.clone(), snapshot, state, - this.writer_properties, + this.writer_properties_factory, this.commit_properties, this.safe_cast, this.streaming, diff --git a/crates/core/src/operations/mod.rs b/crates/core/src/operations/mod.rs index e122ce5584..8f13b94846 100644 --- a/crates/core/src/operations/mod.rs +++ b/crates/core/src/operations/mod.rs @@ -36,6 +36,8 @@ use crate::logstore::LogStoreRef; use crate::operations::generate::GenerateBuilder; use crate::table::builder::{ensure_table_uri, DeltaTableBuilder}; use crate::table::config::{TablePropertiesExt as _, DEFAULT_NUM_INDEX_COLS}; +use crate::table::file_format_options::FileFormatRef; +use crate::table::state::DeltaTableState; use crate::DeltaTable; pub mod add_column; @@ -178,6 +180,28 @@ impl DeltaOps { } } + /// Set options for parquet files + pub async fn with_file_format_options( + mut self, + file_format_options: FileFormatRef, + ) -> DeltaResult { + // Update table-level config so future loads/operations use these options + self.0.config.file_format_options = Some(file_format_options); + + // Update the in-memory state and snapshot config to match the top level table config + if self.0.state.is_some() { + self.0.state = Some( + DeltaTableState::try_new( + &self.0.log_store, + self.0.config.clone(), + Some(self.0.state.unwrap().version()), + ) + .await?, + ); + } + Ok(self) + } + /// Create a [`DeltaOps`] instance from uri string with storage options (deprecated) #[deprecated(note = "Use try_from_uri_with_storage_options with url::Url instead")] pub async fn try_from_uri_str_with_storage_options( @@ -221,7 +245,9 @@ impl DeltaOps { /// ``` #[must_use] pub fn create(self) -> CreateBuilder { - CreateBuilder::default().with_log_store(self.0.log_store) + CreateBuilder::default() + .with_log_store(self.0.log_store) + .with_table_config(self.0.config.clone()) } /// Generate a symlink_format_manifest for other engines diff --git a/crates/core/src/operations/optimize.rs b/crates/core/src/operations/optimize.rs index c3e61da231..ebf39025f6 100644 --- a/crates/core/src/operations/optimize.rs +++ b/crates/core/src/operations/optimize.rs @@ -28,10 +28,12 @@ use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use arrow::array::RecordBatch; use arrow::datatypes::SchemaRef; use datafusion::catalog::Session; +use datafusion::error::DataFusionError; use datafusion::execution::context::SessionState; use datafusion::execution::memory_pool::FairSpillPool; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::SessionStateBuilder; +use datafusion::prelude::SessionContext; use delta_kernel::engine::arrow_conversion::TryIntoArrow as _; use delta_kernel::expressions::Scalar; use delta_kernel::table_properties::DataSkippingNumIndexedCols; @@ -41,8 +43,10 @@ use futures::{Future, StreamExt, TryStreamExt}; use indexmap::IndexMap; use itertools::Itertools; use num_cpus; +use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStreamBuilder}; use parquet::basic::{Compression, ZstdLevel}; +use parquet::encryption::decrypt::FileDecryptionProperties; use parquet::errors::ParquetError; use parquet::file::properties::WriterProperties; use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize, Serializer}; @@ -59,6 +63,9 @@ use crate::kernel::{scalars::ScalarExt, Action, Add, PartitionsExt, Remove}; use crate::logstore::{LogStore, LogStoreRef, ObjectStoreRef}; use crate::protocol::DeltaOperation; use crate::table::config::TablePropertiesExt as _; +use crate::table::file_format_options::{ + FileFormatRef, IntoWriterPropertiesFactoryRef, WriterPropertiesFactoryRef, +}; use crate::table::state::DeltaTableState; use crate::writer::utils::arrow_schema_without_partitions; use crate::{crate_version, DeltaTable, ObjectMeta, PartitionFilter}; @@ -208,7 +215,7 @@ pub struct OptimizeBuilder<'a> { /// Desired file size after bin-packing files target_size: Option, /// Properties passed to underlying parquet writer - writer_properties: Option, + writer_properties_factory: Option, /// Commit properties and configuration commit_properties: CommitProperties, /// Whether to preserve insertion order within files (default false) @@ -237,12 +244,29 @@ impl super::Operation for OptimizeBuilder<'_> { impl<'a> OptimizeBuilder<'a> { /// Create a new [`OptimizeBuilder`] pub(crate) fn new(log_store: LogStoreRef, snapshot: Option) -> Self { + let file_format_options = snapshot + .as_ref() + .map(|ss| ss.load_config().file_format_options.clone()); + let writer_properties_factory = match file_format_options.as_ref() { + None => { + let wp = WriterProperties::builder() + .set_compression(Compression::ZSTD(ZstdLevel::try_new(4).unwrap())) + .set_created_by(format!("delta-rs version {}", crate_version())) + .build(); + let wpf = wp.into_factory_ref(); + Some(wpf) + } + Some(file_format_options) => file_format_options + .clone() + .map(|ffo| ffo.writer_properties_factory()), + }; + Self { snapshot, log_store, filters: &[], target_size: None, - writer_properties: None, + writer_properties_factory, commit_properties: CommitProperties::default(), preserve_insertion_order: false, max_concurrent_tasks: num_cpus::get(), @@ -274,7 +298,8 @@ impl<'a> OptimizeBuilder<'a> { /// Writer properties passed to parquet writer pub fn with_writer_properties(mut self, writer_properties: WriterProperties) -> Self { - self.writer_properties = Some(writer_properties); + let writer_properties_factory = writer_properties.into_factory_ref(); + self.writer_properties_factory = Some(writer_properties_factory); self } @@ -338,13 +363,6 @@ impl<'a> std::future::IntoFuture for OptimizeBuilder<'a> { let operation_id = this.get_operation_id(); this.pre_execute(operation_id).await?; - - let writer_properties = this.writer_properties.unwrap_or_else(|| { - WriterProperties::builder() - .set_compression(Compression::ZSTD(ZstdLevel::try_new(4).unwrap())) - .set_created_by(format!("delta-rs version {}", crate_version())) - .build() - }); let session = this .session .and_then(|session| session.as_any().downcast_ref::().cloned()) @@ -365,7 +383,7 @@ impl<'a> std::future::IntoFuture for OptimizeBuilder<'a> { &snapshot, this.filters, this.target_size.to_owned(), - writer_properties, + this.writer_properties_factory, session, ) .await?; @@ -492,7 +510,7 @@ pub struct MergeTaskParameters { /// Schema of written files file_schema: SchemaRef, /// Properties passed to parquet writer - writer_properties: WriterProperties, + writer_properties_factory: Option, /// Num index cols to collect stats for num_indexed_cols: DataSkippingNumIndexedCols, /// Stats columns, specific columns to collect stats from, takes precedence over num_indexed_cols @@ -548,7 +566,7 @@ impl MergePlan { let writer_config = PartitionWriterConfig::try_new( task_parameters.file_schema.clone(), partition_values.clone(), - Some(task_parameters.writer_properties.clone()), + task_parameters.writer_properties_factory.clone(), Some(task_parameters.input_parameters.target_size as usize), None, None, @@ -646,6 +664,7 @@ impl MergePlan { let operations = std::mem::take(&mut self.operations); info!("starting optimize execution"); let object_store = log_store.object_store(Some(operation_id)); + let ffo = snapshot.load_config().file_format_options.clone(); let stream = match operations { OptimizeOperations::Compact(bins) => futures::stream::iter(bins) @@ -661,17 +680,39 @@ impl MergePlan { debug!(" file {}", file.path); } let object_store_ref = object_store.clone(); + let file_format_options = Arc::new(ffo.clone()); let batch_stream = futures::stream::iter(files.clone()) .then(move |file| { let object_store_ref = object_store_ref.clone(); let meta = ObjectMeta::try_from(file).unwrap(); + let file_format_options = file_format_options.clone(); async move { + let decrypt: Option> = + match &*file_format_options { + Some(ffo) => { + get_file_decryption_properties(ffo, &meta.location) + .await + .map_err(|e| { + ParquetError::General(format!( + "Error getting file decryption properties: {e}" + )) + })? + } + None => None, + }; let file_reader = ParquetObjectReader::new(object_store_ref, meta.location) .with_file_size(meta.size); - ParquetRecordBatchStreamBuilder::new(file_reader) - .await? - .build() + let mut options = ArrowReaderOptions::new(); + if let Some(decrypt) = decrypt { + options = options.with_file_decryption_properties(decrypt); + } + ParquetRecordBatchStreamBuilder::new_with_options( + file_reader, + options, + ) + .await? + .build() } }) .try_flatten() @@ -711,16 +752,15 @@ impl MergePlan { let log_store = log_store.clone(); futures::stream::iter(bins) .map(move |(_, (partition, files))| { - let batch_stream = Self::read_zorder( - files.clone(), - exec_context.clone(), - DeltaTableProvider::try_new( - snapshot.clone(), - log_store.clone(), - scan_config.clone(), - ) - .unwrap(), - ); + let dtp = DeltaTableProvider::try_new( + snapshot.clone(), + log_store.clone(), + scan_config.clone(), + ) + .unwrap(); + + let batch_stream = + Self::read_zorder(files.clone(), exec_context.clone(), dtp); let rewrite_result = tokio::task::spawn(Self::rewrite_files( task_parameters.clone(), partition, @@ -832,7 +872,7 @@ pub async fn create_merge_plan( snapshot: &EagerSnapshot, filters: &[PartitionFilter], target_size: Option, - writer_properties: WriterProperties, + writer_properties_factory: Option, session: SessionState, ) -> Result { let target_size = @@ -879,7 +919,7 @@ pub async fn create_merge_plan( task_parameters: Arc::new(MergeTaskParameters { input_parameters, file_schema, - writer_properties, + writer_properties_factory, num_indexed_cols: snapshot.table_properties().num_indexed_cols(), stats_columns: snapshot .table_properties() @@ -1080,6 +1120,29 @@ async fn build_zorder_plan( Ok((operation, metrics)) } +async fn get_file_decryption_properties( + file_format_options: &FileFormatRef, + file_path: &object_store::path::Path, +) -> Result>, DataFusionError> { + let parquet_options = file_format_options.table_options().parquet; + if let Some(props) = &parquet_options.crypto.file_decryption { + return Ok(Some(Arc::new(props.clone().into()))); + } + if let Some(factory_id) = &parquet_options.crypto.factory_id { + // Create a temporary DataFusion session to access the encryption factory + let ctx = SessionContext::default(); + let state = ctx.state(); + file_format_options.update_session(&state)?; + let encryption_factory = state.runtime_env().parquet_encryption_factory(factory_id)?; + let config = &parquet_options.crypto.factory_options; + encryption_factory + .get_file_decryption_properties(config, file_path) + .await + } else { + Ok(None) + } +} + pub(super) mod util { use super::*; use futures::Future; diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index 8ab6683352..c5340936bb 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -56,6 +56,9 @@ use super::{ use crate::logstore::LogStoreRef; use crate::operations::cdc::*; use crate::protocol::DeltaOperation; +use crate::table::file_format_options::{ + state_with_file_format_options, IntoWriterPropertiesFactoryRef, WriterPropertiesFactoryRef, +}; use crate::table::state::DeltaTableState; use crate::{ delta_datafusion::{ @@ -99,7 +102,7 @@ pub struct UpdateBuilder { /// Datafusion session state relevant for executing the input plan session: Option>, /// Properties passed to underlying parquet writer for when files are rewritten - writer_properties: Option, + writer_properties_factory: Option, /// Additional information to add to the commit commit_properties: CommitProperties, /// safe_cast determines how data types that do not match the underlying table are handled @@ -137,13 +140,22 @@ impl super::Operation for UpdateBuilder { impl UpdateBuilder { /// Create a new ['UpdateBuilder'] pub(crate) fn new(log_store: LogStoreRef, snapshot: Option) -> Self { + let file_format_options = snapshot + .as_ref() + .map(|ss| ss.load_config().file_format_options.clone()); + let writer_properties_factory = match file_format_options { + Some(file_format_options) => file_format_options + .clone() + .map(|ffo| ffo.writer_properties_factory()), + None => None, + }; Self { predicate: None, updates: HashMap::new(), snapshot, log_store, session: None, - writer_properties: None, + writer_properties_factory, commit_properties: CommitProperties::default(), safe_cast: false, custom_execute_handler: None, @@ -180,7 +192,8 @@ impl UpdateBuilder { /// Writer properties passed to parquet writer for when fiiles are rewritten pub fn with_writer_properties(mut self, writer_properties: WriterProperties) -> Self { - self.writer_properties = Some(writer_properties); + let writer_properties_factory = writer_properties.into_factory_ref(); + self.writer_properties_factory = Some(writer_properties_factory); self } @@ -257,7 +270,7 @@ async fn execute( log_store: LogStoreRef, snapshot: EagerSnapshot, session: SessionState, - writer_properties: Option, + writer_properties_factory: Option, mut commit_properties: CommitProperties, _safe_cast: bool, operation_id: Uuid, @@ -285,6 +298,9 @@ async fn execute( .cloned() .collect(); + let file_format_options = snapshot.load_config().file_format_options.clone(); + let session = state_with_file_format_options(session, file_format_options.as_ref())?; + let update_planner = DeltaPlanner::new(); let session = SessionStateBuilder::from(session) @@ -401,6 +417,14 @@ async fn execute( let tracker = CDCTracker::new(df, updated_df); + let writer_properties_factory = if writer_properties_factory.is_some() { + writer_properties_factory + } else { + file_format_options + .clone() + .map(|ffo| ffo.writer_properties_factory()) + }; + let add_actions = write_execution_plan( Some(&snapshot), &session, @@ -409,7 +433,7 @@ async fn execute( log_store.object_store(Some(operation_id)).clone(), Some(snapshot.table_properties().target_file_size().get() as usize), None, - writer_properties.clone(), + writer_properties_factory.clone(), writer_stats_config.clone(), ) .await?; @@ -471,7 +495,7 @@ async fn execute( log_store.object_store(Some(operation_id)), Some(snapshot.table_properties().target_file_size().get() as usize), None, - writer_properties, + writer_properties_factory, writer_stats_config, ) .await?; @@ -523,7 +547,7 @@ impl std::future::IntoFuture for UpdateBuilder { this.log_store.clone(), snapshot, state.clone(), - this.writer_properties, + this.writer_properties_factory, this.commit_properties, this.safe_cast, operation_id, diff --git a/crates/core/src/operations/write/execution.rs b/crates/core/src/operations/write/execution.rs index c852ea4093..97c966b494 100644 --- a/crates/core/src/operations/write/execution.rs +++ b/crates/core/src/operations/write/execution.rs @@ -14,7 +14,6 @@ use datafusion::prelude::DataFrame; use delta_kernel::engine::arrow_conversion::TryIntoKernel as _; use futures::StreamExt; use object_store::prefix::PrefixStore; -use parquet::file::properties::WriterProperties; use tokio::sync::mpsc; use tracing::log::*; use uuid::Uuid; @@ -31,6 +30,7 @@ use crate::logstore::{LogStoreRef, ObjectStoreRef}; use crate::operations::cdc::{should_write_cdc, CDC_COLUMN_NAME}; use crate::operations::write::WriterStatsConfig; use crate::table::config::TablePropertiesExt as _; +use crate::table::file_format_options::WriterPropertiesFactoryRef; use crate::table::Constraint as DeltaConstraint; use crate::DeltaTableError; @@ -62,7 +62,7 @@ pub(crate) async fn write_execution_plan_cdc( object_store: ObjectStoreRef, target_file_size: Option, write_batch_size: Option, - writer_properties: Option, + writer_properties_factory: Option, writer_stats_config: WriterStatsConfig, ) -> DeltaResult> { let cdc_store = Arc::new(PrefixStore::new(object_store, "_change_data")); @@ -75,7 +75,7 @@ pub(crate) async fn write_execution_plan_cdc( cdc_store, target_file_size, write_batch_size, - writer_properties, + writer_properties_factory, writer_stats_config, ) .await? @@ -109,7 +109,7 @@ pub(crate) async fn write_execution_plan( object_store: ObjectStoreRef, target_file_size: Option, write_batch_size: Option, - writer_properties: Option, + writer_properties_factory: Option, writer_stats_config: WriterStatsConfig, ) -> DeltaResult> { let (actions, _) = write_execution_plan_v2( @@ -120,7 +120,7 @@ pub(crate) async fn write_execution_plan( object_store, target_file_size, write_batch_size, - writer_properties, + writer_properties_factory, writer_stats_config, None, false, @@ -137,7 +137,7 @@ pub(crate) async fn execute_non_empty_expr( partition_columns: Vec, expression: &Expr, rewrite: &[Add], - writer_properties: Option, + writer_properties_factory: Option, writer_stats_config: WriterStatsConfig, partition_scan: bool, operation_id: Uuid, @@ -177,7 +177,7 @@ pub(crate) async fn execute_non_empty_expr( log_store.object_store(Some(operation_id)), Some(snapshot.table_properties().target_file_size().get() as usize), None, - writer_properties.clone(), + writer_properties_factory.clone(), writer_stats_config.clone(), ) .await?; @@ -212,7 +212,7 @@ pub(crate) async fn prepare_predicate_actions( snapshot: &EagerSnapshot, session: &dyn Session, partition_columns: Vec, - writer_properties: Option, + writer_properties_factory: Option, deletion_timestamp: i64, writer_stats_config: WriterStatsConfig, operation_id: Uuid, @@ -232,7 +232,7 @@ pub(crate) async fn prepare_predicate_actions( partition_columns, &predicate, &candidates.candidates, - writer_properties, + writer_properties_factory, writer_stats_config, candidates.partition_scan, operation_id, @@ -267,7 +267,7 @@ pub(crate) async fn write_execution_plan_v2( object_store: ObjectStoreRef, target_file_size: Option, write_batch_size: Option, - writer_properties: Option, + writer_properties_factory: Option, writer_stats_config: WriterStatsConfig, predicate: Option, contains_cdc: bool, @@ -299,7 +299,7 @@ pub(crate) async fn write_execution_plan_v2( let config = WriterConfig::new( schema.clone(), partition_columns.clone(), - writer_properties.clone(), + writer_properties_factory.clone(), target_file_size, write_batch_size, writer_stats_config.num_indexed_cols, @@ -393,7 +393,7 @@ pub(crate) async fn write_execution_plan_v2( let normal_config = WriterConfig::new( write_schema.clone(), partition_columns.clone(), - writer_properties.clone(), + writer_properties_factory.clone(), target_file_size, write_batch_size, writer_stats_config.num_indexed_cols, @@ -403,7 +403,7 @@ pub(crate) async fn write_execution_plan_v2( let cdf_config = WriterConfig::new( cdf_schema.clone(), partition_columns.clone(), - writer_properties.clone(), + writer_properties_factory.clone(), target_file_size, write_batch_size, writer_stats_config.num_indexed_cols, diff --git a/crates/core/src/operations/write/mod.rs b/crates/core/src/operations/write/mod.rs index ee5904bb02..de0a5d1030 100644 --- a/crates/core/src/operations/write/mod.rs +++ b/crates/core/src/operations/write/mod.rs @@ -67,6 +67,9 @@ use crate::kernel::{ }; use crate::logstore::LogStoreRef; use crate::protocol::{DeltaOperation, SaveMode}; +use crate::table::file_format_options::{ + IntoWriterPropertiesFactoryRef, WriterPropertiesFactoryRef, +}; use crate::DeltaTable; pub mod configs; @@ -150,7 +153,7 @@ pub struct WriteBuilder { /// how to handle cast failures, either return NULL (safe=true) or return ERR (safe=false) safe_cast: bool, /// Parquet writer properties - writer_properties: Option, + writer_properties_factory: Option, /// Additional information to add to the commit commit_properties: CommitProperties, /// Name of the table, only used when table doesn't exist yet @@ -189,6 +192,10 @@ impl super::Operation for WriteBuilder { impl WriteBuilder { /// Create a new [`WriteBuilder`] pub fn new(log_store: LogStoreRef, snapshot: Option) -> Self { + let ffo = snapshot + .as_ref() + .and_then(|s| s.load_config().file_format_options.clone()); + let writer_properties_factory = ffo.map(|ffo| ffo.writer_properties_factory()); Self { snapshot, log_store, @@ -201,7 +208,7 @@ impl WriteBuilder { write_batch_size: None, safe_cast: false, schema_mode: None, - writer_properties: None, + writer_properties_factory, commit_properties: CommitProperties::default(), name: None, description: None, @@ -278,7 +285,8 @@ impl WriteBuilder { /// Specify the writer properties to use when writing a parquet file pub fn with_writer_properties(mut self, writer_properties: WriterProperties) -> Self { - self.writer_properties = Some(writer_properties); + let writer_properties_factory = writer_properties.into_factory_ref(); + self.writer_properties_factory = Some(writer_properties_factory); self } @@ -642,7 +650,7 @@ impl std::future::IntoFuture for WriteBuilder { snapshot, session.as_ref(), partition_columns.clone(), - this.writer_properties.clone(), + this.writer_properties_factory.clone(), deletion_timestamp, writer_stats_config.clone(), operation_id, @@ -686,7 +694,7 @@ impl std::future::IntoFuture for WriteBuilder { this.log_store.object_store(Some(operation_id)).clone(), target_file_size, this.write_batch_size, - this.writer_properties, + this.writer_properties_factory, writer_stats_config.clone(), predicate.clone(), contains_cdc, diff --git a/crates/core/src/operations/write/writer.rs b/crates/core/src/operations/write/writer.rs index 3752f22083..3a7dfa675f 100644 --- a/crates/core/src/operations/write/writer.rs +++ b/crates/core/src/operations/write/writer.rs @@ -23,6 +23,9 @@ use crate::crate_version; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Add, PartitionsExt}; use crate::logstore::ObjectStoreRef; +use crate::table::file_format_options::{ + IntoWriterPropertiesFactoryRef, WriterPropertiesFactoryRef, +}; use crate::writer::record_batch::{divide_by_partition_values, PartitionResult}; use crate::writer::stats::create_add; use crate::writer::utils::{ @@ -130,7 +133,7 @@ pub struct WriterConfig { /// Column names for columns the table is partitioned by partition_columns: Vec, /// Properties passed to underlying parquet writer - writer_properties: WriterProperties, + writer_properties_factory: WriterPropertiesFactoryRef, /// Size above which we will write a buffered parquet file to disk. target_file_size: usize, /// Row chunks passed to parquet writer. This and the internal parquet writer settings @@ -147,16 +150,18 @@ impl WriterConfig { pub fn new( table_schema: ArrowSchemaRef, partition_columns: Vec, - writer_properties: Option, + writer_properties_factory: Option, target_file_size: Option, write_batch_size: Option, num_indexed_cols: DataSkippingNumIndexedCols, stats_columns: Option>, ) -> Self { - let writer_properties = writer_properties.unwrap_or_else(|| { - WriterProperties::builder() + let writer_properties_factory = writer_properties_factory.unwrap_or_else(|| { + // Keep these compression defaults for backwards compatibility + let wp = WriterProperties::builder() .set_compression(Compression::SNAPPY) - .build() + .build(); + wp.into_factory_ref() }); let target_file_size = target_file_size.unwrap_or(DEFAULT_TARGET_FILE_SIZE); let write_batch_size = write_batch_size.unwrap_or(DEFAULT_WRITE_BATCH_SIZE); @@ -164,7 +169,7 @@ impl WriterConfig { Self { table_schema, partition_columns, - writer_properties, + writer_properties_factory, target_file_size, write_batch_size, num_indexed_cols, @@ -200,7 +205,8 @@ impl DeltaWriter { /// Apply custom writer_properties to the underlying parquet writer pub fn with_writer_properties(mut self, writer_properties: WriterProperties) -> Self { - self.config.writer_properties = writer_properties; + let writer_properties_factory = writer_properties.into_factory_ref(); + self.config.writer_properties_factory = writer_properties_factory; self } @@ -237,7 +243,7 @@ impl DeltaWriter { let config = PartitionWriterConfig::try_new( self.config.file_schema(), partition_values.clone(), - Some(self.config.writer_properties.clone()), + Some(self.config.writer_properties_factory.clone()), Some(self.config.target_file_size), Some(self.config.write_batch_size), None, @@ -300,7 +306,7 @@ pub struct PartitionWriterConfig { /// Values for all partition columns partition_values: IndexMap, /// Properties passed to underlying parquet writer - writer_properties: WriterProperties, + writer_properties_factory: WriterPropertiesFactoryRef, /// Size above which we will write a buffered parquet file to disk. target_file_size: usize, /// Row chunks passed to parquet writer. This and the internal parquet writer settings @@ -315,17 +321,19 @@ impl PartitionWriterConfig { pub fn try_new( file_schema: ArrowSchemaRef, partition_values: IndexMap, - writer_properties: Option, + writer_properties_factory: Option, target_file_size: Option, write_batch_size: Option, max_concurrency_tasks: Option, ) -> DeltaResult { let part_path = partition_values.hive_partition_path(); let prefix = Path::parse(part_path)?; - let writer_properties = writer_properties.unwrap_or_else(|| { - WriterProperties::builder() + let writer_properties_factory = writer_properties_factory.unwrap_or_else(|| { + // These particular compression settings are required by writer::tests::test_unflushed_row_group_size + let wp = WriterProperties::builder() .set_created_by(format!("delta-rs version {}", crate_version())) - .build() + .build(); + wp.into_factory_ref() }); let target_file_size = target_file_size.unwrap_or(DEFAULT_TARGET_FILE_SIZE); let write_batch_size = write_batch_size.unwrap_or(DEFAULT_WRITE_BATCH_SIZE); @@ -334,7 +342,7 @@ impl PartitionWriterConfig { file_schema, prefix, partition_values, - writer_properties, + writer_properties_factory, target_file_size, write_batch_size, max_concurrency_tasks: max_concurrency_tasks.unwrap_or_else(get_max_concurrency_tasks), @@ -359,10 +367,14 @@ impl LazyArrowWriter { ) .with_max_concurrency(config.max_concurrency_tasks), ); + let writer_properties = config + .writer_properties_factory + .create_writer_properties(path, &config.file_schema) + .await?; let mut arrow_writer = AsyncArrowWriter::try_new( writer, config.file_schema.clone(), - Some(config.writer_properties.clone()), + Some(writer_properties), )?; arrow_writer.write(batch).await?; *self = LazyArrowWriter::Writing(path.clone(), arrow_writer); @@ -395,6 +407,7 @@ pub struct PartitionWriter { config: PartitionWriterConfig, writer: LazyArrowWriter, part_counter: usize, + data_path: Path, /// Num index cols to collect stats for num_indexed_cols: DataSkippingNumIndexedCols, /// Stats columns, specific columns to collect stats from, takes precedence over num_indexed_cols @@ -411,15 +424,21 @@ impl PartitionWriter { stats_columns: Option>, ) -> DeltaResult { let writer_id = uuid::Uuid::new_v4(); - let first_path = next_data_path(&config.prefix, 0, &writer_id, &config.writer_properties); - let writer = Self::create_writer(object_store.clone(), first_path.clone(), &config)?; + let data_path = next_data_path( + &config.prefix, + 0, + &writer_id, + config.writer_properties_factory.clone(), + ); + let writer = Self::create_writer(object_store.clone(), data_path.clone(), &config)?; Ok(Self { object_store, writer_id, config, writer, part_counter: 0, + data_path, num_indexed_cols, stats_columns, in_flight_writers: JoinSet::new(), @@ -442,7 +461,7 @@ impl PartitionWriter { &self.config.prefix, self.part_counter, &self.writer_id, - &self.config.writer_properties, + self.config.writer_properties_factory.clone(), ) } @@ -552,10 +571,12 @@ mod tests { target_file_size: Option, write_batch_size: Option, ) -> DeltaWriter { + let writer_properties_factory = writer_properties.map(|wp| wp.into_factory_ref()); + let config = WriterConfig::new( batch.schema(), vec![], - writer_properties, + writer_properties_factory, target_file_size, write_batch_size, DataSkippingNumIndexedCols::NumColumns(DEFAULT_NUM_INDEX_COLS), @@ -564,17 +585,18 @@ mod tests { DeltaWriter::new(object_store, config) } - fn get_partition_writer( + async fn get_partition_writer( object_store: ObjectStoreRef, batch: &RecordBatch, writer_properties: Option, target_file_size: Option, write_batch_size: Option, ) -> PartitionWriter { + let writer_properties_factory = writer_properties.map(|wp| wp.into_factory_ref()); let config = PartitionWriterConfig::try_new( batch.schema(), IndexMap::new(), - writer_properties, + writer_properties_factory, target_file_size, write_batch_size, None, @@ -599,7 +621,7 @@ mod tests { let batch = get_record_batch(None, false); // write single un-partitioned batch - let mut writer = get_partition_writer(object_store.clone(), &batch, None, None, None); + let mut writer = get_partition_writer(object_store.clone(), &batch, None, None, None).await; writer.write(&batch).await.unwrap(); let files = list(object_store.as_ref(), None).await.unwrap(); assert_eq!(files.len(), 0); @@ -634,7 +656,7 @@ mod tests { .build(); // configure small target file size and and row group size so we can observe multiple files written let mut writer = - get_partition_writer(object_store, &batch, Some(properties), Some(10_000), None); + get_partition_writer(object_store, &batch, Some(properties), Some(10_000), None).await; writer.write(&batch).await.unwrap(); // check that we have written more then once file, and no more then 1 is below target size @@ -662,7 +684,7 @@ mod tests { .unwrap() .object_store(None); // configure small target file size so we can observe multiple files written - let mut writer = get_partition_writer(object_store, &batch, None, Some(10_000), None); + let mut writer = get_partition_writer(object_store, &batch, None, Some(10_000), None).await; writer.write(&batch).await.unwrap(); // check that we have written more then once file, and no more then 1 is below target size @@ -691,7 +713,8 @@ mod tests { .object_store(None); // configure high batch size and low file size to observe one file written and flushed immediately // upon writing batch, then ensures the buffer is empty upon closing writer - let mut writer = get_partition_writer(object_store, &batch, None, Some(9000), Some(10000)); + let mut writer = + get_partition_writer(object_store, &batch, None, Some(9000), Some(10000)).await; writer.write(&batch).await.unwrap(); let adds = writer.close().await.unwrap(); diff --git a/crates/core/src/table/builder.rs b/crates/core/src/table/builder.rs index d86c182b51..9aaa7920d3 100644 --- a/crates/core/src/table/builder.rs +++ b/crates/core/src/table/builder.rs @@ -13,6 +13,7 @@ use url::Url; use crate::logstore::storage::IORuntime; use crate::logstore::{object_store_factories, LogStoreRef, StorageConfig}; +use crate::table::file_format_options::FileFormatRef; use crate::{DeltaResult, DeltaTable, DeltaTableError}; /// possible version specifications for loading a delta table @@ -51,6 +52,11 @@ pub struct DeltaTableConfig { /// when processing record batches. pub log_batch_size: usize, + #[serde(skip_serializing, skip_deserializing)] + #[delta(skip)] + /// Options to apply when operating on the table files + pub file_format_options: Option, + #[serde(skip_serializing, skip_deserializing)] #[delta(skip)] /// When a runtime handler is provided, all IO tasks are spawn in that handle @@ -63,6 +69,7 @@ impl Default for DeltaTableConfig { require_files: true, log_buffer_size: num_cpus::get() * 4, log_batch_size: 1024, + file_format_options: None, io_runtime: None, } } @@ -128,6 +135,12 @@ impl DeltaTableBuilder { }) } + /// Sets the overall table configuration + pub fn with_table_config(mut self, table_config: DeltaTableConfig) -> Self { + self.table_config = table_config; + self + } + /// Sets `require_files=false` to the builder pub fn without_files(mut self) -> Self { self.table_config.require_files = false; @@ -223,6 +236,12 @@ impl DeltaTableBuilder { self } + /// Set the file options to use when reading/writing individual files in the table. + pub fn with_file_format_options(mut self, file_format_options: FileFormatRef) -> Self { + self.table_config.file_format_options = Some(file_format_options); + self + } + /// Provide a custom runtime handle or runtime config pub fn with_io_runtime(mut self, io_runtime: IORuntime) -> Self { self.table_config.io_runtime = Some(io_runtime); diff --git a/crates/core/src/table/file_format_options.rs b/crates/core/src/table/file_format_options.rs new file mode 100644 index 0000000000..2740822285 --- /dev/null +++ b/crates/core/src/table/file_format_options.rs @@ -0,0 +1,172 @@ +#[cfg(feature = "datafusion")] +use datafusion::catalog::Session; +#[cfg(feature = "datafusion")] +pub use datafusion::config::{ConfigFileType, TableOptions, TableParquetOptions}; +#[cfg(feature = "datafusion")] +use datafusion::execution::SessionState; +use std::fmt::Debug; + +use crate::{crate_version, DeltaResult}; +use arrow_schema::Schema as ArrowSchema; + +use async_trait::async_trait; + +use object_store::path::Path; +use parquet::basic::Compression; +use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder}; +use parquet::schema::types::ColumnPath; +use std::sync::Arc; +use tracing::debug; + +// Top level trait for file format options used by a DeltaTable +pub trait FileFormatOptions: Send + Sync + std::fmt::Debug + 'static { + #[cfg(feature = "datafusion")] + fn table_options(&self) -> TableOptions; + + fn writer_properties_factory(&self) -> WriterPropertiesFactoryRef; + + #[cfg(feature = "datafusion")] + fn update_session(&self, _session: &dyn Session) -> DeltaResult<()> { + // Default implementation does nothing + Ok(()) + } +} + +/// Convenience alias for file format options reference used across the codebase +pub type FileFormatRef = Arc; + +/// Convenience alias for writer properties factory reference used across the codebase +pub type WriterPropertiesFactoryRef = Arc; + +#[cfg(feature = "datafusion")] +#[derive(Clone, Debug, Default)] +pub struct SimpleFileFormatOptions { + table_options: TableOptions, +} + +#[cfg(feature = "datafusion")] +impl SimpleFileFormatOptions { + pub fn new(table_options: TableOptions) -> Self { + Self { table_options } + } +} + +#[cfg(feature = "datafusion")] +impl FileFormatOptions for SimpleFileFormatOptions { + fn table_options(&self) -> TableOptions { + self.table_options.clone() + } + + fn writer_properties_factory(&self) -> WriterPropertiesFactoryRef { + build_writer_properties_factory_tpo(&Some(self.table_options.parquet.clone())).unwrap() + } +} + +pub trait FileFormatToWriterPropertiesFactory { + fn into_writer_properties_factory_ref_or_default(self) -> WriterPropertiesFactoryRef; +} + +impl FileFormatToWriterPropertiesFactory for Option { + fn into_writer_properties_factory_ref_or_default(self) -> WriterPropertiesFactoryRef { + self.map(|ffo| ffo.writer_properties_factory()) + .unwrap_or_else(|| build_writer_properties_factory_default()) + } +} + +#[cfg(feature = "datafusion")] +pub fn state_with_file_format_options( + state: SessionState, + file_format_options: Option<&FileFormatRef>, +) -> DeltaResult { + if let Some(ffo) = file_format_options { + ffo.update_session(&state)?; + } + Ok(state) +} + +#[cfg(feature = "datafusion")] +fn build_writer_properties_tpo( + table_parquet_options: &Option, +) -> Option { + table_parquet_options.as_ref().map(|tpo| { + let mut tpo = tpo.clone(); + tpo.global.skip_arrow_metadata = true; + let mut wp_build = WriterPropertiesBuilder::try_from(&tpo) + .expect("Failed to convert TableParquetOptions to ParquetWriterOptions"); + if let Some(enc) = tpo.crypto.file_encryption { + // Convert config encryption properties into parquet FileEncryptionProperties + // and wrap into Arc as required by the builder. + wp_build = wp_build.with_file_encryption_properties(Arc::new(enc.into())); + } + wp_build.build() + }) +} + +#[cfg(feature = "datafusion")] +fn build_writer_properties_factory_tpo( + table_parquet_options: &Option, +) -> Option { + let props = build_writer_properties_tpo(table_parquet_options); + props.map(|wp| Arc::new(SimpleWriterPropertiesFactory::new(wp)) as WriterPropertiesFactoryRef) +} + +pub trait IntoWriterPropertiesFactoryRef { + fn into_factory_ref(self) -> WriterPropertiesFactoryRef; +} + +impl IntoWriterPropertiesFactoryRef for WriterProperties { + fn into_factory_ref(self) -> WriterPropertiesFactoryRef { + Arc::new(SimpleWriterPropertiesFactory::new(self)) + } +} + +pub fn build_writer_properties_factory_default() -> WriterPropertiesFactoryRef { + Arc::new(SimpleWriterPropertiesFactory::default()) +} + +#[async_trait] +pub trait WriterPropertiesFactory: Send + Sync + std::fmt::Debug + 'static { + fn compression(&self, column_path: &ColumnPath) -> Compression; + async fn create_writer_properties( + &self, + file_path: &Path, + file_schema: &Arc, + ) -> DeltaResult; +} + +#[derive(Clone, Debug)] +pub struct SimpleWriterPropertiesFactory { + writer_properties: WriterProperties, +} + +impl SimpleWriterPropertiesFactory { + pub fn new(writer_properties: WriterProperties) -> Self { + Self { writer_properties } + } +} + +impl Default for SimpleWriterPropertiesFactory { + fn default() -> Self { + let writer_properties = WriterProperties::builder() + .set_compression(Compression::SNAPPY) // Code assumes Snappy by default + .set_created_by(format!("delta-rs version {}", crate_version())) + .build(); + Self { writer_properties } + } +} + +#[async_trait] +impl WriterPropertiesFactory for SimpleWriterPropertiesFactory { + fn compression(&self, column_path: &ColumnPath) -> Compression { + self.writer_properties.compression(column_path) + } + + async fn create_writer_properties( + &self, + file_path: &Path, + _file_schema: &Arc, + ) -> DeltaResult { + debug!("Called create_writer_properties for file: {file_path}"); + Ok(self.writer_properties.clone()) + } +} diff --git a/crates/core/src/table/mod.rs b/crates/core/src/table/mod.rs index 30d8bfb1e0..3c1d5592a3 100644 --- a/crates/core/src/table/mod.rs +++ b/crates/core/src/table/mod.rs @@ -1,9 +1,5 @@ //! Delta Table read and write implementation -use std::cmp::{min, Ordering}; -use std::fmt; -use std::fmt::Formatter; - use chrono::{DateTime, Utc}; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; @@ -11,6 +7,9 @@ use object_store::{path::Path, ObjectStore}; use serde::de::{Error, SeqAccess, Visitor}; use serde::ser::SerializeSeq; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::cmp::{min, Ordering}; +use std::fmt; +use std::fmt::Formatter; use self::builder::DeltaTableConfig; use self::state::DeltaTableState; @@ -30,6 +29,7 @@ pub mod config; pub mod state; mod columns; +pub mod file_format_options; // Re-exposing for backwards compatibility pub use columns::*; diff --git a/crates/core/src/test_utils/kms_encryption.rs b/crates/core/src/test_utils/kms_encryption.rs new file mode 100644 index 0000000000..5318e33328 --- /dev/null +++ b/crates/core/src/test_utils/kms_encryption.rs @@ -0,0 +1,240 @@ +//! This module contains classes and functions to support encryption with a KMS. +//! These are not part of the core API but are used in the encryption tests and examples. +//! +//! The first main class is `TableEncryption`, which encapsulates the encryption configuration +//! and the encryption factory. +//! +//! The second main class is `KmsFileFormatOptions` which configures the file format options for +//! KMS encryption. It is used to create a `FileFormatOptions` instance that can be +//! passed to the `DeltaTable::create` method. This class can also be directly used in +//! `DeltaOps` via the `with_file_format_options` method. +//! See `crates/deltalake/examples/basic_operations_encryption.rs` for a working example. +//! +//! The `MockKmsClient` struct provides a mock implementation of `EncryptionFactory` for testing +//! purposes. It generates unique encryption keys for each file and stores them for later decryption. + +use crate::table::file_format_options::{ + FileFormatOptions, TableOptions, WriterPropertiesFactory, WriterPropertiesFactoryRef, +}; +use crate::{crate_version, DeltaResult}; +use arrow_schema::Schema as ArrowSchema; +use async_trait::async_trait; +use datafusion::catalog::Session; +use datafusion::config::{ConfigField, EncryptionFactoryOptions, ExtensionOptions}; +use datafusion::execution::parquet_encryption::EncryptionFactory; +use object_store::path::Path; +use parquet::basic::Compression; +use parquet::encryption::decrypt::FileDecryptionProperties; +use parquet::encryption::encrypt::FileEncryptionProperties; +use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder}; +use parquet::schema::types::ColumnPath; +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::{Arc, Mutex}; +use uuid::Uuid; + +pub type SchemaRef = Arc; + +#[derive(Clone, Debug)] +pub struct TableEncryption { + encryption_factory: Arc, + configuration: EncryptionFactoryOptions, +} + +impl TableEncryption { + pub fn new( + encryption_factory: Arc, + configuration: EncryptionFactoryOptions, + ) -> Self { + Self { + encryption_factory, + configuration, + } + } + + pub fn new_with_extension_options( + encryption_factory: Arc, + options: &T, + ) -> DeltaResult { + let mut configuration = EncryptionFactoryOptions::default(); + for entry in options.entries() { + if let Some(value) = &entry.value { + configuration.set(&entry.key, value)?; + } + } + Ok(Self { + encryption_factory, + configuration, + }) + } + + pub fn encryption_factory(&self) -> &Arc { + &self.encryption_factory + } + + pub fn configuration(&self) -> &EncryptionFactoryOptions { + &self.configuration + } + + pub async fn update_writer_properties( + &self, + mut builder: WriterPropertiesBuilder, + file_path: &Path, + file_schema: &SchemaRef, + ) -> DeltaResult { + let encryption_properties = self + .encryption_factory + .get_file_encryption_properties(&self.configuration, file_schema, file_path) + .await?; + if let Some(encryption_properties) = encryption_properties { + builder = builder.with_file_encryption_properties(encryption_properties); + } + Ok(builder) + } +} + +// More advanced factory with KMS support +#[derive(Clone, Debug)] +pub struct KMSWriterPropertiesFactory { + writer_properties: WriterProperties, + encryption: Option, +} + +impl KMSWriterPropertiesFactory { + pub fn with_encryption(table_encryption: TableEncryption) -> Self { + let writer_properties = WriterProperties::builder() + .set_compression(Compression::SNAPPY) // Code assumes Snappy by default + .set_created_by(format!("delta-rs version {}", crate_version())) + .build(); + Self { + writer_properties, + encryption: Some(table_encryption), + } + } +} + +#[async_trait] +impl WriterPropertiesFactory for KMSWriterPropertiesFactory { + fn compression(&self, column_path: &ColumnPath) -> Compression { + self.writer_properties.compression(column_path) + } + + async fn create_writer_properties( + &self, + file_path: &Path, + file_schema: &Arc, + ) -> DeltaResult { + let mut builder: WriterPropertiesBuilder = self.writer_properties.clone().into(); + if let Some(encryption) = self.encryption.as_ref() { + builder = encryption + .update_writer_properties(builder, file_path, file_schema) + .await?; + } + Ok(builder.build()) + } +} + +// ------------------------------------------------------------------------------------------------- +// FileFormatOptions for KMS encryption based on settings in TableEncryption +// ------------------------------------------------------------------------------------------------- +pub struct KmsFileFormatOptions { + table_encryption: TableEncryption, + writer_properties_factory: WriterPropertiesFactoryRef, + encryption_factory_id: String, +} + +impl KmsFileFormatOptions { + pub fn new(table_encryption: TableEncryption) -> Self { + let encryption_factory_id = format!("delta-{}", Uuid::new_v4()); + let writer_properties_factory = Arc::new(KMSWriterPropertiesFactory::with_encryption( + table_encryption.clone(), + )); + Self { + table_encryption, + writer_properties_factory, + encryption_factory_id, + } + } +} + +impl Debug for KmsFileFormatOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("KmsFileFormatOptions") + .finish_non_exhaustive() + } +} + +impl FileFormatOptions for KmsFileFormatOptions { + fn table_options(&self) -> TableOptions { + let mut table_options = TableOptions::default(); + table_options.parquet.crypto.factory_id = Some(self.encryption_factory_id.clone()); + table_options.parquet.crypto.factory_options = + self.table_encryption.configuration().clone(); + table_options + } + + fn writer_properties_factory(&self) -> WriterPropertiesFactoryRef { + Arc::clone(&self.writer_properties_factory) + } + + fn update_session(&self, session: &dyn Session) -> DeltaResult<()> { + // Ensure DataFusion has the encryption factory registered + session.runtime_env().register_parquet_encryption_factory( + &self.encryption_factory_id, + Arc::clone(self.table_encryption.encryption_factory()), + ); + Ok(()) + } +} + +// ------------------------------------------------------------------------------------------------- +// Mock KMS client for testing purposes +// ------------------------------------------------------------------------------------------------- + +/// Mock encryption factory implementation for use in tests. +/// Generates unique encryption keys for each file and stores them for later decryption. +#[derive(Debug, Default)] +pub struct MockKmsClient { + encryption_keys: Mutex>>, + counter: AtomicU8, +} + +impl MockKmsClient { + pub fn new() -> Self { + Self { + encryption_keys: Mutex::new(HashMap::new()), + counter: AtomicU8::new(0), + } + } +} + +#[async_trait] +impl EncryptionFactory for MockKmsClient { + async fn get_file_encryption_properties( + &self, + _config: &EncryptionFactoryOptions, + _schema: &SchemaRef, + file_path: &Path, + ) -> datafusion::error::Result>> { + let file_idx = self.counter.fetch_add(1, Ordering::Relaxed); + let key = vec![file_idx, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + let mut keys = self.encryption_keys.lock().unwrap(); + keys.insert(file_path.clone(), key.clone()); + let encryption_properties = FileEncryptionProperties::builder(key).build()?; + Ok(Some(encryption_properties)) + } + + async fn get_file_decryption_properties( + &self, + _config: &EncryptionFactoryOptions, + file_path: &Path, + ) -> datafusion::error::Result>> { + let keys = self.encryption_keys.lock().unwrap(); + let key = keys.get(file_path).ok_or_else(|| { + datafusion::error::DataFusionError::Execution(format!("No key for file {file_path:?}")) + })?; + let decryption_properties = FileDecryptionProperties::builder(key.clone()).build()?; + Ok(Some(decryption_properties)) + } +} diff --git a/crates/core/src/test_utils/mod.rs b/crates/core/src/test_utils/mod.rs index c00be04f22..12426dc484 100644 --- a/crates/core/src/test_utils/mod.rs +++ b/crates/core/src/test_utils/mod.rs @@ -1,4 +1,6 @@ mod factories; +#[cfg(feature = "datafusion")] +pub mod kms_encryption; use std::{collections::HashMap, path::PathBuf, process::Command}; diff --git a/crates/core/src/writer/json.rs b/crates/core/src/writer/json.rs index 7644abc04e..b67835bb19 100644 --- a/crates/core/src/writer/json.rs +++ b/crates/core/src/writer/json.rs @@ -10,10 +10,7 @@ use delta_kernel::expressions::Scalar; use indexmap::IndexMap; use itertools::Itertools; use object_store::path::Path; -use parquet::{ - arrow::ArrowWriter, basic::Compression, errors::ParquetError, - file::properties::WriterProperties, -}; +use parquet::{arrow::ArrowWriter, errors::ParquetError, file::properties::WriterProperties}; use serde_json::Value; use tracing::*; use url::Url; @@ -30,6 +27,9 @@ use crate::kernel::{scalars::ScalarExt, Add, PartitionsExt}; use crate::logstore::ObjectStoreRetryExt; use crate::table::builder::{ensure_table_uri, DeltaTableBuilder}; use crate::table::config::TablePropertiesExt as _; +use crate::table::file_format_options::{ + FileFormatToWriterPropertiesFactory, WriterPropertiesFactoryRef, +}; use crate::writer::utils::ShareableBuffer; use crate::DeltaTable; @@ -41,7 +41,7 @@ pub struct JsonWriter { table: DeltaTable, /// Optional schema to use, otherwise try to rely on the schema from the [DeltaTable] schema_ref: Option, - writer_properties: WriterProperties, + writer_properties_factory: WriterPropertiesFactoryRef, partition_columns: Vec, arrow_writers: HashMap, } @@ -55,6 +55,7 @@ pub(crate) struct DataArrowWriter { arrow_writer: ArrowWriter, partition_values: IndexMap, buffered_record_batch_count: usize, + path: Path, } impl DataArrowWriter { @@ -117,11 +118,6 @@ impl DataArrowWriter { partition_columns: &[String], record_batch: RecordBatch, ) -> Result<(), DeltaWriterError> { - if self.partition_values.is_empty() { - let partition_values = extract_partition_values(partition_columns, &record_batch)?; - self.partition_values = partition_values; - } - // Copy current buffered bytes so we can recover from failures let buffer_bytes = self.buffer.to_vec(); @@ -153,6 +149,8 @@ impl DataArrowWriter { fn new( arrow_schema: Arc, writer_properties: WriterProperties, + partition_values: IndexMap, + path: Path, ) -> Result { let buffer = ShareableBuffer::default(); let arrow_writer = Self::new_underlying_writer( @@ -161,7 +159,6 @@ impl DataArrowWriter { writer_properties.clone(), )?; - let partition_values = IndexMap::new(); let buffered_record_batch_count = 0; Ok(Self { @@ -171,6 +168,7 @@ impl DataArrowWriter { arrow_writer, partition_values, buffered_record_batch_count, + path, }) } @@ -195,16 +193,18 @@ impl JsonWriter { .with_storage_options(storage_options.unwrap_or_default()) .load() .await?; - // Initialize writer properties for the underlying arrow writer - let writer_properties = WriterProperties::builder() - // NOTE: Consider extracting config for writer properties and setting more than just compression - .set_compression(Compression::SNAPPY) - .build(); + + let writer_properties_factory = table + .snapshot()? + .load_config() + .file_format_options + .clone() + .into_writer_properties_factory_ref_or_default(); Ok(Self { table, schema_ref: Some(schema_ref), - writer_properties, + writer_properties_factory, partition_columns: partition_columns.unwrap_or_default(), arrow_writers: HashMap::new(), }) @@ -216,15 +216,16 @@ impl JsonWriter { let metadata = table.snapshot()?.metadata(); let partition_columns = metadata.partition_columns().clone(); - // Initialize writer properties for the underlying arrow writer - let writer_properties = WriterProperties::builder() - // NOTE: Consider extracting config for writer properties and setting more than just compression - .set_compression(Compression::SNAPPY) - .build(); + let writer_properties_factory = table + .snapshot()? + .load_config() + .file_format_options + .clone() + .into_writer_properties_factory_ref_or_default(); Ok(Self { table: table.clone(), - writer_properties, + writer_properties_factory, partition_columns, schema_ref: None, arrow_writers: HashMap::new(), @@ -323,7 +324,6 @@ impl DeltaWriter> for JsonWriter { let arrow_schema = self.arrow_schema(); let divided = self.divide_by_partition_values(values)?; let partition_columns = self.partition_columns.clone(); - let writer_properties = self.writer_properties.clone(); for (key, values) in divided { match self.arrow_writers.get_mut(&key) { @@ -335,7 +335,22 @@ impl DeltaWriter> for JsonWriter { } None => { let schema = arrow_schema_without_partitions(&arrow_schema, &partition_columns); - let mut writer = DataArrowWriter::new(schema, writer_properties.clone())?; + + let record_batch = + record_batch_from_message(arrow_schema.clone(), &values[..1])?; + let partition_values = + extract_partition_values(&partition_columns, &record_batch)?; + let prefix = Path::parse(partition_values.hive_partition_path())?; + let uuid = Uuid::new_v4(); + let path = + next_data_path(&prefix, 0, &uuid, self.writer_properties_factory.clone()); + let writer_properties = self + .writer_properties_factory + .create_writer_properties(&path, &arrow_schema) + .await?; + let mut writer = + DataArrowWriter::new(schema, writer_properties, partition_values, path)?; + let result = writer .write_values(&partition_columns, arrow_schema.clone(), values) .await; @@ -381,26 +396,27 @@ impl DeltaWriter> for JsonWriter { for (_, writer) in writers { let metadata = writer.arrow_writer.close()?; - let prefix = writer.partition_values.hive_partition_path(); - let prefix = Path::parse(prefix)?; - let uuid = Uuid::new_v4(); - let path = next_data_path(&prefix, 0, &uuid, &writer.writer_properties); let obj_bytes = Bytes::from(writer.buffer.to_vec()); let file_size = obj_bytes.len() as i64; - debug!(path = %path, size = file_size, rows = metadata.file_metadata().num_rows(), "writing data file"); + debug!( + path = writer.path.to_string(), + size = file_size, + rows = metadata.file_metadata().num_rows(), + "writing data file" + ); self.table .object_store() - .put_with_retries(&path, obj_bytes.into(), 15) + .put_with_retries(&writer.path, obj_bytes.into(), 15) .await?; let table_config = self.table.snapshot()?.table_config(); actions.push(create_add( &writer.partition_values, - path.to_string(), + writer.path.to_string(), file_size, &metadata, table_config.num_indexed_cols(), diff --git a/crates/core/src/writer/record_batch.rs b/crates/core/src/writer/record_batch.rs index 7ee1090411..0642f3d1ef 100644 --- a/crates/core/src/writer/record_batch.rs +++ b/crates/core/src/writer/record_batch.rs @@ -18,8 +18,8 @@ use delta_kernel::expressions::Scalar; use delta_kernel::table_properties::DataSkippingNumIndexedCols; use indexmap::IndexMap; use object_store::{path::Path, ObjectStore}; +use parquet::file::properties::WriterProperties; use parquet::{arrow::ArrowWriter, errors::ParquetError}; -use parquet::{basic::Compression, file::properties::WriterProperties}; use tracing::log::*; use uuid::Uuid; @@ -37,6 +37,9 @@ use crate::kernel::{scalars::ScalarExt, Action, Add, PartitionsExt}; use crate::logstore::ObjectStoreRetryExt; use crate::table::builder::DeltaTableBuilder; use crate::table::config::DEFAULT_NUM_INDEX_COLS; +use crate::table::file_format_options::{ + FileFormatToWriterPropertiesFactory, IntoWriterPropertiesFactoryRef, WriterPropertiesFactoryRef, +}; use crate::DeltaTable; /// Writes messages to a delta lake table. @@ -44,7 +47,7 @@ pub struct RecordBatchWriter { storage: Arc, arrow_schema_ref: ArrowSchemaRef, original_schema_ref: ArrowSchemaRef, - writer_properties: WriterProperties, + writer_properties_factory: WriterPropertiesFactoryRef, should_evolve: bool, partition_columns: Vec, arrow_writers: HashMap, @@ -72,11 +75,12 @@ impl RecordBatchWriter { let delta_table = DeltaTableBuilder::from_uri(table_url)? .with_storage_options(storage_options.unwrap_or_default()) .build()?; - // Initialize writer properties for the underlying arrow writer - let writer_properties = WriterProperties::builder() - // NOTE: Consider extracting config for writer properties and setting more than just compression - .set_compression(Compression::SNAPPY) - .build(); + let writer_properties_factory = delta_table + .snapshot()? + .load_config() + .file_format_options + .clone() + .into_writer_properties_factory_ref_or_default(); // if metadata fails to load, use an empty hashmap and default values for num_indexed_cols and stats_columns let configuration = delta_table.snapshot().map_or_else( @@ -88,7 +92,7 @@ impl RecordBatchWriter { storage: delta_table.object_store(), arrow_schema_ref: schema.clone(), original_schema_ref: schema, - writer_properties, + writer_properties_factory, partition_columns: partition_columns.unwrap_or_default(), should_evolve: false, arrow_writers: HashMap::new(), @@ -127,18 +131,20 @@ impl RecordBatchWriter { let arrow_schema_ref = Arc::new(arrow_schema); let partition_columns = metadata.partition_columns().clone(); - // Initialize writer properties for the underlying arrow writer - let writer_properties = WriterProperties::builder() - // NOTE: Consider extracting config for writer properties and setting more than just compression - .set_compression(Compression::SNAPPY) - .build(); + let writer_properties_factory = table + .snapshot()? + .load_config() + .file_format_options + .clone() + .into_writer_properties_factory_ref_or_default(); + let configuration = table.snapshot()?.metadata().configuration().clone(); Ok(Self { storage: table.object_store(), arrow_schema_ref: arrow_schema_ref.clone(), original_schema_ref: arrow_schema_ref.clone(), - writer_properties, + writer_properties_factory, partition_columns, should_evolve: false, arrow_writers: HashMap::new(), @@ -191,6 +197,8 @@ impl RecordBatchWriter { partition_values: &IndexMap, mode: WriteMode, ) -> Result { + let arrow_schema = + arrow_schema_without_partitions(&self.arrow_schema_ref, &self.partition_columns); let partition_key = partition_values.hive_partition_path(); let record_batch = record_batch_without_partitions(&record_batch, &self.partition_columns)?; @@ -198,13 +206,19 @@ impl RecordBatchWriter { let written_schema = match self.arrow_writers.get_mut(&partition_key) { Some(writer) => writer.write(&record_batch, mode)?, None => { + let prefix = Path::parse(&partition_key)?; + let uuid = Uuid::new_v4(); + let path = + next_data_path(&prefix, 0, &uuid, self.writer_properties_factory.clone()); + let writer_properties = self + .writer_properties_factory + .create_writer_properties(&path, &arrow_schema) + .await?; let mut writer = PartitionWriter::new( - arrow_schema_without_partitions( - &self.arrow_schema_ref, - &self.partition_columns, - ), + arrow_schema, partition_values.clone(), - self.writer_properties.clone(), + writer_properties, + path, )?; let schema = writer.write(&record_batch, mode)?; // Currently schema evolution is not supported with partition columns which means @@ -223,7 +237,8 @@ impl RecordBatchWriter { /// Sets the writer properties for the underlying arrow writer. pub fn with_writer_properties(mut self, writer_properties: WriterProperties) -> Self { - self.writer_properties = writer_properties; + let writer_properties_factory = writer_properties.into_factory_ref(); + self.writer_properties_factory = writer_properties_factory; self } @@ -278,18 +293,15 @@ impl DeltaWriter for RecordBatchWriter { for (_, writer) in writers { let metadata = writer.arrow_writer.close()?; - let prefix = Path::parse(writer.partition_values.hive_partition_path())?; - let uuid = Uuid::new_v4(); - let path = next_data_path(&prefix, 0, &uuid, &writer.writer_properties); let obj_bytes = Bytes::from(writer.buffer.to_vec()); let file_size = obj_bytes.len() as i64; self.storage - .put_with_retries(&path, obj_bytes.into(), 15) + .put_with_retries(&writer.path, obj_bytes.into(), 15) .await?; actions.push(create_add( &writer.partition_values, - path.to_string(), + writer.path.to_string(), file_size, &metadata, self.num_indexed_cols, @@ -340,6 +352,7 @@ struct PartitionWriter { pub(super) arrow_writer: ArrowWriter, pub(super) partition_values: IndexMap, pub(super) buffered_record_batch_count: usize, + pub(super) path: Path, } impl PartitionWriter { @@ -347,6 +360,7 @@ impl PartitionWriter { arrow_schema: ArrowSchemaRef, partition_values: IndexMap, writer_properties: WriterProperties, + path: Path, ) -> Result { let buffer = ShareableBuffer::default(); let arrow_writer = ArrowWriter::try_new( @@ -364,6 +378,7 @@ impl PartitionWriter { arrow_writer, partition_values, buffered_record_batch_count, + path, }) } diff --git a/crates/core/src/writer/utils.rs b/crates/core/src/writer/utils.rs index 98c8ba823e..7940aea6cb 100644 --- a/crates/core/src/writer/utils.rs +++ b/crates/core/src/writer/utils.rs @@ -10,12 +10,12 @@ use arrow_schema::{Schema as ArrowSchema, SchemaRef as ArrowSchemaRef}; use object_store::path::Path; use parking_lot::RwLock; use parquet::basic::Compression; -use parquet::file::properties::WriterProperties; use parquet::schema::types::ColumnPath; use serde_json::Value; use uuid::Uuid; use crate::errors::DeltaResult; +use crate::table::file_format_options::WriterPropertiesFactoryRef; use crate::writer::DeltaWriterError; /// Generate the name of the file to be written @@ -26,7 +26,7 @@ pub(crate) fn next_data_path( prefix: &Path, part_count: usize, writer_id: &Uuid, - writer_properties: &WriterProperties, + writer_properties_factory: WriterPropertiesFactoryRef, ) -> Path { fn compression_to_str(compression: &Compression) -> &str { match compression { @@ -46,7 +46,7 @@ pub(crate) fn next_data_path( // We can not access the default column properties but the current implementation will return // the default compression when the column is not found let column_path = ColumnPath::new(Vec::new()); - let compression = writer_properties.compression(&column_path); + let compression = writer_properties_factory.compression(&column_path); let part = format!("{part_count:0>5}"); @@ -63,7 +63,7 @@ pub fn record_batch_from_message( arrow_schema: Arc, json: &[Value], ) -> DeltaResult { - let mut decoder = ReaderBuilder::new(arrow_schema).build_decoder().unwrap(); + let mut decoder = ReaderBuilder::new(arrow_schema).build_decoder()?; decoder.serialize(json)?; decoder .flush()? @@ -158,7 +158,9 @@ impl Write for ShareableBuffer { #[cfg(test)] mod tests { use super::*; + use crate::table::file_format_options::IntoWriterPropertiesFactoryRef; use parquet::basic::{BrotliLevel, GzipLevel, ZstdLevel}; + use parquet::file::properties::WriterProperties; #[test] fn test_data_path() { @@ -171,7 +173,7 @@ mod tests { .build(); assert_eq!( - next_data_path(&prefix, 1, &uuid, &props).as_ref(), + next_data_path(&prefix, 1, &uuid, props.into_factory_ref()).as_ref(), "x=0/y=0/part-00001-02f09a3f-1624-3b1d-8409-44eff7708208-c000.parquet" ); @@ -179,7 +181,7 @@ mod tests { .set_compression(Compression::SNAPPY) .build(); assert_eq!( - next_data_path(&prefix, 1, &uuid, &props).as_ref(), + next_data_path(&prefix, 1, &uuid, props.into_factory_ref()).as_ref(), "x=0/y=0/part-00001-02f09a3f-1624-3b1d-8409-44eff7708208-c000.snappy.parquet" ); @@ -187,7 +189,7 @@ mod tests { .set_compression(Compression::GZIP(GzipLevel::default())) .build(); assert_eq!( - next_data_path(&prefix, 1, &uuid, &props).as_ref(), + next_data_path(&prefix, 1, &uuid, props.into_factory_ref()).as_ref(), "x=0/y=0/part-00001-02f09a3f-1624-3b1d-8409-44eff7708208-c000.gz.parquet" ); @@ -195,7 +197,7 @@ mod tests { .set_compression(Compression::LZ4) .build(); assert_eq!( - next_data_path(&prefix, 1, &uuid, &props).as_ref(), + next_data_path(&prefix, 1, &uuid, props.into_factory_ref()).as_ref(), "x=0/y=0/part-00001-02f09a3f-1624-3b1d-8409-44eff7708208-c000.lz4.parquet" ); @@ -203,7 +205,7 @@ mod tests { .set_compression(Compression::ZSTD(ZstdLevel::default())) .build(); assert_eq!( - next_data_path(&prefix, 1, &uuid, &props).as_ref(), + next_data_path(&prefix, 1, &uuid, props.into_factory_ref()).as_ref(), "x=0/y=0/part-00001-02f09a3f-1624-3b1d-8409-44eff7708208-c000.zstd.parquet" ); @@ -211,7 +213,7 @@ mod tests { .set_compression(Compression::LZ4_RAW) .build(); assert_eq!( - next_data_path(&prefix, 1, &uuid, &props).as_ref(), + next_data_path(&prefix, 1, &uuid, props.into_factory_ref()).as_ref(), "x=0/y=0/part-00001-02f09a3f-1624-3b1d-8409-44eff7708208-c000.lz4raw.parquet" ); @@ -219,7 +221,7 @@ mod tests { .set_compression(Compression::BROTLI(BrotliLevel::default())) .build(); assert_eq!( - next_data_path(&prefix, 1, &uuid, &props).as_ref(), + next_data_path(&prefix, 1, &uuid, props.into_factory_ref()).as_ref(), "x=0/y=0/part-00001-02f09a3f-1624-3b1d-8409-44eff7708208-c000.br.parquet" ); } diff --git a/crates/core/tests/command_optimize.rs b/crates/core/tests/command_optimize.rs index da3a00b429..136de4972b 100644 --- a/crates/core/tests/command_optimize.rs +++ b/crates/core/tests/command_optimize.rs @@ -301,7 +301,7 @@ async fn test_conflict_for_remove_actions() -> Result<(), Box> { dt.snapshot()?.snapshot(), &filter, None, - WriterProperties::builder().build(), + None, df_context.state(), ) .await?; @@ -363,12 +363,12 @@ async fn test_no_conflict_for_append_actions() -> Result<(), Box> { let filter = vec![PartitionFilter::try_from(("date", "=", "2022-05-22"))?]; let plan = create_merge_plan( - &dt.log_store(), + &*dt.log_store(), OptimizeType::Compact, dt.snapshot()?.snapshot(), &filter, None, - WriterProperties::builder().build(), + None, df_context.state(), ) .await?; @@ -432,7 +432,7 @@ async fn test_commit_interval() -> Result<(), Box> { dt.snapshot()?.snapshot(), &[], None, - WriterProperties::builder().build(), + None, context.state(), ) .await?; diff --git a/crates/core/tests/commands_with_encryption.rs b/crates/core/tests/commands_with_encryption.rs new file mode 100644 index 0000000000..ce79125f72 --- /dev/null +++ b/crates/core/tests/commands_with_encryption.rs @@ -0,0 +1,501 @@ +use arrow::{ + array::{Int32Array, StringArray, TimestampMicrosecondArray}, + datatypes::{DataType as ArrowDataType, Field, Schema, Schema as ArrowSchema, TimeUnit}, + record_batch::RecordBatch, +}; +use datafusion::{ + assert_batches_sorted_eq, + config::{ConfigFileType, TableOptions, TableParquetOptions}, + dataframe::DataFrame, + logical_expr::{col, lit}, + prelude::SessionContext, +}; +use deltalake_core::kernel::{DataType, PrimitiveType, StructField}; +use deltalake_core::operations::collect_sendable_stream; +use deltalake_core::parquet::encryption::decrypt::FileDecryptionProperties; +use deltalake_core::table::file_format_options::{FileFormatRef, SimpleFileFormatOptions}; +use deltalake_core::test_utils::kms_encryption::{ + KmsFileFormatOptions, MockKmsClient, TableEncryption, +}; +use deltalake_core::{arrow, parquet, DeltaOps}; +use deltalake_core::{operations::optimize::OptimizeType, DeltaTable, DeltaTableError}; + +use datafusion::config::EncryptionFactoryOptions; +use paste::paste; +use std::{fs, sync::Arc}; +use tempfile::TempDir; +use url::Url; + +fn get_table_columns() -> Vec { + vec![ + StructField::new( + String::from("int"), + DataType::Primitive(PrimitiveType::Integer), + false, + ), + StructField::new( + String::from("string"), + DataType::Primitive(PrimitiveType::String), + true, + ), + StructField::new( + String::from("timestamp"), + DataType::Primitive(PrimitiveType::TimestampNtz), + true, + ), + ] +} + +fn get_table_schema() -> Arc { + Arc::new(ArrowSchema::new(vec![ + Field::new("int", ArrowDataType::Int32, false), + Field::new("string", ArrowDataType::Utf8, true), + Field::new( + "timestamp", + ArrowDataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + ])) +} + +fn get_table_batches() -> RecordBatch { + let schema = get_table_schema(); + + let int_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); + let str_values = StringArray::from(vec!["A", "B", "C", "B", "A", "C", "A", "B", "B", "A", "A"]); + let ts_values = TimestampMicrosecondArray::from(vec![ + 1000000012, 1000000012, 1000000012, 1000000012, 500012305, 500012305, 500012305, 500012305, + 500012305, 500012305, 500012305, + ]); + RecordBatch::try_new( + schema, + vec![ + Arc::new(int_values), + Arc::new(str_values), + Arc::new(ts_values), + ], + ) + .unwrap() +} + +// Create a DeltaOps instance with the specified file_format_options to apply crypto settings. +async fn ops_from_uri(uri: &str) -> Result { + let prefix_uri = format!("file://{}", uri); + let url = Url::parse(&*prefix_uri).unwrap(); + let ops = DeltaOps::try_from_uri(url).await?; + Ok(ops) +} + +// Create a DeltaOps instance with the specified file_format_options to apply crypto settings. +async fn ops_with_crypto( + uri: &str, + file_format_options: &FileFormatRef, +) -> Result { + let ops = ops_from_uri(uri).await?; + let ops = ops + .with_file_format_options(file_format_options.clone()) + .await?; + Ok(ops) +} + +async fn create_table( + uri: &str, + table_name: &str, + file_format_options: &FileFormatRef, +) -> Result { + fs::remove_dir_all(uri)?; + fs::create_dir(uri)?; + let ops = ops_with_crypto(uri, file_format_options).await?; + + // The operations module uses a builder pattern that allows specifying several options + // on how the command behaves. The builders implement `Into`, so once + // options are set you can run the command using `.await`. + let table = ops + .create() + .with_columns(get_table_columns()) + .with_table_name(table_name) + .with_comment("A table to show how delta-rs works") + .await?; + + assert_eq!(table.version(), Some(0)); + + let batch = get_table_batches(); + let table = DeltaOps(table).write(vec![batch.clone()]).await?; + + assert_eq!(table.version(), Some(1)); + + // Append records to the table + let table = DeltaOps(table).write(vec![batch]).await?; + + assert_eq!(table.version(), Some(2)); + + Ok(table) +} + +async fn read_table( + uri: &str, + file_format_options: &FileFormatRef, + use_file_format_options: bool, +) -> Result, DeltaTableError> { + let ops = match use_file_format_options { + true => ops_with_crypto(uri, file_format_options).await?, + false => ops_from_uri(uri).await?, + }; + let (_table, stream) = ops.load().await?; + let data: Vec = collect_sendable_stream(stream).await?; + Ok(data) +} + +async fn update_table( + uri: &str, + file_format_options: &FileFormatRef, +) -> Result<(), DeltaTableError> { + let ops = ops_with_crypto(uri, file_format_options).await?; + let table: DeltaTable = ops.into(); + let version = table.version(); + let ops: DeltaOps = table.into(); + + let (table, _metrics) = ops + .update() + .with_predicate(col("int").eq(lit(1))) + .with_update("int", "100") + .await + .unwrap(); + + assert_eq!(table.version(), Some(version.unwrap() + 1)); + + Ok(()) +} + +async fn delete_from_table( + uri: &str, + file_format_options: &FileFormatRef, +) -> Result<(), DeltaTableError> { + let ops = ops_with_crypto(uri, file_format_options).await?; + let table: DeltaTable = ops.into(); + let version = table.version(); + let ops: DeltaOps = table.into(); + + let (table, _metrics) = ops + .delete() + .with_predicate(col("int").eq(lit(2))) + .await + .unwrap(); + + assert_eq!(table.version(), Some(version.unwrap() + 1)); + + Ok(()) +} + +// Secondary table to merge with primary data +fn merge_source() -> DataFrame { + let ctx = SessionContext::new(); + let schema = get_table_schema(); + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(TimestampMicrosecondArray::from(vec![ + 1000000012, 1000000012, 1000000012, + ])), + ], + ) + .unwrap(); + ctx.read_batch(batch).unwrap() +} + +// Apply merge operation to the primary table +async fn merge_table( + uri: &str, + file_format_options: &FileFormatRef, +) -> Result<(), DeltaTableError> { + let ops = ops_with_crypto(uri, file_format_options).await?; + + let source = merge_source(); + + let (_table, _metrics) = ops + .merge(source, col("target.int").eq(col("source.int"))) + .with_source_alias("source") + .with_target_alias("target") + .when_not_matched_by_source_delete(|delete| delete) + .unwrap() + .await?; + Ok(()) +} + +async fn optimize_table( + uri: &str, + file_format_options: &FileFormatRef, + optimize_type: OptimizeType, +) -> Result<(), DeltaTableError> { + let ops = ops_with_crypto(uri, file_format_options).await?; + let (_table, _metrics) = ops.optimize().with_type(optimize_type).await?; + Ok(()) +} + +async fn optimize_table_z_order( + uri: &str, + file_format_options: &FileFormatRef, +) -> Result<(), DeltaTableError> { + optimize_table( + uri, + file_format_options, + OptimizeType::ZOrder(vec!["timestamp".to_string(), "int".to_string()]), + ) + .await +} + +async fn optimize_table_compact( + uri: &str, + file_format_options: &FileFormatRef, +) -> Result<(), DeltaTableError> { + optimize_table(uri, file_format_options, OptimizeType::Compact).await +} + +// Create a direct encryption / decryption configuration using EncryptionProperties and the provided keys +fn create_plain_crypto_format( + encrypt_key: Vec, + decrypt_key: Vec, +) -> Result { + let crypt = + parquet::encryption::encrypt::FileEncryptionProperties::builder(encrypt_key.clone()) + .with_column_key("int", encrypt_key.clone()) + .with_column_key("string", encrypt_key.clone()) + .build()?; + + let decrypt = FileDecryptionProperties::builder(decrypt_key.clone()) + .with_column_key("int", decrypt_key.clone()) + .with_column_key("string", decrypt_key.clone()) + .build()?; + + let mut tpo: TableParquetOptions = TableParquetOptions::default(); + tpo.crypto.file_encryption = Some((&crypt).into()); + tpo.crypto.file_decryption = Some((&decrypt).into()); + let mut tbl_options = TableOptions::new(); + tbl_options.parquet = tpo; + tbl_options.current_format = Some(ConfigFileType::PARQUET); + let file_format_options = Arc::new(SimpleFileFormatOptions::new(tbl_options)) as FileFormatRef; + Ok(file_format_options) +} + +fn plain_crypto_format() -> Result { + let key: Vec<_> = b"1234567890123450".to_vec(); + create_plain_crypto_format(key.clone(), key.clone()) +} + +fn plain_crypto_format_bad_decryptor() -> Result { + let encryption_key: Vec<_> = b"1234567890123450".to_vec(); + let decryption_key: Vec<_> = b"0123456789012345".to_vec(); + create_plain_crypto_format(encryption_key.clone(), decryption_key.clone()) +} + +fn kms_crypto_format() -> Result { + let encryption_factory = Arc::new(MockKmsClient::new()); + let configuration = EncryptionFactoryOptions::default(); + let table_encryption = TableEncryption::new(encryption_factory, configuration); + let file_format_options = + Arc::new(KmsFileFormatOptions::new(table_encryption)) as FileFormatRef; + Ok(file_format_options) +} + +fn full_table_data() -> Vec<&'static str> { + vec![ + "+-----+--------+----------------------------+", + "| int | string | timestamp |", + "+-----+--------+----------------------------+", + "| 1 | A | 1970-01-01T00:16:40.000012 |", + "| 2 | B | 1970-01-01T00:16:40.000012 |", + "| 3 | C | 1970-01-01T00:16:40.000012 |", + "| 4 | B | 1970-01-01T00:16:40.000012 |", + "| 5 | A | 1970-01-01T00:08:20.012305 |", + "| 6 | C | 1970-01-01T00:08:20.012305 |", + "| 7 | A | 1970-01-01T00:08:20.012305 |", + "| 8 | B | 1970-01-01T00:08:20.012305 |", + "| 9 | B | 1970-01-01T00:08:20.012305 |", + "| 10 | A | 1970-01-01T00:08:20.012305 |", + "| 11 | A | 1970-01-01T00:08:20.012305 |", + "| 1 | A | 1970-01-01T00:16:40.000012 |", + "| 2 | B | 1970-01-01T00:16:40.000012 |", + "| 3 | C | 1970-01-01T00:16:40.000012 |", + "| 4 | B | 1970-01-01T00:16:40.000012 |", + "| 5 | A | 1970-01-01T00:08:20.012305 |", + "| 6 | C | 1970-01-01T00:08:20.012305 |", + "| 7 | A | 1970-01-01T00:08:20.012305 |", + "| 8 | B | 1970-01-01T00:08:20.012305 |", + "| 9 | B | 1970-01-01T00:08:20.012305 |", + "| 10 | A | 1970-01-01T00:08:20.012305 |", + "| 11 | A | 1970-01-01T00:08:20.012305 |", + "+-----+--------+----------------------------+", + ] +} + +type ModifyFn = for<'a> fn( + uri: &'a str, + file_format_options: &'a FileFormatRef, +) -> std::pin::Pin< + Box> + Send + 'a>, +>; + +// Create the table, modify it, and read it back. Verify that the final data is as expected. +async fn run_modify_test( + file_format_options: FileFormatRef, + modifier: ModifyFn, + expected: Vec, + decrypt_final_read: bool, +) { + let temp_dir = TempDir::new().unwrap(); + let uri = temp_dir.path().to_str().unwrap(); + let table_name = "test"; + create_table(uri, table_name, &file_format_options) + .await + .expect("Failed to create encrypted table"); + modifier(uri, &file_format_options) + .await + .expect("Failed to modify encrypted table"); + let data = read_table(uri, &file_format_options, decrypt_final_read) + .await + .expect("Failed to read encrypted table"); + let expected_refs: Vec<&str> = expected.iter().map(AsRef::as_ref).collect(); + assert_batches_sorted_eq!(&expected_refs, &data); +} + +async fn test_create_and_read(file_format_options: FileFormatRef, decrypt_final_read: bool) { + // Use the shared modify test template with a no-op modifier + let expected: Vec = full_table_data().iter().map(|s| s.to_string()).collect(); + run_modify_test( + file_format_options, + |_uri, _opts| Box::pin(async { Ok(()) }), + expected, + decrypt_final_read, + ) + .await; +} + +// Macro to generate the common encryption test matrix for a given runner function +macro_rules! encryption_tests { + ($runner:ident) => { + paste! { + #[tokio::test] + async fn [<$runner _plain_crypto>]() { + let file_format_options = plain_crypto_format().unwrap(); + $runner(file_format_options, true).await; + } + + #[tokio::test] + #[should_panic(expected = "Failed to read encrypted table")] + async fn [<$runner _plain_crypto_no_decryptor>]() { + let file_format_options = plain_crypto_format().unwrap(); + $runner(file_format_options, false).await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn [<$runner _kms>]() { + let file_format_options = kms_crypto_format().unwrap(); + $runner(file_format_options, true).await; + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + #[should_panic(expected = "Failed to read encrypted table")] + async fn [<$runner _kms_no_decryptor>]() { + let file_format_options = kms_crypto_format().unwrap(); + $runner(file_format_options, false).await; + } + } + }; +} + +encryption_tests!(test_create_and_read); + +#[tokio::test] +#[should_panic(expected = "Failed to read encrypted table")] +async fn test_create_and_read_bad_crypto() { + let file_format_options = plain_crypto_format_bad_decryptor().unwrap(); + test_create_and_read(file_format_options, true).await; +} + +async fn test_optimize_compact(file_format_options: FileFormatRef, decrypt_final_read: bool) { + // Use the shared modify test template; perform optimization steps inside the modifier + let expected: Vec = full_table_data().iter().map(|s| s.to_string()).collect(); + run_modify_test( + file_format_options, + |uri, opts| Box::pin(optimize_table_compact(uri, opts)), + expected, + decrypt_final_read, + ) + .await; +} + +async fn test_optimize_z_order(file_format_options: FileFormatRef, decrypt_final_read: bool) { + // Use the shared modify test template; perform optimization steps inside the modifier + let expected: Vec = full_table_data().iter().map(|s| s.to_string()).collect(); + run_modify_test( + file_format_options.clone(), + |uri, opts| Box::pin(optimize_table_z_order(uri, opts)), + expected, + decrypt_final_read, + ) + .await; +} + +encryption_tests!(test_optimize_compact); + +encryption_tests!(test_optimize_z_order); + +async fn test_update(file_format_options: FileFormatRef, decrypt_final_read: bool) { + let base = full_table_data(); + let expected: Vec = base + .iter() + // If the value of the int column is 1, so we expect the value to be updated to 100 + .map(|s| s.to_string().replace("| 1 |", "| 100 |")) + .collect(); + run_modify_test( + file_format_options, + |uri, opts| Box::pin(update_table(uri, opts)), + expected, + decrypt_final_read, + ) + .await; +} + +encryption_tests!(test_update); + +async fn test_delete(file_format_options: FileFormatRef, decrypt_final_read: bool) { + let base = full_table_data(); + let expected: Vec = base + .iter() + // If the value of the int column is 2, we expect the row to be deleted + .filter(|s| !s.contains("| 2 |")) + .map(|s| s.to_string()) + .collect(); + run_modify_test( + file_format_options, + |uri, opts| Box::pin(delete_from_table(uri, opts)), + expected, + decrypt_final_read, + ) + .await; +} + +encryption_tests!(test_delete); + +async fn test_merge(file_format_options: FileFormatRef, decrypt_final_read: bool) { + let expected_str = vec![ + "+-----+--------+----------------------------+", + "| int | string | timestamp |", + "+-----+--------+----------------------------+", + "| 10 | A | 1970-01-01T00:08:20.012305 |", + "| 10 | A | 1970-01-01T00:08:20.012305 |", + "+-----+--------+----------------------------+", + ]; + let expected: Vec = expected_str.iter().map(|s| s.to_string()).collect(); + run_modify_test( + file_format_options, + |uri, opts| Box::pin(merge_table(uri, opts)), + expected, + decrypt_final_read, + ) + .await; +} + +encryption_tests!(test_merge); diff --git a/crates/deltalake/Cargo.toml b/crates/deltalake/Cargo.toml index 1735e0ab6a..c069bc41aa 100644 --- a/crates/deltalake/Cargo.toml +++ b/crates/deltalake/Cargo.toml @@ -46,6 +46,7 @@ datafusion-ext = ["datafusion"] gcs = ["deltalake-gcp"] glue = ["deltalake-catalog-glue"] hdfs = ["deltalake-hdfs"] +integration-test = ["deltalake-core/integration_test"] json = ["deltalake-core/json"] python = ["deltalake-core/python"] s3-native-tls = ["deltalake-aws/native-tls", "native-tls"] @@ -59,12 +60,17 @@ rustls = ["deltalake-core/rustls"] tokio = { version = "1", features = ["macros", "rt-multi-thread"] } chrono = { workspace = true, default-features = false, features = ["clock"] } tracing = { workspace = true } +tempfile = "3.21.0" url = { workspace = true } [[example]] name = "basic_operations" required-features = ["datafusion"] +[[example]] +name = "basic_operations_encryption" +required-features = ["datafusion", "integration-test"] + [[example]] name = "load_table" required-features = ["datafusion"] diff --git a/crates/deltalake/examples/basic_operations_encryption.rs b/crates/deltalake/examples/basic_operations_encryption.rs new file mode 100644 index 0000000000..a088a1e31b --- /dev/null +++ b/crates/deltalake/examples/basic_operations_encryption.rs @@ -0,0 +1,338 @@ +use deltalake::arrow::{ + array::{Int32Array, StringArray, TimestampMicrosecondArray}, + datatypes::{DataType as ArrowDataType, Field, Schema, Schema as ArrowSchema, TimeUnit}, + record_batch::RecordBatch, +}; +use deltalake::datafusion::{ + assert_batches_sorted_eq, + config::{ConfigFileType, TableOptions, TableParquetOptions}, + dataframe::DataFrame, + logical_expr::{col, lit}, + prelude::SessionContext, +}; +use deltalake::kernel::{DataType, PrimitiveType, StructField}; +use deltalake::operations::collect_sendable_stream; +use deltalake::parquet::encryption::decrypt::FileDecryptionProperties; +use deltalake::{arrow, parquet, DeltaOps}; +use deltalake_core::table::file_format_options::{FileFormatRef, SimpleFileFormatOptions}; +use deltalake_core::test_utils::kms_encryption::{ + KmsFileFormatOptions, MockKmsClient, TableEncryption, +}; +use deltalake_core::{ + datafusion::common::test_util::format_batches, operations::optimize::OptimizeType, DeltaTable, + DeltaTableError, +}; + +use deltalake::datafusion::config::EncryptionFactoryOptions; +use std::{fs, sync::Arc}; +use tempfile::TempDir; +use url::Url; + +fn get_table_columns() -> Vec { + vec![ + StructField::new( + String::from("int"), + DataType::Primitive(PrimitiveType::Integer), + false, + ), + StructField::new( + String::from("string"), + DataType::Primitive(PrimitiveType::String), + true, + ), + StructField::new( + String::from("timestamp"), + DataType::Primitive(PrimitiveType::TimestampNtz), + true, + ), + ] +} + +fn get_table_schema() -> Arc { + Arc::new(ArrowSchema::new(vec![ + Field::new("int", ArrowDataType::Int32, false), + Field::new("string", ArrowDataType::Utf8, true), + Field::new( + "timestamp", + ArrowDataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + ])) +} + +fn get_table_batches() -> RecordBatch { + let schema = get_table_schema(); + + let int_values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); + let str_values = StringArray::from(vec!["A", "B", "C", "B", "A", "C", "A", "B", "B", "A", "A"]); + let ts_values = TimestampMicrosecondArray::from(vec![ + 1000000012, 1000000012, 1000000012, 1000000012, 500012305, 500012305, 500012305, 500012305, + 500012305, 500012305, 500012305, + ]); + RecordBatch::try_new( + schema, + vec![ + Arc::new(int_values), + Arc::new(str_values), + Arc::new(ts_values), + ], + ) + .unwrap() +} + +async fn ops_with_crypto( + uri: &str, + file_format_options: &FileFormatRef, +) -> Result { + let prefix_uri = format!("file://{}", uri); + let url = Url::parse(&*prefix_uri).unwrap(); + let ops = DeltaOps::try_from_uri(url).await?; + let ops = ops + .with_file_format_options(file_format_options.clone()) + .await?; + Ok(ops) +} + +async fn create_table( + uri: &str, + table_name: &str, + file_format_options: &FileFormatRef, +) -> Result { + fs::remove_dir_all(uri)?; + fs::create_dir(uri)?; + let ops = ops_with_crypto(uri, file_format_options).await?; + + // The operations module uses a builder pattern that allows specifying several options + // on how the command behaves. The builders implement `Into`, so once + // options are set you can run the command using `.await`. + let table = ops + .create() + .with_columns(get_table_columns()) + .with_table_name(table_name) + .with_comment("A table to show how delta-rs works") + .await?; + + assert_eq!(table.version(), Some(0)); + + let batch = get_table_batches(); + let table = DeltaOps(table).write(vec![batch.clone()]).await?; + + assert_eq!(table.version(), Some(1)); + + // Append records to the table + let table = DeltaOps(table).write(vec![batch.clone()]).await?; + + assert_eq!(table.version(), Some(2)); + + Ok(table) +} + +async fn read_table(uri: &str, file_format_options: &FileFormatRef) -> Result<(), DeltaTableError> { + let ops = ops_with_crypto(uri, file_format_options).await?; + let (_table, stream) = ops.load().await?; + let data: Vec = collect_sendable_stream(stream).await?; + + let formatted = format_batches(&*data)?.to_string(); + println!("Final table:"); + println!("{}", formatted); + + Ok(()) +} + +async fn update_table( + uri: &str, + file_format_options: &FileFormatRef, +) -> Result<(), DeltaTableError> { + let ops = ops_with_crypto(uri, file_format_options).await?; + let table: DeltaTable = ops.into(); + let version = table.version(); + let ops: DeltaOps = table.into(); + + let (table, _metrics) = ops + .update() + .with_predicate(col("int").eq(lit(1))) + .with_update("int", "100") + .await + .unwrap(); + + assert_eq!(table.version(), Some(version.unwrap() + 1)); + + Ok(()) +} + +async fn delete_from_table( + uri: &str, + file_format_options: &FileFormatRef, +) -> Result<(), DeltaTableError> { + let ops = ops_with_crypto(uri, file_format_options).await?; + let table: DeltaTable = ops.into(); + let version = table.version(); + let ops: DeltaOps = table.into(); + + let (table, _metrics) = ops + .delete() + .with_predicate(col("int").eq(lit(2))) + .await + .unwrap(); + + assert_eq!(table.version(), Some(version.unwrap() + 1)); + + if false { + println!("Table after delete:"); + let (_table, stream) = DeltaOps(table).load().await?; + let data: Vec = collect_sendable_stream(stream).await?; + + println!("{data:?}"); + } + + Ok(()) +} + +fn merge_source(schema: Arc) -> DataFrame { + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(TimestampMicrosecondArray::from(vec![ + 1000000012, 1000000012, 1000000012, + ])), + ], + ) + .unwrap(); + ctx.read_batch(batch).unwrap() +} + +async fn merge_table( + uri: &str, + file_format_options: &FileFormatRef, +) -> Result<(), DeltaTableError> { + let ops = ops_with_crypto(uri, file_format_options).await?; + + let schema = get_table_schema(); + let source = merge_source(schema); + + let (table, _metrics) = ops + .merge(source, col("target.int").eq(col("source.int"))) + .with_source_alias("source") + .with_target_alias("target") + .when_not_matched_by_source_delete(|delete| delete) + .unwrap() + .await + .unwrap(); + + let expected = vec![ + "+-----+--------+----------------------------+", + "| int | string | timestamp |", + "+-----+--------+----------------------------+", + "| 10 | A | 1970-01-01T00:08:20.012305 |", + "| 10 | A | 1970-01-01T00:08:20.012305 |", + "+-----+--------+----------------------------+", + ]; + + let (_table, stream) = DeltaOps(table).load().await?; + let data: Vec = collect_sendable_stream(stream).await?; + + assert_batches_sorted_eq!(&expected, &data); + Ok(()) +} + +async fn optimize_table_z_order( + uri: &str, + file_format_options: &FileFormatRef, +) -> Result<(), DeltaTableError> { + let ops = ops_with_crypto(uri, file_format_options).await?; + let (_table, metrics) = ops + .optimize() + .with_type(OptimizeType::ZOrder(vec![ + "timestamp".to_string(), + "int".to_string(), + ])) + .await?; + println!("\nOptimize Z-Order:\n{metrics:?}\n"); + Ok(()) +} + +async fn optimize_table_compact( + uri: &str, + file_format_options: &FileFormatRef, +) -> Result<(), DeltaTableError> { + let ops = ops_with_crypto(uri, file_format_options).await?; + let (_table, metrics) = ops.optimize().with_type(OptimizeType::Compact).await?; + println!("\nOptimize Compact:\n{metrics:?}\n"); + Ok(()) +} + +fn plain_crypto_format() -> Result { + let key: Vec<_> = b"1234567890123450".to_vec(); + let _wrong_key: Vec<_> = b"9234567890123450".to_vec(); // Can use to check encryption + + let crypt = parquet::encryption::encrypt::FileEncryptionProperties::builder(key.clone()) + .with_column_key("int", key.clone()) + .with_column_key("string", key.clone()) + .build()?; + + let decrypt = FileDecryptionProperties::builder(key.clone()) + .with_column_key("int", key.clone()) + .with_column_key("string", key.clone()) + .build()?; + + let mut tpo: TableParquetOptions = TableParquetOptions::default(); + tpo.crypto.file_encryption = Some((&crypt).into()); + tpo.crypto.file_decryption = Some((&decrypt).into()); + let mut tbl_options = TableOptions::new(); + tbl_options.parquet = tpo; + tbl_options.current_format = Some(ConfigFileType::PARQUET); + let file_format_options = Arc::new(SimpleFileFormatOptions::new(tbl_options)) as FileFormatRef; + Ok(file_format_options) +} + +fn kms_crypto_format() -> Result { + let encryption_factory = Arc::new(MockKmsClient::new()); + let configuration = EncryptionFactoryOptions::default(); + let table_encryption = TableEncryption::new(encryption_factory, configuration); + let file_format_options = + Arc::new(KmsFileFormatOptions::new(table_encryption)) as FileFormatRef; + Ok(file_format_options) +} + +async fn round_trip_test( + file_format_options: FileFormatRef, +) -> Result<(), deltalake::errors::DeltaTableError> { + let temp_dir = TempDir::new()?; + let uri = temp_dir.path().to_str().unwrap(); + + let table_name = "roundtrip"; + + create_table(uri, table_name, &file_format_options).await?; + optimize_table_z_order(uri, &file_format_options).await?; + // Re-create and append to table again so compact has work to do + create_table(uri, table_name, &file_format_options).await?; + optimize_table_compact(uri, &file_format_options).await?; + update_table(uri, &file_format_options).await?; + delete_from_table(uri, &file_format_options).await?; + merge_table(uri, &file_format_options).await?; + read_table(uri, &file_format_options).await?; + Ok(()) +} + +#[tokio::main(flavor = "multi_thread", worker_threads = 1)] +async fn main() -> Result<(), DeltaTableError> { + println!("===================="); + println!("Begin Plain encryption test"); + let file_format_options = plain_crypto_format()?; + round_trip_test(file_format_options).await?; + println!("End Plain encryption test"); + println!("===================="); + + println!("\n\n"); + println!("===================="); + println!("Begin KMS encryption test"); + let file_format_options = kms_crypto_format()?; + round_trip_test(file_format_options).await?; + println!("End KMS encryption test"); + println!("===================="); + + Ok(()) +}