36 #if defined(GMM_USES_SUPERLU)
38 #ifndef GMM_SUPERLU_INTERFACE_H
39 #define GMM_SUPERLU_INTERFACE_H
58 #if defined(GMM_NO_SUPERLU_INCLUDE_SUBDIR)
59 #include "slu_Cnames.h"
60 #include "supermatrix.h"
62 #include "slu_scomplex.h"
63 #include "slu_dcomplex.h"
65 #include "superlu/slu_Cnames.h"
66 #include "superlu/supermatrix.h"
67 #include "superlu/slu_util.h"
68 #include "superlu/slu_scomplex.h"
69 #include "superlu/slu_dcomplex.h"
72 #if (SUPERLU_MAJOR_VERSION <= 6)
73 # define singlecomplex complex
86 sgssv(superlu_options_t *, SuperMatrix *,
int *,
int *, SuperMatrix *,
87 SuperMatrix *, SuperMatrix *, SuperLUStat_t *, int_t *info);
89 dgssv(superlu_options_t *, SuperMatrix *,
int *,
int *, SuperMatrix *,
90 SuperMatrix *, SuperMatrix *, SuperLUStat_t *, int_t *info);
92 cgssv(superlu_options_t *, SuperMatrix *,
int *,
int *, SuperMatrix *,
93 SuperMatrix *, SuperMatrix *, SuperLUStat_t *, int_t *info);
95 zgssv(superlu_options_t *, SuperMatrix *,
int *,
int *, SuperMatrix *,
96 SuperMatrix *, SuperMatrix *, SuperLUStat_t *, int_t *info);
98 sgssvx(superlu_options_t *, SuperMatrix *,
int *,
int *,
int *,
99 char *,
float *,
float *, SuperMatrix *, SuperMatrix *,
100 void *, int_t lwork, SuperMatrix *, SuperMatrix *,
101 float *,
float *,
float *,
float *,
102 GlobalLU_t *, mem_usage_t *, SuperLUStat_t *, int_t *info);
104 dgssvx(superlu_options_t *, SuperMatrix *,
int *,
int *,
int *,
105 char *,
double *,
double *, SuperMatrix *, SuperMatrix *,
106 void *, int_t lwork, SuperMatrix *, SuperMatrix *,
107 double *,
double *,
double *,
double *,
108 GlobalLU_t *, mem_usage_t *, SuperLUStat_t *, int_t *info);
110 cgssvx(superlu_options_t *, SuperMatrix *,
int *,
int *,
int *,
111 char *,
float *,
float *, SuperMatrix *, SuperMatrix *,
112 void *, int_t lwork, SuperMatrix *, SuperMatrix *,
113 float *,
float *,
float *,
float *,
114 GlobalLU_t *, mem_usage_t *, SuperLUStat_t *, int_t *info);
116 zgssvx(superlu_options_t *, SuperMatrix *,
int *,
int *,
int *,
117 char *,
double *,
double *, SuperMatrix *, SuperMatrix *,
118 void *, int_t lwork, SuperMatrix *, SuperMatrix *,
119 double *,
double *,
double *,
double *,
120 GlobalLU_t *, mem_usage_t *, SuperLUStat_t *, int_t *info);
122 sCreate_CompCol_Matrix(SuperMatrix *,
int,
int, int_t,
float *,
123 int_t *, int_t *, Stype_t, Dtype_t, Mtype_t);
125 dCreate_CompCol_Matrix(SuperMatrix *,
int,
int, int_t,
double *,
126 int_t *, int_t *, Stype_t, Dtype_t, Mtype_t);
128 cCreate_CompCol_Matrix(SuperMatrix *,
int,
int, int_t, singlecomplex *,
129 int_t *, int_t *, Stype_t, Dtype_t, Mtype_t);
131 zCreate_CompCol_Matrix(SuperMatrix *,
int,
int, int_t, doublecomplex *,
132 int_t *, int_t *, Stype_t, Dtype_t, Mtype_t);
134 sCreate_Dense_Matrix(SuperMatrix *,
int,
int,
float *,
int,
135 Stype_t, Dtype_t, Mtype_t);
137 dCreate_Dense_Matrix(SuperMatrix *,
int,
int,
double *,
int,
138 Stype_t, Dtype_t, Mtype_t);
140 cCreate_Dense_Matrix(SuperMatrix *,
int,
int, singlecomplex *,
int,
141 Stype_t, Dtype_t, Mtype_t);
143 zCreate_Dense_Matrix(SuperMatrix *,
int,
int, doublecomplex *,
int,
144 Stype_t, Dtype_t, Mtype_t);
152 inline void Create_CompCol_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
153 int nnz,
float *a,
int *ir,
int *jc) {
154 SuperLU::sCreate_CompCol_Matrix(A, m, n, nnz, a, ir, jc,
155 SuperLU::SLU_NC, SuperLU::SLU_S,
159 inline void Create_CompCol_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
160 int nnz,
double *a,
int *ir,
int *jc) {
161 SuperLU::dCreate_CompCol_Matrix(A, m, n, nnz, a, ir, jc,
162 SuperLU::SLU_NC, SuperLU::SLU_D,
166 inline void Create_CompCol_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
167 int nnz, std::complex<float> *a,
169 SuperLU::cCreate_CompCol_Matrix(A, m, n, nnz,
170 (SuperLU::singlecomplex *)(a),
171 ir, jc, SuperLU::SLU_NC, SuperLU::SLU_C,
175 inline void Create_CompCol_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
176 int nnz, std::complex<double> *a,
178 SuperLU::zCreate_CompCol_Matrix(A, m, n, nnz,
179 (SuperLU::doublecomplex *)(a), ir, jc,
180 SuperLU::SLU_NC, SuperLU::SLU_Z,
186 inline void Create_Dense_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
188 SuperLU::sCreate_Dense_Matrix(A, m, n, a, k, SuperLU::SLU_DN,
192 inline void Create_Dense_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
194 SuperLU::dCreate_Dense_Matrix(A, m, n, a, k, SuperLU::SLU_DN,
198 inline void Create_Dense_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
199 std::complex<float> *a,
int k) {
200 SuperLU::cCreate_Dense_Matrix(A, m, n,
201 (SuperLU::singlecomplex *)(a),
202 k, SuperLU::SLU_DN, SuperLU::SLU_C,
205 inline void Create_Dense_Matrix(SuperLU::SuperMatrix *A,
int m,
int n,
206 std::complex<double> *a,
int k) {
207 SuperLU::zCreate_Dense_Matrix(A, m, n, (SuperLU::doublecomplex *)(a),
208 k, SuperLU::SLU_DN, SuperLU::SLU_Z,
214 #define DECL_GSSV(FNAME,KEYTYPE) \
215 inline void SuperLU_gssv(SuperLU::superlu_options_t *options, \
216 SuperLU::SuperMatrix *A, int *p, int *q, \
217 SuperLU::SuperMatrix *L, \
218 SuperLU::SuperMatrix *U, \
219 SuperLU::SuperMatrix *B, \
220 SuperLU::SuperLUStat_t *stats, \
221 int *info, KEYTYPE) { \
222 SuperLU::FNAME(options, A, p, q, L, U, B, stats, info); \
225 DECL_GSSV(sgssv,
float)
226 DECL_GSSV(cgssv, std::complex<float>)
227 DECL_GSSV(dgssv,
double)
228 DECL_GSSV(zgssv, std::complex<double>)
232 #define DECL_GSSVX(FNAME,FLOATTYPE,KEYTYPE) \
233 inline float SuperLU_gssvx(SuperLU::superlu_options_t *options, \
234 SuperLU::SuperMatrix *A, \
235 int *perm_c, int *perm_r, int *etree, \
237 FLOATTYPE *R, FLOATTYPE *C, \
238 SuperLU::SuperMatrix *L, \
239 SuperLU::SuperMatrix *U, \
240 void *work, int lwork, \
241 SuperLU::SuperMatrix *B, \
242 SuperLU::SuperMatrix *X, \
243 FLOATTYPE *recip_pivot_growth, \
244 FLOATTYPE *rcond, FLOATTYPE *ferr, \
246 SuperLU::SuperLUStat_t *stats, \
247 int *info, KEYTYPE) { \
248 SuperLU::mem_usage_t mem_usage; \
249 SuperLU::GlobalLU_t Glu; \
250 SuperLU::FNAME(options, A, perm_c, perm_r, etree, equed, R, C, L, \
251 U, work, lwork, B, X, recip_pivot_growth, rcond, \
252 ferr, berr, &Glu, &mem_usage, stats, info); \
253 return mem_usage.for_lu; \
256 DECL_GSSVX(sgssvx,
float,
float)
257 DECL_GSSVX(cgssvx,
float, std::complex<float>)
258 DECL_GSSVX(dgssvx,
double,
double)
259 DECL_GSSVX(zgssvx,
double, std::complex<double>)
265 template <
typename MAT,
typename VECTX,
typename VECTB>
266 int SuperLU_solve(
const MAT &A,
const VECTX &X,
const VECTB &B,
267 double& rcond_,
int permc_spec = 3) {
275 typedef typename linalg_traits<MAT>::value_type T;
276 typedef typename number_traits<T>::magnitude_type R;
278 int m = int(mat_nrows(A)), n = int(mat_ncols(A)), nrhs = 1, info = 0;
280 csc_matrix<T> csc_A(m, n);
282 std::vector<T> rhs(m), sol(m);
285 int nz = int(
nnz(csc_A));
286 if ((2 * nz / n) >= m)
287 GMM_WARNING2(
"CAUTION : it seems that SuperLU has a problem"
288 " for nearly dense sparse matrices");
290 SuperLU::superlu_options_t options;
291 set_default_options(&options);
292 options.ColPerm = SuperLU::NATURAL;
293 options.PrintStat = SuperLU::NO;
294 options.ConditionNumber = SuperLU::YES;
295 switch (permc_spec) {
296 case 1 : options.ColPerm = SuperLU::MMD_ATA;
break;
297 case 2 : options.ColPerm = SuperLU::MMD_AT_PLUS_A;
break;
298 case 3 : options.ColPerm = SuperLU::COLAMD;
break;
300 SuperLU::SuperLUStat_t stat;
303 SuperLU::SuperMatrix SA, SL, SU, SB, SX;
304 Create_CompCol_Matrix(&SA, m, n, nz, (T *)(&csc_A.pr[0]),
305 (
int *)(&csc_A.ir[0]),
306 (
int *)(&csc_A.jc[0]));
307 Create_Dense_Matrix(&SB, m, nrhs, &rhs[0], m);
308 Create_Dense_Matrix(&SX, m, nrhs, &sol[0], m);
309 memset(&SL,0,
sizeof SL);
310 memset(&SU,0,
sizeof SU);
312 std::vector<int> etree(n);
314 std::vector<R> Rscale(m),Cscale(n);
315 std::vector<R> ferr(nrhs), berr(nrhs);
316 R recip_pivot_gross, rcond;
317 std::vector<int> perm_r(m), perm_c(n);
319 SuperLU_gssvx(&options, &SA, &perm_c[0], &perm_r[0],
335 if (SB.Store) Destroy_SuperMatrix_Store(&SB);
336 if (SX.Store) Destroy_SuperMatrix_Store(&SX);
337 if (SA.Store) Destroy_SuperMatrix_Store(&SA);
338 if (SL.Store) Destroy_SuperNode_Matrix(&SL);
339 if (SU.Store) Destroy_CompCol_Matrix(&SU);
341 GMM_ASSERT1(info != -333333333,
"SuperLU was cancelled.");
343 GMM_ASSERT1(info >= 0,
"SuperLU solve failed: info =" << info);
344 if (info > 0) GMM_WARNING1(
"SuperLU solve failed: info =" << info);
350 class SuperLU_factor {
351 typedef typename number_traits<T>::magnitude_type R;
354 mutable SuperLU::SuperMatrix SA, SL, SB, SU, SX;
355 mutable SuperLU::SuperLUStat_t stat;
356 mutable SuperLU::superlu_options_t options;
358 mutable std::vector<int> etree, perm_r, perm_c;
359 mutable std::vector<R> Rscale, Cscale;
360 mutable std::vector<R> ferr, berr;
361 mutable std::vector<T> rhs;
362 mutable std::vector<T> sol;
363 mutable bool is_init;
367 enum { LU_NOTRANSP, LU_TRANSP, LU_CONJUGATED };
368 void free_supermatrix() {
370 if (SB.Store) Destroy_SuperMatrix_Store(&SB);
371 if (SX.Store) Destroy_SuperMatrix_Store(&SX);
372 if (SA.Store) Destroy_SuperMatrix_Store(&SA);
373 if (SL.Store) Destroy_SuperNode_Matrix(&SL);
374 if (SU.Store) Destroy_CompCol_Matrix(&SU);
377 template <
class MAT>
void build_with(
const MAT &A,
int permc_spec = 3);
378 template <
typename VECTX,
typename VECTB>
382 void solve(
const VECTX &X_,
const VECTB &B,
int transp=LU_NOTRANSP)
const;
383 SuperLU_factor() { is_init =
false; }
384 SuperLU_factor(
const SuperLU_factor& other) {
385 GMM_ASSERT2(!(other.is_init),
386 "copy of initialized SuperLU_factor is forbidden");
389 SuperLU_factor& operator=(
const SuperLU_factor& other) {
390 GMM_ASSERT2(!(other.is_init) && !is_init,
391 "assignment of initialized SuperLU_factor is forbidden");
394 ~SuperLU_factor() { free_supermatrix(); }
395 float memsize() {
return memory_used; }
399 template <
class T>
template <
class MAT>
400 void SuperLU_factor<T>::build_with(
const MAT &A,
int permc_spec) {
409 int n = int(mat_nrows(A)), m = int(mat_ncols(A)), info = 0;
412 rhs.resize(m); sol.resize(m);
414 int nz = int(
nnz(csc_A));
416 set_default_options(&options);
417 options.ColPerm = SuperLU::NATURAL;
418 options.PrintStat = SuperLU::NO;
419 options.ConditionNumber = SuperLU::NO;
420 switch (permc_spec) {
421 case 1 : options.ColPerm = SuperLU::MMD_ATA;
break;
422 case 2 : options.ColPerm = SuperLU::MMD_AT_PLUS_A;
break;
423 case 3 : options.ColPerm = SuperLU::COLAMD;
break;
427 Create_CompCol_Matrix(&SA, m, n, nz, (T *)(&csc_A.pr[0]),
428 (
int *)(&csc_A.ir[0]),
429 (
int *)(&csc_A.jc[0]));
430 Create_Dense_Matrix(&SB, m, 0, &rhs[0], m);
431 Create_Dense_Matrix(&SX, m, 0, &sol[0], m);
432 memset(&SL,0,
sizeof SL);
433 memset(&SU,0,
sizeof SU);
435 Rscale.resize(m); Cscale.resize(n); etree.resize(n);
436 ferr.resize(1); berr.resize(1);
437 R recip_pivot_gross, rcond;
438 perm_r.resize(m); perm_c.resize(n);
439 memory_used = SuperLU_gssvx(&options, &SA, &perm_c[0], &perm_r[0],
456 Destroy_SuperMatrix_Store(&SB);
457 Destroy_SuperMatrix_Store(&SX);
458 Create_Dense_Matrix(&SB, m, 1, &rhs[0], m);
459 Create_Dense_Matrix(&SX, m, 1, &sol[0], m);
462 GMM_ASSERT1(info != -333333333,
"SuperLU was cancelled.");
463 GMM_ASSERT1(info == 0,
"SuperLU solve failed: info=" << info);
467 template <
class T>
template <
typename VECTX,
typename VECTB>
468 void SuperLU_factor<T>::solve(
const VECTX &X,
const VECTB &B,
471 options.Fact = SuperLU::FACTORED;
472 options.IterRefine = SuperLU::NOREFINE;
474 case LU_NOTRANSP: options.Trans = SuperLU::NOTRANS;
break;
475 case LU_TRANSP: options.Trans = SuperLU::TRANS;
break;
476 case LU_CONJUGATED: options.Trans = SuperLU::CONJ;
break;
477 default: GMM_ASSERT1(
false,
"invalid value for transposition option");
481 R recip_pivot_gross, rcond;
482 SuperLU_gssvx(&options, &SA, &perm_c[0], &perm_r[0],
498 GMM_ASSERT1(info == 0,
"SuperLU solve failed: info=" << info);
502 template <
typename T,
typename V1,
typename V2>
inline
503 void mult(
const SuperLU_factor<T>& P,
const V1 &v1,
const V2 &v2) {
507 template <
typename T,
typename V1,
typename V2>
inline
508 void transposed_mult(
const SuperLU_factor<T>& P,
const V1 &v1,
const V2 &v2) {
509 P.solve(v2, v1, SuperLU_factor<T>::LU_TRANSP);
size_type nnz(const L &l)
count the number of non-zero entries of a vector or matrix.
void copy(const L1 &l1, L2 &l2)
*/
void clear(L &l)
clear (fill with zeros) a vector or matrix.
void mult(const L1 &l1, const L2 &l2, L3 &l3)
*/
Include the base gmm files.