orhankislal commented on a change in pull request #505: URL: https://github.com/apache/madlib/pull/505#discussion_r463230892
########## File path: src/ports/postgres/modules/dbscan/dbscan.py_in ########## @@ -26,81 +26,297 @@ from utilities.utilities import add_postfix from utilities.utilities import NUMERIC, ONLY_ARRAY from utilities.utilities import is_valid_psql_type from utilities.utilities import is_platform_pg +from utilities.utilities import num_features +from utilities.utilities import get_seg_number from utilities.validate_args import input_tbl_valid, output_tbl_valid from utilities.validate_args import is_var_valid from utilities.validate_args import cols_in_tbl_valid from utilities.validate_args import get_expr_type from utilities.validate_args import get_algorithm_name from graph.wcc import wcc +from math import log +from math import floor +from math import sqrt + +from scipy.spatial import distance + +try: + from rtree import index +except ImportError: + RTREE_ENABLED=0 +else: + RTREE_ENABLED=1 + BRUTE_FORCE = 'brute_force' KD_TREE = 'kd_tree' +DEFAULT_MIN_SAMPLES = 5 +DEFAULT_KD_DEPTH = 3 +DEFAULT_METRIC = 'squared_dist_norm2' -def dbscan(schema_madlib, source_table, output_table, id_column, expr_point, eps, min_samples, metric, algorithm, **kwargs): +def dbscan(schema_madlib, source_table, output_table, id_column, expr_point, + eps, min_samples, metric, algorithm, depth, **kwargs): with MinWarning("warning"): - min_samples = 5 if not min_samples else min_samples - metric = 'squared_dist_norm2' if not metric else metric - algorithm = 'brute' if not algorithm else algorithm + min_samples = DEFAULT_MIN_SAMPLES if not min_samples else min_samples + metric = DEFAULT_METRIC if not metric else metric + algorithm = BRUTE_FORCE if not algorithm else algorithm + depth = DEFAULT_KD_DEPTH if not depth else depth algorithm = get_algorithm_name(algorithm, BRUTE_FORCE, [BRUTE_FORCE, KD_TREE], 'DBSCAN') _validate_dbscan(schema_madlib, source_table, output_table, id_column, - expr_point, eps, min_samples, metric, algorithm) + expr_point, eps, min_samples, metric, algorithm, depth) dist_src_sql = '' if is_platform_pg() else 'DISTRIBUTED BY (__src__)' dist_id_sql = '' if is_platform_pg() else 'DISTRIBUTED BY ({0})'.format(id_column) dist_reach_sql = '' if is_platform_pg() else 'DISTRIBUTED BY (__reachable_id__)' + dist_leaf_sql = '' if is_platform_pg() else 'DISTRIBUTED BY (__leaf_id__)' - # Calculate pairwise distances + core_points_table = unique_string(desp='core_points_table') + core_edge_table = unique_string(desp='core_edge_table') distance_table = unique_string(desp='distance_table') - plpy.execute("DROP TABLE IF EXISTS {0}".format(distance_table)) + plpy.execute("DROP TABLE IF EXISTS {0}, {1}, {2}".format( + core_points_table, core_edge_table, distance_table)) + source_view = unique_string(desp='source_view') + plpy.execute("DROP VIEW IF EXISTS {0}".format(source_view)) sql = """ - CREATE TABLE {distance_table} AS - SELECT __src__, __dest__ FROM ( - SELECT __t1__.{id_column} AS __src__, - __t2__.{id_column} AS __dest__, - {schema_madlib}.{metric}( - __t1__.{expr_point}, __t2__.{expr_point}) AS __dist__ - FROM {source_table} AS __t1__, {source_table} AS __t2__ - WHERE __t1__.{id_column} != __t2__.{id_column}) q1 - WHERE __dist__ < {eps} - {dist_src_sql} + CREATE VIEW {source_view} AS + SELECT {id_column}, {expr_point} AS __expr_point__ + FROM {source_table} """.format(**locals()) plpy.execute(sql) + expr_point = '__expr_point__' + + if algorithm == KD_TREE: + cur_source_table, border_table1, border_table2 = dbscan_kd( + schema_madlib, source_view, id_column, expr_point, eps, + min_samples, metric, depth) + + kd_join_clause = "AND __t1__.__leaf_id__ = __t2__.__leaf_id__ " + + sql = """ + SELECT count(*), __leaf_id__ FROM {cur_source_table} GROUP BY __leaf_id__ + """.format(**locals()) + result = plpy.execute(sql) + rt_counts_dict = {} + for i in result: + rt_counts_dict[i['__leaf_id__']] = int(i['count']) + rt_counts_list = [] + for i in sorted(rt_counts_dict): + rt_counts_list.append(rt_counts_dict[i]) + + leaf_id_start = pow(2,depth)-1 + + find_core_points_table = unique_string(desp='find_core_points_table') + rt_edge_table = unique_string(desp='rt_edge_table') + rt_core_points_table = unique_string(desp='rt_core_points_table') + border_core_points_table = unique_string(desp='border_core_points_table') + border_edge_table = unique_string(desp='border_edge_table') + plpy.execute("DROP TABLE IF EXISTS {0}, {1}, {2}, {3}, {4}".format( + find_core_points_table, rt_edge_table, rt_core_points_table, + border_core_points_table, border_edge_table)) + + sql = """ + CREATE TABLE {find_core_points_table} AS + SELECT __leaf_id__, + {schema_madlib}.find_core_points( {id_column}, + {expr_point}::DOUBLE PRECISION[], + {eps}, + {min_samples}, + '{metric}', + ARRAY{rt_counts_list}, + __leaf_id__ + ) + FROM {cur_source_table} GROUP BY __leaf_id__ + {dist_leaf_sql} + """.format(**locals()) + plpy.execute(sql) + + sql = """ + CREATE TABLE {rt_edge_table} AS + SELECT (unpacked_2d).src AS __src__, (unpacked_2d).dest AS __dest__ + FROM ( + SELECT {schema_madlib}.unpack_2d(find_core_points) AS unpacked_2d + FROM {find_core_points_table} + ) q1 + WHERE (unpacked_2d).src NOT IN (SELECT {id_column} FROM {border_table1}) + {dist_src_sql} + """.format(**locals()) + plpy.execute(sql) + + sql = """ + CREATE TABLE {rt_core_points_table} AS + SELECT DISTINCT(__src__) AS {id_column} FROM {rt_edge_table} + """.format(**locals()) + plpy.execute(sql) + + # # Start border + sql = """ + CREATE TABLE {border_edge_table} AS + SELECT __src__, __dest__ FROM ( + SELECT __t1__.{id_column} AS __src__, + __t2__.{id_column} AS __dest__, + {schema_madlib}.{metric}( + __t1__.{expr_point}, __t2__.{expr_point}) AS __dist__ + FROM {border_table1} AS __t1__, {border_table2} AS __t2__)q1 + WHERE __dist__ < {eps} + """.format(**locals()) + plpy.execute(sql) + + sql = """ + CREATE TABLE {border_core_points_table} AS + SELECT * FROM ( + SELECT __src__ AS {id_column}, count(*) AS __count__ + FROM {border_edge_table} GROUP BY __src__) q1 + WHERE __count__ >= {min_samples} + {dist_id_sql} Review comment: That doesn't work since you need to calculate the `__count__` per group before using it in the having clause. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org