22#include " lr_util.h"
33#include " module_base/lapack_connector.h"
44#include " module_base/scalapack_connector.h"
5+ #include " module_base/blacs_connector.h"
6+ #include " module_hsolver/genelpa/elpa_solver.h"
7+ #include " module_base/module_container/base/third_party/lapack.h"
58namespace LR_Util
69{
710 // / =================PHYSICS====================
@@ -115,9 +118,9 @@ namespace LR_Util
115118 }
116119#endif
117120
118- void diag_lapack (const int & n, double * mat, double * eig)
121+ void diag_lapack_zheev (const int & n, double * mat, double * eig)
119122 {
120- ModuleBase::TITLE (" LR_Util" , " diag_lapack <double>" );
123+ ModuleBase::TITLE (" LR_Util" , " diag_lapack_zheev <double>" );
121124 int info = 0 ;
122125 char jobz = ' V' , uplo = ' U' ;
123126 double work_tmp;
@@ -130,9 +133,9 @@ namespace LR_Util
130133 delete[] work2;
131134 }
132135
133- void diag_lapack (const int & n, std::complex <double >* mat, double * eig)
136+ void diag_lapack_zheev (const int & n, std::complex <double >* mat, double * eig)
134137 {
135- ModuleBase::TITLE (" LR_Util" , " diag_lapack <complex<double>>" );
138+ ModuleBase::TITLE (" LR_Util" , " diag_lapack_zheev <complex<double>>" );
136139 int lwork = 2 * n;
137140 std::complex <double >* work2 = new std::complex <double >[lwork];
138141 double * rwork = new double [3 * n - 2 ];
@@ -144,6 +147,183 @@ namespace LR_Util
144147 delete[] work2;
145148 }
146149
150+ void diag_lapack_zheevx (const int & n, double * mat, double * eig)
151+ {
152+ ModuleBase::TITLE (" LR_Util" , " diag_lapack_zheevx<double>" );
153+ int info = 0 ;
154+ char jobz = ' V' , range = ' A' , uplo = ' U' ;
155+ const double vl = 0.0 , vu = 0.0 , abstol = 0.0 ;
156+ const int il = 0 , iu = 0 ;
157+ int m = 0 ;
158+ const int ldz = n;
159+ double * z = new double [ldz * n];
160+ const int lwork = std::max (8 * n, 1 );
161+ const int lrwork = std::max (7 * n, 1 );
162+ const int liwork = std::max (5 * n, 1 );
163+ double * work = new double [lwork];
164+ double * rwork = new double [lrwork];
165+ int * iwork = new int [liwork];
166+ int * ifail = new int [n];
167+ dsyevx_ (&jobz, &range, &uplo, &n, mat, &n, &vl, &vu, &il, &iu, &abstol, &m, eig, z, &ldz,
168+ work, &lwork, rwork, iwork, ifail, &info);
169+ if (info) { std::cout << " ERROR: Lapack solver dsyevx, info=" << info << std::endl; }
170+ std::copy (z, z + ldz * n, mat);
171+ delete[] ifail;
172+ delete[] iwork;
173+ delete[] rwork;
174+ delete[] work;
175+ delete[] z;
176+ }
177+
178+ void diag_lapack_zheevx (const int & n, std::complex <double >* mat, double * eig)
179+ {
180+ ModuleBase::TITLE (" LR_Util" , " diag_lapack_zheevx<complex<double>>" );
181+ int info = 0 ;
182+ char jobz = ' V' , range = ' A' , uplo = ' U' ;
183+ const double vl = 0.0 , vu = 0.0 , abstol = 0.0 ;
184+ const int il = 0 , iu = 0 ;
185+ int m = 0 ;
186+ const int ldz = n;
187+ std::complex <double >* z = new std::complex <double >[ldz * n];
188+ const int lwork = std::max (2 * n, 1 );
189+ const int lrwork = std::max (7 * n, 1 );
190+ const int liwork = std::max (5 * n, 1 );
191+ std::complex <double >* work = new std::complex <double >[lwork];
192+ double * rwork = new double [lrwork];
193+ int * iwork = new int [liwork];
194+ int * ifail = new int [n];
195+ zheevx_ (&jobz, &range, &uplo, &n, mat, &n, &vl, &vu, &il, &iu, &abstol, &m, eig, z, &ldz,
196+ work, &lwork, rwork, iwork, ifail, &info);
197+ if (info) { std::cout << " ERROR: Lapack solver zheevx, info=" << info << std::endl; }
198+ std::copy (z, z + ldz * n, mat);
199+ delete[] ifail;
200+ delete[] iwork;
201+ delete[] rwork;
202+ delete[] work;
203+ delete[] z;
204+ }
205+
206+ extern " C" {
207+ void dsyevr_ (const char * jobz, const char * range, const char * uplo, const int * n,
208+ double * a, const int * lda, const double * vl, const double * vu, const int * il, const int * iu,
209+ const double * abstol, int * m, double * w, double * z, const int * ldz, int * isuppz,
210+ double * work, const int * lwork, int * iwork, const int * liwork, int * info);
211+ void zheevr_ (const char * jobz, const char * range, const char * uplo, const int * n,
212+ std::complex <double >* a, const int * lda, const double * vl, const double * vu, const int * il, const int * iu,
213+ const double * abstol, int * m, double * w, std::complex <double >* z, const int * ldz, int * isuppz,
214+ std::complex <double >* work, const int * lwork, double * rwork, const int * lrwork, int * iwork, const int * liwork, int * info);
215+ }
216+
217+ void diag_lapack_zheevr (const int & n, double * mat, double * eig)
218+ {
219+ ModuleBase::TITLE (" LR_Util" , " diag_lapack_zheevr<double>" );
220+ int info = 0 ;
221+ char jobz = ' V' , range = ' A' , uplo = ' U' ;
222+ const double vl = 0.0 , vu = 0.0 , abstol = 0.0 ;
223+ const int il = 0 , iu = 0 ;
224+ int m = 0 ;
225+ const int ldz = n;
226+ double * z = new double [ldz * n];
227+ int * isuppz = new int [2 * n];
228+ const int lwork = std::max (26 * n, 1 );
229+ const int liwork = std::max (10 * n, 1 );
230+ double * work = new double [lwork];
231+ int * iwork = new int [liwork];
232+ dsyevr_ (&jobz, &range, &uplo, &n, mat, &n, &vl, &vu, &il, &iu, &abstol, &m, eig, z, &ldz, isuppz,
233+ work, &lwork, iwork, &liwork, &info);
234+ if (info) { std::cout << " ERROR: Lapack solver dsyevr, info=" << info << std::endl; }
235+ std::copy (z, z + ldz * n, mat);
236+ delete[] iwork;
237+ delete[] work;
238+ delete[] isuppz;
239+ delete[] z;
240+ }
241+
242+ void diag_lapack_zheevr (const int & n, std::complex <double >* mat, double * eig)
243+ {
244+ ModuleBase::TITLE (" LR_Util" , " diag_lapack_zheevr<complex<double>>" );
245+ int info = 0 ;
246+ char jobz = ' V' , range = ' A' , uplo = ' U' ;
247+ const double vl = 0.0 , vu = 0.0 , abstol = 0.0 ;
248+ const int il = 0 , iu = 0 ;
249+ int m = 0 ;
250+ const int ldz = n;
251+ std::complex <double >* z = new std::complex <double >[ldz * n];
252+ int * isuppz = new int [2 * n];
253+ const int lwork = std::max (2 * n, 1 );
254+ const int lrwork = std::max (24 * n, 1 );
255+ const int liwork = std::max (10 * n, 1 );
256+ std::complex <double >* work = new std::complex <double >[lwork];
257+ double * rwork = new double [lrwork];
258+ int * iwork = new int [liwork];
259+ zheevr_ (&jobz, &range, &uplo, &n, mat, &n, &vl, &vu, &il, &iu, &abstol, &m, eig, z, &ldz, isuppz,
260+ work, &lwork, rwork, &lrwork, iwork, &liwork, &info);
261+ if (info) { std::cout << " ERROR: Lapack solver zheevr, info=" << info << std::endl; }
262+ std::copy (z, z + ldz * n, mat);
263+ delete[] iwork;
264+ delete[] rwork;
265+ delete[] work;
266+ delete[] isuppz;
267+ delete[] z;
268+ }
269+
270+ void diag_elpa (const int & n, double * mat, double * eig)
271+ {
272+ ModuleBase::TITLE (" LR_Util" , " diag_elpa<double>" );
273+ #ifdef __MPI
274+ int ctxt = Csys2blacs_handle (MPI_COMM_WORLD);
275+ char layout = ' R' ;
276+ Cblacs_gridinit (&ctxt, &layout, 1 , 1 );
277+ int desc[9 ];
278+ int info = 0 ;
279+ const int m = n, nb = std::max (1 , std::min (n, 128 ));
280+ const int rsrc = 0 , csrc = 0 ;
281+ const int lda = n;
282+ descinit_ (desc, &m, &m, &nb, &nb, &rsrc, &csrc, &ctxt, &lda, &info);
283+ if (info) { std::cout << " ERROR: descinit in diag_elpa<double>, info=" << info << std::endl; }
284+ ELPA_Solver es (true , MPI_COMM_WORLD, n, n, n, desc);
285+ std::vector<double > vec (n * n);
286+ es.eigenvector (mat, eig, vec.data ());
287+ es.exit ();
288+ std::copy (vec.begin (), vec.end (), mat);
289+ Cblacs_gridexit (&ctxt);
290+ #else
291+ ELPA_Solver es (true , MPI_COMM_WORLD, n, n, n, nullptr );
292+ std::vector<double > vec (n * n);
293+ es.eigenvector (mat, eig, vec.data ());
294+ es.exit ();
295+ std::copy (vec.begin (), vec.end (), mat);
296+ #endif
297+ }
298+
299+ void diag_elpa (const int & n, std::complex <double >* mat, double * eig)
300+ {
301+ ModuleBase::TITLE (" LR_Util" , " diag_elpa<complex<double>>" );
302+ #ifdef __MPI
303+ int ctxt = Csys2blacs_handle (MPI_COMM_WORLD);
304+ char layout = ' R' ;
305+ Cblacs_gridinit (&ctxt, &layout, 1 , 1 );
306+ int desc[9 ];
307+ int info = 0 ;
308+ const int m = n, nb = std::max (1 , std::min (n, 128 ));
309+ const int rsrc = 0 , csrc = 0 ;
310+ const int lda = n;
311+ descinit_ (desc, &m, &m, &nb, &nb, &rsrc, &csrc, &ctxt, &lda, &info);
312+ if (info) { std::cout << " ERROR: descinit in diag_elpa<complex<double>>, info=" << info << std::endl; }
313+ ELPA_Solver es (false , MPI_COMM_WORLD, n, n, n, desc);
314+ std::vector<std::complex <double >> vec (n * n);
315+ es.eigenvector (mat, eig, vec.data ());
316+ es.exit ();
317+ std::copy (vec.begin (), vec.end (), mat);
318+ Cblacs_gridexit (&ctxt);
319+ #else
320+ ELPA_Solver es (false , MPI_COMM_WORLD, n, n, n, nullptr );
321+ std::vector<std::complex <double >> vec (n * n);
322+ es.eigenvector (mat, eig, vec.data ());
323+ es.exit ();
324+ std::copy (vec.begin (), vec.end (), mat);
325+ #endif
326+ }
147327 void diag_lapack_nh (const int & n, double * mat, std::complex <double >* eig)
148328 {
149329 ModuleBase::TITLE (" LR_Util" , " diag_lapack_nh<double>" );
@@ -192,4 +372,4 @@ namespace LR_Util
192372 std::transform (str_upper.begin (), str_upper.end (), str_upper.begin (), ::toupper);
193373 return str_upper;
194374 }
195- }
375+ }
0 commit comments