Folks,
I'd like to add weighted statistics to PostgreSQL. While the included
weighted_avg() is trivial to calculate using existing machinery, the
included weighted_stddev_*() functions are not.
I've only done the float8 versions, but if we decide to move forward,
I'd be delighted to add the rest of the numeric types and maybe others
as make sense.
What say?
Cheers,
David.
--
David Fetter <[email protected]> http://fetter.org/
Phone: +1 415 235 3778 AIM: dfetter666 Yahoo!: dfetter
Skype: davidfetter XMPP: [email protected]
Remember to vote!
Consider donating to Postgres: http://www.postgresql.org/about/donate
diff --git a/doc/src/sgml/func.sgml b/doc/src/sgml/func.sgml
index 4d482ec..2174594 100644
--- a/doc/src/sgml/func.sgml
+++ b/doc/src/sgml/func.sgml
@@ -12443,6 +12443,29 @@ NULL baz</literallayout>(3 rows)</entry>
<row>
<entry>
<indexterm>
+ <primary>weighted_average</primary>
+ </indexterm>
+ <indexterm>
+ <primary>weighted_avg</primary>
+ </indexterm>
+ <function>weighted_avg(<replaceable class="parameter">value
expression</replaceable>, <replaceable class="parameter">weight
expression</replaceable>)</function>
+ </entry>
+ <entry>
+ <type>smallint</type>, <type>int</type>,
+ <type>bigint</type>, <type>real</type>, <type>double
+ precision</type>, <type>numeric</type>, or <type>interval</type>
+ </entry>
+ <entry>
+ <type>numeric</type> for any integer-type argument,
+ <type>double precision</type> for a floating-point argument,
+ otherwise the same as the argument data type
+ </entry>
+ <entry>the average (arithmetic mean) of all input values, weighted by
the input weights</entry>
+ </row>
+
+ <row>
+ <entry>
+ <indexterm>
<primary>bit_and</primary>
</indexterm>
<function>bit_and(<replaceable
class="parameter">expression</replaceable>)</function>
@@ -13086,6 +13109,29 @@ SELECT xmlagg(x) FROM (SELECT x FROM test ORDER BY y
DESC) AS tab;
<row>
<entry>
<indexterm>
+ <primary>weighted standard deviation</primary>
+ <secondary>population</secondary>
+ </indexterm>
+ <indexterm>
+ <primary>weighted_stddev_pop</primary>
+ </indexterm>
+ <function>weighted_stddev_pop(<replaceable class="parameter">value
expression</replaceable>, <replaceable class="parameter">weight
expression</replaceable>)</function>
+ </entry>
+ <entry>
+ <type>smallint</type>, <type>int</type>,
+ <type>bigint</type>, <type>real</type>, <type>double
+ precision</type>, or <type>numeric</type>
+ </entry>
+ <entry>
+ <type>double precision</type> for floating-point arguments,
+ otherwise <type>numeric</type>
+ </entry>
+ <entry>weighted population standard deviation of the input values</entry>
+ </row>
+
+ <row>
+ <entry>
+ <indexterm>
<primary>standard deviation</primary>
<secondary>sample</secondary>
</indexterm>
@@ -13109,6 +13155,29 @@ SELECT xmlagg(x) FROM (SELECT x FROM test ORDER BY y
DESC) AS tab;
<row>
<entry>
<indexterm>
+ <primary>weighted standard deviation</primary>
+ <secondary>sample</secondary>
+ </indexterm>
+ <indexterm>
+ <primary>weighted_stddev_samp</primary>
+ </indexterm>
+ <function>weighted_stddev_samp(<replaceable class="parameter">value
expression</replaceable>, <replaceable class="parameter">weight
expression</replaceable>)</function>
+ </entry>
+ <entry>
+ <type>smallint</type>, <type>int</type>,
+ <type>bigint</type>, <type>real</type>, <type>double
+ precision</type>, or <type>numeric</type>
+ </entry>
+ <entry>
+ <type>double precision</type> for floating-point arguments,
+ otherwise <type>numeric</type>
+ </entry>
+ <entry>weighted sample standard deviation of the input values</entry>
+ </row>
+
+ <row>
+ <entry>
+ <indexterm>
<primary>variance</primary>
</indexterm>
<function>variance</function>(<replaceable
class="parameter">expression</replaceable>)
diff --git a/src/backend/utils/adt/float.c b/src/backend/utils/adt/float.c
index 4e927d8..533ce0a 100644
--- a/src/backend/utils/adt/float.c
+++ b/src/backend/utils/adt/float.c
@@ -1774,6 +1774,7 @@ setseed(PG_FUNCTION_ARGS)
* float8_accum - accumulate for AVG(), variance
aggregates, etc.
* float4_accum - same, but input data is float4
* float8_avg - produce final result for
float AVG()
+ * float8_weighted_avg - produce final result for float
WEIGHTED_AVG()
* float8_var_samp - produce final result for float
VAR_SAMP()
* float8_var_pop - produce final result for float
VAR_POP()
* float8_stddev_samp - produce final result for float
STDDEV_SAMP()
@@ -1929,6 +1930,28 @@ float8_avg(PG_FUNCTION_ARGS)
}
Datum
+float8_weighted_avg(PG_FUNCTION_ARGS)
+{
+ ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0);
+ float8 *transvalues;
+ float8 N,
+ sumWX,
+ sumW;
+
+ transvalues = check_float8_array(transarray, "float8_weighted_avg", 6);
+ N = transvalues[0];
+ sumW = transvalues[1];
+ sumWX = transvalues[5];
+
+ if (N < 1.0)
+ PG_RETURN_NULL();
+
+ CHECKFLOATVAL(N, isinf(1.0/sumW) || isinf(sumWX), true);
+
+ PG_RETURN_FLOAT8(sumWX/sumW);
+}
+
+Datum
float8_var_pop(PG_FUNCTION_ARGS)
{
ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0);
@@ -2467,6 +2490,119 @@ float8_regr_intercept(PG_FUNCTION_ARGS)
PG_RETURN_FLOAT8(numeratorXXY / numeratorX);
}
+/*
+ * ===================
+ * WEIGHTED AGGREGATES
+ * ===================
+ *
+ * The transition datatype for these aggregates is a 4-element array
+ * of float8, holding the values N, sum(W), sum(W*X), and sum(W*X*X)
+ * in that order.
+ *
+ * First, an accumulator function for those we can't pirate from the
+ * other accumulators. This accumulator function takes out some of
+ * the rounding error inherent in the general one.
+ * https://en.wikipedia.org/wiki/Standard_deviation#Rapid_calculation_methods
+ *
+ * It consists of a four-element array which includes:
+ *
+ * N, the number of non-zero-weighted values seen thus far,
+ * W, the running sum of weights,
+ * A, an intermediate value used in the calculation, and
+ * Q, another intermediate value.
+ *
+ */
+Datum
+float8_weighted_accum(PG_FUNCTION_ARGS)
+{
+ ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0);
+ float8 newvalX = PG_GETARG_FLOAT8(1);
+ float8 newvalW = PG_GETARG_FLOAT8(2);
+ float8 *transvalues;
+ float8 N,
+ W,
+ A,
+ Q;
+
+ transvalues = check_float8_array(transarray,
"float8_weighted_stddev_accum", 4);
+
+ if (newvalW <= 0.0) /* We only care about positive weights */
+ PG_RETURN_NULL();
+
+ N = transvalues[0];
+ W = transvalues[1];
+ A = transvalues[2];
+ Q = transvalues[3];
+
+ N += 1.0;
+ CHECKFLOATVAL(N, isinf(transvalues[0]), true);
+ W += newvalW;
+ CHECKFLOATVAL(W, isinf(transvalues[1]) || isinf(newvalW), true);
+ A += newvalW * ( newvalX - transvalues[2] ) / W;
+ CHECKFLOATVAL(A, isinf(newvalW) || isinf(transvalues[2]) ||
isinf(1.0/W), true);
+ Q += newvalW * (newvalX - transvalues[2]) * (newvalX - A);
+ CHECKFLOATVAL(A, isinf(newvalX - transvalues[3]) || isinf(newvalX - A)
|| isinf(1.0/W), true);
+
+ if (AggCheckCallContext(fcinfo, NULL)) /* Update in place is safe in
Agg context */
+ {
+ transvalues[0] = N;
+ transvalues[1] = W;
+ transvalues[2] = A;
+ transvalues[3] = Q;
+
+ PG_RETURN_ARRAYTYPE_P(transarray);
+ }
+ else /* You do not need to call this directly. */
+ ereport(ERROR,
+ (errmsg("float8_weighted_accum called outside
agg context")));
+}
+
+Datum
+float8_weighted_stddev_samp(PG_FUNCTION_ARGS)
+{
+ ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0);
+ float8 *transvalues;
+ float8 N,
+ W,
+ /* Skip A. Not used in the calculation */
+ Q;
+
+ transvalues = check_float8_array(transarray,
"float8_weighted_stddev_samp", 4);
+ N = transvalues[0];
+ W = transvalues[1];
+ Q = transvalues[3];
+
+ if (N < 2.0) /* Must have at least two samples to get a stddev */
+ PG_RETURN_NULL();
+
+ PG_RETURN_FLOAT8(
+ sqrt(
+ N * Q /
+ ( (N-1) * W )
+ )
+ );
+}
+
+Datum
+float8_weighted_stddev_pop(PG_FUNCTION_ARGS)
+{
+ ArrayType *transarray = PG_GETARG_ARRAYTYPE_P(0);
+ float8 *transvalues;
+ float8 N,
+ W,
+ /* Skip A. Not used in the calculation */
+ Q;
+
+ transvalues = check_float8_array(transarray,
"float8_weighted_stddev_pop", 4);
+ N = transvalues[0];
+ W = transvalues[1];
+ Q = transvalues[3];
+
+ if (N < 2.0) /* Must have at least two samples to get a stddev */
+ PG_RETURN_NULL();
+
+ PG_RETURN_FLOAT8( sqrt( Q / W ) );
+}
/*
* ====================================
diff --git a/src/include/catalog/pg_aggregate.h
b/src/include/catalog/pg_aggregate.h
index dd6079f..6d2f9d4 100644
--- a/src/include/catalog/pg_aggregate.h
+++ b/src/include/catalog/pg_aggregate.h
@@ -133,6 +133,7 @@ DATA(insert ( 2103 n 0 numeric_avg_accum numeric_avg
numeric_avg_accum numeric_a
DATA(insert ( 2104 n 0 float4_accum float8_avg -
- -
f f 0 1022 0 0
0 "{0,0,0}" _null_ ));
DATA(insert ( 2105 n 0 float8_accum float8_avg -
- -
f f 0 1022 0 0
0 "{0,0,0}" _null_ ));
DATA(insert ( 2106 n 0 interval_accum interval_avg interval_accum
interval_accum_inv interval_avg f f 0 1187
0 1187 0 "{0 second,0 second}" "{0 second,0 second}" ));
+DATA(insert ( 3998 n 0 float8_regr_accum float8_weighted_avg
- - -
f f 0 1022 0 0 0 "{0,0,0,0,0,0}"
_null_ ));
/* sum */
DATA(insert ( 2107 n 0 int8_avg_accum numeric_poly_sum
int8_avg_accum int8_avg_accum_inv numeric_poly_sum f f 0 2281 48
2281 48 _null_ _null_ ));
@@ -225,6 +226,7 @@ DATA(insert ( 2726 n 0 int2_accum numeric_poly_stddev_pop
int2_accum int2_accum_
DATA(insert ( 2727 n 0 float4_accum float8_stddev_pop -
- -
f f 0 1022 0 0 0
"{0,0,0}" _null_ ));
DATA(insert ( 2728 n 0 float8_accum float8_stddev_pop -
- -
f f 0 1022 0 0 0
"{0,0,0}" _null_ ));
DATA(insert ( 2729 n 0 numeric_accum numeric_stddev_pop
numeric_accum numeric_accum_inv numeric_stddev_pop f f 0
2281 128 2281 128 _null_ _null_ ));
+DATA(insert ( 4066 n 0 float8_weighted_accum
float8_weighted_stddev_pop - -
- f f 0
1022 0 0 0 "{0,0,0,0}" _null_ ));
/* stddev_samp */
DATA(insert ( 2712 n 0 int8_accum numeric_stddev_samp
int8_accum int8_accum_inv numeric_stddev_samp
f f 0 2281 128 2281 128 _null_ _null_ ));
@@ -232,6 +234,7 @@ DATA(insert ( 2713 n 0 int4_accum
numeric_poly_stddev_samp int4_accum int4_accum
DATA(insert ( 2714 n 0 int2_accum numeric_poly_stddev_samp
int2_accum int2_accum_inv numeric_poly_stddev_samp f f 0 2281
48 2281 48 _null_ _null_ ));
DATA(insert ( 2715 n 0 float4_accum float8_stddev_samp -
- -
f f 0 1022 0 0 0
"{0,0,0}" _null_ ));
DATA(insert ( 2716 n 0 float8_accum float8_stddev_samp -
- -
f f 0 1022 0 0 0
"{0,0,0}" _null_ ));
+DATA(insert ( 4083 n 0 float8_weighted_accum
float8_weighted_stddev_samp - -
- f f 0
1022 0 0 0 "{0,0,0,0}" _null_ ));
DATA(insert ( 2717 n 0 numeric_accum numeric_stddev_samp
numeric_accum numeric_accum_inv numeric_stddev_samp f f 0
2281 128 2281 128 _null_ _null_ ));
/* stddev: historical Postgres syntax for stddev_samp */
diff --git a/src/include/catalog/pg_proc.h b/src/include/catalog/pg_proc.h
index f688454..83c4b64 100644
--- a/src/include/catalog/pg_proc.h
+++ b/src/include/catalog/pg_proc.h
@@ -2502,6 +2502,12 @@ DESCR("join selectivity of case-insensitive regex
non-match");
/* Aggregate-related functions */
DATA(insert OID = 1830 ( float8_avg PGNSP PGUID 12 1 0 0 0 f f f f t f i
s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_avg _null_ _null_
_null_ ));
DESCR("aggregate final function");
+DATA(insert OID = 3997 ( float8_weighted_avg PGNSP PGUID 12
1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_
float8_weighted_avg _null_ _null_ _null_ ));
+DESCR("aggregate final function");
+DATA(insert OID = 4099 ( float8_weighted_stddev_pop PGNSP
PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_
_null_ float8_weighted_stddev_pop _null_ _null_ _null_ ));
+DESCR("aggregate final function");
+DATA(insert OID = 4100 ( float8_weighted_stddev_samp PGNSP
PGUID 12 1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_
_null_ float8_weighted_stddev_samp _null_ _null_ _null_ ));
+DESCR("aggregate final function");
DATA(insert OID = 2512 ( float8_var_pop PGNSP PGUID 12 1 0 0 0 f f f f t f
i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_var_pop _null_
_null_ _null_ ));
DESCR("aggregate final function");
DATA(insert OID = 1831 ( float8_var_samp PGNSP PGUID 12 1 0 0 0 f f f f t f
i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_ float8_var_samp _null_
_null_ _null_ ));
@@ -2585,6 +2591,8 @@ DATA(insert OID = 2805 ( int8inc_float8_float8
PGNSP PGUID 12 1 0 0 0 f f f f
DESCR("aggregate transition function");
DATA(insert OID = 2806 ( float8_regr_accum PGNSP PGUID 12
1 0 0 0 f f f f t f i s 3 0 1022 "1022 701 701" _null_ _null_ _null_ _null_
_null_ float8_regr_accum _null_ _null_ _null_ ));
DESCR("aggregate transition function");
+DATA(insert OID = 3999 ( float8_weighted_accum PGNSP
PGUID 12 1 0 0 0 f f f f t f i s 3 0 1022 "1022 701 701" _null_ _null_ _null_
_null_ _null_ float8_weighted_accum _null_ _null_ _null_ ));
+DESCR("aggregate transition function");
DATA(insert OID = 2807 ( float8_regr_sxx PGNSP PGUID 12
1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_
float8_regr_sxx _null_ _null_ _null_ ));
DESCR("aggregate final function");
DATA(insert OID = 2808 ( float8_regr_syy PGNSP PGUID 12
1 0 0 0 f f f f t f i s 1 0 701 "1022" _null_ _null_ _null_ _null_ _null_
float8_regr_syy _null_ _null_ _null_ ));
@@ -3229,6 +3237,8 @@ DATA(insert OID = 2104 ( avg
PGNSP PGUID 12 1 0 0 0 t f f f f f i s 1 0 701
DESCR("the average (arithmetic mean) as float8 of all float4 values");
DATA(insert OID = 2105 ( avg PGNSP PGUID 12 1 0 0 0
t f f f f f i s 1 0 701 "701" _null_ _null_ _null_ _null_ _null_
aggregate_dummy _null_ _null_ _null_ ));
DESCR("the average (arithmetic mean) as float8 of all float8 values");
+DATA(insert OID = 3998 ( weighted_avg PGNSP PGUID 12 1 0 0 0 t f f f
f f i s 2 0 701 "701 701" _null_ _null_ _null_ _null_ _null_
aggregate_dummy _null_ _null_ _null_ ));
+DESCR("the weighted average (arithmetic mean) as float8 of all float8 values");
DATA(insert OID = 2106 ( avg PGNSP PGUID 12 1 0 0 0
t f f f f f i s 1 0 1186 "1186" _null_ _null_ _null_ _null_ _null_
aggregate_dummy _null_ _null_ _null_ ));
DESCR("the average (arithmetic mean) as interval of all interval values");
@@ -3389,6 +3399,8 @@ DATA(insert OID = 2728 ( stddev_pop PGNSP
PGUID 12 1 0 0 0 t f f f f f i s 1 0
DESCR("population standard deviation of float8 input values");
DATA(insert OID = 2729 ( stddev_pop PGNSP PGUID 12 1 0 0 0 t f f f
f f i s 1 0 1700 "1700" _null_ _null_ _null_ _null_ _null_ aggregate_dummy
_null_ _null_ _null_ ));
DESCR("population standard deviation of numeric input values");
+DATA(insert OID = 4066 ( weighted_stddev_pop PGNSP PGUID 12 1 0 0 0
t f f f f f i s 2 0 701 "701 701" _null_ _null_ _null_ _null_ _null_
aggregate_dummy _null_ _null_ _null_ ));
+DESCR("population weighted standard deviation of float8 input values");
DATA(insert OID = 2712 ( stddev_samp PGNSP PGUID 12 1 0 0 0 t f f f
f f i s 1 0 1700 "20" _null_ _null_ _null_ _null_ _null_ aggregate_dummy _null_
_null_ _null_ ));
DESCR("sample standard deviation of bigint input values");
@@ -3402,6 +3414,8 @@ DATA(insert OID = 2716 ( stddev_samp PGNSP
PGUID 12 1 0 0 0 t f f f f f i s 1
DESCR("sample standard deviation of float8 input values");
DATA(insert OID = 2717 ( stddev_samp PGNSP PGUID 12 1 0 0 0 t f f f
f f i s 1 0 1700 "1700" _null_ _null_ _null_ _null_ _null_ aggregate_dummy
_null_ _null_ _null_ ));
DESCR("sample standard deviation of numeric input values");
+DATA(insert OID = 4083 ( weighted_stddev_samp PGNSP PGUID 12 1 0 0 0
t f f f f f i s 2 0 701 "701 701" _null_ _null_ _null_ _null_ _null_
aggregate_dummy _null_ _null_ _null_ ));
+DESCR("sample weighted standard deviation of float8 input values");
DATA(insert OID = 2154 ( stddev PGNSP PGUID 12 1 0 0 0
t f f f f f i s 1 0 1700 "20" _null_ _null_ _null_ _null_ _null_
aggregate_dummy _null_ _null_ _null_ ));
DESCR("historical alias for stddev_samp");
diff --git a/src/include/utils/builtins.h b/src/include/utils/builtins.h
index fc1679e..333d538 100644
--- a/src/include/utils/builtins.h
+++ b/src/include/utils/builtins.h
@@ -413,8 +413,12 @@ extern Datum radians(PG_FUNCTION_ARGS);
extern Datum drandom(PG_FUNCTION_ARGS);
extern Datum setseed(PG_FUNCTION_ARGS);
extern Datum float8_accum(PG_FUNCTION_ARGS);
+extern Datum float8_weighted_accum(PG_FUNCTION_ARGS);
extern Datum float4_accum(PG_FUNCTION_ARGS);
extern Datum float8_avg(PG_FUNCTION_ARGS);
+extern Datum float8_weighted_avg(PG_FUNCTION_ARGS);
+extern Datum float8_weighted_stddev_pop(PG_FUNCTION_ARGS);
+extern Datum float8_weighted_stddev_samp(PG_FUNCTION_ARGS);
extern Datum float8_var_pop(PG_FUNCTION_ARGS);
extern Datum float8_var_samp(PG_FUNCTION_ARGS);
extern Datum float8_stddev_pop(PG_FUNCTION_ARGS);
diff --git a/src/test/regress/expected/aggregates.out
b/src/test/regress/expected/aggregates.out
index de826b5..a19fd1d 100644
--- a/src/test/regress/expected/aggregates.out
+++ b/src/test/regress/expected/aggregates.out
@@ -247,6 +247,18 @@ SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest;
653.62895538751 | 871.505273850014
(1 row)
+SELECT weighted_avg(a, b) FROM aggtest;
+ weighted_avg
+------------------
+ 55.5553072763149
+(1 row)
+
+SELECT weighted_stddev_pop(a, b), weighted_stddev_samp(a, b) FROM aggtest;
+ weighted_stddev_pop | weighted_stddev_samp
+---------------------+----------------------
+ 24.3364627240769 | 28.1013266097382
+(1 row)
+
SELECT corr(b, a) FROM aggtest;
corr
-------------------
diff --git a/src/test/regress/sql/aggregates.sql
b/src/test/regress/sql/aggregates.sql
index 8d501dc..77b6102 100644
--- a/src/test/regress/sql/aggregates.sql
+++ b/src/test/regress/sql/aggregates.sql
@@ -60,6 +60,8 @@ SELECT regr_avgx(b, a), regr_avgy(b, a) FROM aggtest;
SELECT regr_r2(b, a) FROM aggtest;
SELECT regr_slope(b, a), regr_intercept(b, a) FROM aggtest;
SELECT covar_pop(b, a), covar_samp(b, a) FROM aggtest;
+SELECT weighted_avg(a, b) FROM aggtest;
+SELECT weighted_stddev_pop(a, b), weighted_stddev_samp(a, b) FROM aggtest;
SELECT corr(b, a) FROM aggtest;
SELECT count(four) AS cnt_1000 FROM onek;
--
Sent via pgsql-hackers mailing list ([email protected])
To make changes to your subscription:
http://www.postgresql.org/mailpref/pgsql-hackers