Github user iyerr3 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/225#discussion_r162369682
  
    --- Diff: src/ports/postgres/modules/knn/knn.py_in ---
    @@ -167,22 +169,31 @@ def knn(schema_madlib, point_source, 
point_column_name, point_id,
             knn_neighbors = ""
             label_out = ""
             cast_to_int = ""
    +        k_neighbours = ""
    +        k_neighbours_unnest = ""
     
             if output_neighbors:
                 knn_neighbors = (", array_agg(knn_temp.train_id ORDER BY "
                                  "knn_temp.dist ASC) AS k_nearest_neighbours ")
    +            k_neighbours = ", array_agg(distinct k_neighbours) AS 
k_nearest_neighbours"
    +            k_neighbours_unnest = ", unnest(k_nearest_neighbours) as 
k_neighbours"
             if label_column_name:
                 is_classification = False
                 label_column_type = get_expr_type(
                     label_column_name, point_source).lower()
                 if label_column_type in ['boolean', 'integer', 'text']:
                     is_classification = True
                     cast_to_int = '::INTEGER'
    -
    -            pred_out = ", avg({label_col_temp})".format(**locals())
    +            if weighted_avg:
    +                pred_out = ",sum( {label_col_temp} * 
1/dist)/sum(1/dist)".format(**locals())
    +            else:
    +                pred_out = ", avg({label_col_temp})".format(**locals())
                 if is_classification:
    -                pred_out = (", {schema_madlib}.mode({label_col_temp})"
    -                            ).format(**locals())
    +                if weighted_avg:
    +                    pred_out = ",sum(1/dist)".format(**locals())
    --- End diff --
    
    Nothing to format here. 


---

Reply via email to