branch: master commit 9216a7acc4c992e99beaf8eeccfce871348a2e55 Author: Konstantinos Poulios <logar...@gmail.com> AuthorDate: Fri Apr 12 16:16:53 2024 +0200
Avoid nested preprocessor macros in gmm blas interface --- src/gmm/gmm_blas_interface.h | 44 +++++++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/src/gmm/gmm_blas_interface.h b/src/gmm/gmm_blas_interface.h index 4c8deddf..1b83f1cd 100644 --- a/src/gmm/gmm_blas_interface.h +++ b/src/gmm/gmm_blas_interface.h @@ -385,41 +385,39 @@ namespace gmm { } -# define axpy_interface(param1, trans1, blas_name, base_type) \ - inline void add(param1(base_type), std::vector<base_type> &y) { \ +# define axpy_interface(blas_name, base_type) \ + inline void add(const std::vector<base_type> &x, \ + std::vector<base_type> &y) { \ GMMLAPACK_TRACE("axpy_interface"); \ - BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); trans1(base_type); \ + BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); base_type a(1); \ if(n == 0) return; \ else if(n < 25) add_for_short_vectors(x, y, n); \ else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \ } -# define axpy2_interface(param1, trans1, blas_name, base_type) \ - inline void add(param1(base_type), std::vector<base_type> &y) { \ + axpy_interface(saxpy_, BLAS_S) + axpy_interface(daxpy_, BLAS_D) + axpy_interface(caxpy_, BLAS_C) + axpy_interface(zaxpy_, BLAS_Z) + + +# define axpy2_interface(blas_name, base_type) \ + inline void add \ + (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \ + std::vector<base_type> &y) { \ GMMLAPACK_TRACE("axpy_interface"); \ - BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); trans1(base_type); \ + BLAS_INT inc(1), n(BLAS_INT(vect_size(y))); \ + const std::vector<base_type>& x = *(linalg_origin(x_)); \ + base_type a(x_.r); \ if(n == 0) return; \ else if(n < 25) add_for_short_vectors(x, y, a, n); \ else blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \ } -# define axpy_p1(base_type) const std::vector<base_type> &x -# define axpy_trans1(base_type) base_type a(1) -# define axpy_p1_s(base_type) \ - const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_ -# define axpy_trans1_s(base_type) \ - const std::vector<base_type> &x = *(linalg_origin(x_)); \ - base_type a(x_.r) - - axpy_interface(axpy_p1, axpy_trans1, saxpy_, BLAS_S) - axpy_interface(axpy_p1, axpy_trans1, daxpy_, BLAS_D) - axpy_interface(axpy_p1, axpy_trans1, caxpy_, BLAS_C) - axpy_interface(axpy_p1, axpy_trans1, zaxpy_, BLAS_Z) - - axpy2_interface(axpy_p1_s, axpy_trans1_s, saxpy_, BLAS_S) - axpy2_interface(axpy_p1_s, axpy_trans1_s, daxpy_, BLAS_D) - axpy2_interface(axpy_p1_s, axpy_trans1_s, caxpy_, BLAS_C) - axpy2_interface(axpy_p1_s, axpy_trans1_s, zaxpy_, BLAS_Z) + axpy2_interface(saxpy_, BLAS_S) + axpy2_interface(daxpy_, BLAS_D) + axpy2_interface(caxpy_, BLAS_C) + axpy2_interface(zaxpy_, BLAS_Z) /* ********************************************************************* */