yongwww commented on a change in pull request #5459:
URL: https://github.com/apache/incubator-tvm/pull/5459#discussion_r427496368



##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1019,7 +1016,8 @@ RELAY_REGISTER_OP("ones")
 
 )code" TVM_ADD_FILELINE)
     .set_attrs_type<InitOpAttrs>()
-    .set_num_inputs(0)
+    .set_num_inputs(1)
+    .add_argument("shape", "tensor", "Target shape.")

Review comment:
       ```suggestion
       .add_argument("shape", "Tensor", "Target shape.")
   ```

##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1000,16 +994,19 @@ RELAY_REGISTER_OP("zeros")
 
 )code" TVM_ADD_FILELINE)
     .set_attrs_type<InitOpAttrs>()
-    .set_num_inputs(0)
+    .set_num_inputs(1)
+    .add_argument("shape", "tensor", "Target shape.")

Review comment:
       ```suggestion
       .add_argument("shape", "Tensor", "Target shape.")
   ```

##########
File path: src/relay/op/tensor/transform.cc
##########
@@ -1579,39 +1577,52 @@ RELAY_REGISTER_OP("collapse_sum_like")
 // BroadCastTo: <A, B> -> B where BroadCast(A, B) = B
 bool BroadCastToRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
                     const TypeReporter& reporter) {
-  CHECK_EQ(types.size(), 2);
-  auto ioattrs = attrs.as<InitOpAttrs>();
-  CHECK(ioattrs);
-  auto intt = types[0].as<TensorTypeNode>();
-  if (intt == nullptr) {
-    return false;
+  CHECK_EQ(types.size(), 3);
+  const InitOpAttrs* param = attrs.as<InitOpAttrs>();
+  const auto* target_shape = types[1].as<TensorTypeNode>();
+  DataType out_dtype = types[0].as<TensorTypeNode>()->dtype;
+
+  const IntImmNode* shape_shape = target_shape->shape[0].as<IntImmNode>();
+  CHECK(shape_shape) << "Parameter shape must have static shape";
+
+  std::vector<IndexExpr> oshape;
+  if (param->shape) {
+    const Array<Integer>& cshape_array = param->shape.value();
+    for (size_t i = 0; i < cshape_array.size(); ++i) {
+      oshape.push_back(cshape_array[i]);
+    }
+  } else {
+    for (int i = 0; i < shape_shape->value; ++i) {
+      oshape.push_back(Any::make());
+    }
   }
-  auto type = TensorType(ioattrs->shape, intt->dtype);
-  reporter->Assign(types[1], type);
-  return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
+  return BroadcastRel({types[0], types[2], types[2]}, 2, Attrs(), reporter);
 }
 
-Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape) {
+Expr MakeBroadCastTo(Expr data, Expr shape) {
   static const Op& op = Op::Get("broadcast_to");
   auto attrs = make_object<InitOpAttrs>();
-  attrs->shape = std::move(shape);
-  return Call(op, {data}, Attrs(attrs), {});
+  if (const auto* cshape = shape.as<ConstantNode>()) {
+    attrs->shape = ToVector(cshape->data);
+  }
+  return Call(op, {data, shape}, Attrs(attrs), {});
 }
 
 Array<te::Tensor> BroadCastToCompute(const Attrs& attrs, const 
Array<te::Tensor>& inputs,
                                      const Type& out_type) {
-  auto ioattrs = attrs.as<InitOpAttrs>();
-  CHECK(ioattrs != nullptr);
-  return {topi::broadcast_to(inputs[0], ioattrs->shape)};
+  const auto* out_ttype = out_type.as<TensorTypeNode>();
+  return {topi::broadcast_to(inputs[0], out_ttype->shape)};
 }
 
 
TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to").set_body_typed(MakeBroadCastTo);
 
 RELAY_REGISTER_OP("broadcast_to")
     .describe(R"code(Broadcast the first input to match the shape argument.
 )code" TVM_ADD_FILELINE)
-    .set_num_inputs(1)
+    .set_num_inputs(2)
     .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("shape", "tensor", "Target shape.")

Review comment:
       ```suggestion
       .add_argument("shape", "Tensor", "Target shape.")
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to