4040from .prepselprep import PrepSelPrep
4141from .qubitization import Qubitization
4242
43+ is_jax_available = True
44+ is_optax_available = True
4345try :
44- from jax import config , lax
45- from jax import numpy as jnp
46- from jax import vmap
47- except ModuleNotFoundError : # pragma: no cover
48- pass
46+ import jax
47+ except ImportError :
48+ is_jax_available = False
49+
50+ try :
51+ import optax
52+ except ImportError :
53+ is_otpax_available = False
54+
55+ # try:
56+ # from jax import config, lax
57+ # from jax import numpy as jnp
58+ # from jax import vmap
59+ # except ModuleNotFoundError: # pragma: no cover
60+ # pass
4961
5062
5163def jit_if_jax_available (f , ** kwargs ):
@@ -848,7 +860,7 @@ def _cheby_pol(x, degree):
848860def _poly_func (coeffs , x ):
849861 """\sum c_kT_{k}(x) where T_k(x)=cos(karccos(x))"""
850862
851- return jnp . sum (coeffs @ vmap (_cheby_pol , in_axes = (None , 0 ))(x , np .arange (coeffs .shape [0 ])))
863+ return jax . numpy . sum (coeffs @ jax . vmap (_cheby_pol , in_axes = (None , 0 ))(x , np .arange (coeffs .shape [0 ])))
852864
853865
854866@partial (jit_if_jax_available , static_argnames = ["interface" ])
@@ -913,7 +925,7 @@ def _qsp_iterate_broadcast(phis, x, interface):
913925 """
914926
915927 # pylint: disable=import-outside-toplevel
916- qsp_iterate_list = vmap (_qsp_iterate , in_axes = (0 , None , None ))(phis [1 :], x , interface )
928+ qsp_iterate_list = jax . vmap (_qsp_iterate , in_axes = (0 , None , None ))(phis [1 :], x , interface )
917929
918930 matrix_iterate = reduce (math .dot , qsp_iterate_list )
919931 matrix_iterate = math .dot (_z_rotation (phi = phis [0 ], interface = interface ), matrix_iterate )
@@ -941,14 +953,13 @@ def _grid_pts(degree, interface):
941953@jit_if_jax_available
942954def obj_function (phi , x , y ):
943955 # Equation (23)
944- obj_func = vmap (_qsp_iterate_broadcast , in_axes = (None , 0 , None ))(phi , x , "jax" ) - y
945- obj_func = jnp .dot (obj_func , obj_func )
956+ obj_func = jax . vmap (_qsp_iterate_broadcast , in_axes = (None , 0 , None ))(phi , x , "jax" ) - y
957+ obj_func = jax . numpy .dot (obj_func , obj_func )
946958 return 1 / x .shape [0 ] * obj_func
947959
948960
949961@partial (jit_if_jax_available , static_argnames = ["maxiter" , "tol" ])
950962def optax_opt (initial_guess , x , y , maxiter , tol ):
951- import optax
952963
953964 opt = optax .lbfgs ()
954965 init_carry = (initial_guess , opt .init (initial_guess ))
@@ -968,19 +979,19 @@ def while_loop_cond(params):
968979 cost_val = optax .tree .get (state , "value" )
969980 return (num_iter == 0 ) | ((num_iter < maxiter ) & (cost_val > tol ))
970981
971- carry = lax .while_loop (while_loop_cond , optimizer_iter_update , init_carry )
982+ carry = jax . lax .while_loop (while_loop_cond , optimizer_iter_update , init_carry )
972983 return carry [0 ]
973984
974985
975986def _qsp_optimization (degree : int , coeffs_target_func , maxiter = 100 , tol = 1e-30 ):
976987 """Algorithm 1 in https://arxiv.org/pdf/2002.11649 produces the angle parameters by minimizing the distance between the target and qsp polynomail over the grid"""
977988
978- config .update ("jax_enable_x64" , True )
989+ jax . config .update ("jax_enable_x64" , True )
979990 grid_points = _grid_pts (degree , "jax" )
980991 initial_guess = [np .pi / 4 ] + [0.0 ] * (degree - 1 ) + [np .pi / 4 ]
981992
982- initial_guess = jnp .array (initial_guess )
983- targets = vmap (_poly_func , in_axes = (None , 0 ))(coeffs_target_func , grid_points )
993+ initial_guess = jax . numpy .array (initial_guess )
994+ targets = jax . vmap (_poly_func , in_axes = (None , 0 ))(coeffs_target_func , grid_points )
984995
985996 opt_params = optax_opt (initial_guess , grid_points , targets , maxiter , tol )
986997 cost_fun = obj_function (opt_params , grid_points , targets )
@@ -991,13 +1002,19 @@ def _qsp_optimization(degree: int, coeffs_target_func, maxiter=100, tol=1e-30):
9911002def _compute_qsp_angles_iteratively (
9921003 poly ,
9931004):
994- try :
995- import jax
996- import optax
997-
998- except ModuleNotFoundError as exc :
999- raise ModuleNotFoundError ("JAX and optax are required" ) from exc
1005+ # try:
1006+ # import jax
1007+ # import optax
10001008
1009+ # except ModuleNotFoundError as exc:
1010+ # raise ModuleNotFoundError("JAX and optax are required") from exc
1011+
1012+ if not is_jax_available :
1013+ raise ModuleNotFoundError ("jax is required!" )
1014+
1015+ if not is_optax_available :
1016+ raise ModuleNotFoundError ("optax is required!" )
1017+
10011018 poly_cheb = chebyshev .poly2cheb (poly )
10021019 degree = len (poly_cheb ) - 1
10031020
0 commit comments