lixiaoquan commented on a change in pull request #5699:
URL: https://github.com/apache/incubator-tvm/pull/5699#discussion_r433043210



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -955,34 +930,22 @@ def _impl(inputs, attr, params, prelude):
             v = tensor_func(inputs[2])
             write_func = prelude.get_var('tensor_array_write', dtype_str)
         else:
-            # For write operation, it is possible to write to a newly create
-            # tensor array. We need to check and recreate its input tensor 
array.
-            if input_ta in _static_tensor_array_map and \
-                    _static_tensor_array_map[input_ta] is None:
-                static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                               dtype_str,
-                                                               input_t_shape)
-                static_tensor_array_ops.register()
-                ta_constructor = prelude.get_var_static('tensor_array',
-                                                        dtype_str,
-                                                        input_t_shape)
-                new_ta = ta_constructor(input_ta.args[0])
-                _static_tensor_array_map[input_ta] = new_ta
-                input_ta = new_ta
-                input_ta_shape = input_t_shape
-            else:
-                input_ta_rank = len(input_ta_shape)
-                assert input_ta_rank == input_rank, "Shape rank mismatch: {} 
vs {}". \
-                    format(input_ta_rank, input_rank)
-                static_tensor_array_ops = StaticTensorArrayOps(prelude,
-                                                               dtype_str,
-                                                               input_ta_shape)
-                static_tensor_array_ops.register()
+            input_ta_rank = len(input_ta_shape)
+            assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs 
{}". \
+                format(input_ta_rank, input_rank)
+            static_tensor_array_ops = StaticTensorArrayOps(prelude,
+                                                           dtype_str,
+                                                           input_ta_shape)
+            static_tensor_array_ops.register()
 
             tensor_func = prelude.get_var_static("tensor_constructor",
                                                  dtype_str,
                                                  input_ta_shape)
-            v = tensor_func(inputs[2])
+            actual_shape = _get_more_static_shape(input_t_shape, 
input_ta_shape)

Review comment:
       actual_shape may have tvm.tir.Any in it so it may need be converted 
before passed to reshape()




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to