This is an automated email from the ASF dual-hosted git repository. mbrookhart pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new b8893b557a [Relay][Op] Trilu operator implementation (#12124) b8893b557a is described below commit b8893b557a6c213dfe06f4069fad3cf5ad70051e Author: Josh Fromm <jwfr...@octoml.ai> AuthorDate: Tue Aug 2 12:48:59 2022 -0700 [Relay][Op] Trilu operator implementation (#12124) * Added topi trilu implementation * Implemented and tested full Trilu op. * Fix test type. * Add tril zero tests. * Add pytorch trilu integration. * Clean up torch integration. * Readded skip for zero tests. --- include/tvm/relay/attrs/transform.h | 9 ++++ python/tvm/relay/frontend/onnx.py | 15 +++++++ python/tvm/relay/frontend/pytorch.py | 35 ++++----------- python/tvm/relay/op/_transform.py | 4 ++ python/tvm/relay/op/op_attrs.py | 5 +++ python/tvm/relay/op/strategy/generic.py | 28 ++++++++++++ python/tvm/relay/op/transform.py | 43 ++++++++++++++++++ python/tvm/topi/transform.py | 58 +++++++++++++++++++++++++ src/relay/op/tensor/transform.cc | 50 +++++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 16 ------- tests/python/frontend/pytorch/test_forward.py | 10 +++++ tests/python/relay/test_op_level3.py | 29 +++++++++++++ tests/python/topi/python/test_topi_transform.py | 39 +++++++++++++++++ 13 files changed, 298 insertions(+), 43 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index b9f8c6e1e8..2741d68eec 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -575,6 +575,15 @@ struct StftAttrs : public tvm::AttrsNode<StftAttrs> { } }; // struct StftAttrs +struct TriluAttrs : public tvm::AttrsNode<TriluAttrs> { + bool upper; + + TVM_DECLARE_ATTRS(TriluAttrs, "relay.attrs.TriluAttrs") { + TVM_ATTR_FIELD(upper).set_default(true).describe( + "Whether to keep the upper or lower half of the diagonal."); + } +}; // struct TriluAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3b5bf9acfa..e78e65dc4e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -4685,6 +4685,20 @@ class Einsum(OnnxOpConverter): return _op.einsum(inputs, equation) +class Trilu(OnnxOpConverter): + """Operator converter for Trilu""" + + @classmethod + def _impl_v14(cls, inputs, attr, params): + upper = attr.get("upper", True) + if len(inputs) == 2: + data, k = inputs + else: + data = inputs[0] + k = 0 + return _op.trilu(data, k, upper) + + class RandomNormal(OnnxOpConverter): """Operator converter for random_normal""" @@ -5345,6 +5359,7 @@ def _get_convert_map(opset): "CumSum": CumSum.get_converter(opset), "Unique": Unique.get_converter(opset), "Einsum": Einsum.get_converter(opset), + "Trilu": Trilu.get_converter(opset), # defs/control_flow "Loop": Loop.get_converter(opset), "If": If.get_converter(opset), diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1bd3232871..74ea249a47 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -318,31 +318,6 @@ class PyTorchOpConverter: (dtype,) = input_types return _op.power(inputs[0], _expr.const(2, dtype)) - def tril(self, inputs, input_types): - data = inputs[0] - if len(inputs) == 2: - k_value = inputs[1] - else: - k_value = 0 - input_shape = self.infer_shape(data) - k1, k2 = input_shape[-2:] - k1 = k_value + 1 - diag_input = _op.zeros(input_shape, dtype=input_types[0]) - return _op.matrix_set_diag(data, diag_input, k=(k1, k2)) - - def triu(self, inputs, input_types): - data = inputs[0] - if len(inputs) == 2: - k_value = inputs[1] - else: - k_value = 0 - input_shape = self.infer_shape(data) - k1, k2 = input_shape[-2:] - k1 = (k1 * -1) - 1 - k2 = k_value - 1 - diag_input = _op.zeros(input_shape, dtype=input_types[0]) - return _op.matrix_set_diag(data, diag_input, k=(k1, k2)) - def lerp(self, inputs, input_types): if len(inputs) != 3: msg = "Wrong number of arguments (%d) to parse." % (len(inputs)) @@ -3405,6 +3380,12 @@ class PyTorchOpConverter: inputs[0], grid, interpolate_str, layout, padding_mode_str, align_corners ) + def trilu(self, inputs, input_types, mode): + data = inputs[0] + k = inputs[1] if inputs[1] else 0 + upper = True if mode == "triu" else False + return _op.trilu(data, k, upper) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -3567,8 +3548,8 @@ class PyTorchOpConverter: "aten::sqrt": self.make_unary("sqrt"), "aten::rsqrt": self.make_unary("rsqrt"), "aten::square": self.square, - "aten::tril": self.tril, - "aten::triu": self.triu, + "aten::tril": functools.partial(self.trilu, mode="tril"), + "aten::triu": functools.partial(self.trilu, mode="triu"), "aten::ceil": self.make_unary("ceil"), "aten::floor": self.make_unary("floor"), "aten::round": self.make_unary("round"), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index baf616a946..951de06967 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -191,6 +191,10 @@ def stft_shape_func(attrs, inputs, _): ] +# trilu +_reg.register_strategy("trilu", strategy.trilu_strategy) + + # scatter_add @_reg.register_compute("scatter_add") def compute_scatter_add(attrs, inputs, output_type): diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 8b92fdf267..7e8367abbb 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -617,3 +617,8 @@ class NLLLossAttrs(Attrs): @tvm._ffi.register_object("relay.attrs.FixedPointMultiplyAttrs") class FixedPointMultiplyAttrs(Attrs): """Attributes used in fixed_point_multiply operators""" + + +@tvm._ffi.register_object("relay.attrs.TriluAttrs") +class TriluAttrs(Attrs): + """Attributes used in trilu operators""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 6074b0a69c..95558b5f3d 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1460,6 +1460,34 @@ def wrap_compute_stft(topi_compute): return _compute_stft +# trilu +@override_native_generic_func("trilu_strategy") +def trilu_strategy(attrs, outs, out_type, target): + """trilu generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_trilu(topi.trilu), + wrap_topi_schedule(topi.generic.schedule_extern), + name="trilu.generic", + ) + return strategy + + +def wrap_compute_trilu(topi_compute): + """Wrap trilu compute""" + + def _compute_trilu(attrs, inputs, output_type): + return [ + topi_compute( + inputs[0], + inputs[1], + attrs.upper, + ) + ] + + return _compute_trilu + + # roi_pool @generic_func def schedule_roi_pool(attrs, outs, target): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index b5d44781e5..e7ae5f7d83 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1889,3 +1889,46 @@ def stft( window = _make.ones([n_fft], "int32") return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided) + + +def trilu(data, k, upper=True): + """ + Given a 2-D matrix or batches of 2-D matrices, returns the + upper or lower triangular part of the tensor. + + Parameters + ---------- + data: relay.Expr + The tensor that trilu will be applied to. Must be either + a 2D matrix or a tensor of batches of 2D matrices. + + k: int + The number of diagonals above or below the main diagonal + to exclude or include. + + upper: bool, optional + If True, only upper triangular values of input are kept, + if False, the lower triangular values are kept. + + + Returns + ------- + ret : relay.Expr + The new tensor with appropriate diagonals set to zero. + + Examples + -------- + .. code-block:: python + + x = [[0, 1, 2], + [3, 4, 5], + [6, 7, 8]] + + relay.trilu(x, True, 0) = + [[0, 1, 2], + [0, 4, 5], + [0, 0, 8]] + """ + if not isinstance(k, Expr): + k = const(k, dtype="int32") + return _make.trilu(data, k, upper) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index d99d6772b0..e12f80e2ef 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -1001,3 +1001,61 @@ def sliding_window(data, axis, window_shape, strides): The resulting tensor. """ return cpp.sliding_window(data, axis, window_shape, strides) + + +def trilu(data, k, upper): + """ + Given a 2-D matrix or batches of 2-D matrices, returns the + upper or lower triangular part of the tensor. + + Parameters + ---------- + data: tvm.te.Tensor + The tensor that trilu will be applied to. Must be either + a 2D matrix or a tensor of batches of 2D matrices. + + k: tvm.te.Tensor + The number of diagonals above or below the main diagonal + to exclude or include. + + upper: bool + If True, only upper triangular values of input are kept, + if False, the lower triangular values are kept. + + + Returns + ------- + ret : relay.Expr + The new tensor with appropriate diagonals set to zero. + + Examples + -------- + .. code-block:: python + + x = [[0, 1, 2], + [3, 4, 5], + [6, 7, 8]] + + relay.trilu(x, True, 0) = + [[0, 1, 2], + [0, 4, 5], + [0, 0, 8]] + """ + # Make sure datatype is consistent. + if k.dtype != "int32": + k = tvm.tir.Cast("int32", k) + + # Check either above or below diagonal depending on upper. + check_op = tvm.tir.GE + if upper: + check_op = tvm.tir.LE + + def _apply_trilu(*indices): + row_index = indices[-2] + col_index = indices[-1] + other_indices = indices[:-2] + check_position = check_op(row_index, col_index - k) + value = data(*other_indices, row_index, col_index) + return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype)) + + return te.compute(data.shape, _apply_trilu, name="trilu") diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 989ab2ad25..f90cd91e92 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -4230,5 +4230,55 @@ RELAY_REGISTER_OP("invert_permutation") .set_attr<TOpPattern>("TOpPattern", kInjective) .set_attr<TOpIsStateful>("TOpIsStateful", false); +// Trilu + +TVM_REGISTER_NODE_TYPE(TriluAttrs); + +bool TriluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, k, result] + ICHECK_EQ(types.size(), 3) << "Trilu: expect 3 types but " << types.size() << " provided"; + ICHECK_EQ(num_inputs, 2) << "Trilu: expect 2 inputs but " << num_inputs << " provided"; + auto data = types[0].as<TensorTypeNode>(); + if (data == nullptr) { + ICHECK(types[0].as<IncompleteTypeNode>()) + << "Trilu: expect input type to be TensorType but get " << types[0]; + return false; + } + + auto k = types[1].as<TensorTypeNode>(); + if (k == nullptr) { + ICHECK(types[1].as<IncompleteTypeNode>()) + << "Trilu: expect k type to be TensorType but get " << types[1]; + return false; + } + + ICHECK(k->shape.size() == 0) << "Trilu: k must be a 0-D tensor but get " << k; + + // Output shape is the same as input shape. + reporter->Assign(types[2], TensorType(data->shape, data->dtype)); + return true; +} + +Expr MakeTrilu(Expr data, Expr k, bool upper) { + auto attrs = make_object<TriluAttrs>(); + attrs->upper = upper; + static const Op& op = Op::Get("trilu"); + return Call(op, {data, k}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.trilu").set_body_typed(MakeTrilu); + +RELAY_REGISTER_OP("trilu") + .describe( + R"code(Filters out the upper or lower portion of an input tensor on one side of a diagonal. + )code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor") + .add_argument("k", "Tensor", "The number of diagonals above or below the main to exclude.") + .add_type_rel("trilu", TriluRel) + .set_support_level(3) + .set_attr<TOpPattern>("TOpPattern", kElemWise); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0b2e51e544..e500f0902c 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5242,23 +5242,7 @@ unsupported_onnx_tests = [ "test_training_dropout_mask", "test_training_dropout_zero_ratio", "test_training_dropout_zero_ratio_mask", - "test_tril", - "test_tril_pos", - "test_tril_square", - "test_tril_square_neg", - "test_tril_neg", - "test_tril_one_row_neg", - "test_tril_out_neg", - "test_tril_out_pos", "test_tril_zero", - "test_triu", - "test_triu_one_row", - "test_triu_out_neg_out", - "test_triu_out_pos", - "test_triu_neg", - "test_triu_pos", - "test_triu_square", - "test_triu_square_neg", "test_triu_zero", "test_unique_sorted_with_axis", "test_unique_sorted_with_axis_3d", diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index f52c7168b3..1d07c780b7 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4616,5 +4616,15 @@ def test_lerp(): verify_model(test_fn, [x, y, w[0]]) +def test_trilu(): + def _test_trilu(op, diagonal): + return lambda inp: op(inp, diagonal) + + for op in [torch.triu, torch.tril]: + verify_model(_test_trilu(op, 0), [torch.rand(size=[3, 3])]) + verify_model(_test_trilu(op, 1), [torch.rand(size=[6, 6])]) + verify_model(_test_trilu(op, -2), [torch.rand(size=[6, 6])]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index f91a027de4..b641ba1fdb 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -2207,5 +2207,34 @@ class TestSTFT: ) +def test_trilu(target="llvm", dev=tvm.cpu()): + def verify_trilu(data_shape, upper=True, k=0): + data = relay.var("data", relay.TensorType(data_shape, "float32")) + y = relay.trilu(data, k, upper) + mod = tvm.ir.IRModule.from_expr(y) + + data_np = np.random.normal(size=data_shape).astype("float32") + tvm_res = ( + relay.create_executor("graph", mod=mod, device=dev, target=target) + .evaluate()(data_np) + .numpy() + ) + if upper: + np_res = np.triu(data_np, k) + else: + np_res = np.tril(data_np, k) + tvm.testing.assert_allclose(tvm_res, np_res) + + # Test upper and lower triangle + verify_trilu((3, 3), True, 0) + verify_trilu((3, 3), False, 0) + # Test larger matrices with offset. + verify_trilu((6, 6), True, 1) + verify_trilu((6, 6), False, 2) + verify_trilu((6, 6), False, -2) + # Test batch size + verify_trilu((8, 6, 6), False, -2) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 180f267650..c3155c948a 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -812,6 +812,31 @@ def verify_adv_index(data_shape, index_shapes, indice_dtype="int64"): check_device(target, dev) +def verify_trilu(input_shape, upper, k=0): + x = te.placeholder(shape=input_shape, name="x", dtype="float32") + k_tir = tvm.tir.const(k, dtype="int32") + trilu_result = topi.transform.trilu(x, k_tir, upper) + + def check_device(target, dev): + print("Running on target: %s" % target) + with tvm.target.Target(target): + s = tvm.topi.testing.get_injective_schedule(target)(trilu_result) + fn = tvm.build(s, [x, trilu_result], target, name="trilu") + x_npy = np.random.normal(size=input_shape).astype(x.dtype) + if upper: + out_npy = np.triu(x_npy, k) + else: + out_npy = np.tril(x_npy, k) + x_nd = tvm.nd.array(x_npy, dev) + out_nd = tvm.nd.array(np.empty(x_npy.shape).astype(trilu_result.dtype), dev) + fn(x_nd, out_nd) + out_topi = out_nd.numpy() + tvm.testing.assert_allclose(out_topi, out_npy) + + for target, dev in tvm.testing.enabled_targets(): + check_device(target, dev) + + @tvm.testing.uses_gpu def test_strided_slice(): verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) @@ -1256,6 +1281,19 @@ def test_adv_index(): verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)], indice_dtype=indice_dtype) +@tvm.testing.uses_gpu +def test_trilu(): + # Test upper and lower triangle + verify_trilu((3, 3), True, 0) + verify_trilu((3, 3), False, 0) + # Test larger matrices with offset. + verify_trilu((6, 6), True, 1) + verify_trilu((6, 6), False, 2) + verify_trilu((6, 6), False, -2) + # Test batch size + verify_trilu((8, 6, 6), False, -2) + + if __name__ == "__main__": test_strided_slice() test_concatenate() @@ -1283,3 +1321,4 @@ if __name__ == "__main__": test_sparse_to_dense() test_matrix_set_diag() test_adv_index() + test_trilu()