inadob commented on a change in pull request #4704: [Relay][Frontend][TFLite]
Add parser support for arg_min_max
URL: https://github.com/apache/incubator-tvm/pull/4704#discussion_r385636005
##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -826,6 +828,50 @@ def _convert_reduce_prod(self, op):
def _convert_reduce_sum(self, op):
return self._convert_reduce(_op.reduce.sum, op)
+ def _convert_arg_min_max(self, relay_op, op):
+ """Generic method to convert TFLite arg_min_max"""
+ try:
+ from tflite.Operator import Operator
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.ArgMinOptions import ArgMinOptions
+ from tflite.ArgMaxOptions import ArgMaxOptions
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ assert isinstance(op, Operator)
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+
+ input_tensor = input_tensors[0]
+ in_expr = self.get_expr(input_tensor.tensor_idx)
+ axis_tensor = input_tensors[1]
+ # we support the case when the axis is a scalar not a tensor
+ axis_value = int(self.get_tensor_value(axis_tensor))
+
+ if op.BuiltinOptionsType() == BuiltinOptions.ArgMinOptions:
+ arg_min_max_options = ArgMinOptions()
+ elif op.BuiltinOptionsType() == BuiltinOptions.ArgMaxOptions:
+ arg_min_max_options = ArgMaxOptions()
+ op_options = op.BuiltinOptions()
+ arg_min_max_options.Init(op_options.Bytes, op_options.Pos)
+ output_dtype = arg_min_max_options.OutputType()
+
+ # set keepdims to True since tflite 1.13 removes all dims of size 1
+ # WARNING: all other versions of tflite > 1.13 need keepdims=False
Review comment:
After doing some further investigation, it turned out that `arg_min` and
`arg_max` are fundamentally broken in TFL 1.13 since the behaviour described in
the in flatbuff file is different from what is actually calculated if we do the
inference directly (as in all our parser tests).
You can see the changes between TFL 1.13 and TFL 1.14 for these ops here:
https://github.com/tensorflow/tensorflow/commit/1d0552e94d4e091ce03248675ac92c1e1e0accae
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services