diff --git a/mostlyai/engine/_language/encoding.py b/mostlyai/engine/_language/encoding.py index 962b5aba..4260fb14 100644 --- a/mostlyai/engine/_language/encoding.py +++ b/mostlyai/engine/_language/encoding.py @@ -157,6 +157,7 @@ def _encode_partition( def encode( workspace_dir: str | Path | None = None, update_progress: ProgressCallback | None = None, + parallel_backend: str = "loky", ) -> None: _LOG.info("ENCODE_LANGUAGE started") t0 = time.time() diff --git a/mostlyai/engine/_tabular/encoding.py b/mostlyai/engine/_tabular/encoding.py index 4b1480f7..03bc7cc6 100644 --- a/mostlyai/engine/_tabular/encoding.py +++ b/mostlyai/engine/_tabular/encoding.py @@ -53,6 +53,7 @@ def encode( workspace_dir: str | Path | None = None, update_progress: ProgressCallback | None = None, + parallel_backend: str = "loky", ) -> None: _LOG.info("ENCODE_TABULAR started") t0 = time.time() @@ -84,6 +85,7 @@ def encode( ctx_partition_file=ctx_pqt_partitions[i] if has_context else None, ctx_stats=ctx_stats if has_context else None, n_jobs=min(16, max(1, cpu_count() - 1)), + parallel_backend=parallel_backend, ) progress.update(completed=i, total=len(tgt_pqt_partitions) + 1) _LOG.info(f"ENCODE_TABULAR finished in {time.time() - t0:.2f}s") @@ -97,6 +99,7 @@ def _encode_partition( ctx_partition_file: Path | None = None, ctx_stats: dict | None = None, n_jobs: int = 1, + parallel_backend: str = "loky", ) -> None: seq_len_stats = get_sequence_length_stats(tgt_stats) is_sequential = tgt_stats["is_sequential"] @@ -112,6 +115,7 @@ def _encode_partition( ctx_primary_key=None, tgt_context_key=tgt_context_key, n_jobs=n_jobs, + parallel_backend=parallel_backend, ) has_context = ctx_partition_file is not None and tgt_context_key and ctx_primary_key @@ -128,6 +132,7 @@ def _encode_partition( ctx_primary_key=ctx_primary_key, tgt_context_key=None, n_jobs=n_jobs, + parallel_backend=parallel_backend, ) # pad each list with one extra item df_ctx = pad_ctx_sequences(df_ctx) @@ -185,6 +190,7 @@ def encode_df( ctx_primary_key: str | None = None, tgt_context_key: str | None = None, n_jobs: int = 1, + parallel_backend: str = "loky", ) -> tuple[pd.DataFrame, str | None, str | None]: """ Encodes a given table represented by a DataFrame object. The result will be delivered @@ -194,6 +200,7 @@ def encode_df( :param stats: stats for each of the columns :param ctx_primary_key: context primary key :param tgt_context_key: target context key + :param parallel_backend: joblib parallel backend to use :return: encoded data and keys following columns' naming conventions """ @@ -238,7 +245,7 @@ def encode_df( ) ) if delayed_encodes: - with parallel_config("loky", n_jobs=n_jobs): + with parallel_config(parallel_backend, n_jobs=n_jobs): df_columns.extend(Parallel()(delayed_encodes)) df = pd.concat(df_columns, axis=1) if df_columns else pd.DataFrame() diff --git a/mostlyai/engine/analysis.py b/mostlyai/engine/analysis.py index 562fcc0a..00a36a4a 100644 --- a/mostlyai/engine/analysis.py +++ b/mostlyai/engine/analysis.py @@ -108,6 +108,7 @@ def analyze( differential_privacy: DifferentialPrivacyConfig | None = None, workspace_dir: str | Path = "engine-ws", update_progress: ProgressCallback | None = None, + parallel_backend: str = "loky", ) -> None: """ Generates (privacy-safe) column-level statistics of the original data, that has been `split` into the workspace. @@ -122,6 +123,7 @@ def analyze( value_protection: Whether to enable value protection for rare values. workspace_dir: Path to workspace directory containing partitioned data. update_progress: Optional callback to update progress during analysis. + parallel_backend: Joblib parallel backend to use. Options include 'loky', 'threading', 'multiprocessing', etc. """ _LOG.info("ANALYZE started") @@ -167,6 +169,7 @@ def analyze( ctx_primary_key=ctx_primary_key if has_context else None, ctx_root_key=ctx_root_key, n_jobs=min(16, max(1, cpu_count() - 1)), + parallel_backend=parallel_backend, ) progress.update(completed=i, total=len(tgt_pqt_partitions) + 1) @@ -221,6 +224,7 @@ def _analyze_partition( ctx_primary_key: str | None = None, ctx_root_key: str | None = None, n_jobs: int = 1, + parallel_backend: str = "loky", ) -> None: """ Calculates partial statistics about a single partition. @@ -252,7 +256,7 @@ def _analyze_partition( ctx_root_keys = ctx_primary_keys.rename("__rkey") # analyze all target columns - with parallel_config("loky", n_jobs=n_jobs): + with parallel_config(parallel_backend, n_jobs=n_jobs): results = Parallel()( delayed(_analyze_col)( values=tgt_df[column], @@ -293,7 +297,7 @@ def _analyze_partition( # analyze all context columns assert isinstance(ctx_encoding_types, dict) - with parallel_config("loky", n_jobs=n_jobs): + with parallel_config(parallel_backend, n_jobs=n_jobs): results = Parallel()( delayed(_analyze_col)( values=ctx_df[column], diff --git a/mostlyai/engine/encoding.py b/mostlyai/engine/encoding.py index 607e538e..cfb28ec8 100644 --- a/mostlyai/engine/encoding.py +++ b/mostlyai/engine/encoding.py @@ -23,6 +23,7 @@ def encode( *, workspace_dir: str | Path = "engine-ws", update_progress: ProgressCallback | None = None, + parallel_backend: str = "loky", ) -> None: """ Encodes data in the workspace that has already been split and analyzed. @@ -34,13 +35,18 @@ def encode( Args: workspace_dir: Directory path for workspace. update_progress: Callback for progress updates. + parallel_backend: Joblib parallel backend to use. Options include 'loky', 'threading', 'multiprocessing', etc. """ model_type = resolve_model_type(workspace_dir) if model_type == ModelType.tabular: from mostlyai.engine._tabular.encoding import encode as encode_tabular - return encode_tabular(workspace_dir=workspace_dir, update_progress=update_progress) + return encode_tabular( + workspace_dir=workspace_dir, update_progress=update_progress, parallel_backend=parallel_backend + ) else: from mostlyai.engine._language.encoding import encode as encode_language - return encode_language(workspace_dir=workspace_dir, update_progress=update_progress) + return encode_language( + workspace_dir=workspace_dir, update_progress=update_progress, parallel_backend=parallel_backend + )