learning-chip opened a new issue #8406: URL: https://github.com/apache/tvm/issues/8406
## Problem description `topi.sparse.csrmv` has `"float32"` hard-coded inside `ir_builder` and `te.extern`, making it only accept float32, but not float64 and other data types: https://github.com/apache/tvm/blob/d3fc562a6f3b8cd4d0a5f86e1e3ebc503ebeba2b/python/tvm/topi/sparse/csrmv.py#L66-L68 https://github.com/apache/tvm/blob/d3fc562a6f3b8cd4d0a5f86e1e3ebc503ebeba2b/python/tvm/topi/sparse/csrmv.py#L80-L87 Same problem for `topi.sparse.csrmm`. ## Steps to reproduce Build TVM 0.8dev from the latest master branch, and then run: ```python # extracted from tests/python/topi/python/test_topi_sparse.py from tvm import te from tvm import topi import tvm.contrib.sparse as tvmsp dtype = "float64" # "float32" works fine nr, nc = (3, 5) nnz = 6 A = tvmsp.placeholder(shape=(nr, nc), nonzeros=nnz, dtype=dtype, name="A") B = te.placeholder((nc, 1), name="B") out = topi.sparse.csrmv(A, B) # TVMError: Cannot match type float64 vs float32 ``` Full error message: <details> ``` --------------------------------------------------------------------------- TVMError Traceback (most recent call last) <ipython-input-1-6daa8fd2fb08> in <module> 9 A = tvmsp.placeholder(shape=(nr, nc), nonzeros=nnz, dtype=dtype, name="A") 10 B = te.placeholder((nc, 1), name="B") ---> 11 out = topi.sparse.csrmv(A, B) # TVMError: Cannot match type float64 vs float32 /tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in csrmv(a, x, y) 111 2-D dense matrix with shape [m, 1] 112 """ --> 113 return csrmv_default(a.data, a.indices, a.indptr, x, y) /tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in csrmv_default(data, indices, indptr, weight, bias) 78 79 oshape = (batch, 1) ---> 80 matmul = te.extern( 81 oshape, 82 [data, indices, indptr, weight], /tvm_install/tvm/python/tvm/te/operation.py in extern(shape, inputs, fcompute, name, dtype, in_buffers, out_buffers, tag, attrs) 315 for shp, dt in zip(shape, dtype): 316 output_placeholders.append(tvm.tir.decl_buffer(shp, dt, name)) --> 317 body = fcompute(input_placeholders, output_placeholders) 318 if isinstance(body, tvm.tir.PrimExpr): 319 body = tvm.tir.Evaluate(body) /tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in <lambda>(ins, outs) 81 oshape, 82 [data, indices, indptr, weight], ---> 83 lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]), 84 tag="csrmv", 85 dtype="float32", /tvm_install/tvm/python/tvm/topi/sparse/csrmv.py in csrmv_default_ir(data, indices, indptr, weight, out) 73 with irb.for_range(0, row_elems, name="elemidx") as elemidx: 74 elem = row_start + elemidx ---> 75 dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]] 76 out_ptr[row] += dot[0] 77 return irb.get() /tvm_install/tvm/python/tvm/tir/expr.py in __mul__(self, other) 75 76 def __mul__(self, other): ---> 77 return _generic.multiply(self, other) 78 79 def __rmul__(self, other): /tvm_install/tvm/python/tvm/topi/generic_op_impl.py in _tensor_bop_impl(lhs, rhs) 81 """ 82 if not isinstance(lhs, te.tensor.Tensor) and not isinstance(rhs, te.tensor.Tensor): ---> 83 return orig_bop(lhs, rhs) 84 return broadcast_bop(lhs, rhs) 85 /tvm_install/tvm/python/tvm/tir/generic.py in multiply(lhs, rhs, span) 84 The result Expr of multiply operaton. 85 """ ---> 86 return _ffi_api._OpMul(lhs, rhs, span) 87 88 /tvm_install/tvm/python/tvm/_ffi/_ctypes/packed_func.py in __call__(self, *args) 235 != 0 236 ): --> 237 raise get_last_ffi_error() 238 _ = temp_args 239 _ = args TVMError: Traceback (most recent call last): 3: TVMFuncCall 2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::PrimExpr (tvm::PrimExpr, tvm::PrimExpr, tvm::Span)>::AssignTypedLambda<tvm::{lambda(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)#5}>(tvm::{lambda(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)#5}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) 1: tvm::mul(tvm::PrimExpr, tvm::PrimExpr, tvm::Span) 0: tvm::BinaryOpMatchTypes(tvm::PrimExpr&, tvm::PrimExpr&, tvm::Span) File "/tvm_install/tvm/src/tir/op/op.cc", line 144 TVMError: Cannot match type float64 vs float32 ``` </details> ## Desired fix - `topi.sparse.{csrmv, csrmm}` should be independent of data type. - Add unit tests to `tests/python/topi/python/test_topi_sparse.py` to make sure multiple data types work ## References - https://github.com/apache/tvm/pull/1289 - https://github.com/apache/tvm/issues/1291 - https://github.com/apache/tvm/issues/4332 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
