Skip to content

Commit caea2bf

Browse files
committed
LR: support zheevx, zheevr, elpa diago
1 parent 882cb6a commit caea2bf

9 files changed

Lines changed: 426 additions & 27 deletions

File tree

cmake/FindELPA.cmake

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,15 @@ if(ELPA_FOUND)
7878
IMPORTED_LOCATION "${ELPA_LIBRARY}"
7979
INTERFACE_INCLUDE_DIRECTORIES "${ELPA_INCLUDE_DIR}")
8080
endif()
81+
82+
if(TARGET ELPA::ELPA)
83+
get_filename_component(_elpa_libdir "${ELPA_LIBRARY}" DIRECTORY)
84+
if(_elpa_libdir)
85+
set_property(TARGET ELPA::ELPA APPEND PROPERTY
86+
INTERFACE_LINK_OPTIONS "-Wl,-rpath,${_elpa_libdir}")
87+
endif()
88+
unset(_elpa_libdir)
89+
endif()
8190
endif()
8291

8392
set(CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES} ${ELPA_INCLUDE_DIR})

source/module_hsolver/genelpa/elpa_new.cpp

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,32 @@ ELPA_Solver::ELPA_Solver(const bool isReal,
3333
this->nev = nev;
3434
this->narows = narows;
3535
this->nacols = nacols;
36-
for (int i = 0; i < 9; ++i)
37-
this->desc[i] = desc[i];
38-
cblacs_ctxt = desc[1];
39-
nFull = desc[2];
40-
nblk = desc[4];
41-
lda = desc[8];
36+
if (desc)
37+
{
38+
for (int i = 0; i < 9; ++i)
39+
this->desc[i] = desc[i];
40+
cblacs_ctxt = desc[1];
41+
nFull = desc[2];
42+
nblk = desc[4];
43+
lda = desc[8];
44+
}
45+
else
46+
{
47+
cblacs_ctxt = 0;
48+
nFull = std::max(narows, nacols);
49+
nblk = nFull;
50+
lda = narows;
51+
nprows = 1;
52+
npcols = 1;
53+
myprow = 0;
54+
mypcol = 0;
55+
}
4256
// cout<<"parameters are passed\n";
4357
MPI_Comm_rank(comm, &myid);
44-
Cblacs_gridinfo(cblacs_ctxt, &nprows, &npcols, &myprow, &mypcol);
58+
if (desc)
59+
{
60+
Cblacs_gridinfo(cblacs_ctxt, &nprows, &npcols, &myprow, &mypcol);
61+
}
4562
// cout<<"blacs grid is inited\n";
4663
allocate_work();
4764
// cout<<"work array is inited\n";
@@ -111,19 +128,39 @@ ELPA_Solver::ELPA_Solver(const bool isReal,
111128
this->nev = nev;
112129
this->narows = narows;
113130
this->nacols = nacols;
114-
for (int i = 0; i < 9; ++i)
115-
this->desc[i] = desc[i];
131+
if (desc)
132+
{
133+
for (int i = 0; i < 9; ++i)
134+
this->desc[i] = desc[i];
135+
}
116136

117137
kernel_id = otherParameter[0];
118138
useQR = otherParameter[1];
119139
loglevel = otherParameter[2];
120140

121-
cblacs_ctxt = desc[1];
122-
nFull = desc[2];
123-
nblk = desc[4];
124-
lda = desc[8];
141+
if (desc)
142+
{
143+
cblacs_ctxt = desc[1];
144+
nFull = desc[2];
145+
nblk = desc[4];
146+
lda = desc[8];
147+
}
148+
else
149+
{
150+
cblacs_ctxt = 0;
151+
nFull = std::max(narows, nacols);
152+
nblk = nFull;
153+
lda = narows;
154+
nprows = 1;
155+
npcols = 1;
156+
myprow = 0;
157+
mypcol = 0;
158+
}
125159
MPI_Comm_rank(comm, &myid);
126-
Cblacs_gridinfo(cblacs_ctxt, &nprows, &npcols, &myprow, &mypcol);
160+
if (desc)
161+
{
162+
Cblacs_gridinfo(cblacs_ctxt, &nprows, &npcols, &myprow, &mypcol);
163+
}
127164
allocate_work();
128165

129166
int error;

source/module_lr/AX/test/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,8 @@ AddTest(
33
TARGET AX_test
44
LIBS parameter base ${math_libs} container device psi
55
SOURCES AX_test.cpp ../../utils/lr_util.cpp ../AX_parallel.cpp ../AX_serial.cpp
6-
)
6+
)
7+
8+
if(USE_ELPA)
9+
target_link_libraries(AX_test ELPA::ELPA genelpa)
10+
endif()

source/module_lr/dm_trans/test/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,8 @@ AddTest(
1111
# ../../../module_base/module_container/base/core/cpu_allocator.cpp
1212
# ../../../module_base/module_container/base/core/refcount.cpp
1313
# ../../../module_base/module_container/ATen/kernels/memory_impl.cpp
14-
)
14+
)
15+
16+
if(USE_ELPA)
17+
target_link_libraries(dm_trans_test ELPA::ELPA genelpa)
18+
endif()

source/module_lr/hsolver_lrtd.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@ namespace LR
5454
std::vector<T> Amat_full = hm.matrix();
5555
const int gdim = std::sqrt(Amat_full.size());
5656
eigenvalue.resize(gdim);
57-
if (hermitian) { LR_Util::diag_lapack(gdim, Amat_full.data(), eigenvalue.data()); }
57+
if (hermitian)
58+
{
59+
LR_Util::diag_elpa(gdim, Amat_full.data(), eigenvalue.data());
60+
// LR_Util::diag_lapack_zheev(gdim, Amat_full.data(), eigenvalue.data());
61+
}
5862
else
5963
{
6064
std::vector<std::complex<double>> eig_complex(gdim);
@@ -184,4 +188,4 @@ namespace LR
184188
<< " ; where current threshold is: " << hsolver::DiagoIterAssist<T>::PW_DIAG_THR << " . " << std::endl;
185189
}
186190
}
187-
}
191+
}

source/module_lr/utils/lr_util.cpp

Lines changed: 185 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#include "lr_util.h"
33
#include "module_base/lapack_connector.h"
44
#include "module_base/scalapack_connector.h"
5+
#include "module_base/blacs_connector.h"
6+
#include "module_hsolver/genelpa/elpa_solver.h"
7+
#include "module_base/module_container/base/third_party/lapack.h"
58
namespace LR_Util
69
{
710
/// =================PHYSICS====================
@@ -115,9 +118,9 @@ namespace LR_Util
115118
}
116119
#endif
117120

118-
void diag_lapack(const int& n, double* mat, double* eig)
121+
void diag_lapack_zheev(const int& n, double* mat, double* eig)
119122
{
120-
ModuleBase::TITLE("LR_Util", "diag_lapack<double>");
123+
ModuleBase::TITLE("LR_Util", "diag_lapack_zheev<double>");
121124
int info = 0;
122125
char jobz = 'V', uplo = 'U';
123126
double work_tmp;
@@ -130,9 +133,9 @@ namespace LR_Util
130133
delete[] work2;
131134
}
132135

133-
void diag_lapack(const int& n, std::complex<double>* mat, double* eig)
136+
void diag_lapack_zheev(const int& n, std::complex<double>* mat, double* eig)
134137
{
135-
ModuleBase::TITLE("LR_Util", "diag_lapack<complex<double>>");
138+
ModuleBase::TITLE("LR_Util", "diag_lapack_zheev<complex<double>>");
136139
int lwork = 2 * n;
137140
std::complex<double>* work2 = new std::complex<double>[lwork];
138141
double* rwork = new double[3 * n - 2];
@@ -144,6 +147,183 @@ namespace LR_Util
144147
delete[] work2;
145148
}
146149

150+
void diag_lapack_zheevx(const int& n, double* mat, double* eig)
151+
{
152+
ModuleBase::TITLE("LR_Util", "diag_lapack_zheevx<double>");
153+
int info = 0;
154+
char jobz = 'V', range = 'A', uplo = 'U';
155+
const double vl = 0.0, vu = 0.0, abstol = 0.0;
156+
const int il = 0, iu = 0;
157+
int m = 0;
158+
const int ldz = n;
159+
double* z = new double[ldz * n];
160+
const int lwork = std::max(8 * n, 1);
161+
const int lrwork = std::max(7 * n, 1);
162+
const int liwork = std::max(5 * n, 1);
163+
double* work = new double[lwork];
164+
double* rwork = new double[lrwork];
165+
int* iwork = new int[liwork];
166+
int* ifail = new int[n];
167+
dsyevx_(&jobz, &range, &uplo, &n, mat, &n, &vl, &vu, &il, &iu, &abstol, &m, eig, z, &ldz,
168+
work, &lwork, rwork, iwork, ifail, &info);
169+
if (info) { std::cout << "ERROR: Lapack solver dsyevx, info=" << info << std::endl; }
170+
std::copy(z, z + ldz * n, mat);
171+
delete[] ifail;
172+
delete[] iwork;
173+
delete[] rwork;
174+
delete[] work;
175+
delete[] z;
176+
}
177+
178+
void diag_lapack_zheevx(const int& n, std::complex<double>* mat, double* eig)
179+
{
180+
ModuleBase::TITLE("LR_Util", "diag_lapack_zheevx<complex<double>>");
181+
int info = 0;
182+
char jobz = 'V', range = 'A', uplo = 'U';
183+
const double vl = 0.0, vu = 0.0, abstol = 0.0;
184+
const int il = 0, iu = 0;
185+
int m = 0;
186+
const int ldz = n;
187+
std::complex<double>* z = new std::complex<double>[ldz * n];
188+
const int lwork = std::max(2 * n, 1);
189+
const int lrwork = std::max(7 * n, 1);
190+
const int liwork = std::max(5 * n, 1);
191+
std::complex<double>* work = new std::complex<double>[lwork];
192+
double* rwork = new double[lrwork];
193+
int* iwork = new int[liwork];
194+
int* ifail = new int[n];
195+
zheevx_(&jobz, &range, &uplo, &n, mat, &n, &vl, &vu, &il, &iu, &abstol, &m, eig, z, &ldz,
196+
work, &lwork, rwork, iwork, ifail, &info);
197+
if (info) { std::cout << "ERROR: Lapack solver zheevx, info=" << info << std::endl; }
198+
std::copy(z, z + ldz * n, mat);
199+
delete[] ifail;
200+
delete[] iwork;
201+
delete[] rwork;
202+
delete[] work;
203+
delete[] z;
204+
}
205+
206+
extern "C" {
207+
void dsyevr_(const char* jobz, const char* range, const char* uplo, const int* n,
208+
double* a, const int* lda, const double* vl, const double* vu, const int* il, const int* iu,
209+
const double* abstol, int* m, double* w, double* z, const int* ldz, int* isuppz,
210+
double* work, const int* lwork, int* iwork, const int* liwork, int* info);
211+
void zheevr_(const char* jobz, const char* range, const char* uplo, const int* n,
212+
std::complex<double>* a, const int* lda, const double* vl, const double* vu, const int* il, const int* iu,
213+
const double* abstol, int* m, double* w, std::complex<double>* z, const int* ldz, int* isuppz,
214+
std::complex<double>* work, const int* lwork, double* rwork, const int* lrwork, int* iwork, const int* liwork, int* info);
215+
}
216+
217+
void diag_lapack_zheevr(const int& n, double* mat, double* eig)
218+
{
219+
ModuleBase::TITLE("LR_Util", "diag_lapack_zheevr<double>");
220+
int info = 0;
221+
char jobz = 'V', range = 'A', uplo = 'U';
222+
const double vl = 0.0, vu = 0.0, abstol = 0.0;
223+
const int il = 0, iu = 0;
224+
int m = 0;
225+
const int ldz = n;
226+
double* z = new double[ldz * n];
227+
int* isuppz = new int[2 * n];
228+
const int lwork = std::max(26 * n, 1);
229+
const int liwork = std::max(10 * n, 1);
230+
double* work = new double[lwork];
231+
int* iwork = new int[liwork];
232+
dsyevr_(&jobz, &range, &uplo, &n, mat, &n, &vl, &vu, &il, &iu, &abstol, &m, eig, z, &ldz, isuppz,
233+
work, &lwork, iwork, &liwork, &info);
234+
if (info) { std::cout << "ERROR: Lapack solver dsyevr, info=" << info << std::endl; }
235+
std::copy(z, z + ldz * n, mat);
236+
delete[] iwork;
237+
delete[] work;
238+
delete[] isuppz;
239+
delete[] z;
240+
}
241+
242+
void diag_lapack_zheevr(const int& n, std::complex<double>* mat, double* eig)
243+
{
244+
ModuleBase::TITLE("LR_Util", "diag_lapack_zheevr<complex<double>>");
245+
int info = 0;
246+
char jobz = 'V', range = 'A', uplo = 'U';
247+
const double vl = 0.0, vu = 0.0, abstol = 0.0;
248+
const int il = 0, iu = 0;
249+
int m = 0;
250+
const int ldz = n;
251+
std::complex<double>* z = new std::complex<double>[ldz * n];
252+
int* isuppz = new int[2 * n];
253+
const int lwork = std::max(2 * n, 1);
254+
const int lrwork = std::max(24 * n, 1);
255+
const int liwork = std::max(10 * n, 1);
256+
std::complex<double>* work = new std::complex<double>[lwork];
257+
double* rwork = new double[lrwork];
258+
int* iwork = new int[liwork];
259+
zheevr_(&jobz, &range, &uplo, &n, mat, &n, &vl, &vu, &il, &iu, &abstol, &m, eig, z, &ldz, isuppz,
260+
work, &lwork, rwork, &lrwork, iwork, &liwork, &info);
261+
if (info) { std::cout << "ERROR: Lapack solver zheevr, info=" << info << std::endl; }
262+
std::copy(z, z + ldz * n, mat);
263+
delete[] iwork;
264+
delete[] rwork;
265+
delete[] work;
266+
delete[] isuppz;
267+
delete[] z;
268+
}
269+
270+
void diag_elpa(const int& n, double* mat, double* eig)
271+
{
272+
ModuleBase::TITLE("LR_Util", "diag_elpa<double>");
273+
#ifdef __MPI
274+
int ctxt = Csys2blacs_handle(MPI_COMM_WORLD);
275+
char layout = 'R';
276+
Cblacs_gridinit(&ctxt, &layout, 1, 1);
277+
int desc[9];
278+
int info = 0;
279+
const int m = n, nb = std::max(1, std::min(n, 128));
280+
const int rsrc = 0, csrc = 0;
281+
const int lda = n;
282+
descinit_(desc, &m, &m, &nb, &nb, &rsrc, &csrc, &ctxt, &lda, &info);
283+
if (info) { std::cout << "ERROR: descinit in diag_elpa<double>, info=" << info << std::endl; }
284+
ELPA_Solver es(true, MPI_COMM_WORLD, n, n, n, desc);
285+
std::vector<double> vec(n * n);
286+
es.eigenvector(mat, eig, vec.data());
287+
es.exit();
288+
std::copy(vec.begin(), vec.end(), mat);
289+
Cblacs_gridexit(&ctxt);
290+
#else
291+
ELPA_Solver es(true, MPI_COMM_WORLD, n, n, n, nullptr);
292+
std::vector<double> vec(n * n);
293+
es.eigenvector(mat, eig, vec.data());
294+
es.exit();
295+
std::copy(vec.begin(), vec.end(), mat);
296+
#endif
297+
}
298+
299+
void diag_elpa(const int& n, std::complex<double>* mat, double* eig)
300+
{
301+
ModuleBase::TITLE("LR_Util", "diag_elpa<complex<double>>");
302+
#ifdef __MPI
303+
int ctxt = Csys2blacs_handle(MPI_COMM_WORLD);
304+
char layout = 'R';
305+
Cblacs_gridinit(&ctxt, &layout, 1, 1);
306+
int desc[9];
307+
int info = 0;
308+
const int m = n, nb = std::max(1, std::min(n, 128));
309+
const int rsrc = 0, csrc = 0;
310+
const int lda = n;
311+
descinit_(desc, &m, &m, &nb, &nb, &rsrc, &csrc, &ctxt, &lda, &info);
312+
if (info) { std::cout << "ERROR: descinit in diag_elpa<complex<double>>, info=" << info << std::endl; }
313+
ELPA_Solver es(false, MPI_COMM_WORLD, n, n, n, desc);
314+
std::vector<std::complex<double>> vec(n * n);
315+
es.eigenvector(mat, eig, vec.data());
316+
es.exit();
317+
std::copy(vec.begin(), vec.end(), mat);
318+
Cblacs_gridexit(&ctxt);
319+
#else
320+
ELPA_Solver es(false, MPI_COMM_WORLD, n, n, n, nullptr);
321+
std::vector<std::complex<double>> vec(n * n);
322+
es.eigenvector(mat, eig, vec.data());
323+
es.exit();
324+
std::copy(vec.begin(), vec.end(), mat);
325+
#endif
326+
}
147327
void diag_lapack_nh(const int& n, double* mat, std::complex<double>* eig)
148328
{
149329
ModuleBase::TITLE("LR_Util", "diag_lapack_nh<double>");
@@ -192,4 +372,4 @@ namespace LR_Util
192372
std::transform(str_upper.begin(), str_upper.end(), str_upper.begin(), ::toupper);
193373
return str_upper;
194374
}
195-
}
375+
}

source/module_lr/utils/lr_util.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,14 @@ namespace LR_Util
9898

9999
///=================diago-lapack====================
100100
/// @brief diagonalize a hermitian matrix
101-
void diag_lapack(const int& n, double* mat, double* eig);
102-
void diag_lapack(const int& n, std::complex<double>* mat, double* eig);
101+
void diag_lapack_zheev(const int& n, double* mat, double* eig);
102+
void diag_lapack_zheev(const int& n, std::complex<double>* mat, double* eig);
103+
void diag_lapack_zheevx(const int& n, double* mat, double* eig);
104+
void diag_lapack_zheevx(const int& n, std::complex<double>* mat, double* eig);
105+
void diag_lapack_zheevr(const int& n, double* mat, double* eig);
106+
void diag_lapack_zheevr(const int& n, std::complex<double>* mat, double* eig);
107+
void diag_elpa(const int& n, double* mat, double* eig);
108+
void diag_elpa(const int& n, std::complex<double>* mat, double* eig);
103109
/// @brief diagonalize a general matrix
104110
void diag_lapack_nh(const int& n, double* mat, std::complex<double>* eig);
105111
void diag_lapack_nh(const int& n, std::complex<double>* mat, std::complex<double>* eig);

source/module_lr/utils/test/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,9 @@ AddTest(
1010
TARGET lr_util_algo_test
1111
LIBS parameter base ${math_libs} device psi container planewave #for FFT
1212
SOURCES lr_util_algorithms_test.cpp ../lr_util.cpp
13-
)
13+
)
14+
15+
if(USE_ELPA)
16+
target_link_libraries(lr_util_phys_test ELPA::ELPA genelpa)
17+
target_link_libraries(lr_util_algo_test ELPA::ELPA genelpa)
18+
endif()

0 commit comments

Comments
 (0)