This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 30190edb2df7 [SPARK-46955][PS] Implement `Frame.to_stata`
30190edb2df7 is described below

commit 30190edb2df7e6cc15a5db7b070cd9dde11e2106
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Mon Feb 5 09:14:05 2024 +0800

    [SPARK-46955][PS] Implement `Frame.to_stata`
    
    ### What changes were proposed in this pull request?
     Implement `Frame.to_stata`
    
    ### Why are the changes needed?
    for Pandas parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    ```
    In [5]: df = pd.DataFrame({'animal': ['falcon', 'parrot', 'falcon', 
'parrot'], 'speed': [350, 18, 361, 15]})
    
    In [6]: psdf = ps.from_pandas(df)
    
    In [7]: df.to_stata('/tmp/animals_1.dta')
    
    In [8]: psdf.to_stata('/tmp/animals_2.dta')
    
    In [9]: pd.read_stata('/tmp/animals_1.dta')
    Out[9]:
       index  animal  speed
    0      0  falcon    350
    1      1  parrot     18
    2      2  falcon    361
    3      3  parrot     15
    
    In [10]: pd.read_stata('/tmp/animals_2.dta')
    Out[10]:
       index  animal  speed
    0      0  falcon    350
    1      1  parrot     18
    2      2  falcon    361
    3      3  parrot     15
    ```
    
    ### How was this patch tested?
    added ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44996 from zhengruifeng/ps_to_stata.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 dev/sparktestsupport/modules.py                    |  2 +
 .../docs/source/reference/pyspark.pandas/frame.rst |  1 +
 python/pyspark/pandas/frame.py                     | 82 ++++++++++++++++++++++
 python/pyspark/pandas/missing/frame.py             |  1 -
 .../pandas/tests/connect/io/test_parity_stata.py   | 42 +++++++++++
 python/pyspark/pandas/tests/io/test_stata.py       | 67 ++++++++++++++++++
 6 files changed, 194 insertions(+), 1 deletion(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 233dcf4e54b6..2ed2144fa64b 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -817,6 +817,7 @@ pyspark_pandas = Module(
         "pyspark.pandas.tests.io.test_io",
         "pyspark.pandas.tests.io.test_csv",
         "pyspark.pandas.tests.io.test_feather",
+        "pyspark.pandas.tests.io.test_stata",
         "pyspark.pandas.tests.io.test_dataframe_conversion",
         "pyspark.pandas.tests.io.test_dataframe_spark_io",
         "pyspark.pandas.tests.io.test_series_conversion",
@@ -1299,6 +1300,7 @@ pyspark_pandas_connect_part3 = Module(
         "pyspark.pandas.tests.connect.io.test_parity_io",
         "pyspark.pandas.tests.connect.io.test_parity_csv",
         "pyspark.pandas.tests.connect.io.test_parity_feather",
+        "pyspark.pandas.tests.connect.io.test_parity_stata",
         "pyspark.pandas.tests.connect.io.test_parity_dataframe_conversion",
         "pyspark.pandas.tests.connect.io.test_parity_dataframe_spark_io",
         "pyspark.pandas.tests.connect.io.test_parity_series_conversion",
diff --git a/python/docs/source/reference/pyspark.pandas/frame.rst 
b/python/docs/source/reference/pyspark.pandas/frame.rst
index 564ddb607a19..336fd262f611 100644
--- a/python/docs/source/reference/pyspark.pandas/frame.rst
+++ b/python/docs/source/reference/pyspark.pandas/frame.rst
@@ -284,6 +284,7 @@ Serialization / IO / Conversion
    DataFrame.to_spark
    DataFrame.to_string
    DataFrame.to_feather
+   DataFrame.to_stata
    DataFrame.to_json
    DataFrame.to_dict
    DataFrame.to_excel
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 3b3565f7ea9f..e857344a6098 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -2683,6 +2683,88 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
             self._to_internal_pandas(), self.to_feather, 
pd.DataFrame.to_feather, args
         )
 
+    def to_stata(
+        self,
+        path: Union[str, IO[str]],
+        *,
+        convert_dates: Optional[Dict] = None,
+        write_index: bool = True,
+        byteorder: Optional[str] = None,
+        time_stamp: Optional[datetime.datetime] = None,
+        data_label: Optional[str] = None,
+        variable_labels: Optional[Dict] = None,
+        version: Optional[int] = 114,
+        convert_strl: Optional[Sequence[Name]] = None,
+        compression: str = "infer",
+        storage_options: Optional[str] = None,
+        value_labels: Optional[Dict] = None,
+    ) -> None:
+        """
+        Export DataFrame object to Stata dta format.
+
+        .. note:: This method should only be used if the resulting DataFrame 
is expected
+                  to be small, as all the data is loaded into the driver's 
memory.
+
+        .. versionadded:: 4.0.0
+
+        Parameters
+        ----------
+        path : str, path object, or buffer
+            String, path object (implementing ``os.PathLike[str]``), or 
file-like
+            object implementing a binary ``write()`` function.
+        convert_dates : dict
+            Dictionary mapping columns containing datetime types to stata
+            internal format to use when writing the dates. Options are 'tc',
+            'td', 'tm', 'tw', 'th', 'tq', 'ty'. Column can be either an integer
+            or a name. Datetime columns that do not have a conversion type
+            specified will be converted to 'tc'. Raises NotImplementedError if
+            a datetime column has timezone information.
+        write_index : bool
+            Write the index to Stata dataset.
+        byteorder : str
+            Can be ">", "<", "little", or "big". default is `sys.byteorder`.
+        time_stamp : datetime
+            A datetime to use as file creation date.  Default is the current
+            time.
+        data_label : str, optional
+            A label for the data set.  Must be 80 characters or smaller.
+        variable_labels : dict
+            Dictionary containing columns as keys and variable labels as
+            values. Each label must be 80 characters or smaller.
+        version : {{114, 117, 118, 119, None}}, default 114
+            Version to use in the output dta file. Set to None to let pandas
+            decide between 118 or 119 formats depending on the number of
+            columns in the frame. Version 114 can be read by Stata 10 and
+            later. Version 117 can be read by Stata 13 or later. Version 118
+            is supported in Stata 14 and later. Version 119 is supported in
+            Stata 15 and later. Version 114 limits string variables to 244
+            characters or fewer while versions 117 and later allow strings
+            with lengths up to 2,000,000 characters. Versions 118 and 119
+            support Unicode characters, and version 119 supports more than
+            32,767 variables.
+        convert_strl : list, optional
+            List of column names to convert to string columns to Stata StrL
+            format. Only available if version is 117.  Storing strings in the
+            StrL format can produce smaller dta files if strings have more than
+            8 characters and values are repeated.
+        value_labels : dict of dicts
+            Dictionary containing columns as keys and dictionaries of column 
value
+            to labels as values. Labels for a single variable must be 32,000
+            characters or smaller.
+
+        Examples
+        --------
+        >>> df = ps.DataFrame({'animal': ['falcon', 'parrot', 'falcon', 
'parrot'],
+        ...                    'speed': [350, 18, 361, 15]})
+        >>> df.to_stata('animals.dta')  # doctest: +SKIP
+        """
+        # Make sure locals() call is at the top of the function so we don't 
capture local variables.
+        args = locals()
+
+        return validate_arguments_and_invoke_function(
+            self._to_internal_pandas(), self.to_stata, pd.DataFrame.to_stata, 
args
+        )
+
     def transpose(self) -> "DataFrame":
         """
         Transpose index and columns.
diff --git a/python/pyspark/pandas/missing/frame.py 
b/python/pyspark/pandas/missing/frame.py
index fdb6cec7c0f9..bdfa7574dc3d 100644
--- a/python/pyspark/pandas/missing/frame.py
+++ b/python/pyspark/pandas/missing/frame.py
@@ -44,7 +44,6 @@ class MissingPandasLikeDataFrame:
     set_axis = _unsupported_function("set_axis")
     to_period = _unsupported_function("to_period")
     to_sql = _unsupported_function("to_sql")
-    to_stata = _unsupported_function("to_stata")
     to_timestamp = _unsupported_function("to_timestamp")
     tz_convert = _unsupported_function("tz_convert")
     tz_localize = _unsupported_function("tz_localize")
diff --git a/python/pyspark/pandas/tests/connect/io/test_parity_stata.py 
b/python/pyspark/pandas/tests/connect/io/test_parity_stata.py
new file mode 100644
index 000000000000..d7a74d7399a5
--- /dev/null
+++ b/python/pyspark/pandas/tests/connect/io/test_parity_stata.py
@@ -0,0 +1,42 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.io.test_stata import StataMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils
+
+
+class StataParityTests(
+    StataMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+    TestUtils,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.pandas.tests.connect.io.test_parity_stata import *  # noqa: 
F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/io/test_stata.py 
b/python/pyspark/pandas/tests/io/test_stata.py
new file mode 100644
index 000000000000..6fe7cf13513c
--- /dev/null
+++ b/python/pyspark/pandas/tests/io/test_stata.py
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
+
+
+class StataMixin:
+    @property
+    def pdf(self):
+        return pd.DataFrame(
+            {"animal": ["falcon", "parrot", "falcon", "parrot"], "speed": 
[350, 18, 361, 15]}
+        )
+
+    @property
+    def psdf(self):
+        return ps.from_pandas(self.pdf)
+
+    def test_to_feather(self):
+        with self.temp_dir() as dirpath:
+            path1 = f"{dirpath}/file1.dta"
+            path2 = f"{dirpath}/file2.dta"
+
+            self.pdf.to_stata(path1)
+            self.psdf.to_stata(path2)
+
+            self.assert_eq(
+                pd.read_stata(path1),
+                pd.read_stata(path2),
+            )
+
+
+class StataTests(
+    StataMixin,
+    PandasOnSparkTestCase,
+    TestUtils,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.pandas.tests.io.test_stata import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to