diff --git a/checkpoint/orbax/checkpoint/_src/path/async_path.py b/checkpoint/orbax/checkpoint/_src/path/async_path.py index e574d7397..faa038525 100644 --- a/checkpoint/orbax/checkpoint/_src/path/async_path.py +++ b/checkpoint/orbax/checkpoint/_src/path/async_path.py @@ -159,8 +159,17 @@ async def open_file( path: epath.Path, mode: str = 'rb' ) -> AsyncIterator[AsyncFile]: """Async context manager for opening files.""" - f = await asyncio.to_thread(path.open, mode=mode) - try: - yield AsyncFile(f) - finally: - await asyncio.to_thread(f.close) + f_or_cm = await asyncio.to_thread(path.open, mode=mode) + if hasattr(f_or_cm, 'read'): + f = f_or_cm + try: + yield AsyncFile(f) + finally: + await asyncio.to_thread(f.close) + else: # It is a context manager + cm = f_or_cm + f = await asyncio.to_thread(cm.__enter__) + try: + yield AsyncFile(f) + finally: + await asyncio.to_thread(cm.__exit__, None, None, None) diff --git a/checkpoint/orbax/checkpoint/_src/path/async_path_test.py b/checkpoint/orbax/checkpoint/_src/path/async_path_test.py index 99a340974..5ac4cbe19 100644 --- a/checkpoint/orbax/checkpoint/_src/path/async_path_test.py +++ b/checkpoint/orbax/checkpoint/_src/path/async_path_test.py @@ -13,6 +13,9 @@ # limitations under the License. import asyncio +import contextlib +import io +from unittest import mock from absl.testing import absltest from absl.testing import parameterized @@ -135,6 +138,24 @@ async def read_chunk(offset, size): asyncio.run(_test()) + def test_open_returns_context_manager_handled(self): + test_file = self.test_dir / 'test.txt' + test_file.write_text('hello world') + + @contextlib.contextmanager + def open_mock(mode): + del mode + yield io.BytesIO(b'hello world') + + async def _test(): + with mock.patch.object( + test_file, 'open', return_value=open_mock('rb') + ): + async with async_path.open_file(test_file, 'rb') as f: + self.assertEqual(await f.read(), b'hello world') + + asyncio.run(_test()) + if __name__ == '__main__': absltest.main() diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/safetensors_benchmark.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/safetensors_benchmark.yaml new file mode 100644 index 000000000..025f06f26 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/safetensors_benchmark.yaml @@ -0,0 +1,33 @@ +suite_name: "Safetensors Load Benchmark" + +mesh_configs: + # Case 1: v5litepod-8, num_slices=2 (4 processes, 4 chips/process) + - mesh_axes: ["data", "model"] + ici_parallelism: {"data": 4, "model": 1} + dcn_parallelism: {"data": 4, "model": 1} + # Case 2: v5litepod-8, num_slices=1 (2 processes, 4 chips/process) + - mesh_axes: ["data", "model"] + ici_parallelism: {"data": 4, "model": 1} + dcn_parallelism: {"data": 2, "model": 1} + # Case 5: v5litepod-16, num_slices=1 (4 processes, 4 chips/process), ICI-only + - mesh_axes: ["data", "model"] + ici_parallelism: {"data": 8, "model": 16} + dcn_parallelism: null + - mesh_axes: ["data", "model"] + ici_parallelism: {"data": 1, "model": 16} + dcn_parallelism: null + - mesh_axes: ["data", "model"] + ici_parallelism: {"data": 1, "model": 32} + dcn_parallelism: null + - mesh_axes: ["data", "model"] + ici_parallelism: {"data": 1, "model": 64} + dcn_parallelism: null + +checkpoint_config: + spec: + array: {dtype: "float32", shape: [1024, 2048], sharding: ["data", "model"]} + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.safetensors_benchmark.SafetensorsBenchmark" + options: + checkpoint_path: "gs://safetensor-kimi-central/test-model-kimi" diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/safetensors_benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/safetensors_benchmark.py new file mode 100644 index 000000000..8931ebf70 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/safetensors_benchmark.py @@ -0,0 +1,105 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmarks for SafetensorsLayout (V1).""" + +import asyncio +import dataclasses + +from absl import logging +from etils import epath +import jax +from orbax.checkpoint._src.arrays import sharding as sharding_utils +from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core +from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib +from orbax.checkpoint.experimental import v1 as ocp_v1 + + +# ============================================================================== +# Define the Options Dataclass +# ============================================================================== +@dataclasses.dataclass(frozen=True) +class SafetensorsBenchmarkOptions(benchmarks_core.BenchmarkOptions): + """Configuration options for benchmarks targeting SafetensorsLayout. + + Attributes: + checkpoint_config_path: The path to the checkpoint config file. + """ + + checkpoint_path: str | None = None + + +# ============================================================================== +# 2. Implement the Benchmark Generator +# ============================================================================== +@benchmarks_core.benchmark_options(SafetensorsBenchmarkOptions) +class SafetensorsBenchmark(benchmarks_core.BenchmarksGenerator): + """A generator for benchmarking SafetensorsLayout.""" + + def test_fn( + self, context: benchmarks_core.TestContext + ) -> benchmarks_core.TestResult: + """The core test logic for a single save/restore cycle using V1 API.""" + metrics = metric_lib.Metrics() + options = context.options + assert isinstance(options, SafetensorsBenchmarkOptions) + + load_path = epath.Path(options.checkpoint_path) + logging.info('Benchmarking Load from: %s', load_path) + mesh = context.mesh + + async def _load_gcs(): + octx = ocp_v1.Context( + checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS + ) + with octx: + # METRIC 1: Header/Index parsing (Metadata) + with metrics.measure('metadata_load'): + logging.info('Step 1: Parsing Safetensors metadata...') + metadata = ocp_v1.pytree_metadata(load_path) + abstract_state = metadata.metadata + + # METRIC 2: The actual data transfer (The sharded load) + with metrics.measure('data_load_sharded'): + logging.info('Step 2: Starting sharded data transfer...') + + shardings = sharding_utils.construct_maximal_shardings( + abstract_state, list(mesh.devices.flatten()) + ) + sharded_abstract_state = jax.tree.map( + lambda sds, sharding: jax.ShapeDtypeStruct( + sds.shape, sds.dtype, sharding=sharding + ), + abstract_state, + shardings, + ) + + restored_pytree = ocp_v1.load_pytree( + load_path, sharded_abstract_state + ) + + # Verify the result landed on TPU + first_leaf = jax.tree_util.tree_leaves(restored_pytree)[0] + logging.info( + 'SUCCESS: Load complete. First leaf shape: %s, on devices: %s', + first_leaf.shape, + first_leaf.devices(), + ) + return restored_pytree + + # Safe execution for benchmark environments + loop = asyncio.get_event_loop() + loop.run_until_complete(_load_gcs()) + + return benchmarks_core.TestResult(metrics=metrics) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/safetensors_benchmark_test.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/safetensors_benchmark_test.py new file mode 100644 index 000000000..27c19f082 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/safetensors_benchmark_test.py @@ -0,0 +1,186 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +import jax +import jax.numpy as jnp +import numpy as np +from orbax.checkpoint._src.testing.benchmarks import safetensors_benchmark +from orbax.checkpoint._src.testing.benchmarks.core import configs as benchmarks_configs +from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core +from orbax.checkpoint.experimental import v1 as ocp_v1 +import safetensors.numpy as safe_np + +SafetensorsBenchmarkOptions = safetensors_benchmark.SafetensorsBenchmarkOptions +SafetensorsBenchmark = safetensors_benchmark.SafetensorsBenchmark + + +class SafetensorsBenchmarkTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_dir = epath.Path(self.create_tempdir().full_path) + self.checkpoint_path = self.test_dir / 'fake_checkpoint.safetensors' + + self.dummy_pytree = { + 'tensor_a': jnp.ones((32, 1024), dtype=jnp.float32), + 'scalar': jnp.ones((), dtype=jnp.float32), + 'vector': jnp.ones((1024,), dtype=jnp.float32), + } + + save_pytree = jax.tree.map(np.array, self.dummy_pytree) + safe_np.save_file(save_pytree, str(self.checkpoint_path)) + + def test_benchmark_test_fn_sharded_load(self): + # 1. Setup Benchmark Generator + generator = SafetensorsBenchmark( + checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})], + options=SafetensorsBenchmarkOptions(), + ) + + # 2. Create Test Context + devices = np.array(jax.devices()) + if devices.size == 1: + devices = devices.reshape(1, 1) + else: + devices = devices.reshape(1, devices.size) # Keep it simple for this test + mesh = jax.sharding.Mesh(devices, ('data', 'model')) + options = SafetensorsBenchmarkOptions( + checkpoint_path=str(self.checkpoint_path) + ) + + context = benchmarks_core.TestContext( + pytree={}, # Unused in this test_fn implementation + path=self.checkpoint_path, + options=options, + mesh=mesh, + ) + + # 3. Run the Benchmark Test Function + result = generator.test_fn(context) + + # 4. Verify Benchmark Metrics + self.assertIsInstance(result, benchmarks_core.TestResult) + self.assertIn('metadata_load_time_duration', result.metrics.results) + self.assertIn('data_load_sharded_time_duration', result.metrics.results) + + # 5. Verify Loaded Content by Reloading + octx = ocp_v1.Context( + checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS + ) + with octx: + metadata = ocp_v1.pytree_metadata(self.checkpoint_path) + abstract_state = metadata.metadata + restored_pytree = ocp_v1.load_pytree(self.checkpoint_path, abstract_state) + + self.assertEqual( + jax.tree_util.tree_structure(restored_pytree), + jax.tree_util.tree_structure(self.dummy_pytree), + ) + jax.tree.map( + self.assertTrue, + jax.tree.map( + lambda a, b: np.array_equal(np.array(a), np.array(b)), + restored_pytree, + self.dummy_pytree, + ), + ) + jax.tree.map( + self.assertEqual, + jax.tree.map(lambda a: a.shape, restored_pytree), + jax.tree.map(lambda a: a.shape, self.dummy_pytree), + ) + jax.tree.map( + self.assertEqual, + jax.tree.map(lambda a: a.dtype, restored_pytree), + jax.tree.map(lambda a: a.dtype, self.dummy_pytree), + ) + + def test_benchmark_test_fn_rank_aware_sharding(self): + # 1. Setup Benchmark Generator + generator = SafetensorsBenchmark( + checkpoint_configs=[benchmarks_configs.CheckpointConfig(spec={})], + options=SafetensorsBenchmarkOptions(), + ) + + # 2. Create Test Context + devices = np.array(jax.devices()) + # Reshape devices to be 2D for the mesh axis names ('data', 'model') + num_devices = devices.size + if num_devices == 1: + devices = devices.reshape(1, 1) + elif num_devices == 2: + devices = devices.reshape(1, 2) + elif num_devices % 2 == 0: + devices = devices.reshape(2, num_devices // 2) + else: # Fallback for odd numbers, should not happen in typical test envs + devices = devices.reshape(1, num_devices) + mesh = jax.sharding.Mesh(devices, ('data', 'model')) + options = SafetensorsBenchmarkOptions( + checkpoint_path=str(self.checkpoint_path) + ) + + context = benchmarks_core.TestContext( + pytree={}, # Unused + path=self.checkpoint_path, + options=options, + mesh=mesh, + ) + + # 3. Run the Benchmark Test Function + result = generator.test_fn(context) + + # 4. Verify Benchmark Metrics + self.assertIsInstance(result, benchmarks_core.TestResult) + self.assertIn('metadata_load_time_duration', result.metrics.results) + self.assertIn('data_load_sharded_time_duration', result.metrics.results) + + # 5. Verify Loaded Content by Reloading + octx = ocp_v1.Context( + checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS + ) + with octx: + metadata = ocp_v1.pytree_metadata(self.checkpoint_path) + abstract_state = metadata.metadata + # Note: Sharding is not applied here, loading as is from the file. + restored_pytree = ocp_v1.load_pytree(self.checkpoint_path, abstract_state) + + self.assertEqual( + jax.tree_util.tree_structure(restored_pytree), + jax.tree_util.tree_structure(self.dummy_pytree), + ) + jax.tree.map( + self.assertTrue, + jax.tree.map( + lambda a, b: np.array_equal(np.array(a), np.array(b)), + restored_pytree, + self.dummy_pytree, + ), + ) + jax.tree.map( + self.assertEqual, + jax.tree.map(lambda a: a.shape, restored_pytree), + jax.tree.map(lambda a: a.shape, self.dummy_pytree), + ) + jax.tree.map( + self.assertEqual, + jax.tree.map(lambda a: a.dtype, restored_pytree), + jax.tree.map(lambda a: a.dtype, self.dummy_pytree), + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/smart_batching.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/smart_batching.py new file mode 100644 index 000000000..07f19996b --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/smart_batching.py @@ -0,0 +1,576 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script for batch-converting Safetensors to native Orbax layout.""" + +import asyncio +from concurrent import futures +import inspect +import itertools +import os +import time +from typing import Any, cast, Dict, Sequence, Tuple + +from absl import app +from absl import flags +from absl import logging +from etils import epath +import huggingface_hub +import jax +import numpy as np +from orbax.checkpoint.experimental import v1 as ocp_v1 +from orbax.checkpoint.experimental.v1._src.layout import safetensors_layout +from tensorflow.io import gfile + + +ThreadPoolExecutor = futures.ThreadPoolExecutor +_INPUT_DIR = flags.DEFINE_string( + 'input_dir', None, 'Directory containing Safetensors files.' +) +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', None, 'Directory to save the converted Orbax checkpoint.' +) +_MAX_BATCH_SIZE_GB = flags.DEFINE_float( + 'max_batch_size_gb', 5.0, 'Maximum size of a single batch in GB.' +) +_WORKERS = flags.DEFINE_integer( + 'workers', + 1, + 'Number of worker threads to use for concurrent processing of batches.', +) + + +_USE_FILE_LOGGER_ONLY = flags.DEFINE_boolean( + 'use_file_logger_only', + True, + 'If True, only logs from this file are printed to avoid excessive logging ' + 'from other modules. If False, enables global INFO logging.', +) + +_SAVING_ENABLED = flags.DEFINE_boolean( + 'saving_enabled', + False, + 'If True, saves the converted Orbax checkpoint to the output directory.', +) + +_OFFICIAL_LOAD_ENABLED = flags.DEFINE_boolean( + 'official_load_enabled', + False, + 'If True, loads the same checkpoint using official Safetensors library.', +) + + +def _log_info(msg: str, *args): + """Designated method for internal logging with volume control.""" + if _USE_FILE_LOGGER_ONLY.value: + print('INFO: ' + (msg % args if args else msg)) + else: + logging.info(msg, *args) + + +def _log_error(msg: str, *args): + """Designated method for internal error logging with volume control.""" + if _USE_FILE_LOGGER_ONLY.value: + print('ERROR: ' + str(msg % args if args else msg)) + else: + logging.error(msg, *args) + + +def benchmark_official_safetensors(repo_id: str) -> float: + """Benchmarks loading checkpoints using the standard Safetensors library. + + It simulates a real-world scenario: Download -> RAM -> Parse. + + Args: + repo_id: The HuggingFace repo ID to download from. + + Returns: + The total time in seconds taken to download and process the Safetensor + files. + """ + # 1. Get only the safetensor filenames + all_files = huggingface_hub.list_repo_files(repo_id) + safetensor_files = [f for f in all_files if f.endswith('.safetensors')] + safetensor_files.sort() + total_start_time = time.time() + local_dir = os.path.expanduser('~/safetensor_layout_loading_test/tmp') + os.makedirs(local_dir, exist_ok=True) + + for filename in safetensor_files: + print(f'--- Processing {filename} ---') + file_start = time.time() + file_path = huggingface_hub.hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=local_dir, + local_dir_use_symlinks=False, + ) + + file_end = time.time() + _log_info('Finished %s in %.2f seconds', filename, file_end - file_start) + if os.path.exists(file_path): + _log_info('Removing file: %s', file_path) + os.remove(file_path) + + return time.time() - total_start_time + + +async def verify_local_integrity( + input_dir: str, + local_output_dir: str, + model_metadata: Any, +): + """Verifies integrity by comparing source and destination tensors. + + Loads entire Source (Safetensors) and Destination (Local Orbax) and compares + them tensor-by-tensor for exact matches. + + Args: + input_dir: Directory containing Safetensors files. + local_output_dir: Directory to save the converted Orbax checkpoint. + model_metadata: Model metadata. + + Raises: + RuntimeError: If any tensor does not match exactly between the source and + destination. + """ + _log_info(f'🔎 STARTING DIRECT INTEGRITY CHECK on {local_output_dir}') + + input_path = epath.Path(input_dir) + output_path = epath.Path(local_output_dir) + + with ocp_v1.Context( + checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS + ): + _log_info('Loading entire Safetensors checkpoint into memory...') + load_start_time = time.time() + st_data = ocp_v1.load_pytree( + path=input_path, abstract_pytree=model_metadata + ) + load_end_time = time.time() + _log_info( + '✅ Safetensors Load Time: %.2f seconds', + load_end_time - load_start_time, + ) + + if inspect.iscoroutine(st_data): + st_data = await st_data + st_data = cast(Dict[str, Any], st_data) + + if 'pytree' in st_data and len(st_data) == 1: + st_data = st_data['pytree'] + + with ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.ORBAX): + _log_info('Loading entire Local Orbax checkpoint into memory...') + load_start_time = time.time() + try: + orbax_data = ocp_v1.load_pytree(output_path, {'pytree': model_metadata}) + except Exception as e: + _log_error( + '❌ FATAL: Could not read Local Orbax checkpoint from %s', output_path + ) + _log_error('Error details: %s', e) + if output_path.exists(): + _log_info('📂 Directory contents of %s:', output_path) + for f in output_path.iterdir(): + _log_info(' - %s', f.name) + raise e + + if isinstance(orbax_data, dict) and 'pytree' in orbax_data: + orbax_data = orbax_data['pytree'] + + load_end_time = time.time() + _log_info( + '✅ Orbax Load Time: %.2f seconds', + load_end_time - load_start_time, + ) + + _log_info('Comparing checkpoints...') + + flat_meta, _ = jax.tree_util.tree_flatten_with_path(model_metadata) + + def get_val(d: Any, keys: Sequence[str]) -> Any: + try: + for k in keys: + d = d[k] + return d + except KeyError: + return None + + total_tensors = 0 + mismatches = 0 + + for path, _ in flat_meta: + key_tuple = tuple(cast(Any, p).key for p in path) + tensor_name = '.'.join(key_tuple) + total_tensors += 1 + + val_src = get_val(st_data, key_tuple) + val_dst = get_val(orbax_data, key_tuple) + + if val_dst is None: + _log_error('❌ MISSING TENSOR: %s not found in Orbax load.', tensor_name) + mismatches += 1 + continue + + if val_src.shape != val_dst.shape: + _log_error( + '❌ SHAPE MISMATCH: %s | Src %s != Dst %s', + tensor_name, + val_src.shape, + val_dst.shape, + ) + mismatches += 1 + continue + + if not np.allclose(val_src, val_dst, equal_nan=True, atol=1e-6): + diff = np.abs(val_src - val_dst) + _log_error( + '❌ VALUE MISMATCH: %s | Max Diff: %s', tensor_name, np.max(diff) + ) + mismatches += 1 + # else: + # # Optional: Log success for very large tensors just to be sure + # val_src_size = val_src.nbytes / (1024 * 1024 * 1024) + # val_dst_size = val_dst.nbytes / (1024 * 1024 * 1024) + # _log_info( + # '✅ Verified Large Tensor: %s | Src Size: %s GB | Dst Size: %s GB', + # tensor_name, + # val_src_size, + # val_dst_size, + # ) + + del st_data + del orbax_data + + if mismatches == 0: + _log_info( + '🎉 SUCCESS: All %s tensors verified matching exactly!', total_tensors + ) + else: + _log_error( + '💀 FAILURE: Found %s mismatches during verification.', mismatches + ) + raise RuntimeError('Sanity check failed! Stopping upload.') + + +# --- NEW HELPER FUNCTION FOR SIZE TRACKING --- +def get_dir_size_mb(start_path: str) -> float: + """Calculates size of the checkpoint, including hidden partial folders.""" + files_to_stat = [] + # Orbax writes to a folder with a suffix like '.partial_save' + # We check both the main folder and the partial folder to be sure. + paths_to_check = [start_path, start_path + '.partial_save'] + + for p in paths_to_check: + if gfile.exists(p): + try: + for dirpath, _, filenames in gfile.Walk(p): + for f in filenames: + fp = os.path.join(dirpath, f) + files_to_stat.append(fp) + except gfile.GOSError as e: + _log_error('Failed to walk files in %s: %s', p, e) + raise e + + if not files_to_stat: + return 0.0 + + try: + stats = gfile.BulkStatWithException(files_to_stat) + total_size = sum(s.length for s in stats if s.length > -1) + except gfile.GOSError as e: + _log_error('Failed to stat files in %s: %s', start_path, e) + raise e + + return total_size / (1024 * 1024) + + +def analyze_model_structure(metadata_tree: Any) -> None: + """Logs detailed statistics about the model structure and expected size.""" + flat_metadata, _ = jax.tree_util.tree_flatten_with_path(metadata_tree) + + total_params = 0 + total_bytes = 0 + dtype_counts = {} + + for _, leaf in flat_metadata: + # Get path string (e.g., "model/layers/0/self_attn/q_proj") + if hasattr(leaf, 'shape') and hasattr(leaf, 'dtype'): + # Calculate size for this specific tensor + shape = leaf.shape + dtype = leaf.dtype + param_count = np.prod(shape) + byte_size = param_count * np.dtype(dtype).itemsize + + # Update totals + total_params += param_count + total_bytes += byte_size + + # Track dtype distribution + dtype_str = str(dtype) + dtype_counts[dtype_str] = dtype_counts.get(dtype_str, 0) + 1 + + total_gb = total_bytes / (1024**3) + total_mb = total_bytes / (1024**2) + + _log_info('Total Tensors: %d', len(flat_metadata)) + _log_info('Total Parameters: %s', f'{total_params:,}') + _log_info( + 'Expected Raw Size (Uncompressed): %.2f MB (%.4f GB)', + total_mb, + total_gb, + ) + _log_info('Dtype Distribution: %s', dtype_counts) + + +def get_param_size_bytes(param_info: jax.ShapeDtypeStruct) -> int: + """Calculates size in bytes for a single parameter from metadata.""" + dtype_size = np.dtype(param_info.dtype).itemsize + return np.prod(param_info.shape) * dtype_size + + +async def _execute_batch_async( + layout: safetensors_layout.SafetensorsLayout, + input_path: epath.Path, + output_dir: str, + plan: Any, + flat_map: Dict[str, Any], + batch_index: int, + total_batches: int, +) -> Tuple[float, float, int]: + """Loads and saves a single batch.""" + _log_info( + '\033[1mProcessing batch %s/%s...\033[0m', batch_index, total_batches + ) + batch_abstract_pytree = {} + for i, batch_keys in enumerate(plan, 1): + if i == batch_index: + for k in batch_keys: + batch_abstract_pytree[k] = flat_map[k] + break + load_start_time = time.time() + _log_info( + 'Loading batch into Host RAM from Safetensors, starting at %.2f seconds', + load_start_time, + ) + tensors = await layout.load_pytree( + path=input_path, abstract_pytree=batch_abstract_pytree + ) + load_end_time = time.time() + _log_info( + '✅ Safetensors Load Time: %.2f seconds', + load_end_time - load_start_time, + ) + tensors = cast(Dict[str, Any], tensors) + + total_save_time = 0.0 + if _SAVING_ENABLED.value: + save_start_time = time.time() + + tensors_to_save = {'pytree': tensors} + ocp_v1.partial.save_pytree(output_dir, tensors_to_save) + save_end_time = time.time() + _log_info( + '✅ Orbax Save Time: %.2f seconds in directory %s', + save_end_time - save_start_time, + output_dir, + ) + total_save_time = save_end_time - save_start_time + return load_end_time - load_start_time, total_save_time, batch_index - 1 + + +def _execute_batch( + layout: safetensors_layout.SafetensorsLayout, + input_path: epath.Path, + output_dir: str, + plan: Any, + flat_map: Dict[str, Any], + batch_index: int, + total_batches: int, +) -> Tuple[float, float, int]: + """Sync wrapper for _execute_batch_async to run in ThreadPoolExecutor.""" + return asyncio.run( + _execute_batch_async( + layout, + input_path, + output_dir, + plan, + flat_map, + batch_index, + total_batches, + ) + ) + + +async def run_cpu_batching( + input_dir: str, output_dir: str, max_batch_size_gb: float +): + """Orchestrates the metadata sizing, planning, and batch execution loop.""" + if gfile.exists(output_dir): + _log_info('Removing existing checkpoint directory: %s', output_dir) + gfile.DeleteRecursively(output_dir) + partial_save_dir = output_dir + '.partial_save' + if gfile.exists(partial_save_dir): + _log_info('Removing existing partial save directory: %s', partial_save_dir) + gfile.DeleteRecursively(partial_save_dir) + start_time = time.time() + input_path = epath.Path(input_dir) + layout = safetensors_layout.SafetensorsLayout() + + _log_info('=' * 60) + _log_info(f'🔎 STARTING CONVERSION FROM SAFETENSORS TO ORBAX for {input_path}') + _log_info('Output directory: %s', output_dir) + _log_info('=' * 60) + _log_info('Conversion will be done in following steps.') + _log_info('step 1: Reading Safetensors Metadata') + _log_info('step 2: Analyzing Model Structure') + _log_info('step 3: Creating Loading Plan') + _log_info('step 4: Executing Loading Loop') + _log_info('step 5: Finalizing Native Orbax Checkpoint') + _log_info('step 6: Time to load using Safetensors Flax API') + _log_info('step 7: Verifying Local Integrity') + + # Step 1: Reading Safetensors Metadata... + _log_info('---------- Step 1: Reading Safetensors Metadata -----------') + metadata_start_time = time.time() + metadata_ckpt = await layout.metadata(input_path) + metadata_end_time = time.time() + _log_info( + 'Safetensors Metadata Time: %.2f seconds', + metadata_end_time - metadata_start_time, + ) + + if 'pytree' in metadata_ckpt.metadata: + model_metadata = metadata_ckpt.metadata['pytree'] + else: + model_metadata = metadata_ckpt.metadata + + # Step 2: Analyzing Model Structure... + _log_info('------------ Step 2: Analyzing Model Structure --------------') + analyze_start_time = time.time() + analyze_model_structure(model_metadata) + analyze_end_time = time.time() + _log_info( + 'Model Structure Analysis Time: %.2f seconds', + analyze_end_time - analyze_start_time, + ) + + # Step 3: Creating Loading Plan... + _log_info('-------------- Step 3: Creating Loading Plan ----------------') + plan_start_time = time.time() + plan, batch_sizes = await layout.create_loading_plan( + input_path, max_batch_size_gb + ) + _log_info( + 'Created a plan containing %s batches in %.2f seconds', + len(plan), + time.time() - plan_start_time, + ) + + # Step 4: Executing Loading Loop... + _log_info('------------- Step 4: Executing Loading Loop ----------------') + flat_metadata, _ = jax.tree_util.tree_flatten_with_path(model_metadata) + flat_map = {path[0].key: leaf for path, leaf in flat_metadata} + + global_start = time.time() + total_load_size = 0 + num_batches = len(plan) + with ThreadPoolExecutor(max_workers=_WORKERS.value) as pool: + results = pool.map( + _execute_batch, + itertools.repeat(layout), + itertools.repeat(input_path), + itertools.repeat(output_dir), + itertools.repeat(plan), + itertools.repeat(flat_map), + range(1, num_batches + 1), + itertools.repeat(num_batches), + ) + for r in results: + load_time, _, batch_index = r + batch_size = batch_sizes[batch_index] + total_load_size += batch_size + print( + f'Read Batch {batch_index} with {batch_size / (1024 * 1024)} MB in' + f' {load_time} sec (throughput:' + f' {batch_size / (1024 * 1024) / load_time} MB/sec)' + ) + global_end = time.time() + _log_info( + 'Total Loading Time: %.2f seconds', + global_end - global_start, + ) + # total_load_time = sum(r[0] for r in results) + total_save_time = 0.0 + if _SAVING_ENABLED.value: + # Step 5: Finalizing Native Orbax Checkpoint... + finalize_start_time = time.time() + _log_info('------- Step 5: Finalizing Native Orbax Checkpoint ----------') + ocp_v1.partial.finalize(output_dir) + total_save_time += time.time() - finalize_start_time + + # final_size = get_dir_size_mb(output_dir) + official_load_time = 0.0 + if _OFFICIAL_LOAD_ENABLED.value: + # Step 6: Time to load using Safetensors Flax API.. + _log_info( + '------------ Step 6: Time to load using Safetensors Flax API' + ' ------------\n' + ) + _log_info('Loading entire checkpoint in RAM using official Safetensors...') + repo_id_1gb = 'Qwen/Qwen2.5-0.5B-Instruct' + repo_id_50gb = 'Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4' + repo_id_150gb = 'Qwen/Qwen2.5-72B-Instruct' + repo_id_1tb = 'moonshotai/Kimi-K2-Instruct' + if input_dir == 'gs://safetensor-kimi-central/test-model-1gb': + official_load_time = benchmark_official_safetensors(repo_id_1gb) + elif input_dir == 'gs://safetensor-kimi-central/test-model-50gb': + official_load_time = benchmark_official_safetensors(repo_id_50gb) + elif input_dir == 'gs://safetensor-kimi-central/test-model-150gb': + official_load_time = benchmark_official_safetensors(repo_id_150gb) + elif input_dir == 'gs://safetensor-kimi-central/test-model-1tb': + official_load_time = benchmark_official_safetensors(repo_id_1tb) + _log_info( + '✅ \033[1mTotal Conversion Time: %.2f seconds | Final Size: %.2f MB |' + ' Total Load Throughput: %.2f MB/s | Total Save Time: %.2f seconds |' + ' Official Load Time: %.2f seconds | Final checkpoint location:' + ' %s\033[0m', + time.time() - start_time, + total_load_size / (1024 * 1024), + total_load_size / (1024 * 1024) / (time.time() - start_time), + total_save_time, + official_load_time, + output_dir, + ) + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + if not _INPUT_DIR.value or not _OUTPUT_DIR.value: + raise app.UsageError('--input_dir and --output_dir must be prov ided.') + + if not _USE_FILE_LOGGER_ONLY.value: + logging.set_stderrthreshold('INFO') + + asyncio.run( + run_cpu_batching( + _INPUT_DIR.value, _OUTPUT_DIR.value, _MAX_BATCH_SIZE_GB.value + ) + ) + + +if __name__ == '__main__': + app.run(main) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/test_loading_orbax.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/test_loading_orbax.py new file mode 100644 index 000000000..34842ad8d --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/test_loading_orbax.py @@ -0,0 +1,182 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests loading Orbax checkpoints from CNS.""" + +from concurrent import futures +import time +from typing import Sequence + +from absl import app +from etils import epath +from google.cloud import storage +import zstandard as zstd + + +ThreadPoolExecutor = futures.ThreadPoolExecutor +CNS_PATH = "s/ig-d/home/gemax-prod-team/llama-checkpoint/llama-3.1-70B-checkpoints/0/items/" +GCS_PATH = "gs://safetensor-kimi-central/test_model_orbax/llama-3.1-70B-checkpoints/0/items/items" + +# Initialize the client globally so workers share the connection pool. +# This prevents opening a brand new TCP connection for every single file. +_GCS_CLIENT = None + + +def get_gcs_client(): + global _GCS_CLIENT + if _GCS_CLIENT is None: + _GCS_CLIENT = storage.Client() + return _GCS_CLIENT + + +def read_file_gcs(p: epath.Path): + """Reads a file from GCS using the native cloud storage client.""" + start = time.time() + + # 1. Parse the gs:// path to get the bucket and blob name + path_str = str(p) + if not path_str.startswith("/big" + "store/"): + raise ValueError(f"Expected a /bistore/ path, got {path_str}") + + # Strip "gs://" and split + path_without_scheme = path_str[10:] + bucket_name, blob_name = path_without_scheme.split("/", 1) + + # 2. Download directly to memory using the GCS client + client = get_gcs_client() + bucket = client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + # download_as_bytes() releases the GIL during network transit + data = blob.download_as_bytes() + + # 3. Decompress (Same logic as before) + zstd_magic = b"\x28\xb5\x2f\xfd" + start_offset = data.find(zstd_magic) + + if start_offset == -1: + raise ValueError("Could not find a valid Zstd magic number.") + + compressed_payload = data[start_offset:] + + dctx = zstd.ZstdDecompressor() + dctx.decompress(compressed_payload, max_output_size=int(20e9)) + + end = time.time() + mb = len(data) / 1e6 + return p.name, mb, end - start + + +def read_file_cns(p: epath.Path): + """Reads and decompresses a single Orbax checkpoint file. + + Args: + p: The epath.Path to the file to read. + + Returns: + A tuple containing the file name, size in MB, and the time taken to read + and decompress. + + Raises: + ValueError: If a valid Zstd magic number is not found in the file. + """ + start = time.time() + data = p.read_bytes() + + # Zstandard frames always start with the magic number: 0xFD2FB528 + # We need to find where this starts in the raw OCDBT file + zstd_magic = b"\x28\xb5\x2f\xfd" + start_offset = data.find(zstd_magic) + + if start_offset == -1: + raise ValueError( + "Could not find a valid Zstd magic number in the file. " + "It might be uncompressed or use a different algorithm." + ) + + # print(f"Found Zstd payload at offset: {start_offset}") + + # Slice the data from the magic number to the end + compressed_payload = data[start_offset:] + + dctx = zstd.ZstdDecompressor() + dctx.decompress(compressed_payload, max_output_size=int(20e9)) + + end = time.time() + mb = len(data) / 1e6 + return p.name, mb, end - start + + +def read_concurrently(workers): + """Reads and decompresses Orbax checkpoint files concurrently. + + Args: + workers: The number of worker threads to use for concurrent reading. + """ + global_start = time.time() + + filepaths_cns = list(epath.Path(CNS_PATH + "/cn").glob("ocdbt.process*/d/*")) + filepaths_gcs = list(epath.Path(GCS_PATH).glob("ocdbt.process*/d/*")) + print(f"Total files in CNS: {len(filepaths_cns)}") + print(f"Total files in GCS: {len(filepaths_gcs)}") + with ThreadPoolExecutor(max_workers=workers) as pool: + # --------------------------------------------------------- + # 1. READ FROM CNS + # --------------------------------------------------------- + print("\n--- Reading from CNS ---") + cns_start = time.time() + cns_total_mb = 0 + results_cns = pool.map(read_file_cns, filepaths_cns) + for r in results_cns: + name, mb, t = r + cns_total_mb += mb + print(f"CNS: Opened {name} ({mb:.2f} MB) in {t:.2f} sec") + cns_total_time = time.time() - cns_start + print(f"CNS Finished: {cns_total_mb:.2f} MB in {cns_total_time:.2f} sec") + print(f"CNS Throughput: {cns_total_mb / cns_total_time:.2f} MB/sec") + + # --------------------------------------------------------- + # 2. READ FROM GCS + # --------------------------------------------------------- + print("\n--- Reading from GCS ---") + gcs_start = time.time() + gcs_total_mb = 0 + # The workers are now free and will immediately start on this: + results_gcs = pool.map(read_file_gcs, filepaths_gcs) + for r in results_gcs: + name, mb, t = r + gcs_total_mb += mb + print(f"GCS: Opened {name} ({mb:.2f} MB) in {t:.2f} sec") + gcs_total_time = time.time() - gcs_start + print(f"GCS Finished: {gcs_total_mb:.2f} MB in {gcs_total_time:.2f} sec") + print(f"GCS Throughput: {gcs_total_mb / gcs_total_time:.2f} MB/sec") + + global_end = time.time() + print( + f"\nTotal benchmark time for {workers} workers:" + f" {global_end - global_start:.2f} sec" + ) + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + read_concurrently(4) + read_concurrently(16) + read_concurrently(32) + + +if __name__ == "__main__": + app.run(main) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile index bc9e9465a..65f57cb65 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/Dockerfile @@ -48,7 +48,7 @@ ARG JAX_VERSION=newest ARG DEVICE=tpu # Install GCSFS and Portpicker -RUN pip install --no-cache-dir gcsfs portpicker clu tensorflow +RUN pip install --no-cache-dir gcsfs portpicker clu tensorflow safetensors # Install requirements from repo root if it exists RUN if [ -f "requirements.txt" ]; then pip install --no-cache-dir -r requirements.txt; fi @@ -90,6 +90,7 @@ ENV PYTHONPATH=/app/orbax_repo/checkpoint # Verify installation RUN python3 -c "import orbax.checkpoint; print('Orbax installed:', orbax.checkpoint.__file__)" +RUN python3 -c "import safetensors.numpy; print('SafeTensors installed:', safetensors.numpy.__file__)" # 6. Entrypoint # We point to the benchmark script relative to the repo root structure diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py index 0f64b2c7c..466da50ae 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout.py @@ -12,12 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Defines `SafetensorsLayout`, a class to handle Safetensors checkpoint formats.""" +"""Safetensors checkpoint format layout.""" +import asyncio import collections import json -from typing import Any, Awaitable, Sequence, cast +import mmap +import os +import time +from typing import Any, Awaitable, List, Sequence, Tuple, cast +from google.cloud import storage +from google.cloud.storage import transfer_manager import jax import numpy as np from orbax.checkpoint._src.arrays import numpy_utils @@ -32,6 +38,7 @@ HEADER_NUM_BYTES = 8 SAFETENSORS_SUFFIX = ".safetensors" +_HEADER_CACHE = {} def _get_dtypes() -> dict[str, Any]: @@ -50,6 +57,8 @@ def _get_dtypes() -> dict[str, Any]: "F32": np.float32, "F64": np.float64, "BF16": jax.numpy.bfloat16, + "F8_E4M3": jax.numpy.float8_e4m3fn, + "F8_E5M2": jax.numpy.float8_e5m2, "F8_E8M0": "float8_e8m0fnu (specialized ML dtype)", "F4": "float4_e2m1fn_x2 (specialized ML dtype)", } @@ -63,8 +72,21 @@ async def _get_safetensors_file_list(path: Path) -> Sequence[Path]: return [path] +async def get_tensor_to_path_indexing(path): + """Returns a mapping from tensor name to safetensors file.""" + path_ = Path(str(path) + "/model.safetensors.index.json") + async with async_path.open_file(path_, mode="rb") as f: + raw_data = await f.read() + index_data = json.loads(raw_data) + return index_data["weight_map"] + + async def _read_safetensors_header(path: Path) -> tuple[dict[str, Any], int]: - """Reads a safetensors file header, returning the header and data start offset.""" + """Reads a safetensors file header, returning header and data start offset.""" + path_str = str(path) + if path_str in _HEADER_CACHE: + return _HEADER_CACHE[path_str] + async with async_path.open_file(path, mode="rb") as f: header_size_bytes = await f.read(HEADER_NUM_BYTES) if not header_size_bytes: @@ -77,6 +99,8 @@ async def _read_safetensors_header(path: Path) -> tuple[dict[str, Any], int]: header = json.loads(header_bytes) data_start_offset = HEADER_NUM_BYTES + header_size + + _HEADER_CACHE[path_str] = header, data_start_offset return header, data_start_offset @@ -133,19 +157,26 @@ async def _read_non_contiguous_slice( for i in range(len(stored_shape) - 2, -1, -1): global_strides[i] = global_strides[i + 1] * stored_shape[i + 1] - async def _read_slice_recursively(dim: int, base_offset: int) -> bytes: - # TODO(b/438763866) - @zachmeyers to consider alternative methods. - s = idx[dim] # The slice for the current dimension. + # Pre-calculate which dimensions are fully selected. + is_full_dim = [False] * len(stored_shape) + for i, s in enumerate(idx): + if s.start == 0 and s.stop == stored_shape[i] and s.step == 1: + is_full_dim[i] = True - # If we are at the last dimension, the data is contiguous. - if dim == len(stored_shape) - 1: + async def _read_slice_recursively(dim: int, base_offset: int) -> bytes: + # If all remaining dimensions are fully selected, we can read the entire + # contiguous block for the current dimension's slice. + if dim == len(stored_shape) - 1 or all(is_full_dim[dim + 1 :]): + s = idx[dim] start = base_offset + s.start * global_strides[dim] - num_bytes = (s.stop - s.start) * itemsize + num_bytes = (s.stop - s.start) * global_strides[dim] + await f.seek(tensor_file_offset + start) return cast(bytes, await f.read(num_bytes)) # For all other dimensions, iterate through the indices # of the slice and make a recursive call for the next dimension. + s = idx[dim] chunks = [] for i in range(s.start, s.stop): offset = base_offset + i * global_strides[dim] @@ -178,8 +209,208 @@ async def _load_safetensors_as_numpy(path: Path) -> dict[str, np.ndarray]: return tensors +def _process_bytes_to_jax( + tensor_bytes: bytes, + tensor_name: str, + abstract_leaf: Any, + header: dict[str, Any], +) -> jax.Array: + """Universal helper to safely parse bytes into sharded or host JAX arrays.""" + stored_shape, stored_dtype = _get_array_properties(header[tensor_name]) + sharding = abstract_leaf.sharding + target_shape = abstract_leaf.shape + target_dtype = abstract_leaf.dtype + + np_array = np.frombuffer(tensor_bytes, dtype=stored_dtype).reshape( + stored_shape + ) + if np_array.dtype != target_dtype: + np_array = np_array.astype(target_dtype) + + # Fallback to Host RAM if no sharding is provided + if sharding is None: + arr = jax.device_put(np_array) + del np_array + return arr + + # Distributed Sharding + device_indices_map = sharding.addressable_devices_indices_map(target_shape) + device_map = list() + + for device in sharding.addressable_devices: + if device in device_indices_map: + idx = device_indices_map[device] + resolved_idx = numpy_utils.resolve_slice(idx, stored_shape) + shard_np = np_array[resolved_idx] + device_map.append(jax.device_put(shard_np.copy(), device)) + del shard_np + + del np_array + return jax.make_array_from_single_device_arrays( + target_shape, sharding, device_map + ) + + +async def _load_safetensors_on_device_local( + path: Path, abstract_pytree: dict[str, Any] +) -> dict[str, jax.Array]: + """Fast path for local NVMe/SSD files using zero-copy memory mapping.""" + header, data_start_offset = await _read_safetensors_header(path) + restored_tensors = {} + results_to_block = list() + + with open(path, "rb") as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + for tensor_name, abstract_leaf in abstract_pytree.items(): + if tensor_name not in header: + continue + + start_offset, end_offset = header[tensor_name]["data_offsets"] + absolute_start = data_start_offset + start_offset + absolute_end = data_start_offset + end_offset + + # Wrap in memoryview to ensure zero-copy bridging to NumPy + tensor_bytes = memoryview(mm)[absolute_start:absolute_end] + + jax_array = _process_bytes_to_jax( + tensor_bytes, tensor_name, abstract_leaf, header + ) + restored_tensors[tensor_name] = jax_array + results_to_block.append(jax_array) + + del tensor_bytes + + # Ensure hardware has ingested the data before the mmap file lock is + # released + jax.block_until_ready(results_to_block) + + return restored_tensors + + +async def _load_safetensors_on_device_gcs( + path: Path, abstract_pytree: dict[str, Any] +) -> dict[str, jax.Array]: + """High-bandwidth parallel downloader for Google Cloud Storage.""" + header, _ = await _read_safetensors_header(path) + restored_tensors = {} + + path_str = str(path) + + if path_str.startswith("/big" + "store/"): + path_str = "gs://" + path_str[10:] + + if not path_str.startswith("gs://"): + raise ValueError(f"Unsupported remote path format: {path_str}") + + bucket_name, blob_name = path_str[5:].split("/", 1) + + min_start = float("inf") + max_end = 0 + tensors_to_load = {} + + # 1. Calculate the exact Bounding Box of the batch + for t_name, abstract_leaf in abstract_pytree.items(): + if t_name in header: + start, end = header[t_name]["data_offsets"] + if start < min_start: min_start = start + if end > max_end: max_end = end + tensors_to_load[t_name] = (abstract_leaf, start, end) + + if not tensors_to_load: + return restored_tensors + + span_size = max_end - min_start + + start_read_time = time.time() + + client = storage.Client() + blob = client.bucket(bucket_name).blob(blob_name) + + safe_temp_name = blob_name.replace("/", "_") + ram_disk_path = f"/dev/shm/{safe_temp_name}_temp.bin" + + # 3. Execute the highly parallel download + # Using 64 processes handles the massive 17GB file effortlessly + print(f"----------------------- [{path.name}] ----------------------------- ") + print( + f"Starting parallel download from GCS to RAM disk..." + f" {span_size / (1024 * 1024):.2f} MB in" + ) + transfer_manager.download_chunks_concurrently( + blob, + ram_disk_path, + chunk_size=128 * 1024 * 1024, # 256 MB chunks per connection + max_workers=128, # Half of your 128 vCPUs + worker_type="process", # Crucial: Uses processes to bypass the GIL + ) + + print( + f"Single-shot network read from GCS finished:" + f" {span_size / (1024 * 1024):.2f} MB in" + f" {time.time() - start_read_time:.2f}s Throughput:" + f" ({span_size / (time.time() - start_read_time) / (1024 * 1024):.2f}" + " MB/s)" + ) + + print("Loading data into Python memory...") + start_load_time = time.time() + with open(ram_disk_path, "rb") as f: + span_bytes = f.read() + + # 5. Clean up the RAM disk + os.remove(ram_disk_path) + print( + f"Python Memory Load Time:" + f" {time.time() - start_load_time:.2f} seconds." + ) + + # 6. Load into JAX arrays + print("Loading data into JAX arrays...") + start_jax_load_time = time.time() + for t_name, (abstract_leaf, start, end) in tensors_to_load.items(): + rel_start = start - min_start + rel_end = end - min_start + + tensor_bytes = span_bytes[rel_start:rel_end] + restored_tensors[t_name] = _process_bytes_to_jax( + tensor_bytes, t_name, abstract_leaf, header + ) + del tensor_bytes + print( + f"JAX Load Time:" + f" {time.time() - start_jax_load_time:.2f} seconds." + ) + print( + f"Total (Network + Python + JAX) Time:" + f" {time.time() - start_read_time:.2f} seconds." + ) + print("----------------------------------------------------------------- ") + # 5. Destroy the large buffer immediately after the GPUs take over + del span_bytes + + return restored_tensors + + async def _load_safetensors_on_device( path: Path, abstract_pytree: dict[str, Any] +) -> dict[str, jax.Array]: + """Intelligent Router to load Safetensors based on storage topology.""" + is_local = False + try: + with open(path, "rb"): + is_local = True + except (FileNotFoundError, OSError): + pass + + if is_local: + result = await _load_safetensors_on_device_local(path, abstract_pytree) + else: + result = await _load_safetensors_on_device_gcs(path, abstract_pytree) + return result + + +async def _load_safetensors_on_device_old( + path: Path, abstract_pytree: dict[str, Any] ) -> dict[str, jax.Array]: """Loads tensors from a safetensors file into on-device JAX arrays.""" header, data_start_offset = await _read_safetensors_header(path) @@ -275,17 +506,29 @@ async def _load_safetensors( # has a `weight_map` that maps tensor names to file paths. From the # abstract tree, we can look up only the keys that are actually needed to # load using the index.json. + + start_time = time.time() tensor_to_path = {} - for path in paths: - header, _ = await _read_safetensors_header(path) - for name in header: - if name == "__metadata__": - continue - if name in tensor_to_path: - raise ValueError(f"Duplicate tensor {name} found in multiple files.") - tensor_to_path[name] = path + file_to_path = {} + + for file_ in paths: + file_to_path[str(Path(file_).name)] = file_ + + indexing_results = await get_tensor_to_path_indexing(paths[0].parent) + + for name, path in indexing_results.items(): + if name in tensor_to_path: + raise ValueError(f"Duplicate tensor {name} found in multiple files.") + tensor_to_path[name] = file_to_path[str(path)] + + end_time = time.time() + print( + f"----- Mapping tensor names to files took" + f" {end_time - start_time:.2f} seconds." + ) # 2. Split abstract_pytree by file + all_files = [] file_abstract_trees = collections.defaultdict( dict ) # Path -> dict[str, abstract_leaf] @@ -297,14 +540,23 @@ async def _load_safetensors( ) path = tensor_to_path[tensor_name] + all_files.append(path.name) file_abstract_trees[path][tensor_name] = abstract_leaf + unique_files = set(all_files) + print(f"Unique files: {len(unique_files)}, files: {unique_files}") # 3. Load from each file restored_pytree = {} for path, sub_tree in file_abstract_trees.items(): sub_restored = await _load_safetensors_on_device(path, sub_tree) restored_pytree.update(sub_restored) + end_time = time.time() + print( + f"----- [{paths[0].name}] Loading {len(file_abstract_trees)} files took" + f" {end_time - start_time:.2f} seconds." + ) + return restored_pytree @@ -322,6 +574,54 @@ class SafetensorsLayout(CheckpointLayout): def __init__(self): pass + async def create_loading_plan( + self, + path: Path, + max_batch_size_gb: float, + ) -> Tuple[List[List[Tuple[str, ...]]], List[float]]: + """Saves the checkpoint to the given directory.""" + files = await _get_safetensors_file_list(path) + + async def _fetch_header(p): + h, _ = await _read_safetensors_header(p) + return p, h + + all_headers = await asyncio.gather(*[_fetch_header(p) for p in files]) + batches = [] + batches_size = [] + current_batch = [] + current_batch_size = 0 + max_bytes = max_batch_size_gb * (1024**3) + for _, header in all_headers: + current_file_size = 0 + current_file_tensors = [] + for tensor_name, leaf_meta in header.items(): + if tensor_name == "__metadata__": + continue + else: + if "shape" in leaf_meta and "dtype" in leaf_meta: + shape, dtype = _get_array_properties(leaf_meta) + dtype_size = np.dtype(dtype).itemsize + size = np.prod(shape) * dtype_size + else: + size = 0 + current_file_tensors.append(tensor_name) + current_file_size += size + if current_batch_size + current_file_size > max_bytes: + print(f"Batch size: {current_batch_size / (1024 * 1024):.2f} MB") + batches.append(current_batch) + batches_size.append(current_batch_size) + current_batch = current_file_tensors + current_batch_size = current_file_size + else: + current_batch.extend(current_file_tensors) + current_batch_size += current_file_size + if current_batch: + print(f"Batch size: {current_batch_size / (1024 * 1024):.2f} MB") + batches.append(current_batch) + batches_size.append(current_batch_size) + return batches, batches_size + async def metadata( self, path: Path ) -> metadata_types.CheckpointMetadata[dict[str, Any]]: @@ -332,9 +632,13 @@ async def metadata( # Track the latest commit timestamp. commit_timestamp_nsecs = None + async def _fetch_header(p): + h, _ = await _read_safetensors_header(p) + return p, h + + header_results = await asyncio.gather(*[_fetch_header(p) for p in files]) - for path in files: - header, _ = await _read_safetensors_header(path) + for path, header in header_results: stat = await async_path.async_stat(path) ts = int(stat.mtime) if commit_timestamp_nsecs is None or ts > commit_timestamp_nsecs: @@ -408,7 +712,8 @@ async def load_pytree( """ del checkpointable_name files = await _get_safetensors_file_list(path) - return _load_safetensors(files, abstract_pytree) + result = await _load_safetensors(files, abstract_pytree) + return result async def save( self,