This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 30190edb2df7 [SPARK-46955][PS] Implement `Frame.to_stata` 30190edb2df7 is described below commit 30190edb2df7e6cc15a5db7b070cd9dde11e2106 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Mon Feb 5 09:14:05 2024 +0800 [SPARK-46955][PS] Implement `Frame.to_stata` ### What changes were proposed in this pull request? Implement `Frame.to_stata` ### Why are the changes needed? for Pandas parity ### Does this PR introduce _any_ user-facing change? yes ``` In [5]: df = pd.DataFrame({'animal': ['falcon', 'parrot', 'falcon', 'parrot'], 'speed': [350, 18, 361, 15]}) In [6]: psdf = ps.from_pandas(df) In [7]: df.to_stata('/tmp/animals_1.dta') In [8]: psdf.to_stata('/tmp/animals_2.dta') In [9]: pd.read_stata('/tmp/animals_1.dta') Out[9]: index animal speed 0 0 falcon 350 1 1 parrot 18 2 2 falcon 361 3 3 parrot 15 In [10]: pd.read_stata('/tmp/animals_2.dta') Out[10]: index animal speed 0 0 falcon 350 1 1 parrot 18 2 2 falcon 361 3 3 parrot 15 ``` ### How was this patch tested? added ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #44996 from zhengruifeng/ps_to_stata. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- dev/sparktestsupport/modules.py | 2 + .../docs/source/reference/pyspark.pandas/frame.rst | 1 + python/pyspark/pandas/frame.py | 82 ++++++++++++++++++++++ python/pyspark/pandas/missing/frame.py | 1 - .../pandas/tests/connect/io/test_parity_stata.py | 42 +++++++++++ python/pyspark/pandas/tests/io/test_stata.py | 67 ++++++++++++++++++ 6 files changed, 194 insertions(+), 1 deletion(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 233dcf4e54b6..2ed2144fa64b 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -817,6 +817,7 @@ pyspark_pandas = Module( "pyspark.pandas.tests.io.test_io", "pyspark.pandas.tests.io.test_csv", "pyspark.pandas.tests.io.test_feather", + "pyspark.pandas.tests.io.test_stata", "pyspark.pandas.tests.io.test_dataframe_conversion", "pyspark.pandas.tests.io.test_dataframe_spark_io", "pyspark.pandas.tests.io.test_series_conversion", @@ -1299,6 +1300,7 @@ pyspark_pandas_connect_part3 = Module( "pyspark.pandas.tests.connect.io.test_parity_io", "pyspark.pandas.tests.connect.io.test_parity_csv", "pyspark.pandas.tests.connect.io.test_parity_feather", + "pyspark.pandas.tests.connect.io.test_parity_stata", "pyspark.pandas.tests.connect.io.test_parity_dataframe_conversion", "pyspark.pandas.tests.connect.io.test_parity_dataframe_spark_io", "pyspark.pandas.tests.connect.io.test_parity_series_conversion", diff --git a/python/docs/source/reference/pyspark.pandas/frame.rst b/python/docs/source/reference/pyspark.pandas/frame.rst index 564ddb607a19..336fd262f611 100644 --- a/python/docs/source/reference/pyspark.pandas/frame.rst +++ b/python/docs/source/reference/pyspark.pandas/frame.rst @@ -284,6 +284,7 @@ Serialization / IO / Conversion DataFrame.to_spark DataFrame.to_string DataFrame.to_feather + DataFrame.to_stata DataFrame.to_json DataFrame.to_dict DataFrame.to_excel diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 3b3565f7ea9f..e857344a6098 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -2683,6 +2683,88 @@ defaultdict(<class 'list'>, {'col..., 'col...})] self._to_internal_pandas(), self.to_feather, pd.DataFrame.to_feather, args ) + def to_stata( + self, + path: Union[str, IO[str]], + *, + convert_dates: Optional[Dict] = None, + write_index: bool = True, + byteorder: Optional[str] = None, + time_stamp: Optional[datetime.datetime] = None, + data_label: Optional[str] = None, + variable_labels: Optional[Dict] = None, + version: Optional[int] = 114, + convert_strl: Optional[Sequence[Name]] = None, + compression: str = "infer", + storage_options: Optional[str] = None, + value_labels: Optional[Dict] = None, + ) -> None: + """ + Export DataFrame object to Stata dta format. + + .. note:: This method should only be used if the resulting DataFrame is expected + to be small, as all the data is loaded into the driver's memory. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + path : str, path object, or buffer + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``write()`` function. + convert_dates : dict + Dictionary mapping columns containing datetime types to stata + internal format to use when writing the dates. Options are 'tc', + 'td', 'tm', 'tw', 'th', 'tq', 'ty'. Column can be either an integer + or a name. Datetime columns that do not have a conversion type + specified will be converted to 'tc'. Raises NotImplementedError if + a datetime column has timezone information. + write_index : bool + Write the index to Stata dataset. + byteorder : str + Can be ">", "<", "little", or "big". default is `sys.byteorder`. + time_stamp : datetime + A datetime to use as file creation date. Default is the current + time. + data_label : str, optional + A label for the data set. Must be 80 characters or smaller. + variable_labels : dict + Dictionary containing columns as keys and variable labels as + values. Each label must be 80 characters or smaller. + version : {{114, 117, 118, 119, None}}, default 114 + Version to use in the output dta file. Set to None to let pandas + decide between 118 or 119 formats depending on the number of + columns in the frame. Version 114 can be read by Stata 10 and + later. Version 117 can be read by Stata 13 or later. Version 118 + is supported in Stata 14 and later. Version 119 is supported in + Stata 15 and later. Version 114 limits string variables to 244 + characters or fewer while versions 117 and later allow strings + with lengths up to 2,000,000 characters. Versions 118 and 119 + support Unicode characters, and version 119 supports more than + 32,767 variables. + convert_strl : list, optional + List of column names to convert to string columns to Stata StrL + format. Only available if version is 117. Storing strings in the + StrL format can produce smaller dta files if strings have more than + 8 characters and values are repeated. + value_labels : dict of dicts + Dictionary containing columns as keys and dictionaries of column value + to labels as values. Labels for a single variable must be 32,000 + characters or smaller. + + Examples + -------- + >>> df = ps.DataFrame({'animal': ['falcon', 'parrot', 'falcon', 'parrot'], + ... 'speed': [350, 18, 361, 15]}) + >>> df.to_stata('animals.dta') # doctest: +SKIP + """ + # Make sure locals() call is at the top of the function so we don't capture local variables. + args = locals() + + return validate_arguments_and_invoke_function( + self._to_internal_pandas(), self.to_stata, pd.DataFrame.to_stata, args + ) + def transpose(self) -> "DataFrame": """ Transpose index and columns. diff --git a/python/pyspark/pandas/missing/frame.py b/python/pyspark/pandas/missing/frame.py index fdb6cec7c0f9..bdfa7574dc3d 100644 --- a/python/pyspark/pandas/missing/frame.py +++ b/python/pyspark/pandas/missing/frame.py @@ -44,7 +44,6 @@ class MissingPandasLikeDataFrame: set_axis = _unsupported_function("set_axis") to_period = _unsupported_function("to_period") to_sql = _unsupported_function("to_sql") - to_stata = _unsupported_function("to_stata") to_timestamp = _unsupported_function("to_timestamp") tz_convert = _unsupported_function("tz_convert") tz_localize = _unsupported_function("tz_localize") diff --git a/python/pyspark/pandas/tests/connect/io/test_parity_stata.py b/python/pyspark/pandas/tests/connect/io/test_parity_stata.py new file mode 100644 index 000000000000..d7a74d7399a5 --- /dev/null +++ b/python/pyspark/pandas/tests/connect/io/test_parity_stata.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.pandas.tests.io.test_stata import StataMixin +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils, TestUtils + + +class StataParityTests( + StataMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, + TestUtils, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.connect.io.test_parity_stata import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/io/test_stata.py b/python/pyspark/pandas/tests/io/test_stata.py new file mode 100644 index 000000000000..6fe7cf13513c --- /dev/null +++ b/python/pyspark/pandas/tests/io/test_stata.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import pandas as pd + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils + + +class StataMixin: + @property + def pdf(self): + return pd.DataFrame( + {"animal": ["falcon", "parrot", "falcon", "parrot"], "speed": [350, 18, 361, 15]} + ) + + @property + def psdf(self): + return ps.from_pandas(self.pdf) + + def test_to_feather(self): + with self.temp_dir() as dirpath: + path1 = f"{dirpath}/file1.dta" + path2 = f"{dirpath}/file2.dta" + + self.pdf.to_stata(path1) + self.psdf.to_stata(path2) + + self.assert_eq( + pd.read_stata(path1), + pd.read_stata(path2), + ) + + +class StataTests( + StataMixin, + PandasOnSparkTestCase, + TestUtils, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.io.test_stata import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org