branch: improve-expm-performance commit ee139e697b50db9198ef8257e2c492202b037a35 Author: Konstantinos Poulios <logar...@gmail.com> AuthorDate: Fri Oct 20 10:59:42 2023 +0200
Use a Pade approximation for expm, ported from Eigen/Unsupported --- src/getfem_plasticity.cc | 472 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 386 insertions(+), 86 deletions(-) diff --git a/src/getfem_plasticity.cc b/src/getfem_plasticity.cc index 5a0ec9c0..69ca66b2 100644 --- a/src/getfem_plasticity.cc +++ b/src/getfem_plasticity.cc @@ -46,109 +46,409 @@ namespace getfem { { mi.resize(2); mi[0] = mi[1] = N; } - bool expm(const base_matrix &a_, base_matrix &aexp, scalar_type tol=1e-15) { + inline void matmul(base_matrix &aa,base_matrix &bb,base_matrix &cc) + {gmm::mult(aa,bb,cc);} - const size_type itmax = 40; - base_matrix a(a_); - // scale input matrix a - int e; - frexp(gmm::mat_norminf(a), &e); - e = std::max(0, std::min(1023, e)); - gmm::scale(a, pow(scalar_type(2),-scalar_type(e))); - - base_matrix atmp(a), an(a); - gmm::copy(a, aexp); - gmm::add(gmm::identity_matrix(), aexp); - scalar_type factn(1); + bool expm(const base_matrix &a_, base_matrix &aexp) { + + const size_type N = gmm::mat_nrows(a_); bool success(false); - for (size_type n=2; n < itmax; ++n) { - factn /= scalar_type(n); - gmm::mult(an, a, atmp); - gmm::copy(atmp, an); - gmm::scale(atmp, factn); - gmm::add(atmp, aexp); - if (gmm::mat_euclidean_norm(atmp) < tol) { - success = true; - break; - } + + // Pade approximation ported from Eigen/Unsupported + base_matrix a(a_); + gmm::clear(aexp.as_vector()); + base_matrix tmp(aexp), v(aexp), u(aexp); // Pade approximant is (v+u)/(v-u) + const scalar_type l1norm = gmm::mat_norminf(a_); + int e = 0; // squarings + if (l1norm < 1.495585217958292e-002) { // matrix_exp_pade3(a, u, v) + const static std::array<scalar_type,4> b{120,60,12,1}; + base_matrix a2(a); + matmul(a, a, a2); + gmm::copy(gmm::scaled(a2,b[2]), v); // v = b2*A2 + b0*I + gmm::copy(gmm::scaled(a2,b[3]), u); // u = b3*A2 + b1*I + for (size_type ij=0; ij < N; ++ij) + { v(ij,ij) += b[0]; u(ij,ij) += b[1]; } + } else if (l1norm < 2.539398330063230e-001) { // matrix_exp_pade5(a, u, v) + const static std::array<scalar_type,6> b{30240,15120,3360,420,30,1}; + base_matrix a2(a), a4(a); + matmul(a, a, a2); + matmul(a2, a2, a4); + gmm::add(gmm::scaled(a4,b[4]), // v = b4*A4 + b2*A2 + b0*I + gmm::scaled(a2,b[2]), v); + gmm::add(gmm::scaled(a4,b[5]), // u = b5*A4 + b3*A2 + b1*I + gmm::scaled(a2,b[3]), u); + for (size_type ij=0; ij < N; ++ij) + { v(ij,ij) += b[0]; u(ij,ij) += b[1]; } + } else if (l1norm < 9.504178996162932e-001) { // matrix_exp_pade7(a, u, v) + const static std::array<scalar_type,8> + b{17297280,8648640,1995840,277200,25200,1512,56,1}; + base_matrix a2(a), a4(a), a6(a); + matmul(a, a, a2); + matmul(a2, a2, a4); + matmul(a2, a4, a6); + gmm::add(gmm::scaled(a6,b[6]), // v = b6*A6 + b4*A4 + b2*A2 + b0*I + gmm::scaled(a4,b[4]), v); + gmm::add(gmm::scaled(a2,b[2]), v); + gmm::add(gmm::scaled(a6,b[7]), // u = b7*A6 + b5*A4 + b3*A2 + b1*I + gmm::scaled(a4,b[5]), u); + gmm::add(gmm::scaled(a2,b[3]), u); + for (size_type ij=0; ij < N; ++ij) + { v(ij,ij) += b[0]; u(ij,ij) += b[1]; } + } else if (l1norm < 2.097847961257068e+000) { // matrix_exp_pade9(a, u, v) + const static std::array<scalar_type,10> + b{17643225600,8821612800,2075673600,302702400,30270240,2162160, + 110880,3960,90,1}; + base_matrix a2(a), a4(a), a6(a), a8(a); + matmul(a, a, a2); + matmul(a2, a2, a4); + matmul(a2, a4, a6); + matmul(a4, a4, a8); + gmm::add(gmm::scaled(a8,b[8]), // v = b8*A8+b6*A6+b4*A4+b2*A2+b0*I + gmm::scaled(a6,b[6]), v); + gmm::add(gmm::scaled(a4,b[4]), v); + gmm::add(gmm::scaled(a2,b[2]), v); + gmm::add(gmm::scaled(a8,b[9]), // u = b9*A8+b7*A6+b5*A4+b3*A2+b1*I + gmm::scaled(a6,b[7]), u); + gmm::add(gmm::scaled(a4,b[5]), u); + gmm::add(gmm::scaled(a2,b[3]), u); + for (size_type ij=0; ij < N; ++ij) + { v(ij,ij) += b[0]; u(ij,ij) += b[1]; } + } else { // matrix_exp_pade13(a, U, V); + const scalar_type maxnorm = 5.371920351148152; + frexp(l1norm / maxnorm, &e); + if (e <= 0) e = 0; + else for (auto &&val : a.as_vector()) { val = ldexp(val,-e); } + // <==> gmm::scale(a, pow(scalar_type(2),-scalar_type(e))); + const static std::array<scalar_type,14> + b{64764752532480000,32382376266240000,7771770303897600, + 1187353796428800,129060195264000,10559470521600,670442572800, + 33522128640,1323241920,40840800,960960,16380,182,1}; + base_matrix a2(a), a4(a), a6(a); + matmul(a, a, a2); + matmul(a2, a2, a4); + matmul(a2, a4, a6); + gmm::add(gmm::scaled(a6,b[12]), + gmm::scaled(a4,b[10]), tmp); + gmm::add(gmm::scaled(a2,b[8]), tmp); + matmul(a6, tmp, v); // v = b12*A12+b10*A10+b8*A8 + gmm::add(gmm::scaled(a6,b[6]), v); // + b6*A6+b4*A4+b2*A2+b0*I + gmm::add(gmm::scaled(a4,b[4]), v); + gmm::add(gmm::scaled(a2,b[2]), v); + gmm::add(gmm::scaled(a6,b[13]), + gmm::scaled(a4,b[11]), tmp); + gmm::add(gmm::scaled(a2,b[9]), tmp); + matmul(a6, tmp, u); // u = b13*A12+b11*A10+b9*A8 + gmm::add(gmm::scaled(a6,b[7]), u); // + b7*A6+b5*A4+b3*A2+b1*I + gmm::add(gmm::scaled(a4,b[5]), u); + gmm::add(gmm::scaled(a2,b[3]), u); + for (size_type ij=0; ij < N; ++ij) + { v(ij,ij) += b[0]; u(ij,ij) += b[1]; } } - // unscale result - for (int i=0; i < e; ++i) { - gmm::mult(aexp, aexp, atmp); - gmm::copy(atmp, aexp); + std::swap(u, tmp); + matmul(a, tmp, u); // u <-- A*u + + gmm::add(v,gmm::scaled(u,-1),tmp); // tmp = denom = v-u + gmm::lu_inverse(tmp); // tmp = (v-u)^-1 + gmm::add(u,v); // v <-- numer = v+u; + matmul(tmp,v,aexp); + success = true; + + for (int i=0; i < e; ++i) { // unscale result + std::swap(aexp, tmp); + matmul(tmp, tmp, aexp); } return success; } - bool expm_deriv(const base_matrix &a_, base_tensor &daexp, - base_matrix *paexp=NULL, scalar_type tol=1e-15) { - - const size_type itmax = 40; - size_type N = gmm::mat_nrows(a_); - size_type N2 = N*N; - base_matrix a(a_); - // scale input matrix a - int e; - frexp(gmm::mat_norminf(a), &e); - e = std::max(0, std::min(1023, e)); - scalar_type scale = pow(scalar_type(2),-scalar_type(e)); - gmm::scale(a, scale); - - base_vector factnn(itmax); - base_matrix atmp(a), an(a), aexp(a); - base_tensor ann(bgeot::multi_index(N,N,itmax)); - gmm::add(gmm::identity_matrix(), aexp); - gmm::copy(gmm::identity_matrix(), atmp); - std::copy(atmp.begin(), atmp.end(), ann.begin()); - factnn[1] = 1; - std::copy(a.begin(), a.end(), ann.begin()+N2); - size_type n; - bool success(false); - for (n=2; n < itmax; ++n) { - factnn[n] = factnn[n-1]/scalar_type(n); - gmm::mult(an, a, atmp); - gmm::copy(atmp, an); - std::copy(an.begin(), an.end(), ann.begin()+n*N2); - gmm::scale(atmp, factnn[n]); - gmm::add(atmp, aexp); - if (gmm::mat_euclidean_norm(atmp) < tol) { - success = true; - break; - } - } - if (!success) - return false; + bool expm_deriv(const base_matrix &a_, base_tensor &daexp) { + size_type N = gmm::mat_nrows(a_); + base_matrix a(a_), tmp(a_); + gmm::clear(tmp.as_vector()); + base_matrix aexp(tmp), v(tmp), u(tmp), // Pade approximant is (v+u)/(v-u) + tmp_(tmp), dv_(tmp), du_(tmp); gmm::clear(daexp.as_vector()); - gmm::scale(factnn, scale); - for (--n; n >= 1; --n) { - scalar_type factn = factnn[n]; - for (size_type m=1; m <= n; ++m) - for (size_type l=0; l < N; ++l) - for (size_type k=0; k < N; ++k) - for (size_type j=0; j < N; ++j) - for (size_type i=0; i < N; ++i) - daexp(i,j,k,l) += factn*ann(i,k,m-1)*ann(l,j,n-m); - } + base_tensor dv(daexp), du(daexp); + const scalar_type l1norm = gmm::mat_norminf(a_); + int e = 0; // squarings + if (l1norm < 1.495585217958292e-002) { // matrix_exp_pade3(a, u, v) + const static std::array<scalar_type,4> b{120,60,12,1}; + base_matrix a2(a); + matmul(a, a, a2); + gmm::copy(gmm::scaled(a2,b[2]), v); // v = b2*A2 + b0*I + gmm::copy(gmm::scaled(a2,b[3]), u); // u = b3*A2 + b1*I + for (size_type ij=0; ij < N; ++ij) + { v(ij,ij) += b[0]; u(ij,ij) += b[1]; } + + for (size_type l=0; l < N; ++l) // tmp derivative of a2 + for (size_type k=0; k < N; ++k) { + gmm::clear(dv_); gmm::clear(du_); + for (size_type ij=0; ij < N; ++ij) { + const auto &al=a(l,ij), &ak=a(ij,k); + dv_(k,ij) += b[2]*al; dv_(ij,l) += b[2]*ak; + du_(k,ij) += b[3]*al; du_(ij,l) += b[3]*ak; + } + std::swap(du_,tmp); // derivative of u <-- A*u + matmul(a,tmp,du_); + for (size_type j=0; j < N; ++j) // i == k + du_(k,j) += u(l,j); - // unscale result - base_matrix atmp1(a), atmp2(a); - for (int i=0; i < e; ++i) { + std::copy(dv_.begin(),dv_.end(), &dv(0,0,k,l)); + std::copy(du_.begin(),du_.end(), &du(0,0,k,l)); + } + } else if (l1norm < 2.539398330063230e-001) { // matrix_exp_pade5(a, u, v) + const static std::array<scalar_type,6> b{30240,15120,3360,420,30,1}; + base_matrix a2(a), a4(a); + matmul(a, a, a2); + matmul(a2, a2, a4); + gmm::add(gmm::scaled(a4,b[4]), // v = b4*A4 + b2*A2 + b0*I + gmm::scaled(a2,b[2]), v); + gmm::add(gmm::scaled(a4,b[5]), // u = b5*A4 + b3*A2 + b1*I + gmm::scaled(a2,b[3]), u); + for (size_type ij=0; ij < N; ++ij) + { v(ij,ij) += b[0]; u(ij,ij) += b[1]; } + + base_matrix da2(aexp); // zero init for (size_type l=0; l < N; ++l) for (size_type k=0; k < N; ++k) { - std::copy(&daexp(0,0,k,l), &daexp(0,0,k,l)+N*N, atmp.begin()); - gmm::mult(atmp, aexp, atmp1); - gmm::mult(aexp, atmp, atmp2); - gmm::add(atmp1, atmp2, atmp); - std::copy(atmp.begin(), atmp.end(), &daexp(0,0,k,l)); + gmm::clear(da2); gmm::clear(dv_); gmm::clear(du_); + for (size_type ij=0; ij < N; ++ij) { + const auto &al=a(l,ij), &ak=a(ij,k); + da2(k,ij) += al; da2(ij,l) += ak; + dv_(k,ij) += b[2]*al; dv_(ij,l) += b[2]*ak; + du_(k,ij) += b[3]*al; du_(ij,l) += b[3]*ak; + } + matmul(a2,da2,tmp); + matmul(da2,a2,tmp_); + gmm::add(tmp_,tmp); // tmp derivative of a4 + gmm::add(gmm::scaled(tmp,b[4]), dv_); + gmm::add(gmm::scaled(tmp,b[5]), du_); + + std::swap(du_,tmp); // derivative of u <-- A*u + matmul(a,tmp,du_); + for (size_type j=0; j < N; ++j) // i == k + du_(k,j) += u(l,j); + + std::copy(dv_.begin(),dv_.end(), &dv(0,0,k,l)); + std::copy(du_.begin(),du_.end(), &du(0,0,k,l)); + } + } else if (l1norm < 9.504178996162932e-001) { // matrix_exp_pade7(a, u, v) + const static std::array<scalar_type,8> + b{17297280,8648640,1995840,277200,25200,1512,56,1}; + base_matrix a2(a), a4(a), a6(a); + matmul(a, a, a2); + matmul(a2, a2, a4); + matmul(a2, a4, a6); + gmm::add(gmm::scaled(a6,b[6]), // v = b6*A6 + b4*A4 + b2*A2 + b0*I + gmm::scaled(a4,b[4]), v); + gmm::add(gmm::scaled(a2,b[2]), v); + gmm::add(gmm::scaled(a6,b[7]), // u = b7*A6 + b5*A4 + b3*A2 + b1*I + gmm::scaled(a4,b[5]), u); + gmm::add(gmm::scaled(a2,b[3]), u); + for (size_type ij=0; ij < N; ++ij) + { v(ij,ij) += b[0]; u(ij,ij) += b[1]; } + + base_matrix da2(aexp); // zero init + for (size_type l=0; l < N; ++l) + for (size_type k=0; k < N; ++k) { + gmm::clear(da2); gmm::clear(dv_); gmm::clear(du_); + for (size_type ij=0; ij < N; ++ij) { + const auto &al=a(l,ij), &ak=a(ij,k); + da2(k,ij) += al; da2(ij,l) += ak; + dv_(k,ij) += b[2]*al; dv_(ij,l) += b[2]*ak; + du_(k,ij) += b[3]*al; du_(ij,l) += b[3]*ak; + } + matmul(a2,da2,tmp); + matmul(da2,a2,tmp_); + gmm::add(tmp_,tmp); // tmp derivative of a4 + gmm::add(gmm::scaled(tmp,b[4]), dv_); + gmm::add(gmm::scaled(tmp,b[5]), du_); + + matmul(a2,tmp,tmp_); + matmul(da2,a4,tmp); + gmm::add(tmp_,tmp); // tmp derivative of a6 + gmm::add(gmm::scaled(tmp,b[6]), dv_); + gmm::add(gmm::scaled(tmp,b[7]), du_); + + std::swap(du_,tmp); // derivative of u <-- A*u + matmul(a,tmp,du_); + for (size_type j=0; j < N; ++j) // i == k + du_(k,j) += u(l,j); + + std::copy(dv_.begin(),dv_.end(), &dv(0,0,k,l)); + std::copy(du_.begin(),du_.end(), &du(0,0,k,l)); + } + } else if (l1norm < 2.097847961257068e+000) { // matrix_exp_pade9(a, u, v) + const static std::array<scalar_type,10> + b{17643225600,8821612800,2075673600,302702400,30270240,2162160, + 110880,3960,90,1}; + base_matrix a2(a), a4(a), a6(a), a8(a); + matmul(a, a, a2); + matmul(a2, a2, a4); + matmul(a2, a4, a6); + matmul(a4, a4, a8); + gmm::add(gmm::scaled(a8,b[8]), // v = b8*A8+b6*A6+b4*A4+b2*A2+b0*I + gmm::scaled(a6,b[6]), v); + gmm::add(gmm::scaled(a4,b[4]), v); + gmm::add(gmm::scaled(a2,b[2]), v); + gmm::add(gmm::scaled(a8,b[9]), // u = b9*A8+b7*A6+b5*A4+b3*A2+b1*I + gmm::scaled(a6,b[7]), u); + gmm::add(gmm::scaled(a4,b[5]), u); + gmm::add(gmm::scaled(a2,b[3]), u); + for (size_type ij=0; ij < N; ++ij) + { v(ij,ij) += b[0]; u(ij,ij) += b[1]; } + + base_matrix da2(aexp), da4(aexp); // zero init + for (size_type l=0; l < N; ++l) + for (size_type k=0; k < N; ++k) { + gmm::clear(da2); gmm::clear(dv_); gmm::clear(du_); + for (size_type ij=0; ij < N; ++ij) { + const auto &al=a(l,ij), &ak=a(ij,k); + da2(k,ij) += al; da2(ij,l) += ak; + dv_(k,ij) += b[2]*al; dv_(ij,l) += b[2]*ak; + du_(k,ij) += b[3]*al; du_(ij,l) += b[3]*ak; + } + matmul(a2,da2,tmp); + matmul(da2,a2,da4); + gmm::add(tmp,da4); + gmm::add(gmm::scaled(da4,b[4]), dv_); + gmm::add(gmm::scaled(da4,b[5]), du_); + + matmul(a2,da4,tmp_); + matmul(da2,a4,tmp); + gmm::add(tmp_,tmp); // tmp derivative of a6 + gmm::add(gmm::scaled(tmp,b[6]), dv_); + gmm::add(gmm::scaled(tmp,b[7]), du_); + + matmul(a4,da4,tmp); + matmul(da4,a4,tmp_); + gmm::add(tmp_,tmp); // tmp derivative of a8 + gmm::add(gmm::scaled(tmp,b[8]), dv_); + gmm::add(gmm::scaled(tmp,b[9]), du_); + + std::swap(du_,tmp); // derivative of u <-- A*u + matmul(a,tmp,du_); + for (size_type j=0; j < N; ++j) // i == k + du_(k,j) += u(l,j); + + std::copy(dv_.begin(),dv_.end(), &dv(0,0,k,l)); + std::copy(du_.begin(),du_.end(), &du(0,0,k,l)); + } + } else { // matrix_exp_pade13(a, U, V); + const scalar_type maxnorm = 5.371920351148152; + frexp(l1norm / maxnorm, &e); + if (e <= 0) e = 0; + else for (auto &&val : a.as_vector()) { val = ldexp(val,-e); } + // <==> gmm::scale(a, pow(scalar_type(2),-scalar_type(e))); + const static std::array<scalar_type,14> + b{64764752532480000,32382376266240000,7771770303897600, + 1187353796428800,129060195264000,10559470521600,670442572800, + 33522128640,1323241920,40840800,960960,16380,182,1}; + base_matrix a2(a), a4(a), a6(a), v_(a), u_(a); + matmul(a, a, a2); + matmul(a2, a2, a4); + matmul(a2, a4, a6); + gmm::add(gmm::scaled(a6,b[12]), + gmm::scaled(a4,b[10]), v_); + gmm::add(gmm::scaled(a2,b[8]), v_); + matmul(a6, v_, v); // v = b12*A12+b10*A10+b8*A8 + gmm::add(gmm::scaled(a6,b[6]), v); // + b6*A6+b4*A4+b2*A2+b0*I + gmm::add(gmm::scaled(a4,b[4]), v); + gmm::add(gmm::scaled(a2,b[2]), v); + + gmm::add(gmm::scaled(a6,b[13]), + gmm::scaled(a4,b[11]), u_); + gmm::add(gmm::scaled(a2,b[9]), u_); + matmul(a6, u_, u); // u = b13*A12+b11*A10+b9*A8 + gmm::add(gmm::scaled(a6,b[7]), u); // + b7*A6+b5*A4+b3*A2+b1*I + gmm::add(gmm::scaled(a4,b[5]), u); + gmm::add(gmm::scaled(a2,b[3]), u); + for (size_type ij=0; ij < N; ++ij) + { v(ij,ij) += b[0]; u(ij,ij) += b[1]; } + + base_matrix da2(aexp), da4(aexp), da6(aexp), + dv__(aexp), du__(aexp); + for (size_type l=0; l < N; ++l) + for (size_type k=0; k < N; ++k) { + gmm::clear(da2); gmm::clear(dv_); gmm::clear(du_); + gmm::clear(dv__); gmm::clear(du__); + for (size_type ij=0; ij < N; ++ij) { + const auto &al=a(l,ij), &ak=a(ij,k); + da2(k,ij) += al; da2(ij,l) += ak; + dv_(k,ij) += b[2]*al; dv_(ij,l) += b[2]*ak; + du_(k,ij) += b[3]*al; du_(ij,l) += b[3]*ak; + dv__(k,ij) += b[8]*al; dv__(ij,l) += b[8]*ak; + du__(k,ij) += b[9]*al; du__(ij,l) += b[9]*ak; + } + matmul(a2,da2,da4); + matmul(da2,a2,tmp); gmm::add(tmp,da4); + gmm::add(gmm::scaled(da4,b[4]), dv_); + gmm::add(gmm::scaled(da4,b[5]), du_); + gmm::add(gmm::scaled(da4,b[10]), dv__); + gmm::add(gmm::scaled(da4,b[11]), du__); + + matmul(a2,da4,da6); + matmul(da2,a4,tmp); gmm::add(tmp,da6); + gmm::add(gmm::scaled(da6,b[6]), dv_); + gmm::add(gmm::scaled(da6,b[7]), du_); + gmm::add(gmm::scaled(da6,b[12]), dv__); + gmm::add(gmm::scaled(da6,b[13]), du__); + + matmul(a6,dv__,tmp); gmm::add(tmp, dv_); + matmul(da6,v_,tmp); gmm::add(tmp, dv_); + + matmul(a6,du__,tmp); gmm::add(tmp, du_); + matmul(da6,u_,tmp); gmm::add(tmp, du_); + + std::swap(du_,tmp); // derivative of u <-- A*u + matmul(a,tmp,du_); + for (size_type j=0; j < N; ++j) // i == k + du_(k,j) += u(l,j); + + std::copy(dv_.begin(),dv_.end(), &dv(0,0,k,l)); + std::copy(du_.begin(),du_.end(), &du(0,0,k,l)); } - gmm::mult(aexp, aexp, atmp); - gmm::copy(atmp, aexp); } + std::swap(u, tmp); + matmul(a, tmp, u); // u <-- A*u + + base_matrix inv_denom(v); + gmm::add(gmm::scaled(u,-1),inv_denom); // denom = v-u + gmm::lu_inverse(inv_denom); + + gmm::add(u,v,tmp); // tmp = numer = v+u + matmul(inv_denom,tmp,aexp); + + for (size_type l=0; l < N; ++l) + for (size_type k=0; k < N; ++k) { // daexp_kl= D\(dN_kl-dD_kl*aexp) + std::copy(&dv(0,0,k,l),&dv(0,0,k,l)+N*N, tmp_.begin()); + std::copy(&du(0,0,k,l),&du(0,0,k,l)+N*N, tmp.begin()); + gmm::add(gmm::scaled(tmp_/*dv*/,-1),tmp/*du*/); // tmp = -(dv-du) + matmul(tmp,aexp,tmp_); + std::copy(&du(0,0,k,l),&du(0,0,k,l)+N*N, tmp.begin()); + gmm::add(tmp/*du*/, tmp_); + std::copy(&dv(0,0,k,l),&dv(0,0,k,l)+N*N, tmp.begin()); + gmm::add(tmp/*dv*/, tmp_); // tmp_ = (dv+du)_kl-(dv-du)_kl*aexp + matmul(inv_denom, tmp_, tmp); + std::copy(tmp.begin(), tmp.end(), &daexp(0,0,k,l)); + } + if (e) + for (auto &&val : daexp.as_vector()) { val = ldexp(val,-e); } - if (paexp) gmm::copy(aexp, *paexp); + for (int i=0; i < e; ++i) { // unscale result + for (size_type l=0; l < N; ++l) + for (size_type k=0; k < N; ++k) { + std::copy(&daexp(0,0,k,l), &daexp(0,0,k,l)+N*N, tmp.begin()); + matmul(tmp, aexp, u); // u,v used a temporaries + matmul(aexp, tmp, v); // + gmm::add(u, v, tmp); + std::copy(tmp.begin(), tmp.end(), &daexp(0,0,k,l)); + } + std::swap(aexp,tmp); + matmul(tmp, tmp, aexp); + } return true; }