Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions export/orbax/export/data_processors/tf_data_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for TfDataProcessor."""

import os
import orbax.experimental.model.core as obm
from orbax.export.data_processors import tf_data_processor
import tensorflow as tf
Expand Down Expand Up @@ -80,6 +79,7 @@ def test_prepare_succeeds(self):

self.assertIsNotNone(processor.concrete_function)
self.assertIsNotNone(processor.obm_function)

self.assertEqual(
processor.input_signature[0][0],
obm.ShloTensorSpec(shape=(None, 3), dtype=obm.ShloDType.f64, name='x'),
Expand Down Expand Up @@ -158,17 +158,40 @@ def test_suppress_x64_output(self):
def test_convert_to_bfloat16(self):
v = tf.Variable(0.5, dtype=tf.float32)

def func(x):
return v + x
def func(x, y, z):
return v + x + y + tf.cast(z, tf.float32)

processor = tf_data_processor.TfDataProcessor(func, name='preprocessor')
converter_options = converter_options_v2_pb2.ConverterOptionsV2(
bfloat16_optimization_options=converter_options_v2_pb2.BFloat16OptimizationOptions(
scope=converter_options_v2_pb2.BFloat16OptimizationOptions.ALL,
skip_safety_checks=True,
)
)

processor.prepare(
available_tensor_specs=(tf.TensorSpec(shape=(2, 3), dtype=tf.float32)),
bfloat16_options=converter_options_v2_pb2.ConverterOptionsV2(
bfloat16_optimization_options=converter_options_v2_pb2.BFloat16OptimizationOptions(
scope=converter_options_v2_pb2.BFloat16OptimizationOptions.ALL,
skip_safety_checks=True,
)
available_tensor_specs=(tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
tf.TensorSpec(shape=(2, 3), dtype=tf.float32),
tf.constant(2, dtype=tf.int32)),
bfloat16_options=converter_options,
tf_trackable_resources=[v],
)

self.assertEqual(
processor.input_signature,
(
(
obm.ShloTensorSpec(
shape=(2, 3), dtype=obm.ShloDType.bf16, name='x'
),
obm.ShloTensorSpec(
shape=(2, 3), dtype=obm.ShloDType.bf16, name='y'
),
obm.ShloTensorSpec(
shape=(), dtype=obm.ShloDType.i32, name='z'
),
),
{},
),
)
self.assertEqual(
Expand All @@ -177,11 +200,22 @@ def func(x):
shape=(2, 3), dtype=obm.ShloDType.bf16, name='output_0'
),
)
self.assertLen(processor.concrete_function.variables, 1)
self.assertEqual(
processor.concrete_function.variables[0].dtype, tf.bfloat16

# Verify that the variables have been converted to bfloat16 too.
model_dir = self.create_tempdir().full_path
tf2obm.save_tf_functions(
model_dir,
{'preprocessor': processor.concrete_function},
trackable_resources=[v],
converter_options=converter_options,
)

saved_model = tf.saved_model.load(os.path.join(model_dir, 'tf_saved_model'))
restored_fn = saved_model.signatures['preprocessor']

self.assertLen(restored_fn.variables, 1)
self.assertEqual(restored_fn.variables[0].dtype, tf.bfloat16)

def test_bfloat16_convert_error(self):
processor = tf_data_processor.TfDataProcessor(
lambda x: 0.5 + x, name='preprocessor'
Expand Down
3 changes: 3 additions & 0 deletions export/orbax/export/export_testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

"""Testing utils for orbax.export."""

import os
from typing import cast

import jax
from jax import sharding
from jax.experimental import mesh_utils
Expand All @@ -27,4 +29,5 @@
from orbax.export import serving_config as osc
import tensorflow as tf


os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
2 changes: 1 addition & 1 deletion export/orbax/export/obm_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from collections.abc import Callable, Mapping, Sequence
import copy
import dataclasses
import functools
import itertools
import os
from typing import Any, cast
Expand All @@ -42,6 +41,7 @@

_obm_export_config = config.config


class ObmExport(export_base.ExportBase):
"""Defines the save and load methods for exporting a model using Orbax Model export."""

Expand Down
1 change: 0 additions & 1 deletion export/orbax/export/obm_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from collections.abc import Mapping, Sequence
import contextlib
import importlib
import os
import pathlib
from typing import Any, Callable
Expand Down
133 changes: 129 additions & 4 deletions model/orbax/experimental/model/tf2obm/_src/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""Converts TF concrete functions to OBM functions (allowing TF resources)."""

from collections.abc import Mapping, Sequence
import copy
import os
import tempfile
from typing import Any, Dict, NamedTuple, Tuple

from jax import tree_util as jax_tree_util
Expand All @@ -26,6 +28,8 @@

from .learning.brain.contrib.tpu_modeling.inference_converter_v2 import converter_options_v2_pb2
from .learning.brain.contrib.tpu_modeling.inference_converter_v2.python import converter
from tensorflow.core.protobuf import meta_graph_pb2 # pylint: disable=g-direct-tensorflow-import
from tensorflow.core.protobuf import saved_model_pb2 # pylint: disable=g-direct-tensorflow-import

TF_CONCRETE_FUNCTION_HANDLE_MIME_TYPE = (
'application/protobuf;'
Expand All @@ -52,6 +56,10 @@ def _is_args_kwargs_pattern(tree: utils.TfSignature) -> bool:
def convert_function(
fn_name: str,
fn: tf.types.experimental.ConcreteFunction,
converter_options: (
converter_options_v2_pb2.ConverterOptionsV2 | None
) = None,
trackable_resources: Any | None = None,
) -> obm.SerializableFunction:
"""Converts the TF concrete function to an OBM function.

Expand All @@ -62,16 +70,36 @@ def convert_function(
fn_name: The name to be used in the OBM manifest to refer to the TF
function.
fn: The TF concrete function.
converter_options: The converter options to use for the TF SavedModel. If
set, the TF SavedModel will be converted using Inference Converter V2 in
order to get the correct types for the input and output signatures.
trackable_resources: Trackable resources used by the function.

Returns:
The OBM function referring to the original TF function in the TF SavedModel.
"""
input_signature = fn.structured_input_signature
output_signature = get_output_signature(fn)

input_names, _, _ = _flat_input_signature(fn)
output_names = _output_names(fn)

if converter_options is not None:
converterted_signature_def = _get_converted_function_signature_def(
fn_name, fn, trackable_resources, converter_options
)
input_signature = _copy_types_from_signature_def(
fn.structured_input_signature,
converterted_signature_def.inputs,
input_names,
)
output_signature = _copy_types_from_signature_def(
get_output_signature(fn),
converterted_signature_def.outputs,
output_names,
)
else:
input_signature = fn.structured_input_signature
output_signature = get_output_signature(fn)

unstructured_data = obm.manifest_pb2.UnstructuredData(
inlined_bytes=tf_concrete_function_handle_pb2.TfConcreteFunctionHandle(
fn_name=fn_name,
Expand Down Expand Up @@ -406,12 +434,23 @@ def save_tf_functions(

target_path = os.path.join(model_dir, tf_saved_model_sub_dir)
if converter_options is not None:
# Inference Converter V2 modifies the converter_options in place, so we
# need to deepcopy it to avoid modifying the original options and keep
# them re-usable.
converter_options_copy = copy.deepcopy(converter_options)
pre_conversion_path = os.path.join(model_dir, 'tmp_tf_saved_model')
tf.saved_model.save(tf_module, pre_conversion_path, signatures=wrapped_fns)
tf.saved_model.save(
tf_module,
pre_conversion_path,
signatures=wrapped_fns,
# Function aliases are used by the Inference Converter V2 to
# identify XLA functions.
options=tf.saved_model.SaveOptions(function_aliases=wrapped_fns),
)
converter.ConvertSavedModel(
pre_conversion_path,
target_path,
converter_options,
converter_options_copy,
)
tf.io.gfile.rmtree(pre_conversion_path)
else:
Expand All @@ -422,3 +461,89 @@ def save_tf_functions(
tf_saved_model_as_obm_supplemental(tf_saved_model_sub_dir)
)
}


def _copy_types_from_signature_def(
original_signature: Any,
signature_def_args: Mapping[str, meta_graph_pb2.TensorInfo],
arg_names: Sequence[str],
) -> Any:
"""Copies types from TF SignatureDef to the original signature.

Args:
original_signature: The original signature that needs new types.
signature_def_args: The TF SignatureDef arguments to copy types from.
arg_names: The argument names of the original TF function. They are used to
infer the input order in the original signature.

Returns:
The original signature with types copied from the signature_def for the
corresponding input names.

Raises:
ValueError: If any of the argument names is not found in the SignatureDef.
"""

arg_names_iter = iter(arg_names)

def _copy_type(t: Any) -> Any:
arg_name = next(arg_names_iter)
if arg_name not in signature_def_args:
raise ValueError(
f'Argument name {arg_name!r} not found in SignatureDef: '
f'{signature_def_args.keys()!r}'
)

if not isinstance(t, tf.TensorSpec):
return t

return tf.TensorSpec(
shape=t.shape,
dtype=tf.as_dtype(signature_def_args[arg_name].dtype),
name=arg_name,
)

return jax_tree_util.tree_map(
_copy_type,
original_signature,
)


def _get_converted_function_signature_def(
fn_name: str,
fn: tf.types.experimental.ConcreteFunction,
trackable_resources: Any,
converter_options: converter_options_v2_pb2.ConverterOptionsV2,
) -> meta_graph_pb2.SignatureDef:
"""Saves the function, converts it, returns its SignatureDef.

Args:
fn_name: The name of the function in the SavedModel.
fn: The concrete function to save.
trackable_resources: The trackable resources to save.
converter_options: The converter options to use for the TF SavedModel.

Returns:
The SignatureDef of the converted function.
"""

opts_copy = copy.deepcopy(converter_options)
# There is no need to convert the checkpoint in this case, since we are only
# interested in the signature.
opts_copy.bfloat16_optimization_options.experimental.convert_checkpoint = (
False
)

with tempfile.TemporaryDirectory() as temp_dir:
save_tf_functions(
temp_dir,
{fn_name: fn},
trackable_resources=trackable_resources,
converter_options=opts_copy,
)

converted_model_path = os.path.join(temp_dir, OBM_TF_SAVED_MODEL_SUB_DIR)
with open(os.path.join(converted_model_path, 'saved_model.pb'), 'rb') as f:
saved_model_proto = saved_model_pb2.SavedModel.FromString(f.read())

return saved_model_proto.meta_graphs[0].signature_def[fn_name]
Loading