Skip to content
Merged
198 changes: 197 additions & 1 deletion docs/transforms/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,199 @@
# Transforms

🚧 Coming Soon 🚧
Transforms allow you to post-process model outputs after ONNX inference and before returning results. They run inside the model binary, operating directly on tensors for high performance.

Transforms run on Lua 5.4 in a sandboxed environment. As of right now, the transforms feature does not support LuaJIT.

## Why Use Transforms?

Common use cases:
- **Normalize embeddings** for cosine similarity
- **Apply softmax** to convert logits to probabilities
- **Pool embeddings** to create sentence representations
- **Scale outputs** for specific downstream tasks

## Getting Started

A transform is a Lua script that defines a `Postprocess` function:

```lua
---@param arr Tensor
---@return Tensor
function Postprocess(arr, ...)
-- your postprocessing logic
return tensor
end
```

With a handful of exceptions, the `Postprocess` function must return a `Tensor` with the exact same shape as the input `Tensor` provided for that model type. The exceptions are as follows:

- Embedding and sentence embedding models can modify the length of `hidden` (useful for matryoshka embeddings)
- Sentence embeddings are given a `Tensor` of shape `[batch_size, seq_len, hidden]` and attention mask of `[batch_size, seq_len]`, and must return a `Tensor` of shape `[batch_size, hidden]`. In other words, it expects a pooling operation along dimension `seq_len`.

!!! note "Note on indexing"
Lua is 1-indexed, meaning that it starts counting at 1 instead of 0. The `Tensor` API reflects this, meaning that you must count your axes and indices starting at 1 instead of 0.

We provide a built-in API for standard tensor operations. To learn more, check out our [Tensor API reference page](reference). You can find the stub file [here](https://github.com/mozilla-ai/encoderfile/blob/main/encoderfile-core/stubs/lua/tensor.lua).

If you don't see an op that you need, please don't hesitate to [create an issue](https://github.com/mozilla-ai/encoderfile/issues) on Github.

## Input Signatures

The input signature of `Postprocess` depends on the type of model being used.

### Embedding

```lua
--- input: 3d tensor of shape [batch_size, seq_len, hidden]
---@param arr Tensor
---output: 3d tensor of shape [batch_size, seq_len, hidden]
---@return Tensor
function Postprocess(arr)
-- your postprocessing logic
return tensor
end
```

### Sequence Classification

```lua
--- input: 2d tensor of shape [batch_size, n_labels]
---@param arr Tensor
---output: 2d tensor of shape [batch_size, n_labels]
---@return Tensor
function Postprocess(arr)
-- your postprocessing logic
return tensor
end
```

### Token Classification

```lua
--- input: 3d tensor of shape [batch_size, seq_len, n_labels]
---@param arr Tensor
---output: 3d tensor of shape [batch_size, seq_len, n_labels]
---@return Tensor
function Postprocess(arr)
-- your postprocessing logic
return tensor
end
```

### Sentence Embedding


!!! note "Mean Pooling"
To mean-pool embeddings, you can use the `Tensor:mean_pool` function like this: `tensor:mean_pool(mask)`.

```lua
--- input: 3d tensor of shape [batch_size, seq_len, hidden]
---@param arr Tensor
-- input: 2d tensor of shape [batch_size, seq_len]
-- This is automatically provided to the function and is equivalent to 🤗 transformer's attention_mask.
---@param mask Tensor
---output: 2d tensor of shape [batch_size, hidden]
---@return Tensor
function Postprocess(arr, mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify a sample mask value here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@angpt The mask is going to look like a 2d matrix with a bunch of 1s and 0s. It's automatically provided by the transform engine as a parameter to the Postprocess function.

-- your postprocessing logic
return tensor
end
```

## Typical Transform Patterns

Most transforms fall into one of 3 patterns:

### 1. Elementwise Transforms

Safe: they preserve shape automatically.

Examples:

- scaling (`tensor * 1.5`)
- activation functions (`tensor:exp()`)

### 2. Normalization Across Axis

These also preserve shape.

Examples:

- Lp normalization: (`tensor:lp_normalize(p, axis)`)
- subtracting mean per batch or per token
- applying softmax across a specific dimension (`tensor:softmax(2)`)

### 3. Mask-aware adjustments

When working with sentence embedding models:

```lua
function Postprocess(arr, mask)
-- embeddings: [batch, seq, hidden]
-- mask: [batch, seq]

-- operations here must output [batch, hidden]
return ...
end
```

## Best Practices

!!! warning "Performance Implications"
Transforms run synchronously during inference, so expensive Lua-side loops will increase latency. If you don't see an op that you need, please don't hesitate to [create an issue](https://github.com/mozilla-ai/encoderfile/issues) on Github.

A typical transform follows this structure:

```lua
function Postprocess(arr, ...)
-- Step 1: apply elementwise or axis-based operations
local modified = arr:exp() -- example

-- Step 2: ensure the output shape matches the input shape
-- (all built-in ops described in the Tensor API preserve shape)

return modified
end
```

## Debugging Transforms

You can inspect shape and values using:

```lua
print("ndim:", t:ndim())
print("len:", #t)
print(tostring(t))
```

Errors typically fall into:

- axis out of range
→ axis must be 1-indexed and ≤ tensor rank

- broadcasting errors
→ the two shapes are incompatible

- returned value is not a tensor
→ must return a Tensor userdata object

- shape mismatch
→ you modified rank or dimensions

## Configuration

Transforms are embedded at build time. You can specify them in your config.yml either as a file path or inline.

```yml
transform:
path: path/to/your/transform/here
```

Or, they can be passed inline:
```yml
transform: |
function Postprocess(arr)
...
return arr
```