Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions checkpoint/orbax/checkpoint/_src/path/async_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,17 @@ async def open_file(
path: epath.Path, mode: str = 'rb'
) -> AsyncIterator[AsyncFile]:
"""Async context manager for opening files."""
f = await asyncio.to_thread(path.open, mode=mode)
try:
yield AsyncFile(f)
finally:
await asyncio.to_thread(f.close)
f_or_cm = await asyncio.to_thread(path.open, mode=mode)
if hasattr(f_or_cm, 'read'):
f = f_or_cm
try:
yield AsyncFile(f)
finally:
await asyncio.to_thread(f.close)
else: # It is a context manager
cm = f_or_cm
f = await asyncio.to_thread(cm.__enter__)
try:
yield AsyncFile(f)
finally:
await asyncio.to_thread(cm.__exit__, None, None, None)
21 changes: 21 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/async_path_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

import asyncio
import contextlib
import io
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -135,6 +138,24 @@ async def read_chunk(offset, size):

asyncio.run(_test())

def test_open_returns_context_manager_handled(self):
test_file = self.test_dir / 'test.txt'
test_file.write_text('hello world')

@contextlib.contextmanager
def open_mock(mode):
del mode
yield io.BytesIO(b'hello world')

async def _test():
with mock.patch.object(
test_file, 'open', return_value=open_mock('rb')
):
async with async_path.open_file(test_file, 'rb') as f:
self.assertEqual(await f.read(), b'hello world')

asyncio.run(_test())


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
suite_name: "Safetensors Load Benchmark"

mesh_configs:
# Case 1: v5litepod-8, num_slices=2 (4 processes, 4 chips/process)
- mesh_axes: ["data", "model"]
ici_parallelism: {"data": 4, "model": 1}
dcn_parallelism: {"data": 4, "model": 1}
# Case 2: v5litepod-8, num_slices=1 (2 processes, 4 chips/process)
- mesh_axes: ["data", "model"]
ici_parallelism: {"data": 4, "model": 1}
dcn_parallelism: {"data": 2, "model": 1}
# Case 5: v5litepod-16, num_slices=1 (4 processes, 4 chips/process), ICI-only
- mesh_axes: ["data", "model"]
ici_parallelism: {"data": 8, "model": 16}
dcn_parallelism: null
- mesh_axes: ["data", "model"]
ici_parallelism: {"data": 1, "model": 16}
dcn_parallelism: null
- mesh_axes: ["data", "model"]
ici_parallelism: {"data": 1, "model": 32}
dcn_parallelism: null
- mesh_axes: ["data", "model"]
ici_parallelism: {"data": 1, "model": 64}
dcn_parallelism: null

checkpoint_config:
spec:
array: {dtype: "float32", shape: [1024, 2048], sharding: ["data", "model"]}

benchmarks:
- generator: "orbax.checkpoint._src.testing.benchmarks.safetensors_benchmark.SafetensorsBenchmark"
options:
checkpoint_path: "gs://safetensor-kimi-central/test-model-kimi"
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Benchmarks for SafetensorsLayout (V1)."""

import asyncio
import dataclasses

from absl import logging
from etils import epath
import jax
from orbax.checkpoint._src.arrays import sharding as sharding_utils
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
from orbax.checkpoint.experimental import v1 as ocp_v1


# ==============================================================================
# Define the Options Dataclass
# ==============================================================================
@dataclasses.dataclass(frozen=True)
class SafetensorsBenchmarkOptions(benchmarks_core.BenchmarkOptions):
"""Configuration options for benchmarks targeting SafetensorsLayout.

Attributes:
checkpoint_config_path: The path to the checkpoint config file.
"""

checkpoint_path: str | None = None


# ==============================================================================
# 2. Implement the Benchmark Generator
# ==============================================================================
@benchmarks_core.benchmark_options(SafetensorsBenchmarkOptions)
class SafetensorsBenchmark(benchmarks_core.BenchmarksGenerator):
"""A generator for benchmarking SafetensorsLayout."""

def test_fn(
self, context: benchmarks_core.TestContext
) -> benchmarks_core.TestResult:
"""The core test logic for a single save/restore cycle using V1 API."""
metrics = metric_lib.Metrics()
options = context.options
assert isinstance(options, SafetensorsBenchmarkOptions)

load_path = epath.Path(options.checkpoint_path)
logging.info('Benchmarking Load from: %s', load_path)
mesh = context.mesh

async def _load_gcs():
octx = ocp_v1.Context(
checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS
)
with octx:
# METRIC 1: Header/Index parsing (Metadata)
with metrics.measure('metadata_load'):
logging.info('Step 1: Parsing Safetensors metadata...')
metadata = ocp_v1.pytree_metadata(load_path)
abstract_state = metadata.metadata

# METRIC 2: The actual data transfer (The sharded load)
with metrics.measure('data_load_sharded'):
logging.info('Step 2: Starting sharded data transfer...')

shardings = sharding_utils.construct_maximal_shardings(
abstract_state, list(mesh.devices.flatten())
)
sharded_abstract_state = jax.tree.map(
lambda sds, sharding: jax.ShapeDtypeStruct(
sds.shape, sds.dtype, sharding=sharding
),
abstract_state,
shardings,
)

restored_pytree = ocp_v1.load_pytree(
load_path, sharded_abstract_state
)

# Verify the result landed on TPU
first_leaf = jax.tree_util.tree_leaves(restored_pytree)[0]
logging.info(
'SUCCESS: Load complete. First leaf shape: %s, on devices: %s',
first_leaf.shape,
first_leaf.devices(),
)
return restored_pytree

# Safe execution for benchmark environments
loop = asyncio.get_event_loop()
loop.run_until_complete(_load_gcs())

return benchmarks_core.TestResult(metrics=metrics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright 2026 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
import jax
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint._src.testing.benchmarks import safetensors_benchmark
from orbax.checkpoint._src.testing.benchmarks.core import configs as benchmarks_configs
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
from orbax.checkpoint.experimental import v1 as ocp_v1
import safetensors.numpy as safe_np

SafetensorsBenchmarkOptions = safetensors_benchmark.SafetensorsBenchmarkOptions
SafetensorsBenchmark = safetensors_benchmark.SafetensorsBenchmark


class SafetensorsBenchmarkTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self.test_dir = epath.Path(self.create_tempdir().full_path)
self.checkpoint_path = self.test_dir / 'fake_checkpoint.safetensors'

self.dummy_pytree = {
'tensor_a': jnp.ones((32, 1024), dtype=jnp.float32),
'scalar': jnp.ones((), dtype=jnp.float32),
'vector': jnp.ones((1024,), dtype=jnp.float32),
}

save_pytree = jax.tree.map(np.array, self.dummy_pytree)
safe_np.save_file(save_pytree, str(self.checkpoint_path))

def test_benchmark_test_fn_sharded_load(self):
# 1. Setup Benchmark Generator
generator = SafetensorsBenchmark(
checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})],
options=SafetensorsBenchmarkOptions(),
)

# 2. Create Test Context
devices = np.array(jax.devices())
if devices.size == 1:
devices = devices.reshape(1, 1)
else:
devices = devices.reshape(1, devices.size) # Keep it simple for this test
mesh = jax.sharding.Mesh(devices, ('data', 'model'))
options = SafetensorsBenchmarkOptions(
checkpoint_path=str(self.checkpoint_path)
)

context = benchmarks_core.TestContext(
pytree={}, # Unused in this test_fn implementation
path=self.checkpoint_path,
options=options,
mesh=mesh,
)

# 3. Run the Benchmark Test Function
result = generator.test_fn(context)

# 4. Verify Benchmark Metrics
self.assertIsInstance(result, benchmarks_core.TestResult)
self.assertIn('metadata_load_time_duration', result.metrics.results)
self.assertIn('data_load_sharded_time_duration', result.metrics.results)

# 5. Verify Loaded Content by Reloading
octx = ocp_v1.Context(
checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS
)
with octx:
metadata = ocp_v1.pytree_metadata(self.checkpoint_path)
abstract_state = metadata.metadata
restored_pytree = ocp_v1.load_pytree(self.checkpoint_path, abstract_state)

self.assertEqual(
jax.tree_util.tree_structure(restored_pytree),
jax.tree_util.tree_structure(self.dummy_pytree),
)
jax.tree.map(
self.assertTrue,
jax.tree.map(
lambda a, b: np.array_equal(np.array(a), np.array(b)),
restored_pytree,
self.dummy_pytree,
),
)
jax.tree.map(
self.assertEqual,
jax.tree.map(lambda a: a.shape, restored_pytree),
jax.tree.map(lambda a: a.shape, self.dummy_pytree),
)
jax.tree.map(
self.assertEqual,
jax.tree.map(lambda a: a.dtype, restored_pytree),
jax.tree.map(lambda a: a.dtype, self.dummy_pytree),
)

def test_benchmark_test_fn_rank_aware_sharding(self):
# 1. Setup Benchmark Generator
generator = SafetensorsBenchmark(
checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})],
options=SafetensorsBenchmarkOptions(),
)

# 2. Create Test Context
devices = np.array(jax.devices())
# Reshape devices to be 2D for the mesh axis names ('data', 'model')
num_devices = devices.size
if num_devices == 1:
devices = devices.reshape(1, 1)
elif num_devices == 2:
devices = devices.reshape(1, 2)
elif num_devices % 2 == 0:
devices = devices.reshape(2, num_devices // 2)
else: # Fallback for odd numbers, should not happen in typical test envs
devices = devices.reshape(1, num_devices)
mesh = jax.sharding.Mesh(devices, ('data', 'model'))
options = SafetensorsBenchmarkOptions(
checkpoint_path=str(self.checkpoint_path)
)

context = benchmarks_core.TestContext(
pytree={}, # Unused
path=self.checkpoint_path,
options=options,
mesh=mesh,
)

# 3. Run the Benchmark Test Function
result = generator.test_fn(context)

# 4. Verify Benchmark Metrics
self.assertIsInstance(result, benchmarks_core.TestResult)
self.assertIn('metadata_load_time_duration', result.metrics.results)
self.assertIn('data_load_sharded_time_duration', result.metrics.results)

# 5. Verify Loaded Content by Reloading
octx = ocp_v1.Context(
checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS
)
with octx:
metadata = ocp_v1.pytree_metadata(self.checkpoint_path)
abstract_state = metadata.metadata
# Note: Sharding is not applied here, loading as is from the file.
restored_pytree = ocp_v1.load_pytree(self.checkpoint_path, abstract_state)

self.assertEqual(
jax.tree_util.tree_structure(restored_pytree),
jax.tree_util.tree_structure(self.dummy_pytree),
)
jax.tree.map(
self.assertTrue,
jax.tree.map(
lambda a, b: np.array_equal(np.array(a), np.array(b)),
restored_pytree,
self.dummy_pytree,
),
)
jax.tree.map(
self.assertEqual,
jax.tree.map(lambda a: a.shape, restored_pytree),
jax.tree.map(lambda a: a.shape, self.dummy_pytree),
)
jax.tree.map(
self.assertEqual,
jax.tree.map(lambda a: a.dtype, restored_pytree),
jax.tree.map(lambda a: a.dtype, self.dummy_pytree),
)


if __name__ == '__main__':
absltest.main()
Loading
Loading