masahi commented on code in PR #14291:
URL: https://github.com/apache/tvm/pull/14291#discussion_r1135119116


##########
src/runtime/contrib/cublas/cublas.cc:
##########
@@ -133,6 +134,120 @@ bool CheckMixPrecisionType(DLDataType in_dtype, 
DLDataType out_dtype, bool int_s
 int roundoff(int v, int d) { return (v + d - 1) / d * d; }
 
 #if CUDART_VERSION >= 10010
+
+void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, 
const DLTensor* bias,
+                  const DLTensor* C, bool transa, bool transb, 
cublasLtEpilogue_t epilogue) {
+  ICHECK(TypeEqual(A->dtype, B->dtype));
+  // Reversed strides indicates an in-place transpose operation.
+  transa = IsInPlaceTransposed(A) ? !transa : transa;
+  transb = IsInPlaceTransposed(B) ? !transb : transb;
+
+  auto compute_type = CUBLAS_COMPUTE_32F;
+  auto scale_type = CUDA_R_32F;
+  cudaDataType_t ab_type = CUDA_R_32F;
+  cudaDataType_t c_type = CUDA_R_32F;
+  float one_fp32 = 1.0;
+  float zero_fp32 = 0.0;
+  auto one_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 
10>(1.0);
+  auto zero_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 
10>(0.0);
+  void* alpha = &one_fp32;
+  void* beta = &zero_fp32;
+
+  if (A->dtype.bits == 16 && A->dtype.code == kDLFloat) {
+    ab_type = CUDA_R_16F;
+  }
+
+  if (C->dtype.bits == 16 && C->dtype.code == kDLFloat) {
+    c_type = CUDA_R_16F;
+    compute_type = CUBLAS_COMPUTE_16F;
+    scale_type = CUDA_R_16F;
+    alpha = &one_fp16;
+    beta = &zero_fp16;
+  }
+
+  cublasLtMatmulDesc_t op_desc;
+  cublasOperation_t op_transa = CUBLASBooleanToTranspose(transa);
+  cublasOperation_t op_transb = CUBLASBooleanToTranspose(transb);
+
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&op_desc, compute_type, 
scale_type));
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_TRANSA,
+                                                    &op_transb, 
sizeof(op_transa)));
+  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_TRANSB,
+                                                    &op_transa, 
sizeof(op_transb)));
+
+  if (bias != nullptr) {
+    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
+                                                      &bias->data, 
sizeof(float*)));
+  }
+
+  if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
+    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_EPILOGUE,
+                                                      &epilogue, 
sizeof(epilogue)));
+  }
+
+  int batch_offset_A = A->ndim - 2;
+  int batch_offset_B = B->ndim - 2;
+
+  int M = ColumnCount(B, transb, batch_offset_B);
+  int N = RowCount(A, transa, batch_offset_A);
+  int K = ColumnCount(A, transa, batch_offset_A);
+
+  int lda = transb ? K : M;
+  int ldb = transa ? N : K;
+  int ldc = M;
+
+  cublasLtMatrixLayout_t A_desc, B_desc, C_desc;
+  CHECK_CUBLAS_ERROR(
+      cublasLtMatrixLayoutCreate(&A_desc, ab_type, !transb ? M : K, !transb ? 
K : M, lda));
+  CHECK_CUBLAS_ERROR(
+      cublasLtMatrixLayoutCreate(&B_desc, ab_type, !transa ? K : N, !transa ? 
N : K, ldb));
+  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&C_desc, c_type, M, N, ldc));
+
+  if (A->ndim != 2 || B->ndim != 2) {
+    auto get_batch_count = [](int64_t* shape, int batch_offset) {
+      int64_t count = 1;
+      for (int i = 0; i < batch_offset; ++i) {
+        count *= shape[i];
+      }
+      return count;
+    };
+    auto set_batch = [](cublasLtMatrixLayout_t mat_desc, int batch_count, 
int64_t batch_stride) {
+      CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
+          mat_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, 
sizeof(batch_count)));
+      CHECK_CUBLAS_ERROR(
+          cublasLtMatrixLayoutSetAttribute(mat_desc, 
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
+                                           &batch_stride, 
sizeof(batch_stride)));
+    };
+
+    int batch_count_A = get_batch_count(A->shape, batch_offset_A);
+    int batch_count_B = get_batch_count(B->shape, batch_offset_B);
+    int batch_count_C = get_batch_count(C->shape, C->ndim - 2);
+    int64_t batch_stride_A = M * K;
+    int64_t batch_stride_B = K * N;
+    int64_t batch_stride_C = M * N;
+
+    // cuBLASLt does not seem to support batched GEMM with one of matrices 
having
+    // one batch (with batch_stride 0).
+    ICHECK_EQ(batch_count_A, batch_count_B);

Review Comment:
   I was not able to get ND x 2D batch GEMM working with cuBLASLt, even though 
the old API perfectly supports it 
https://github.com/apache/tvm/blob/5cacecc0c026fcd4c4c0064a8ef740374b86a699/src/runtime/contrib/cublas/cublas.cc#L305-L313
   
   I wonder if I'm missing something or it is a known limitation cc @mnicely 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to