eric-haibin-lin commented on a change in pull request #10371: [MXNET-263] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU URL: https://github.com/apache/incubator-mxnet/pull/10371#discussion_r183261461
########## File path: tests/python/unittest/test_sparse_operator.py ########## @@ -1205,6 +1205,27 @@ def check_cast_storage(shape, density, from_stype, to_stype, check_numeric_grad= @with_seed() def test_sparse_dot(): + def test_infer_forward_stype(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_a, trans_b): + all_stypes = ["default", "csr", "row_sparse"] + lhs_nd = rand_ndarray(lhs_shape, 'default', density=lhs_density) + rhs_nd = rand_ndarray(rhs_shape, 'default', density=rhs_density) + out_nd = mx.nd.dot(lhs_nd, rhs_nd, transpose_a=trans_a, transpose_b=trans_b) + out_np = out_nd.asnumpy() + for lhs_stype in all_stypes: + for rhs_stype in all_stypes: + for forward_stype in all_stypes: + lhs = lhs_nd.tostype(lhs_stype) + rhs = rhs_nd.tostype(rhs_stype) + out = mx.nd.dot(lhs, rhs, forward_stype=forward_stype, + transpose_a=trans_a, transpose_b=trans_b) + assert_almost_equal(out.tostype('default').asnumpy(), out_np, rtol=1e-4, atol=1e-5) + lhs_var = mx.symbol.Variable('lhs', stype=lhs_stype) + rhs_var = mx.symbol.Variable('rhs', stype=rhs_stype) + out = mx.symbol.sparse.dot(lhs_var, rhs_var, + forward_stype=forward_stype, + transpose_a=trans_a, transpose_b=trans_b) + location = {'lhs': lhs, 'rhs': rhs} + check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4) Review comment: can we test check_symbolic_backward for dot(dns, csr)? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services