Github user iyerr3 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/225#discussion_r162371982
  
    --- Diff: src/ports/postgres/modules/knn/knn.py_in ---
    @@ -211,23 +222,43 @@ def knn(schema_madlib, point_source, 
point_column_name, point_id,
                         ) {y_temp_table}
                 WHERE {y_temp_table}.r <= {k_val}
                 """.format(**locals()))
    -
    -        plpy.execute(
    -            """
    -            CREATE TABLE {output_table} AS
    -                SELECT {test_id_temp} AS id, {test_column_name}
    -                    {pred_out}
    -                    {knn_neighbors}
    -                FROM pg_temp.{interim_table} AS knn_temp
    -                    JOIN
    -                    {test_source} AS knn_test ON
    -                    knn_temp.{test_id_temp} = knn_test.{test_id}
    -                GROUP BY {test_id_temp} , {test_column_name}
    -            """.format(**locals()))
    -
    -        plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
    +        if weighted_avg and is_classification:
    +            plpy.execute(
    +                """
    +                CREATE TABLE {output_table} AS
    +                    SELECT id, {test_column_name} ,max(prediction) as 
prediction
    +                        {k_neighbours}
    +                    FROM
    +                        ( SELECT {test_id_temp} AS id, {test_column_name}
    +                                {pred_out}
    +                                {knn_neighbors}
    +                            FROM pg_temp.{interim_table} AS knn_temp
    +                                JOIN
    +                                {test_source} AS knn_test ON
    +                                knn_temp.{test_id_temp} = 
knn_test.{test_id}
    +                            GROUP BY {test_id_temp} ,
    +                                {test_column_name}, {label_col_temp})
    +                            a {k_neighbours_unnest}
    +                    GROUP BY id, {test_column_name}
    +                """.format(**locals()))
    +        else:
    +            plpy.execute(
    +                """
    +                CREATE TABLE {output_table} AS
    +                    SELECT {test_id_temp} AS id, {test_column_name}
    --- End diff --
    
    This is a subquery in the above query. Let's rewrite to make it a single 
query and add a wrapper if it's `weighted_avg`. Also we need comments to 
describe what's happening here. I had to read a couple of times to understand 
what's the flow for weighted_avg regression. 


---

Reply via email to