-
Notifications
You must be signed in to change notification settings - Fork 3
docs: transforms API #134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
docs: transforms API #134
Changes from all commits
b47e324
dd1afd7
104bd50
fe885a2
d61bdc3
ea6f3d1
5c69d7e
f2c5e9b
8a52b96
87e801f
42c070d
94c5bda
e4c4a61
d8dc704
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,199 @@ | ||
| # Transforms | ||
|
|
||
| 🚧 Coming Soon 🚧 | ||
| Transforms allow you to post-process model outputs after ONNX inference and before returning results. They run inside the model binary, operating directly on tensors for high performance. | ||
|
|
||
| Transforms run on Lua 5.4 in a sandboxed environment. As of right now, the transforms feature does not support LuaJIT. | ||
|
|
||
| ## Why Use Transforms? | ||
|
|
||
| Common use cases: | ||
| - **Normalize embeddings** for cosine similarity | ||
| - **Apply softmax** to convert logits to probabilities | ||
| - **Pool embeddings** to create sentence representations | ||
| - **Scale outputs** for specific downstream tasks | ||
|
|
||
| ## Getting Started | ||
|
|
||
| A transform is a Lua script that defines a `Postprocess` function: | ||
|
|
||
| ```lua | ||
| ---@param arr Tensor | ||
| ---@return Tensor | ||
| function Postprocess(arr, ...) | ||
| -- your postprocessing logic | ||
| return tensor | ||
| end | ||
| ``` | ||
|
|
||
| With a handful of exceptions, the `Postprocess` function must return a `Tensor` with the exact same shape as the input `Tensor` provided for that model type. The exceptions are as follows: | ||
|
|
||
| - Embedding and sentence embedding models can modify the length of `hidden` (useful for matryoshka embeddings) | ||
| - Sentence embeddings are given a `Tensor` of shape `[batch_size, seq_len, hidden]` and attention mask of `[batch_size, seq_len]`, and must return a `Tensor` of shape `[batch_size, hidden]`. In other words, it expects a pooling operation along dimension `seq_len`. | ||
|
|
||
| !!! note "Note on indexing" | ||
| Lua is 1-indexed, meaning that it starts counting at 1 instead of 0. The `Tensor` API reflects this, meaning that you must count your axes and indices starting at 1 instead of 0. | ||
besaleli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| We provide a built-in API for standard tensor operations. To learn more, check out our [Tensor API reference page](reference). You can find the stub file [here](https://github.com/mozilla-ai/encoderfile/blob/main/encoderfile-core/stubs/lua/tensor.lua). | ||
|
|
||
| If you don't see an op that you need, please don't hesitate to [create an issue](https://github.com/mozilla-ai/encoderfile/issues) on Github. | ||
|
|
||
| ## Input Signatures | ||
|
|
||
| The input signature of `Postprocess` depends on the type of model being used. | ||
|
|
||
| ### Embedding | ||
|
|
||
| ```lua | ||
| --- input: 3d tensor of shape [batch_size, seq_len, hidden] | ||
| ---@param arr Tensor | ||
| ---output: 3d tensor of shape [batch_size, seq_len, hidden] | ||
| ---@return Tensor | ||
| function Postprocess(arr) | ||
| -- your postprocessing logic | ||
| return tensor | ||
| end | ||
| ``` | ||
|
|
||
| ### Sequence Classification | ||
|
|
||
| ```lua | ||
| --- input: 2d tensor of shape [batch_size, n_labels] | ||
| ---@param arr Tensor | ||
| ---output: 2d tensor of shape [batch_size, n_labels] | ||
| ---@return Tensor | ||
| function Postprocess(arr) | ||
| -- your postprocessing logic | ||
| return tensor | ||
| end | ||
| ``` | ||
|
|
||
| ### Token Classification | ||
|
|
||
| ```lua | ||
| --- input: 3d tensor of shape [batch_size, seq_len, n_labels] | ||
| ---@param arr Tensor | ||
| ---output: 3d tensor of shape [batch_size, seq_len, n_labels] | ||
| ---@return Tensor | ||
| function Postprocess(arr) | ||
| -- your postprocessing logic | ||
| return tensor | ||
| end | ||
| ``` | ||
|
|
||
| ### Sentence Embedding | ||
|
|
||
|
|
||
| !!! note "Mean Pooling" | ||
| To mean-pool embeddings, you can use the `Tensor:mean_pool` function like this: `tensor:mean_pool(mask)`. | ||
|
|
||
| ```lua | ||
| --- input: 3d tensor of shape [batch_size, seq_len, hidden] | ||
| ---@param arr Tensor | ||
| -- input: 2d tensor of shape [batch_size, seq_len] | ||
| -- This is automatically provided to the function and is equivalent to 🤗 transformer's attention_mask. | ||
| ---@param mask Tensor | ||
| ---output: 2d tensor of shape [batch_size, hidden] | ||
| ---@return Tensor | ||
| function Postprocess(arr, mask) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specify a sample mask value here.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @angpt The mask is going to look like a 2d matrix with a bunch of 1s and 0s. It's automatically provided by the transform engine as a parameter to the |
||
| -- your postprocessing logic | ||
| return tensor | ||
| end | ||
| ``` | ||
|
|
||
| ## Typical Transform Patterns | ||
|
|
||
| Most transforms fall into one of 3 patterns: | ||
|
|
||
| ### 1. Elementwise Transforms | ||
|
|
||
| Safe: they preserve shape automatically. | ||
|
|
||
| Examples: | ||
|
|
||
| - scaling (`tensor * 1.5`) | ||
| - activation functions (`tensor:exp()`) | ||
|
|
||
| ### 2. Normalization Across Axis | ||
|
|
||
| These also preserve shape. | ||
|
|
||
| Examples: | ||
|
|
||
| - Lp normalization: (`tensor:lp_normalize(p, axis)`) | ||
| - subtracting mean per batch or per token | ||
| - applying softmax across a specific dimension (`tensor:softmax(2)`) | ||
|
|
||
| ### 3. Mask-aware adjustments | ||
|
|
||
| When working with sentence embedding models: | ||
|
|
||
| ```lua | ||
| function Postprocess(arr, mask) | ||
| -- embeddings: [batch, seq, hidden] | ||
| -- mask: [batch, seq] | ||
|
|
||
| -- operations here must output [batch, hidden] | ||
| return ... | ||
| end | ||
| ``` | ||
|
|
||
| ## Best Practices | ||
|
|
||
| !!! warning "Performance Implications" | ||
| Transforms run synchronously during inference, so expensive Lua-side loops will increase latency. If you don't see an op that you need, please don't hesitate to [create an issue](https://github.com/mozilla-ai/encoderfile/issues) on Github. | ||
|
|
||
| A typical transform follows this structure: | ||
|
|
||
| ```lua | ||
| function Postprocess(arr, ...) | ||
| -- Step 1: apply elementwise or axis-based operations | ||
| local modified = arr:exp() -- example | ||
|
|
||
| -- Step 2: ensure the output shape matches the input shape | ||
| -- (all built-in ops described in the Tensor API preserve shape) | ||
|
|
||
| return modified | ||
| end | ||
| ``` | ||
|
|
||
| ## Debugging Transforms | ||
|
|
||
| You can inspect shape and values using: | ||
|
|
||
| ```lua | ||
| print("ndim:", t:ndim()) | ||
| print("len:", #t) | ||
| print(tostring(t)) | ||
| ``` | ||
|
|
||
| Errors typically fall into: | ||
|
|
||
| - axis out of range | ||
| → axis must be 1-indexed and ≤ tensor rank | ||
|
|
||
| - broadcasting errors | ||
| → the two shapes are incompatible | ||
|
|
||
| - returned value is not a tensor | ||
| → must return a Tensor userdata object | ||
|
|
||
| - shape mismatch | ||
| → you modified rank or dimensions | ||
|
|
||
| ## Configuration | ||
|
|
||
| Transforms are embedded at build time. You can specify them in your config.yml either as a file path or inline. | ||
|
|
||
| ```yml | ||
| transform: | ||
| path: path/to/your/transform/here | ||
| ``` | ||
|
|
||
| Or, they can be passed inline: | ||
| ```yml | ||
| transform: | | ||
| function Postprocess(arr) | ||
| ... | ||
| return arr | ||
| ``` | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.