t-vi commented on a change in pull request #6449:
URL: https://github.com/apache/incubator-tvm/pull/6449#discussion_r486806914



##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I must admit that I'd appreciate if there were more commentary to the 
typing changes here.
   - In my opinion (and I could be wrong), it would be helpful to have a view 
what kind of types `input_types` and `inputs` can have and have a single place 
where we do implicit type promotion. I had hoped `_pytorch_promote_types` could 
be that.
   - If `_pytorch_promote_types` doesn't do the job, maybe we can comment why 
it isn't. Also why is this particular apparently particular elementwise ops as 
opposed to amending `_pytorch_promote_types`?
   
   I know this looks like I'm asking for busywork when you're mostly interested 
in getting a particular to work, but I have the impression that we would want 
to avoid ad hoc type workarounds as much as possible if we want to avoid having 
subtle bugs whenever someone uses something outside what our unit tests catch.
   

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       The int is not needed here?
   Also it might be worth trying to avoid `try: ... except Exception:` during 
non-error-processing in favour of `if isinstance(....): ... else:`.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       For which types do we want to do this (or alternatively which can go 
straight through)?

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")

Review comment:
       int32 here and index_size limit 2**63-1 feels strange.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -429,25 +507,56 @@ def _impl(inputs, input_types):
 
     return _impl
 
+def _full_impl(data, fill_value, dtype):
+    size = []
+    need_reshape = False
+    new_shape = []
+    for dim in data:
+        if isinstance(dim, _expr.Expr):
+            if isinstance(dim, _expr.Constant):
+                dim = int(dim.data.asnumpy())
+                if isinstance(size, list):
+                    size.append(dim)
+                new_shape.append(dim)
+            else:
+                try:
+                    dim = int(_infer_value(dim, {}).asnumpy())

Review comment:
       Here, too, maybe avoid: `try: .. except:`
   (there are more places, I didn't flag them all, but I think they should be 
all changed to use plain `if`).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       Still, I would prefer looking at what the type of `inputs[1]` is and 
have an `if`. We should at least know which types are good to leave as is (the 
current except block).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I think we would eventually want to look at using type propagation more.
   However, the issue here is that PyTorch's default dtype for integral tensors 
is int64. I don't think we should be hacking around that, really, because we're 
bound to end up with cases where int64 is the right thing to have. If I 
understood the discussions on the forum correctly, the idea was to downcast 64 
bit indexing to 32 based if it is considered safe.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       I'd have a strong preference for that, yeah.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I must admit that I'd appreciate if there were more commentary to the 
typing changes here.
   - In my opinion (and I could be wrong), it would be helpful to have a view 
what kind of types `input_types` and `inputs` can have and have a single place 
where we do implicit type promotion. I had hoped `_pytorch_promote_types` could 
be that.
   - If `_pytorch_promote_types` doesn't do the job, maybe we can comment why 
it isn't. Also why is this particular apparently particular elementwise ops as 
opposed to amending `_pytorch_promote_types`?
   
   I know this looks like I'm asking for busywork when you're mostly interested 
in getting a particular to work, but I have the impression that we would want 
to avoid ad hoc type workarounds as much as possible if we want to avoid having 
subtle bugs whenever someone uses something outside what our unit tests catch.
   

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       The int is not needed here?
   Also it might be worth trying to avoid `try: ... except Exception:` during 
non-error-processing in favour of `if isinstance(....): ... else:`.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       For which types do we want to do this (or alternatively which can go 
straight through)?

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")

Review comment:
       int32 here and index_size limit 2**63-1 feels strange.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -429,25 +507,56 @@ def _impl(inputs, input_types):
 
     return _impl
 
+def _full_impl(data, fill_value, dtype):
+    size = []
+    need_reshape = False
+    new_shape = []
+    for dim in data:
+        if isinstance(dim, _expr.Expr):
+            if isinstance(dim, _expr.Constant):
+                dim = int(dim.data.asnumpy())
+                if isinstance(size, list):
+                    size.append(dim)
+                new_shape.append(dim)
+            else:
+                try:
+                    dim = int(_infer_value(dim, {}).asnumpy())

Review comment:
       Here, too, maybe avoid: `try: .. except:`
   (there are more places, I didn't flag them all, but I think they should be 
all changed to use plain `if`).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       Still, I would prefer looking at what the type of `inputs[1]` is and 
have an `if`. We should at least know which types are good to leave as is (the 
current except block).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I think we would eventually want to look at using type propagation more.
   However, the issue here is that PyTorch's default dtype for integral tensors 
is int64. I don't think we should be hacking around that, really, because we're 
bound to end up with cases where int64 is the right thing to have. If I 
understood the discussions on the forum correctly, the idea was to downcast 64 
bit indexing to 32 based if it is considered safe.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       I'd have a strong preference for that, yeah.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I must admit that I'd appreciate if there were more commentary to the 
typing changes here.
   - In my opinion (and I could be wrong), it would be helpful to have a view 
what kind of types `input_types` and `inputs` can have and have a single place 
where we do implicit type promotion. I had hoped `_pytorch_promote_types` could 
be that.
   - If `_pytorch_promote_types` doesn't do the job, maybe we can comment why 
it isn't. Also why is this particular apparently particular elementwise ops as 
opposed to amending `_pytorch_promote_types`?
   
   I know this looks like I'm asking for busywork when you're mostly interested 
in getting a particular to work, but I have the impression that we would want 
to avoid ad hoc type workarounds as much as possible if we want to avoid having 
subtle bugs whenever someone uses something outside what our unit tests catch.
   

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       The int is not needed here?
   Also it might be worth trying to avoid `try: ... except Exception:` during 
non-error-processing in favour of `if isinstance(....): ... else:`.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       For which types do we want to do this (or alternatively which can go 
straight through)?

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")

Review comment:
       int32 here and index_size limit 2**63-1 feels strange.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -429,25 +507,56 @@ def _impl(inputs, input_types):
 
     return _impl
 
+def _full_impl(data, fill_value, dtype):
+    size = []
+    need_reshape = False
+    new_shape = []
+    for dim in data:
+        if isinstance(dim, _expr.Expr):
+            if isinstance(dim, _expr.Constant):
+                dim = int(dim.data.asnumpy())
+                if isinstance(size, list):
+                    size.append(dim)
+                new_shape.append(dim)
+            else:
+                try:
+                    dim = int(_infer_value(dim, {}).asnumpy())

Review comment:
       Here, too, maybe avoid: `try: .. except:`
   (there are more places, I didn't flag them all, but I think they should be 
all changed to use plain `if`).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       Still, I would prefer looking at what the type of `inputs[1]` is and 
have an `if`. We should at least know which types are good to leave as is (the 
current except block).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I think we would eventually want to look at using type propagation more.
   However, the issue here is that PyTorch's default dtype for integral tensors 
is int64. I don't think we should be hacking around that, really, because we're 
bound to end up with cases where int64 is the right thing to have. If I 
understood the discussions on the forum correctly, the idea was to downcast 64 
bit indexing to 32 based if it is considered safe.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       I'd have a strong preference for that, yeah.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I must admit that I'd appreciate if there were more commentary to the 
typing changes here.
   - In my opinion (and I could be wrong), it would be helpful to have a view 
what kind of types `input_types` and `inputs` can have and have a single place 
where we do implicit type promotion. I had hoped `_pytorch_promote_types` could 
be that.
   - If `_pytorch_promote_types` doesn't do the job, maybe we can comment why 
it isn't. Also why is this particular apparently particular elementwise ops as 
opposed to amending `_pytorch_promote_types`?
   
   I know this looks like I'm asking for busywork when you're mostly interested 
in getting a particular to work, but I have the impression that we would want 
to avoid ad hoc type workarounds as much as possible if we want to avoid having 
subtle bugs whenever someone uses something outside what our unit tests catch.
   

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       The int is not needed here?
   Also it might be worth trying to avoid `try: ... except Exception:` during 
non-error-processing in favour of `if isinstance(....): ... else:`.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       For which types do we want to do this (or alternatively which can go 
straight through)?

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")

Review comment:
       int32 here and index_size limit 2**63-1 feels strange.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -429,25 +507,56 @@ def _impl(inputs, input_types):
 
     return _impl
 
+def _full_impl(data, fill_value, dtype):
+    size = []
+    need_reshape = False
+    new_shape = []
+    for dim in data:
+        if isinstance(dim, _expr.Expr):
+            if isinstance(dim, _expr.Constant):
+                dim = int(dim.data.asnumpy())
+                if isinstance(size, list):
+                    size.append(dim)
+                new_shape.append(dim)
+            else:
+                try:
+                    dim = int(_infer_value(dim, {}).asnumpy())

Review comment:
       Here, too, maybe avoid: `try: .. except:`
   (there are more places, I didn't flag them all, but I think they should be 
all changed to use plain `if`).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       Still, I would prefer looking at what the type of `inputs[1]` is and 
have an `if`. We should at least know which types are good to leave as is (the 
current except block).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I think we would eventually want to look at using type propagation more.
   However, the issue here is that PyTorch's default dtype for integral tensors 
is int64. I don't think we should be hacking around that, really, because we're 
bound to end up with cases where int64 is the right thing to have. If I 
understood the discussions on the forum correctly, the idea was to downcast 64 
bit indexing to 32 based if it is considered safe.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       I'd have a strong preference for that, yeah.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I must admit that I'd appreciate if there were more commentary to the 
typing changes here.
   - In my opinion (and I could be wrong), it would be helpful to have a view 
what kind of types `input_types` and `inputs` can have and have a single place 
where we do implicit type promotion. I had hoped `_pytorch_promote_types` could 
be that.
   - If `_pytorch_promote_types` doesn't do the job, maybe we can comment why 
it isn't. Also why is this particular apparently particular elementwise ops as 
opposed to amending `_pytorch_promote_types`?
   
   I know this looks like I'm asking for busywork when you're mostly interested 
in getting a particular to work, but I have the impression that we would want 
to avoid ad hoc type workarounds as much as possible if we want to avoid having 
subtle bugs whenever someone uses something outside what our unit tests catch.
   

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       The int is not needed here?
   Also it might be worth trying to avoid `try: ... except Exception:` during 
non-error-processing in favour of `if isinstance(....): ... else:`.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       For which types do we want to do this (or alternatively which can go 
straight through)?

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")

Review comment:
       int32 here and index_size limit 2**63-1 feels strange.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -429,25 +507,56 @@ def _impl(inputs, input_types):
 
     return _impl
 
+def _full_impl(data, fill_value, dtype):
+    size = []
+    need_reshape = False
+    new_shape = []
+    for dim in data:
+        if isinstance(dim, _expr.Expr):
+            if isinstance(dim, _expr.Constant):
+                dim = int(dim.data.asnumpy())
+                if isinstance(size, list):
+                    size.append(dim)
+                new_shape.append(dim)
+            else:
+                try:
+                    dim = int(_infer_value(dim, {}).asnumpy())

Review comment:
       Here, too, maybe avoid: `try: .. except:`
   (there are more places, I didn't flag them all, but I think they should be 
all changed to use plain `if`).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       Still, I would prefer looking at what the type of `inputs[1]` is and 
have an `if`. We should at least know which types are good to leave as is (the 
current except block).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I think we would eventually want to look at using type propagation more.
   However, the issue here is that PyTorch's default dtype for integral tensors 
is int64. I don't think we should be hacking around that, really, because we're 
bound to end up with cases where int64 is the right thing to have. If I 
understood the discussions on the forum correctly, the idea was to downcast 64 
bit indexing to 32 based if it is considered safe.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       I'd have a strong preference for that, yeah.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I must admit that I'd appreciate if there were more commentary to the 
typing changes here.
   - In my opinion (and I could be wrong), it would be helpful to have a view 
what kind of types `input_types` and `inputs` can have and have a single place 
where we do implicit type promotion. I had hoped `_pytorch_promote_types` could 
be that.
   - If `_pytorch_promote_types` doesn't do the job, maybe we can comment why 
it isn't. Also why is this particular apparently particular elementwise ops as 
opposed to amending `_pytorch_promote_types`?
   
   I know this looks like I'm asking for busywork when you're mostly interested 
in getting a particular to work, but I have the impression that we would want 
to avoid ad hoc type workarounds as much as possible if we want to avoid having 
subtle bugs whenever someone uses something outside what our unit tests catch.
   

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       The int is not needed here?
   Also it might be worth trying to avoid `try: ... except Exception:` during 
non-error-processing in favour of `if isinstance(....): ... else:`.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       For which types do we want to do this (or alternatively which can go 
straight through)?

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")

Review comment:
       int32 here and index_size limit 2**63-1 feels strange.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -429,25 +507,56 @@ def _impl(inputs, input_types):
 
     return _impl
 
+def _full_impl(data, fill_value, dtype):
+    size = []
+    need_reshape = False
+    new_shape = []
+    for dim in data:
+        if isinstance(dim, _expr.Expr):
+            if isinstance(dim, _expr.Constant):
+                dim = int(dim.data.asnumpy())
+                if isinstance(size, list):
+                    size.append(dim)
+                new_shape.append(dim)
+            else:
+                try:
+                    dim = int(_infer_value(dim, {}).asnumpy())

Review comment:
       Here, too, maybe avoid: `try: .. except:`
   (there are more places, I didn't flag them all, but I think they should be 
all changed to use plain `if`).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I must admit that I'd appreciate if there were more commentary to the 
typing changes here.
   - In my opinion (and I could be wrong), it would be helpful to have a view 
what kind of types `input_types` and `inputs` can have and have a single place 
where we do implicit type promotion. I had hoped `_pytorch_promote_types` could 
be that.
   - If `_pytorch_promote_types` doesn't do the job, maybe we can comment why 
it isn't. Also why is this particular apparently particular elementwise ops as 
opposed to amending `_pytorch_promote_types`?
   
   I know this looks like I'm asking for busywork when you're mostly interested 
in getting a particular to work, but I have the impression that we would want 
to avoid ad hoc type workarounds as much as possible if we want to avoid having 
subtle bugs whenever someone uses something outside what our unit tests catch.
   

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       The int is not needed here?
   Also it might be worth trying to avoid `try: ... except Exception:` during 
non-error-processing in favour of `if isinstance(....): ... else:`.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       For which types do we want to do this (or alternatively which can go 
straight through)?

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")

Review comment:
       int32 here and index_size limit 2**63-1 feels strange.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -429,25 +507,56 @@ def _impl(inputs, input_types):
 
     return _impl
 
+def _full_impl(data, fill_value, dtype):
+    size = []
+    need_reshape = False
+    new_shape = []
+    for dim in data:
+        if isinstance(dim, _expr.Expr):
+            if isinstance(dim, _expr.Constant):
+                dim = int(dim.data.asnumpy())
+                if isinstance(size, list):
+                    size.append(dim)
+                new_shape.append(dim)
+            else:
+                try:
+                    dim = int(_infer_value(dim, {}).asnumpy())

Review comment:
       Here, too, maybe avoid: `try: .. except:`
   (there are more places, I didn't flag them all, but I think they should be 
all changed to use plain `if`).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       Still, I would prefer looking at what the type of `inputs[1]` is and 
have an `if`. We should at least know which types are good to leave as is (the 
current except block).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I think we would eventually want to look at using type propagation more.
   However, the issue here is that PyTorch's default dtype for integral tensors 
is int64. I don't think we should be hacking around that, really, because we're 
bound to end up with cases where int64 is the right thing to have. If I 
understood the discussions on the forum correctly, the idea was to downcast 64 
bit indexing to 32 based if it is considered safe.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       I'd have a strong preference for that, yeah.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I must admit that I'd appreciate if there were more commentary to the 
typing changes here.
   - In my opinion (and I could be wrong), it would be helpful to have a view 
what kind of types `input_types` and `inputs` can have and have a single place 
where we do implicit type promotion. I had hoped `_pytorch_promote_types` could 
be that.
   - If `_pytorch_promote_types` doesn't do the job, maybe we can comment why 
it isn't. Also why is this particular apparently particular elementwise ops as 
opposed to amending `_pytorch_promote_types`?
   
   I know this looks like I'm asking for busywork when you're mostly interested 
in getting a particular to work, but I have the impression that we would want 
to avoid ad hoc type workarounds as much as possible if we want to avoid having 
subtle bugs whenever someone uses something outside what our unit tests catch.
   

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       The int is not needed here?
   Also it might be worth trying to avoid `try: ... except Exception:` during 
non-error-processing in favour of `if isinstance(....): ... else:`.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       For which types do we want to do this (or alternatively which can go 
straight through)?

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")

Review comment:
       int32 here and index_size limit 2**63-1 feels strange.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -429,25 +507,56 @@ def _impl(inputs, input_types):
 
     return _impl
 
+def _full_impl(data, fill_value, dtype):
+    size = []
+    need_reshape = False
+    new_shape = []
+    for dim in data:
+        if isinstance(dim, _expr.Expr):
+            if isinstance(dim, _expr.Constant):
+                dim = int(dim.data.asnumpy())
+                if isinstance(size, list):
+                    size.append(dim)
+                new_shape.append(dim)
+            else:
+                try:
+                    dim = int(_infer_value(dim, {}).asnumpy())

Review comment:
       Here, too, maybe avoid: `try: .. except:`
   (there are more places, I didn't flag them all, but I think they should be 
all changed to use plain `if`).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       Still, I would prefer looking at what the type of `inputs[1]` is and 
have an `if`. We should at least know which types are good to leave as is (the 
current except block).

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I think we would eventually want to look at using type propagation more.
   However, the issue here is that PyTorch's default dtype for integral tensors 
is int64. I don't think we should be hacking around that, really, because we're 
bound to end up with cases where int64 is the right thing to have. If I 
understood the discussions on the forum correctly, the idea was to downcast 64 
bit indexing to 32 based if it is considered safe.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       I'd have a strong preference for that, yeah.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to