This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6ad80fd0368 [SPARK-46084][PS] Refactor data type casting operation for 
Categorical type
6ad80fd0368 is described below

commit 6ad80fd036834e7291c572335e318096781a7ae4
Author: Haejoon Lee <haejoon....@databricks.com>
AuthorDate: Thu Nov 23 22:05:52 2023 -0800

    [SPARK-46084][PS] Refactor data type casting operation for Categorical type
    
    ### What changes were proposed in this pull request?
    
    The PR proposes to refactor data type casting operation - especially 
`DataTypeOps.astype` -  for Categorical type.
    
    ### Why are the changes needed?
    
    To optimize performance/debuggability/readability by using official API. We 
can leverage the PySpark API `coalesce` and `create_map `, instead of 
implementing Python code from scratch.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it's internal optimization.
    
    ### How was this patch tested?
    
    The existing CI should pass.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43993 from itholic/refactor_cat.
    
    Authored-by: Haejoon Lee <haejoon....@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 python/pyspark/pandas/data_type_ops/base.py | 26 ++++++--------------------
 1 file changed, 6 insertions(+), 20 deletions(-)

diff --git a/python/pyspark/pandas/data_type_ops/base.py 
b/python/pyspark/pandas/data_type_ops/base.py
index 4f57aa65be7..5a4cd7a1eb0 100644
--- a/python/pyspark/pandas/data_type_ops/base.py
+++ b/python/pyspark/pandas/data_type_ops/base.py
@@ -18,6 +18,7 @@
 import numbers
 from abc import ABCMeta
 from typing import Any, Optional, Union
+from itertools import chain
 
 import numpy as np
 import pandas as pd
@@ -129,26 +130,11 @@ def _as_categorical_type(
         if len(categories) == 0:
             scol = F.lit(-1)
         else:
-            scol = F.lit(-1)
-            if isinstance(
-                
index_ops._internal.spark_type_for(index_ops._internal.column_labels[0]), 
BinaryType
-            ):
-                from pyspark.sql.functions import base64
-
-                stringified_column = base64(index_ops.spark.column)
-                for code, category in enumerate(categories):
-                    # Convert each category to base64 before comparison
-                    base64_category = F.base64(F.lit(category))
-                    scol = F.when(stringified_column == base64_category, 
F.lit(code)).otherwise(
-                        scol
-                    )
-            else:
-                stringified_column = F.format_string("%s", 
index_ops.spark.column)
-
-                for code, category in enumerate(categories):
-                    scol = F.when(stringified_column == F.lit(category), 
F.lit(code)).otherwise(
-                        scol
-                    )
+            kvs = chain(
+                *[(F.lit(category), F.lit(code)) for code, category in 
enumerate(categories)]
+            )
+            map_scol = F.create_map(*kvs)
+            scol = F.coalesce(map_scol[index_ops.spark.column], F.lit(-1))
 
         return index_ops._with_new_scol(
             scol.cast(spark_type),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to