siju-samuel commented on a change in pull request #6685:
URL: https://github.com/apache/incubator-tvm/pull/6685#discussion_r505991920



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -890,6 +890,44 @@ def _impl(inputs, attr, params, mod):
     return _impl
 
 
+def _sparse_tensor_dense_matmul():
+    # Sparse utility from Numpy

Review comment:
       Numpy > Scipy
   

##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -890,6 +890,44 @@ def _impl(inputs, attr, params, mod):
     return _impl
 
 
+def _sparse_tensor_dense_matmul():
+    # Sparse utility from Numpy
+    from scipy import sparse
+
+    def _impl(inputs, attr, params, mod):
+        assert len(inputs) == 4, "There should be 4 input tensors"
+
+        indices_tensor = _infer_value(inputs[0], params, mod).asnumpy()
+        values_tensor = _infer_value(inputs[1], params, mod).asnumpy()
+        dense_shape_tensor = _infer_value(inputs[2], params, mod).asnumpy()
+
+        data = inputs[3]
+
+        rows = [x[0] for x in indices_tensor]
+        cols = [x[1] for x in indices_tensor]
+
+        # Create Numpy sparse Tensor(CSR)
+        weight_sp = sparse.csr_matrix(
+            (values_tensor, (rows, cols)), 
shape=tuple(dense_shape_tensor.tolist())
+        )
+        weight_sp = sparse.csr_matrix(weight_sp.transpose())
+
+        weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
+        weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype)
+        weight_indices = _expr.const(weight_sp.indices, 
weight_sp.indices.dtype)
+
+        ret = _op.nn.sparse_dense(data, [weight_data, weight_indices, 
weight_indptrs])
+
+        # If both are true means First input was dense and second was sparse
+        # TODO: Support other adjoint option too
+        if attr.get("adjoint_a") and attr.get("adjoint_b"):

Review comment:
       return not supported error for other adjoint options

##########
File path: tests/python/frontend/tensorflow/test_forward.py
##########
@@ -1750,6 +1750,64 @@ def test_forward_batch_matmul():
     _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), 
"float32", False, True)
 
 
+#######################################################################
+# SparseTensorDenseMatMul
+# ----------------------------------
+
+
+def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, 
flip=False):
+    """ One iteration of sparse_dense_matmul """
+
+    # TODO: Support adjoint options too
+    for adjoint_a in [False]:
+        for adjoint_b in [False]:
+            with tf.Graph().as_default():
+                A_sp = tf.sparse.SparseTensor(
+                    indices=[[0, 0], [1, 2]], values=[4.0, 8.0], 
dense_shape=A_shape
+                )
+                B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")
+
+                if flip:
+                    result = tf.sparse.sparse_dense_matmul(
+                        B, A_sp, adjoint_a=adjoint_a, adjoint_b=adjoint_b
+                    )
+                else:
+                    result = tf.sparse.sparse_dense_matmul(
+                        A_sp, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b
+                    )
+
+                B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
+
+                # TODO: There is an issue in cuda scheduling for csr, work in 
progress
+                compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True)

Review comment:
       Need a followup pr to solve the cuda scheduling issue for csr

##########
File path: tests/python/frontend/tensorflow/test_forward.py
##########
@@ -1750,6 +1750,64 @@ def test_forward_batch_matmul():
     _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), 
"float32", False, True)
 
 
+#######################################################################
+# SparseTensorDenseMatMul
+# ----------------------------------
+
+
+def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, 
flip=False):
+    """ One iteration of sparse_dense_matmul """
+
+    # TODO: Support adjoint options too
+    for adjoint_a in [False]:
+        for adjoint_b in [False]:
+            with tf.Graph().as_default():
+                A_sp = tf.sparse.SparseTensor(
+                    indices=[[0, 0], [1, 2]], values=[4.0, 8.0], 
dense_shape=A_shape
+                )
+                B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")
+
+                if flip:
+                    result = tf.sparse.sparse_dense_matmul(
+                        B, A_sp, adjoint_a=adjoint_a, adjoint_b=adjoint_b
+                    )
+                else:
+                    result = tf.sparse.sparse_dense_matmul(
+                        A_sp, B, adjoint_a=adjoint_a, adjoint_b=adjoint_b
+                    )
+
+                B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
+
+                # TODO: There is an issue in cuda scheduling for csr, work in 
progress
+                compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True)
+
+
+def test_forward_sparse_dense_matmul():
+    """ sparse_dense_matmul op test"""
+    ###################################################################
+    #
+    # In order to create a SparseTensor, it requires 3 input as below:
+    #    SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 
4])
+    #
+    # Above Sparse can be represented in Dense as below :
+    #    [[1, 0, 0, 0]
+    #     [0, 0, 2, 0]
+    #     [0, 0, 0, 0]]
+    #
+    # ------------------------------------------------------------------
+
+    # TODO: False case for flip need to be supported

Review comment:
       I suggest this pr can include flip false case as well




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to