This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3b0d0aa4dcf [SPARK-40990][PYTHON] DataFrame creation from 2d NumPy 
array with arbitrary columns
3b0d0aa4dcf is described below

commit 3b0d0aa4dcf0c7436fc44cbee93c61728a514cdf
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Wed Nov 2 10:06:14 2022 +0900

    [SPARK-40990][PYTHON] DataFrame creation from 2d NumPy array with arbitrary 
columns
    
    ### What changes were proposed in this pull request?
    Support DataFrame creation from 2d NumPy array with arbitrary columns.
    
    ### Why are the changes needed?
    Currently, DataFrame creation from 2d ndarray works only with 2 columns. We 
should provide complete support for DataFrame creation with 2d ndarray.
    
    As part of [SPARK-39405](https://issues.apache.org/jira/browse/SPARK-39405).
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    Before
    ```py
    >>> spark.createDataFrame(np.array([[1], [2]])).dtypes
    Traceback (most recent call last):
    ...
        raise ValueError(f"Shape of passed values is {passed}, indices imply 
{implied}")
    ValueError: Shape of passed values is (2, 1), indices imply (2, 2)
    
    >>> spark.createDataFrame(np.array([[1, 1, 1], [2, 2, 2]])).dtypes
    Traceback (most recent call last):
    ...
        raise ValueError(f"Shape of passed values is {passed}, indices imply 
{implied}")
    ValueError: Shape of passed values is (2, 3), indices imply (2, 2)
    ```
    
    After
    ```py
    >>> spark.createDataFrame(np.array([[1], [2]])).dtypes
    [('value', 'bigint')]
    
    >>> spark.createDataFrame(np.array([[1, 1, 1], [2, 2, 2]])).dtypes
    [('_1', 'bigint'), ('_2', 'bigint'), ('_3', 'bigint')]
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #38473 from xinrong-meng/ncol_ndarr.
    
    Authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/session.py          |  7 ++++++-
 python/pyspark/sql/tests/test_arrow.py | 14 +++++++++-----
 2 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index f248afa3d83..ee04c94cbd5 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -1139,7 +1139,12 @@ class SparkSession(SparkConversionMixin):
             require_minimum_pandas_version()
             if data.ndim not in [1, 2]:
                 raise ValueError("NumPy array input should be of 1 or 2 
dimensions.")
-            column_names = ["value"] if data.ndim == 1 else ["_1", "_2"]
+
+            if data.ndim == 1 or data.shape[1] == 1:
+                column_names = ["value"]
+            else:
+                column_names = ["_%s" % i for i in range(1, data.shape[1] + 1)]
+
             if schema is None and not self._jconf.arrowPySparkEnabled():
                 # Construct `schema` from `np.dtype` of the input NumPy array
                 # TODO: Apply the logic below when 
self._jconf.arrowPySparkEnabled() is True
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index fdba431726c..6083f31ac81 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -188,8 +188,10 @@ class ArrowTests(ReusedSQLTestCase):
         return (
             [np.array([1, 2]).astype(t) for t in int_dtypes]
             + [np.array([0.1, 0.2]).astype(t) for t in float_dtypes]
-            + [np.array([[1, 2], [3, 4]]).astype(t) for t in int_dtypes]
-            + [np.array([[0.1, 0.2], [0.3, 0.4]]).astype(t) for t in 
float_dtypes]
+            + [np.array([[1], [2]]).astype(t) for t in int_dtypes]
+            + [np.array([[0.1], [0.2]]).astype(t) for t in float_dtypes]
+            + [np.array([[1, 1, 1], [2, 2, 2]]).astype(t) for t in int_dtypes]
+            + [np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2]]).astype(t) for t in 
float_dtypes]
         )
 
     def test_toPandas_fallback_enabled(self):
@@ -510,9 +512,11 @@ class ArrowTests(ReusedSQLTestCase):
 
     def test_createDataFrame_with_ndarray(self):
         dtypes = ["tinyint", "smallint", "int", "bigint", "float", "double"]
-        expected_dtypes = [[("value", t)] for t in dtypes] + [
-            [("_1", t), ("_2", t)] for t in dtypes
-        ]
+        expected_dtypes = (
+            [[("value", t)] for t in dtypes]
+            + [[("value", t)] for t in dtypes]
+            + [[("_1", t), ("_2", t), ("_3", t)] for t in dtypes]
+        )
         arrs = self.create_np_arrs
 
         for arr, dtypes in zip(arrs, expected_dtypes):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to