Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.path import types as path_types
from orbax.checkpoint._src.serialization import limits
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handler_registry as handler_registry
Expand Down Expand Up @@ -470,7 +471,9 @@ def _concurrent_bytes(
return concurrent_gb * 10**9


class PyTreeCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler):
class PyTreeCheckpointHandler(
async_checkpoint_handler.DeferredPathAsyncCheckpointHandler
):
"""A CheckpointHandler implementation for any PyTree structure.

See JAX documentation for more information on what consistutes a "PyTree".
Expand Down Expand Up @@ -608,7 +611,7 @@ def __init__(

async def async_save(
self,
directory: epath.Path,
directory: epath.Path | path_types.PathAwaitingCreation,
item: Optional[PyTree] = None,
save_args: Optional[PyTreeSaveArgs] = None,
args: Optional[PyTreeSaveArgs] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import functools
import json
import re
import threading
from typing import Any, Iterator, List, NamedTuple, Optional, Sequence
import unittest
from unittest import mock
Expand Down Expand Up @@ -54,6 +55,7 @@
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.metadata import value as value_metadata
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import atomicity
from orbax.checkpoint._src.serialization import limits
from orbax.checkpoint._src.serialization import replica_slices
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
Expand Down Expand Up @@ -2948,6 +2950,56 @@ def test_partial_restore_with_omission_unexpected_keys(
)
test_utils.assert_tree_equal(self, expected, restored)

async def test_save_with_deferred_path(self):
"""Tests that async_save works with deferred paths."""
deferred_path = atomicity.DeferredPath()
save_dir = self.directory / 'deferred_path_ckpt'
await_creation_called = False
original_await = atomicity.DeferredPath.await_creation
set_path_lock = threading.Lock()

async def mock_await_creation(dp_self):
"""Sets the path only once await_creation is called.

This ensures the path is not resolved before the handler awaits it, fully
exercising the deferred path resolution contract.

Args:
dp_self: The DeferredPath instance.

Returns:
The result of the original await_creation method.
"""
nonlocal await_creation_called
with set_path_lock:
if not dp_self._future_path.done():
save_dir.mkdir(parents=True, exist_ok=True)
dp_self.set_path(save_dir)
await_creation_called = True
return await original_await(dp_self)

with self.ocdbt_checkpoint_handler(use_ocdbt=False) as handler:
with mock.patch.object(
atomicity.DeferredPath,
'await_creation',
mock_await_creation,
):
commit_futures = await handler.async_save(
deferred_path,
args=PyTreeSaveArgs(self.pytree),
)
if commit_futures:
for f in commit_futures:
f.result()

self.assertTrue(await_creation_called)
self.validate_save(
save_dir,
self.pytree,
handler,
restore_args=self.restore_args,
)


if __name__ == '__main__':
multiprocess_test.main()
37 changes: 37 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/atomicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

import abc
import asyncio
import concurrent.futures
import pickle
import threading
import time
Expand Down Expand Up @@ -228,6 +229,42 @@
...


class DeferredPath(path_types.PathAwaitingCreation):
"""A path that is created asynchronously and can be awaited.

Uses concurrent.futures.Future instead of asyncio.Task to avoid
event loop binding issues when create() runs in a different thread.
The Future is thread-safe and can be awaited from any event loop.
"""

def __init__(self):
self._future_path: concurrent.futures.Future[epath.Path] = (
concurrent.futures.Future()
)

def set_path(self, path: epath.Path) -> None:
"""Sets the path result. Called by create() when allocation completes."""
self._future_path.set_result(path)

def __truediv__(
self, other: path_types.PathLike
) -> path_types.PathAwaitingCreation:
child = DeferredPath()
self._future_path.add_done_callback(
lambda f: child.set_path(f.result() / other)
)
return child

@property
def path(self) -> epath.Path:
if not self._future_path.done():
raise ValueError('Path has not been created yet. Call await_creation().')

Check failure on line 261 in checkpoint/orbax/checkpoint/_src/path/atomicity.py

View workflow job for this annotation

GitHub Actions / multiprocess-unit-tests (Python 3.11, jax=newest)

Path has not been created yet. Call await_creation().

Check failure on line 261 in checkpoint/orbax/checkpoint/_src/path/atomicity.py

View workflow job for this annotation

GitHub Actions / multiprocess-unit-tests (Python 3.11, jax=newest)

Path has not been created yet. Call await_creation().

Check failure on line 261 in checkpoint/orbax/checkpoint/_src/path/atomicity.py

View workflow job for this annotation

GitHub Actions / multiprocess-unit-tests (Python 3.11, jax=newest)

Path has not been created yet. Call await_creation().

Check failure on line 261 in checkpoint/orbax/checkpoint/_src/path/atomicity.py

View workflow job for this annotation

GitHub Actions / multiprocess-unit-tests (Python 3.11, jax=newest)

Path has not been created yet. Call await_creation().

Check failure on line 261 in checkpoint/orbax/checkpoint/_src/path/atomicity.py

View workflow job for this annotation

GitHub Actions / multiprocess-unit-tests (Python 3.11, jax=newest)

Path has not been created yet. Call await_creation().

Check failure on line 261 in checkpoint/orbax/checkpoint/_src/path/atomicity.py

View workflow job for this annotation

GitHub Actions / multiprocess-unit-tests (Python 3.11, jax=newest)

Path has not been created yet. Call await_creation().

Check failure on line 261 in checkpoint/orbax/checkpoint/_src/path/atomicity.py

View workflow job for this annotation

GitHub Actions / multiprocess-unit-tests (Python 3.11, jax=newest)

Path has not been created yet. Call await_creation().

Check failure on line 261 in checkpoint/orbax/checkpoint/_src/path/atomicity.py

View workflow job for this annotation

GitHub Actions / multiprocess-unit-tests (Python 3.11, jax=newest)

Path has not been created yet. Call await_creation().

Check failure on line 261 in checkpoint/orbax/checkpoint/_src/path/atomicity.py

View workflow job for this annotation

GitHub Actions / multiprocess-unit-tests (Python 3.11, jax=newest)

Path has not been created yet. Call await_creation().

Check failure on line 261 in checkpoint/orbax/checkpoint/_src/path/atomicity.py

View workflow job for this annotation

GitHub Actions / multiprocess-unit-tests (Python 3.11, jax=newest)

Path has not been created yet. Call await_creation().
return self._future_path.result()

async def await_creation(self) -> epath.Path:
return await asyncio.wrap_future(self._future_path)


class ReadOnlyTemporaryPath(atomicity_types.TemporaryPath):
"""A read-only, serializable object providing path properties access.

Expand Down
29 changes: 29 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/atomicity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import concurrent.futures
import stat
import unittest
from absl.testing import absltest
Expand Down Expand Up @@ -205,6 +207,33 @@ async def test_finalize_raises(self):
)


class DeferredPathTest(absltest.TestCase):

def test_set_and_get_path(self):
dp = atomicity.DeferredPath()
test_path = epath.Path('/test/path')
dp.set_path(test_path)
self.assertEqual(dp.path, test_path)

def test_path_before_set_raises(self):
dp = atomicity.DeferredPath()
with self.assertRaises(ValueError):
_ = dp.path

def test_await_creation(self):
dp = atomicity.DeferredPath()
test_path = epath.Path('/test/path')
dp.set_path(test_path)
result = asyncio.run(dp.await_creation())
self.assertEqual(result, test_path)

def test_set_path_twice_raises(self):
dp = atomicity.DeferredPath()
dp.set_path(epath.Path('/first'))
with self.assertRaises(concurrent.futures.InvalidStateError):
dp.set_path(epath.Path('/second'))



if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,10 @@ def _serialize_batch(
' scheduled asynchronously.'
)

all_infos = infos
async def _serialize():
for info in all_infos:
await info.await_path_creation()
if prioritized:
arrays, infos, args = zip(*prioritized)
_serialize_batch(infos, args, arrays)
Expand Down
Loading