zhiics commented on a change in pull request #6396:
URL: https://github.com/apache/incubator-tvm/pull/6396#discussion_r483304452



##########
File path: src/relay/op/tensor/transform.h
##########
@@ -180,6 +181,134 @@ static inline Array<Array<Layout>> 
ConcatenateLayout(const Attrs& attrs,
   return Array<Array<Layout>>{Array<Layout>(old_in_layouts.size(), ret), 
{ret}};
 }
 
+static inline Array<IndexExpr> infer_newshape(const Array<IndexExpr>& 
data_shape,

Review comment:
       I suggest we just declare the function in the header and define it in 
the .cc file. Inlining such a large function is usually not preferable though 
it is the compiler's decision to inline it or not.

##########
File path: src/relay/op/tensor/transform.h
##########
@@ -180,6 +181,134 @@ static inline Array<Array<Layout>> 
ConcatenateLayout(const Attrs& attrs,
   return Array<Array<Layout>>{Array<Layout>(old_in_layouts.size(), ret), 
{ret}};
 }
 
+static inline Array<IndexExpr> infer_newshape(const Array<IndexExpr>& 
data_shape,
+                                              const Attrs& attrs) {
+  const auto* param = attrs.as<ReshapeAttrs>();
+  Array<IndexExpr> oshape;
+  Array<IndexExpr> ishape;
+  Array<Integer> newshape;
+
+  if (param->reverse) {
+    ishape.Assign(data_shape.rbegin(), data_shape.rend());
+    newshape.Assign(param->newshape.rbegin(), param->newshape.rend());
+  } else {
+    ishape = data_shape;
+    newshape = param->newshape;
+  }
+
+  std::unordered_set<size_t> used_input_dims;
+  std::unordered_set<size_t> used_output_dims;
+  size_t src_idx = 0;
+  int infer_idx = -1;
+
+  for (size_t i = 0; i < newshape.size(); ++i) {
+    int svalue = newshape[i]->value;
+    // special flag handling for shape inference.
+    if (svalue > 0) {
+      oshape.push_back(newshape[i]);
+      ++src_idx;
+    } else if (svalue == 0) {
+      // keep same
+      CHECK_LT(src_idx, ishape.size());
+      used_input_dims.insert(src_idx);
+      used_output_dims.insert(oshape.size());
+      oshape.push_back(ishape[src_idx++]);
+    } else if (svalue == -1) {
+      // inference based on rest
+      CHECK_LT(infer_idx, 0) << "One and only one dim can be inferred";
+      infer_idx = i;
+      oshape.push_back(1);
+      ++src_idx;
+    } else if (svalue == -2) {
+      // copy all remaining dims from source
+      while (src_idx < ishape.size()) {
+        used_input_dims.insert(src_idx);
+        used_output_dims.insert(oshape.size());
+        oshape.push_back(ishape[src_idx++]);
+      }
+    } else if (svalue == -3) {
+      // merge two dims from source
+      CHECK_LT(src_idx + 1, ishape.size());
+      used_input_dims.insert(src_idx);
+      IndexExpr d1 = ishape[src_idx++];
+      used_input_dims.insert(src_idx);
+      IndexExpr d2 = ishape[src_idx++];
+      used_output_dims.insert(oshape.size());
+      if (d1.as<AnyNode>() || d2.as<AnyNode>()) {
+        oshape.push_back(Any());
+      } else {
+        oshape.push_back(d1 * d2);
+      }
+    } else if (svalue == -4) {
+      // split the source dim s into two dims
+      // read the left dim and then the right dim (either can be -1)
+      CHECK_LT(i + 2, newshape.size());
+      CHECK_LT(src_idx, ishape.size());
+      used_input_dims.insert(src_idx);
+      IndexExpr d0 = ishape[src_idx++];
+      Integer d1 = newshape[++i];
+      Integer d2 = newshape[++i];
+      if (d1->value == -1) {
+        CHECK(d2->value != -1) << "Split dims cannot both be -1.";

Review comment:
       CHECK_NE(d2->value, -1)

##########
File path: tests/python/relay/test_any.py
##########
@@ -801,6 +801,23 @@ def test_any_ndarray_size():
     verify_any_ndarray_size((2, 2))
     verify_any_ndarray_size((1, 2, 3, 4))
 
+def test_reshape_concat():
+    d0 = relay.var("d0", shape=any_dims(2), dtype='float32')
+    d1 = relay.var("d1", shape=any_dims(3), dtype='float32')
+    out = relay.op.concatenate([relay.op.reshape(d0, [-1]), 
relay.op.reshape(d1, [-1])], axis=0)
+    mod = tvm.IRModule()
+    mod['main'] = relay.Function([d0, d1], out)
+    relay.create_executor("vm", mod=mod, ctx=tvm.cpu(), target="llvm")
+
+    d0 = relay.var("d0", shape=any_dims(2), dtype='float32')
+    d1 = relay.var("d1", shape=any_dims(2), dtype='float32')
+    s0 = relay.var("s0", shape=any_dims(3), dtype='float32')
+    s1 = relay.var("s1", shape=any_dims(3), dtype='float32')
+    out = relay.op.concatenate([relay.op.reshape_like(d0, s0), 
relay.op.reshape_like(d1, s1)], axis=0)
+    mod = tvm.IRModule()
+    mod['main'] = relay.Function([d0, d1, s0, s1], out)
+    relay.create_executor("vm", mod=mod, ctx=tvm.cpu(), target="llvm")

Review comment:
       yeah, I have added a `check_result` function. We should be able to just 
use that.

##########
File path: src/relay/op/tensor/transform.h
##########
@@ -180,6 +181,134 @@ static inline Array<Array<Layout>> 
ConcatenateLayout(const Attrs& attrs,
   return Array<Array<Layout>>{Array<Layout>(old_in_layouts.size(), ret), 
{ret}};
 }
 
+static inline Array<IndexExpr> infer_newshape(const Array<IndexExpr>& 
data_shape,
+                                              const Attrs& attrs) {
+  const auto* param = attrs.as<ReshapeAttrs>();
+  Array<IndexExpr> oshape;
+  Array<IndexExpr> ishape;
+  Array<Integer> newshape;
+
+  if (param->reverse) {
+    ishape.Assign(data_shape.rbegin(), data_shape.rend());
+    newshape.Assign(param->newshape.rbegin(), param->newshape.rend());
+  } else {
+    ishape = data_shape;
+    newshape = param->newshape;
+  }
+
+  std::unordered_set<size_t> used_input_dims;
+  std::unordered_set<size_t> used_output_dims;
+  size_t src_idx = 0;
+  int infer_idx = -1;
+
+  for (size_t i = 0; i < newshape.size(); ++i) {
+    int svalue = newshape[i]->value;
+    // special flag handling for shape inference.
+    if (svalue > 0) {
+      oshape.push_back(newshape[i]);
+      ++src_idx;
+    } else if (svalue == 0) {
+      // keep same
+      CHECK_LT(src_idx, ishape.size());
+      used_input_dims.insert(src_idx);
+      used_output_dims.insert(oshape.size());
+      oshape.push_back(ishape[src_idx++]);
+    } else if (svalue == -1) {
+      // inference based on rest
+      CHECK_LT(infer_idx, 0) << "One and only one dim can be inferred";
+      infer_idx = i;
+      oshape.push_back(1);
+      ++src_idx;
+    } else if (svalue == -2) {
+      // copy all remaining dims from source
+      while (src_idx < ishape.size()) {
+        used_input_dims.insert(src_idx);
+        used_output_dims.insert(oshape.size());
+        oshape.push_back(ishape[src_idx++]);
+      }
+    } else if (svalue == -3) {
+      // merge two dims from source
+      CHECK_LT(src_idx + 1, ishape.size());
+      used_input_dims.insert(src_idx);
+      IndexExpr d1 = ishape[src_idx++];
+      used_input_dims.insert(src_idx);
+      IndexExpr d2 = ishape[src_idx++];
+      used_output_dims.insert(oshape.size());
+      if (d1.as<AnyNode>() || d2.as<AnyNode>()) {
+        oshape.push_back(Any());
+      } else {
+        oshape.push_back(d1 * d2);
+      }
+    } else if (svalue == -4) {
+      // split the source dim s into two dims
+      // read the left dim and then the right dim (either can be -1)
+      CHECK_LT(i + 2, newshape.size());
+      CHECK_LT(src_idx, ishape.size());
+      used_input_dims.insert(src_idx);
+      IndexExpr d0 = ishape[src_idx++];
+      Integer d1 = newshape[++i];
+      Integer d2 = newshape[++i];
+      if (d1->value == -1) {
+        CHECK(d2->value != -1) << "Split dims cannot both be -1.";
+        used_output_dims.insert(oshape.size());
+        if (d0.as<AnyNode>()) {
+          oshape.push_back(Any());
+        } else {
+          oshape.push_back(indexdiv(d0, d2));
+        }
+        used_output_dims.insert(oshape.size());
+        oshape.push_back(d2);
+      } else {
+        used_output_dims.insert(oshape.size());
+        oshape.push_back(d1);
+        used_output_dims.insert(oshape.size());
+        if (d2->value == -1) {
+          if (d0.as<AnyNode>()) {
+            oshape.push_back(Any());
+          } else {
+            oshape.push_back(indexdiv(d0, d1));
+          }
+        } else {
+          oshape.push_back(d2);
+        }
+      }
+    } else {
+      CHECK(false) << "Unsupported special value: " << svalue;

Review comment:
       LOG(FATAL)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to