MasterJH5574 commented on code in PR #16575:
URL: https://github.com/apache/tvm/pull/16575#discussion_r1491924777


##########
python/tvm/relax/frontend/nn/op.py:
##########
@@ -1825,3 +1826,260 @@ def print_(tensor: Tensor):
     filename, line_number = 
inspect.getframeinfo(inspect.currentframe().f_back)[:2]
     line_info = f"{filename}:{line_number}"
     debug_func("vm.builtin.debug_print", tensor, _line_info=line_info)
+
+
+def less(a: Tensor, b: Tensor, name: str = "less") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs < rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.less(a._expr, b._expr), name)
+
+
+def less_equal(a: Tensor, b: Tensor, name: str = "less_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs <= rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.less_equal(a._expr, b._expr), name)
+
+
+def greater(a: Tensor, b: Tensor, name: str = "greater") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs > rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.greater(a._expr, b._expr), name)
+
+
+def greater_equal(a: Tensor, b: Tensor, name: str = "greater_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs >= rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.greater_equal(a._expr, b._expr), name)
+
+
+def equal(a: Tensor, b: Tensor, name: str = "equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs == rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.equal(a._expr, b._expr), name)
+
+
+def not_equal(a: Tensor, b: Tensor, name: str = "not_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs != rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.not_equal(a._expr, b._expr), name)
+
+
+def cumsum(
+    data: Tensor,
+    axis: Optional[int] = None,
+    dtype: Optional[str] = None,
+    exclusive: Optional[bool] = None,
+    name: str = "cumsum",
+) -> Tensor:
+    """Numpy style cumsum op. Return the cumulative inclusive sum of the 
elements along
+    a given axis.
+
+    Parameters
+    ----------
+    data : Tensor
+        The input data to the operator.
+
+    axis : Optional[int]
+        Axis along which the cumulative sum is computed. The default (None) is 
to compute
+        the cumsum over the flattened array.
+
+    dtype : Optional[str]
+        Type of the returned array and of the accumulator in which the 
elements are summed.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    exclusive : Optional[bool]
+        If true will return exclusive sum in which the first element is not
+        included.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The result has the same size as data, and the same shape as data if 
axis is not None.
+        If axis is None, the result is a 1-d array.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        a = [[1, 2, 3], [4, 5, 6]]
+
+        cumsum(a)  # if axis is not provided, cumsum is done over the 
flattened input.
+        -> [ 1,  3,  6, 10, 15, 21]
+
+        cumsum(a, dtype="float32")
+        -> [  1.,   3.,   6.,  10.,  15.,  21.]
+
+        cumsum(a, axis=0)  # sum over rows for each of the 3 columns
+        -> [[1, 2, 3],
+            [5, 7, 9]]
+
+        cumsum(a, axis=1)
+        -> [[ 1,  3,  6],
+            [ 4,  9, 15]]
+
+        a = [1, 0, 1, 0, 1, 1, 0]  # a is a boolean array
+        cumsum(a, dtype=int32)  # dtype should be provided to get the expected 
results
+        -> [1, 1, 2, 2, 3, 4, 4]
+    """
+    return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
+
+
+def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor):
+    """Returns a tensor where each row contains the index sampled from the 
multinomial
+    probability distribution located in the corresponding row of tensor prob.
+
+    Notes
+    -----
+    For better cpu performance, use 'vm.builtin.multinomial_from_uniform'.
+
+    Parameters
+    ----------
+    prob : Tensor
+        The probability tensor.

Review Comment:
   The TIR function does not require the input prob distribution to have a sum 
of 1 for each row and also does not hard require the samples to be in range 
`[0, 1]`. Shall we add these notes in the docstring? It might be good to just 
say the behavior is undefined when the prob distributions do not sum to 1 and 
samples not in range [0, 1].



##########
python/tvm/relax/frontend/nn/op.py:
##########
@@ -1825,3 +1826,260 @@ def print_(tensor: Tensor):
     filename, line_number = 
inspect.getframeinfo(inspect.currentframe().f_back)[:2]
     line_info = f"{filename}:{line_number}"
     debug_func("vm.builtin.debug_print", tensor, _line_info=line_info)
+
+
+def less(a: Tensor, b: Tensor, name: str = "less") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs < rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.less(a._expr, b._expr), name)
+
+
+def less_equal(a: Tensor, b: Tensor, name: str = "less_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs <= rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.less_equal(a._expr, b._expr), name)
+
+
+def greater(a: Tensor, b: Tensor, name: str = "greater") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs > rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.greater(a._expr, b._expr), name)
+
+
+def greater_equal(a: Tensor, b: Tensor, name: str = "greater_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs >= rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.greater_equal(a._expr, b._expr), name)
+
+
+def equal(a: Tensor, b: Tensor, name: str = "equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs == rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.equal(a._expr, b._expr), name)
+
+
+def not_equal(a: Tensor, b: Tensor, name: str = "not_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs != rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.not_equal(a._expr, b._expr), name)
+
+
+def cumsum(
+    data: Tensor,
+    axis: Optional[int] = None,
+    dtype: Optional[str] = None,
+    exclusive: Optional[bool] = None,
+    name: str = "cumsum",
+) -> Tensor:
+    """Numpy style cumsum op. Return the cumulative inclusive sum of the 
elements along
+    a given axis.
+
+    Parameters
+    ----------
+    data : Tensor
+        The input data to the operator.
+
+    axis : Optional[int]
+        Axis along which the cumulative sum is computed. The default (None) is 
to compute
+        the cumsum over the flattened array.
+
+    dtype : Optional[str]
+        Type of the returned array and of the accumulator in which the 
elements are summed.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    exclusive : Optional[bool]
+        If true will return exclusive sum in which the first element is not
+        included.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The result has the same size as data, and the same shape as data if 
axis is not None.
+        If axis is None, the result is a 1-d array.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        a = [[1, 2, 3], [4, 5, 6]]
+
+        cumsum(a)  # if axis is not provided, cumsum is done over the 
flattened input.
+        -> [ 1,  3,  6, 10, 15, 21]
+
+        cumsum(a, dtype="float32")
+        -> [  1.,   3.,   6.,  10.,  15.,  21.]
+
+        cumsum(a, axis=0)  # sum over rows for each of the 3 columns
+        -> [[1, 2, 3],
+            [5, 7, 9]]
+
+        cumsum(a, axis=1)
+        -> [[ 1,  3,  6],
+            [ 4,  9, 15]]
+
+        a = [1, 0, 1, 0, 1, 1, 0]  # a is a boolean array
+        cumsum(a, dtype=int32)  # dtype should be provided to get the expected 
results
+        -> [1, 1, 2, 2, 3, 4, 4]
+    """
+    return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
+
+
+def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor):
+    """Returns a tensor where each row contains the index sampled from the 
multinomial
+    probability distribution located in the corresponding row of tensor prob.
+
+    Notes
+    -----
+    For better cpu performance, use 'vm.builtin.multinomial_from_uniform'.
+
+    Parameters
+    ----------
+    prob : Tensor
+        The probability tensor.
+
+    uniform_sample : Tensor
+        The uniform sample.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.

Review Comment:
   Shall we note the shapes of these tensors in the docstring? For example wen 
can say `prob` has shape `(b, n)`, `uniform_sample` has shape `(b, 1)`, and 
`result` has shape `(b, 1)`.



##########
python/tvm/relax/frontend/nn/op.py:
##########
@@ -1825,3 +1826,260 @@ def print_(tensor: Tensor):
     filename, line_number = 
inspect.getframeinfo(inspect.currentframe().f_back)[:2]
     line_info = f"{filename}:{line_number}"
     debug_func("vm.builtin.debug_print", tensor, _line_info=line_info)
+
+
+def less(a: Tensor, b: Tensor, name: str = "less") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs < rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.less(a._expr, b._expr), name)
+
+
+def less_equal(a: Tensor, b: Tensor, name: str = "less_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs <= rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.less_equal(a._expr, b._expr), name)
+
+
+def greater(a: Tensor, b: Tensor, name: str = "greater") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs > rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.greater(a._expr, b._expr), name)
+
+
+def greater_equal(a: Tensor, b: Tensor, name: str = "greater_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs >= rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.greater_equal(a._expr, b._expr), name)
+
+
+def equal(a: Tensor, b: Tensor, name: str = "equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs == rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.equal(a._expr, b._expr), name)
+
+
+def not_equal(a: Tensor, b: Tensor, name: str = "not_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs != rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.not_equal(a._expr, b._expr), name)
+
+
+def cumsum(
+    data: Tensor,
+    axis: Optional[int] = None,
+    dtype: Optional[str] = None,
+    exclusive: Optional[bool] = None,
+    name: str = "cumsum",
+) -> Tensor:
+    """Numpy style cumsum op. Return the cumulative inclusive sum of the 
elements along
+    a given axis.
+
+    Parameters
+    ----------
+    data : Tensor
+        The input data to the operator.
+
+    axis : Optional[int]
+        Axis along which the cumulative sum is computed. The default (None) is 
to compute
+        the cumsum over the flattened array.
+
+    dtype : Optional[str]
+        Type of the returned array and of the accumulator in which the 
elements are summed.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    exclusive : Optional[bool]
+        If true will return exclusive sum in which the first element is not
+        included.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The result has the same size as data, and the same shape as data if 
axis is not None.
+        If axis is None, the result is a 1-d array.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        a = [[1, 2, 3], [4, 5, 6]]
+
+        cumsum(a)  # if axis is not provided, cumsum is done over the 
flattened input.
+        -> [ 1,  3,  6, 10, 15, 21]
+
+        cumsum(a, dtype="float32")
+        -> [  1.,   3.,   6.,  10.,  15.,  21.]
+
+        cumsum(a, axis=0)  # sum over rows for each of the 3 columns
+        -> [[1, 2, 3],
+            [5, 7, 9]]
+
+        cumsum(a, axis=1)
+        -> [[ 1,  3,  6],
+            [ 4,  9, 15]]
+
+        a = [1, 0, 1, 0, 1, 1, 0]  # a is a boolean array
+        cumsum(a, dtype=int32)  # dtype should be provided to get the expected 
results
+        -> [1, 1, 2, 2, 3, 4, 4]
+    """
+    return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
+
+
+def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor):

Review Comment:
   Do we have tests showing this op can build and correctly run on GPU?



##########
python/tvm/relax/frontend/nn/op.py:
##########
@@ -1825,3 +1826,260 @@ def print_(tensor: Tensor):
     filename, line_number = 
inspect.getframeinfo(inspect.currentframe().f_back)[:2]
     line_info = f"{filename}:{line_number}"
     debug_func("vm.builtin.debug_print", tensor, _line_info=line_info)
+
+
+def less(a: Tensor, b: Tensor, name: str = "less") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs < rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.less(a._expr, b._expr), name)
+
+
+def less_equal(a: Tensor, b: Tensor, name: str = "less_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs <= rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.less_equal(a._expr, b._expr), name)
+
+
+def greater(a: Tensor, b: Tensor, name: str = "greater") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs > rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.greater(a._expr, b._expr), name)
+
+
+def greater_equal(a: Tensor, b: Tensor, name: str = "greater_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs >= rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.greater_equal(a._expr, b._expr), name)
+
+
+def equal(a: Tensor, b: Tensor, name: str = "equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs == rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.equal(a._expr, b._expr), name)
+
+
+def not_equal(a: Tensor, b: Tensor, name: str = "not_equal") -> Tensor:
+    """Broadcasted element-wise comparison for (lhs != rhs).
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+    """
+    return wrap_nested(_op.not_equal(a._expr, b._expr), name)
+
+
+def cumsum(
+    data: Tensor,
+    axis: Optional[int] = None,
+    dtype: Optional[str] = None,
+    exclusive: Optional[bool] = None,
+    name: str = "cumsum",
+) -> Tensor:
+    """Numpy style cumsum op. Return the cumulative inclusive sum of the 
elements along
+    a given axis.
+
+    Parameters
+    ----------
+    data : Tensor
+        The input data to the operator.
+
+    axis : Optional[int]
+        Axis along which the cumulative sum is computed. The default (None) is 
to compute
+        the cumsum over the flattened array.
+
+    dtype : Optional[str]
+        Type of the returned array and of the accumulator in which the 
elements are summed.
+        If dtype is not specified, it defaults to the dtype of data.
+
+    exclusive : Optional[bool]
+        If true will return exclusive sum in which the first element is not
+        included.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The result has the same size as data, and the same shape as data if 
axis is not None.
+        If axis is None, the result is a 1-d array.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        a = [[1, 2, 3], [4, 5, 6]]
+
+        cumsum(a)  # if axis is not provided, cumsum is done over the 
flattened input.
+        -> [ 1,  3,  6, 10, 15, 21]
+
+        cumsum(a, dtype="float32")
+        -> [  1.,   3.,   6.,  10.,  15.,  21.]
+
+        cumsum(a, axis=0)  # sum over rows for each of the 3 columns
+        -> [[1, 2, 3],
+            [5, 7, 9]]
+
+        cumsum(a, axis=1)
+        -> [[ 1,  3,  6],
+            [ 4,  9, 15]]
+
+        a = [1, 0, 1, 0, 1, 1, 0]  # a is a boolean array
+        cumsum(a, dtype=int32)  # dtype should be provided to get the expected 
results
+        -> [1, 1, 2, 2, 3, 4, 4]
+    """
+    return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
+
+
+def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor):
+    """Returns a tensor where each row contains the index sampled from the 
multinomial
+    probability distribution located in the corresponding row of tensor prob.
+
+    Notes
+    -----
+    For better cpu performance, use 'vm.builtin.multinomial_from_uniform'.
+
+    Parameters
+    ----------
+    prob : Tensor
+        The probability tensor.
+
+    uniform_sample : Tensor
+        The uniform sample.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]]
+        usample = [[0.4], [0.9]]
+
+        multinomial_from_uniform(a)
+        -> [[1], [2]]
+    """
+
+    batch = prob.shape[0]
+    cumsum_prob = cumsum(prob, axis=1, exclusive=False)
+
+    @T.prim_func(private=True)
+    def _get_sample_index(A: T.handle, B: T.handle, C: T.handle):
+        batch, vocab_size = T.int64(), T.int64()
+        prob = T.match_buffer(A, (batch, vocab_size), "float32")
+        usample = T.match_buffer(B, (batch, 1), "float32")
+        output_index = T.match_buffer(C, (batch, 1), "int64")
+
+        for ax0 in T.parallel(batch):
+            for ax1 in T.parallel(vocab_size):
+                with T.block("T_get_sample_index"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.writes(output_index[v_ax0, 0])
+                    if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or 
v_ax1 + 1 == vocab_size:
+                        if v_ax1 == 0:
+                            output_index[v_ax0, 0] = 0
+                        else:
+                            if not (usample[v_ax0, T.int64(0)] < prob[v_ax0, 
v_ax1 - 1]):

Review Comment:
   Just curious: is there any difference between using `not (xxx < yyy)` and 
`xxx >= yyy` here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to