This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 899556d2da [Relax][Op][PyTorch] Supported Median operator (#18626)
899556d2da is described below

commit 899556d2da3f0bc191ec01cfb696f90f69f01b66
Author: Nguyen Duy Loc <[email protected]>
AuthorDate: Fri Jan 2 22:12:48 2026 +0700

    [Relax][Op][PyTorch] Supported Median operator (#18626)
    
    ## Summary:
    - Supported Median operator: Add relax.median & Apply median op into
    exported_program_translator
    - Input: Tensor, Axis, KeepDim
    - Output: (Values, Indices)
    ## Expected:
    ### 1. Axis = None, KeepDim = False
    ```
    class MedianWithoutDim(nn.Module):
        def forward(self, x):
            return torch.median(x)
    ```
    
    ```
    class Module:
        def main(x: R.Tensor((2, 3, 4), dtype="float32")) -> 
R.Tuple(R.Tensor((), dtype="float32")):
            with R.dataflow():
                lv: R.Tensor((), dtype="float32") = R.median(x, axis=None, 
keepdims=False)
                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
                R.output(gv)
            return gv
    ```
    
    
    ### 2. Axis = 0, KeepDim = False
    ```
    class MedianDim(nn.Module):
        def forward(self, x):
            return torch.median(x, dim=0)
    ```
    ```
    class Module:
        def main(x: R.Tensor((2, 3, 4), dtype="float32")) -> 
R.Tuple(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 4), dtype="int64")):
            with R.dataflow():
                lv: R.Tuple(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 4), 
dtype="int64")) = R.median(x, axis=[0], keepdims=False)
                lv1: R.Tensor((3, 4), dtype="float32") = lv[0]
                lv2: R.Tensor((3, 4), dtype="int64") = lv[1]
                gv: R.Tuple(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 4), 
dtype="int64")) = lv1, lv2
                R.output(gv)
            return gv
    ```
    ### 3. Axis = -1, KeepDim = True
    ```
    class MedianKeepDim(nn.Module):
        def forward(self, x):
            return torch.median(x, dim=-1, keepdim=True)
    ```
    ```
    class Module:
        def main(x: R.Tensor((2, 3, 4), dtype="float32")) -> 
R.Tuple(R.Tensor((2, 3, 1), dtype="float32"), R.Tensor((2, 3, 1), 
dtype="int64")):
            with R.dataflow():
                lv: R.Tuple(R.Tensor((2, 3, 1), dtype="float32"), R.Tensor((2, 
3, 1), dtype="int64")) = R.median(x, axis=[-1], keepdims=True)
                lv1: R.Tensor((2, 3, 1), dtype="float32") = lv[0]
                lv2: R.Tensor((2, 3, 1), dtype="int64") = lv[1]
                gv: R.Tuple(R.Tensor((2, 3, 1), dtype="float32"), R.Tensor((2, 
3, 1), dtype="int64")) = lv1, lv2
                R.output(gv)
            return gv
    ```
---
 .../frontend/torch/base_fx_graph_translator.py     |   7 +
 .../frontend/torch/exported_program_translator.py  |   2 +
 python/tvm/relax/op/__init__.py                    |   2 +-
 python/tvm/relax/op/statistical.py                 |  27 +++
 .../relax/transform/legalize_ops/statistical.py    |  47 ++++-
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 src/relax/op/tensor/statistical.cc                 |  82 ++++++++
 src/relax/op/tensor/statistical.h                  |   3 +
 .../relax/test_frontend_from_exported_program.py   |  68 +++++++
 tests/python/relax/test_op_statistical.py          | 226 +++++++++++++++++++++
 ...st_transform_legalize_ops_search_statistical.py |  78 +++++++
 .../relax/test_tvmscript_parser_op_statistical.py  |  19 ++
 12 files changed, 561 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index f7d54a6216..d04dfbb6c3 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1572,6 +1572,13 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
         return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))
 
+    def _median(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+        keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
+        return self.block_builder.emit(relax.op.median(x, dim, 
keepdims=keepdim))
+
     def _norm(self, node: fx.Node) -> relax.Var:
         data = self.env[node.args[0]]
         dtype = data.struct_info.dtype
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index b6b9723c13..0a97614eb5 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1384,6 +1384,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "sum.dim_IntList": self._sum,
             "var.correction": self._var,
             "max.dim": self._max_dim,
+            "median.dim": self._median,
+            "median.default": self._median,
             # search
             "argmax.default": self._argmax_argmin(relax.op.argmax),
             "argmin.default": self._argmax_argmin(relax.op.argmin),
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 19096decd9..c6504d79c9 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -119,7 +119,7 @@ from .sampling import multinomial_from_uniform
 from .search import argmax, argmin, where, bucketize
 from .set import nonzero, unique
 from .sorting import argsort, sort, topk
-from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, 
variance
+from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, 
variance, median
 from .ternary import ewise_fma
 from .unary import (
     abs,
diff --git a/python/tvm/relax/op/statistical.py 
b/python/tvm/relax/op/statistical.py
index 502d058ffd..f11d31604a 100644
--- a/python/tvm/relax/op/statistical.py
+++ b/python/tvm/relax/op/statistical.py
@@ -341,3 +341,30 @@ def variance(x: Expr, axis: Optional[Union[int, 
List[int]]] = None, keepdims: bo
     if isinstance(axis, int):
         axis = [axis]
     return _ffi_api.variance(x, axis, keepdims)  # type: ignore
+
+
+def median(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: 
bool = False) -> Expr:
+    """Computes the median of tensor elements over given axes.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data tensor
+
+    axis : Optional[Union[int, List[int]]]
+        Axis along which the median is computed. The default (None) is to 
compute
+        the median of the entire flattened tensor.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the 
result as dimensions
+        with size one.
+        With this option, the result will broadcast correctly against the 
input tensor.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(axis, int):
+        axis = [axis]
+    return _ffi_api.median(x, axis, keepdims)  # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py 
b/python/tvm/relax/transform/legalize_ops/statistical.py
index bdb79126f0..0c140187db 100644
--- a/python/tvm/relax/transform/legalize_ops/statistical.py
+++ b/python/tvm/relax/transform/legalize_ops/statistical.py
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=invalid-name
 """Default legalization function for statistical operators."""
-from typing import List
+from typing import List, Union, Tuple
 from tvm import topi, tir, te
 from ...block_builder import BlockBuilder
 from ...expr import Call, Expr
@@ -53,6 +53,40 @@ def _te_variance(x: te.Tensor, axis: List[tir.IntImm], 
keepdims: bool) -> te.Ten
     # return _te_mean(x * x, axis, keepdims) - mean * mean
 
 
+def _te_median(
+    x: te.Tensor, axis: List[tir.IntImm], keepdims: bool
+) -> Union[te.Tensor, Tuple[te.Tensor, te.Tensor]]:
+    # currently only supports one axis or no axis ~ same pytorch
+    # todo: support multiple axis ~ same numpy
+    shape_prod = _compute_shape_prod(x, axis)
+    mid_index = (shape_prod - 1) // 2
+
+    if axis is None or len(axis) == 0:
+        x = topi.reshape(x, [shape_prod.value])
+        ax = -1
+    else:
+        ax = axis[0].value
+    index_sorted = topi.argsort(x, axis=ax, is_ascend=True, dtype="int64")
+    x_sorted = topi.gather(x, axis=ax, indices=index_sorted)
+
+    new_shape = list(x.shape)
+    new_shape[ax] = 1
+    indices = topi.full(new_shape, fill_value=mid_index, dtype="int64")
+
+    median_val = topi.gather(x_sorted, axis=ax, indices=indices)
+    median_idx = topi.gather(index_sorted, axis=ax, indices=indices)
+
+    if axis is None or len(axis) == 0:
+        return median_val if keepdims else topi.squeeze(median_val, axis=axis)
+
+    val = median_val
+    idx = median_idx
+    if not keepdims:
+        val = topi.squeeze(val, axis=axis)
+        idx = topi.squeeze(idx, axis=axis)
+    return val, idx
+
+
 @register_legalize("relax.mean")
 def _mean(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(
@@ -81,6 +115,17 @@ def _variance(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.median")
+def _median(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(
+        _te_median,
+        call.args[0],
+        call.attrs.axis,
+        call.attrs.keepdims,
+        primfunc_name_hint="median",
+    )
+
+
 register_legalize("relax.max", _statistical(topi.max))
 register_legalize("relax.min", _statistical(topi.min))
 register_legalize("relax.prod", _statistical(topi.prod))
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 141361a729..354a4d77ba 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -128,6 +128,7 @@ from tvm.relax.op import (
     max,
     maximum,
     mean,
+    median,
     memory,
     meshgrid,
     min,
@@ -874,6 +875,7 @@ __all__ = [
     "max",
     "maximum",
     "mean",
+    "median",
     "memory",
     "meshgrid",
     "metal",
diff --git a/src/relax/op/tensor/statistical.cc 
b/src/relax/op/tensor/statistical.cc
index 621c23d363..771f6ffb13 100644
--- a/src/relax/op/tensor/statistical.cc
+++ b/src/relax/op/tensor/statistical.cc
@@ -180,6 +180,68 @@ StructInfo InferStructInfoScan(const Call& call, const 
BlockBuilder& ctx) {
   }
 }
 
+StructInfo InferStructInfoStatisticalExtension(const Call& call, const 
BlockBuilder& ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  const auto* attrs = call->attrs.as<StatisticalAttrs>();
+
+  std::vector<int> axes;
+  if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) {
+    axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value());
+  }
+
+  int out_ndim;
+  if (attrs->keepdims) {
+    out_ndim = data_sinfo->ndim;
+  } else if (!attrs->axis.defined()) {
+    out_ndim = 0;
+  } else if (data_sinfo->IsUnknownNdim()) {
+    out_ndim = kUnknownNDim;
+  } else {
+    out_ndim = data_sinfo->ndim - axes.size();
+    ICHECK_GE(out_ndim, 0);
+  }
+
+  // The inference rule for median operator output shapes:
+  // - axes is None || len(axes) > 1, keepdims is false -> return the 
zero-rank shape;
+  // - axes is None || len(axes) > 1, keepdims is true -> return the shape 
whose ndim
+  // is the same as input and every value is 1.
+  // - len(axes) == 1, keepdims is false -> the returned shape does not 
contain the input axis.
+  // - len(axes) == 1, keepdims is true -> the returned shape has value 1 at 
the positions of the
+  // input axis
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+  if (data_shape == nullptr) {
+    if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) 
{
+      return TensorStructInfo(
+          ShapeExpr(ffi::Array<PrimExpr>(out_ndim, IntImm(DataType::Int(64), 
/*value=*/1))),
+          data_sinfo->dtype, data_sinfo->vdevice);
+    }
+    if (out_ndim == 0) {
+      return TensorStructInfo(ShapeExpr(ffi::Array<PrimExpr>()), 
data_sinfo->dtype,
+                              data_sinfo->vdevice);
+    }
+    return TupleStructInfo({TensorStructInfo(data_sinfo->dtype, out_ndim, 
data_sinfo->vdevice),
+                            TensorStructInfo(DataType::Int(64), out_ndim, 
data_sinfo->vdevice)});
+  }
+
+  ffi::Array<PrimExpr> out_shape;
+  out_shape.reserve(out_ndim);
+  for (int i = 0; i < data_sinfo->ndim; ++i) {
+    if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == 
axes.end()) {
+      out_shape.push_back(data_shape->values[i]);
+    } else if (attrs->keepdims) {
+      out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1));
+    }
+  }
+  ICHECK_EQ(static_cast<int>(out_shape.size()), out_ndim);
+
+  if (!attrs->axis.defined() || axes.size() > 1)
+    return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, 
data_sinfo->vdevice);
+  else
+    return TupleStructInfo(
+        {TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, 
data_sinfo->vdevice),
+         TensorStructInfo(ShapeExpr(out_shape), DataType::Int(64), 
data_sinfo->vdevice)});
+}
+
 /* relax.cumprod */
 Expr cumprod(Expr data, ffi::Optional<int64_t> axis, ffi::Optional<DataType> 
dtype,
              Bool exclusive) {
@@ -227,6 +289,26 @@ TVM_REGISTER_OP("relax.cumsum")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScan)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.median */
+Expr median(Expr data, ffi::Optional<ffi::Array<Integer>> axis, bool keepdims) 
{
+  ObjectPtr<StatisticalAttrs> attrs = ffi::make_object<StatisticalAttrs>();
+  attrs->axis = std::move(axis);
+  attrs->keepdims = keepdims;
+  static const Op& op = Op::Get("relax.median");
+  return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.median", median);
+}
+
+TVM_REGISTER_OP("relax.median")
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoStatisticalExtension)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 RELAX_REGISTER_STATISTICAL_OP_INTERFACE(max);
 RELAX_REGISTER_STATISTICAL_OP_INTERFACE(mean);
 RELAX_REGISTER_STATISTICAL_OP_INTERFACE(min);
diff --git a/src/relax/op/tensor/statistical.h 
b/src/relax/op/tensor/statistical.h
index a80ef72868..0a4f83687d 100644
--- a/src/relax/op/tensor/statistical.h
+++ b/src/relax/op/tensor/statistical.h
@@ -119,6 +119,9 @@ Expr cumsum(Expr data, ffi::Optional<int64_t> axis = 
std::nullopt,
 /*! \brief Computes the variance of tensor elements over given axes. */
 Expr variance(Expr x, ffi::Optional<ffi::Array<Integer>> axis, bool keepdims);
 
+/*! \brief Computes the median of tensor elements over given axes. */
+Expr median(Expr x, ffi::Optional<ffi::Array<Integer>> axis, bool keepdims);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 9f8842ddcb..01a24ada1f 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4957,6 +4957,74 @@ def test_mean():
     verify_model(MeanWithoutDim(), example_args, {}, Expected3)
 
 
+def test_median():
+    class Median(Module):
+        def forward(self, input):
+            return input.median(-1)
+
+    class MedianKeepDim(Module):
+        def forward(self, input):
+            return input.median(-1, keepdim=True)
+
+    class MedianWithoutDim(Module):
+        def forward(self, input):
+            return input.median()
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tuple(R.Tensor((256,), dtype="float32"), R.Tensor((256,), 
dtype="int64")):
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((256,), dtype="float32"), R.Tensor((256,), 
dtype="int64")
+                ) = R.median(inp_0, axis=[-1], keepdims=False)
+                lv1: R.Tensor((256,), dtype="float32") = lv[0]
+                lv2: R.Tensor((256,), dtype="int64") = lv[1]
+                gv: R.Tuple(R.Tensor((256,), dtype="float32"), 
R.Tensor((256,), dtype="int64")) = (
+                    lv1,
+                    lv2,
+                )
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tuple(R.Tensor((256, 1), dtype="float32"), R.Tensor((256, 1), 
dtype="int64")):
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((256, 1), dtype="float32"), R.Tensor((256, 1), 
dtype="int64")
+                ) = R.median(inp_0, axis=[-1], keepdims=True)
+                lv1: R.Tensor((256, 1), dtype="float32") = lv[0]
+                lv2: R.Tensor((256, 1), dtype="int64") = lv[1]
+                gv: R.Tuple(
+                    R.Tensor((256, 1), dtype="float32"), R.Tensor((256, 1), 
dtype="int64")
+                ) = (lv1, lv2)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected3:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tuple(R.Tensor((), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.median(inp_0, axis=None, 
keepdims=False)
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(256, 256, dtype=torch.float32),)
+    verify_model(Median(), example_args, {}, Expected1)
+    verify_model(MedianKeepDim(), example_args, {}, Expected2)
+    verify_model(MedianWithoutDim(), example_args, {}, Expected3)
+
+
 def test_sum():
     class Sum(Module):
         def forward(self, x):
diff --git a/tests/python/relax/test_op_statistical.py 
b/tests/python/relax/test_op_statistical.py
index a0cfc81e55..5dccbb33cc 100644
--- a/tests/python/relax/test_op_statistical.py
+++ b/tests/python/relax/test_op_statistical.py
@@ -33,6 +33,7 @@ def test_op_correctness():
     assert relax.op.std(x).op == Op.get("relax.std")
     assert relax.op.sum(x).op == Op.get("relax.sum")
     assert relax.op.variance(x).op == Op.get("relax.variance")
+    assert relax.op.median(x).op == Op.get("relax.median")
 
 
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
@@ -275,5 +276,230 @@ def 
test_scan_opinfer_struct_info_wrong_input_type(scan_op: Callable):
         bb.normalize(scan_op(x1, axis=1))
 
 
+def test_statistical_ext_infer_struct_info():
+    bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
+    x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
+    x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
+
+    _check_inference(
+        bb,
+        relax.op.median(x0, axis=[1]),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 4, 5), "float32"),
+                relax.TensorStructInfo((2, 4, 5), "int64"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x0, axis=[1], keepdims=True),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 1, 4, 5), "float32"),
+                relax.TensorStructInfo((2, 1, 4, 5), "int64"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x1, axis=[1]),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=3),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x1, axis=[1], keepdims=True),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=4),
+                relax.TensorStructInfo(dtype="int64", ndim=4),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x1, axis=None, keepdims=True),
+        relax.TensorStructInfo((1, 1, 1, 1), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x2, axis=[1]),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32"),
+                relax.TensorStructInfo(dtype="int64"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x2, axis=[1], keepdims=True),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32"),
+                relax.TensorStructInfo(dtype="int64"),
+            ]
+        ),
+    )
+    _check_inference(bb, relax.op.median(x2, axis=None), 
relax.TensorStructInfo((), "float32"))
+    _check_inference(
+        bb,
+        relax.op.median(x3, axis=[1], keepdims=True),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 1, 4, 5), dtype=""),
+                relax.TensorStructInfo((2, 1, 4, 5), dtype="int64"),
+            ]
+        ),
+    )
+    _check_inference(bb, relax.op.median(x3, axis=None), 
relax.TensorStructInfo((), dtype=""))
+    _check_inference(
+        bb,
+        relax.op.median(x3, axis=None, keepdims=True),
+        relax.TensorStructInfo((1, 1, 1, 1), dtype=""),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x4, axis=[1]),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 4, 5), "float32", vdev0),
+                relax.TensorStructInfo((2, 4, 5), "int64", vdev0),
+            ]
+        ),
+    )
+
+
+def test_statistical_ext_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    c = tir.Var("c", "int64")
+    d = tir.Var("d", "int64")
+    x = relax.Var("x", R.Tensor((a, b, c, d), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.median(x, axis=[1]),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((a, c, d), "float32"),
+                relax.TensorStructInfo((a, c, d), "int64"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x, axis=[1], keepdims=True),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((a, 1, c, d), "float32"),
+                relax.TensorStructInfo((a, 1, c, d), "int64"),
+            ]
+        ),
+    )
+    _check_inference(bb, relax.op.median(x, axis=None), 
relax.TensorStructInfo((), "float32"))
+    _check_inference(
+        bb,
+        relax.op.median(x, axis=None, keepdims=True),
+        relax.TensorStructInfo((1, 1, 1, 1), "float32"),
+    )
+
+
+def test_statistical_ext_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+    s1 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+    _check_inference(bb, relax.op.median(x0), relax.TensorStructInfo((), 
dtype="float32"))
+    _check_inference(
+        bb,
+        relax.op.median(x0, keepdims=True),
+        relax.TensorStructInfo((1, 1, 1, 1), dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x0, axis=[2]),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="int64", ndim=3),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x0, axis=[2], keepdims=True),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=4),
+                relax.TensorStructInfo(dtype="int64", ndim=4),
+            ]
+        ),
+    )
+    _check_inference(bb, relax.op.median(x1), relax.TensorStructInfo((), 
dtype="float32"))
+    _check_inference(
+        bb,
+        relax.op.median(x1, keepdims=True),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32"),
+                relax.TensorStructInfo(dtype="int64"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x1, axis=[2]),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32"),
+                relax.TensorStructInfo(dtype="int64"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.median(x1, axis=[2], keepdims=True),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32"),
+                relax.TensorStructInfo(dtype="int64"),
+            ]
+        ),
+    )
+
+
+def test_statistical_ext_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8"))
+
+    _check_inference(bb, relax.op.median(x0), relax.TensorStructInfo((), 
"float16"))
+    _check_inference(bb, relax.op.median(x1), relax.TensorStructInfo((), 
"int8"))
+
+
+def test_statistical_ext_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), 
"float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.median(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.median(x1))
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git 
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py 
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 7edfff3dfc..b28451da1b 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -684,6 +684,84 @@ def test_mean_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_median():
+    # fmt: off
+    @tvm.script.ir_module
+    class Median:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> 
R.Tuple(R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4, 5), 
dtype="int64")):
+            gv: R.Tuple(R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4, 
5), dtype="int64")) = R.median(x, axis=[0], keepdims=False)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> 
R.Tuple(R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4, 5), 
dtype="int64")):
+            gv = R.call_tir(Expected.median, (x,), out_sinfo=[R.Tensor((3, 4, 
5), dtype="float32"), R.Tensor((3, 4, 5), dtype="int64")])
+            return gv
+
+        @T.prim_func(private=True)
+        def median(var_x: T.handle, T_squeeze: T.Buffer((T.int64(3), 
T.int64(4), T.int64(5)), "float32"), T_squeeze_1: T.Buffer((T.int64(3), 
T.int64(4), T.int64(5)), "int64")):
+            T.func_attr({"tir.noalias": True})
+            data_buf = T.match_buffer(var_x, (T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)), align=8)
+            # with T.block("root"):
+            T_full = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4), 
T.int64(5)), "int64")
+            out_buf = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)), "int64", align=8)
+            T_gather = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            T_gather_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4), 
T.int64(5)))
+            T_gather_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4), 
T.int64(5)), "int64")
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_full"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads()
+                    T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_full[v_ax0, v_ax1, v_ax2, v_ax3] = 0
+            with T.block("argsort_cpu"):
+                T.reads()
+                T.writes()
+                T.call_packed("tvm.contrib.sort.argsort", 
T.tvm_stack_make_array(data_buf.data,
+                                                                               
  T.tvm_stack_make_shape(T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
+                                                                               
  0, 4, T.float32(0.0), T.int64(0)),
+                                                          
T.tvm_stack_make_array(out_buf.data,
+                                                                               
  T.tvm_stack_make_shape(T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
+                                                                               
  0, 4, T.int64(0), T.int64(0)),
+                                                          0, T.bool(True))
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_gather"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(data_buf[out_buf[v_ax0, v_ax1, v_ax2, v_ax3], 
v_ax1, v_ax2, v_ax3], out_buf[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_gather[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_gather[v_ax0, v_ax1, v_ax2, v_ax3] = 
data_buf[out_buf[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_gather_1"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_gather[T_full[v_ax0, v_ax1, v_ax2, v_ax3], 
v_ax1, v_ax2, v_ax3], T_full[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_gather_1[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_gather_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_gather[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3]
+            for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(4), T.int64(5)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(T_gather_1[T.int64(0), v_ax0, v_ax1, v_ax2])
+                    T.writes(T_squeeze[v_ax0, v_ax1, v_ax2])
+                    T_squeeze[v_ax0, v_ax1, v_ax2] = T_gather_1[T.int64(0), 
v_ax0, v_ax1, v_ax2]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_gather_2"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(out_buf[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, 
v_ax2, v_ax3], T_full[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_gather_2[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_gather_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
out_buf[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3]
+            for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(4), T.int64(5)):
+                with T.block("T_squeeze_1"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(T_gather_2[T.int64(0), v_ax0, v_ax1, v_ax2])
+                    T.writes(T_squeeze_1[v_ax0, v_ax1, v_ax2])
+                    T_squeeze_1[v_ax0, v_ax1, v_ax2] = T_gather_2[T.int64(0), 
v_ax0, v_ax1, v_ax2]
+    # fmt: on
+
+    mod = LegalizeOps()(Median)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_std():
     # fmt: off
     @tvm.script.ir_module
diff --git a/tests/python/relax/test_tvmscript_parser_op_statistical.py 
b/tests/python/relax/test_tvmscript_parser_op_statistical.py
index 910c08bf1e..6ba90c5651 100644
--- a/tests/python/relax/test_tvmscript_parser_op_statistical.py
+++ b/tests/python/relax/test_tvmscript_parser_op_statistical.py
@@ -95,6 +95,25 @@ def test_mean():
     _check(foo, bb.get()["foo"])
 
 
+def test_median():
+    @R.function
+    def foo(
+        x: R.Tensor((1, 2, 3, 4), "float32")
+    ) -> R.Tuple(R.Tensor((1, 3, 4), "float32"), R.Tensor((1, 3, 4), "int64")):
+        gv: R.Tuple(R.Tensor((1, 3, 4), "float32"), R.Tensor((1, 3, 4), 
"int64")) = R.median(
+            x, axis=[1]
+        )
+        return gv
+
+    x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.median(x, axis=[1]))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
 def test_variance():
     @R.function
     def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1,), "float32"):

Reply via email to