This is an automated email from the ASF dual-hosted git repository. zhaowu pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 006b9b5 [Frontend][TFLite] Densify Op added (#7048) 006b9b5 is described below commit 006b9b53ab97e677933011d8b36a98a5a4ac7723 Author: ANSHUMAN TRIPATHY <anshuma...@huawei.com> AuthorDate: Wed Jan 13 19:50:02 2021 +0530 [Frontend][TFLite] Densify Op added (#7048) * [Frontend][TFLite] Densify Op added * [1] Review comments handled * TODO added for sparse_to_dense Op usage * stale comments removed --- python/tvm/relay/frontend/tflite.py | 215 +++++++++++++++++++++++++-- tests/python/frontend/tflite/test_forward.py | 48 ++++++ 2 files changed, 253 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 7a2aada..525fb41 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -65,6 +65,7 @@ class OperatorConverter(object): self.builtin_op_code = build_str_map(BuiltinOperator()) self.activation_fn_type = build_str_map(ActivationFunctionType()) self.builtin_options = build_str_map(BuiltinOptions()) + self.prefetched_nodes = {} # Add more operators self.convert_map = { @@ -80,6 +81,7 @@ class OperatorConverter(object): "CONCATENATION": self.convert_concatenation, "CONV_2D": self.convert_conv2d, "COS": self.convert_cos, + "DENSIFY": self.convert_densify, "DEPTH_TO_SPACE": self.convert_depth_to_space, "DEPTHWISE_CONV_2D": self.convert_depthwise_conv2d, "DEQUANTIZE": self.convert_dequantize, @@ -200,6 +202,10 @@ class OperatorConverter(object): assert isinstance(op, Operator) ret = self.convert_map[op_code_str](op) + # In case the Op can be prefetched, the output can be optimized out + if ret is None: + continue + if len(output_tensors) == 1: tensor_idx = output_tensors[0].tensor_idx self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret) @@ -338,7 +344,8 @@ class OperatorConverter(object): "Tensor type '{}' currently not supported".format(tensor_wrapper.tensor.Type()) ) - def get_tensor_value(self, tensor_wrapper): + # pylint: disable=no-else-return + def get_tensor_value(self, tensor_wrapper, is_sparse=False): """Get tensor buffer value from given tensor wrapper""" assert isinstance(tensor_wrapper, TensorWrapper) @@ -350,7 +357,10 @@ class OperatorConverter(object): else: shape = [] - return np.frombuffer(data, dtype=dtype).reshape(shape) + if is_sparse: + return np.frombuffer(data, dtype=dtype) + else: + return np.frombuffer(data, dtype=dtype).reshape(shape) def get_tensor_type_str(self, tensor_type): """Get tensor type string representation when given TFLite tensor type""" @@ -1662,11 +1672,15 @@ class OperatorConverter(object): axis = tuple(axis_value) if len(axis_value.shape) > 0 else tuple((axis_value.item(),)) # Options - keep_dims (bool) - assert op.BuiltinOptionsType() == BuiltinOptions.ReducerOptions - reduce_options = ReducerOptions() - op_options = op.BuiltinOptions() - reduce_options.Init(op_options.Bytes, op_options.Pos) - keep_dims = reduce_options.KeepDims() + # In case Options are not present, set keep_dims to False(default) + if op.BuiltinOptionsType(): + assert op.BuiltinOptionsType() == BuiltinOptions.ReducerOptions + reduce_options = ReducerOptions() + op_options = op.BuiltinOptions() + reduce_options.Init(op_options.Bytes, op_options.Pos) + keep_dims = reduce_options.KeepDims() + else: + keep_dims = False if input_tensor.qnn_params: in_expr = _op.cast(in_expr, "int32") @@ -2026,7 +2040,11 @@ class OperatorConverter(object): else: weight_expr = _op.transpose(weight_expr, axes=(1, 2, 3, 0)) else: - weight_value = self.get_tensor_value(weight_tensor) + if self.is_prefetched(weight_tensor.tensor_idx): + weight_value = self.get_prefetched_node(weight_tensor.tensor_idx) + else: + weight_value = self.get_tensor_value(weight_tensor) + # TFLite kernel layout: # convolution: # OC KH KW IC, we require KH KW IC OC (HWIO) @@ -3196,22 +3214,199 @@ class OperatorConverter(object): out = _op.matrix_set_diag(input_expr, diagonal_expr) return out + def convert_densify(self, op): + """Convert TFLite DENSIFY""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + sparse_weight_tensor = input_tensors[0] + sparse_weight_tensor_type_str = self.get_tensor_type_str(sparse_weight_tensor.tensor.Type()) + + # NOTE: With current implementation in TFLite, Densify Op does not need to be present + # in runtime. + # TODO(ANSHUMAN87): we need to use the sparse_indices output + # from below function and use that in sparse_to_dense Op. + # Once the stack corruption issue is resolved in sparse_to_dense Op. + _, dense_weight = prepare_dense_matrix_from_sparse( + sparse_weight_tensor.tensor, + self.get_tensor_value(sparse_weight_tensor, is_sparse=True), + sparse_weight_tensor_type_str, + ) + + self.set_prefetched_node(output_tensor.tensor_idx, dense_weight) + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) def has_expr(self, input_tensor_idx): return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) - def get_tensor_expr(self, tensor): + def is_prefetched(self, input_tensor_idx): + return ( + self.prefetched_nodes.get(get_tensor_name(self.subgraph, input_tensor_idx)) is not None + ) + + def set_prefetched_node(self, input_tensor_idx, value): + self.prefetched_nodes[get_tensor_name(self.subgraph, input_tensor_idx)] = value + + def get_prefetched_node(self, input_tensor_idx): + return self.prefetched_nodes[get_tensor_name(self.subgraph, input_tensor_idx)] + + def get_tensor_expr(self, tensor, is_sparse=False): """ Return the Relay expr for tensor. """ if self.has_expr(tensor.tensor_idx): expr = self.get_expr(tensor.tensor_idx) else: type_str = self.get_tensor_type_str(tensor.tensor.Type()) - expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str) + expr = self.exp_tab.new_const(self.get_tensor_value(tensor, is_sparse), dtype=type_str) return expr +# pylint: disable=no-else-return +def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value, sparse_tensor_type): + """ Prepare sparse indices and dense matrix from TFLite sparse parameters. """ + # The function is implemented based on TFLite sparse parameter specifications + # Please refer + # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs#L89 + # for details about each parameters + sparsity = sparse_tensor.Sparsity() + dense_shape = sparse_tensor.ShapeAsNumpy() + orig_rank = len(dense_shape) + + # The traversal order of the dimensions defined in the `shape` field of the to be dense tensor. + traversal_order = sparsity.TraversalOrderAsNumpy() + + # For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), + # stores how a block dimension in (dn, ..., dn+k-1) maps to the original + # tensor dimension in (d0, ..., dn). It's stored in the order of (dn, ..., dn+k-1). + # If not block-sparse, this field is NULL. + block_map = sparsity.BlockMapAsNumpy() + + total_rank = sparsity.TraversalOrderLength() + dense_mat = np.full(shape=dense_shape, fill_value=0, dtype=sparse_tensor_type).flatten() + + from enum import Enum + + # NOTE: Here the Vector term is borrowed from TFLite spec. + class VectorType(Enum): + Empty = 0 + Int32 = 1 + Uint16 = 2 + Uint8 = 3 + + def _get_vector_flag(v_type): + if VectorType(v_type) == VectorType.Int32: + return N.Int32Flags + elif VectorType(v_type) == VectorType.Uint16: + return N.Uint16Flags + elif VectorType(v_type) == VectorType.Uint8: + return N.Uint8Flags + else: + raise tvm.error.OpNotImplemented("The provided type {} is not supported".format(v_type)) + + def _get_flattened_index(indices, shape): + index = 0 + sub_elements = 1 + for i in reversed(range(0, len(dense_shape))): + index += indices[i] * sub_elements + sub_elements *= shape[i] + return index + + # DimensionMetadata per dimension: the metadata needed for + # each dimension to locate the non-zero values in the original dense tensor + # inline with traversal order parameter. + # + # sp_format has 2 possible values: {DENSE = 0, SPARSE_CSR = 1} + # If format = DENSE{0} : DenseSize represents size of that dimension + # If format = SPARSE_CSR{1} : array_segments represents how to segment the indices array, + # each segment corresponds to one element in the previous dimension. array_indices + # represents the index of the non-zero elements within this dimension + # (as those in the CSR matrix format, where the first array is row pointers + # and the second array is column indices). + sp_format = np.zeros(sparsity.DimMetadataLength()) + dim_metadata = [None] * (2 * sparsity.DimMetadataLength()) + + # Below loop will fetch all meta data per dimension based on format type + # Dense or Sparse and will put it in an agnostic array for easy access + # while preparing dense buffer or indices. + for i in range(sparsity.DimMetadataLength()): + sp_format[i] = sparsity.DimMetadata(i).Format() + if sp_format[i] == 0: + dim_metadata[2 * i] = [sparsity.DimMetadata(i).DenseSize()] + else: + from flatbuffers import number_types as N + + dim_metadata[2 * i] = ( + sparsity.DimMetadata(i) + .ArraySegments() + .GetVectorAsNumpy( + flags=_get_vector_flag(sparsity.DimMetadata(i).ArraySegmentsType()), off=4 + ) + ) + dim_metadata[2 * i + 1] = ( + sparsity.DimMetadata(i) + .ArrayIndices() + .GetVectorAsNumpy( + flags=_get_vector_flag(sparsity.DimMetadata(i).ArrayIndicesType()), off=4 + ) + ) + + block_dim = 0 + block_size = np.zeros(sparsity.BlockMapLength()) + + # Block size parameter if encoded in BSR format + for i in range(orig_rank): + if block_dim < sparsity.BlockMapLength() and block_map[block_dim] == i: + orig_dim = traversal_order[orig_rank + block_dim] + block_size[block_dim] = sparsity.DimMetadata(orig_dim).DenseSize() + block_dim += 1 + + indices_list = [] + + # Below function iterates through each applicable indices per dimension + # based on format type specified and finaly produce the dense matrix and the NZ indices. + def _def_prepare_dense_matrix_from_sparse(indices, level, prev_idx): + if level == len(indices): + start_pos = 0 + orig_idx = np.zeros(orig_rank, dtype="int32") + while start_pos < orig_rank: + orig_idx[traversal_order[start_pos]] = indices[start_pos] + start_pos += 1 + while start_pos < len(indices): + block_idx = traversal_order[start_pos] - orig_rank + orig_dim = block_map[block_idx] + orig_idx[orig_dim] = orig_idx[orig_dim] * block_size[block_idx] + indices[start_pos] + start_pos += 1 + indices_list.append(orig_idx) + nonlocal value_idx + dense_mat[_get_flattened_index(orig_idx, dense_shape)] = sparse_tensor_value[value_idx] + value_idx += 1 + else: + metadata_idx = 2 * level + if sp_format[level] == 0: + shape_of_level = dim_metadata[metadata_idx][0] + for idx in range(shape_of_level): + indices[level] = idx + _def_prepare_dense_matrix_from_sparse( + indices, level + 1, prev_idx * shape_of_level + idx + ) + else: + array_segments = dim_metadata[metadata_idx] + array_indices = dim_metadata[metadata_idx + 1] + for idx in range(array_segments[prev_idx], array_segments[prev_idx + 1]): + indices[level] = array_indices[idx] + _def_prepare_dense_matrix_from_sparse(indices, level + 1, idx) + + indices = np.zeros(total_rank) + value_idx = 0 + _def_prepare_dense_matrix_from_sparse(indices, 0, 0) + return np.array(indices_list, dtype="int32"), dense_mat.reshape(dense_shape) + + def get_scalar_from_constant(expr): """ Returns scalar value from Relay constant scalar. """ assert ( diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index c8bd094..f365301 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -3692,6 +3692,50 @@ def test_forward_mobilenet_v3(): ####################################################################### +# Mobilenet V1 Sparse +# ----------------- + + +def test_forward_sparse_mobilenet_v1(): + """Test the Sparse version of Mobilenet V1 TF Lite model.""" + # MobilenetV1 + tflite_model_file = download_testdata( + "https://storage.googleapis.com/fast-convnets/tflite-models/mbv1_140_90_12b4_720.tflite", + "mbv1_140_90_12b4_720.tflite", + ) + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32") + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, "float_image_input") + tvm.testing.assert_allclose( + np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5 + ) + + +####################################################################### +# Mobilenet V2 Sparse +# ----------------- + + +def test_forward_sparse_mobilenet_v2(): + """Test the Sparse version of Mobilenet V2 TF Lite model.""" + # MobilenetV1 + tflite_model_file = download_testdata( + "https://storage.googleapis.com/fast-convnets/tflite-models/mbv2_200_85_11-16b2_744.tflite", + "mbv2_200_85_11-16b2_744.tflite", + ) + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32") + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, "float_image_input") + tvm.testing.assert_allclose( + np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5 + ) + + +####################################################################### # Inception # --------- @@ -4197,6 +4241,10 @@ if __name__ == "__main__": test_forward_coco_ssd_mobilenet_v1() test_forward_mediapipe_hand_landmark() + # End to End Sparse models + test_forward_sparse_mobilenet_v1() + test_forward_sparse_mobilenet_v2() + # End to End quantized test_forward_qnn_inception_v1_net() test_forward_qnn_mobilenet_v1_net()