This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4e0bd3c5717 [SPARK-43097][FOLLOW-UP][ML] Improve logistic regression model saving 4e0bd3c5717 is described below commit 4e0bd3c571758f441d56d0341d6c4f506fdb1565 Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Sat Jun 17 22:30:39 2023 +0800 [SPARK-43097][FOLLOW-UP][ML] Improve logistic regression model saving ### What changes were proposed in this pull request? Improve logistic regression model saving: Current master code, it saves the core pytorch model that only includes the "Linear" layer, to make the saved pytorch model easier to use solely without pyspark, I append a "softmax" layer to the torch model and then save it. ### Why are the changes needed? Improving the saved pytorch model in `LogisticRegressionModel.save` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Closes #41629 from WeichenXu123/improve-lor-model-save. Authored-by: Weichen Xu <weichen...@databricks.com> Signed-off-by: Weichen Xu <weichen...@databricks.com> --- python/pyspark/mlv2/classification.py | 9 +++++++-- python/pyspark/mlv2/tests/test_classification.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mlv2/classification.py b/python/pyspark/mlv2/classification.py index 0fcded0d769..fe0d76837f9 100644 --- a/python/pyspark/mlv2/classification.py +++ b/python/pyspark/mlv2/classification.py @@ -331,10 +331,15 @@ class LogisticRegressionModel(PredictionModel, _LogisticRegressionParams, ModelR return self.__class__.__name__ + ".torch" def _save_core_model(self, path: str) -> None: - torch.save(self.torch_model, path) + lor_torch_model = torch_nn.Sequential( + self.torch_model, + torch_nn.Softmax(dim=1), + ) + torch.save(lor_torch_model, path) def _load_core_model(self, path: str) -> None: - self.torch_model = torch.load(path) + lor_torch_model = torch.load(path) + self.torch_model = lor_torch_model[0] def _get_extra_metadata(self) -> Dict[str, Any]: return { diff --git a/python/pyspark/mlv2/tests/test_classification.py b/python/pyspark/mlv2/tests/test_classification.py index 159862ef5f6..7f7d43b9cc8 100644 --- a/python/pyspark/mlv2/tests/test_classification.py +++ b/python/pyspark/mlv2/tests/test_classification.py @@ -167,10 +167,29 @@ class ClassificationTestsMixin: ) model = estimator.fit(training_dataset) + model_predictions = model.transform(eval_df1.toPandas()) + assert model.uid == estimator.uid local_model_path = os.path.join(tmp_dir, "model") model.saveToLocal(local_model_path) + + # test saved torch model can be loaded by pytorch solely + lor_torch_model = torch.load( + os.path.join(local_model_path, "LogisticRegressionModel.torch") + ) + + with torch.inference_mode(): + torch_infer_result = lor_torch_model( + torch.tensor(np.stack(list(eval_df1.toPandas().features)), dtype=torch.float32) + ).numpy() + + np.testing.assert_allclose( + np.stack(list(model_predictions.probability)), + torch_infer_result, + rtol=1e-4, + ) + loaded_model = LORV2Model.loadFromLocal(local_model_path) assert loaded_model.numFeatures == 2 assert loaded_model.numClasses == 2 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org