This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 481f9866f5f5 [SPARK-55303][PYTHON][TESTS] Extract GoldenFileTestMixin
for type coercion golden file tests
481f9866f5f5 is described below
commit 481f9866f5f5a41ad08e72f2a3820b0620927580
Author: Yicong-Huang <[email protected]>
AuthorDate: Thu Feb 5 11:06:44 2026 +0800
[SPARK-55303][PYTHON][TESTS] Extract GoldenFileTestMixin for type coercion
golden file tests
### What changes were proposed in this pull request?
Extract common golden file testing utilities into `GoldenFileTestMixin` in
`python/pyspark/testing/goldenutils.py`, and simplify the four type coercion
test files to use this mixin.
### Why are the changes needed?
Reduce duplicated code across four test files and provide a reusable
framework for future golden file tests.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Regenerated all golden files with `SPARK_GENERATE_GOLDEN_FILES=1` and
verified tests pass.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54084 from
Yicong-Huang/SPARK-55303/refactor/extract-golden-file-test-util.
Authored-by: Yicong-Huang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../tests/coercion/test_pandas_udf_input_type.py | 62 +----
.../tests/coercion/test_pandas_udf_return_type.py | 75 +-----
.../tests/coercion/test_python_udf_input_type.py | 62 +----
.../tests/coercion/test_python_udf_return_type.py | 64 +-----
python/pyspark/testing/goldenutils.py | 254 +++++++++++++++++++++
5 files changed, 297 insertions(+), 220 deletions(-)
diff --git a/python/pyspark/sql/tests/coercion/test_pandas_udf_input_type.py
b/python/pyspark/sql/tests/coercion/test_pandas_udf_input_type.py
index cd3880e6c9dd..64377f2df698 100644
--- a/python/pyspark/sql/tests/coercion/test_pandas_udf_input_type.py
+++ b/python/pyspark/sql/tests/coercion/test_pandas_udf_input_type.py
@@ -18,7 +18,6 @@
from decimal import Decimal
import datetime
import os
-import time
import unittest
from pyspark.sql.functions import pandas_udf
@@ -51,6 +50,7 @@ from pyspark.testing.utils import (
numpy_requirement_message,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.goldenutils import GoldenFileTestMixin
if have_numpy:
import numpy as np
@@ -73,29 +73,7 @@ if have_pandas:
or LooseVersion(np.__version__) < LooseVersion("2.0.0"),
pandas_requirement_message or pyarrow_requirement_message or
numpy_requirement_message,
)
-class PandasUDFInputTypeTests(ReusedSQLTestCase):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
- # Synchronize default timezone between Python and Java
- cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
- tz = "America/Los_Angeles"
- os.environ["TZ"] = tz
- time.tzset()
-
- cls.sc.environment["TZ"] = tz
- cls.spark.conf.set("spark.sql.session.timeZone", tz)
-
- @classmethod
- def tearDownClass(cls):
- del os.environ["TZ"]
- if cls.tz_prev is not None:
- os.environ["TZ"] = cls.tz_prev
- time.tzset()
-
- super().tearDownClass()
-
+class PandasUDFInputTypeTests(GoldenFileTestMixin, ReusedSQLTestCase):
@property
def prefix(self):
return "golden_pandas_udf_input_type_coercion"
@@ -265,27 +243,20 @@ class PandasUDFInputTypeTests(ReusedSQLTestCase):
self._compare_or_generate_golden(golden_file, test_name)
def _compare_or_generate_golden(self, golden_file, test_name):
- testing = os.environ.get("SPARK_GENERATE_GOLDEN_FILES", "?") != "1"
+ generating = self.is_generating_golden()
golden_csv = os.path.join(os.path.dirname(__file__),
f"{golden_file}.csv")
golden_md = os.path.join(os.path.dirname(__file__),
f"{golden_file}.md")
golden = None
- if testing:
- golden = pd.read_csv(
- golden_csv,
- sep="\t",
- index_col=0,
- dtype="str",
- na_filter=False,
- engine="python",
- )
+ if not generating:
+ golden = self.load_golden_csv(golden_csv)
results = []
- for idx, (test_name, spark_type, data_func) in
enumerate(self.test_cases):
+ for idx, (case_name, spark_type, data_func) in
enumerate(self.test_cases):
input_df = data_func(spark_type).repartition(1)
input_data = [row["value"] for row in input_df.collect()]
- result = [test_name, spark_type.simpleString(), str(input_data)]
+ result = [case_name, self.repr_type(spark_type), str(input_data)]
try:
@@ -319,15 +290,15 @@ class PandasUDFInputTypeTests(ReusedSQLTestCase):
result.append(f"✗ {str(e)}")
# Clean up exception message to remove newlines and extra
whitespace
- result = [r.replace("\n", " ").replace("\r", " ").replace("\t", "
") for r in result]
+ result = [self.clean_result(r) for r in result]
error_msg = None
- if testing and result != list(golden.iloc[idx]):
+ if not generating and result != list(golden.iloc[idx]):
error_msg = f"line mismatch: expects {list(golden.iloc[idx])}
but got {result}"
results.append((result, error_msg))
- if testing:
+ if not generating:
errs = []
for _, err in results:
if err is not None:
@@ -340,18 +311,7 @@ class PandasUDFInputTypeTests(ReusedSQLTestCase):
columns=["Test Case", "Spark Type", "Spark Value", "Python
Type", "Python Value"],
)
- # generating the CSV file as the golden file
- new_golden.to_csv(golden_csv, sep="\t", header=True, index=True)
-
- try:
- # generating the GitHub flavored Markdown file
- # package tabulate is required
- new_golden.to_markdown(golden_md, index=True,
tablefmt="github")
- except Exception as e:
- print(
- f"{test_name} return type coercion: "
- f"fail to write the markdown file due to {e}!"
- )
+ self.save_golden(new_golden, golden_csv, golden_md)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py
b/python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py
index f0c81adb1010..71fd7b14daa4 100644
--- a/python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py
+++ b/python/pyspark/sql/tests/coercion/test_pandas_udf_return_type.py
@@ -19,7 +19,6 @@ import concurrent.futures
from decimal import Decimal
import itertools
import os
-import time
import unittest
from pyspark.sql.functions import pandas_udf
@@ -51,6 +50,7 @@ from pyspark.testing.utils import (
numpy_requirement_message,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.goldenutils import GoldenFileTestMixin
if have_numpy:
import numpy as np
@@ -73,36 +73,14 @@ if have_pandas:
or LooseVersion(np.__version__) < LooseVersion("2.0.0"),
pandas_requirement_message or pyarrow_requirement_message or
numpy_requirement_message,
)
-class PandasUDFReturnTypeTests(ReusedSQLTestCase):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
- # Synchronize default timezone between Python and Java
- cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
- tz = "America/Los_Angeles"
- os.environ["TZ"] = tz
- time.tzset()
-
- cls.sc.environment["TZ"] = tz
- cls.spark.conf.set("spark.sql.session.timeZone", tz)
-
- @classmethod
- def tearDownClass(cls):
- del os.environ["TZ"]
- if cls.tz_prev is not None:
- os.environ["TZ"] = cls.tz_prev
- time.tzset()
-
- super().tearDownClass()
-
+class PandasUDFReturnTypeTests(GoldenFileTestMixin, ReusedSQLTestCase):
@property
def prefix(self):
return "golden_pandas_udf_return_type_coercion"
@property
def test_data(self):
- data = [
+ return [
[None, None],
[True, False],
list("ab"),
@@ -131,7 +109,6 @@ class PandasUDFReturnTypeTests(ReusedSQLTestCase):
pd.Categorical(["A", "B"]),
pd.DataFrame({"_1": [1, 2]}),
]
- return data
@property
def test_types(self):
@@ -153,19 +130,9 @@ class PandasUDFReturnTypeTests(ReusedSQLTestCase):
StructType([StructField("_1", IntegerType())]),
]
- def repr_type(self, spark_type):
- return spark_type.simpleString()
-
def repr_value(self, value):
- v_str = value.to_json() if isinstance(value, pd.DataFrame) else
str(value)
- v_str = v_str.replace(chr(10), " ")
- v_str = v_str[:32]
- if isinstance(value, np.ndarray):
- return f"{v_str}@ndarray[{value.dtype.name}]"
- elif isinstance(value, pd.DataFrame):
- simple_schema = ", ".join([f"{t} {d.name}" for t, d in
value.dtypes.items()])
- return f"{v_str}@Dataframe[{simple_schema}]"
- return f"{v_str}@{type(value).__name__}"
+ # Use extended pandas value representation
+ return self.repr_pandas_value(value)
def test_str_repr(self):
self.assertEqual(
@@ -189,21 +156,14 @@ class PandasUDFReturnTypeTests(ReusedSQLTestCase):
self._compare_or_generate_golden(golden_file, test_name)
def _compare_or_generate_golden(self, golden_file, test_name):
- testing = os.environ.get("SPARK_GENERATE_GOLDEN_FILES", "?") != "1"
+ generating = self.is_generating_golden()
golden_csv = os.path.join(os.path.dirname(__file__),
f"{golden_file}.csv")
golden_md = os.path.join(os.path.dirname(__file__),
f"{golden_file}.md")
golden = None
- if testing:
- golden = pd.read_csv(
- golden_csv,
- sep="\t",
- index_col=0,
- dtype="str",
- na_filter=False,
- engine="python",
- )
+ if not generating:
+ golden = self.load_golden_csv(golden_csv)
def work(arg):
spark_type, value = arg
@@ -231,10 +191,10 @@ class PandasUDFReturnTypeTests(ReusedSQLTestCase):
result = "X"
# Clean up exception message to remove newlines and extra
whitespace
- result = result.replace("\n", " ").replace("\r", "
").replace("\t", " ")
+ result = self.clean_result(result)
err = None
- if testing:
+ if not generating:
expected = golden.loc[str_t, str_v]
if expected != result:
err = f"{str_v} => {spark_type} expects {expected} but got
{result}"
@@ -250,7 +210,7 @@ class PandasUDFReturnTypeTests(ReusedSQLTestCase):
)
)
- if testing:
+ if not generating:
errs = []
for _, _, _, err in results:
if err is not None:
@@ -270,18 +230,7 @@ class PandasUDFReturnTypeTests(ReusedSQLTestCase):
for str_t, str_v, res, _ in results:
new_golden.loc[str_t, str_v] = res
- # generating the CSV file as the golden file
- new_golden.to_csv(golden_csv, sep="\t", header=True, index=True)
-
- try:
- # generating the GitHub flavored Markdown file
- # package tabulate is required
- new_golden.to_markdown(golden_md, index=True,
tablefmt="github")
- except Exception as e:
- print(
- f"{test_name} return type coercion: "
- f"fail to write the markdown file due to {e}!"
- )
+ self.save_golden(new_golden, golden_csv, golden_md)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/coercion/test_python_udf_input_type.py
b/python/pyspark/sql/tests/coercion/test_python_udf_input_type.py
index 6897647d6bea..f0afb84b4361 100644
--- a/python/pyspark/sql/tests/coercion/test_python_udf_input_type.py
+++ b/python/pyspark/sql/tests/coercion/test_python_udf_input_type.py
@@ -18,7 +18,6 @@
from decimal import Decimal
import datetime
import os
-import time
import unittest
from pyspark.sql.functions import udf
@@ -51,6 +50,7 @@ from pyspark.testing.utils import (
numpy_requirement_message,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.goldenutils import GoldenFileTestMixin
if have_numpy:
import numpy as np
@@ -73,29 +73,7 @@ if have_pandas:
or LooseVersion(np.__version__) < LooseVersion("2.0.0"),
pandas_requirement_message or pyarrow_requirement_message or
numpy_requirement_message,
)
-class UDFInputTypeTests(ReusedSQLTestCase):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
- # Synchronize default timezone between Python and Java
- cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
- tz = "America/Los_Angeles"
- os.environ["TZ"] = tz
- time.tzset()
-
- cls.sc.environment["TZ"] = tz
- cls.spark.conf.set("spark.sql.session.timeZone", tz)
-
- @classmethod
- def tearDownClass(cls):
- del os.environ["TZ"]
- if cls.tz_prev is not None:
- os.environ["TZ"] = cls.tz_prev
- time.tzset()
-
- super().tearDownClass()
-
+class UDFInputTypeTests(GoldenFileTestMixin, ReusedSQLTestCase):
@property
def prefix(self):
return "golden_python_udf_input_type_coercion"
@@ -289,27 +267,20 @@ class UDFInputTypeTests(ReusedSQLTestCase):
self._compare_or_generate_golden(golden_file, test_name)
def _compare_or_generate_golden(self, golden_file, test_name):
- testing = os.environ.get("SPARK_GENERATE_GOLDEN_FILES", "?") != "1"
+ generating = self.is_generating_golden()
golden_csv = os.path.join(os.path.dirname(__file__),
f"{golden_file}.csv")
golden_md = os.path.join(os.path.dirname(__file__),
f"{golden_file}.md")
golden = None
- if testing:
- golden = pd.read_csv(
- golden_csv,
- sep="\t",
- index_col=0,
- dtype="str",
- na_filter=False,
- engine="python",
- )
+ if not generating:
+ golden = self.load_golden_csv(golden_csv)
results = []
- for idx, (test_name, spark_type, data_func) in
enumerate(self.test_cases):
+ for idx, (case_name, spark_type, data_func) in
enumerate(self.test_cases):
input_df = data_func(spark_type).repartition(1)
input_data = [row["value"] for row in input_df.collect()]
- result = [test_name, spark_type.simpleString(), str(input_data)]
+ result = [case_name, self.repr_type(spark_type), str(input_data)]
try:
@@ -350,15 +321,15 @@ class UDFInputTypeTests(ReusedSQLTestCase):
result.append(f"✗ {str(e)}")
# Clean up exception message to remove newlines and extra
whitespace
- result = [r.replace("\n", " ").replace("\r", " ").replace("\t", "
") for r in result]
+ result = [self.clean_result(r) for r in result]
error_msg = None
- if testing and result != list(golden.iloc[idx]):
+ if not generating and result != list(golden.iloc[idx]):
error_msg = f"line mismatch: expects {list(golden.iloc[idx])}
but got {result}"
results.append((result, error_msg))
- if testing:
+ if not generating:
errs = []
for _, err in results:
if err is not None:
@@ -371,18 +342,7 @@ class UDFInputTypeTests(ReusedSQLTestCase):
columns=["Test Case", "Spark Type", "Spark Value", "Python
Type", "Python Value"],
)
- # generating the CSV file as the golden file
- new_golden.to_csv(golden_csv, sep="\t", header=True, index=True)
-
- try:
- # generating the GitHub flavored Markdown file
- # package tabulate is required
- new_golden.to_markdown(golden_md, index=True,
tablefmt="github")
- except Exception as e:
- print(
- f"{test_name} return type coercion: "
- f"fail to write the markdown file due to {e}!"
- )
+ self.save_golden(new_golden, golden_csv, golden_md)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/coercion/test_python_udf_return_type.py
b/python/pyspark/sql/tests/coercion/test_python_udf_return_type.py
index 08ba2f809f50..e3b9939fa51f 100644
--- a/python/pyspark/sql/tests/coercion/test_python_udf_return_type.py
+++ b/python/pyspark/sql/tests/coercion/test_python_udf_return_type.py
@@ -22,7 +22,6 @@ from decimal import Decimal
import itertools
import os
import re
-import time
import unittest
from pyspark.sql import Row
@@ -55,6 +54,7 @@ from pyspark.testing.utils import (
numpy_requirement_message,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.goldenutils import GoldenFileTestMixin
if have_numpy:
import numpy as np
@@ -77,29 +77,7 @@ if have_pandas:
or LooseVersion(np.__version__) < LooseVersion("2.0.0"),
pandas_requirement_message or pyarrow_requirement_message or
numpy_requirement_message,
)
-class UDFReturnTypeTests(ReusedSQLTestCase):
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
-
- # Synchronize default timezone between Python and Java
- cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
- tz = "America/Los_Angeles"
- os.environ["TZ"] = tz
- time.tzset()
-
- cls.sc.environment["TZ"] = tz
- cls.spark.conf.set("spark.sql.session.timeZone", tz)
-
- @classmethod
- def tearDownClass(cls):
- del os.environ["TZ"]
- if cls.tz_prev is not None:
- os.environ["TZ"] = cls.tz_prev
- time.tzset()
-
- super().tearDownClass()
-
+class UDFReturnTypeTests(GoldenFileTestMixin, ReusedSQLTestCase):
@property
def prefix(self):
return "golden_python_udf_return_type_coercion"
@@ -144,12 +122,6 @@ class UDFReturnTypeTests(ReusedSQLTestCase):
StructType([StructField("_1", IntegerType())]),
]
- def repr_type(self, spark_type):
- return spark_type.simpleString()
-
- def repr_value(self, value):
- return f"{str(value)}@{type(value).__name__}"
-
def test_str_repr(self):
self.assertEqual(
len(self.test_types),
@@ -196,21 +168,14 @@ class UDFReturnTypeTests(ReusedSQLTestCase):
self._compare_or_generate_golden(golden_file, test_name)
def _compare_or_generate_golden(self, golden_file, test_name):
- testing = os.environ.get("SPARK_GENERATE_GOLDEN_FILES", "?") != "1"
+ generating = self.is_generating_golden()
golden_csv = os.path.join(os.path.dirname(__file__),
f"{golden_file}.csv")
golden_md = os.path.join(os.path.dirname(__file__),
f"{golden_file}.md")
golden = None
- if testing:
- golden = pd.read_csv(
- golden_csv,
- sep="\t",
- index_col=0,
- dtype="str",
- na_filter=False,
- engine="python",
- )
+ if not generating:
+ golden = self.load_golden_csv(golden_csv)
def work(arg):
spark_type, value = arg
@@ -228,10 +193,10 @@ class UDFReturnTypeTests(ReusedSQLTestCase):
result = "X"
# Clean up exception message to remove newlines and extra
whitespace
- result = result.replace("\n", " ").replace("\r", "
").replace("\t", " ")
+ result = self.clean_result(result)
err = None
- if testing:
+ if not generating:
expected = golden.loc[str_t, str_v]
if expected != result:
err = f"{str_v} => {spark_type} expects {expected} but got
{result}"
@@ -247,7 +212,7 @@ class UDFReturnTypeTests(ReusedSQLTestCase):
)
)
- if testing:
+ if not generating:
errs = []
for _, _, _, err in results:
if err is not None:
@@ -267,18 +232,7 @@ class UDFReturnTypeTests(ReusedSQLTestCase):
for str_t, str_v, res, _ in results:
new_golden.loc[str_t, str_v] = res
- # generating the CSV file as the golden file
- new_golden.to_csv(golden_csv, sep="\t", header=True, index=True)
-
- try:
- # generating the GitHub flavored Markdown file
- # package tabulate is required
- new_golden.to_markdown(golden_md, index=True,
tablefmt="github")
- except Exception as e:
- print(
- f"{test_name} return type coercion: "
- f"fail to write the markdown file due to {e}!"
- )
+ self.save_golden(new_golden, golden_csv, golden_md)
if __name__ == "__main__":
diff --git a/python/pyspark/testing/goldenutils.py
b/python/pyspark/testing/goldenutils.py
new file mode 100644
index 000000000000..ecb253689e97
--- /dev/null
+++ b/python/pyspark/testing/goldenutils.py
@@ -0,0 +1,254 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Any, Optional
+import os
+import time
+
+import pandas as pd
+
+try:
+ import numpy as np
+
+ have_numpy = True
+except ImportError:
+ have_numpy = False
+
+
+class GoldenFileTestMixin:
+ """
+ Mixin class providing utilities for golden file based testing.
+
+ Golden files are CSV files that store expected test results. This mixin
provides:
+ - Timezone setup/teardown for deterministic results
+ - Golden file read/write with SPARK_GENERATE_GOLDEN_FILES env var support
+ - Result string cleaning utilities
+
+ To regenerate golden files, set SPARK_GENERATE_GOLDEN_FILES=1 before
running tests.
+
+ Usage:
+ class MyTest(GoldenFileTestMixin, ReusedSQLTestCase):
+ def test_something(self):
+ # Use helper methods from mixin
+ if self.is_generating_golden():
+ self.save_golden(df, golden_csv, golden_md)
+ else:
+ golden = self.load_golden_csv(golden_csv)
+ # compare results with golden
+ """
+
+ _tz_prev: Optional[str] = None
+
+ def __init_subclass__(cls, **kwargs):
+ """Verify correct inheritance order at class definition time."""
+ super().__init_subclass__(**kwargs)
+ # Check that GoldenFileTestMixin comes before any class with
setUpClass in MRO.
+ # This ensures setup_timezone() will be called after Spark session is
created.
+ # Correct: class MyTest(GoldenFileTestMixin, ReusedSQLTestCase)
+ # Incorrect: class MyTest(ReusedSQLTestCase, GoldenFileTestMixin)
+ for base in cls.__mro__:
+ if base is GoldenFileTestMixin:
+ break
+ # If we find a class with setUpClass before GoldenFileTestMixin,
that's wrong
+ if base is not cls and hasattr(base, "setUpClass") and
"setUpClass" in base.__dict__:
+ raise TypeError(
+ f"{cls.__name__} has incorrect inheritance order. "
+ f"GoldenFileTestMixin must be listed BEFORE
{base.__name__}. "
+ f"Use: class {cls.__name__}(GoldenFileTestMixin,
{base.__name__}, ...)"
+ )
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ """Setup test class with timezone configuration."""
+ super().setUpClass()
+ cls.setup_timezone()
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ """Teardown test class and restore timezone."""
+ cls.teardown_timezone()
+ super().tearDownClass()
+
+ @classmethod
+ def setup_timezone(cls, tz: str = "America/Los_Angeles") -> None:
+ """
+ Setup timezone for deterministic test results.
+ Synchronizes timezone between Python and Java.
+ """
+ cls._tz_prev = os.environ.get("TZ", None)
+ os.environ["TZ"] = tz
+ time.tzset()
+
+ cls.sc.environment["TZ"] = tz
+ cls.spark.conf.set("spark.sql.session.timeZone", tz)
+
+ @classmethod
+ def teardown_timezone(cls) -> None:
+ """Restore original timezone."""
+ if "TZ" in os.environ:
+ del os.environ["TZ"]
+ if cls._tz_prev is not None:
+ os.environ["TZ"] = cls._tz_prev
+ time.tzset()
+
+ @staticmethod
+ def is_generating_golden() -> bool:
+ """Check if we are generating golden files (vs testing against
them)."""
+ return os.environ.get("SPARK_GENERATE_GOLDEN_FILES", "0") == "1"
+
+ @staticmethod
+ def load_golden_csv(golden_csv: str, use_index: bool = True) ->
"pd.DataFrame":
+ """
+ Load golden file from CSV.
+
+ Parameters
+ ----------
+ golden_csv : str
+ Path to the golden CSV file.
+ use_index : bool
+ If True, use first column as index.
+ If False, don't use index.
+
+ Returns
+ -------
+ pd.DataFrame
+ The loaded golden data with string dtype.
+ """
+ return pd.read_csv(
+ golden_csv,
+ sep="\t",
+ index_col=0 if use_index else None,
+ dtype="str",
+ na_filter=False,
+ engine="python",
+ )
+
+ @staticmethod
+ def save_golden(df: "pd.DataFrame", golden_csv: str, golden_md:
Optional[str] = None) -> None:
+ """
+ Save DataFrame as golden file (CSV and optionally Markdown).
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ The DataFrame to save.
+ golden_csv : str
+ Path to save the CSV file.
+ golden_md : str, optional
+ Path to save the Markdown file. Requires tabulate package.
+ """
+ df.to_csv(golden_csv, sep="\t", header=True, index=True)
+
+ if golden_md is not None:
+ try:
+ df.to_markdown(golden_md, index=True, tablefmt="github")
+ except Exception as e:
+ import warnings
+
+ warnings.warn(
+ f"Failed to write markdown file {golden_md}: {e}. "
+ "Install 'tabulate' package to generate markdown files."
+ )
+
+ @staticmethod
+ def repr_type(t: Any) -> str:
+ """
+ Convert a type to string representation.
+
+ Handles different type representations:
+ - Spark DataType: uses simpleString() (e.g., "int", "string",
"array<int>")
+ - Python type: uses __name__ (e.g., "int", "str", "list")
+ - Other: uses str()
+
+ Parameters
+ ----------
+ t : Any
+ The type to represent. Can be Spark DataType or Python type.
+
+ Returns
+ -------
+ str
+ String representation of the type.
+ """
+ # Check if it's a Spark DataType (has simpleString method)
+ if hasattr(t, "simpleString"):
+ return t.simpleString()
+ # Check if it's a Python type
+ elif isinstance(t, type):
+ return t.__name__
+ else:
+ return str(t)
+
+ @classmethod
+ def repr_value(cls, value: Any, max_len: int = 32) -> str:
+ """
+ Convert Python value to string representation for golden file.
+
+ Default format: "value_str@type_name"
+ Subclasses can override this method for custom representations.
+
+ Parameters
+ ----------
+ value : Any
+ The Python value to represent.
+ max_len : int, default 32
+ Maximum length for the value string portion.
+
+ Returns
+ -------
+ str
+ String representation in format "value@type".
+ """
+ v_str = str(value)[:max_len]
+ return f"{v_str}@{type(value).__name__}"
+
+ @classmethod
+ def repr_pandas_value(cls, value: Any, max_len: int = 32) -> str:
+ """
+ Convert Python/Pandas value to string representation for golden file.
+
+ Extended version that handles pandas DataFrame and numpy ndarray
specially.
+
+ Parameters
+ ----------
+ value : Any
+ The Python value to represent.
+ max_len : int, default 32
+ Maximum length for the value string portion.
+
+ Returns
+ -------
+ str
+ String representation in format "value@type[dtype]".
+ """
+ if isinstance(value, pd.DataFrame):
+ v_str = value.to_json()
+ else:
+ v_str = str(value)
+ v_str = v_str.replace("\n", " ")[:max_len]
+
+ if have_numpy and isinstance(value, np.ndarray):
+ return f"{v_str}@ndarray[{value.dtype.name}]"
+ elif isinstance(value, pd.DataFrame):
+ simple_schema = ", ".join([f"{t} {d.name}" for t, d in
value.dtypes.items()])
+ return f"{v_str}@Dataframe[{simple_schema}]"
+ return f"{v_str}@{type(value).__name__}"
+
+ @staticmethod
+ def clean_result(result: str) -> str:
+ """Clean result string by removing newlines and extra whitespace."""
+ return result.replace("\n", " ").replace("\r", " ").replace("\t", " ")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]