[ https://issues.apache.org/jira/browse/SPARK-30397?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Hyukjin Kwon updated SPARK-30397: --------------------------------- Component/s: ML > [pyspark] Writer applied to custom model changes type of keys' dict from int > to str > ----------------------------------------------------------------------------------- > > Key: SPARK-30397 > URL: https://issues.apache.org/jira/browse/SPARK-30397 > Project: Spark > Issue Type: Bug > Components: ML, PySpark > Affects Versions: 2.4.4 > Reporter: Jean-Marc Montanier > Priority: Major > > Hello, > > I have a custom model that I'm trying to persist. Within this custom model > there is a python dict mapping from int to int. When the model is saved (with > write().save('path')), the keys of the dict are modified from int to str. > > You can find bellow a code to reproduce the issue: > {code:python} > #!/usr/bin/env python3 > # -*- coding: utf-8 -*- > """ > @author: Jean-Marc Montanier > @date: 2019/12/31 > """ > from pyspark.sql import SparkSession > from pyspark import keyword_only > from pyspark.ml import Pipeline, PipelineModel > from pyspark.ml import Estimator, Model > from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable > from pyspark.ml.param import Param, Params > from pyspark.ml.param.shared import HasInputCol, HasOutputCol > from pyspark.sql.types import IntegerType > from pyspark.sql.functions import udf > spark = SparkSession \ > .builder \ > .appName("ImputeNormal") \ > .getOrCreate() > class CustomFit(Estimator, > HasInputCol, > HasOutputCol, > DefaultParamsReadable, > DefaultParamsWritable, > ): > @keyword_only > def __init__(self, inputCol="inputCol", outputCol="outputCol"): > super(CustomFit, self).__init__() > self._setDefault(inputCol="inputCol", outputCol="outputCol") > kwargs = self._input_kwargs > self.setParams(**kwargs) > @keyword_only > def setParams(self, inputCol="inputCol", outputCol="outputCol"): > """ > setParams(self, inputCol="inputCol", outputCol="outputCol") > """ > kwargs = self._input_kwargs > self._set(**kwargs) > return self > def _fit(self, data): > inputCol = self.getInputCol() > outputCol = self.getOutputCol() > categories = data.where(data[inputCol].isNotNull()) \ > .groupby(inputCol) \ > .count() \ > .orderBy("count", ascending=False) \ > .limit(2) > categories = dict(categories.toPandas().set_index(inputCol)["count"]) > for cat in categories: > categories[cat] = int(categories[cat]) > return CustomModel(categories=categories, > input_col=inputCol, > output_col=outputCol) > class CustomModel(Model, > DefaultParamsReadable, > DefaultParamsWritable): > input_col = Param(Params._dummy(), "input_col", "Name of the input > column") > output_col = Param(Params._dummy(), "output_col", "Name of the output > column") > categories = Param(Params._dummy(), "categories", "Top categories") > def __init__(self, categories: dict = None, input_col="input_col", > output_col="output_col"): > super(CustomModel, self).__init__() > self._set(categories=categories, input_col=input_col, > output_col=output_col) > def get_output_col(self) -> str: > """ > output_col getter > :return: > """ > return self.getOrDefault(self.output_col) > def get_input_col(self) -> str: > """ > input_col getter > :return: > """ > return self.getOrDefault(self.input_col) > def get_categories(self): > """ > categories getter > :return: > """ > return self.getOrDefault(self.categories) > def _transform(self, data): > input_col = self.get_input_col() > output_col = self.get_output_col() > categories = self.get_categories() > def get_cat(val): > if val is None: > return -1 > if val not in categories: > return -1 > return int(categories[val]) > get_cat_udf = udf(get_cat, IntegerType()) > df = data.withColumn(output_col, > get_cat_udf(input_col)) > return df > def test_without_write(): > fit_df = spark.createDataFrame([[10]] * 5 + [[11]] * 4 + [[12]] * 3 + > [[None]] * 2, ['input']) > custom_fit = CustomFit(inputCol='input', outputCol='output') > pipeline = Pipeline(stages=[custom_fit]) > pipeline_model = pipeline.fit(fit_df) > print("Categories: {}".format(pipeline_model.stages[0].get_categories())) > transform_df = spark.createDataFrame([[10]] * 2 + [[11]] * 2 + [[12]] * 2 > + [[None]] * 2, ['input']) > test = pipeline_model.transform(transform_df) > test.show() # This output is the expected output > def test_with_write(): > fit_df = spark.createDataFrame([[10]] * 5 + [[11]] * 4 + [[12]] * 3 + > [[None]] * 2, ['input']) > custom_fit = CustomFit(inputCol='input', outputCol='output') > pipeline = Pipeline(stages=[custom_fit]) > pipeline_model = pipeline.fit(fit_df) > print("Categories: {}".format(pipeline_model.stages[0].get_categories())) > pipeline_model.write().save('tmp') > loaded_model = PipelineModel.load('tmp') > # We can see that the type of the keys is know str instead of int > print("Categories: {}".format(loaded_model.stages[0].get_categories())) > transform_df = spark.createDataFrame([[10]] * 2 + [[11]] * 2 + [[12]] * 2 > + [[None]] * 2, ['input']) > test = loaded_model.transform(transform_df) > test.show() # We can see that the output does not match the expected > output > if __name__ == "__main__": > test_without_write() > test_with_write() > {code} > -- This message was sent by Atlassian Jira (v8.3.4#803005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org