wkcn commented on a change in pull request #14733: [MXNET-1398] Enable 
zero-copy from numpy to MXNet NDArray
URL: https://github.com/apache/incubator-mxnet/pull/14733#discussion_r279600607
 
 

 ##########
 File path: python/mxnet/ndarray/ndarray.py
 ##########
 @@ -4115,3 +4115,108 @@ def from_dlpack(dlpack):
     # delete the deleter of the old dlpack
     ctypes.pythonapi.PyCapsule_SetDestructor(dlpack, None)
     return NDArray(handle=handle)
+
+class DLContext(ctypes.Structure):
+    _fields_ = [("device_type", ctypes.c_int),
+                ("device_id", ctypes.c_int)]
+
+
+class DLDataType(ctypes.Structure):
+    _fields_ = [("type_code", ctypes.c_uint8),
+                ("bits", ctypes.c_uint8),
+                ("lanes", ctypes.c_uint16)]
+    TYPE_MAP = {
+        "int32": (0, 32, 1),
+        "int64": (0, 64, 1),
+        "bool": (1, 1, 1),
+        "uint32": (1, 32, 1),
+        "uint64": (1, 64, 1),
+        "float32": (2, 32, 1),
+        "float64": (2, 64, 1),
+    }
+
+
+class DLTensor(ctypes.Structure):
+    _fields_ = [("data", ctypes.c_void_p),
+                ("ctx", DLContext),
+                ("ndim", ctypes.c_int),
+                ("dtype", DLDataType),
+                ("shape", ctypes.POINTER(ctypes.c_int64)),
+                ("strides", ctypes.POINTER(ctypes.c_int64)),
+                ("byte_offset", ctypes.c_uint64)]
+
+class DLManagedTensor(ctypes.Structure):
+    pass
+
+
+DeleterFunc = ctypes.CFUNCTYPE(None, ctypes.POINTER(DLManagedTensor))
+
+
+DLManagedTensor._fields_ = [("dl_tensor", DLTensor),           # pylint: 
disable=protected-access
+                            ("manager_ctx", ctypes.c_void_p),
+                            ("deleter", DeleterFunc)]
+
+
+@DeleterFunc
+def dl_managed_tensor_deleter(dl_managed_tensor_handle):
+    void_p = dl_managed_tensor_handle.contents.manager_ctx
+    pyobj = ctypes.cast(void_p, ctypes.py_object)
+    ctypes.pythonapi.Py_DecRef(pyobj)
+
+
+def from_numpy(ndarray, zero_copy=True):
+    """Returns an MXNet's NDArray backed by Numpy's ndarray.
+
+    Parameters
+    ----------
+    ndarray: numpy.ndarray
+        input data
+
+    zero_copy: bool
+        Whether we use DLPack's zero-copy conversion to convert to MXNet's 
NDArray.
+        This is only available for c-contiguous arrays, i.e. 
array.flags[C_CONTIGUOUS] == True.
+
+    Returns
+    -------
+    NDArray
+        a NDArray backed by a dlpack tensor
+
+    """
+
+    def _make_manager_ctx(obj):
+        pyobj = ctypes.py_object(obj)
+        void_p = ctypes.c_void_p.from_buffer(pyobj)
+        ctypes.pythonapi.Py_IncRef(pyobj)
+        return void_p
+
+    def _make_dl_tensor(array):
+        if str(array.dtype) not in DLDataType.TYPE_MAP:
+            raise ValueError(str(array.dtype) + " is not supported.")
+        dl_tensor = DLTensor()
+        dl_tensor.data = array.ctypes.data_as(ctypes.c_void_p)
+        dl_tensor.ctx = DLContext(1, 0)
+        dl_tensor.ndim = array.ndim
+        dl_tensor.dtype = DLDataType.TYPE_MAP[str(array.dtype)]
+        dl_tensor.shape = array.ctypes.shape_as(ctypes.c_int64)
+        dl_tensor.strides = None
+        dl_tensor.byte_offset = 0
+        return dl_tensor
+
+    def _make_dl_managed_tensor(array):
+        c_obj = DLManagedTensor()
+        c_obj.dl_tensor = _make_dl_tensor(array)
+        c_obj.manager_ctx = _make_manager_ctx(array)
+        c_obj.deleter = dl_managed_tensor_deleter
+        return c_obj
+
+    if not zero_copy:
+        return array(ndarray, dtype=ndarray.dtype)
+
+    if not ndarray.flags['C_CONTIGUOUS']:
+        raise ValueError("Only c-contiguous arrays are supported for 
zero-copy")
+    c_obj = _make_dl_managed_tensor(ndarray)
 
 Review comment:
   Consider the following case:
   ```python
   # the variable `a` is a `mx.nd.NDArray` object
   a = mx.nd.array([1,2,3])
   b = a.asnumpy(zero_copy=True)
   #  users call `asnumpy(zero_copy=True)` twice
   c = a.asnumpy(zero_copy=True)
   c[:] = 3
   print(a)
   print(b)
   
   a[:] = 5
   a.wait_to_read()
   print(b)
   print(c)
   ```
   It's necessary to keep co-ownership.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to