orhankislal commented on a change in pull request #505:
URL: https://github.com/apache/madlib/pull/505#discussion_r463230892



##########
File path: src/ports/postgres/modules/dbscan/dbscan.py_in
##########
@@ -26,81 +26,297 @@ from utilities.utilities import add_postfix
 from utilities.utilities import NUMERIC, ONLY_ARRAY
 from utilities.utilities import is_valid_psql_type
 from utilities.utilities import is_platform_pg
+from utilities.utilities import num_features
+from utilities.utilities import get_seg_number
 from utilities.validate_args import input_tbl_valid, output_tbl_valid
 from utilities.validate_args import is_var_valid
 from utilities.validate_args import cols_in_tbl_valid
 from utilities.validate_args import get_expr_type
 from utilities.validate_args import get_algorithm_name
 from graph.wcc import wcc
 
+from math import log
+from math import floor
+from math import sqrt
+
+from scipy.spatial import distance
+
+try:
+    from rtree import index
+except ImportError:
+    RTREE_ENABLED=0
+else:
+    RTREE_ENABLED=1
+
 BRUTE_FORCE = 'brute_force'
 KD_TREE = 'kd_tree'
+DEFAULT_MIN_SAMPLES = 5
+DEFAULT_KD_DEPTH = 3
+DEFAULT_METRIC = 'squared_dist_norm2'
 
-def dbscan(schema_madlib, source_table, output_table, id_column, expr_point, 
eps, min_samples, metric, algorithm, **kwargs):
+def dbscan(schema_madlib, source_table, output_table, id_column, expr_point,
+           eps, min_samples, metric, algorithm, depth, **kwargs):
 
     with MinWarning("warning"):
 
-        min_samples = 5 if not min_samples else min_samples
-        metric = 'squared_dist_norm2' if not metric else metric
-        algorithm = 'brute' if not algorithm else algorithm
+        min_samples = DEFAULT_MIN_SAMPLES if not min_samples else min_samples
+        metric = DEFAULT_METRIC if not metric else metric
+        algorithm = BRUTE_FORCE if not algorithm else algorithm
+        depth = DEFAULT_KD_DEPTH if not depth else depth
 
         algorithm = get_algorithm_name(algorithm, BRUTE_FORCE,
             [BRUTE_FORCE, KD_TREE], 'DBSCAN')
 
         _validate_dbscan(schema_madlib, source_table, output_table, id_column,
-                         expr_point, eps, min_samples, metric, algorithm)
+                         expr_point, eps, min_samples, metric, algorithm, 
depth)
 
         dist_src_sql = ''  if is_platform_pg() else 'DISTRIBUTED BY (__src__)'
         dist_id_sql = ''  if is_platform_pg() else 'DISTRIBUTED BY 
({0})'.format(id_column)
         dist_reach_sql = ''  if is_platform_pg() else 'DISTRIBUTED BY 
(__reachable_id__)'
+        dist_leaf_sql = ''  if is_platform_pg() else 'DISTRIBUTED BY 
(__leaf_id__)'
 
-        # Calculate pairwise distances
+        core_points_table = unique_string(desp='core_points_table')
+        core_edge_table = unique_string(desp='core_edge_table')
         distance_table = unique_string(desp='distance_table')
-        plpy.execute("DROP TABLE IF EXISTS {0}".format(distance_table))
+        plpy.execute("DROP TABLE IF EXISTS {0}, {1}, {2}".format(
+            core_points_table, core_edge_table, distance_table))
 
+        source_view = unique_string(desp='source_view')
+        plpy.execute("DROP VIEW IF EXISTS {0}".format(source_view))
         sql = """
-            CREATE TABLE {distance_table} AS
-            SELECT __src__, __dest__ FROM (
-                SELECT  __t1__.{id_column} AS __src__,
-                        __t2__.{id_column} AS __dest__,
-                        {schema_madlib}.{metric}(
-                            __t1__.{expr_point}, __t2__.{expr_point}) AS 
__dist__
-                FROM {source_table} AS __t1__, {source_table} AS __t2__
-                WHERE __t1__.{id_column} != __t2__.{id_column}) q1
-            WHERE __dist__ < {eps}
-            {dist_src_sql}
+            CREATE VIEW {source_view} AS
+            SELECT {id_column}, {expr_point} AS __expr_point__
+            FROM {source_table}
             """.format(**locals())
         plpy.execute(sql)
+        expr_point = '__expr_point__'
+
+        if algorithm == KD_TREE:
+            cur_source_table, border_table1, border_table2 = dbscan_kd(
+                schema_madlib, source_view, id_column, expr_point, eps,
+                min_samples, metric, depth)
+
+            kd_join_clause = "AND __t1__.__leaf_id__ = __t2__.__leaf_id__ "
+
+            sql = """
+                SELECT count(*), __leaf_id__ FROM {cur_source_table} GROUP BY 
__leaf_id__
+                """.format(**locals())
+            result = plpy.execute(sql)
+            rt_counts_dict = {}
+            for i in result:
+                rt_counts_dict[i['__leaf_id__']] = int(i['count'])
+            rt_counts_list = []
+            for i in sorted(rt_counts_dict):
+                rt_counts_list.append(rt_counts_dict[i])
+
+            leaf_id_start = pow(2,depth)-1
+
+            find_core_points_table = 
unique_string(desp='find_core_points_table')
+            rt_edge_table = unique_string(desp='rt_edge_table')
+            rt_core_points_table = unique_string(desp='rt_core_points_table')
+            border_core_points_table = 
unique_string(desp='border_core_points_table')
+            border_edge_table = unique_string(desp='border_edge_table')
+            plpy.execute("DROP TABLE IF EXISTS {0}, {1}, {2}, {3}, {4}".format(
+                find_core_points_table, rt_edge_table, rt_core_points_table,
+                border_core_points_table, border_edge_table))
+
+            sql = """
+            CREATE TABLE {find_core_points_table} AS
+            SELECT __leaf_id__,
+                   {schema_madlib}.find_core_points( {id_column},
+                                               {expr_point}::DOUBLE 
PRECISION[],
+                                               {eps},
+                                               {min_samples},
+                                               '{metric}',
+                                               ARRAY{rt_counts_list},
+                                               __leaf_id__
+                                               )
+            FROM {cur_source_table} GROUP BY __leaf_id__
+            {dist_leaf_sql}
+            """.format(**locals())
+            plpy.execute(sql)
+
+            sql = """
+            CREATE TABLE {rt_edge_table} AS
+            SELECT (unpacked_2d).src AS __src__, (unpacked_2d).dest AS __dest__
+            FROM (
+                SELECT {schema_madlib}.unpack_2d(find_core_points) AS 
unpacked_2d
+                FROM {find_core_points_table}
+                ) q1
+            WHERE (unpacked_2d).src NOT IN (SELECT {id_column} FROM 
{border_table1})
+            {dist_src_sql}
+            """.format(**locals())
+            plpy.execute(sql)
+
+            sql = """
+                CREATE TABLE {rt_core_points_table} AS
+                SELECT DISTINCT(__src__) AS {id_column} FROM {rt_edge_table}
+                """.format(**locals())
+            plpy.execute(sql)
+
+            # # Start border
+            sql = """
+                CREATE TABLE {border_edge_table} AS
+                SELECT __src__, __dest__ FROM (
+                    SELECT  __t1__.{id_column} AS __src__,
+                            __t2__.{id_column} AS __dest__,
+                            {schema_madlib}.{metric}(
+                                __t1__.{expr_point}, __t2__.{expr_point}) AS 
__dist__
+                    FROM {border_table1} AS __t1__, {border_table2} AS 
__t2__)q1
+                WHERE __dist__ < {eps}
+                """.format(**locals())
+            plpy.execute(sql)
+
+            sql = """
+                CREATE TABLE {border_core_points_table} AS
+                SELECT * FROM (
+                    SELECT __src__ AS {id_column}, count(*) AS __count__
+                    FROM {border_edge_table} GROUP BY __src__) q1
+                WHERE __count__ >= {min_samples}
+                {dist_id_sql}

Review comment:
       That doesn't work since you need to calculate the `__count__` per group 
before using it in the having clause.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to