diff --git a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py index 973becd60..90ecb9bec 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py @@ -44,6 +44,7 @@ from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib from orbax.checkpoint._src.metadata import empty_values from orbax.checkpoint._src.metadata import tree as tree_metadata +from orbax.checkpoint._src.path import types as path_types from orbax.checkpoint._src.serialization import limits from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.serialization import type_handler_registry as handler_registry @@ -470,7 +471,9 @@ def _concurrent_bytes( return concurrent_gb * 10**9 -class PyTreeCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler): +class PyTreeCheckpointHandler( + async_checkpoint_handler.DeferredPathAsyncCheckpointHandler +): """A CheckpointHandler implementation for any PyTree structure. See JAX documentation for more information on what consistutes a "PyTree". @@ -608,7 +611,7 @@ def __init__( async def async_save( self, - directory: epath.Path, + directory: epath.Path | path_types.PathAwaitingCreation, item: Optional[PyTree] = None, save_args: Optional[PyTreeSaveArgs] = None, args: Optional[PyTreeSaveArgs] = None, diff --git a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py index 7599e4f24..6e915c9fb 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test.py @@ -21,6 +21,7 @@ import functools import json import re +import threading from typing import Any, Iterator, List, NamedTuple, Optional, Sequence import unittest from unittest import mock @@ -54,6 +55,7 @@ from orbax.checkpoint._src.metadata import tree as tree_metadata from orbax.checkpoint._src.metadata import value as value_metadata from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.path import atomicity from orbax.checkpoint._src.serialization import limits from orbax.checkpoint._src.serialization import replica_slices from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils @@ -2948,6 +2950,56 @@ def test_partial_restore_with_omission_unexpected_keys( ) test_utils.assert_tree_equal(self, expected, restored) + async def test_save_with_deferred_path(self): + """Tests that async_save works with deferred paths.""" + deferred_path = atomicity.DeferredPath() + save_dir = self.directory / 'deferred_path_ckpt' + await_creation_called = False + original_await = atomicity.DeferredPath.await_creation + set_path_lock = threading.Lock() + + async def mock_await_creation(dp_self): + """Sets the path only once await_creation is called. + + This ensures the path is not resolved before the handler awaits it, fully + exercising the deferred path resolution contract. + + Args: + dp_self: The DeferredPath instance. + + Returns: + The result of the original await_creation method. + """ + nonlocal await_creation_called + with set_path_lock: + if not dp_self._future_path.done(): + save_dir.mkdir(parents=True, exist_ok=True) + dp_self.set_path(save_dir) + await_creation_called = True + return await original_await(dp_self) + + with self.ocdbt_checkpoint_handler(use_ocdbt=False) as handler: + with mock.patch.object( + atomicity.DeferredPath, + 'await_creation', + mock_await_creation, + ): + commit_futures = await handler.async_save( + deferred_path, + args=PyTreeSaveArgs(self.pytree), + ) + if commit_futures: + for f in commit_futures: + f.result() + + self.assertTrue(await_creation_called) + self.validate_save( + save_dir, + self.pytree, + handler, + restore_args=self.restore_args, + ) + if __name__ == '__main__': multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/_src/path/atomicity.py b/checkpoint/orbax/checkpoint/_src/path/atomicity.py index 7ccd1091b..7e3c8c061 100644 --- a/checkpoint/orbax/checkpoint/_src/path/atomicity.py +++ b/checkpoint/orbax/checkpoint/_src/path/atomicity.py @@ -54,6 +54,7 @@ import abc import asyncio +import concurrent.futures import pickle import threading import time @@ -228,6 +229,42 @@ def get_awaitable_path(self) -> path_types.PathAwaitingCreation: ... +class DeferredPath(path_types.PathAwaitingCreation): + """A path that is created asynchronously and can be awaited. + + Uses concurrent.futures.Future instead of asyncio.Task to avoid + event loop binding issues when create() runs in a different thread. + The Future is thread-safe and can be awaited from any event loop. + """ + + def __init__(self): + self._future_path: concurrent.futures.Future[epath.Path] = ( + concurrent.futures.Future() + ) + + def set_path(self, path: epath.Path) -> None: + """Sets the path result. Called by create() when allocation completes.""" + self._future_path.set_result(path) + + def __truediv__( + self, other: path_types.PathLike + ) -> path_types.PathAwaitingCreation: + child = DeferredPath() + self._future_path.add_done_callback( + lambda f: child.set_path(f.result() / other) + ) + return child + + @property + def path(self) -> epath.Path: + if not self._future_path.done(): + raise ValueError('Path has not been created yet. Call await_creation().') + return self._future_path.result() + + async def await_creation(self) -> epath.Path: + return await asyncio.wrap_future(self._future_path) + + class ReadOnlyTemporaryPath(atomicity_types.TemporaryPath): """A read-only, serializable object providing path properties access. diff --git a/checkpoint/orbax/checkpoint/_src/path/atomicity_test.py b/checkpoint/orbax/checkpoint/_src/path/atomicity_test.py index 6ff8e4b49..87b4c346a 100644 --- a/checkpoint/orbax/checkpoint/_src/path/atomicity_test.py +++ b/checkpoint/orbax/checkpoint/_src/path/atomicity_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import concurrent.futures import stat import unittest from absl.testing import absltest @@ -205,6 +207,33 @@ async def test_finalize_raises(self): ) +class DeferredPathTest(absltest.TestCase): + + def test_set_and_get_path(self): + dp = atomicity.DeferredPath() + test_path = epath.Path('/test/path') + dp.set_path(test_path) + self.assertEqual(dp.path, test_path) + + def test_path_before_set_raises(self): + dp = atomicity.DeferredPath() + with self.assertRaises(ValueError): + _ = dp.path + + def test_await_creation(self): + dp = atomicity.DeferredPath() + test_path = epath.Path('/test/path') + dp.set_path(test_path) + result = asyncio.run(dp.await_creation()) + self.assertEqual(result, test_path) + + def test_set_path_twice_raises(self): + dp = atomicity.DeferredPath() + dp.set_path(epath.Path('/first')) + with self.assertRaises(concurrent.futures.InvalidStateError): + dp.set_path(epath.Path('/second')) + + if __name__ == '__main__': absltest.main() diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index daa5697c1..7b75807ee 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -516,7 +516,10 @@ def _serialize_batch( ' scheduled asynchronously.' ) + all_infos = infos async def _serialize(): + for info in all_infos: + await info.await_path_creation() if prioritized: arrays, infos, args = zip(*prioritized) _serialize_batch(infos, args, arrays)