Merge branch 'feature/feature_selection' of https://github.com/amaya382/hivemall into feature_selection
# Conflicts: # core/src/main/java/hivemall/utils/hadoop/HiveUtils.java # core/src/main/java/hivemall/utils/math/StatsUtils.java # spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala # spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala # spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala # spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/67ba9631 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/67ba9631 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/67ba9631 Branch: refs/heads/JIRA-22/pr-385 Commit: 67ba9631af3c231b7abd145134d17237b6aca0a5 Parents: 69496fa ce4a489 Author: myui <yuin...@gmail.com> Authored: Mon Nov 21 18:19:45 2016 +0900 Committer: myui <yuin...@gmail.com> Committed: Mon Nov 21 18:19:45 2016 +0900 ---------------------------------------------------------------------- .../hivemall/ftvec/selection/ChiSquareUDF.java | 155 ++++++++ .../ftvec/selection/SignalNoiseRatioUDAF.java | 349 +++++++++++++++++++ .../hivemall/tools/array/SelectKBestUDF.java | 143 ++++++++ .../tools/matrix/TransposeAndDotUDAF.java | 213 +++++++++++ .../java/hivemall/utils/hadoop/HiveUtils.java | 22 +- .../java/hivemall/utils/math/StatsUtils.java | 91 +++++ .../ftvec/selection/ChiSquareUDFTest.java | 80 +++++ .../selection/SignalNoiseRatioUDAFTest.java | 348 ++++++++++++++++++ .../tools/array/SelectKBeatUDFTest.java | 65 ++++ .../tools/matrix/TransposeAndDotUDAFTest.java | 58 +++ resources/ddl/define-all-as-permanent.hive | 20 ++ resources/ddl/define-all.hive | 20 ++ resources/ddl/define-all.spark | 20 ++ resources/ddl/define-udfs.td.hql | 4 + .../apache/spark/sql/hive/GroupedDataEx.scala | 21 ++ .../org/apache/spark/sql/hive/HivemallOps.scala | 18 + .../spark/sql/hive/HivemallOpsSuite.scala | 100 ++++++ .../spark/sql/hive/HivemallGroupedDataset.scala | 25 ++ .../org/apache/spark/sql/hive/HivemallOps.scala | 20 ++ .../spark/sql/hive/HivemallOpsSuite.scala | 103 ++++++ 20 files changed, 1873 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --cc core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index d8b1aef,c752188..8188b7a --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@@ -242,10 -240,16 +242,20 @@@ public final class HiveUtils return category == Category.LIST; } + public static boolean isMapOI(@Nonnull final ObjectInspector oi) { + return oi.getCategory() == Category.MAP; + } + + public static boolean isNumberListOI(@Nonnull final ObjectInspector oi) { + return isListOI(oi) + && isNumberOI(((ListObjectInspector) oi).getListElementObjectInspector()); + } + + public static boolean isNumberListListOI(@Nonnull final ObjectInspector oi) { + return isListOI(oi) + && isNumberListOI(((ListObjectInspector) oi).getListElementObjectInspector()); + } + public static boolean isPrimitiveTypeInfo(@Nonnull TypeInfo typeInfo) { return typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/core/src/main/java/hivemall/utils/math/StatsUtils.java ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala ---------------------------------------------------------------------- diff --cc spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala index fd4da64,2482c62..8f78a7f --- a/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala +++ b/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/GroupedDataEx.scala @@@ -267,13 -266,25 +267,34 @@@ final class GroupedDataEx protected[sql } /** + * @see hivemall.ftvec.trans.OnehotEncodingUDAF + */ + def onehot_encoding(features: String*): DataFrame = { + val udaf = HiveUDAFFunction( + new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"), + features.map(df.col(_).expr), + isUDAFBridgeRequired = false) ++ ++ /** + * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF + */ + def snr(X: String, Y: String): DataFrame = { + val udaf = HiveUDAFFunction( + new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"), + Seq(X, Y).map(df.col(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Seq(Alias(udaf, udaf.prettyString)())) + } + + /** + * @see hivemall.tools.matrix.TransposeAndDotUDAF + */ + def transpose_and_dot(X: String, Y: String): DataFrame = { + val udaf = HiveUDAFFunction( + new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"), + Seq(X, Y).map(df.col(_).expr), + isUDAFBridgeRequired = false) .toAggregateExpression() toDF(Seq(Alias(udaf, udaf.prettyString)())) } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-1.6/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala ---------------------------------------------------------------------- diff --cc spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala index 901056d,c7016c0..c231105 --- a/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala +++ b/spark/spark-1.6/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@@ -534,30 -570,63 +575,89 @@@ final class HivemallOpsSuite extends Hi assert(row4(0).getDouble(1) ~== 0.25) } + test("user-defined aggregators for ftvec.trans") { + import hiveContext.implicits._ + + val df0 = Seq((1, "cat", "mammal", 9), (1, "dog", "mammal", 10), (1, "human", "mammal", 10), + (1, "seahawk", "bird", 101), (1, "wasp", "insect", 3), (1, "wasp", "insect", 9), + (1, "cat", "mammal", 101), (1, "dog", "mammal", 1), (1, "human", "mammal", 9)) + .toDF("col0", "cat1", "cat2", "cat3") + + val row00 = df0.groupby($"col0").onehot_encoding("cat1") + val row01 = df0.groupby($"col0").onehot_encoding("cat1", "cat2", "cat3") + + val result000 = row00.collect()(0).getAs[Row](1).getAs[Map[String, Int]](0) + val result01 = row01.collect()(0).getAs[Row](1) + val result010 = result01.getAs[Map[String, Int]](0) + val result011 = result01.getAs[Map[String, Int]](1) + val result012 = result01.getAs[Map[String, Int]](2) + + assert(result000.keySet === Set("seahawk", "cat", "human", "wasp", "dog")) + assert(result000.values.toSet === Set(1, 2, 3, 4, 5)) + assert(result010.keySet === Set("seahawk", "cat", "human", "wasp", "dog")) + assert(result010.values.toSet === Set(1, 2, 3, 4, 5)) + assert(result011.keySet === Set("bird", "insect", "mammal")) + assert(result011.values.toSet === Set(6, 7, 8)) + assert(result012.keySet === Set(1, 3, 9, 10, 101)) + assert(result012.values.toSet === Set(9, 10, 11, 12, 13)) ++ + test("user-defined aggregators for ftvec.selection") { + import hiveContext.implicits._ + + // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest + // binary class + // +-----------------+-------+ + // | features | class | + // +-----------------+-------+ + // | 5.1,3.5,1.4,0.2 | 0 | + // | 4.9,3.0,1.4,0.2 | 0 | + // | 4.7,3.2,1.3,0.2 | 0 | + // | 7.0,3.2,4.7,1.4 | 1 | + // | 6.4,3.2,4.5,1.5 | 1 | + // | 6.9,3.1,4.9,1.5 | 1 | + // +-----------------+-------+ + val df0 = Seq( + (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)), + (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)), + (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1))) + .toDF("c0", "arg0", "arg1") + val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect + (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + + // multiple class + // +-----------------+-------+ + // | features | class | + // +-----------------+-------+ + // | 5.1,3.5,1.4,0.2 | 0 | + // | 4.9,3.0,1.4,0.2 | 0 | + // | 7.0,3.2,4.7,1.4 | 1 | + // | 6.4,3.2,4.5,1.5 | 1 | + // | 6.3,3.3,6.0,2.5 | 2 | + // | 5.8,2.7,5.1,1.9 | 2 | + // +-----------------+-------+ + val df1 = Seq( + (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)), + (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)), + (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1))) + .toDF("c0", "arg0", "arg1") + val row1 = df1.groupby($"c0").snr("arg0", "arg1").collect + (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + } + + test("user-defined aggregators for tools.matrix") { + import hiveContext.implicits._ + + // | 1 2 3 |T | 5 6 7 | + // | 3 4 5 | * | 7 8 9 | + val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9))) + .toDF("c0", "arg0", "arg1") + + // if use checkAnswer here, fail for some reason, maybe type? but it's okay on spark-2.0 + assert(df0.groupby($"c0").transpose_and_dot("arg0", "arg1").collect() === + Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0))))) } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala ---------------------------------------------------------------------- diff --cc spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala index 8ac7185,0000000..73757f6 mode 100644,000000..100644 --- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala +++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallGroupedDataset.scala @@@ -1,277 -1,0 +1,302 @@@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.hive + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.RelationalGroupedDataset +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.plans.logical.Pivot +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper +import org.apache.spark.sql.types._ + +/** + * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * + * @groupname ensemble + * @groupname ftvec.trans + * @groupname evaluation + */ +final class HivemallGroupedDataset(groupBy: RelationalGroupedDataset) { + + /** + * @see hivemall.ensemble.bagging.VotedAvgUDAF + * @group ensemble + */ + def voted_avg(weight: String): DataFrame = { + // checkType(weight, NumericType) + val udaf = HiveUDAFFunction( + "voted_avg", + new HiveFunctionWrapper("hivemall.ensemble.bagging.WeightVotedAvgUDAF"), + Seq(weight).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.ensemble.bagging.WeightVotedAvgUDAF + * @group ensemble + */ + def weight_voted_avg(weight: String): DataFrame = { + // checkType(weight, NumericType) + val udaf = HiveUDAFFunction( + "weight_voted_avg", + new HiveFunctionWrapper("hivemall.ensemble.bagging.WeightVotedAvgUDAF"), + Seq(weight).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.ensemble.ArgminKLDistanceUDAF + * @group ensemble + */ + def argmin_kld(weight: String, conv: String): DataFrame = { + // checkType(weight, NumericType) + // checkType(conv, NumericType) + val udaf = HiveUDAFFunction( + "argmin_kld", + new HiveFunctionWrapper("hivemall.ensemble.ArgminKLDistanceUDAF"), + Seq(weight, conv).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.ensemble.MaxValueLabelUDAF" + * @group ensemble + */ + def max_label(score: String, label: String): DataFrame = { + // checkType(score, NumericType) + checkType(label, StringType) + val udaf = HiveUDAFFunction( + "max_label", + new HiveFunctionWrapper("hivemall.ensemble.MaxValueLabelUDAF"), + Seq(score, label).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.ensemble.MaxRowUDAF + * @group ensemble + */ + def maxrow(score: String, label: String): DataFrame = { + // checkType(score, NumericType) + checkType(label, StringType) + val udaf = HiveUDAFFunction( + "maxrow", + new HiveFunctionWrapper("hivemall.ensemble.MaxRowUDAF"), + Seq(score, label).map(df.col(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.smile.tools.RandomForestEnsembleUDAF + * @group ensemble + */ + def rf_ensemble(predict: String): DataFrame = { + // checkType(predict, NumericType) + val udaf = HiveUDAFFunction( + "rf_ensemble", + new HiveFunctionWrapper("hivemall.smile.tools.RandomForestEnsembleUDAF"), + Seq(predict).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.ftvec.trans.OnehotEncodingUDAF + * @group ftvec.trans + */ + def onehot_encoding(cols: String*): DataFrame = { + val udaf = HiveUDAFFunction( + "onehot_encoding", + new HiveFunctionWrapper("hivemall.ftvec.trans.OnehotEncodingUDAF"), + cols.map(df.col(_).expr), + isUDAFBridgeRequired = false) + .toAggregateExpression() + toDF(Seq(Alias(udaf, udaf.prettyName)())) + } + + /** + * @see hivemall.evaluation.MeanAbsoluteErrorUDAF + * @group evaluation + */ + def mae(predict: String, target: String): DataFrame = { + checkType(predict, FloatType) + checkType(target, FloatType) + val udaf = HiveUDAFFunction( + "mae", + new HiveFunctionWrapper("hivemall.evaluation.MeanAbsoluteErrorUDAF"), + Seq(predict, target).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.evaluation.MeanSquareErrorUDAF + * @group evaluation + */ + def mse(predict: String, target: String): DataFrame = { + checkType(predict, FloatType) + checkType(target, FloatType) + val udaf = HiveUDAFFunction( + "mse", + new HiveFunctionWrapper("hivemall.evaluation.MeanSquaredErrorUDAF"), + Seq(predict, target).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.evaluation.RootMeanSquareErrorUDAF + * @group evaluation + */ + def rmse(predict: String, target: String): DataFrame = { + checkType(predict, FloatType) + checkType(target, FloatType) + val udaf = HiveUDAFFunction( + "rmse", + new HiveFunctionWrapper("hivemall.evaluation.RootMeanSquaredErrorUDAF"), + Seq(predict, target).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * @see hivemall.evaluation.FMeasureUDAF + * @group evaluation + */ + def f1score(predict: String, target: String): DataFrame = { + // checkType(target, ArrayType(IntegerType)) + // checkType(predict, ArrayType(IntegerType)) + val udaf = HiveUDAFFunction( + "f1score", + new HiveFunctionWrapper("hivemall.evaluation.FMeasureUDAF"), + Seq(predict, target).map(df.col(_).expr), + isUDAFBridgeRequired = true) + .toAggregateExpression() + toDF((Alias(udaf, udaf.prettyName)() :: Nil).toSeq) + } + + /** + * [[RelationalGroupedDataset]] has the three values as private fields, so, to inject Hivemall + * aggregate functions, we fetch them via Java Reflections. + */ + private val df = getPrivateField[DataFrame]("org$apache$spark$sql$RelationalGroupedDataset$$df") + private val groupingExprs = getPrivateField[Seq[Expression]]("groupingExprs") + private val groupType = getPrivateField[RelationalGroupedDataset.GroupType]("groupType") + + private def getPrivateField[T](name: String): T = { + val field = groupBy.getClass.getDeclaredField(name) + field.setAccessible(true) + field.get(groupBy).asInstanceOf[T] + } + + private def toDF(aggExprs: Seq[Expression]): DataFrame = { + val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { + groupingExprs ++ aggExprs + } else { + aggExprs + } + + val aliasedAgg = aggregates.map(alias) + + groupType match { + case RelationalGroupedDataset.GroupByType => + Dataset.ofRows( + df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.RollupType => + Dataset.ofRows( + df.sparkSession, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.CubeType => + Dataset.ofRows( + df.sparkSession, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) + case RelationalGroupedDataset.PivotType(pivotCol, values) => + val aliasedGrps = groupingExprs.map(alias) + Dataset.ofRows( + df.sparkSession, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) + } + } + + private def alias(expr: Expression): NamedExpression = expr match { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyName)() + } + + private def checkType(colName: String, expected: DataType) = { + val dataType = df.resolve(colName).dataType + if (dataType != expected) { + throw new AnalysisException( + s""""$colName" must be $expected, however it is $dataType""") + } + } +} + +object HivemallGroupedDataset { + + /** + * Implicitly inject the [[HivemallGroupedDataset]] into [[RelationalGroupedDataset]]. + */ + implicit def relationalGroupedDatasetToHivemallOne( + groupBy: RelationalGroupedDataset): HivemallGroupedDataset = { + new HivemallGroupedDataset(groupBy) ++ ++ /** ++ * @see hivemall.ftvec.selection.SignalNoiseRatioUDAF ++ */ ++ def snr(X: String, Y: String): DataFrame = { ++ val udaf = HiveUDAFFunction( ++ "snr", ++ new HiveFunctionWrapper("hivemall.ftvec.selection.SignalNoiseRatioUDAF"), ++ Seq(X, Y).map(df.col(_).expr), ++ isUDAFBridgeRequired = false) ++ .toAggregateExpression() ++ toDF(Seq(Alias(udaf, udaf.prettyName)())) ++ } ++ ++ /** ++ * @see hivemall.tools.matrix.TransposeAndDotUDAF ++ */ ++ def transpose_and_dot(X: String, Y: String): DataFrame = { ++ val udaf = HiveUDAFFunction( ++ "transpose_and_dot", ++ new HiveFunctionWrapper("hivemall.tools.matrix.TransposeAndDotUDAF"), ++ Seq(X, Y).map(df.col(_).expr), ++ isUDAFBridgeRequired = false) ++ .toAggregateExpression() ++ toDF(Seq(Alias(udaf, udaf.prettyName)())) + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/67ba9631/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala ---------------------------------------------------------------------- diff --cc spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala index a093e07,8446677..8bea975 --- a/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala +++ b/spark/spark-2.0/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@@ -1,31 -1,28 +1,37 @@@ /* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ - package org.apache.spark.sql.hive +import org.apache.spark.sql.{AnalysisException, Column, Row} +import org.apache.spark.sql.functions +import org.apache.spark.sql.hive.HivemallGroupedDataset._ +import org.apache.spark.sql.hive.HivemallOps._ +import org.apache.spark.sql.hive.HivemallUtils._ +import org.apache.spark.sql.types._ +import org.apache.spark.test.{HivemallFeatureQueryTest, TestUtils, VectorQueryTest} +import org.apache.spark.test.TestDoubleWrapper._ + import org.apache.spark.sql.hive.HivemallOps._ + import org.apache.spark.sql.hive.HivemallUtils._ + import org.apache.spark.sql.types._ + import org.apache.spark.sql.{AnalysisException, Column, Row, functions} + import org.apache.spark.test.TestDoubleWrapper._ + import org.apache.spark.test.{HivemallFeatureQueryTest, TestUtils, VectorQueryTest} final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { @@@ -636,30 -685,63 +681,88 @@@ assert(row4(0).getDouble(1) ~== 0.25) } + test("user-defined aggregators for ftvec.trans") { + import hiveContext.implicits._ + + val df0 = Seq((1, "cat", "mammal", 9), (1, "dog", "mammal", 10), (1, "human", "mammal", 10), + (1, "seahawk", "bird", 101), (1, "wasp", "insect", 3), (1, "wasp", "insect", 9), + (1, "cat", "mammal", 101), (1, "dog", "mammal", 1), (1, "human", "mammal", 9)) + .toDF("col0", "cat1", "cat2", "cat3") + val row00 = df0.groupBy($"col0").onehot_encoding("cat1") + val row01 = df0.groupBy($"col0").onehot_encoding("cat1", "cat2", "cat3") + + val result000 = row00.collect()(0).getAs[Row](1).getAs[Map[String, Int]](0) + val result01 = row01.collect()(0).getAs[Row](1) + val result010 = result01.getAs[Map[String, Int]](0) + val result011 = result01.getAs[Map[String, Int]](1) + val result012 = result01.getAs[Map[String, Int]](2) + + assert(result000.keySet === Set("seahawk", "cat", "human", "wasp", "dog")) + assert(result000.values.toSet === Set(1, 2, 3, 4, 5)) + assert(result010.keySet === Set("seahawk", "cat", "human", "wasp", "dog")) + assert(result010.values.toSet === Set(1, 2, 3, 4, 5)) + assert(result011.keySet === Set("bird", "insect", "mammal")) + assert(result011.values.toSet === Set(6, 7, 8)) + assert(result012.keySet === Set(1, 3, 9, 10, 101)) + assert(result012.values.toSet === Set(9, 10, 11, 12, 13)) ++ + test("user-defined aggregators for ftvec.selection") { + import hiveContext.implicits._ + + // see also hivemall.ftvec.selection.SignalNoiseRatioUDAFTest + // binary class + // +-----------------+-------+ + // | features | class | + // +-----------------+-------+ + // | 5.1,3.5,1.4,0.2 | 0 | + // | 4.9,3.0,1.4,0.2 | 0 | + // | 4.7,3.2,1.3,0.2 | 0 | + // | 7.0,3.2,4.7,1.4 | 1 | + // | 6.4,3.2,4.5,1.5 | 1 | + // | 6.9,3.1,4.9,1.5 | 1 | + // +-----------------+-------+ + val df0 = Seq( + (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0)), + (1, Seq(4.7, 3.2, 1.3, 0.2), Seq(1, 0)), (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1)), + (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1)), (1, Seq(6.9, 3.1, 4.9, 1.5), Seq(0, 1))) + .toDF("c0", "arg0", "arg1") + val row0 = df0.groupby($"c0").snr("arg0", "arg1").collect + (row0(0).getAs[Seq[Double]](1), Seq(4.38425236, 0.26390002, 15.83984511, 26.87005769)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + + // multiple class + // +-----------------+-------+ + // | features | class | + // +-----------------+-------+ + // | 5.1,3.5,1.4,0.2 | 0 | + // | 4.9,3.0,1.4,0.2 | 0 | + // | 7.0,3.2,4.7,1.4 | 1 | + // | 6.4,3.2,4.5,1.5 | 1 | + // | 6.3,3.3,6.0,2.5 | 2 | + // | 5.8,2.7,5.1,1.9 | 2 | + // +-----------------+-------+ + val df1 = Seq( + (1, Seq(5.1, 3.5, 1.4, 0.2), Seq(1, 0, 0)), (1, Seq(4.9, 3.0, 1.4, 0.2), Seq(1, 0, 0)), + (1, Seq(7.0, 3.2, 4.7, 1.4), Seq(0, 1, 0)), (1, Seq(6.4, 3.2, 4.5, 1.5), Seq(0, 1, 0)), + (1, Seq(6.3, 3.3, 6.0, 2.5), Seq(0, 0, 1)), (1, Seq(5.8, 2.7, 5.1, 1.9), Seq(0, 0, 1))) + .toDF("c0", "arg0", "arg1") + val row1 = df1.groupby($"c0").snr("arg0", "arg1").collect + (row1(0).getAs[Seq[Double]](1), Seq(8.43181818, 1.32121212, 42.94949495, 33.80952381)) + .zipped + .foreach((actual, expected) => assert(actual ~== expected)) + } + + test("user-defined aggregators for tools.matrix") { + import hiveContext.implicits._ + + // | 1 2 3 |T | 5 6 7 | + // | 3 4 5 | * | 7 8 9 | + val df0 = Seq((1, Seq(1, 2, 3), Seq(5, 6, 7)), (1, Seq(3, 4, 5), Seq(7, 8, 9))) + .toDF("c0", "arg0", "arg1") + + checkAnswer(df0.groupby($"c0").transpose_and_dot("arg0", "arg1"), + Seq(Row(1, Seq(Seq(26.0, 30.0, 34.0), Seq(38.0, 44.0, 50.0), Seq(50.0, 58.0, 66.0))))) } }