Skip to content

[FEA] Store categorical encoding in Treelite model representation #639

@hcho3

Description

@hcho3

Background

Many data science applications involve categorical variables (features). A categorical variable takes in a limited number of possible values. Often these values come in as human-readable strings.

Example: U.S. Residency Status variable could take in string values: Citizen, Permanent Resident, Alien, and Other.

In a typical data science pipeline, string-valued categories are first encoded as numerical indices (0, 1, 2, ...). Encoding is done because most tree libraries represent test conditions using numerical indices for the sake of space efficiency.

Example: The test condition U.S. Residency Status == Citizen or Permanent Resident would be represented as x \in {0, 1}, where Citizen have been mapped to 0 and Permanent Resident mapped to 1.

The mapping between string categories and numerical indices will need to be stored along with the rest of the tree model, so that we can apply the consistent encoding when we run inference with new data:

  1. Scan the training data set and create an encoding map for each categorical feature.
    Example: The U.S. Residency Status variable would have the following mapping:

    • Citizen -> 0
    • Permanent Resident -> 1
    • Alien -> 2
    • Other -> 3
  2. Store the encoding map as part of the tree model.

  3. At inference time, apply the same encoding map to test data set and run inference.

It is crucial that the same encoding be used in training and inference stages to obtain correct outputs.

State-of-art practices

The Python data science community has developed best practices to discover and store encodings for categorical variables.

Example: Pandas DataFrame has a built-in support for categorical columns and will create encodings on the fly.

import pandas as pd
from pandas.api.types import CategoricalDtype

s1 = pd.Series(["a", "b", "c", "a"], dtype="category")
print(s1.cat.categories)
  # Prints ['a', 'b', 'c'], which represents the mapping a -> 0, b -> 1, c -> 2
print(s1.cat.codes)
  # Prints [0, 1, 2, 0], which is the result of applying the categorical mapping to column s1

# Create a new column s2 and apply the same encoding as column s1
s2 = pd.Series(["b", "c", "a"]).astype(CategoricalDtype(categories=s1.cat.categories))
print(s2.cat.codes)
  # Prints [1, 2, 0], which is the result of applying the categorical mapping to column s2

ML libraries such as scikit-learn, LightGBM, and XGBoost (version 3.1+) uses Pandas to apply the consistent categorical encoding in training and inference stages. The libraries will store the categorical encoding as part of the model file so that the encoding can be used later.

Status quo

Treelite does not store any information regarding categorical encodings. This has a few important consequences:

  • Users will need to manually save the categorical encoding to a separate file, and remember to apply the encoding later at inference time.
  • GTIL, Treelite's reference implementation for tree inference, only accepts NumPy array as inputs and does not accept Pandas DataFrames.
  • Treelite throws an error when the user attempts to load a HistGradientBoosting model from scikit-learn that was fit from string categorical variables.
  • Downstream applications, such as Forest Inference Library (FIL), have no way to natively support categorical encodings.

Overall, the status quo represents a subpar user experience.

Proposal

Update the Treelite model representation to explicitly store encodings for categorical features. The encoding can be stored as an array. For the U.S. Residency Status example above, we'd store ["Citizen", "Permanent Resident", "Alien", "Other"].

I propose to add the categorical_mapping field to the Treelite model representation:

{
...
  "categorical_mapping": [
    {"feature_id": [id], "mapping": [ list of categories ] },
    {"feature_id": [id], "mapping": [ list of categories ] },
    ...
  ]
...
}

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions