This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 6fea291  [SPARK-31186][PYSPARK][SQL] toPandas should not fail on 
duplicate column names
6fea291 is described below

commit 6fea291762af3e802cb4c237bdad51ebf5d7152c
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Fri Mar 27 12:10:30 2020 +0900

    [SPARK-31186][PYSPARK][SQL] toPandas should not fail on duplicate column 
names
    
    ### What changes were proposed in this pull request?
    
    When `toPandas` API works on duplicate column names produced from operators 
like join, we see the error like:
    
    ```
    ValueError: The truth value of a Series is ambiguous. Use a.empty, 
a.bool(), a.item(), a.any() or a.all().
    ```
    
    This patch fixes the error in `toPandas` API.
    
    ### Why are the changes needed?
    
    To make `toPandas` work on dataframe with duplicate column names.
    
    ### Does this PR introduce any user-facing change?
    
    Yes. Previously calling `toPandas` API on a dataframe with duplicate column 
names will fail. After this patch, it will produce correct result.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #28025 from viirya/SPARK-31186.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: HyukjinKwon <gurwls...@apache.org>
    (cherry picked from commit 559d3e4051500d5c49e9a7f3ac33aac3de19c9c6)
    Signed-off-by: HyukjinKwon <gurwls...@apache.org>
---
 python/pyspark/sql/pandas/conversion.py    | 48 +++++++++++++++++++++++-------
 python/pyspark/sql/tests/test_dataframe.py | 18 +++++++++++
 2 files changed, 56 insertions(+), 10 deletions(-)

diff --git a/python/pyspark/sql/pandas/conversion.py 
b/python/pyspark/sql/pandas/conversion.py
index 8548cd2..47cf8bb 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -21,6 +21,7 @@ if sys.version >= '3':
     xrange = range
 else:
     from itertools import izip as zip
+from collections import Counter
 
 from pyspark import since
 from pyspark.rdd import _load_from_socket
@@ -131,9 +132,16 @@ class PandasConversionMixin(object):
 
         # Below is toPandas without Arrow optimization.
         pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
+        column_counter = Counter(self.columns)
+
+        dtype = [None] * len(self.schema)
+        for fieldIdx, field in enumerate(self.schema):
+            # For duplicate column name, we use `iloc` to access it.
+            if column_counter[field.name] > 1:
+                pandas_col = pdf.iloc[:, fieldIdx]
+            else:
+                pandas_col = pdf[field.name]
 
-        dtype = {}
-        for field in self.schema:
             pandas_type = 
PandasConversionMixin._to_corrected_pandas_type(field.dataType)
             # SPARK-21766: if an integer field is nullable and has null 
values, it can be
             # inferred by pandas as float column. Once we convert the column 
with NaN back
@@ -141,16 +149,36 @@ class PandasConversionMixin(object):
             # float type, not the corrected type from the schema in this case.
             if pandas_type is not None and \
                 not(isinstance(field.dataType, IntegralType) and 
field.nullable and
-                    pdf[field.name].isnull().any()):
-                dtype[field.name] = pandas_type
+                    pandas_col.isnull().any()):
+                dtype[fieldIdx] = pandas_type
             # Ensure we fall back to nullable numpy types, even when whole 
column is null:
-            if isinstance(field.dataType, IntegralType) and 
pdf[field.name].isnull().any():
-                dtype[field.name] = np.float64
-            if isinstance(field.dataType, BooleanType) and 
pdf[field.name].isnull().any():
-                dtype[field.name] = np.object
+            if isinstance(field.dataType, IntegralType) and 
pandas_col.isnull().any():
+                dtype[fieldIdx] = np.float64
+            if isinstance(field.dataType, BooleanType) and 
pandas_col.isnull().any():
+                dtype[fieldIdx] = np.object
+
+        df = pd.DataFrame()
+        for index, t in enumerate(dtype):
+            column_name = self.schema[index].name
+
+            # For duplicate column name, we use `iloc` to access it.
+            if column_counter[column_name] > 1:
+                series = pdf.iloc[:, index]
+            else:
+                series = pdf[column_name]
+
+            if t is not None:
+                series = series.astype(t, copy=False)
+
+            # `insert` API makes copy of data, we only do it for Series of 
duplicate column names.
+            # `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always work 
because `iloc` could
+            # return a view or a copy depending by context.
+            if column_counter[column_name] > 1:
+                df.insert(index, column_name, series, allow_duplicates=True)
+            else:
+                df[column_name] = series
 
-        for f, t in dtype.items():
-            pdf[f] = pdf[f].astype(t, copy=False)
+        pdf = df
 
         if timezone is None:
             return pdf
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index d738449..d9dcbc0 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -529,6 +529,24 @@ class DataFrameTests(ReusedSQLTestCase):
         self.assertEquals(types[4], np.object)  # datetime.date
         self.assertEquals(types[5], 'datetime64[ns]')
 
+    @unittest.skipIf(not have_pandas, pandas_requirement_message)
+    def test_to_pandas_on_cross_join(self):
+        import numpy as np
+
+        sql = """
+        select t1.*, t2.* from (
+          select explode(sequence(1, 3)) v
+        ) t1 left join (
+          select explode(sequence(1, 3)) v
+        ) t2
+        """
+        with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+            df = self.spark.sql(sql)
+            pdf = df.toPandas()
+            types = pdf.dtypes
+            self.assertEquals(types.iloc[0], np.int32)
+            self.assertEquals(types.iloc[1], np.int32)
+
     @unittest.skipIf(have_pandas, "Required Pandas was found.")
     def test_to_pandas_required_pandas_not_found(self):
         with QuietTest(self.sc):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to