Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/341#discussion_r11570397
  
    --- Diff: python/pyspark/mllib/_common.py ---
    @@ -55,159 +103,222 @@ def _serialize_double_vector(v):
         >>> array_equal(y, array([1.0, 2.0, 3.0]))
         True
         """
    -    if type(v) != ndarray:
    -        raise TypeError("_serialize_double_vector called on a %s; "
    -                "wanted ndarray" % type(v))
    -    """complex is only datatype that can't be converted to float64"""
    -    if issubdtype(v.dtype, complex):
    +    v = _convert_vector(v)
    +    if type(v) == ndarray:
    +        return _serialize_dense_vector(v)
    +    elif type(v) == SparseVector:
    +        return _serialize_sparse_vector(v)
    +    else:
             raise TypeError("_serialize_double_vector called on a %s; "
    -                "wanted ndarray" % type(v))
    -    if v.dtype != float64:
    -        v = v.astype(float64)
    +                "wanted ndarray or SparseVector" % type(v))
    +
    +
    +def _serialize_dense_vector(v):
    +    """Serialize a dense vector given as a NumPy array."""
         if v.ndim != 1:
             raise TypeError("_serialize_double_vector called on a %ddarray; "
                     "wanted a 1darray" % v.ndim)
    +    if v.dtype != float64:
    +        if numpy.issubdtype(v.dtype, numpy.complex):
    +            raise TypeError("_serialize_double_vector called on an ndarray 
of %s; "
    +                    "wanted ndarray of float64" % v.dtype)
    +        v = v.astype(float64)
         length = v.shape[0]
    -    ba = bytearray(16 + 8*length)
    -    header = ndarray(shape=[2], buffer=ba, dtype="int64")
    -    header[0] = 1
    -    header[1] = length
    -    copyto(ndarray(shape=[length], buffer=ba, offset=16,
    -            dtype="float64"), v)
    +    ba = bytearray(5 + 8 * length)
    +    ba[0] = DENSE_VECTOR_MAGIC
    +    length_bytes = ndarray(shape=[1], buffer=ba, offset=1, dtype=int32)
    +    length_bytes[0] = length
    +    copyto(ndarray(shape=[length], buffer=ba, offset=5, dtype=float64), v)
    +    return ba
    +
    +
    +def _serialize_sparse_vector(v):
    +    """Serialize a pyspark.mllib.linalg.SparseVector."""
    +    nonzeros = len(v.indices)
    +    ba = bytearray(9 + 12 * nonzeros)
    +    ba[0] = SPARSE_VECTOR_MAGIC
    +    header = ndarray(shape=[2], buffer=ba, offset=1, dtype=int32)
    +    header[0] = v.size
    +    header[1] = nonzeros
    +    copyto(ndarray(shape=[nonzeros], buffer=ba, offset=9, dtype=int32), 
v.indices)
    +    values_offset = 9 + 4 * nonzeros
    +    copyto(ndarray(shape=[nonzeros], buffer=ba, offset=values_offset, 
dtype=float64), v.values)
    --- End diff --
    
    ditto.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

Reply via email to