Github user njayaram2 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/225#discussion_r163652456
--- Diff: src/ports/postgres/modules/knn/knn.py_in ---
@@ -178,11 +183,38 @@ def knn(schema_madlib, point_source,
point_column_name, point_id,
if label_column_type in ['boolean', 'integer', 'text']:
is_classification = True
cast_to_int = '::INTEGER'
+ if weighted_avg:
+ pred_out = ",sum( {label_col_temp} *
1/dist)/sum(1/dist)".format(
+ label_col_temp=label_col_temp)
+ else:
+ pred_out = ", avg({label_col_temp})".format(
+ label_col_temp=label_col_temp)
- pred_out = ", avg({label_col_temp})".format(**locals())
if is_classification:
- pred_out = (", {schema_madlib}.mode({label_col_temp})"
- ).format(**locals())
+ if weighted_avg:
+ # This view is to calculate the max value of sum of
the 1/distance grouped by label and Id.
+ # And this max value will be the prediction for the
+ # classification model.
+ view_def = (" WITH vw "
+ " AS (SELECT {test_id_temp} ,"
+ " max(data_sum) data_dist "
+ " FROM (SELECT {test_id_temp}, "
+ " sum(1 / dist) data_sum"
+ " FROM pg_temp.{interim_table} "
+ " GROUP BY {test_id_temp}, "
+ " {label_col_temp}) a "
+ " GROUP BY {test_id_temp}
)").format(**locals())
+ # This join is needed to get the max value of predicion
+ # calculated above
+ view_join = (" JOIN vw AS knn_vw "
+ "ON knn_temp.{test_id_temp} =
knn_vw.{test_id_temp}").format(
+ test_id_temp=test_id_temp)
+ view_grp_by = ", knn_vw.data_dist "
+ pred_out = ", knn_vw.data_dist"
+ else:
+ pred_out = (", {schema_madlib}.mode({label_col_temp})"
+ ).format(**locals())
+
--- End diff --
We can have this string inside a single pair of `"""..."""`, instead of
multiple `"..."`
---