tkonolige commented on a change in pull request #7048:
URL: https://github.com/apache/tvm/pull/7048#discussion_r548055912



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -3167,22 +3185,200 @@ def convert_matrix_diag(self, op):
         out = _op.matrix_set_diag(input_expr, diagonal_expr)
         return out
 
+    def convert_densify(self, op):
+        """Convert TFLite DENSIFY"""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+
+        sparse_weight_tensor = input_tensors[0]
+        sparse_weight_tensor_type_str = 
self.get_tensor_type_str(sparse_weight_tensor.tensor.Type())
+
+        # NOTE: With current implementation in TFLite, Densify Op does not 
need to be present
+        # in runtime.
+        # TODO(ANSHUMAN87): we need to use the sparse_indices output
+        # from below function and use that in sparse_to_dense Op.
+        # Once the stack corruption issue is resolved in sparse_to_dense Op.
+        _, dense_weight = prepare_dense_matrix_from_sparse(
+            sparse_weight_tensor.tensor,
+            self.get_tensor_value(sparse_weight_tensor, is_sparse=True),
+            sparse_weight_tensor_type_str,
+        )
+
+        self.set_prefetched_node(output_tensor.tensor_idx, dense_weight)
+
     def get_expr(self, input_tensor_idx):
         return self.exp_tab.get_expr(get_tensor_name(self.subgraph, 
input_tensor_idx))
 
     def has_expr(self, input_tensor_idx):
         return self.exp_tab.has_expr(get_tensor_name(self.subgraph, 
input_tensor_idx))
 
-    def get_tensor_expr(self, tensor):
+    def is_prefetched(self, input_tensor_idx):
+        return (
+            self.prefetched_nodes.get(get_tensor_name(self.subgraph, 
input_tensor_idx)) is not None
+        )
+
+    def set_prefetched_node(self, input_tensor_idx, value):
+        self.prefetched_nodes[get_tensor_name(self.subgraph, 
input_tensor_idx)] = value
+
+    def get_prefetched_node(self, input_tensor_idx):
+        return self.prefetched_nodes[get_tensor_name(self.subgraph, 
input_tensor_idx)]
+
+    def get_tensor_expr(self, tensor, is_sparse=False):
         """ Return the Relay expr for tensor. """
         if self.has_expr(tensor.tensor_idx):
             expr = self.get_expr(tensor.tensor_idx)
         else:
             type_str = self.get_tensor_type_str(tensor.tensor.Type())
-            expr = self.exp_tab.new_const(self.get_tensor_value(tensor), 
dtype=type_str)
+            expr = self.exp_tab.new_const(self.get_tensor_value(tensor, 
is_sparse), dtype=type_str)
         return expr
 
 
+# pylint: disable=no-else-return
+def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value, 
sparse_tensor_type):
+    """ Prepare sparse indices and dense matrix from TFLite sparse parameters. 
"""
+    # The function is implemented based on TFLite sparse parameter 
specifications
+    # Please refer
+    # 
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs#L89
+    # for details about each parameters
+    sparsity = sparse_tensor.Sparsity()
+    dense_shape = sparse_tensor.ShapeAsNumpy()
+    orig_rank = len(dense_shape)
+
+    # The traversal order of the dimensions defined in the `shape` field of 
the to be dense tensor.
+    traversal_order = sparsity.TraversalOrderAsNumpy()
+
+    # For an n-dimensional tensor with a k-dimensional block (0 <= k <= n),
+    # stores how a block dimension in (dn, ..., dn+k-1) maps to the original
+    # tensor dimension in (d0, ..., dn). It's stored in the order of (dn, ..., 
dn+k-1).
+    # If not block-sparse, this field is NULL.
+    block_map = sparsity.BlockMapAsNumpy()
+
+    total_rank = sparsity.TraversalOrderLength()
+    dense_mat = np.full(shape=dense_shape, fill_value=0, 
dtype=sparse_tensor_type).flatten()
+
+    from enum import Enum
+
+    # NOTE: Here the Vector term is borrowed from TFLite spec.
+    class VectorType(Enum):
+        Empty = 0
+        Int32 = 1
+        Uint16 = 2
+        Uint8 = 3
+
+    def _get_vector_flag(v_type):
+        if VectorType(v_type) == VectorType.Int32:
+            return N.Int32Flags
+        elif VectorType(v_type) == VectorType.Uint16:
+            return N.Uint16Flags
+        elif VectorType(v_type) == VectorType.Uint8:
+            return N.Uint8Flags
+        else:
+            # raise UnsupportedError("The provided type {} is not 
supported".format(type))

Review comment:
       uncomment or remove




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to