Github user iyerr3 commented on a diff in the pull request: https://github.com/apache/madlib/pull/225#discussion_r162369645 --- Diff: src/ports/postgres/modules/knn/knn.py_in --- @@ -167,22 +169,31 @@ def knn(schema_madlib, point_source, point_column_name, point_id, knn_neighbors = "" label_out = "" cast_to_int = "" + k_neighbours = "" + k_neighbours_unnest = "" if output_neighbors: knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY " "knn_temp.dist ASC) AS k_nearest_neighbours ") + k_neighbours = ", array_agg(distinct k_neighbours) AS k_nearest_neighbours" + k_neighbours_unnest = ", unnest(k_nearest_neighbours) as k_neighbours" if label_column_name: is_classification = False label_column_type = get_expr_type( label_column_name, point_source).lower() if label_column_type in ['boolean', 'integer', 'text']: is_classification = True cast_to_int = '::INTEGER' - - pred_out = ", avg({label_col_temp})".format(**locals()) + if weighted_avg: + pred_out = ",sum( {label_col_temp} * 1/dist)/sum(1/dist)".format(**locals()) --- End diff -- We should avoid `**locals()` when the format list is so short.
---