Skip to content

Commit 0f49fb3

Browse files
committed
add more tests
1 parent 4430396 commit 0f49fb3

File tree

1 file changed

+159
-13
lines changed

1 file changed

+159
-13
lines changed

frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py

Lines changed: 159 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class FakeDAGBuilder(DAGBuilder):
4444

4545
def __init__(self):
4646
self._nodes = {}
47-
self._edges = []
47+
self._edges = {}
4848
self._clusters = {}
4949

5050
def add_node(self, uid, label, cluster_uid=None, **attrs) -> None:
@@ -56,13 +56,11 @@ def add_node(self, uid, label, cluster_uid=None, **attrs) -> None:
5656
}
5757

5858
def add_edge(self, from_uid: str, to_uid: str, **attrs) -> None:
59-
self._edges.append(
60-
{
61-
"from_uid": from_uid,
62-
"to_uid": to_uid,
63-
"attrs": attrs,
64-
}
65-
)
59+
# O(1) look up
60+
edge_key = (from_uid, to_uid)
61+
self._edges[edge_key] = {
62+
"attrs": attrs,
63+
}
6664

6765
def add_cluster(
6866
self,
@@ -979,17 +977,165 @@ class TestOperatorConnectivity:
979977
@pytest.mark.unit
980978
def test_static_connection_within_cluster(self):
981979
"""Tests that connections can be made within the same cluster."""
982-
pass
980+
981+
dev = qml.device("null.qubit", wires=3)
982+
983+
@xdsl_from_qjit
984+
@qml.qjit(autograph=True, target="mlir")
985+
@qml.qnode(dev)
986+
def my_workflow():
987+
qml.X(0)
988+
qml.Z(1)
989+
qml.Y(0)
990+
qml.H(1)
991+
qml.S(1)
992+
qml.T(2)
993+
994+
module = my_workflow()
995+
996+
utility = ConstructCircuitDAG(FakeDAGBuilder())
997+
utility.construct(module)
998+
999+
edges = utility.dag_builder.edges
1000+
nodes = utility.dag_builder.nodes
1001+
1002+
# node0 -> NullQubit
1003+
1004+
# Check all nodes
1005+
assert "PauliX" in nodes["node1"]["label"]
1006+
assert "PauliZ" in nodes["node2"]["label"]
1007+
assert "PauliY" in nodes["node3"]["label"]
1008+
assert "Hadamard" in nodes["node4"]["label"]
1009+
assert "S" in nodes["node5"]["label"]
1010+
assert "T" in nodes["node6"]["label"]
1011+
1012+
# Check edges
1013+
# X -> Y
1014+
# Z -> H -> S
1015+
# T
1016+
assert len(edges) == 3
1017+
assert ("node1", "node3") in edges
1018+
assert ("node2", "node4") in edges
1019+
assert ("node4", "node5") in edges
9831020

9841021
@pytest.mark.unit
985-
def test_static_connection_through_clusters(self):
986-
"""Tests that connections can be made through nested clusters."""
987-
pass
1022+
def test_static_connection_through_for_loop(self):
1023+
"""Tests that connections can be made through a for loop cluster."""
1024+
1025+
dev = qml.device("null.qubit", wires=1)
1026+
1027+
@xdsl_from_qjit
1028+
@qml.qjit(autograph=True, target="mlir")
1029+
@qml.qnode(dev)
1030+
def my_workflow():
1031+
qml.X(0)
1032+
for i in range(3):
1033+
qml.Y(0)
1034+
1035+
module = my_workflow()
1036+
1037+
utility = ConstructCircuitDAG(FakeDAGBuilder())
1038+
utility.construct(module)
1039+
1040+
edges = utility.dag_builder.edges
1041+
nodes = utility.dag_builder.nodes
1042+
# node0 -> NullQubit
1043+
1044+
# Check all nodes
1045+
assert "PauliX" in nodes["node1"]["label"]
1046+
assert "PauliY" in nodes["node2"]["label"]
1047+
1048+
# Check edges
1049+
# for loop
1050+
# X ----------> Y
1051+
assert len(edges) == 1
1052+
assert ("node1", "node2") in edges
1053+
1054+
@pytest.mark.unit
1055+
def test_static_connection_through_while_loop(self):
1056+
"""Tests that connections can be made through a while loop cluster."""
1057+
1058+
dev = qml.device("null.qubit", wires=1)
1059+
1060+
@xdsl_from_qjit
1061+
@qml.qjit(autograph=True, target="mlir")
1062+
@qml.qnode(dev)
1063+
def my_workflow():
1064+
counter = 0
1065+
qml.X(0)
1066+
while counter < 5:
1067+
qml.Y(0)
1068+
counter += 1
1069+
1070+
module = my_workflow()
1071+
1072+
utility = ConstructCircuitDAG(FakeDAGBuilder())
1073+
utility.construct(module)
1074+
1075+
edges = utility.dag_builder.edges
1076+
nodes = utility.dag_builder.nodes
1077+
# node0 -> NullQubit
1078+
1079+
# Check all nodes
1080+
assert "PauliX" in nodes["node1"]["label"]
1081+
assert "PauliY" in nodes["node2"]["label"]
1082+
1083+
# Check edges
1084+
# for loop
1085+
# X ----------> Y
1086+
assert len(edges) == 1
1087+
assert ("node1", "node2") in edges
9881088

9891089
@pytest.mark.unit
9901090
def test_static_connection_through_conditional(self):
9911091
"""Tests that connections through conditionals make sense."""
992-
pass
1092+
1093+
dev = qml.device("null.qubit", wires=1)
1094+
1095+
@xdsl_from_qjit
1096+
@qml.qjit(autograph=True, target="mlir")
1097+
@qml.qnode(dev)
1098+
def my_workflow(x):
1099+
qml.X(0)
1100+
qml.T(1)
1101+
if x == 1:
1102+
qml.RX(0, 0)
1103+
qml.S(1)
1104+
elif x == 2:
1105+
qml.RY(0, 0)
1106+
else:
1107+
qml.RZ(0, 0)
1108+
qml.H(0)
1109+
1110+
module = my_workflow()
1111+
1112+
utility = ConstructCircuitDAG(FakeDAGBuilder())
1113+
utility.construct(module)
1114+
1115+
edges = utility.dag_builder.edges
1116+
nodes = utility.dag_builder.nodes
1117+
1118+
# node0 -> NullQubit
1119+
1120+
# Check all nodes
1121+
# NOTE: depth first traversal hence T first then PauliX
1122+
assert "T" in nodes["node1"]["label"]
1123+
assert "PauliX" in nodes["node2"]["label"]
1124+
assert "RX" in nodes["node3"]["label"]
1125+
assert "S" in nodes["node4"]["label"]
1126+
assert "RY" in nodes["node5"]["label"]
1127+
assert "RZ" in nodes["node6"]["label"]
1128+
assert "Hadamard" in nodes["node7"]["label"]
1129+
1130+
# Check all edges
1131+
assert len(edges) == 7
1132+
assert ("node1", "node4") in edges
1133+
assert ("node2", "node3") in edges
1134+
assert ("node2", "node5") in edges
1135+
assert ("node2", "node6") in edges
1136+
assert ("node3", "node7") in edges
1137+
assert ("node5", "node7") in edges
1138+
assert ("node6", "node7") in edges
9931139

9941140

9951141
class TestTerminalMeasurementConnectivity:

0 commit comments

Comments
 (0)