Skip to content

Commit 7a6869c

Browse files
committed
replacing imports try/catch blocks with indicative variable
1 parent cca0257 commit 7a6869c

File tree

2 files changed

+65
-32
lines changed

2 files changed

+65
-32
lines changed

pennylane/templates/subroutines/qsvt.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,24 @@
4040
from .prepselprep import PrepSelPrep
4141
from .qubitization import Qubitization
4242

43+
is_jax_available = True
44+
is_optax_available = True
4345
try:
44-
from jax import config, lax
45-
from jax import numpy as jnp
46-
from jax import vmap
47-
except ModuleNotFoundError: # pragma: no cover
48-
pass
46+
import jax
47+
except ImportError:
48+
is_jax_available = False
49+
50+
try:
51+
import optax
52+
except ImportError:
53+
is_otpax_available = False
54+
55+
# try:
56+
# from jax import config, lax
57+
# from jax import numpy as jnp
58+
# from jax import vmap
59+
# except ModuleNotFoundError: # pragma: no cover
60+
# pass
4961

5062

5163
def jit_if_jax_available(f, **kwargs):
@@ -848,7 +860,7 @@ def _cheby_pol(x, degree):
848860
def _poly_func(coeffs, x):
849861
"""\sum c_kT_{k}(x) where T_k(x)=cos(karccos(x))"""
850862

851-
return jnp.sum(coeffs @ vmap(_cheby_pol, in_axes=(None, 0))(x, np.arange(coeffs.shape[0])))
863+
return jax.numpy.sum(coeffs @ jax.vmap(_cheby_pol, in_axes=(None, 0))(x, np.arange(coeffs.shape[0])))
852864

853865

854866
@partial(jit_if_jax_available, static_argnames=["interface"])
@@ -913,7 +925,7 @@ def _qsp_iterate_broadcast(phis, x, interface):
913925
"""
914926

915927
# pylint: disable=import-outside-toplevel
916-
qsp_iterate_list = vmap(_qsp_iterate, in_axes=(0, None, None))(phis[1:], x, interface)
928+
qsp_iterate_list = jax.vmap(_qsp_iterate, in_axes=(0, None, None))(phis[1:], x, interface)
917929

918930
matrix_iterate = reduce(math.dot, qsp_iterate_list)
919931
matrix_iterate = math.dot(_z_rotation(phi=phis[0], interface=interface), matrix_iterate)
@@ -941,14 +953,13 @@ def _grid_pts(degree, interface):
941953
@jit_if_jax_available
942954
def obj_function(phi, x, y):
943955
# Equation (23)
944-
obj_func = vmap(_qsp_iterate_broadcast, in_axes=(None, 0, None))(phi, x, "jax") - y
945-
obj_func = jnp.dot(obj_func, obj_func)
956+
obj_func = jax.vmap(_qsp_iterate_broadcast, in_axes=(None, 0, None))(phi, x, "jax") - y
957+
obj_func = jax.numpy.dot(obj_func, obj_func)
946958
return 1 / x.shape[0] * obj_func
947959

948960

949961
@partial(jit_if_jax_available, static_argnames=["maxiter", "tol"])
950962
def optax_opt(initial_guess, x, y, maxiter, tol):
951-
import optax
952963

953964
opt = optax.lbfgs()
954965
init_carry = (initial_guess, opt.init(initial_guess))
@@ -968,19 +979,19 @@ def while_loop_cond(params):
968979
cost_val = optax.tree.get(state, "value")
969980
return (num_iter == 0) | ((num_iter < maxiter) & (cost_val > tol))
970981

971-
carry = lax.while_loop(while_loop_cond, optimizer_iter_update, init_carry)
982+
carry = jax.lax.while_loop(while_loop_cond, optimizer_iter_update, init_carry)
972983
return carry[0]
973984

974985

975986
def _qsp_optimization(degree: int, coeffs_target_func, maxiter=100, tol=1e-30):
976987
"""Algorithm 1 in https://arxiv.org/pdf/2002.11649 produces the angle parameters by minimizing the distance between the target and qsp polynomail over the grid"""
977988

978-
config.update("jax_enable_x64", True)
989+
jax.config.update("jax_enable_x64", True)
979990
grid_points = _grid_pts(degree, "jax")
980991
initial_guess = [np.pi / 4] + [0.0] * (degree - 1) + [np.pi / 4]
981992

982-
initial_guess = jnp.array(initial_guess)
983-
targets = vmap(_poly_func, in_axes=(None, 0))(coeffs_target_func, grid_points)
993+
initial_guess = jax.numpy.array(initial_guess)
994+
targets = jax.vmap(_poly_func, in_axes=(None, 0))(coeffs_target_func, grid_points)
984995

985996
opt_params = optax_opt(initial_guess, grid_points, targets, maxiter, tol)
986997
cost_fun = obj_function(opt_params, grid_points, targets)
@@ -991,13 +1002,19 @@ def _qsp_optimization(degree: int, coeffs_target_func, maxiter=100, tol=1e-30):
9911002
def _compute_qsp_angles_iteratively(
9921003
poly,
9931004
):
994-
try:
995-
import jax
996-
import optax
997-
998-
except ModuleNotFoundError as exc:
999-
raise ModuleNotFoundError("JAX and optax are required") from exc
1005+
# try:
1006+
# import jax
1007+
# import optax
10001008

1009+
# except ModuleNotFoundError as exc:
1010+
# raise ModuleNotFoundError("JAX and optax are required") from exc
1011+
1012+
if not is_jax_available:
1013+
raise ModuleNotFoundError("jax is required!")
1014+
1015+
if not is_optax_available:
1016+
raise ModuleNotFoundError("optax is required!")
1017+
10011018
poly_cheb = chebyshev.poly2cheb(poly)
10021019
degree = len(poly_cheb) - 1
10031020

tests/templates/subroutines/test_qsvt.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@
3636
_z_rotation,
3737
jit_if_jax_available,
3838
)
39+
is_jax_available = True
40+
is_optax_available = True
41+
try:
42+
import jax
43+
except ImportError:
44+
is_jax_available = False
45+
46+
try:
47+
import optax
48+
except ImportError:
49+
is_otpax_available = False
3950

4051

4152
def qfunc(A):
@@ -982,8 +993,6 @@ def test_qsp_on_poly_with_parity(self, polynomial_coeffs_in_cheby_basis):
982993
target_polynomial_coeffs = polynomial_coeffs_in_cheby_basis
983994
phis, cost_func = _qsp_optimization(degree, target_polynomial_coeffs)
984995

985-
import jax
986-
987996
key = jax.random.key(123)
988997
x_point = jax.random.uniform(key=key, shape=(1,), minval=-1, maxval=1)
989998

@@ -1018,12 +1027,13 @@ def f(x):
10181027

10191028
jit_wrapped_f = jit_if_jax_available(f)
10201029

1021-
try:
1022-
# This is for testing fallback
1023-
import jax # pylint: disable=unused-import
1030+
if is_jax_available:
1031+
# try:
1032+
# # This is for testing fallback
1033+
# import jax # pylint: disable=unused-import
10241034

10251035
assert hasattr(jit_wrapped_f, "lower")
1026-
except ModuleNotFoundError:
1036+
else:
10271037
assert jit_wrapped_f is f
10281038

10291039
@pytest.mark.parametrize(
@@ -1033,13 +1043,19 @@ def f(x):
10331043
],
10341044
)
10351045
def test_raised_exceptions(self, polynomial_coeffs_in_cheby_basis):
1036-
try:
1037-
import jax # pylint: disable=unused-import
1038-
import optax # pylint: disable=unused-import
1039-
except ModuleNotFoundError:
1040-
with pytest.raises(ModuleNotFoundError, match="JAX and optax are required"):
1046+
# try:
1047+
# import jax # pylint: disable=unused-import
1048+
# import optax # pylint: disable=unused-import
1049+
# except ModuleNotFoundError:
1050+
# with pytest.raises(ModuleNotFoundError, match="JAX and optax are required"):
1051+
# _compute_qsp_angles_iteratively(polynomial_coeffs_in_cheby_basis)
1052+
if not is_jax_available:
1053+
with pytest.raises(ModuleNotFoundError, match="jax is required!"):
10411054
_compute_qsp_angles_iteratively(polynomial_coeffs_in_cheby_basis)
1042-
1055+
elif not is_optax_available:
1056+
with pytest.raises(ModuleNotFoundError, match="optax is required!"):
1057+
_compute_qsp_angles_iteratively(polynomial_coeffs_in_cheby_basis)
1058+
10431059
@pytest.mark.parametrize(
10441060
"x, degree",
10451061
[

0 commit comments

Comments
 (0)