DRILL-842: Aggregate function for correlation coefficient calculation
Project: http://git-wip-us.apache.org/repos/asf/incubator-drill/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-drill/commit/e170cf0b Tree: http://git-wip-us.apache.org/repos/asf/incubator-drill/tree/e170cf0b Diff: http://git-wip-us.apache.org/repos/asf/incubator-drill/diff/e170cf0b Branch: refs/heads/master Commit: e170cf0b302e8f6b625e0f96a66d70da680c594e Parents: 7fe8a15 Author: Yash Sharma <[email protected]> Authored: Tue Jun 17 16:25:01 2014 -0700 Committer: Jacques Nadeau <[email protected]> Committed: Tue Jun 17 16:25:01 2014 -0700 ---------------------------------------------------------------------- exec/java-exec/src/main/codegen/config.fmpp | 6 +- .../src/main/codegen/data/CorrelationTypes.tdd | 43 +++++ .../src/main/codegen/data/CovarTypes.tdd | 66 ++++++++ .../templates/CorrelationTypeFunctions.java | 156 +++++++++++++++++++ .../codegen/templates/CovarTypeFunctions.java | 149 ++++++++++++++++++ .../exec/fn/impl/TestAggregateFunction.java | 39 +++-- .../src/test/resources/covariance_input.json | 8 + .../resources/functions/test_covariance.json | 85 ++++++++++ 8 files changed, 534 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-drill/blob/e170cf0b/exec/java-exec/src/main/codegen/config.fmpp ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/main/codegen/config.fmpp b/exec/java-exec/src/main/codegen/config.fmpp index 5c0d03b..d00e24a 100644 --- a/exec/java-exec/src/main/codegen/config.fmpp +++ b/exec/java-exec/src/main/codegen/config.fmpp @@ -25,6 +25,8 @@ data: { decimalaggrtypes1: tdd(../data/DecimalAggrTypes1.tdd), aggrtypes2: tdd(../data/AggrTypes2.tdd), aggrtypes3: tdd(../data/AggrTypes3.tdd), + covarTypes: tdd(../data/CovarTypes.tdd), + corrTypes: tdd(../data/CorrelationTypes.tdd), date: tdd(../data/DateTypes.tdd), extract: tdd(../data/ExtractTypes.tdd), parser: tdd(../data/Parser.tdd), @@ -36,6 +38,4 @@ data: { } freemarkerLinks: { includes: includes/ -} - - +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-drill/blob/e170cf0b/exec/java-exec/src/main/codegen/data/CorrelationTypes.tdd ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/main/codegen/data/CorrelationTypes.tdd b/exec/java-exec/src/main/codegen/data/CorrelationTypes.tdd new file mode 100644 index 0000000..cc6d2a5 --- /dev/null +++ b/exec/java-exec/src/main/codegen/data/CorrelationTypes.tdd @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http:# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +{ + correlationTypes: [ + {className: "Correlation", funcName: "corr", aliasName: "correlation", types: [ + {inputType: "BigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableBigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "Int", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "SmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableSmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "TinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableTinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "UInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "UInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "UInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "UInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "Float4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableFloat4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "Float8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableFloat8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"} + ] + } + ] +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-drill/blob/e170cf0b/exec/java-exec/src/main/codegen/data/CovarTypes.tdd ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/main/codegen/data/CovarTypes.tdd b/exec/java-exec/src/main/codegen/data/CovarTypes.tdd new file mode 100644 index 0000000..d0ec489 --- /dev/null +++ b/exec/java-exec/src/main/codegen/data/CovarTypes.tdd @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http:# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +{ + covarianceTypes: [ + {className: "CoVarianceSample", funcName: "covar_samp", aliasName: "covariance", types: [ + {inputType: "BigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableBigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "Int", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "SmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableSmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "TinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableTinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "UInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "UInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "UInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "UInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "Float4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableFloat4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "Float8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableFloat8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"} + ] + }, + {className: "CoVariancePopulation", funcName: "covar_pop", aliasName: "", types: [ + {inputType: "BigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableBigInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "Int", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "SmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableSmallInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "TinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableTinyInt", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "UInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt1", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "UInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt2", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "UInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "UInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableUInt8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "Float4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableFloat4", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "Float8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"}, + {inputType: "NullableFloat8", outputType: "Float8", movingAverageType: "Float8", movingDeviationType: "Float8", countRunningType: "BigInt"} + ] + } + ] +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-drill/blob/e170cf0b/exec/java-exec/src/main/codegen/templates/CorrelationTypeFunctions.java ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/main/codegen/templates/CorrelationTypeFunctions.java b/exec/java-exec/src/main/codegen/templates/CorrelationTypeFunctions.java new file mode 100644 index 0000000..19f9c59 --- /dev/null +++ b/exec/java-exec/src/main/codegen/templates/CorrelationTypeFunctions.java @@ -0,0 +1,156 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +<@pp.dropOutputFile /> + + + +<#list corrTypes.correlationTypes as aggrtype> +<@pp.changeOutputFile name="/org/apache/drill/exec/expr/fn/impl/gaggr/${aggrtype.className}Functions.java" /> + +<#include "/@includes/license.ftl" /> + +<#-- A utility class that is used to generate java code for corr/correlation aggr functions --> + +/* + * This class is automatically generated from CorrelationTypes.tdd using FreeMarker. + */ + +package org.apache.drill.exec.expr.fn.impl.gaggr; + +import org.apache.drill.exec.expr.DrillAggFunc; +import org.apache.drill.exec.expr.annotations.FunctionTemplate; +import org.apache.drill.exec.expr.annotations.FunctionTemplate.NullHandling; +import org.apache.drill.exec.expr.annotations.FunctionTemplate.FunctionScope; +import org.apache.drill.exec.expr.annotations.Output; +import org.apache.drill.exec.expr.annotations.Param; +import org.apache.drill.exec.expr.annotations.Workspace; +import org.apache.drill.exec.expr.holders.BigIntHolder; +import org.apache.drill.exec.expr.holders.NullableBigIntHolder; +import org.apache.drill.exec.expr.holders.IntHolder; +import org.apache.drill.exec.expr.holders.NullableIntHolder; +import org.apache.drill.exec.expr.holders.SmallIntHolder; +import org.apache.drill.exec.expr.holders.NullableSmallIntHolder; +import org.apache.drill.exec.expr.holders.TinyIntHolder; +import org.apache.drill.exec.expr.holders.NullableTinyIntHolder; +import org.apache.drill.exec.expr.holders.UInt1Holder; +import org.apache.drill.exec.expr.holders.NullableUInt1Holder; +import org.apache.drill.exec.expr.holders.UInt2Holder; +import org.apache.drill.exec.expr.holders.NullableUInt2Holder; +import org.apache.drill.exec.expr.holders.UInt4Holder; +import org.apache.drill.exec.expr.holders.NullableUInt4Holder; +import org.apache.drill.exec.expr.holders.UInt8Holder; +import org.apache.drill.exec.expr.holders.NullableUInt8Holder; +import org.apache.drill.exec.record.RecordBatch; +import org.apache.drill.exec.expr.holders.NullableFloat8Holder; +import org.apache.drill.exec.expr.holders.NullableFloat4Holder; +import org.apache.drill.exec.expr.holders.Float8Holder; +import org.apache.drill.exec.expr.holders.Float4Holder; + +@SuppressWarnings("unused") + +public class ${aggrtype.className}Functions { + static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(${aggrtype.className}Functions.class); + +<#list aggrtype.types as type> + +@FunctionTemplate(names = {"${aggrtype.funcName}", "${aggrtype.aliasName}"}, scope = FunctionTemplate.FunctionScope.POINT_AGGREGATE) +public static class ${type.inputType}${aggrtype.className} implements DrillAggFunc{ + + @Param ${type.inputType}Holder xIn; + @Param ${type.inputType}Holder yIn; + + @Workspace ${type.movingAverageType}Holder xMean; + @Workspace ${type.movingAverageType}Holder yMean; + @Workspace ${type.movingAverageType}Holder xyMean; + + @Workspace ${type.movingAverageType}Holder xDev; + @Workspace ${type.movingAverageType}Holder yDev; + + @Workspace ${type.movingDeviationType}Holder covar; + + @Workspace ${type.countRunningType}Holder count; + @Output ${type.outputType}Holder out; + + public void setup(RecordBatch b) { + xMean = new ${type.movingAverageType}Holder(); + yMean = new ${type.movingAverageType}Holder(); + xyMean = new ${type.movingDeviationType}Holder(); + xDev = new ${type.movingDeviationType}Holder(); + yDev = new ${type.movingDeviationType}Holder(); + count = new ${type.countRunningType}Holder(); + covar = new ${type.movingDeviationType}Holder(); + + // Initialize the workspace variables + xMean.value = 0; + yMean.value = 0; + xyMean.value = 0; + xDev.value = 0; + yDev.value = 0; + count.value = 1; + covar.value = 0; + } + + @Override + public void add() { + <#if type.inputType?starts_with("Nullable")> + sout: { + if (xIn.isSet == 0 || yIn.isSet == 0) { + // processing nullable input and the value is null, so don't do anything... + break sout; + } + </#if> + + // compute covariance + double xOldMean = xMean.value, yOldMean = yMean.value; + + xMean.value += ((xIn.value - xMean.value) / count.value); + yMean.value += ((yIn.value - yMean.value) / count.value); + + xDev.value += (xIn.value - xOldMean) * (xIn.value - xMean.value); + yDev.value += (yIn.value - yOldMean) * (yIn.value - yMean.value); + + xyMean.value += ((xIn.value * yIn.value) - xyMean.value) / count.value; + count.value++; + <#if type.inputType?starts_with("Nullable")> + } // end of sout block + </#if> + } + + @Override + public void output() { + double xVariance = (xDev.value / (count.value - 1)); + double yVariance = (yDev.value / (count.value - 1)); + double xyCovariance = (xyMean.value - (xMean.value * yMean.value)); + + out.value = xyCovariance / Math.sqrt((xVariance * yVariance)); + } + + @Override + public void reset() { + xMean.value = 0; + yMean.value = 0; + xyMean.value = 0; + count.value = 1; + covar.value = 0; + } +} + + +</#list> +} +</#list> \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-drill/blob/e170cf0b/exec/java-exec/src/main/codegen/templates/CovarTypeFunctions.java ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/main/codegen/templates/CovarTypeFunctions.java b/exec/java-exec/src/main/codegen/templates/CovarTypeFunctions.java new file mode 100644 index 0000000..b8131c2 --- /dev/null +++ b/exec/java-exec/src/main/codegen/templates/CovarTypeFunctions.java @@ -0,0 +1,149 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +<@pp.dropOutputFile /> + + + +<#list covarTypes.covarianceTypes as aggrtype> +<@pp.changeOutputFile name="/org/apache/drill/exec/expr/fn/impl/gaggr/${aggrtype.className}Functions.java" /> + +<#include "/@includes/license.ftl" /> + +<#-- A utility class that is used to generate java code for covariance functions --> + +/* + * This class is automatically generated from CovarType.tdd using FreeMarker. + */ + +package org.apache.drill.exec.expr.fn.impl.gaggr; + +import org.apache.drill.exec.expr.DrillAggFunc; +import org.apache.drill.exec.expr.annotations.FunctionTemplate; +import org.apache.drill.exec.expr.annotations.FunctionTemplate.NullHandling; +import org.apache.drill.exec.expr.annotations.FunctionTemplate.FunctionScope; +import org.apache.drill.exec.expr.annotations.Output; +import org.apache.drill.exec.expr.annotations.Param; +import org.apache.drill.exec.expr.annotations.Workspace; +import org.apache.drill.exec.expr.holders.BigIntHolder; +import org.apache.drill.exec.expr.holders.NullableBigIntHolder; +import org.apache.drill.exec.expr.holders.IntHolder; +import org.apache.drill.exec.expr.holders.NullableIntHolder; +import org.apache.drill.exec.expr.holders.SmallIntHolder; +import org.apache.drill.exec.expr.holders.NullableSmallIntHolder; +import org.apache.drill.exec.expr.holders.TinyIntHolder; +import org.apache.drill.exec.expr.holders.NullableTinyIntHolder; +import org.apache.drill.exec.expr.holders.UInt1Holder; +import org.apache.drill.exec.expr.holders.NullableUInt1Holder; +import org.apache.drill.exec.expr.holders.UInt2Holder; +import org.apache.drill.exec.expr.holders.NullableUInt2Holder; +import org.apache.drill.exec.expr.holders.UInt4Holder; +import org.apache.drill.exec.expr.holders.NullableUInt4Holder; +import org.apache.drill.exec.expr.holders.UInt8Holder; +import org.apache.drill.exec.expr.holders.NullableUInt8Holder; +import org.apache.drill.exec.record.RecordBatch; +import org.apache.drill.exec.expr.holders.NullableFloat8Holder; +import org.apache.drill.exec.expr.holders.NullableFloat4Holder; +import org.apache.drill.exec.expr.holders.Float8Holder; +import org.apache.drill.exec.expr.holders.Float4Holder; + +@SuppressWarnings("unused") + +public class ${aggrtype.className}Functions { + static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(${aggrtype.className}Functions.class); + +<#list aggrtype.types as type> + +<#if aggrtype.aliasName == ""> +@FunctionTemplate(name = "${aggrtype.funcName}", scope = FunctionTemplate.FunctionScope.POINT_AGGREGATE) +<#else> +@FunctionTemplate(names = {"${aggrtype.funcName}", "${aggrtype.aliasName}"}, scope = FunctionTemplate.FunctionScope.POINT_AGGREGATE) +</#if> + +public static class ${type.inputType}${aggrtype.className} implements DrillAggFunc{ + + @Param ${type.inputType}Holder xIn; + @Param ${type.inputType}Holder yIn; + + @Workspace ${type.movingAverageType}Holder xMean; + @Workspace ${type.movingAverageType}Holder yMean; + @Workspace ${type.movingAverageType}Holder xyMean; + + @Workspace ${type.movingDeviationType}Holder covar; + + @Workspace ${type.countRunningType}Holder count; + @Output ${type.outputType}Holder out; + + public void setup(RecordBatch b) { + xMean = new ${type.movingAverageType}Holder(); + yMean = new ${type.movingAverageType}Holder(); + xyMean = new ${type.movingDeviationType}Holder(); + count = new ${type.countRunningType}Holder(); + covar = new ${type.movingDeviationType}Holder(); + + // Initialize the workspace variables + xMean.value = 0; + yMean.value = 0; + xyMean.value = 0; + count.value = 1; + covar.value = 0; + } + + @Override + public void add() { + <#if type.inputType?starts_with("Nullable")> + sout: { + if (xIn.isSet == 0 || yIn.isSet == 0) { + // processing nullable input and the value is null, so don't do anything... + break sout; + } + </#if> + + // compute covariance + xMean.value += ((xIn.value - xMean.value) / count.value); + yMean.value += ((yIn.value - yMean.value) / count.value); + + xyMean.value += ((xIn.value * yIn.value) - xyMean.value) / count.value; + count.value++; + <#if type.inputType?starts_with("Nullable")> + } // end of sout block + </#if> + } + + @Override + public void output() { + <#if aggrtype.funcName == "covar_pop"> + out.value = (xyMean.value - (xMean.value * yMean.value)); + <#elseif aggrtype.funcName == "covar_samp"> + out.value = (xyMean.value - (xMean.value * yMean.value))*(count.value - 1)/(count.value - 2); + </#if> + } + + @Override + public void reset() { + xMean.value = 0; + yMean.value = 0; + xyMean.value = 0; + count.value = 1; + covar.value = 0; + } +} + + +</#list> +} +</#list> \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-drill/blob/e170cf0b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunction.java ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunction.java b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunction.java index ffb372d..5e57dc7 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunction.java +++ b/exec/java-exec/src/test/java/org/apache/drill/exec/fn/impl/TestAggregateFunction.java @@ -38,8 +38,9 @@ import com.google.common.base.Charsets; import com.google.common.io.Files; public class TestAggregateFunction extends PopUnitTestBase { - @Test - public void testSortDate() throws Exception { + + public void runTest(Object[] values, String planPath, String dataPath) throws Throwable { + try (RemoteServiceSet serviceSet = RemoteServiceSet.getLocalServiceSet(); Drillbit bit = new Drillbit(CONFIG, serviceSet); DrillClient client = new DrillClient(CONFIG, serviceSet.getCoordinator())) { @@ -48,27 +49,17 @@ public class TestAggregateFunction extends PopUnitTestBase { bit.run(); client.connect(); List<QueryResultBatch> results = client.runQuery(QueryType.PHYSICAL, - Files.toString(FileUtils.getResourceAsFile("/functions/test_stddev_variance.json"), Charsets.UTF_8) - .replace("#{TEST_FILE}", "/simple_stddev_variance_input.json")); + Files.toString(FileUtils.getResourceAsFile(planPath), Charsets.UTF_8).replace("#{TEST_FILE}", dataPath)); RecordBatchLoader batchLoader = new RecordBatchLoader(bit.getContext().getAllocator()); QueryResultBatch batch = results.get(0); assertTrue(batchLoader.load(batch.getHeader().getDef(), batch.getData())); - Double values[] = {2.0d, - 2.138089935299395d, - 2.138089935299395d, - 4.0d, - 4.571428571428571d, - 4.571428571428571d}; - int i = 0; for (VectorWrapper<?> v : batchLoader) { - - ValueVector.Accessor accessor = v.getValueVector().getAccessor(); - System.out.println(accessor.getObject(0)); - assertEquals((accessor.getObject(0)), values[i++]); + ValueVector.Accessor accessor = v.getValueVector().getAccessor(); + assertEquals(values[i++], (accessor.getObject(0))); } batchLoader.clear(); @@ -77,4 +68,22 @@ public class TestAggregateFunction extends PopUnitTestBase { } } } + + @Test + public void testSortDate() throws Throwable { + String planPath = "/functions/test_stddev_variance.json"; + String dataPath = "/simple_stddev_variance_input.json"; + Double expectedValues[] = {2.0d, 2.138089935299395d, 2.138089935299395d, 4.0d, 4.571428571428571d, 4.571428571428571d}; + + runTest(expectedValues, planPath, dataPath); + } + + @Test + public void testCovarianceCorrelation() throws Throwable { + String planPath = "/functions/test_covariance.json"; + String dataPath = "/covariance_input.json"; + Double expectedValues[] = {4.571428571428571d, 4.857142857142857d, -6.000000000000002d, 4.0d , 4.25d, -5.250000000000002d, 1.0d, 0.9274260335029677d, -1.0000000000000004d}; + + runTest(expectedValues, planPath, dataPath); } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-drill/blob/e170cf0b/exec/java-exec/src/test/resources/covariance_input.json ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/test/resources/covariance_input.json b/exec/java-exec/src/test/resources/covariance_input.json new file mode 100644 index 0000000..112aeb4 --- /dev/null +++ b/exec/java-exec/src/test/resources/covariance_input.json @@ -0,0 +1,8 @@ +{"A" : 2.0, "B" : 2.0, "C" : 1.0, "D" : 8.0} +{"A" : 4.0, "B" : 4.0, "C" : 2.0, "D" : 7.0} +{"A" : 4.0, "B" : 4.0, "C" : 3.0, "D" : 6.0} +{"A" : 4.0, "B" : 4.0, "C" : 4.0, "D" : 5.0} +{"A" : 5.0, "B" : 5.0, "C" : 5.0, "D" : 4.0} +{"A" : 5.0, "B" : 5.0, "C" : 6.0, "D" : 3.0} +{"A" : 7.0, "B" : 7.0, "C" : 7.0, "D" : 2.0} +{"A" : 9.0, "B" : 9.0, "C" : 8.0, "D" : 1.0} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-drill/blob/e170cf0b/exec/java-exec/src/test/resources/functions/test_covariance.json ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/test/resources/functions/test_covariance.json b/exec/java-exec/src/test/resources/functions/test_covariance.json new file mode 100644 index 0000000..3572090 --- /dev/null +++ b/exec/java-exec/src/test/resources/functions/test_covariance.json @@ -0,0 +1,85 @@ +{ + "head" : { + "version" : 1, + "generator" : { + "type" : "org.apache.drill.exec.planner.logical.DrillImplementor", + "info" : "" + }, + "type" : "APACHE_DRILL_PHYSICAL", + "resultMode" : "EXEC" + }, + graph:[ + { + @id:1, + pop:"fs-scan", + format: {type: "json"}, + storage:{type: "file", connection: "classpath:///"}, + files:["#{TEST_FILE}"] + }, + { + "pop" : "project", + "@id" : 2, + "exprs" : [ { + "ref" : "`A`", + "expr" : "`A`" + }, + { + "ref" : "`B`", + "expr" : "`B`" + }, + { + "ref" : "`C`", + "expr" : "`C`" + }, + { + "ref" : "`D`", + "expr" : "`D`" + } ], + "child" : 1 + }, { + "pop" : "streaming-aggregate", + "@id" : 3, + "child" : 2, + "keys" : [ ], + "exprs" : [ { + "ref" : "`EXPR$1`", + "expr" : "covar_samp(`A`, `B`) " + }, + { + "ref" : "`EXPR$2`", + "expr" : "covar_samp(`A`, `C`) " + }, + { + "ref" : "`EXPR$3`", + "expr" : "covar_samp(`C`, `D`) " + }, + { + "ref" : "`EXPR$4`", + "expr" : "covar_pop(`A`, `B`) " + }, + { + "ref" : "`EXPR$5`", + "expr" : "covar_pop(`A`, `C`) " + }, + { + "ref" : "`EXPR$6`", + "expr" : "covar_pop(`C`, `D`) " + }, + { + "ref" : "`EXPR$7`", + "expr" : "corr(`A`, `B`) " + }, + { + "ref" : "`EXPR$7`", + "expr" : "corr(`A`, `C`) " + }, + { + "ref" : "`EXPR$8`", + "expr" : "corr(`C`, `D`) " + } ] + }, { + "pop" : "screen", + "@id" : 4, + "child" : 3 + } ] +} \ No newline at end of file
