Github user kaknikhil commented on a diff in the pull request:
https://github.com/apache/madlib/pull/218#discussion_r157773474
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -0,0 +1,322 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import plpy
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.validate_args import table_exists
+from utilities.validate_args import columns_exist_in_table
+from utilities.validate_args import table_is_empty
+from utilities.validate_args import get_cols
+from utilities.utilities import py_list_to_sql_string
+
+m4_changequote(`<!', `!>')
+
+def balance_sample(schema_madlib, source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols, with_replacement,
**kwargs):
+
+ """
+ Balance sampling function
+ Args:
+ @param source_table Input table name.
+ @param output_table Output table name.
+ @param class_col Name of the column containing the class to
be
+ balanced.
+ @param with_replacement (Default: FALSE) The sampling method.
+
+ """
+ with MinWarning("warning"):
+
+ class_counts = unique_string(desp='class_counts')
+
+ _validate_strs(source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols)
+ source_table_columns = ','.join(get_cols(source_table))
+ grp_by = "GROUP BY {0}".format(class_col)
+ """
+ Frequency table for classes
+ """
+ plpy.execute(""" CREATE VIEW {class_counts} AS (
+ SELECT
+ {class_col} AS classes,
+ count(*) AS class_count
+ FROM {source_table}
+ {grp_by})
+ """.format(**locals()))
+
+ if class_sizes.lower() == 'undersample':
+
+ if not with_replacement:
+ """
+ Random undersample without replacement
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT
+ {source_table_columns}
+ FROM
+ (SELECT
+ *,
+ row_number() OVER(PARTITION BY
+ {class_col} ORDER BY random())
+ AS __row_no
+ FROM {source_table}) AS foo
+ WHERE __row_no <=
+ (SELECT
+ MIN(class_count)
+ FROM {class_counts}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+ else:
+ """
+ Random undersample with replacement
+ """
+ """
+ Create row identifiers for each row wrt the class
+ """
+ classwise_row_numbering_sql = """
+ SELECT
+ *,
+ row_number() OVER(PARTITION BY {class_col})
+ AS __row_no
+ FROM
+ {source_table}
+ """.format(**locals())
+ """
+ Create independent random values
+ for each class that has more than the min number of
rows
+ """
+ random_minorityclass_size_sample_number_gen_sql = """
+ SELECT
+ classes,
+ generate_series(1,minority_class_size) AS _i,
+ ((random()*({class_counts}.class_count-1)+1)::int)
+ AS __row_no
+ FROM
+ (SELECT
+ min(class_count) AS minority_class_size
+ FROM {class_counts})
+ AS foo,
+ {class_counts}
+ WHERE {class_counts}.class_count !=
minority_class_size
+ """.format(**locals())
+ """
+ Match random values with the row identifiers
+ """
+ undersample_otherclass_set = """
+ SELECT
+ {source_table_columns}
+ FROM
+ ({classwise_row_numbering_sql}) AS f1
+ RIGHT JOIN
+
({random_minorityclass_size_sample_number_gen_sql}) AS
+ f2
+ ON (f1.__row_no = f2.__row_no) AND
+ (f1.{class_col} = f2.classes)
+ """.format(**locals())
+ """
+ Find classes with minimum number of rows
+ """
+ minorityclass_set = """
+ SELECT
+ {source_table_columns}
+ FROM {source_table}
+ WHERE {class_col} IN
+ (SELECT
+ classes AS minority_class
+ FROM {class_counts}
+ WHERE class_count in
+ (SELECT min(class_count) FROM {class_counts}))
+ """.format(**locals())
+ """
+ Combine minority and other undersampled classes
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT {source_table_columns}
+ FROM
+ ({minorityclass_set}) AS a
+ UNION ALL
+ ({undersample_otherclass_set}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+
+ plpy.execute("DROP VIEW IF EXISTS {0}".format(class_counts))
+ return
+
+def _validate_strs (source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols):
+
+ _assert(source_table and source_table.strip().lower() not in ('null',
''),
--- End diff --
do we have any common functions that check for the validity of the source
table and the output table. I would think that the validation will remain the
same across modules.(??).
---