mbaret commented on a change in pull request #6303: URL: https://github.com/apache/incubator-tvm/pull/6303#discussion_r475439237
########## File path: tests/python/frontend/tflite/test_forward.py ########## @@ -2652,6 +2652,77 @@ def test_forward_reverse_v2(): _test_reverse_v2((5, 6, 4, 2), np.array([2], dtype='int32'), dtype) +####################################################################### +# MATRIX_SET_DIAG +# --------------- + +def _test_matrix_set_diag(input_shape, input_type, quantized=False): + """ One iteration of MATRIX_SET_DIAG """ + with tf.Graph().as_default(): + diagonal_shape = list(input_shape[:-2]) + diagonal_shape.append(min(input_shape[-2], input_shape[-1])) Review comment: Should the broadcasting case be tested here? ########## File path: src/relay/op/tensor/transform.cc ########## @@ -3093,5 +3093,55 @@ RELAY_REGISTER_OP("sparse_to_dense") .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FTVMCompute>("FTVMCompute", SparseToDenseCompute); +// relay.matrix_set_diag +bool MatrixSetDiagRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [input, diagonal, result] + CHECK_EQ(types.size(), 3); + + const auto* input = types[0].as<TensorTypeNode>(); + CHECK(input); + + const auto* diagonal = types[1].as<TensorTypeNode>(); + CHECK(diagonal); + + int d_ndims = diagonal->shape.size(); + for (int i = 0; i < d_ndims - 1; i++) { + reporter->AssertEQ(input->shape[i], diagonal->shape[i]); + } + auto min_dim = if_then_else(input->shape[d_ndims - 1] >= input->shape[d_ndims], + input->shape[d_ndims], input->shape[d_ndims - 1]); Review comment: Is `if_then_else` appropriate here? Could `a ? x : y` not be used? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org