reminisce commented on a change in pull request #16660: [WIP] [Numpy] TVM 
implementation for binary ops
URL: https://github.com/apache/incubator-mxnet/pull/16660#discussion_r343827251
 
 

 ##########
 File path: contrib/tvmop/core/umath.py
 ##########
 @@ -120,3 +121,327 @@ def _compute_binary_scalar_logic(op, dtype, ndim):
           **_bin_scalar_logic_cpu_attrs)(_binary_logic_cpu)
     defop(name='{}_gpu'.format(op_name), op=op_name,
           **_bin_scalar_logic_gpu_attrs)(_binary_logic_gpu)
+
+
+_bin_cpu_attrs_base = {
+    'target': 'cpu',
+    'dtype': AllTypes,
+    'ndim': [5],
+    'req': ['kWriteTo', 'kAddTo'],
+    'attrs': ['req'],
+}
+
+_bin_gpu_attrs_base = {
+    'target': 'gpu',
+    'dtype': ["float32", "float64", "uint8", "int8", "int32", "int64"],
+    'ndim': [5],
+    'req': ['kWriteTo', 'kAddTo'],
+    'attrs': ['req'],
+}
+
+def _binary_cpu(compute_func, op, dtype, ndim, req):
+    s, a, b, old, new = compute_func(op, dtype, ndim, req)
+    axes = [axis for axis in new.op.axis]
+    fused = s[new].fuse(*axes)
+    s[new].parallel(fused)
+    return s, [a, b, old, new]
+
+
+def _binary_gpu(compute_func, op, dtype, ndim, req):
+    s, a, b, old, new = compute_func(op, dtype, ndim, req)
+    axes = [axis for axis in new.op.axis]
+    fused = s[new].fuse(*axes)
+    bx, tx = s[new].split(fused, factor=64)
+    s[new].bind(bx, tvm.thread_axis('blockIdx.x'))
+    s[new].bind(tx, tvm.thread_axis('threadIdx.x'))
+    return s, [a, b, old, new]
+
+_bin_op_map = {
+    'multiply': lambda a, b: a * b,
+    'add': lambda a, b: a + b,
+}
+
+def _compute_binary(op, dtype, ndim, req):
+    op = _bin_op_map[op]
+    a = tvm.placeholder([tvm.var() for _ in range(ndim)], dtype=dtype, 
name='a')
+    b = tvm.placeholder([tvm.var() for _ in range(ndim)], dtype=dtype, 
name='b')
+    c = tvm.compute([tvm.var() for _ in range(ndim)],
+                    lambda *idx: op(a[idx], b[idx]), name='c')
+    old, new = assign_by_req(c, req)
+    s = tvm.create_schedule(new.op)
+    s[c].compute_inline()
+    return s, a, b, old, new
+
+_bin_cpu_attrs = {
+    **_bin_cpu_attrs_base,
+    'compute_func': _compute_binary,
+    'auto_broadcast': True,
+}
+
+_bin_gpu_attrs = {
+    **_bin_gpu_attrs_base,
+    'compute_func': _compute_binary,
+    'auto_broadcast': True,
+}
+
+# register binary element-wise ops with broadcasting supported
+for op_name in _bin_op_map.keys():
+    defop(name='{}_cpu'.format(op_name), op=op_name, 
**_bin_cpu_attrs)(_binary_cpu)
+    defop(name='{}_gpu'.format(op_name), op=op_name, 
**_bin_gpu_attrs)(_binary_gpu)
+
+
+_bin_scalar_op_map = {
+    'multiply_scalar': lambda a, b: a * b.astype(a.dtype),
+    'add_scalar': lambda a, b: a + b.astype(a.dtype),
+}
+
+
+def _compute_binary_scalar(op, dtype, ndim, req):
+    op = _bin_scalar_op_map[op]
+    a = tvm.placeholder([tvm.var() for _ in range(ndim)], name='a', 
dtype=dtype)
+    b = tvm.var('b', dtype='float64')
+    c = tvm.compute([tvm.var() for _ in range(ndim)],
+                    lambda *idx: op(a[idx], b), name='c')
+    old, new = assign_by_req(c, req)
+    s = tvm.create_schedule(new.op)
+    s[c].compute_inline()
+    return s, a, b, old, new
+
+
+_bin_scalar_cpu_attrs = {
+    **_bin_cpu_attrs_base,
+    'compute_func': _compute_binary_scalar,
+}
+
+_bin_scalar_gpu_attrs = {
+    **_bin_gpu_attrs_base,
+    'compute_func': _compute_binary_scalar,
+}
+
+for op_name in _bin_scalar_op_map.keys():
+    defop(name='{}_cpu'.format(op_name), op=op_name,
+            **_bin_scalar_cpu_attrs)(_binary_cpu)
+    defop(name='{}_gpu'.format(op_name), op=op_name,
+            **_bin_scalar_gpu_attrs)(_binary_gpu)
+
+
+_bin_backward_cpu_attrs_base = {
+    'dtype': AllTypes,
+    'output': [0, 1],
+    'reduce1st': [0, 1],
+    'req': ['kWriteTo', 'kAddTo'],
+    'attrs': ["output", "reduce1st", "req"],
 
 Review comment:
   Rename `reduce1st` to `reduce1st_dim` for better readability?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to