Github user iyerr3 commented on a diff in the pull request: https://github.com/apache/madlib/pull/225#discussion_r162371982 --- Diff: src/ports/postgres/modules/knn/knn.py_in --- @@ -211,23 +222,43 @@ def knn(schema_madlib, point_source, point_column_name, point_id, ) {y_temp_table} WHERE {y_temp_table}.r <= {k_val} """.format(**locals())) - - plpy.execute( - """ - CREATE TABLE {output_table} AS - SELECT {test_id_temp} AS id, {test_column_name} - {pred_out} - {knn_neighbors} - FROM pg_temp.{interim_table} AS knn_temp - JOIN - {test_source} AS knn_test ON - knn_temp.{test_id_temp} = knn_test.{test_id} - GROUP BY {test_id_temp} , {test_column_name} - """.format(**locals())) - - plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table)) + if weighted_avg and is_classification: + plpy.execute( + """ + CREATE TABLE {output_table} AS + SELECT id, {test_column_name} ,max(prediction) as prediction + {k_neighbours} + FROM + ( SELECT {test_id_temp} AS id, {test_column_name} + {pred_out} + {knn_neighbors} + FROM pg_temp.{interim_table} AS knn_temp + JOIN + {test_source} AS knn_test ON + knn_temp.{test_id_temp} = knn_test.{test_id} + GROUP BY {test_id_temp} , + {test_column_name}, {label_col_temp}) + a {k_neighbours_unnest} + GROUP BY id, {test_column_name} + """.format(**locals())) + else: + plpy.execute( + """ + CREATE TABLE {output_table} AS + SELECT {test_id_temp} AS id, {test_column_name} --- End diff -- This is a subquery in the above query. Let's rewrite to make it a single query and add a wrapper if it's `weighted_avg`. Also we need comments to describe what's happening here. I had to read a couple of times to understand what's the flow for weighted_avg regression.
---