Jean-Marc Montanier created SPARK-30397: -------------------------------------------
Summary: [pyspark] Writer applied to custom model changes type of keys' dict from int to str Key: SPARK-30397 URL: https://issues.apache.org/jira/browse/SPARK-30397 Project: Spark Issue Type: Bug Components: PySpark Affects Versions: 2.4.4 Reporter: Jean-Marc Montanier Hello, I have a custom model that I'm trying to persist. Within this custom model there is a python dict mapping from int to int. When the model is saved (with write().save('path')), the keys of the dict are modified from int to str. You can find bellow a code to reproduce the issue: {code:python} #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ @author: Jean-Marc Montanier @date: 2019/12/31 """ from pyspark.sql import SparkSession from pyspark import keyword_only from pyspark.ml import Pipeline, PipelineModel from pyspark.ml import Estimator, Model from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasInputCol, HasOutputCol from pyspark.sql.types import IntegerType from pyspark.sql.functions import udf spark = SparkSession \ .builder \ .appName("ImputeNormal") \ .getOrCreate() class CustomFit(Estimator, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable, ): @keyword_only def __init__(self, inputCol="inputCol", outputCol="outputCol"): super(CustomFit, self).__init__() self._setDefault(inputCol="inputCol", outputCol="outputCol") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, inputCol="inputCol", outputCol="outputCol"): """ setParams(self, inputCol="inputCol", outputCol="outputCol") """ kwargs = self._input_kwargs self._set(**kwargs) return self def _fit(self, data): inputCol = self.getInputCol() outputCol = self.getOutputCol() categories = data.where(data[inputCol].isNotNull()) \ .groupby(inputCol) \ .count() \ .orderBy("count", ascending=False) \ .limit(2) categories = dict(categories.toPandas().set_index(inputCol)["count"]) for cat in categories: categories[cat] = int(categories[cat]) return CustomModel(categories=categories, input_col=inputCol, output_col=outputCol) class CustomModel(Model, DefaultParamsReadable, DefaultParamsWritable): input_col = Param(Params._dummy(), "input_col", "Name of the input column") output_col = Param(Params._dummy(), "output_col", "Name of the output column") categories = Param(Params._dummy(), "categories", "Top categories") def __init__(self, categories: dict = None, input_col="input_col", output_col="output_col"): super(CustomModel, self).__init__() self._set(categories=categories, input_col=input_col, output_col=output_col) def get_output_col(self) -> str: """ output_col getter :return: """ return self.getOrDefault(self.output_col) def get_input_col(self) -> str: """ input_col getter :return: """ return self.getOrDefault(self.input_col) def get_categories(self): """ categories getter :return: """ return self.getOrDefault(self.categories) def _transform(self, data): input_col = self.get_input_col() output_col = self.get_output_col() categories = self.get_categories() def get_cat(val): if val is None: return -1 if val not in categories: return -1 return int(categories[val]) get_cat_udf = udf(get_cat, IntegerType()) df = data.withColumn(output_col, get_cat_udf(input_col)) return df def test_without_write(): fit_df = spark.createDataFrame([[10]] * 5 + [[11]] * 4 + [[12]] * 3 + [[None]] * 2, ['input']) custom_fit = CustomFit(inputCol='input', outputCol='output') pipeline = Pipeline(stages=[custom_fit]) pipeline_model = pipeline.fit(fit_df) print("Categories: {}".format(pipeline_model.stages[0].get_categories())) transform_df = spark.createDataFrame([[10]] * 2 + [[11]] * 2 + [[12]] * 2 + [[None]] * 2, ['input']) test = pipeline_model.transform(transform_df) test.show() # This output is the expected output def test_with_write(): fit_df = spark.createDataFrame([[10]] * 5 + [[11]] * 4 + [[12]] * 3 + [[None]] * 2, ['input']) custom_fit = CustomFit(inputCol='input', outputCol='output') pipeline = Pipeline(stages=[custom_fit]) pipeline_model = pipeline.fit(fit_df) print("Categories: {}".format(pipeline_model.stages[0].get_categories())) pipeline_model.write().save('tmp') loaded_model = PipelineModel.load('tmp') # We can see that the type of the keys is know str instead of int print("Categories: {}".format(loaded_model.stages[0].get_categories())) transform_df = spark.createDataFrame([[10]] * 2 + [[11]] * 2 + [[12]] * 2 + [[None]] * 2, ['input']) test = loaded_model.transform(transform_df) test.show() # We can see that the output does not match the expected output if __name__ == "__main__": test_without_write() test_with_write() {code} -- This message was sent by Atlassian Jira (v8.3.4#803005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org