Repository: ignite Updated Branches: refs/heads/master e539dfc19 -> 85bfcc733
IGNITE-10803: [ML] Add prototype LogReg loading from PMML format This closes #5744 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/85bfcc73 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/85bfcc73 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/85bfcc73 Branch: refs/heads/master Commit: 85bfcc7331d013d8738ab71b8564a67c84e45cf6 Parents: e539dfc Author: zaleslaw <zaleslaw....@gmail.com> Authored: Thu Dec 27 16:17:00 2018 +0300 Committer: Yury Babak <yba...@gridgain.com> Committed: Thu Dec 27 16:17:00 2018 +0300 ---------------------------------------------------------------------- examples/pom.xml | 25 +++++ .../LogRegFromSparkThroughPMMLExample.java | 108 +++++++++++++++++++ .../src/main/resources/models/spark/iris.pmml | 30 ++++++ 3 files changed, 163 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/85bfcc73/examples/pom.xml ---------------------------------------------------------------------- diff --git a/examples/pom.xml b/examples/pom.xml index 429ec79..6320a0f 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -122,6 +122,31 @@ <version>${javassist.version}</version> <scope>test</scope> </dependency> + <!-- https://mvnrepository.com/artifact/org.jpmml/pmml-model --> + <dependency> + <groupId>org.jpmml</groupId> + <artifactId>pmml-model</artifactId> + <version>1.4.7</version> + </dependency> + + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-core</artifactId> + <version>2.7.3</version> + </dependency> + + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-databind</artifactId> + <version>2.7.3</version> + </dependency> + + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-annotations</artifactId> + <version>2.7.3</version> + </dependency> + </dependencies> <properties> http://git-wip-us.apache.org/repos/asf/ignite/blob/85bfcc73/examples/src/main/java/org/apache/ignite/examples/ml/inference/LogRegFromSparkThroughPMMLExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/LogRegFromSparkThroughPMMLExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/LogRegFromSparkThroughPMMLExample.java new file mode 100644 index 0000000..30a4498 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/LogRegFromSparkThroughPMMLExample.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.inference; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import javax.xml.bind.JAXBException; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.ml.math.primitives.vector.Vector; +import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector; +import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel; +import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator; +import org.apache.ignite.ml.selection.scoring.metric.Accuracy; +import org.apache.ignite.ml.util.MLSandboxDatasets; +import org.apache.ignite.ml.util.SandboxMLCache; +import org.dmg.pmml.PMML; +import org.dmg.pmml.regression.RegressionModel; +import org.dmg.pmml.regression.RegressionTable; +import org.jpmml.model.PMMLUtil; +import org.xml.sax.SAXException; + +/** + * Run logistic regression model loaded from PMML file. The PMML file was generated by Spark MLLib toPMML operator. + * <p> + * Code in this example launches Ignite grid and fills the cache with test data points (based on the + * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p> + * <p> + * You can change the test data used in this example and re-run it to explore this algorithm further.</p> + */ +public class LogRegFromSparkThroughPMMLExample { + /** Run example. */ + public static void main(String[] args) throws FileNotFoundException { + System.out.println(); + System.out.println(">>> Logistic regression model loaded from PMML over partitioned dataset usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite) + .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS); + + LogisticRegressionModel mdl = PMMLParser.load("examples/src/main/resources/models/spark/iris.pmml"); + + System.out.println(">>> Logistic regression model: " + mdl); + + double accuracy = BinaryClassificationEvaluator.evaluate( + dataCache, + mdl, + (k, v) -> v.copyOfRange(1, v.size()), + (k, v) -> v.get(0), + new Accuracy<>() + ); + + System.out.println("\n>>> Accuracy " + accuracy); + System.out.println("\n>>> Test Error " + (1 - accuracy)); + } + } + + /** Util class to build the LogReg model. */ + private static class PMMLParser { + /** + * @param path Path. + */ + public static LogisticRegressionModel load(String path) { + try (InputStream is = new FileInputStream(new File(path))) { + PMML pmml = PMMLUtil.unmarshal(is); + + RegressionModel logRegMdl = (RegressionModel)pmml.getModels().get(0); + + RegressionTable regTbl = logRegMdl.getRegressionTables().get(0); + + Vector coefficients = new DenseVector(regTbl.getNumericPredictors().size()); + + for (int i = 0; i < regTbl.getNumericPredictors().size(); i++) + coefficients.set(i, regTbl.getNumericPredictors().get(i).getCoefficient()); + + double interceptor = regTbl.getIntercept(); + + return new LogisticRegressionModel(coefficients, interceptor); + } + catch (IOException | JAXBException | SAXException e) { + e.printStackTrace(); + } + + return null; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/85bfcc73/examples/src/main/resources/models/spark/iris.pmml ---------------------------------------------------------------------- diff --git a/examples/src/main/resources/models/spark/iris.pmml b/examples/src/main/resources/models/spark/iris.pmml new file mode 100644 index 0000000..78f310d --- /dev/null +++ b/examples/src/main/resources/models/spark/iris.pmml @@ -0,0 +1,30 @@ +<?xml version="1.0" encoding="UTF-8" standalone="yes"?> +<PMML xmlns="http://www.dmg.org/PMML-4_2" version="4.2"> + <Header description="logistic regression"> + <Application name="Apache Spark MLlib" version="2.2.0"/> + <Timestamp>2018-12-25T15:09:09</Timestamp> + </Header> + <DataDictionary numberOfFields="5"> + <DataField name="field_0" optype="continuous" dataType="double"/> + <DataField name="field_1" optype="continuous" dataType="double"/> + <DataField name="field_2" optype="continuous" dataType="double"/> + <DataField name="field_3" optype="continuous" dataType="double"/> + <DataField name="target" optype="categorical" dataType="string"/> + </DataDictionary> + <RegressionModel modelName="logistic regression" functionName="classification" normalizationMethod="logit"> + <MiningSchema> + <MiningField name="field_0" usageType="active"/> + <MiningField name="field_1" usageType="active"/> + <MiningField name="field_2" usageType="active"/> + <MiningField name="field_3" usageType="active"/> + <MiningField name="target" usageType="target"/> + </MiningSchema> + <RegressionTable intercept="0.0" targetCategory="1"> + <NumericPredictor name="field_0" coefficient="5.84520630732407"/> + <NumericPredictor name="field_1" coefficient="-19.36222130270906"/> + <NumericPredictor name="field_2" coefficient="5.66074235971065"/> + <NumericPredictor name="field_3" coefficient="16.110585062151788"/> + </RegressionTable> + <RegressionTable intercept="-0.0" targetCategory="0"/> + </RegressionModel> +</PMML>