Skip to content

[WIP][tx] Add initial implementation of RayJaxBackend#1418

Draft
andrewsykim wants to merge 14 commits intoNovaSky-AI:mainfrom
andrewsykim:ray-tx
Draft

[WIP][tx] Add initial implementation of RayJaxBackend#1418
andrewsykim wants to merge 14 commits intoNovaSky-AI:mainfrom
andrewsykim:ray-tx

Conversation

@andrewsykim
Copy link
Copy Markdown
Contributor

@andrewsykim andrewsykim commented Mar 31, 2026

Fixes #1393

This PR introduces the initial implementation of RayJaxBackend and RayJaxBackendImpl to enable running skyrl-tx on a Ray cluster with a single ray job submit command. When enabled, the Tinker API and Engine run on the driver/head node, while the actual JAX backend operations are distributed across Ray actors running on worker nodes. This removes the need for manual multi-node orchestration for JAX distributed training.

Tested the changes by running the following command on my 4x4 v6e TPU cluster:

ray job submit --runtime-env-json '{"py_executable": "uv run"}' -- sh -c 'cd /home/ray/SkyRL && uv run --extra tpu --extra jax --extra tinker -m skyrl.tinker.api --base-model Qwen/Qwen3-0.6B --backend ray-jax  --backend-config '\''{"ray_pg_bundles": [{"CPU": 4, "TPU": 4},{"CPU": 4, "TPU": 4},{"CPU": 4, "TPU": 4},{"CPU": 4, "TPU": 4}], "tensor_parallel_size": 4, "sample_max_num_sequences": 256, "train_micro_batch_size": 8, "fully_sharded_data_parallel_size": 4, "num_processes": 4}'\'''

Create a train script (copied from skyrl-tx examples)

import tinker
import numpy as np
from tinker import types

# Connect to the local server
service_client = tinker.ServiceClient(base_url="http://localhost:8000", api_key="tml-dummy")
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
tokenizer = training_client.get_tokenizer()

# Training examples
examples = [
    {"input": "banana split", "output": "anana-bay plit-say"},
    {"input": "quantum physics", "output": "uantum-qay ysics-phay"},
    {"input": "coding wizard", "output": "oding-cay izard-way"},
]

def process_example(example, tokenizer):
    prompt = f"English: {example['input']}\nPig Latin:"
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)

    tokens = prompt_tokens + completion_tokens
    weights = [0] * len(prompt_tokens) + [1] * len(completion_tokens)

    return types.Datum(
        model_input=types.ModelInput.from_ints(tokens=tokens[:-1]),
        loss_fn_inputs=dict(weights=weights[1:], target_tokens=tokens[1:])
    )

processed = [process_example(ex, tokenizer) for ex in examples]

# Training loop
for _ in range(6):
    fwdbwd = training_client.forward_backward(processed, "cross_entropy").result()
    training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result()

    logprobs = np.concatenate([o['logprobs'].tolist() for o in fwdbwd.loss_fn_outputs])
    weights = np.concatenate([e.loss_fn_inputs['weights'].tolist() for e in processed])
    print(f"Loss: {-np.dot(logprobs, weights) / weights.sum():.4f}")

Run this training script on the driver:

ray job submit --working-dir . -- sh -c 'cp rl_loop.py /home/ray/SkyRL/ && cd /home/ray/SkyRL && uv run rl_loop.py'

Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
…er API and engine now run on driver

Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
…point script

Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[tinker] Support using Ray to manage tinker API server, engine and Jax workers

1 participant