iyerr3 commented on a change in pull request #352: Feature/kd tree knn
URL: https://github.com/apache/madlib/pull/352#discussion_r256159891
 
 

 ##########
 File path: src/ports/postgres/modules/knn/knn.py_in
 ##########
 @@ -124,16 +137,377 @@ def knn_validate_src(schema_madlib, point_source, 
point_column_name, point_id,
             """.format(fn_dist=fn_dist, profunc=profunc))[0]['output']
 
         if is_invalid_func or (fn_dist not in dist_functions):
-            plpy.error("KNN error: Distance function has invalid signature "
-                       "or is not a simple function.")
-
+            plpy.error("KNN error: Distance function ({0}) has invalid 
signature "
+                       "or is not a simple function.".format(fn_dist))
+    if depth <= 0:
+        plpy.error("kNN Error: depth={0} is an invalid value, must be greater "
+                   "than 0.".format(depth))
+    if leaf_nodes <= 0:
+        plpy.error("kNN Error: leaf_nodes={0} is an invalid value, must be 
greater "
+                   "than 0.".format(leaf_nodes))
+    if pow(2,depth) <= leaf_nodes:
+        plpy.error("kNN Error: depth={0}, leaf_nodes={1} is not valid. "
+                   "The leaf_nodes value must be lower than 
2^depth".format(depth, leaf_nodes))
     return k
 # 
------------------------------------------------------------------------------
 
 
+def build_kd_tree(schema_madlib, source_table, output_table, point_column_name,
+                  depth, r_id, dim, **kwargs):
+    """
+        KD-tree function to create a partitioning for KNN
+        Args:
+            @param schema_madlib        Name of the Madlib Schema
+            @param source_table         Training data table
+            @param output_table         Name of the table to store kd tree
+            @param point_column_name    Name of the column with training data
+                                        or expression that evaluates to a
+                                        numeric array
+            @param depth                Depth of the kd tree
+            @param r_id                 Name of the region id column
+            @param dim                  Name of the dimension column
+    """
+    with MinWarning("error"):
+
+        validate_kd_tree(source_table, output_table, point_column_name, depth)
+        n_features = num_features(source_table, point_column_name)
+
+        clauses = [' 1=1 ']
+        cutoffs = []
+        centers_table = add_postfix(output_table, "_centers")
+        clause_counter = 0
+        current_feature = 1
+        for curr_level in range(depth):
+            for curr_leaf in range(pow(2,curr_level)):
+                clause = clauses[clause_counter]
+                cutoff_sql = """
+                    SELECT percentile_disc(0.5)
+                           WITHIN GROUP (
+                            ORDER BY ({point_column_name})[{current_feature}]
+                           ) AS cutoff
+                    FROM {source_table}
+                    WHERE {clause}
+                    """.format(**locals())
+
+                cutoff = plpy.execute(cutoff_sql)[0]['cutoff']
+                cutoff = cutoff if cutoff is not None else "NULL"
+                clause_counter += 1
+
+                cutoffs.append(cutoff)
+                clauses.append(clause +
+                               "AND ({point_column_name})[{current_feature}]"
+                               " < {cutoff} ".format(**locals()))
+                clauses.append(clause +
+                               "AND ({point_column_name})[{current_feature}]"
+                               " >= {cutoff} ".format(**locals()))
+            current_feature = current_feature % n_features + 1
+
+        output_table_tree = add_postfix(output_table, "_tree")
+        plpy.execute("CREATE TABLE {0} AS "
+                     "SELECT ('{{ {1} }}')::DOUBLE PRECISION[] AS tree".
+                     format(output_table_tree,
+                            " ,".join(map(str, cutoffs))))
+
+        n_leaves = pow(2,depth)
+        case_when_clause = ' '.join(["WHEN {0} THEN {1}::INTEGER".format(cond, 
i)
+                                     for i, cond in 
enumerate(clauses[-n_leaves:])])
+        output_sql = """
+            CREATE TABLE {output_table} AS
+                SELECT *, CASE {case_when_clause} END AS {r_id}
+                FROM {source_table}""".format(**locals())
+        plpy.execute(output_sql)
+
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(centers_table))
+        centers_sql = """
+            CREATE TABLE {centers_table} AS
+                SELECT {r_id}, {schema_madlib}.array_scalar_mult(
+                        {schema_madlib}.sum({point_column_name}):: DOUBLE 
PRECISION[],
+                        (1.0/count(*))::DOUBLE PRECISION) AS __center__
+                FROM {output_table}
+                GROUP BY {r_id}
+            """.format(**locals())
+        plpy.execute(centers_sql)
+        return case_when_clause
+# 
------------------------------------------------------------------------------
+
+def validate_kd_tree(source_table, output_table, point_column_name, depth):
+
+    input_tbl_valid(source_table, 'kd_tree')
+    output_tbl_valid(output_table, 'kd_tree')
+    output_tbl_valid(output_table+"_tree", 'kd_tree')
+
+    _assert(is_var_valid(source_table, point_column_name),
+            "kd_tree error: {0} is an invalid column name or expression for "
+            "point_column_name param".format(point_column_name))
+    point_col_type = get_expr_type(point_column_name, source_table)
+    _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
+            "kNN Error: Feature column or expression '{0}' in train table is 
not"
+            " an array.".format(point_column_name))
+    if depth <= 0:
+        plpy.error("kNN Error: depth={0} is an invalid value, must be greater "
+                   "than 0.".format(depth))
+# 
------------------------------------------------------------------------------
+
+def knn_kd_tree(schema_madlib, kd_out, point_source, point_column_name, 
point_id,
+             label_column_name, test_source, test_column_name, test_id,
+             interim_table, in_k, output_neighbors, fn_dist, weighted_avg,
+             leaf_nodes, r_id, dim, label_out, comma_label_out_alias,
+             label_name, train, train_id, dist_inverse, test_id_temp,
+             case_when_clause, **kwargs):
+    """
+        KNN function to find the K Nearest neighbours using kd tree
+        Args:
+            @param schema_madlib        Name of the Madlib Schema
+            @param kd_out               Name of the kd tree table
+            @param point_source         Training data table
+            @param point_column_name    Name of the column with training data
+                                        or expression that evaluates to a
+                                        numeric array
+            @param point_id             Name of the column having ids of data
+                                        point in train data table
+                                        points.
+            @param label_column_name    Name of the column with labels/values
+                                        of training data points.
+            @param test_source          Name of the table containing the test
+                                        data points.
+            @param test_column_name     Name of the column with testing data
+                                        points or expression that evaluates to 
a
+                                        numeric array
+            @param test_id              Name of the column having ids of data
+                                        points in test data table.
+            @param interim_table        Name of the table to store interim
+                                        results.
+            @param in_k                 default: 1. Number of nearest
+                                        neighbors to consider
+            @param output_neighbours    Outputs the list of k-nearest neighbors
+                                        that were used in the voting/averaging.
+            @param fn_dist              Distance metrics function. Default is
+                                        squared_dist_norm2. Following functions
+                                        are supported :
+                                        dist_norm1 , 
dist_norm2,squared_dist_norm2,
+                                        dist_angle , dist_tanimoto
+                                        Or user defined function with signature
+                                        DOUBLE PRECISION[] x, DOUBLE 
PRECISION[] y -> DOUBLE PRECISION
+            @param weighted_avg         Calculates the Regression or 
classication of k-NN using
+                                        the weighted average method.
+            @param leaf_nodes           Number of leaf nodes to explore
+            @param r_id                 Name of the region id column
+            @param dim                  Name of the dimension column
+            Following parameters are passed to ensure the interim table has
+            identical features to non-kd-tree implementation
+            @param label_out
+            @param comma_label_out_alias
+            @param label_name
+            @param train
+            @param train_id
+            @param dist_inverse
+            @param test_id_temp
+            @param case_when_clause
+    """
+    with MinWarning("error"):
+
+        tree_model = add_postfix(kd_out, "_tree")
+        centers_table = add_postfix(kd_out, "_centers")
+        n_features = num_features(test_source, test_column_name)
+
+        tree = plpy.execute("SELECT * FROM {0}".format(tree_model))[0]['tree']
+        # 'tree' contains only non-leaf nodes,
+        # hence 'n_leaves' is always 1 more than len(tree)
+        n_leaves = len(tree)+1
+
+        depth = int(log(n_leaves, 2))
+
+        # The borders table will have two rows for each dimension (Upper & 
lower)
+        # even if a dimension does not have a branch.
+        # Its borders will be -Inf, Inf
+
+        # The first leaf_note is itself,
+        # we expand to n-1 nodes out of 2 * n_features borders
+        quant = float(leaf_nodes) / n_leaves
+
+        test_view = unique_string("test_view")
+        t_col_name = unique_string("t_col_name")
+        plpy.execute("DROP VIEW IF EXISTS {test_view}".format(**locals()))
+
+        test_view_sql = """
+            CREATE VIEW {test_view} AS
+                SELECT {test_id},
+                       {test_column_name}::DOUBLE PRECISION[] AS {t_col_name},
+                       CASE
+                       {case_when_clause}
+                       END AS {r_id}
+                FROM {test_source}""".format(**locals())
+        plpy.execute(test_view_sql)
+
+        if leaf_nodes > 1:
+            ext_test_view = unique_string("ext_test_view")
+            ext_test_view_sql = """
+                CREATE VIEW {ext_test_view} AS
+                    SELECT {test_id},
+                           {t_col_name},
+                           {centers_table}.{r_id}
+                    FROM {test_view} INNER JOIN
 
 Review comment:
   I'm confused about the goal of the below subquery. My understanding is we 
need `k` leaf nodes sorted in ascending order of distance to centers - wouldn't 
`row_number` with `order by` be good enough for that (similar to how `k` points 
are picked below)? 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to