mbrookhart commented on a change in pull request #5619: URL: https://github.com/apache/incubator-tvm/pull/5619#discussion_r437004308
########## File path: src/relay/op/tensor/transform.cc ########## @@ -781,6 +781,53 @@ non-zero)doc" TVM_ADD_FILELINE) .set_attr<TOpPattern>("TOpPattern", kOpaque) .set_support_level(10); +// Scatter +TVM_REGISTER_NODE_TYPE(ScatterAttrs); + +// Scatter +bool ScatterRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 3); + CHECK_EQ(types.size(), 4); + auto data = types[0].as<TensorTypeNode>(); + if (data == nullptr) { + return false; + } + auto indices = types[1].as<TensorTypeNode>(); + if (indices == nullptr) { + return false; + } + auto updates = types[2].as<TensorTypeNode>(); + if (updates == nullptr) { + return false; + } + CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; + const auto param = attrs.as<ScatterAttrs>(); + CHECK(param != nullptr); + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); + return true; +} + +TVM_REGISTER_GLOBAL("relay.op._make.scatter") + .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) { + auto attrs = make_object<ScatterAttrs>(); + attrs->axis = std::move(axis); + static const Op& op = Op::Get("scatter"); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); + }); + +RELAY_REGISTER_OP("scatter") + .describe( + R"doc(Update data at positions defined by indices with values in updates)doc" TVM_ADD_FILELINE) + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input data tensor.") + .add_argument("indicies", "Tensor", "The indicies location tensor.") + .add_argument("updates", "Tensor", "The values to update the input with.") + .add_type_rel("Scatter", ScatterRel) + .set_attr<TOpIsStateful>("TOpIsStateful", false) + .set_attr<TOpPattern>("TOpPattern", kOpaque) Review comment: Because I washing sure and I wanted to be conservative :) It might be Injective. ########## File path: src/relay/op/tensor/transform.cc ########## @@ -781,6 +781,53 @@ non-zero)doc" TVM_ADD_FILELINE) .set_attr<TOpPattern>("TOpPattern", kOpaque) .set_support_level(10); +// Scatter +TVM_REGISTER_NODE_TYPE(ScatterAttrs); + +// Scatter +bool ScatterRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 3); + CHECK_EQ(types.size(), 4); + auto data = types[0].as<TensorTypeNode>(); + if (data == nullptr) { + return false; + } + auto indices = types[1].as<TensorTypeNode>(); + if (indices == nullptr) { + return false; + } + auto updates = types[2].as<TensorTypeNode>(); + if (updates == nullptr) { + return false; + } + CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; + const auto param = attrs.as<ScatterAttrs>(); + CHECK(param != nullptr); + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); + return true; +} + +TVM_REGISTER_GLOBAL("relay.op._make.scatter") + .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) { + auto attrs = make_object<ScatterAttrs>(); + attrs->axis = std::move(axis); + static const Op& op = Op::Get("scatter"); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); + }); + +RELAY_REGISTER_OP("scatter") + .describe( + R"doc(Update data at positions defined by indices with values in updates)doc" TVM_ADD_FILELINE) + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input data tensor.") + .add_argument("indicies", "Tensor", "The indicies location tensor.") + .add_argument("updates", "Tensor", "The values to update the input with.") + .add_type_rel("Scatter", ScatterRel) + .set_attr<TOpIsStateful>("TOpIsStateful", false) + .set_attr<TOpPattern>("TOpPattern", kOpaque) Review comment: Because I wasn't sure and I wanted to be conservative :) It might be Injective. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org