Github user BryanCutler commented on a diff in the pull request: https://github.com/apache/spark/pull/19459#discussion_r145291702 --- Diff: python/pyspark/sql/session.py --- @@ -414,6 +415,39 @@ def _createFromLocal(self, data, schema): data = [schema.toInternal(row) for row in data] return self._sc.parallelize(data), schema + def _createFromPandasWithArrow(self, pdf, schema): + """ + Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting + to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the + data types will be used to coerce the data in Pandas to Arrow conversion. + """ + from pyspark.serializers import ArrowSerializer + from pyspark.sql.types import from_arrow_schema, to_arrow_schema + import pyarrow as pa + + # Slice the DataFrame into batches + step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up + pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) + arrow_schema = to_arrow_schema(schema) if schema is not None else None + batches = [pa.RecordBatch.from_pandas(pdf_slice, schema=arrow_schema, preserve_index=False) + for pdf_slice in pdf_slices] + + # Verify schema, there will be at least 1 batch from pandas.DataFrame + schema_from_arrow = from_arrow_schema(batches[0].schema) + if schema is not None and schema != schema_from_arrow: + raise ValueError("Supplied schema does not match result from Arrow\nsupplied: " + + "%s\n!=\nfrom Arrow: %s" % (str(schema), str(schema_from_arrow))) --- End diff -- @ueshin and @HyukjinKwon after thinking about what to do when the schema is not equal, I have some concerns: 1. Fallback to `createDataFrame` without Arrow - I implemented this and works fine, but there is no logging in python (afaik) so my concern is that it does this silently and causes bad performance and the user will not know why. 2. Cast types using `astype` similar to `ArrowPandasSerializer.dump_stream` - The issue I see with that is if there are null values and ints have been promoted to floats, this works fine in `dump_stream` because we are working with pd.Series and pyarrow allows us to pass a validity mask, which ignores the filled values. There aren't options to pass in masks for pd.DataFrames, so I believe it will try to interpret whatever fill values are there and cause an error. I can look into this more though.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org