Github user iyerr3 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/225#discussion_r162369682
--- Diff: src/ports/postgres/modules/knn/knn.py_in ---
@@ -167,22 +169,31 @@ def knn(schema_madlib, point_source,
point_column_name, point_id,
knn_neighbors = ""
label_out = ""
cast_to_int = ""
+ k_neighbours = ""
+ k_neighbours_unnest = ""
if output_neighbors:
knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY "
"knn_temp.dist ASC) AS k_nearest_neighbours ")
+ k_neighbours = ", array_agg(distinct k_neighbours) AS
k_nearest_neighbours"
+ k_neighbours_unnest = ", unnest(k_nearest_neighbours) as
k_neighbours"
if label_column_name:
is_classification = False
label_column_type = get_expr_type(
label_column_name, point_source).lower()
if label_column_type in ['boolean', 'integer', 'text']:
is_classification = True
cast_to_int = '::INTEGER'
-
- pred_out = ", avg({label_col_temp})".format(**locals())
+ if weighted_avg:
+ pred_out = ",sum( {label_col_temp} *
1/dist)/sum(1/dist)".format(**locals())
+ else:
+ pred_out = ", avg({label_col_temp})".format(**locals())
if is_classification:
- pred_out = (", {schema_madlib}.mode({label_col_temp})"
- ).format(**locals())
+ if weighted_avg:
+ pred_out = ",sum(1/dist)".format(**locals())
--- End diff --
Nothing to format here.
---