Github user jomach commented on a diff in the pull request: https://github.com/apache/spark/pull/7842#discussion_r144642055 --- Diff: mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLTreeModelUtils.scala --- @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.collection.mutable +import scala.collection.JavaConverters._ + +import org.dmg.pmml.{Node => PMMLNode, Value => PMMLValue, _} + +import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType} +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} + +private[mllib] object PMMLTreeModelUtils { + + val FieldNamePrefix = "field_" + + def toPMMLTree(dtModel: DecisionTreeModel, modelName: String): (TreeModel, List[DataField]) = { + + val miningFunctionType = dtModel.algo match { + case Algo.Classification => MiningFunctionType.CLASSIFICATION + case Algo.Regression => MiningFunctionType.REGRESSION + } + + val treeModel = new TreeModel() + .setModelName(modelName) + .setFunctionName(miningFunctionType) + .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT) + + var (rootNode, miningFields, dataFields, classes) = buildStub(dtModel.topNode, dtModel.algo) + + // adding predicted classes for classification and target field for regression for completeness + dtModel.algo match { + + case Algo.Classification => + miningFields = miningFields :+ new MiningField() + .setName(FieldName.create("class")) + .setUsageType(FieldUsageType.PREDICTED) + + val dataField = new DataField() + .setName(FieldName.create("class")) + .setOpType(OpType.CATEGORICAL) + .addValues(classes: _*) + .setDataType(DataType.DOUBLE) + + dataFields = dataFields :+ dataField + + case Algo.Regression => + val targetField = FieldName.create("target") + val dataField = new DataField(targetField, OpType.CONTINUOUS, DataType.DOUBLE) + dataFields = dataFields :+ dataField + + miningFields = miningFields :+ new MiningField() + .setName(targetField) + .setUsageType(FieldUsageType.TARGET) + + } + + val miningSchema = new MiningSchema().addMiningFields(miningFields: _*) + + treeModel.setNode(rootNode).setMiningSchema(miningSchema) + + (treeModel, dataFields) + } + + /** Build a pmml tree stub given the root mllib node. */ + private def buildStub(rootDTNode: Node, algo: Algo): + (PMMLNode, List[MiningField], List[DataField], List[PMMLValue]) = { + + val miningFields = mutable.MutableList[MiningField]() + val dataFields = mutable.HashMap[String, DataField]() + val classes = mutable.MutableList[Double]() + + def buildStubInternal(rootNode: Node, predicate: Predicate): PMMLNode = { + + // get rootPMML node for the MLLib node + val rootPMMLNode = new PMMLNode() + .setId(rootNode.id.toString) + .setScore(rootNode.predict.predict.toString) + .setPredicate(predicate) + + var leftPredicate: Predicate = new True() + var rightPredicate: Predicate = new True() + + if (rootNode.split.isDefined) { + val fieldName = FieldName.create(FieldNamePrefix + rootNode.split.get.feature) + val dataField = getDataField(rootNode, fieldName).get + + if (dataFields.get(dataField.getName.getValue).isEmpty) { + dataFields.put(dataField.getName.getValue, dataField) + miningFields += new MiningField() + .setName(dataField.getName) + .setUsageType(FieldUsageType.ACTIVE) + + } else if (dataField.getOpType != OpType.CONTINUOUS) { + appendCategories( + dataFields.get(dataField.getName.getValue).get, + dataField.getValues.asScala.toList) + } + + leftPredicate = getPredicate(rootNode, Some(dataField.getName), true) + rightPredicate = getPredicate(rootNode, Some(dataField.getName), false) + } + // if left node exist, add the node + if (rootNode.leftNode.isDefined) { + val leftNode = buildStubInternal(rootNode.leftNode.get, leftPredicate) + rootPMMLNode.addNodes(leftNode) + } + // if right node exist, add the node + if (rootNode.rightNode.isDefined) { + val rightNode = buildStubInternal(rootNode.rightNode.get, rightPredicate) + rootPMMLNode.addNodes(rightNode) + } + + // add to the list of classes + if (rootNode.isLeaf && (algo == Algo.Classification)) { + classes += rootNode.predict.predict + } + + rootPMMLNode + } + + val pmmlTreeRootNode = buildStubInternal(rootDTNode, new True()) + + val pmmlValues = classes.toList.distinct.map(doubleVal => new PMMLValue(doubleVal.toString)) + --- End diff -- remove blank Line
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org