szha closed pull request #10083: [TENSOR] Fix DLTensor conversion for int64 URL: https://github.com/apache/incubator-mxnet/pull/10083
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h index 59c1eacb2c5..6f604a5bb8d 100755 --- a/include/mxnet/tensor_blob.h +++ b/include/mxnet/tensor_blob.h @@ -322,16 +322,19 @@ class TBlob { private: static DLDataType DTypeTransform(int type_flag) { - static std::unordered_map<int, DLDataType> - MSHADOW_DTYPE_TO_DLPACK_DTYPE = { - {0, {2, 32, 1}}, // Float32 - {1, {2, 64, 1}}, // Float64 - {2, {2, 16, 1}}, // Float16 - {3, {1, 8, 1}}, // UInt8 - {4, {0, 32, 1}}, // Int32 - {5, {0, 8, 1}} // Int8 - }; - return MSHADOW_DTYPE_TO_DLPACK_DTYPE[type_flag]; + switch (type_flag) { + case mshadow::kFloat32: return DLDataType{kDLFloat, 32, 1}; + case mshadow::kFloat64: return DLDataType{kDLFloat, 64, 1}; + case mshadow::kFloat16: return DLDataType{kDLFloat, 16, 1}; + case mshadow::kUint8: return DLDataType{kDLUInt, 8, 1}; + case mshadow::kInt32: return DLDataType{kDLInt, 32, 1}; + case mshadow::kInt8: return DLDataType{kDLInt, 8, 1}; + case mshadow::kInt64: return DLDataType{kDLInt, 64, 1}; + default: { + LOG(FATAL) << "Unknown type_flag=" << type_flag; + return DLDataType(); + } + } } inline void SetDLTensor(int dev_mask, int dev_id) { diff --git a/tests/python/gpu/test_tvm_bridge.py b/tests/python/gpu/test_tvm_bridge.py index 292b9d91e5f..69a713d6a28 100644 --- a/tests/python/gpu/test_tvm_bridge.py +++ b/tests/python/gpu/test_tvm_bridge.py @@ -30,13 +30,13 @@ def test_tvm_bridge(): logging.warn("TVM bridge test skipped because TVM is missing...") return - def check(target): + def check(target, dtype): shape = (20,) scale = tvm.var("scale", dtype="float32") - x = tvm.placeholder(shape) - y = tvm.placeholder(shape) + x = tvm.placeholder(shape, dtype=dtype) + y = tvm.placeholder(shape, dtype=dtype) z = tvm.compute(shape, lambda i: x[i] + y[i]) - zz = tvm.compute(shape, lambda *i: z(*i) * scale) + zz = tvm.compute(shape, lambda *i: z(*i) * scale.astype(dtype)) ctx = mx.gpu(0) if target == "cuda" else mx.cpu(0) target = tvm.target.create(target) @@ -47,17 +47,18 @@ def check(target): # get a mxnet version mxf = tvm.contrib.mxnet.to_mxnet_func(f, const_loc=[0, 1]) - xx = mx.nd.uniform(shape=shape, ctx=ctx) - yy = mx.nd.uniform(shape=shape, ctx=ctx) - zz = mx.nd.empty(shape=shape, ctx=ctx) + xx = mx.nd.uniform(shape=shape, ctx=ctx).astype(dtype) + yy = mx.nd.uniform(shape=shape, ctx=ctx).astype(dtype) + zz = mx.nd.empty(shape=shape, ctx=ctx).astype(dtype) # invoke myf: this runs in mxnet engine mxf(xx, yy, zz, 10.0) np.testing.assert_allclose( zz.asnumpy(), (xx.asnumpy() + yy.asnumpy()) * 10) - check("llvm") - check("cuda") - + for tgt in ["llvm", "cuda"]: + for dtype in ["int8", "uint8", "int64", + "float32", "float64"]: + check(tgt, dtype) if __name__ == "__main__": ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services