masahi commented on code in PR #14291: URL: https://github.com/apache/tvm/pull/14291#discussion_r1135119116
########## src/runtime/contrib/cublas/cublas.cc: ########## @@ -133,6 +134,120 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s int roundoff(int v, int d) { return (v + d - 1) / d * d; } #if CUDART_VERSION >= 10010 + +void CallCublasLt(cublasLtHandle_t hdl, const DLTensor* A, const DLTensor* B, const DLTensor* bias, + const DLTensor* C, bool transa, bool transb, cublasLtEpilogue_t epilogue) { + ICHECK(TypeEqual(A->dtype, B->dtype)); + // Reversed strides indicates an in-place transpose operation. + transa = IsInPlaceTransposed(A) ? !transa : transa; + transb = IsInPlaceTransposed(B) ? !transb : transb; + + auto compute_type = CUBLAS_COMPUTE_32F; + auto scale_type = CUDA_R_32F; + cudaDataType_t ab_type = CUDA_R_32F; + cudaDataType_t c_type = CUDA_R_32F; + float one_fp32 = 1.0; + float zero_fp32 = 0.0; + auto one_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(1.0); + auto zero_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(0.0); + void* alpha = &one_fp32; + void* beta = &zero_fp32; + + if (A->dtype.bits == 16 && A->dtype.code == kDLFloat) { + ab_type = CUDA_R_16F; + } + + if (C->dtype.bits == 16 && C->dtype.code == kDLFloat) { + c_type = CUDA_R_16F; + compute_type = CUBLAS_COMPUTE_16F; + scale_type = CUDA_R_16F; + alpha = &one_fp16; + beta = &zero_fp16; + } + + cublasLtMatmulDesc_t op_desc; + cublasOperation_t op_transa = CUBLASBooleanToTranspose(transa); + cublasOperation_t op_transb = CUBLASBooleanToTranspose(transb); + + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSA, + &op_transb, sizeof(op_transa))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSB, + &op_transa, sizeof(op_transb))); + + if (bias != nullptr) { + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias->data, sizeof(float*))); + } + + if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) { + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, sizeof(epilogue))); + } + + int batch_offset_A = A->ndim - 2; + int batch_offset_B = B->ndim - 2; + + int M = ColumnCount(B, transb, batch_offset_B); + int N = RowCount(A, transa, batch_offset_A); + int K = ColumnCount(A, transa, batch_offset_A); + + int lda = transb ? K : M; + int ldb = transa ? N : K; + int ldc = M; + + cublasLtMatrixLayout_t A_desc, B_desc, C_desc; + CHECK_CUBLAS_ERROR( + cublasLtMatrixLayoutCreate(&A_desc, ab_type, !transb ? M : K, !transb ? K : M, lda)); + CHECK_CUBLAS_ERROR( + cublasLtMatrixLayoutCreate(&B_desc, ab_type, !transa ? K : N, !transa ? N : K, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&C_desc, c_type, M, N, ldc)); + + if (A->ndim != 2 || B->ndim != 2) { + auto get_batch_count = [](int64_t* shape, int batch_offset) { + int64_t count = 1; + for (int i = 0; i < batch_offset; ++i) { + count *= shape[i]; + } + return count; + }; + auto set_batch = [](cublasLtMatrixLayout_t mat_desc, int batch_count, int64_t batch_stride) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + mat_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count))); + CHECK_CUBLAS_ERROR( + cublasLtMatrixLayoutSetAttribute(mat_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, sizeof(batch_stride))); + }; + + int batch_count_A = get_batch_count(A->shape, batch_offset_A); + int batch_count_B = get_batch_count(B->shape, batch_offset_B); + int batch_count_C = get_batch_count(C->shape, C->ndim - 2); + int64_t batch_stride_A = M * K; + int64_t batch_stride_B = K * N; + int64_t batch_stride_C = M * N; + + // cuBLASLt does not seem to support batched GEMM with one of matrices having + // one batch (with batch_stride 0). + ICHECK_EQ(batch_count_A, batch_count_B); Review Comment: I was not able to get ND x 2D batch GEMM working with cuBLASLt, even though the old API perfectly supports it https://github.com/apache/tvm/blob/5cacecc0c026fcd4c4c0064a8ef740374b86a699/src/runtime/contrib/cublas/cublas.cc#L305-L313 I wonder if I'm missing something or it is a known limitation cc @mnicely -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org