Skip to content

Commit b8c2619

Browse files
IR2Vec.cpp refactor
1 parent 76b3142 commit b8c2619

1 file changed

Lines changed: 125 additions & 87 deletions

File tree

src/IR2Vec.cpp

Lines changed: 125 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
#include <stdio.h>
1717
#include <time.h>
1818

19+
#include <ctime>
20+
#include <fstream>
21+
#include <string>
22+
#include <utility>
23+
1924
using namespace llvm;
2025
using namespace IR2Vec;
2126

@@ -73,9 +78,102 @@ void printVersion(raw_ostream &ostream) {
7378
cl::PrintVersionMessage();
7479
}
7580

76-
int main(int argc, char **argv) {
77-
cl::SetVersionPrinter(printVersion);
78-
cl::HideUnrelatedOptions(category);
81+
struct SymOutputs {
82+
std::ofstream out;
83+
};
84+
85+
struct FAOutputs : SymOutputs {
86+
std::ofstream miss;
87+
std::ofstream cyclic;
88+
};
89+
90+
inline SymOutputs openSymOutputs(const std::string &baseName) {
91+
SymOutputs f;
92+
f.out.open(baseName, std::ios_base::app);
93+
return f;
94+
}
95+
96+
inline FAOutputs openFAOutputs(const std::string &baseName) {
97+
FAOutputs f;
98+
f.out.open(baseName, std::ios_base::app);
99+
f.miss.open("missCount_" + baseName, std::ios_base::app);
100+
f.cyclic.open("cyclicCount_" + baseName, std::ios_base::app);
101+
return f;
102+
}
103+
104+
template <class F>
105+
inline void runMaybeTimed(bool shouldTime, const char *timingMsgFmt, F &&job) {
106+
if (shouldTime) {
107+
const clock_t start = clock();
108+
std::forward<F>(job)();
109+
const clock_t end = clock();
110+
const double elapsed = static_cast<double>(end - start) / CLOCKS_PER_SEC;
111+
std::printf(timingMsgFmt, elapsed);
112+
} else {
113+
std::forward<F>(job)();
114+
}
115+
}
116+
117+
template <class Encoder, class Outputs, class OutputsFactory, class Body>
118+
inline void executeEncoder(const char *timingMsgFmt, bool shouldTime,
119+
OutputsFactory &&makeOutputs, Body &&body) {
120+
auto M = getLLVMIR();
121+
auto vocabulary = VocabularyFactory::createVocabulary(DIM)->getVocabulary();
122+
Encoder encoder(*M, vocabulary);
123+
auto files = std::forward<OutputsFactory>(makeOutputs)(oname);
124+
125+
auto job = [&] { std::forward<Body>(body)(encoder, files); };
126+
runMaybeTimed(shouldTime, timingMsgFmt, job);
127+
}
128+
129+
void generateFAEncodingsFunction(std::string funcName) {
130+
executeEncoder<IR2Vec_FA, FAOutputs>(
131+
"Time taken by on-demand generation of flow-aware encodings is: %.6f "
132+
"seconds.\n",
133+
printTime, openFAOutputs, [&, funcName](IR2Vec_FA &FA, FAOutputs &files) {
134+
FA.generateFlowAwareEncodingsForFunction(&files.out, funcName,
135+
&files.miss, &files.cyclic);
136+
});
137+
}
138+
139+
void generateFAEncodings() {
140+
executeEncoder<IR2Vec_FA, FAOutputs>(
141+
"Time taken by normal generation of flow-aware encodings is: %.6f "
142+
"seconds.\n",
143+
printTime, openFAOutputs, [&](IR2Vec_FA &FA, FAOutputs &files) {
144+
FA.generateFlowAwareEncodings(&files.out, &files.miss, &files.cyclic);
145+
});
146+
}
147+
148+
void generateSymEncodingsFunction(std::string funcName) {
149+
executeEncoder<IR2Vec_Symbolic, SymOutputs>(
150+
"Time taken by on-demand generation of symbolic encodings is: %.6f "
151+
"seconds.\n",
152+
printTime, openSymOutputs,
153+
[&, funcName](IR2Vec_Symbolic &SYM, SymOutputs &files) {
154+
SYM.generateSymbolicEncodingsForFunction(&files.out, funcName);
155+
});
156+
}
157+
158+
void generateSYMEncodings() {
159+
executeEncoder<IR2Vec_Symbolic, SymOutputs>(
160+
"Time taken by normal generation of symbolic encodings is: %.6f "
161+
"seconds.\n",
162+
printTime, openSymOutputs, [&](IR2Vec_Symbolic &SYM, SymOutputs &files) {
163+
SYM.generateSymbolicEncodings(&files.out);
164+
});
165+
}
166+
167+
void collectIRfunc() {
168+
auto M = getLLVMIR();
169+
CollectIR cir(M);
170+
std::ofstream o;
171+
o.open(oname, std::ios_base::app);
172+
cir.generateTriplets(o);
173+
o.close();
174+
}
175+
176+
void setGlobalVars(int argc, char **argv) {
79177
cl::ParseCommandLineOptions(argc, argv);
80178

81179
fa = cl_fa;
@@ -92,113 +190,53 @@ int main(int argc, char **argv) {
92190
WT = cl_WT;
93191
debug = cl_debug;
94192
printTime = cl_printTime;
193+
}
95194

195+
void checkFailureConditions() {
96196
bool failed = false;
97-
if (!((sym ^ fa) ^ collectIR)) {
98-
errs() << "Either of sym, fa or collectIR should be specified\n";
197+
198+
if (!(sym || fa || collectIR)) {
199+
errs() << "Either of sym, fa, or collectIR should be specified\n";
99200
failed = true;
100201
}
101202

203+
if (failed)
204+
exit(1);
205+
102206
if (sym || fa) {
103207
if (level != 'p' && level != 'f') {
104208
errs() << "Invalid level specified: Use either p or f\n";
105209
failed = true;
106210
}
107211
} else {
108-
if (!collectIR) {
109-
errs() << "Either of sym, fa or collectIR should be specified\n";
110-
failed = true;
111-
} else if (level)
212+
assert(collectIR == true);
213+
214+
if (collectIR && level) {
112215
errs() << "[WARNING] level would not be used in collectIR mode\n";
216+
}
113217
}
114218

115219
if (failed)
116220
exit(1);
221+
}
117222

118-
auto M = getLLVMIR();
119-
auto vocabulary = VocabularyFactory::createVocabulary(DIM)->getVocabulary();
223+
int main(int argc, char **argv) {
224+
cl::SetVersionPrinter(printVersion);
225+
cl::HideUnrelatedOptions(category);
226+
setGlobalVars(argc, argv);
227+
checkFailureConditions();
120228

121-
// newly added
122229
if (sym && !(funcName.empty())) {
123-
IR2Vec_Symbolic SYM(*M, vocabulary);
124-
std::ofstream o;
125-
o.open(oname, std::ios_base::app);
126-
if (printTime) {
127-
clock_t start = clock();
128-
SYM.generateSymbolicEncodingsForFunction(&o, funcName);
129-
clock_t end = clock();
130-
double elapsed = double(end - start) / CLOCKS_PER_SEC;
131-
printf("Time taken by on-demand generation of symbolic encodings "
132-
"is: %.6f "
133-
"seconds.\n",
134-
elapsed);
135-
} else {
136-
SYM.generateSymbolicEncodingsForFunction(&o, funcName);
137-
}
138-
o.close();
230+
generateSymEncodingsFunction(funcName);
139231
} else if (fa && !(funcName.empty())) {
140-
IR2Vec_FA FA(*M, vocabulary);
141-
std::ofstream o, missCount, cyclicCount;
142-
o.open(oname, std::ios_base::app);
143-
missCount.open("missCount_" + oname, std::ios_base::app);
144-
cyclicCount.open("cyclicCount_" + oname, std::ios_base::app);
145-
if (printTime) {
146-
clock_t start = clock();
147-
FA.generateFlowAwareEncodingsForFunction(&o, funcName, &missCount,
148-
&cyclicCount);
149-
clock_t end = clock();
150-
double elapsed = double(end - start) / CLOCKS_PER_SEC;
151-
printf("Time taken by on-demand generation of flow-aware encodings "
152-
"is: %.6f "
153-
"seconds.\n",
154-
elapsed);
155-
} else {
156-
FA.generateFlowAwareEncodingsForFunction(&o, funcName, &missCount,
157-
&cyclicCount);
158-
}
159-
o.close();
232+
generateFAEncodingsFunction(funcName);
160233
} else if (fa) {
161-
IR2Vec_FA FA(*M, vocabulary);
162-
std::ofstream o, missCount, cyclicCount;
163-
o.open(oname, std::ios_base::app);
164-
missCount.open("missCount_" + oname, std::ios_base::app);
165-
cyclicCount.open("cyclicCount_" + oname, std::ios_base::app);
166-
if (printTime) {
167-
clock_t start = clock();
168-
FA.generateFlowAwareEncodings(&o, &missCount, &cyclicCount);
169-
clock_t end = clock();
170-
double elapsed = double(end - start) / CLOCKS_PER_SEC;
171-
printf("Time taken by normal generation of flow-aware encodings "
172-
"is: %.6f "
173-
"seconds.\n",
174-
elapsed);
175-
} else {
176-
FA.generateFlowAwareEncodings(&o, &missCount, &cyclicCount);
177-
}
178-
o.close();
234+
generateFAEncodings();
179235
} else if (sym) {
180-
IR2Vec_Symbolic SYM(*M, vocabulary);
181-
std::ofstream o;
182-
o.open(oname, std::ios_base::app);
183-
if (printTime) {
184-
clock_t start = clock();
185-
SYM.generateSymbolicEncodings(&o);
186-
clock_t end = clock();
187-
double elapsed = double(end - start) / CLOCKS_PER_SEC;
188-
printf("Time taken by normal generation of symbolic encodings is: "
189-
"%.6f "
190-
"seconds.\n",
191-
elapsed);
192-
} else {
193-
SYM.generateSymbolicEncodings(&o);
194-
}
195-
o.close();
236+
generateSYMEncodings();
196237
} else if (collectIR) {
197-
CollectIR cir(M);
198-
std::ofstream o;
199-
o.open(oname, std::ios_base::app);
200-
cir.generateTriplets(o);
201-
o.close();
238+
collectIRfunc();
202239
}
240+
203241
return 0;
204242
}

0 commit comments

Comments
 (0)