Github user iyerr3 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/225#discussion_r162371982
--- Diff: src/ports/postgres/modules/knn/knn.py_in ---
@@ -211,23 +222,43 @@ def knn(schema_madlib, point_source,
point_column_name, point_id,
) {y_temp_table}
WHERE {y_temp_table}.r <= {k_val}
""".format(**locals()))
-
- plpy.execute(
- """
- CREATE TABLE {output_table} AS
- SELECT {test_id_temp} AS id, {test_column_name}
- {pred_out}
- {knn_neighbors}
- FROM pg_temp.{interim_table} AS knn_temp
- JOIN
- {test_source} AS knn_test ON
- knn_temp.{test_id_temp} = knn_test.{test_id}
- GROUP BY {test_id_temp} , {test_column_name}
- """.format(**locals()))
-
- plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
+ if weighted_avg and is_classification:
+ plpy.execute(
+ """
+ CREATE TABLE {output_table} AS
+ SELECT id, {test_column_name} ,max(prediction) as
prediction
+ {k_neighbours}
+ FROM
+ ( SELECT {test_id_temp} AS id, {test_column_name}
+ {pred_out}
+ {knn_neighbors}
+ FROM pg_temp.{interim_table} AS knn_temp
+ JOIN
+ {test_source} AS knn_test ON
+ knn_temp.{test_id_temp} =
knn_test.{test_id}
+ GROUP BY {test_id_temp} ,
+ {test_column_name}, {label_col_temp})
+ a {k_neighbours_unnest}
+ GROUP BY id, {test_column_name}
+ """.format(**locals()))
+ else:
+ plpy.execute(
+ """
+ CREATE TABLE {output_table} AS
+ SELECT {test_id_temp} AS id, {test_column_name}
--- End diff --
This is a subquery in the above query. Let's rewrite to make it a single
query and add a wrapper if it's `weighted_avg`. Also we need comments to
describe what's happening here. I had to read a couple of times to understand
what's the flow for weighted_avg regression.
---