*ahem*, attached

On Wed, 30 Mar 2022, Elijah Stone wrote:

Recent apple arm CPUs include a hardware coprocessor for matrix multiplication. This is nominally not directly accessible to user code (though it has been partly reverse engineered), but must be accessed through apple's blas implementation. Attached trivial patch makes j use this rather than its own routines for large matrix multiplication on darwin/arm. Performance delta is quite good. Before:

   a=. ?1e3 2e3$0
   b=. ?2e3 3e3$0
   100 timex 'a +/ . * b'
0.103497

after:

   100 timex 'a +/ . * b'
0.0274741
   0.103497%0.0274741
3.76708

Nearly 4x faster!

There seems to be a warmup period (big buffers go brrr...), so the gemm threshold should perhaps be tuned. I did not take detailed measurements.

(Fine print: benchmarks taken on a 14in macbook w/m1pro.)

Also of note: on desktop (zen2), numpy is 3x faster than j. I tried swapping out j's mm microkernel for the newest from blis, and got only a modest boost, so the problem is not there. I think numpy is using openblas. (On arm, j and numpy are reasonably close, and the hardware accelerator smokes both.)

 -E
----------------------------------------------------------------------
For information about J forums see http://www.jsoftware.com/forums.htm
diff --git a/jsrc/gemm.c b/jsrc/gemm.c
index 5333c7b3..c573a8f9 100644
--- a/jsrc/gemm.c
+++ b/jsrc/gemm.c
@@ -29,6 +29,24 @@
 #define MR  BLIS_DEFAULT_MR_D
 #define NR  BLIS_DEFAULT_NR_D
 
+#ifdef SYSTEM_BLAS
+enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102 };
+enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113, 
AtlasConj=114};
+// note: C may not alias B or A
+void cblas_dgemm(const enum CBLAS_ORDER Order,
+                 const enum CBLAS_TRANSPOSE TransA,
+                 const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
+                 const int K, const double alpha, const double *A,
+                 const int lda, const double *B, const int ldb,
+                 const double beta, double *C, const int ldc);
+void cblas_zgemm(const enum CBLAS_ORDER Order,
+                 const enum CBLAS_TRANSPOSE TransA,
+                 const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
+                 const int K, const dcomplex *alpha, const dcomplex *A, const 
int lda,
+                 const dcomplex *B, const int ldb, const dcomplex *beta, 
dcomplex *C,
+                 const int ldc);
+#endif
+
 
 //
 //  Packing complete panels from A (i.e. without padding)
@@ -265,6 +283,27 @@ dgemm_macro_kernel(dim_t   mc,
 //
 //  Compute C <- beta*C + alpha*A*B
 //
+#ifdef SYSTEM_BLAS
+void
+dgemm_nn         (I              m,
+                  I              n,
+                  I              k,
+                  double         alpha,
+                  double         *A,
+                  I              rs_a,
+                  I              cs_a,
+                  double         *B,
+                  I              rs_b,
+                  I              cs_b,
+                  double         beta,
+                  double         *C,
+                  I              rs_c,
+                  I              cs_c)
+{
+ //note: we never get passed significant values for cs_x
+ cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha, A, 
rs_a, B, rs_b, beta, C, rs_c);
+}
+#else
 void
 dgemm_nn         (I              m,
                   I              n,
@@ -335,6 +374,7 @@ dgemm_nn         (I              m,
         }
     }
 }
+#endif //SYSTEM_BLAS
 
 
 // -----------------------------------------------------------------
@@ -778,6 +818,27 @@ zgemm_macro_kernel(dim_t   mc,
 //
 //  Compute C <- beta*C + alpha*A*B
 //
+#ifdef SYSTEM_BLAS
+void
+zgemm_nn         (I              m,
+                  I              n,
+                  I              k,
+                  dcomplex       alpha,
+                  dcomplex       *A,
+                  I              rs_a,
+                  I              cs_a,
+                  dcomplex       *B,
+                  I              rs_b,
+                  I              cs_b,
+                  dcomplex       beta,
+                  dcomplex       *C,
+                  I              rs_c,
+                  I              cs_c)
+{
+ //note: we never get passed significant values for cs_x
+ cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, &alpha, A, 
rs_a, B, rs_b, &beta, C, rs_c);
+}
+#else
 void
 zgemm_nn         (I              m,
                   I              n,
@@ -848,4 +909,5 @@ zgemm_nn         (I              m,
         }
     }
 }
+#endif //SYSTEM_BLAS
 
----------------------------------------------------------------------
For information about J forums see http://www.jsoftware.com/forums.htm

Reply via email to