Github user BryanCutler commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19459#discussion_r145796514
  
    --- Diff: python/pyspark/sql/session.py ---
    @@ -414,6 +415,73 @@ def _createFromLocal(self, data, schema):
             data = [schema.toInternal(row) for row in data]
             return self._sc.parallelize(data), schema
     
    +    def _createFromPandasWithArrow(self, pdf, schema):
    +        """
    +        Create a DataFrame from a given pandas.DataFrame by slicing it 
into partitions, converting
    +        to Arrow data, then sending to the JVM to parallelize. If a schema 
is passed in, the
    +        data types will be used to coerce the data in Pandas to Arrow 
conversion.
    +        """
    +        from pyspark.serializers import ArrowSerializer
    +        from pyspark.sql.types import from_arrow_schema, to_arrow_type, 
_cast_pandas_series_type
    +        import pyarrow as pa
    +
    +        # Slice the DataFrame into batches
    +        step = -(-len(pdf) // self.sparkContext.defaultParallelism)  # 
round int up
    +        pdf_slices = (pdf[start:start + step] for start in xrange(0, 
len(pdf), step))
    +
    +        if schema is None or isinstance(schema, list):
    +            batches = [pa.RecordBatch.from_pandas(pdf_slice, 
preserve_index=False)
    +                       for pdf_slice in pdf_slices]
    +
    +            # There will be at least 1 batch after slicing the 
pandas.DataFrame
    +            schema_from_arrow = from_arrow_schema(batches[0].schema)
    +
    +            # If passed schema as a list of names then rename fields
    +            if isinstance(schema, list):
    +                fields = []
    +                for i, field in enumerate(schema_from_arrow):
    +                    field.name = schema[i]
    +                    fields.append(field)
    +                schema = StructType(fields)
    +            else:
    +                schema = schema_from_arrow
    +        else:
    +            batches = []
    +            for i, pdf_slice in enumerate(pdf_slices):
    +
    +                # convert to series to pyarrow.Arrays to use mask when 
creating Arrow batches
    +                arrs = []
    +                names = []
    +                for c, (_, series) in enumerate(pdf_slice.iteritems()):
    +                    field = schema[c]
    +                    names.append(field.name)
    +                    t = to_arrow_type(field.dataType)
    +                    try:
    +                        # NOTE: casting is not necessary with Arrow >= 0.7
    +                        
arrs.append(pa.Array.from_pandas(_cast_pandas_series_type(series, t),
    +                                                         
mask=series.isnull(), type=t))
    +                    except ValueError as e:
    +                        warnings.warn("Arrow will not be used in 
createDataFrame: %s" % str(e))
    +                        return None
    +                batches.append(pa.RecordBatch.from_arrays(arrs, names))
    +
    +                # Verify schema of first batch, return None if not equal 
and fallback without Arrow
    +                if i == 0:
    +                    schema_from_arrow = 
from_arrow_schema(batches[i].schema)
    +                    if schema != schema_from_arrow:
    +                        warnings.warn("Arrow will not be used in 
createDataFrame.\n" +
    --- End diff --
    
    Thanks @ueshin! This does simplify things quite a bit, which I like.  My 
only concerns are that we rely on Arrow/Pandas to raise an error somewhere 
during the casting in order to fallback, and the fields in the Arrow record 
batches get arbitrary names (doesn't use schema names).  What are your thoughts 
@HyukjinKwon ?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to