This is an automated email from the ASF dual-hosted git repository.

apeforest pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 05f3ae1  Large Index Support for Slice (#15593)
05f3ae1 is described below

commit 05f3ae13a39835c4690eb5d9567f457be84189e3
Author: Rohit Kumar Srivastava <srivastava....@osu.edu>
AuthorDate: Mon Aug 12 21:23:31 2019 -0700

    Large Index Support for Slice (#15593)
    
    * Adding Large Index Support for slice operator
    
    * adding changes to fix py2 related error in CI/CD
    
    * fixing base.py
    
    * rearrange system call and slower Feature() call
    
    * refactoring c_api, c_symbolic_api, c_api_common
    
    * templatizing code
    
    * caching results of runtime features and minor refactoring
    
    * fixing local caching in ndarray shape
---
 include/mxnet/c_api.h               | 156 +++++++++++++++++++++++++--
 include/mxnet/c_predict_api.h       |   4 +-
 python/mxnet/base.py                |   1 +
 python/mxnet/ndarray/ndarray.py     |  49 ++++++---
 python/mxnet/symbol/symbol.py       |  75 ++++++++-----
 src/c_api/c_api.cc                  | 111 ++++++++++++-------
 src/c_api/c_api_common.h            |  14 +--
 src/c_api/c_api_executor.cc         |  12 +--
 src/c_api/c_api_ndarray.cc          |  10 +-
 src/c_api/c_api_profile.cc          |   2 +-
 src/c_api/c_api_symbolic.cc         | 206 ++++++++++++++++++++++++++----------
 src/imperative/imperative.cc        |   4 +-
 src/imperative/imperative_utils.h   |   2 +-
 src/operator/tensor/matrix_op-inl.h |  42 ++++----
 src/operator/tensor/slice-inl.h     |   6 +-
 tests/nightly/test_large_array.py   |   1 -
 tests/nightly/test_large_vector.py  |  37 +++++++
 17 files changed, 549 insertions(+), 183 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 20b2aa2..5ab10b6 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -55,7 +55,9 @@ extern "C" {
 #endif
 
 /*! \brief manually define unsigned int */
-typedef unsigned int mx_uint;
+typedef uint32_t mx_uint;
+/*! \brief manually define 64-bit int */
+typedef int64_t mx_int64;
 /*! \brief manually define float */
 typedef float mx_float;
 /*! \brief data type to store dim size */
@@ -572,6 +574,13 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
                               int dtype,
                               NDArrayHandle *out);
 
+MXNET_DLL int MXNDArrayCreateEx64(const mx_int64 *shape,
+                                  int ndim,
+                                  int dev_type,
+                                  int dev_id,
+                                  int delay_alloc,
+                                  int dtype,
+                                  NDArrayHandle *out);
 
 /*!
  * \brief create an empty sparse NDArray with specified shape and data type
@@ -603,6 +612,19 @@ MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
                                       const mx_uint *aux_shape,
                                       NDArrayHandle *out);
 
+MXNET_DLL int MXNDArrayCreateSparseEx64(int storage_type,
+                                        const mx_int64 *shape,
+                                        int ndim,
+                                        int dev_type,
+                                        int dev_id,
+                                        int delay_alloc,
+                                        int dtype,
+                                        mx_uint num_aux,
+                                        int *aux_type,
+                                        int *aux_ndims,
+                                        const mx_int64 *aux_shape,
+                                        NDArrayHandle *out);
+
 /*!
  * \brief create a NDArray handle that is loaded from raw bytes.
  * \param buf the head of the raw bytes
@@ -650,6 +672,12 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
                             mx_uint *out_name_size,
                             const char*** out_names);
 
+MXNET_DLL int MXNDArrayLoad64(const char* fname,
+                              mx_int64 *out_size,
+                              NDArrayHandle** out_arr,
+                              mx_int64 *out_name_size,
+                              const char*** out_names);
+
 /*!
  * \brief Load list / dictionary of narrays from file content loaded into 
memory.
  * This will load a list of ndarrays in a similar
@@ -665,11 +693,18 @@ MXNET_DLL int MXNDArrayLoad(const char* fname,
  * \return 0 when success, -1 when failure happens
  */
 MXNET_DLL int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
-                            size_t size,
-                            mx_uint *out_size,
-                            NDArrayHandle** out_arr,
-                            mx_uint *out_name_size,
-                            const char*** out_names);
+                                      size_t size,
+                                      mx_uint *out_size,
+                                      NDArrayHandle** out_arr,
+                                      mx_uint *out_name_size,
+                                      const char*** out_names);
+
+MXNET_DLL int MXNDArrayLoadFromBuffer64(const void *ndarray_buffer,
+                                        size_t size,
+                                        mx_int64 *out_size,
+                                        NDArrayHandle** out_arr,
+                                        mx_int64 *out_name_size,
+                                        const char*** out_names);
 
 /*!
  * \brief Perform a synchronize copy from a continugous CPU memory region.
@@ -809,6 +844,11 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
 MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
                                 mx_uint *out_dim,
                                 const mx_uint **out_pdata);
+
+MXNET_DLL int MXNDArrayGetShape64(NDArrayHandle handle,
+                                  int *out_dim,
+                                  const int64_t **out_pdata);
+
 /*!
  * \brief get the shape of the array
  * \param handle the handle to the narray
@@ -819,6 +859,11 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
 MXNET_DLL int MXNDArrayGetShapeEx(NDArrayHandle handle,
                                   int *out_dim,
                                   const int **out_pdata);
+
+MXNET_DLL int MXNDArrayGetShapeEx64(NDArrayHandle handle,
+                                    int *out_dim,
+                                    const mx_int64 **out_pdata);
+
 /*!
  * \brief get the content of the data in NDArray
  * \param handle the handle to the ndarray
@@ -902,6 +947,10 @@ MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
                                   mx_uint i,
                                   int *out_type);
 
+MXNET_DLL int MXNDArrayGetAuxType64(NDArrayHandle handle,
+                                    mx_int64 i,
+                                    int *out_type);
+
 /*!
  * \brief Get a deep copy of the ith aux data blob
  * in the form of an NDArray of default storage type.
@@ -911,6 +960,10 @@ MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
                                      mx_uint i,
                                      NDArrayHandle *out);
 
+MXNET_DLL int MXNDArrayGetAuxNDArray64(NDArrayHandle handle,
+                                       mx_int64 i,
+                                       NDArrayHandle *out);
+
 /*!
  * \brief Get a deep copy of the data blob
  * in the form of an NDArray of default storage type.
@@ -966,6 +1019,10 @@ MXNET_DLL int MXNDArrayGetGradState(NDArrayHandle handle, 
int *out);
  */
 MXNET_DLL int MXListFunctions(mx_uint *out_size,
                               FunctionHandle **out_array);
+
+MXNET_DLL int MXListFunctions64(mx_int64 *out_size,
+                                FunctionHandle **out_array);
+
 /*!
  * \brief get the function handle by name
  * \param name the name of the function
@@ -1233,6 +1290,10 @@ MXNET_DLL int MXInvokeCachedOpEx(CachedOpHandle handle,
  */
 MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
                                const char ***out_array);
+
+MXNET_DLL int MXListAllOpNames64(mx_int64 *out_size,
+                                 const char ***out_array);
+
 /*!
  * \brief list all the available AtomicSymbolEntry
  * \param out_size the size of returned array
@@ -1242,6 +1303,9 @@ MXNET_DLL int MXListAllOpNames(mx_uint *out_size,
 MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
                                                AtomicSymbolCreator 
**out_array);
 
+MXNET_DLL int MXSymbolListAtomicSymbolCreators64(mx_int64 *out_size,
+                                                 AtomicSymbolCreator 
**out_array);
+
 /*!
  * \brief Get the name of an atomic symbol.
  * \param creator the AtomicSymbolCreator.
@@ -1454,6 +1518,11 @@ MXNET_DLL int MXSymbolListAttrShallow(SymbolHandle 
symbol,
 MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
                                     mx_uint *out_size,
                                     const char ***out_str_array);
+
+MXNET_DLL int MXSymbolListArguments64(SymbolHandle symbol,
+                                      size_t *out_size,
+                                      const char ***out_str_array);
+
 /*!
  * \brief List returns in the symbol.
  * \param symbol the symbol
@@ -1465,6 +1534,10 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
                                   mx_uint *out_size,
                                   const char ***out_str_array);
 
+MXNET_DLL int MXSymbolListOutputs64(SymbolHandle symbol,
+                                    size_t *out_size,
+                                    const char ***out_str_array);
+
 /*!
  * \brief Get number of outputs of the symbol.
  * \param symbol The symbol
@@ -1472,7 +1545,7 @@ MXNET_DLL int MXSymbolListOutputs(SymbolHandle symbol,
  * \return 0 when success, -1 when failure happens
  */
 MXNET_DLL int MXSymbolGetNumOutputs(SymbolHandle symbol,
-                                     mx_uint *output_count);
+                                    mx_uint *output_count);
 
 /*!
  * \brief Get a symbol that contains all the internals.
@@ -1511,6 +1584,11 @@ MXNET_DLL int MXSymbolGetOutput(SymbolHandle symbol,
 MXNET_DLL int MXSymbolListAuxiliaryStates(SymbolHandle symbol,
                                           mx_uint *out_size,
                                           const char ***out_str_array);
+
+MXNET_DLL int MXSymbolListAuxiliaryStates64(SymbolHandle symbol,
+                                            size_t *out_size,
+                                            const char ***out_str_array);
+
 /*!
  * \brief Compose the symbol on other symbols.
  *
@@ -1582,6 +1660,22 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym,
                                  const mx_uint ***aux_shape_data,
                                  int *complete);
 
+MXNET_DLL int MXSymbolInferShape64(SymbolHandle sym,
+                                   mx_uint num_args,
+                                   const char** keys,
+                                   const mx_int64 *arg_ind_ptr,
+                                   const mx_int64 *arg_shape_data,
+                                   size_t *in_shape_size,
+                                   const int **in_shape_ndim,
+                                   const mx_int64 ***in_shape_data,
+                                   size_t *out_shape_size,
+                                   const int **out_shape_ndim,
+                                   const mx_int64 ***out_shape_data,
+                                   size_t *aux_shape_size,
+                                   const int **aux_shape_ndim,
+                                   const mx_int64 ***aux_shape_data,
+                                   int *complete);
+
 /*!
  * \brief infer shape of unknown input shapes given the known one.
  *  The shapes are packed into a CSR matrix represented by arg_ind_ptr and 
arg_shape_data
@@ -1619,6 +1713,23 @@ MXNET_DLL int MXSymbolInferShapeEx(SymbolHandle sym,
                                    const int **aux_shape_ndim,
                                    const int ***aux_shape_data,
                                    int *complete);
+
+MXNET_DLL int MXSymbolInferShapeEx64(SymbolHandle sym,
+                                     mx_uint num_args,
+                                     const char** keys,
+                                     const mx_int64 *arg_ind_ptr,
+                                     const mx_int64 *arg_shape_data,
+                                     size_t *in_shape_size,
+                                     const int **in_shape_ndim,
+                                     const mx_int64 ***in_shape_data,
+                                     size_t *out_shape_size,
+                                     const int **out_shape_ndim,
+                                     const mx_int64 ***out_shape_data,
+                                     size_t *aux_shape_size,
+                                     const int **aux_shape_ndim,
+                                     const mx_int64 ***aux_shape_data,
+                                     int *complete);
+
 /*!
  * \brief DEPRECATED. Use MXSymbolInferShapePartialEx instead.
  * partially infer shape of unknown input shapes given the known one.
@@ -1660,6 +1771,21 @@ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym,
                                         const mx_uint ***aux_shape_data,
                                         int *complete);
 
+MXNET_DLL int MXSymbolInferShapePartial64(SymbolHandle sym,
+                                          mx_uint num_args,
+                                          const char** keys,
+                                          const mx_int64 *arg_ind_ptr,
+                                          const mx_int64 *arg_shape_data,
+                                          size_t *in_shape_size,
+                                          const int **in_shape_ndim,
+                                          const mx_int64 ***in_shape_data,
+                                          size_t *out_shape_size,
+                                          const int **out_shape_ndim,
+                                          const mx_int64 ***out_shape_data,
+                                          size_t *aux_shape_size,
+                                          const int **aux_shape_ndim,
+                                          const mx_int64 ***aux_shape_data,
+                                          int *complete);
 
 /*!
  * \brief partially infer shape of unknown input shapes given the known one.
@@ -1701,6 +1827,22 @@ MXNET_DLL int MXSymbolInferShapePartialEx(SymbolHandle 
sym,
                                           const int ***aux_shape_data,
                                           int *complete);
 
+MXNET_DLL int MXSymbolInferShapePartialEx64(SymbolHandle sym,
+                                            mx_uint num_args,
+                                            const char** keys,
+                                            const mx_int64 *arg_ind_ptr,
+                                            const mx_int64 *arg_shape_data,
+                                            size_t *in_shape_size,
+                                            const int **in_shape_ndim,
+                                            const mx_int64 ***in_shape_data,
+                                            size_t *out_shape_size,
+                                            const int **out_shape_ndim,
+                                            const mx_int64 ***out_shape_data,
+                                            size_t *aux_shape_size,
+                                            const int **aux_shape_ndim,
+                                            const mx_int64 ***aux_shape_data,
+                                            int *complete);
+
 /*!
  * \brief infer type of unknown input types given the known one.
  *  The types are packed into a CSR matrix represented by arg_ind_ptr and 
arg_type_data
diff --git a/include/mxnet/c_predict_api.h b/include/mxnet/c_predict_api.h
index 18bec62..c79baa4 100644
--- a/include/mxnet/c_predict_api.h
+++ b/include/mxnet/c_predict_api.h
@@ -42,7 +42,9 @@ extern "C" {
 #endif
 
 /*! \brief manually define unsigned int */
-typedef unsigned int mx_uint;
+typedef uint32_t mx_uint;
+/*! \brief manually define 64-bit int */
+typedef int64_t mx_int64;
 /*! \brief manually define float */
 typedef float mx_float;
 /*! \brief handle to Predictor */
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 848c36f..dd5fcf0 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -215,6 +215,7 @@ _LIB = _load_lib()
 # type definitions
 mx_int = ctypes.c_int
 mx_uint = ctypes.c_uint
+mx_int64 = ctypes.c_int64
 mx_float = ctypes.c_float
 mx_float_p = ctypes.POINTER(mx_float)
 mx_real_t = _np.float32
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 171ba0a5..5f03c65 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -33,11 +33,13 @@ import ctypes
 import warnings
 import operator
 from functools import reduce # pylint: disable=redefined-builtin
+import sys
 import numpy as np
 from ..base import _LIB, numeric_types, integer_types
 from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
-from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int
+from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, 
mx_int64
 from ..base import ctypes2buffer
+from ..runtime import Features
 from ..context import Context, current_context
 from . import _internal
 from . import op
@@ -105,6 +107,14 @@ _NDARRAY_UNSUPPORTED_INDEXING = -1
 _NDARRAY_BASIC_INDEXING = 0
 _NDARRAY_ADVANCED_INDEXING = 1
 
+# Caching whether MXNet was built with INT64 support or not
+_INT64_TENSOR_SIZE_ENABLED = None
+
+def _int64_enabled():
+    global _INT64_TENSOR_SIZE_ENABLED
+    if _INT64_TENSOR_SIZE_ENABLED is None:
+        _INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE')
+    return _INT64_TENSOR_SIZE_ENABLED
 
 def _new_empty_handle():
     """Returns a new empty handle.
@@ -132,14 +142,24 @@ def _new_alloc_handle(shape, ctx, delay_alloc, 
dtype=mx_real_t):
         A new empty `NDArray` handle.
     """
     hdl = NDArrayHandle()
-    check_call(_LIB.MXNDArrayCreateEx(
-        c_array_buf(mx_uint, native_array('I', shape)),
-        mx_uint(len(shape)),
-        ctypes.c_int(ctx.device_typeid),
-        ctypes.c_int(ctx.device_id),
-        ctypes.c_int(int(delay_alloc)),
-        ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
-        ctypes.byref(hdl)))
+    if sys.version_info[0] > 2 and _int64_enabled():
+        check_call(_LIB.MXNDArrayCreateEx64(
+            c_array_buf(mx_int64, native_array('q', shape)),
+            ctypes.c_int(len(shape)),
+            ctypes.c_int(ctx.device_typeid),
+            ctypes.c_int(ctx.device_id),
+            ctypes.c_int(int(delay_alloc)),
+            ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
+            ctypes.byref(hdl)))
+    else:
+        check_call(_LIB.MXNDArrayCreateEx(
+            c_array_buf(mx_uint, native_array('I', shape)),
+            mx_uint(len(shape)),
+            ctypes.c_int(ctx.device_typeid),
+            ctypes.c_int(ctx.device_id),
+            ctypes.c_int(int(delay_alloc)),
+            ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
+            ctypes.byref(hdl)))
     return hdl
 
 
@@ -2218,9 +2238,14 @@ fixed-size items.
         (2L, 3L, 4L)
         """
         ndim = mx_int()
-        pdata = ctypes.POINTER(mx_int)()
-        check_call(_LIB.MXNDArrayGetShapeEx(
-            self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
+        if _int64_enabled():
+            pdata = ctypes.POINTER(mx_int64)()
+            check_call(_LIB.MXNDArrayGetShapeEx64(
+                self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
+        else:
+            pdata = ctypes.POINTER(mx_int)()
+            check_call(_LIB.MXNDArrayGetShapeEx(
+                self.handle, ctypes.byref(ndim), ctypes.byref(pdata)))
         if ndim.value == -1:
             return None
         else:
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 542b379..3ac4417 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -29,17 +29,17 @@ from array import array
 import ctypes
 import warnings
 from numbers import Number
-
+import sys
 import numpy as _numpy  # pylint: disable=relative-import
 
 from ..attribute import AttrScope
 from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, 
c_str_array, c_handle_array
-from ..base import mx_uint, py_str, string_types, integer_types, mx_int
+from ..base import mx_uint, py_str, string_types, integer_types, mx_int, 
mx_int64
 from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle
 from ..base import check_call, MXNetError, NotImplementedForSymbol
 from ..context import Context, current_context
 from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP
-from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
+from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _int64_enabled
 from ..ndarray import _ndarray_cls
 from ..executor import Executor
 from . import _internal
@@ -1207,34 +1207,59 @@ class Symbol(SymbolBase):
             keys = c_str_array(str_keys)
         arg_shape_size = mx_uint()
         arg_shape_ndim = ctypes.POINTER(mx_int)()
-        arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
         out_shape_size = mx_uint()
         out_shape_ndim = ctypes.POINTER(mx_int)()
-        out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
         aux_shape_size = mx_uint()
         aux_shape_ndim = ctypes.POINTER(mx_int)()
-        aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
         complete = ctypes.c_int()
-        if partial:
-            infer_func = _LIB.MXSymbolInferShapePartialEx
+        if sys.version_info[0] > 2 and _int64_enabled():
+            arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()
+            out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()
+            aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int64))()
+            if partial:
+                infer_func = _LIB.MXSymbolInferShapePartialEx64
+            else:
+                infer_func = _LIB.MXSymbolInferShapeEx64
+            check_call(infer_func(
+                self.handle,
+                mx_uint(len(indptr) - 1),
+                keys,
+                c_array_buf(mx_int64, array('q', indptr)),
+                c_array_buf(mx_int64, array('q', sdata)),
+                ctypes.byref(arg_shape_size),
+                ctypes.byref(arg_shape_ndim),
+                ctypes.byref(arg_shape_data),
+                ctypes.byref(out_shape_size),
+                ctypes.byref(out_shape_ndim),
+                ctypes.byref(out_shape_data),
+                ctypes.byref(aux_shape_size),
+                ctypes.byref(aux_shape_ndim),
+                ctypes.byref(aux_shape_data),
+                ctypes.byref(complete)))
         else:
-            infer_func = _LIB.MXSymbolInferShapeEx
-        check_call(infer_func(
-            self.handle,
-            mx_uint(len(indptr) - 1),
-            keys,
-            c_array_buf(mx_uint, array('I', indptr)),
-            c_array_buf(mx_int, array('i', sdata)),
-            ctypes.byref(arg_shape_size),
-            ctypes.byref(arg_shape_ndim),
-            ctypes.byref(arg_shape_data),
-            ctypes.byref(out_shape_size),
-            ctypes.byref(out_shape_ndim),
-            ctypes.byref(out_shape_data),
-            ctypes.byref(aux_shape_size),
-            ctypes.byref(aux_shape_ndim),
-            ctypes.byref(aux_shape_data),
-            ctypes.byref(complete)))
+            arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
+            out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
+            aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))()
+            if partial:
+                infer_func = _LIB.MXSymbolInferShapePartialEx
+            else:
+                infer_func = _LIB.MXSymbolInferShapeEx
+            check_call(infer_func(
+                self.handle,
+                mx_uint(len(indptr) - 1),
+                keys,
+                c_array_buf(mx_uint, array('I', indptr)),
+                c_array_buf(mx_int, array('i', sdata)),
+                ctypes.byref(arg_shape_size),
+                ctypes.byref(arg_shape_ndim),
+                ctypes.byref(arg_shape_data),
+                ctypes.byref(out_shape_size),
+                ctypes.byref(out_shape_ndim),
+                ctypes.byref(out_shape_data),
+                ctypes.byref(aux_shape_size),
+                ctypes.byref(aux_shape_ndim),
+                ctypes.byref(aux_shape_data),
+                ctypes.byref(complete)))
         if complete.value != 0:
             arg_shapes = [tuple(arg_shape_data[i][:arg_shape_ndim[i]])
                           if arg_shape_ndim[i] >= 0 else None
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 1764845..c2b80b3 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -67,7 +67,7 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
                                    const char ***arg_type_infos,
                                    const char ***arg_descriptions,
                                    const char **return_type) {
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
 
   API_BEGIN();
   *name = e->name.c_str();
@@ -189,6 +189,19 @@ int MXNDArrayCreateNone(NDArrayHandle *out) {
   API_END();
 }
 
+template<typename DataType, typename dimtype>
+void CreateNDArray(const DataType* shape,
+                   dimtype ndim,
+                   int dev_type,
+                   int dev_id,
+                   int delay_alloc,
+                   int dtype,
+                   NDArrayHandle* out) {
+  *out = new NDArray(mxnet::TShape(shape, shape + ndim),
+                     
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
+                     delay_alloc != 0, dtype);
+}
+
 int MXNDArrayCreate(const mx_uint *shape,
                     mx_uint ndim,
                     int dev_type,
@@ -196,41 +209,48 @@ int MXNDArrayCreate(const mx_uint *shape,
                     int delay_alloc,
                     NDArrayHandle *out) {
   API_BEGIN();
-  *out = new NDArray(
-      mxnet::TShape(shape, shape + ndim),
-      Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
-      delay_alloc != 0);
+  *out = new NDArray(mxnet::TShape(shape, shape + ndim),
+                     
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
+                     delay_alloc != 0);
+  API_END();
+}
+
+int MXNDArrayCreateEx64(const mx_int64 *shape,
+                        int ndim,
+                        int dev_type,
+                        int dev_id,
+                        int delay_alloc,
+                        int dtype,
+                        NDArrayHandle *out) {
+  API_BEGIN();
+  CreateNDArray<mx_int64, int>(shape, ndim, dev_type, dev_id, delay_alloc, 
dtype, out);
   API_END();
 }
 
 int MXNDArrayCreateEx(const mx_uint *shape,
-                    mx_uint ndim,
-                    int dev_type,
-                    int dev_id,
-                    int delay_alloc,
-                    int dtype,
-                    NDArrayHandle *out) {
+                      mx_uint ndim,
+                      int dev_type,
+                      int dev_id,
+                      int delay_alloc,
+                      int dtype,
+                      NDArrayHandle *out) {
   API_BEGIN();
-  *out = new NDArray(
-      mxnet::TShape(shape, shape + ndim),
-      Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
-      delay_alloc != 0,
-      dtype);
+  CreateNDArray<mx_uint, mx_uint>(shape, ndim, dev_type, dev_id, delay_alloc, 
dtype, out);
   API_END();
 }
 
 int MXNDArrayCreateSparseEx(int storage_type,
-                    const mx_uint *shape,
-                    mx_uint ndim,
-                    int dev_type,
-                    int dev_id,
-                    int delay_alloc,
-                    int dtype,
-                    mx_uint num_aux,
-                    int *aux_type,
-                    mx_uint *aux_ndims,
-                    const mx_uint *aux_shape,
-                    NDArrayHandle *out) {
+                            const mx_uint *shape,
+                            mx_uint ndim,
+                            int dev_type,
+                            int dev_id,
+                            int delay_alloc,
+                            int dtype,
+                            mx_uint num_aux,
+                            int *aux_type,
+                            mx_uint *aux_ndims,
+                            const mx_uint *aux_shape,
+                            NDArrayHandle *out) {
   API_BEGIN();
   std::vector<int> aux_types;
   mxnet::ShapeVector aux_shapes;
@@ -269,7 +289,7 @@ int MXNDArrayLoadFromRawBytes(const void *buf,
 int MXNDArraySaveRawBytes(NDArrayHandle handle,
                           size_t *out_size,
                           const char **out_buf) {
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   ret->ret_str.resize(0);
   dmlc::MemoryStringStream strm(&ret->ret_str);
@@ -365,7 +385,7 @@ int MXNDArrayLoad(const char* fname,
                   NDArrayHandle** out_arr,
                   mx_uint *out_name_size,
                   const char*** out_names) {
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   ret->ret_vec_str.clear();
   API_BEGIN();
   std::vector<NDArray> data;
@@ -397,7 +417,7 @@ int MXNDArrayLoadFromBuffer(const void *ndarray_buffer,
                             NDArrayHandle** out_arr,
                             mx_uint *out_name_size,
                             const char*** out_names) {
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   ret->ret_vec_str.clear();
   API_BEGIN();
   CHECK_NOTNULL(ndarray_buffer);
@@ -521,7 +541,7 @@ int MXNDArrayGetStorageType(NDArrayHandle handle,
 int MXNDArrayGetShape(NDArrayHandle handle,
                       mx_uint *out_dim,
                       const mx_uint **out_pdata) {
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   NDArray *arr = static_cast<NDArray*>(handle);
   if (!arr->is_none()) {
@@ -537,12 +557,10 @@ int MXNDArrayGetShape(NDArrayHandle handle,
   API_END();
 }
 
-int MXNDArrayGetShapeEx(NDArrayHandle handle,
-                        int *out_dim,
-                        const int **out_pdata) {
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
-  API_BEGIN();
-  NDArray *arr = static_cast<NDArray*>(handle);
+template<typename dtype>
+inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* 
out_dim,
+                     MXAPIThreadLocalEntry<dtype>* ret) {
+  NDArray* arr = static_cast<NDArray*>(handle);
   if (!arr->is_none()) {
     mxnet::TShape s = arr->shape();
     if (!Imperative::Get()->is_np_shape()) {
@@ -550,7 +568,7 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle,
     }
     *out_dim = s.ndim();
     if (s.ndim() >= 0) {
-      std::vector<int> &buffer = ret->arg_shape_buffer_ex;
+      std::vector<dtype> &buffer = ret->arg_shape_buffer_ex;
       buffer.resize(s.ndim());
       mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data());
       *out_pdata = buffer.data();
@@ -562,6 +580,23 @@ int MXNDArrayGetShapeEx(NDArrayHandle handle,
       *out_dim = 0;
     }
   }
+}
+
+int MXNDArrayGetShapeEx(NDArrayHandle handle,
+                        int *out_dim,
+                        const int **out_pdata) {
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
+  API_BEGIN();
+  GetShape<int>(handle, out_pdata, out_dim, ret);
+  API_END();
+}
+
+int MXNDArrayGetShapeEx64(NDArrayHandle handle,
+                          int *out_dim,
+                          const mx_int64 **out_pdata) {
+  MXAPIThreadLocalEntry<int64_t> *ret = MXAPIThreadLocalStore<int64_t>::Get();
+  API_BEGIN();
+  GetShape<mx_int64>(handle, out_pdata, out_dim, ret);
   API_END();
 }
 
diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h
index 233acc8..93fcff0 100644
--- a/src/c_api/c_api_common.h
+++ b/src/c_api/c_api_common.h
@@ -57,6 +57,7 @@
 using namespace mxnet;
 
 /*! \brief entry to to easily hold returning information */
+template<typename dtype = int>
 struct MXAPIThreadLocalEntry {
   /*! \brief result holder for returning string */
   std::string ret_str;
@@ -81,11 +82,11 @@ struct MXAPIThreadLocalEntry {
   /*! \brief result holder for returning shape pointer */
   std::vector<const mx_uint*> arg_shape_data, out_shape_data, aux_shape_data;
   /*! \brief result holder for returning shape pointer */
-  std::vector<const int*> arg_shape_data_ex, out_shape_data_ex, 
aux_shape_data_ex;
+  std::vector<const dtype*> arg_shape_data_ex, out_shape_data_ex, 
aux_shape_data_ex;
   /*! \brief uint32_t buffer for returning shape pointer */
   std::vector<uint32_t> arg_shape_buffer, out_shape_buffer, aux_shape_buffer;
   /*! \brief uint32_t buffer for returning shape pointer */
-  std::vector<int> arg_shape_buffer_ex, out_shape_buffer_ex, 
aux_shape_buffer_ex;
+  std::vector<dtype> arg_shape_buffer_ex, out_shape_buffer_ex, 
aux_shape_buffer_ex;
   /*! \brief bool buffer */
   std::vector<bool> save_inputs, save_outputs;
   // DEPRECATED. Use SetupShapeArrayReturnWithBufferEx instead.
@@ -111,8 +112,8 @@ struct MXAPIThreadLocalEntry {
   inline static void SetupShapeArrayReturnWithBufferEx(
       const mxnet::ShapeVector &shapes,
       std::vector<int> *ndim,
-      std::vector<const int*> *data,
-      std::vector<int> *buffer) {
+      std::vector<const dtype*> *data,
+      std::vector<dtype> *buffer) {
     ndim->resize(shapes.size());
     data->resize(shapes.size());
     size_t size = 0;
@@ -122,7 +123,7 @@ struct MXAPIThreadLocalEntry {
       }
     }
     buffer->resize(size);
-    int *ptr = buffer->data();
+    dtype* ptr = buffer->data();
     for (size_t i = 0; i < shapes.size(); ++i) {
       ndim->at(i) = shapes[i].ndim();
       data->at(i) = ptr;
@@ -134,7 +135,8 @@ struct MXAPIThreadLocalEntry {
 };
 
 // define the threadlocal store.
-typedef dmlc::ThreadLocalStore<MXAPIThreadLocalEntry> MXAPIThreadLocalStore;
+template<typename dtype = int>
+using MXAPIThreadLocalStore = 
dmlc::ThreadLocalStore<MXAPIThreadLocalEntry<dtype>>;
 
 namespace mxnet {
 // copy attributes from inferred vector back to the vector of each type.
diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc
index ebe3f17..31b74b5 100644
--- a/src/c_api/c_api_executor.cc
+++ b/src/c_api/c_api_executor.cc
@@ -32,7 +32,7 @@
 
 int MXExecutorPrint(ExecutorHandle handle, const char **out_str) {
   Executor *exec = static_cast<Executor*>(handle);
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   std::ostringstream os;
   exec->Print(os);
@@ -78,7 +78,7 @@ int MXExecutorBackwardEx(ExecutorHandle handle,
 int MXExecutorOutputs(ExecutorHandle handle,
                       mx_uint *out_size,
                       NDArrayHandle **out) {
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   Executor *exec = static_cast<Executor*>(handle);
   std::vector<NDArray> heads = exec->outputs();
@@ -252,7 +252,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle,
                          NDArrayHandle** aux_states,
                          ExecutorHandle shared_exec_handle,
                          ExecutorHandle* out) {
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(symbol_handle);
 
@@ -586,7 +586,7 @@ int MXExecutorSimpleBindEx(SymbolHandle symbol_handle,
                            NDArrayHandle** aux_states,
                            ExecutorHandle shared_exec_handle,
                            ExecutorHandle* out) {
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(symbol_handle);
 
@@ -870,7 +870,7 @@ int MXExecutorReshape(int partial_shaping,
                       ExecutorHandle *out) {
   Executor* new_exec = nullptr;
 
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   *out = nullptr;  // ensure we can know whether to free executor on early 
abort
   // create shape map for in_args and aux_states
@@ -961,7 +961,7 @@ int MXExecutorReshapeEx(int partial_shaping,
                         ExecutorHandle *out) {
   Executor* new_exec = nullptr;
 
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   *out = nullptr;  // ensure we can know whether to free executor on early 
abort
   // create shape map for in_args and aux_states
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index c9c6000..4546659 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -87,7 +87,7 @@ void MXImperativeInvokeImpl(AtomicSymbolCreator creator,
                             const char **param_keys,
                             const char **param_vals) {
   const nnvm::Op* op = static_cast<nnvm::Op*>(creator);
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
 
   nnvm::NodeAttrs attrs = imperative::ParseAttrs(op, num_inputs, num_params,
                                                  param_keys, param_vals);
@@ -138,7 +138,7 @@ int MXImperativeInvokeEx(AtomicSymbolCreator creator,
                          const char **param_keys,
                          const char **param_vals,
                          const int **out_stypes) {  // outputs storage types
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   MXImperativeInvokeImpl(creator, num_inputs, inputs, num_outputs, outputs,
                          num_params, param_keys, param_vals);
@@ -194,7 +194,7 @@ int MXInvokeCachedOp(CachedOpHandle handle,
                      NDArrayHandle *inputs,
                      int *num_outputs,
                      NDArrayHandle **outputs) {
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
 
   API_BEGIN();
   CachedOpPtr op = *static_cast<CachedOpPtr*>(handle);
@@ -238,7 +238,7 @@ int MXInvokeCachedOpEx(CachedOpHandle handle,
                        int *num_outputs,
                        NDArrayHandle **outputs,
                        const int **out_stypes) {  // outputs storage types
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   int err = MXInvokeCachedOp(handle, num_inputs, inputs, num_outputs, outputs);
   if (err != 0) return err;
   API_BEGIN();
@@ -331,7 +331,7 @@ int MXAutogradBackwardEx(mx_uint num_output,
                          int is_train,
                          NDArrayHandle **grad_handles,
                          int **grad_stypes) {
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
 
   std::vector<NDArray*> outputs, ograds, variables;
diff --git a/src/c_api/c_api_profile.cc b/src/c_api/c_api_profile.cc
index cec7028..5eb219a 100644
--- a/src/c_api/c_api_profile.cc
+++ b/src/c_api/c_api_profile.cc
@@ -312,7 +312,7 @@ int MXAggregateProfileStatsPrint(const char **out_str, int 
reset) {
 
 int MXAggregateProfileStatsPrintEx(const char **out_str, int reset, int 
format, int sort_by,
                                   int ascending) {
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
     CHECK_NOTNULL(out_str);
     profiler::Profiler *profiler = profiler::Profiler::Get();
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 020c0d1..e8d59d9 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -98,7 +98,7 @@ int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
                                 const char **return_type) {
   static auto& map_key_var_args = 
nnvm::Op::GetAttr<std::string>("key_var_num_args");
   const Op* op = static_cast<Op*>(creator);
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   ret->ret_str.resize(0);
 
   if (map_key_var_args.count(op) != 0) {
@@ -203,7 +203,7 @@ int MXSymbolGetAttr(SymbolHandle symbol,
                     const char** out,
                     int* success) {
   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   if (s->GetAttr(key, &(ret->ret_str))) {
     *out = (ret->ret_str).c_str();
@@ -251,7 +251,7 @@ int MXSymbolListAttr(SymbolHandle symbol,
                      mx_uint *out_size,
                      const char*** out) {
   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   std::vector<std::tuple<std::string, std::string, std::string> > attr =
       s->ListAttrsRecursive();
@@ -281,7 +281,7 @@ int MXSymbolListAttrShallow(SymbolHandle symbol,
                             mx_uint *out_size,
                             const char*** out) {
   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   std::unordered_map<std::string, std::string> attr =
       s->ListAttrs(static_cast<nnvm::Symbol::ListAttrOption>(1));  // NOLINT(*)
@@ -360,7 +360,7 @@ int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle 
**input_arr, int *inp
   std::vector<nnvm::Symbol *> input_syms = mxnet::GetInputSymbols(*s);
   *input_size = input_syms.size();
 
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   ret->ret_handles.clear();
   ret->ret_handles.reserve(*input_size);
   for (int i = 0; i < *input_size; ++i) 
ret->ret_handles.push_back(input_syms[i]);
@@ -405,7 +405,7 @@ int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle 
**input_symbols,
     }
     *input_size = input_syms.size();
 
-    MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+    MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
     ret->ret_handles.clear();
     ret->ret_handles.reserve(*input_size);
     for (int i = 0; i < *input_size; ++i) 
ret->ret_handles.push_back(input_syms[i]);
@@ -464,7 +464,7 @@ int MXSymbolSaveToFile(SymbolHandle symbol, const char 
*fname) {
 
 int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json) {
   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(symbol);
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   ret->ret_str = nnvm::pass::SaveJSON(Symbol2Graph(*s));
   *out_json = ret->ret_str.c_str();
@@ -528,7 +528,7 @@ int MXSymbolInferShape(SymbolHandle sym,
                        const mx_uint ***aux_shape_data,
                        int *complete) {
   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   nnvm::Graph g = Symbol2Graph(*s);
   mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), 
mxnet::TShape());
@@ -565,11 +565,11 @@ int MXSymbolInferShape(SymbolHandle sym,
            &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes));
 
   // copy data back
-  MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBuffer(ret->arg_shapes,
+  MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->arg_shapes,
       &(ret->arg_shape_ndim), &(ret->arg_shape_data), 
&(ret->arg_shape_buffer));
-  MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBuffer(ret->out_shapes,
+  MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->out_shapes,
       &(ret->out_shape_ndim), &(ret->out_shape_data), 
&(ret->out_shape_buffer));
-  MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBuffer(ret->aux_shapes,
+  MXAPIThreadLocalEntry<>::SetupShapeArrayReturnWithBuffer(ret->aux_shapes,
       &(ret->aux_shape_ndim), &(ret->aux_shape_data), 
&(ret->aux_shape_buffer));
   *in_shape_size = static_cast<mx_uint>(ret->arg_shapes.size());
   *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim);
@@ -585,76 +585,149 @@ int MXSymbolInferShape(SymbolHandle sym,
   API_END();
 }
 
-int MXSymbolInferShapeEx(SymbolHandle sym,
-                         mx_uint num_args,
-                         const char** keys,
-                         const mx_uint *arg_ind_ptr,
-                         const int *arg_shape_data,
-                         mx_uint *in_shape_size,
-                         const int **in_shape_ndim,
-                         const int ***in_shape_data,
-                         mx_uint *out_shape_size,
-                         const int **out_shape_ndim,
-                         const int ***out_shape_data,
-                         mx_uint *aux_shape_size,
-                         const int **aux_shape_ndim,
-                         const int ***aux_shape_data,
-                         int *complete) {
-  nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
-  API_BEGIN();
+template<typename dtype, typename stype, typename itype>
+inline void SymbolInferShape(const char** keys,
+                             mx_uint num_args,
+                             const dtype* arg_shape_data,
+                             const itype* arg_ind_ptr,
+                             const int** in_shape_ndim,
+                             const dtype*** in_shape_data,
+                             const int** out_shape_ndim,
+                             const dtype*** out_shape_data,
+                             const int** aux_shape_ndim,
+                             const dtype*** aux_shape_data,
+                             nnvm::Symbol* s,
+                             MXAPIThreadLocalEntry<dtype>* ret,
+                             stype* in_shape_size,
+                             stype* out_shape_size,
+                             stype* aux_shape_size,
+                             int* complete) {
   nnvm::Graph g = Symbol2Graph(*s);
   mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), 
mxnet::TShape());
   if (keys == nullptr && num_args != 0) {
-    std::vector<uint32_t> read_only_args = 
mxnet::ReadOnlyArgIndices(g.indexed_graph());
+    std::vector < uint32_t > read_only_args = 
mxnet::ReadOnlyArgIndices(g.indexed_graph());
     CHECK_LE(num_args, read_only_args.size());
     for (mx_uint i = 0; i < num_args; ++i) {
-      arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(
-          arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]);
+      arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(arg_shape_data + 
arg_ind_ptr[i],
+                                                           arg_shape_data + 
arg_ind_ptr[i + 1]);
     }
   } else {
     std::unordered_map<std::string, mxnet::TShape> kwargs;
     for (mx_uint i = 0; i < num_args; ++i) {
-      kwargs[keys[i]] = mxnet::ShapeTypeCast(
-          arg_shape_data + arg_ind_ptr[i], arg_shape_data + arg_ind_ptr[i+1]);
+      kwargs[keys[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i],
+                                             arg_shape_data + arg_ind_ptr[i + 
1]);
     }
     mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_shapes, 
"InferShape");
   }
-
   try {
     g = mxnet::exec::InferShape(std::move(g), std::move(arg_shapes), 
"__shape__");
-  } catch (const mxnet::op::InferShapeError &err) {
+  } catch (const mxnet::op::InferShapeError& err) {
     throw dmlc::Error(err.msg);
   }
-
   // if use legacy shape definition, need to convert numpy shape to legacy 
shape
   mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
   if (!Imperative::Get()->is_np_shape()) {
     common::ConvertToLegacyShape(&shapes);
   }
-
   // copy back
-  CopyAttr(g.indexed_graph(), shapes,
-           &(ret->arg_shapes), &(ret->out_shapes), &(ret->aux_shapes));
-
+  CopyAttr(g.indexed_graph(), shapes, &(ret->arg_shapes), &(ret->out_shapes), 
&(ret->aux_shapes));
   // copy data back
-  MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->arg_shapes,
-      &(ret->arg_shape_ndim_ex), &(ret->arg_shape_data_ex), 
&(ret->arg_shape_buffer_ex));
-  MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->out_shapes,
-      &(ret->out_shape_ndim_ex), &(ret->out_shape_data_ex), 
&(ret->out_shape_buffer_ex));
-  MXAPIThreadLocalEntry::SetupShapeArrayReturnWithBufferEx(ret->aux_shapes,
-      &(ret->aux_shape_ndim_ex), &(ret->aux_shape_data_ex), 
&(ret->aux_shape_buffer_ex));
-  *in_shape_size = static_cast<mx_uint>(ret->arg_shapes.size());
+  
MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->arg_shapes,
+                                                                  
&(ret->arg_shape_ndim_ex),
+                                                                  
&(ret->arg_shape_data_ex),
+                                                                  
&(ret->arg_shape_buffer_ex));
+  
MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->out_shapes,
+                                                                  
&(ret->out_shape_ndim_ex),
+                                                                  
&(ret->out_shape_data_ex),
+                                                                  
&(ret->out_shape_buffer_ex));
+  
MXAPIThreadLocalEntry<dtype>::SetupShapeArrayReturnWithBufferEx(ret->aux_shapes,
+                                                                  
&(ret->aux_shape_ndim_ex),
+                                                                  
&(ret->aux_shape_data_ex),
+                                                                  
&(ret->aux_shape_buffer_ex));
+  *in_shape_size = static_cast<stype>(ret->arg_shapes.size());
   *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim_ex);
   *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex);
-  *out_shape_size = static_cast<mx_uint>(ret->out_shapes.size());
+  *out_shape_size = static_cast<stype>(ret->out_shapes.size());
   *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim_ex);
   *out_shape_data = dmlc::BeginPtr(ret->out_shape_data_ex);
-  *aux_shape_size = static_cast<mx_uint>(ret->aux_shapes.size());
+  *aux_shape_size = static_cast<stype>(ret->aux_shapes.size());
   *aux_shape_ndim = dmlc::BeginPtr(ret->aux_shape_ndim_ex);
   *aux_shape_data = dmlc::BeginPtr(ret->aux_shape_data_ex);
   // mark complete
   *complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
+}
+
+int MXSymbolInferShapeEx(SymbolHandle sym,
+                         mx_uint num_args,
+                         const char** keys,
+                         const mx_uint *arg_ind_ptr,
+                         const int *arg_shape_data,
+                         mx_uint *in_shape_size,
+                         const int **in_shape_ndim,
+                         const int ***in_shape_data,
+                         mx_uint *out_shape_size,
+                         const int **out_shape_ndim,
+                         const int ***out_shape_data,
+                         mx_uint *aux_shape_size,
+                         const int **aux_shape_ndim,
+                         const int ***aux_shape_data,
+                         int *complete) {
+  nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
+  API_BEGIN();
+  SymbolInferShape<int, mx_uint, mx_uint>(keys,
+                                              num_args,
+                                              arg_shape_data,
+                                              arg_ind_ptr,
+                                              in_shape_ndim,
+                                              in_shape_data,
+                                              out_shape_ndim,
+                                              out_shape_data,
+                                              aux_shape_ndim,
+                                              aux_shape_data,
+                                              s,
+                                              ret,
+                                              in_shape_size,
+                                              out_shape_size,
+                                              aux_shape_size,
+                                              complete);
+  API_END();
+}
+
+int MXSymbolInferShapeEx64(SymbolHandle sym,
+                           mx_uint num_args,
+                           const char** keys,
+                           const int64_t *arg_ind_ptr,
+                           const int64_t *arg_shape_data,
+                           size_t *in_shape_size,
+                           const int **in_shape_ndim,
+                           const int64_t ***in_shape_data,
+                           size_t *out_shape_size,
+                           const int **out_shape_ndim,
+                           const int64_t ***out_shape_data,
+                           size_t *aux_shape_size,
+                           const int **aux_shape_ndim,
+                           const int64_t ***aux_shape_data,
+                           int *complete) {
+  nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
+  MXAPIThreadLocalEntry<int64_t> *ret = MXAPIThreadLocalStore<int64_t>::Get();
+  API_BEGIN();
+  SymbolInferShape<int64_t, size_t, int64_t>(keys,
+                                                 num_args,
+                                                 arg_shape_data,
+                                                 arg_ind_ptr,
+                                                 in_shape_ndim,
+                                                 in_shape_data,
+                                                 out_shape_ndim,
+                                                 out_shape_data,
+                                                 aux_shape_ndim,
+                                                 aux_shape_data,
+                                                 s,
+                                                 ret,
+                                                 in_shape_size,
+                                                 out_shape_size,
+                                                 aux_shape_size,
+                                                 complete);
   API_END();
 }
 
@@ -673,7 +746,7 @@ int MXSymbolInferShapePartial(SymbolHandle sym,
                               const mx_uint **aux_shape_ndim,
                               const mx_uint ***aux_shape_data,
                               int *complete) {
-  int succ;
+  int succ = 0;
   *complete = 1;
   return MXSymbolInferShape(sym, num_args, keys,
                             arg_ind_ptr, arg_shape_data,
@@ -698,7 +771,7 @@ int MXSymbolInferShapePartialEx(SymbolHandle sym,
                                 const int **aux_shape_ndim,
                                 const int ***aux_shape_data,
                                 int *complete) {
-  int succ;
+  int succ = 0;
   *complete = 1;
   return MXSymbolInferShapeEx(sym, num_args, keys,
                               arg_ind_ptr, arg_shape_data,
@@ -708,6 +781,31 @@ int MXSymbolInferShapePartialEx(SymbolHandle sym,
                               &succ);
 }
 
+int MXSymbolInferShapePartialEx64(SymbolHandle sym,
+                                  mx_uint num_args,
+                                  const char** keys,
+                                  const int64_t *arg_ind_ptr,
+                                  const int64_t *arg_shape_data,
+                                  size_t *in_shape_size,
+                                  const int **in_shape_ndim,
+                                  const int64_t ***in_shape_data,
+                                  size_t *out_shape_size,
+                                  const int **out_shape_ndim,
+                                  const int64_t ***out_shape_data,
+                                  size_t *aux_shape_size,
+                                  const int **aux_shape_ndim,
+                                  const int64_t ***aux_shape_data,
+                                  int *complete) {
+  int succ = 0;
+  *complete = 1;
+  return MXSymbolInferShapeEx64(sym, num_args, keys,
+                                arg_ind_ptr, arg_shape_data,
+                                in_shape_size, in_shape_ndim, in_shape_data,
+                                out_shape_size, out_shape_ndim, out_shape_data,
+                                aux_shape_size, aux_shape_ndim, aux_shape_data,
+                                &succ);
+}
+
 int MXSymbolInferType(SymbolHandle sym,
                       mx_uint num_args,
                       const char** keys,
@@ -720,7 +818,7 @@ int MXSymbolInferType(SymbolHandle sym,
                       const int **aux_type_data,
                       int *complete) {
   nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   API_BEGIN();
   nnvm::Graph g = Symbol2Graph(*s);
   nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1);
@@ -764,7 +862,7 @@ int MXSymbolInferTypePartial(SymbolHandle sym,
                              mx_uint *aux_type_size,
                              const int **aux_type_data,
                              int *complete) {
-  int succ;
+  int succ = 0;
   *complete = 1;
   return MXSymbolInferType(sym, num_args, keys,
                             arg_type_data,
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index c00021c..e6a177e 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -48,7 +48,7 @@ OpStatePtr Imperative::InvokeOp(
   using namespace imperative;
   static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
   static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
 
   const nnvm::Op *op = attrs.op;
 
@@ -197,7 +197,7 @@ void Imperative::RecordOp(
     const OpStatePtr& state,
     std::vector<bool>* p_save_inputs,
     std::vector<bool>* p_save_outputs) {
-  MXAPIThreadLocalEntry *local_buff = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *local_buff = MXAPIThreadLocalStore<>::Get();
 
   for (auto output : outputs) {
     CHECK(AGInfo::IsNone(*output))
diff --git a/src/imperative/imperative_utils.h 
b/src/imperative/imperative_utils.h
index 477139f..b0476fc 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -104,7 +104,7 @@ inline void SetShapeType(const Context& ctx,
   static auto& infershape = 
nnvm::Op::GetAttr<mxnet::FInferShape>("FInferShape");
   static auto& infertype = nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
   static auto& inferstorage = 
nnvm::Op::GetAttr<FInferStorageType>("FInferStorageType");
-  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
+  MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
   // infer shape
   mxnet::ShapeVector& in_shapes  = ret->arg_shapes;
   in_shapes.clear();
diff --git a/src/operator/tensor/matrix_op-inl.h 
b/src/operator/tensor/matrix_op-inl.h
index 96c86c4..611dd72 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -668,9 +668,9 @@ void SliceEx(const nnvm::NodeAttrs& attrs,
 
 template<int ndim>
 inline void GetIndexRange(const mxnet::TShape& dshape,
-                          const mxnet::Tuple<dmlc::optional<int>>& param_begin,
-                          const mxnet::Tuple<dmlc::optional<int>>& param_end,
-                          const mxnet::Tuple<dmlc::optional<int>>& param_step,
+                          const mxnet::Tuple<dmlc::optional<index_t>>& 
param_begin,
+                          const mxnet::Tuple<dmlc::optional<index_t>>& 
param_end,
+                          const mxnet::Tuple<dmlc::optional<index_t>>& 
param_step,
                           common::StaticArray<index_t, ndim>* begin,
                           common::StaticArray<index_t, ndim>* end,
                           common::StaticArray<index_t, ndim>* step) {
@@ -1033,8 +1033,8 @@ void SliceAssignOpForward(const nnvm::NodeAttrs& attrs,
 
 struct SliceAssignScalarParam : public dmlc::Parameter<SliceAssignScalarParam> 
{
   double scalar;
-  mxnet::Tuple<dmlc::optional<int>> begin, end;
-  mxnet::Tuple<dmlc::optional<int>> step;
+  mxnet::Tuple<dmlc::optional<index_t>> begin, end;
+  mxnet::Tuple<dmlc::optional<index_t>> step;
   DMLC_DECLARE_PARAMETER(SliceAssignScalarParam) {
     DMLC_DECLARE_FIELD(scalar)
     .set_default(0)
@@ -1044,7 +1044,7 @@ struct SliceAssignScalarParam : public 
dmlc::Parameter<SliceAssignScalarParam> {
     DMLC_DECLARE_FIELD(end)
     .describe("ending indices for the slice operation, supports negative 
indices.");
     DMLC_DECLARE_FIELD(step)
-    .set_default(mxnet::Tuple<dmlc::optional<int>>())
+    .set_default(mxnet::Tuple<dmlc::optional<index_t>>())
     .describe("step for the slice operation, supports negative values.");
   }
 };
@@ -1346,12 +1346,12 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
 inline void SliceLikeInferRanges(const mxnet::TShape& dshape,
                                  const mxnet::TShape& fshape,
                                  const mxnet::Tuple<int>& axes,
-                                 mxnet::Tuple<dmlc::optional<int>>* 
param_begin,
-                                 mxnet::Tuple<dmlc::optional<int>>* param_end,
-                                 mxnet::Tuple<dmlc::optional<int>>* 
param_step) {
-  std::vector<dmlc::optional<int>> pb(dshape.ndim());
-  std::vector<dmlc::optional<int>> pe(dshape.ndim());
-  std::vector<dmlc::optional<int>> ps(dshape.ndim());
+                                 mxnet::Tuple<dmlc::optional<index_t>>* 
param_begin,
+                                 mxnet::Tuple<dmlc::optional<index_t>>* 
param_end,
+                                 mxnet::Tuple<dmlc::optional<index_t>>* 
param_step) {
+  std::vector<dmlc::optional<index_t>> pb(dshape.ndim());
+  std::vector<dmlc::optional<index_t>> pe(dshape.ndim());
+  std::vector<dmlc::optional<index_t>> ps(dshape.ndim());
   if (axes.ndim() == 0) {
     for (int i = 0; i < dshape.ndim(); ++i) {
       pb[i] = 0;
@@ -1375,9 +1375,9 @@ inline void SliceLikeInferRanges(const mxnet::TShape& 
dshape,
       ps[axis] = 1;
     }
   }
-  *param_begin = mxnet::Tuple<dmlc::optional<int>>(pb.begin(), pb.end());
-  *param_end = mxnet::Tuple<dmlc::optional<int>>(pe.begin(), pe.end());
-  *param_step = mxnet::Tuple<dmlc::optional<int>>(ps.begin(), ps.end());
+  *param_begin = mxnet::Tuple<dmlc::optional<index_t>>(pb.begin(), pb.end());
+  *param_end = mxnet::Tuple<dmlc::optional<index_t>>(pe.begin(), pe.end());
+  *param_step = mxnet::Tuple<dmlc::optional<index_t>>(ps.begin(), ps.end());
 }
 
 template<typename xpu>
@@ -1396,9 +1396,9 @@ void SliceLikeForward(const nnvm::NodeAttrs& attrs,
   const TBlob& out = outputs[0];
   const mxnet::TShape& ishape = data.shape_;
   const mxnet::TShape& from_shape = inputs[1].shape_;
-  mxnet::Tuple<dmlc::optional<int>> param_begin;
-  mxnet::Tuple<dmlc::optional<int>> param_end;
-  mxnet::Tuple<dmlc::optional<int>> param_step;
+  mxnet::Tuple<dmlc::optional<index_t>> param_begin;
+  mxnet::Tuple<dmlc::optional<index_t>> param_end;
+  mxnet::Tuple<dmlc::optional<index_t>> param_step;
   SliceLikeInferRanges(ishape, from_shape, param.axes, &param_begin, 
&param_end, &param_step);
 
   MXNET_NDIM_SWITCH(data.ndim(), ndim, {
@@ -1444,9 +1444,9 @@ void SliceLikeBackward(const nnvm::NodeAttrs& attrs,
 
   const mxnet::TShape& ishape = ograd.shape_;
   const mxnet::TShape& from_shape = outputs[1].shape_;
-  mxnet::Tuple<dmlc::optional<int>> param_begin;
-  mxnet::Tuple<dmlc::optional<int>> param_end;
-  mxnet::Tuple<dmlc::optional<int>> param_step;
+  mxnet::Tuple<dmlc::optional<index_t>> param_begin;
+  mxnet::Tuple<dmlc::optional<index_t>> param_end;
+  mxnet::Tuple<dmlc::optional<index_t>> param_step;
   SliceLikeInferRanges(ishape, from_shape, param.axes, &param_begin, 
&param_end, &param_step);
 
   MXNET_NDIM_SWITCH(ograd.ndim(), ndim, {
diff --git a/src/operator/tensor/slice-inl.h b/src/operator/tensor/slice-inl.h
index 78a2bd8..7450e46 100644
--- a/src/operator/tensor/slice-inl.h
+++ b/src/operator/tensor/slice-inl.h
@@ -34,15 +34,15 @@ namespace mxnet {
 namespace op {
 
 struct SliceParam : public dmlc::Parameter<SliceParam> {
-  mxnet::Tuple<dmlc::optional<int>> begin, end;
-  mxnet::Tuple<dmlc::optional<int>> step;
+  mxnet::Tuple<dmlc::optional<index_t>> begin, end;
+  mxnet::Tuple<dmlc::optional<index_t>> step;
   DMLC_DECLARE_PARAMETER(SliceParam) {
     DMLC_DECLARE_FIELD(begin)
     .describe("starting indices for the slice operation, supports negative 
indices.");
     DMLC_DECLARE_FIELD(end)
     .describe("ending indices for the slice operation, supports negative 
indices.");
     DMLC_DECLARE_FIELD(step)
-    .set_default(mxnet::Tuple<dmlc::optional<int>>())
+    .set_default(mxnet::Tuple<dmlc::optional<index_t>>())
     .describe("step for the slice operation, supports negative values.");
   }
   bool operator==(const SliceParam& other) const {
diff --git a/tests/nightly/test_large_array.py 
b/tests/nightly/test_large_array.py
index 0df481a..9dc29eb 100644
--- a/tests/nightly/test_large_array.py
+++ b/tests/nightly/test_large_array.py
@@ -24,7 +24,6 @@ from tests.python.unittest.common import with_seed
 # dimension constants
 MEDIUM_X = 10000
 LARGE_X = 100000000
-LARGE_Y = 50000000
 SMALL_Y = 50
 LARGE_SIZE = LARGE_X * SMALL_Y
 
diff --git a/tests/nightly/test_large_vector.py 
b/tests/nightly/test_large_vector.py
new file mode 100644
index 0000000..8c030f5
--- /dev/null
+++ b/tests/nightly/test_large_vector.py
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx
+from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d
+from mxnet import gluon, nd
+from tests.python.unittest.common import with_seed
+
+# dimension constants
+LARGE_X = 5000000000
+MEDIUM_X = 1000000000
+
+
+def test_slice():
+    a = nd.ones(LARGE_X)
+    res = nd.slice(a, begin=(LARGE_X - MEDIUM_X), end=LARGE_X)
+    assert res.shape[0] == MEDIUM_X
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()

Reply via email to