diff --git a/export/orbax/export/data_processors/tf_data_processor_test.py b/export/orbax/export/data_processors/tf_data_processor_test.py index ef62da777..2bfbde109 100644 --- a/export/orbax/export/data_processors/tf_data_processor_test.py +++ b/export/orbax/export/data_processors/tf_data_processor_test.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for TfDataProcessor.""" - +import os import orbax.experimental.model.core as obm from orbax.export.data_processors import tf_data_processor import tensorflow as tf @@ -80,6 +79,7 @@ def test_prepare_succeeds(self): self.assertIsNotNone(processor.concrete_function) self.assertIsNotNone(processor.obm_function) + self.assertEqual( processor.input_signature[0][0], obm.ShloTensorSpec(shape=(None, 3), dtype=obm.ShloDType.f64, name='x'), @@ -158,17 +158,40 @@ def test_suppress_x64_output(self): def test_convert_to_bfloat16(self): v = tf.Variable(0.5, dtype=tf.float32) - def func(x): - return v + x + def func(x, y, z): + return v + x + y + tf.cast(z, tf.float32) processor = tf_data_processor.TfDataProcessor(func, name='preprocessor') + converter_options = converter_options_v2_pb2.ConverterOptionsV2( + bfloat16_optimization_options=converter_options_v2_pb2.BFloat16OptimizationOptions( + scope=converter_options_v2_pb2.BFloat16OptimizationOptions.ALL, + skip_safety_checks=True, + ) + ) + processor.prepare( - available_tensor_specs=(tf.TensorSpec(shape=(2, 3), dtype=tf.float32)), - bfloat16_options=converter_options_v2_pb2.ConverterOptionsV2( - bfloat16_optimization_options=converter_options_v2_pb2.BFloat16OptimizationOptions( - scope=converter_options_v2_pb2.BFloat16OptimizationOptions.ALL, - skip_safety_checks=True, - ) + available_tensor_specs=(tf.TensorSpec(shape=(2, 3), dtype=tf.float32), + tf.TensorSpec(shape=(2, 3), dtype=tf.float32), + tf.constant(2, dtype=tf.int32)), + bfloat16_options=converter_options, + tf_trackable_resources=[v], + ) + + self.assertEqual( + processor.input_signature, + ( + ( + obm.ShloTensorSpec( + shape=(2, 3), dtype=obm.ShloDType.bf16, name='x' + ), + obm.ShloTensorSpec( + shape=(2, 3), dtype=obm.ShloDType.bf16, name='y' + ), + obm.ShloTensorSpec( + shape=(), dtype=obm.ShloDType.i32, name='z' + ), + ), + {}, ), ) self.assertEqual( @@ -177,11 +200,22 @@ def func(x): shape=(2, 3), dtype=obm.ShloDType.bf16, name='output_0' ), ) - self.assertLen(processor.concrete_function.variables, 1) - self.assertEqual( - processor.concrete_function.variables[0].dtype, tf.bfloat16 + + # Verify that the variables have been converted to bfloat16 too. + model_dir = self.create_tempdir().full_path + tf2obm.save_tf_functions( + model_dir, + {'preprocessor': processor.concrete_function}, + trackable_resources=[v], + converter_options=converter_options, ) + saved_model = tf.saved_model.load(os.path.join(model_dir, 'tf_saved_model')) + restored_fn = saved_model.signatures['preprocessor'] + + self.assertLen(restored_fn.variables, 1) + self.assertEqual(restored_fn.variables[0].dtype, tf.bfloat16) + def test_bfloat16_convert_error(self): processor = tf_data_processor.TfDataProcessor( lambda x: 0.5 + x, name='preprocessor' diff --git a/export/orbax/export/export_testing_utils.py b/export/orbax/export/export_testing_utils.py index f71e1911a..19ba04832 100644 --- a/export/orbax/export/export_testing_utils.py +++ b/export/orbax/export/export_testing_utils.py @@ -13,8 +13,10 @@ # limitations under the License. """Testing utils for orbax.export.""" + import os from typing import cast + import jax from jax import sharding from jax.experimental import mesh_utils @@ -27,4 +29,5 @@ from orbax.export import serving_config as osc import tensorflow as tf + os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' diff --git a/export/orbax/export/obm_export.py b/export/orbax/export/obm_export.py index b04fc2d6a..0913e82db 100644 --- a/export/orbax/export/obm_export.py +++ b/export/orbax/export/obm_export.py @@ -17,7 +17,6 @@ from collections.abc import Callable, Mapping, Sequence import copy import dataclasses -import functools import itertools import os from typing import Any, cast @@ -42,6 +41,7 @@ _obm_export_config = config.config + class ObmExport(export_base.ExportBase): """Defines the save and load methods for exporting a model using Orbax Model export.""" diff --git a/export/orbax/export/obm_export_test.py b/export/orbax/export/obm_export_test.py index b173dfdeb..ec28dc7d1 100644 --- a/export/orbax/export/obm_export_test.py +++ b/export/orbax/export/obm_export_test.py @@ -14,7 +14,6 @@ from collections.abc import Mapping, Sequence import contextlib -import importlib import os import pathlib from typing import Any, Callable diff --git a/model/orbax/experimental/model/tf2obm/_src/converter.py b/model/orbax/experimental/model/tf2obm/_src/converter.py index 6b1c83eb2..7f8faf5ac 100644 --- a/model/orbax/experimental/model/tf2obm/_src/converter.py +++ b/model/orbax/experimental/model/tf2obm/_src/converter.py @@ -15,7 +15,9 @@ """Converts TF concrete functions to OBM functions (allowing TF resources).""" from collections.abc import Mapping, Sequence +import copy import os +import tempfile from typing import Any, Dict, NamedTuple, Tuple from jax import tree_util as jax_tree_util @@ -26,6 +28,8 @@ from .learning.brain.contrib.tpu_modeling.inference_converter_v2 import converter_options_v2_pb2 from .learning.brain.contrib.tpu_modeling.inference_converter_v2.python import converter +from tensorflow.core.protobuf import meta_graph_pb2 # pylint: disable=g-direct-tensorflow-import +from tensorflow.core.protobuf import saved_model_pb2 # pylint: disable=g-direct-tensorflow-import TF_CONCRETE_FUNCTION_HANDLE_MIME_TYPE = ( 'application/protobuf;' @@ -52,6 +56,10 @@ def _is_args_kwargs_pattern(tree: utils.TfSignature) -> bool: def convert_function( fn_name: str, fn: tf.types.experimental.ConcreteFunction, + converter_options: ( + converter_options_v2_pb2.ConverterOptionsV2 | None + ) = None, + trackable_resources: Any | None = None, ) -> obm.SerializableFunction: """Converts the TF concrete function to an OBM function. @@ -62,16 +70,36 @@ def convert_function( fn_name: The name to be used in the OBM manifest to refer to the TF function. fn: The TF concrete function. + converter_options: The converter options to use for the TF SavedModel. If + set, the TF SavedModel will be converted using Inference Converter V2 in + order to get the correct types for the input and output signatures. + trackable_resources: Trackable resources used by the function. Returns: The OBM function referring to the original TF function in the TF SavedModel. """ - input_signature = fn.structured_input_signature - output_signature = get_output_signature(fn) input_names, _, _ = _flat_input_signature(fn) output_names = _output_names(fn) + if converter_options is not None: + converterted_signature_def = _get_converted_function_signature_def( + fn_name, fn, trackable_resources, converter_options + ) + input_signature = _copy_types_from_signature_def( + fn.structured_input_signature, + converterted_signature_def.inputs, + input_names, + ) + output_signature = _copy_types_from_signature_def( + get_output_signature(fn), + converterted_signature_def.outputs, + output_names, + ) + else: + input_signature = fn.structured_input_signature + output_signature = get_output_signature(fn) + unstructured_data = obm.manifest_pb2.UnstructuredData( inlined_bytes=tf_concrete_function_handle_pb2.TfConcreteFunctionHandle( fn_name=fn_name, @@ -406,12 +434,23 @@ def save_tf_functions( target_path = os.path.join(model_dir, tf_saved_model_sub_dir) if converter_options is not None: + # Inference Converter V2 modifies the converter_options in place, so we + # need to deepcopy it to avoid modifying the original options and keep + # them re-usable. + converter_options_copy = copy.deepcopy(converter_options) pre_conversion_path = os.path.join(model_dir, 'tmp_tf_saved_model') - tf.saved_model.save(tf_module, pre_conversion_path, signatures=wrapped_fns) + tf.saved_model.save( + tf_module, + pre_conversion_path, + signatures=wrapped_fns, + # Function aliases are used by the Inference Converter V2 to + # identify XLA functions. + options=tf.saved_model.SaveOptions(function_aliases=wrapped_fns), + ) converter.ConvertSavedModel( pre_conversion_path, target_path, - converter_options, + converter_options_copy, ) tf.io.gfile.rmtree(pre_conversion_path) else: @@ -422,3 +461,89 @@ def save_tf_functions( tf_saved_model_as_obm_supplemental(tf_saved_model_sub_dir) ) } + + +def _copy_types_from_signature_def( + original_signature: Any, + signature_def_args: Mapping[str, meta_graph_pb2.TensorInfo], + arg_names: Sequence[str], +) -> Any: + """Copies types from TF SignatureDef to the original signature. + + Args: + original_signature: The original signature that needs new types. + signature_def_args: The TF SignatureDef arguments to copy types from. + arg_names: The argument names of the original TF function. They are used to + infer the input order in the original signature. + + Returns: + The original signature with types copied from the signature_def for the + corresponding input names. + + Raises: + ValueError: If any of the argument names is not found in the SignatureDef. + """ + + arg_names_iter = iter(arg_names) + + def _copy_type(t: Any) -> Any: + arg_name = next(arg_names_iter) + if arg_name not in signature_def_args: + raise ValueError( + f'Argument name {arg_name!r} not found in SignatureDef: ' + f'{signature_def_args.keys()!r}' + ) + + if not isinstance(t, tf.TensorSpec): + return t + + return tf.TensorSpec( + shape=t.shape, + dtype=tf.as_dtype(signature_def_args[arg_name].dtype), + name=arg_name, + ) + + return jax_tree_util.tree_map( + _copy_type, + original_signature, + ) + + +def _get_converted_function_signature_def( + fn_name: str, + fn: tf.types.experimental.ConcreteFunction, + trackable_resources: Any, + converter_options: converter_options_v2_pb2.ConverterOptionsV2, +) -> meta_graph_pb2.SignatureDef: + """Saves the function, converts it, returns its SignatureDef. + + Args: + fn_name: The name of the function in the SavedModel. + fn: The concrete function to save. + trackable_resources: The trackable resources to save. + converter_options: The converter options to use for the TF SavedModel. + + Returns: + The SignatureDef of the converted function. + """ + + opts_copy = copy.deepcopy(converter_options) + # There is no need to convert the checkpoint in this case, since we are only + # interested in the signature. + opts_copy.bfloat16_optimization_options.experimental.convert_checkpoint = ( + False + ) + + with tempfile.TemporaryDirectory() as temp_dir: + save_tf_functions( + temp_dir, + {fn_name: fn}, + trackable_resources=trackable_resources, + converter_options=opts_copy, + ) + + converted_model_path = os.path.join(temp_dir, OBM_TF_SAVED_MODEL_SUB_DIR) + with open(os.path.join(converted_model_path, 'saved_model.pb'), 'rb') as f: + saved_model_proto = saved_model_pb2.SavedModel.FromString(f.read()) + + return saved_model_proto.meta_graphs[0].signature_def[fn_name]