t-vi commented on a change in pull request #6449:
URL: https://github.com/apache/incubator-tvm/pull/6449#discussion_r486806914



##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -127,8 +128,22 @@ def _is_quantized_tensor(data, prelude):
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
-        data0, data1 = _pytorch_promote_types(inputs[:2], input_types[:2])
-        return get_relay_op(name)(data0, data1)
+        dtype0, dtype1 = input_types[:2]
+        if isinstance(inputs[0], _expr.Expr):
+            dtype0 = _infer_type(inputs[0]).checked_type.dtype
+        if isinstance(inputs[1], _expr.Expr):
+            dtype1 = _infer_type(inputs[1]).checked_type.dtype
+

Review comment:
       I must admit that I'd appreciate if there were more commentary to the 
typing changes here.
   - In my opinion (and I could be wrong), it would be helpful to have a view 
what kind of types `input_types` and `inputs` can have and have a single place 
where we do implicit type promotion. I had hoped `_pytorch_promote_types` could 
be that.
   - If `_pytorch_promote_types` doesn't do the job, maybe we can comment why 
it isn't. Also why is this particular apparently particular elementwise ops as 
opposed to amending `_pytorch_promote_types`?
   
   I know this looks like I'm asking for busywork when you're mostly interested 
in getting a particular to work, but I have the impression that we would want 
to avoid ad hoc type workarounds as much as possible if we want to avoid having 
subtle bugs whenever someone uses something outside what our unit tests catch.
   

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -364,7 +438,11 @@ def _impl(inputs, input_types):
 def _topk():
     def _impl(inputs, input_types):
         data = inputs[0]
-        k = int(inputs[1])
+        try:
+            k = int(_infer_value(inputs[1], {}).asnumpy().tolist())
+            k = _expr.const(k)
+        except Exception:
+            k = inputs[1]

Review comment:
       The int is not needed here?
   Also it might be worth trying to avoid `try: ... except Exception:` during 
non-error-processing in favour of `if isinstance(....): ... else:`.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")
+
         if isinstance(inputs[3], str) and inputs[3].isdigit():
-            end[dim] = min(end[dim], int(inputs[3]))
+            target_end = int(inputs[3])
         else:
-            if isinstance(inputs[3], _expr.Call):
-                target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+            if isinstance(inputs[3], _expr.Expr):
+                try:
+                    target_end = np.asscalar(_infer_value(inputs[3], 
{}).asnumpy().astype(np.int))
+                except Exception:

Review comment:
       For which types do we want to do this (or alternatively which can go 
straight through)?

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -274,38 +295,91 @@ def _impl(inputs, input_types):
 
 def _slice():
     def _impl(inputs, input_types):
+        index_size_limit = 2**63 - 1
         data = inputs[0]
-        strides = []
+        dshape = _infer_shape(data)
+        ndim = len(dshape)
+        end = []
+        for dim in dshape:
+            if isinstance(dim, tvm.tir.Any):
+                end = _op.shape_of(data)
+                break
+            end.append(int(dim))
 
-        if isinstance(data, _expr.Expr):
-            inferred_shape = _infer_shape(data)
-            end = []
-            for infer in inferred_shape:
-                end.append(int(infer))
-            if isinstance(data, _expr.Var):
-                end = inferred_shape
-                end = list(end)
-        else:
-            end = data.shape
-
-        begin = [0] * len(end)
+        begin = [0] * ndim
         dim = int(inputs[1])
+        stride = int(inputs[4])
         if isinstance(inputs[2], _expr.Call):
-            begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            try:
+                begin[dim] = np.asscalar(_infer_value(inputs[2], 
{}).asnumpy().astype(np.int))
+            except Exception:
+                begin[dim] = inputs[2]
         else:
             begin[dim] = int(inputs[2])
 
+        # Process begin
+        if not isinstance(begin[dim], int):
+            tmp = []
+            for b in begin:
+                if isinstance(b, int):
+                    tmp.append(_op.expand_dims(_expr.const(b, "int64"), 
axis=0))
+                else:
+                    tmp.append(_op.cast(_op.expand_dims(b, axis=0), "int64"))
+            begin = _op.concatenate(tmp, axis=0)
+            btype = _infer_type(begin).checked_type.dtype
+            if str(btype) != "int32":
+                begin = _op.cast(begin, "int32")

Review comment:
       int32 here and index_size limit 2**63-1 feels strange.

##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -429,25 +507,56 @@ def _impl(inputs, input_types):
 
     return _impl
 
+def _full_impl(data, fill_value, dtype):
+    size = []
+    need_reshape = False
+    new_shape = []
+    for dim in data:
+        if isinstance(dim, _expr.Expr):
+            if isinstance(dim, _expr.Constant):
+                dim = int(dim.data.asnumpy())
+                if isinstance(size, list):
+                    size.append(dim)
+                new_shape.append(dim)
+            else:
+                try:
+                    dim = int(_infer_value(dim, {}).asnumpy())

Review comment:
       Here, too, maybe avoid: `try: .. except:`
   (there are more places, I didn't flag them all, but I think they should be 
all changed to use plain `if`).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to