reductionista commented on a change in pull request #496: URL: https://github.com/apache/madlib/pull/496#discussion_r421143072
########## File path: src/ports/postgres/modules/dbscan/dbscan.py_in ########## @@ -0,0 +1,331 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import plpy + +from utilities.control import MinWarning +from utilities.utilities import _assert +from utilities.utilities import unique_string +from utilities.utilities import add_postfix +from utilities.utilities import NUMERIC, ONLY_ARRAY +from utilities.utilities import is_valid_psql_type +from utilities.utilities import is_platform_pg +from utilities.validate_args import input_tbl_valid, output_tbl_valid +from utilities.validate_args import is_var_valid +from utilities.validate_args import cols_in_tbl_valid +from utilities.validate_args import get_expr_type +from utilities.validate_args import get_algorithm_name +from graph.wcc import wcc + +BRUTE_FORCE = 'brute_force' +KD_TREE = 'kd_tree' + +def dbscan(schema_madlib, source_table, output_table, id_column, expr_point, eps, min_samples, metric, algorithm, **kwargs): + + with MinWarning("warning"): + + min_samples = 5 if not min_samples else min_samples + metric = 'squared_dist_norm2' if not metric else metric + algorithm = 'brute' if not algorithm else algorithm + + algorithm = get_algorithm_name(algorithm, BRUTE_FORCE, + [BRUTE_FORCE, KD_TREE], 'DBSCAN') + + _validate_dbscan(schema_madlib, source_table, output_table, id_column, + expr_point, eps, min_samples, metric, algorithm) + + dist_src_sql = '' if is_platform_pg() else 'DISTRIBUTED BY (__src__)' + dist_id_sql = '' if is_platform_pg() else 'DISTRIBUTED BY ({0})'.format(id_column) + dist_reach_sql = '' if is_platform_pg() else 'DISTRIBUTED BY (__reachable_id__)' + + # Calculate pairwise distances + distance_table = unique_string(desp='distance_table') + plpy.execute("DROP TABLE IF EXISTS {0}".format(distance_table)) + + sql = """ + CREATE TABLE {distance_table} AS + SELECT __src__, __dest__ FROM ( + SELECT __t1__.{id_column} AS __src__, + __t2__.{id_column} AS __dest__, + {schema_madlib}.{metric}( + __t1__.{expr_point}, __t2__.{expr_point}) AS __dist__ + FROM {source_table} AS __t1__, {source_table} AS __t2__) q1 + WHERE __dist__ < {eps} + {dist_src_sql} + """.format(**locals()) + plpy.execute(sql) + + # Find core points + core_points_table = unique_string(desp='core_points_table') + plpy.execute("DROP TABLE IF EXISTS {0}".format(core_points_table)) + sql = """ + CREATE TABLE {core_points_table} AS + SELECT * FROM (SELECT __src__ AS {id_column}, count(*) AS __count__ + FROM {distance_table} GROUP BY __src__) q1 + WHERE __count__ >= {min_samples} + {dist_id_sql} + """.format(**locals()) + plpy.execute(sql) + + # Find the connections between core points to form the clusters + core_edge_table = unique_string(desp='core_edge_table') + plpy.execute("DROP TABLE IF EXISTS {0}".format(core_edge_table)) + sql = """ + CREATE TABLE {core_edge_table} AS + SELECT __src__, __dest__ + FROM {distance_table} AS __t1__, (SELECT array_agg({id_column}) AS arr + FROM {core_points_table}) __t2__ + WHERE __t1__.__src__ = ANY(arr) AND __t1__.__dest__ = ANY(arr) + {dist_src_sql} Review comment: Somehow a comment I left here got dropped when I submitted. I was just wondering, maybe we should add `AND __t1__.__src__ != __t1__.__dest__` to the WHERE clause? It don't know that it hurts anything to have an extra 0-length edge for each core point, but maybe adds some minor extra work for wcc to do having them in there. Or, maybe the filter should go earlier where the distance is computed instead? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
