branch: lapack-interface-simplifications commit b8f8a7f2548ef0de60b0ac5bb24b5e37b4ae30dc Author: Konstantinos Poulios <logar...@gmail.com> AuthorDate: Sun Dec 18 22:53:01 2022 +0100
Drop runtime handling of int64 lapack and fix lapack type at compile time --- src/gmm/gmm_dense_lu.h | 68 +++++++++++------------------------------- src/gmm/gmm_lapack_interface.h | 14 ++++----- src/gmm/gmm_opt.h | 2 +- 3 files changed, 24 insertions(+), 60 deletions(-) diff --git a/src/gmm/gmm_dense_lu.h b/src/gmm/gmm_dense_lu.h index 33fbbcb1..dd3470e2 100644 --- a/src/gmm/gmm_dense_lu.h +++ b/src/gmm/gmm_dense_lu.h @@ -73,47 +73,11 @@ namespace gmm { - /* ********************************************************************** */ - /* IPVT structure. */ - /* ********************************************************************** */ - // For compatibility with lapack version with 64 or 32 bit integer. - // Should be replaced by std::vector<size_type> if 32 bit integer version - // of lapack is not used anymore (and lapack_ipvt_int set to size_type) - - // Do not use iterators of this interface container - class lapack_ipvt : public std::vector<size_type> { - bool is_int64; - size_type &operator[](size_type i) - { return std::vector<size_type>::operator[](i); } - size_type operator[] (size_type i) const - { return std::vector<size_type>::operator[](i); } - void begin(void) const {} - void begin(void) {} - void end(void) const {} - void end(void) {} - - public: - void set_to_int32() { is_int64 = false; } - const size_type *pfirst() const - { return &(*(std::vector<size_type>::begin())); } - size_type *pfirst() { return &(*(std::vector<size_type>::begin())); } - - lapack_ipvt(size_type n) : std::vector<size_type>(n), is_int64(true) {} - - size_type get(size_type i) const { - const size_type *p = pfirst(); - return is_int64 ? p[i] : size_type(((const int *)(p))[i]); - } - void set(size_type i, size_type val) { - size_type *p = pfirst(); - if (is_int64) p[i] = val; else ((int *)(p))[i] = int(val); - } - }; -} - -#include "gmm_opt.h" - -namespace gmm { +#if defined(GMM_USES_LAPACK) || defined(GMM_USES_ATLAS) + typedef std::vector<BLAS_INT> lapack_ipvt; +#else + typedef std::vector<size_type> lapack_ipvt; +#endif /** LU Factorization of a general (dense) matrix (real or complex). @@ -125,23 +89,24 @@ namespace gmm { The pivot indices in ipvt are indexed starting from 1 so that this is compatible with LAPACK (Fortran). */ - template <typename DenseMatrix> - size_type lu_factor(DenseMatrix& A, lapack_ipvt& ipvt) { + template <typename DenseMatrix, typename Pvector> + size_type lu_factor(DenseMatrix& A, Pvector& ipvt) { typedef typename linalg_traits<DenseMatrix>::value_type T; + typedef typename linalg_traits<Pvector>::value_type INT; typedef typename number_traits<T>::magnitude_type R; size_type info(0), i, j, jp, M(mat_nrows(A)), N(mat_ncols(A)); size_type NN = std::min(M, N); std::vector<T> c(M), r(N); GMM_ASSERT2(ipvt.size()+1 >= NN, "IPVT too small"); - for (i = 0; i+1 < NN; ++i) ipvt.set(i, i); + for (i = 0; i+1 < NN; ++i) ipvt[i] = INT(i); if (M || N) { for (j = 0; j+1 < NN; ++j) { R max = gmm::abs(A(j,j)); jp = j; for (i = j+1; i < M; ++i) /* find pivot. */ if (gmm::abs(A(i,j)) > max) { jp = i; max = gmm::abs(A(i,j)); } - ipvt.set(j, jp + 1); + ipvt[j] = INT(jp + 1); if (max == R(0)) { info = j + 1; break; } if (jp != j) for (i = 0; i < N; ++i) std::swap(A(jp, i), A(j, i)); @@ -151,7 +116,7 @@ namespace gmm { rank_one_update(sub_matrix(A, sub_interval(j+1, M-j-1), sub_interval(j+1, N-j-1)), c, conjugated(r)); } - ipvt.set(NN-1, NN); + ipvt[NN-1] = INT(NN); } return info; } @@ -165,7 +130,7 @@ namespace gmm { typedef typename linalg_traits<DenseMatrix>::value_type T; copy(b, x); for(size_type i = 0; i < pvector.size(); ++i) { - size_type perm = pvector.get(i)-1; // permutations stored in 1's offset + size_type perm = size_type(pvector[i]-1); // permutations stored in 1's offset if(i != perm) { T aux = x[i]; x[i] = x[perm]; x[perm] = aux; } } /* solve Ax = b -> LUx = b -> Ux = L^-1 b. */ @@ -193,7 +158,7 @@ namespace gmm { lower_tri_solve(transposed(LU), x, false); upper_tri_solve(transposed(LU), x, true); for(size_type i = pvector.size(); i > 0; --i) { - size_type perm = pvector.get(i-1)-1; // permutations stored in 1's offset + size_type perm = size_type(pvector[i-1]-1); // permutations stored in 1's offset if(i-1 != perm) { T aux = x[i-1]; x[i-1] = x[perm]; x[perm] = aux; } } } @@ -263,11 +228,12 @@ namespace gmm { typename linalg_traits<DenseMatrixLU>::value_type lu_det(const DenseMatrixLU& LU, const Pvector &pvector) { typedef typename linalg_traits<DenseMatrixLU>::value_type T; + typedef typename linalg_traits<Pvector>::value_type INT; T det(1); for (size_type j = 0; j < std::min(mat_nrows(LU), mat_ncols(LU)); ++j) det *= LU(j,j); - for(size_type i = 0; i < pvector.size(); ++i) - if (i != size_type(pvector.get(i)-1)) { det = -det; } + for(INT i = 0; i < INT(pvector.size()); ++i) + if (i != pvector[i]-1) { det = -det; } return det; } @@ -284,5 +250,7 @@ namespace gmm { } +#include "gmm_opt.h" + #endif diff --git a/src/gmm/gmm_lapack_interface.h b/src/gmm/gmm_lapack_interface.h index fa38addf..e7f9a99e 100644 --- a/src/gmm/gmm_lapack_interface.h +++ b/src/gmm/gmm_lapack_interface.h @@ -106,12 +106,8 @@ namespace gmm { GMMLAPACK_TRACE("getrf_interface"); \ BLAS_INT m = BLAS_INT(mat_nrows(A)), n = BLAS_INT(mat_ncols(A)), lda(m); \ BLAS_INT info(-1); \ - if (m && n) lapack_name(&m, &n, &A(0,0), &lda, ipvt.pfirst(), &info); \ - if ((sizeof(BLAS_INT) == 4) || \ - ((info & 0xFFFFFFFF00000000L) && !(info & 0x00000000FFFFFFFFL))) \ - /* For compatibility with lapack version with 32 bit integer. */ \ - ipvt.set_to_int32(); \ - return size_type(int(info & 0x00000000FFFFFFFFL)); \ + if (m && n) lapack_name(&m, &n, &A(0,0), &lda, &ipvt[0], &info); \ + return size_type(int(info & 0x00000000FFFFFFFFL)); \ } getrf_interface(sgetrf_, BLAS_S) @@ -131,7 +127,7 @@ namespace gmm { BLAS_INT n = BLAS_INT(mat_nrows(A)), info(0), nrhs(1); \ gmm::copy(b, x); trans1; \ if (n) \ - lapack_name(&t,&n,&nrhs,&(A(0,0)),&n,ipvt.pfirst(),&x[0],&n,&info); \ + lapack_name(&t,&n,&nrhs,&(A(0,0)),&n,&ipvt[0],&x[0],&n,&info); \ } # define getrs_trans_n const char t = 'N' @@ -160,10 +156,10 @@ namespace gmm { base_type work1; \ if (n) { \ gmm::copy(LU, A); \ - lapack_name(&n, &A(0,0), &n, ipvt.pfirst(), &work1, &lwork, &info); \ + lapack_name(&n, &A(0,0), &n, &ipvt[0], &work1, &lwork, &info); \ lwork = int(gmm::real(work1)); \ std::vector<base_type> work(lwork); \ - lapack_name(&n, &A(0,0), &n, ipvt.pfirst(), &work[0], &lwork,&info); \ + lapack_name(&n, &A(0,0), &n, &ipvt[0], &work[0], &lwork, &info); \ } \ } diff --git a/src/gmm/gmm_opt.h b/src/gmm/gmm_opt.h index 7be45c44..1a4b1f14 100644 --- a/src/gmm/gmm_opt.h +++ b/src/gmm/gmm_opt.h @@ -37,7 +37,7 @@ #ifndef GMM_OPT_H__ #define GMM_OPT_H__ -#include <gmm/gmm_dense_lu.h> +#include "gmm_dense_lu.h" namespace gmm {