Skip to content

Commit 436beaa

Browse files
committed
move ionInfo to separate file
1 parent 81db109 commit 436beaa

File tree

2 files changed

+130
-98
lines changed

2 files changed

+130
-98
lines changed

mlir/include/Ion/IR/IonInfo.h

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Copyright 2025 Xanadu Quantum Technologies Inc.
2+
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <optional>
18+
#include <string>
19+
20+
#include "llvm/ADT/SmallVector.h"
21+
#include "llvm/ADT/StringMap.h"
22+
#include "llvm/ADT/StringRef.h"
23+
24+
#include "Ion/IR/IonOps.h"
25+
26+
namespace catalyst {
27+
namespace ion {
28+
29+
/// Helper class to store and query ion information from an IonOp.
30+
class IonInfo {
31+
public:
32+
struct TransitionInfo {
33+
std::string level0;
34+
std::string level1;
35+
double einstein_a;
36+
std::string multipole;
37+
};
38+
39+
private:
40+
llvm::StringMap<double> levelEnergyMap;
41+
llvm::SmallVector<TransitionInfo> transitions;
42+
43+
public:
44+
explicit IonInfo(ion::IonOp op)
45+
{
46+
auto levelAttrs = op.getLevels();
47+
auto transitionsAttr = op.getTransitions();
48+
49+
// Map from Level label to Energy value
50+
for (auto levelAttr : levelAttrs) {
51+
auto level = mlir::cast<LevelAttr>(levelAttr);
52+
std::string label = level.getLabel().getValue().str();
53+
double energy = level.getEnergy().getValueAsDouble();
54+
levelEnergyMap[label] = energy;
55+
}
56+
57+
// Store transition information
58+
for (auto transitionAttr : transitionsAttr) {
59+
auto transition = mlir::cast<TransitionAttr>(transitionAttr);
60+
TransitionInfo info;
61+
info.level0 = transition.getLevel_0().getValue().str();
62+
info.level1 = transition.getLevel_1().getValue().str();
63+
info.einstein_a = transition.getEinsteinA().getValueAsDouble();
64+
info.multipole = transition.getMultipole().getValue().str();
65+
transitions.push_back(info);
66+
}
67+
}
68+
69+
/// Get energy of a level by label
70+
std::optional<double> getLevelEnergy(llvm::StringRef label) const
71+
{
72+
auto it = levelEnergyMap.find(label.str());
73+
if (it != levelEnergyMap.end()) {
74+
return it->second;
75+
}
76+
return std::nullopt;
77+
}
78+
79+
/// Get level energy of a transition by index
80+
template <int IndexT>
81+
std::optional<double> getTransitionLevelEnergy(size_t transitionIndex) const
82+
{
83+
static_assert(IndexT == 0 || IndexT == 1, "IndexT must be 0 or 1");
84+
85+
if (transitionIndex >= transitions.size()) {
86+
return std::nullopt;
87+
}
88+
89+
const auto &transition = transitions[transitionIndex];
90+
if constexpr (IndexT == 0) {
91+
return getLevelEnergy(transition.level0);
92+
}
93+
else {
94+
return getLevelEnergy(transition.level1);
95+
}
96+
}
97+
98+
/// Get energy difference of a transition (level1 energy - level0 energy)
99+
std::optional<double> getTransitionEnergyDiff(size_t index) const
100+
{
101+
if (index >= transitions.size()) {
102+
return std::nullopt;
103+
}
104+
105+
auto energy0 = getTransitionLevelEnergy<0>(index);
106+
auto energy1 = getTransitionLevelEnergy<1>(index);
107+
108+
if (energy0.has_value() && energy1.has_value()) {
109+
return energy1.value() - energy0.value();
110+
}
111+
112+
return std::nullopt;
113+
}
114+
115+
/// Get number of transitions
116+
size_t getNumTransitions() const { return transitions.size(); }
117+
118+
/// Get transition info by index
119+
std::optional<TransitionInfo> getTransition(size_t index) const
120+
{
121+
if (index < transitions.size()) {
122+
return transitions[index];
123+
}
124+
return std::nullopt;
125+
}
126+
};
127+
128+
} // namespace ion
129+
} // namespace catalyst

mlir/lib/Ion/Transforms/ion-to-rtio.cpp

Lines changed: 1 addition & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3333

3434
#include "Ion/IR/IonDialect.h"
35+
#include "Ion/IR/IonInfo.h"
3536
#include "Ion/IR/IonOps.h"
3637
#include "Quantum/IR/QuantumDialect.h"
3738
#include "Quantum/IR/QuantumOps.h"
@@ -170,104 +171,6 @@ Value awaitEvents(ArrayRef<Value> events, PatternRewriter &rewriter)
170171
return rewriter.create<rtio::RTIOSyncOp>(rewriter.getUnknownLoc(), eventType, events);
171172
}
172173

173-
// Helper class to store ion information
174-
class IonInfo {
175-
private:
176-
llvm::StringMap<double> levelEnergyMap;
177-
178-
struct TransitionInfo {
179-
std::string level0;
180-
std::string level1;
181-
double einstein_a;
182-
std::string multipole;
183-
};
184-
SmallVector<TransitionInfo> transitions;
185-
186-
public:
187-
IonInfo(ion::IonOp op)
188-
{
189-
auto levelAttrs = op.getLevels();
190-
auto transitionsAttr = op.getTransitions();
191-
192-
// Map from Level label to Energy value
193-
for (auto levelAttr : levelAttrs) {
194-
auto level = cast<LevelAttr>(levelAttr);
195-
std::string label = level.getLabel().getValue().str();
196-
double energy = level.getEnergy().getValueAsDouble();
197-
levelEnergyMap[label] = energy;
198-
}
199-
200-
// Store transition information
201-
for (auto transitionAttr : transitionsAttr) {
202-
auto transition = cast<TransitionAttr>(transitionAttr);
203-
TransitionInfo info;
204-
info.level0 = transition.getLevel_0().getValue().str();
205-
info.level1 = transition.getLevel_1().getValue().str();
206-
info.einstein_a = transition.getEinsteinA().getValueAsDouble();
207-
info.multipole = transition.getMultipole().getValue().str();
208-
transitions.push_back(info);
209-
}
210-
}
211-
212-
// Get energy of a level by label
213-
std::optional<double> getLevelEnergy(StringRef label) const
214-
{
215-
auto it = levelEnergyMap.find(label.str());
216-
if (it != levelEnergyMap.end()) {
217-
return it->second;
218-
}
219-
return std::nullopt;
220-
}
221-
222-
// Get level label of a transition by index
223-
template <int IndexT>
224-
std::optional<double> getTransitionLevelEnergy(size_t transitionIndex) const
225-
{
226-
static_assert(IndexT == 0 || IndexT == 1, "IndexT must be 0 or 1");
227-
228-
if (transitionIndex >= transitions.size()) {
229-
return std::nullopt;
230-
}
231-
232-
const auto &transition = transitions[transitionIndex];
233-
if constexpr (IndexT == 0) {
234-
return getLevelEnergy(transition.level0);
235-
}
236-
else {
237-
return getLevelEnergy(transition.level1);
238-
}
239-
}
240-
241-
// Get energy difference of a transition (level1 energy - level0 energy)
242-
std::optional<double> getTransitionEnergyDiff(size_t index) const
243-
{
244-
if (index >= transitions.size()) {
245-
return std::nullopt;
246-
}
247-
248-
auto energy0 = getTransitionLevelEnergy<0>(index);
249-
auto energy1 = getTransitionLevelEnergy<1>(index);
250-
251-
if (energy0.has_value() && energy1.has_value()) {
252-
return energy1.value() - energy0.value();
253-
}
254-
255-
return std::nullopt;
256-
}
257-
258-
// Get number of transitions
259-
size_t getNumTransitions() const { return transitions.size(); }
260-
261-
// Get transition info by index
262-
std::optional<TransitionInfo> getTransition(size_t index) const
263-
{
264-
if (index < transitions.size()) {
265-
return transitions[index];
266-
}
267-
return std::nullopt;
268-
}
269-
};
270-
271174
} // namespace
272175

273176
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)