eric-haibin-lin commented on a change in pull request #10371: [MXNET-263] 
Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU
URL: https://github.com/apache/incubator-mxnet/pull/10371#discussion_r183261461
 
 

 ##########
 File path: tests/python/unittest/test_sparse_operator.py
 ##########
 @@ -1205,6 +1205,27 @@ def check_cast_storage(shape, density, from_stype, 
to_stype, check_numeric_grad=
 
 @with_seed()
 def test_sparse_dot():
+    def test_infer_forward_stype(lhs_shape, rhs_shape, lhs_density, 
rhs_density, trans_a, trans_b):
+        all_stypes = ["default", "csr", "row_sparse"]
+        lhs_nd = rand_ndarray(lhs_shape, 'default', density=lhs_density)
+        rhs_nd = rand_ndarray(rhs_shape, 'default', density=rhs_density)
+        out_nd = mx.nd.dot(lhs_nd, rhs_nd, transpose_a=trans_a, 
transpose_b=trans_b)
+        out_np = out_nd.asnumpy()
+        for lhs_stype in all_stypes:
+            for rhs_stype in all_stypes:
+                for forward_stype in all_stypes:
+                    lhs = lhs_nd.tostype(lhs_stype)
+                    rhs = rhs_nd.tostype(rhs_stype)
+                    out = mx.nd.dot(lhs, rhs, forward_stype=forward_stype,
+                                    transpose_a=trans_a, transpose_b=trans_b)
+                    assert_almost_equal(out.tostype('default').asnumpy(), 
out_np, rtol=1e-4, atol=1e-5)
+                    lhs_var = mx.symbol.Variable('lhs', stype=lhs_stype)
+                    rhs_var = mx.symbol.Variable('rhs', stype=rhs_stype)
+                    out = mx.symbol.sparse.dot(lhs_var, rhs_var,
+                                               forward_stype=forward_stype,
+                                               transpose_a=trans_a, 
transpose_b=trans_b)
+                    location = {'lhs': lhs, 'rhs': rhs}
+                    check_symbolic_forward(out, location, [out_np], rtol=1e-3, 
atol=1e-4)
 
 Review comment:
   can we test check_symbolic_backward for dot(dns, csr)? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to