Github user njayaram2 commented on a diff in the pull request: https://github.com/apache/madlib/pull/223#discussion_r162515193 --- Diff: src/ports/postgres/modules/sample/balance_sample.py_in --- @@ -0,0 +1,994 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file EXCEPT in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import math +import plpy +import re +from collections import defaultdict +from fractions import Fraction +from utilities.control import MinWarning +from utilities.utilities import _assert +from utilities.utilities import unique_string +from utilities.validate_args import table_exists +from utilities.validate_args import columns_exist_in_table +from utilities.validate_args import table_is_empty +from utilities.validate_args import get_cols +from utilities.utilities import py_list_to_sql_string + + +m4_changequote(`<!', `!>') + +def balance_sample(schema_madlib, source_table, output_table, class_col, + class_sizes, output_table_size, grouping_cols, with_replacement, **kwargs): + + """ + Balance sampling function + Args: + @param source_table Input table name. + @param output_table Output table name. + @param class_col Name of the column containing the class to be + balanced. + @param class_size Parameter to define the size of the different + class values. + @param output_table_size Desired size of the output data set. + @param grouping_cols The columns columns that defines the grouping. + @param with_replacement The sampling method. + + """ + with MinWarning("warning"): + + class_counts = unique_string(desp='class_counts') + desired_sample_per_class = unique_string(desp='desired_sample_per_class') + desired_counts = unique_string(desp='desired_counts') + + if not class_sizes or class_sizes.strip().lower() in ('null', ''): + class_sizes = 'uniform' + + _validate_strs(source_table, output_table, class_col, class_sizes, + output_table_size, grouping_cols, with_replacement) + + source_table_columns = ','.join(get_cols(source_table)) + grp_by = "GROUP BY {0}".format(class_col) + + _create_frequency_distribution(class_counts, source_table, class_col) + temp_views = [class_counts] + + if class_sizes.lower() == 'undersample' and not with_replacement: + """ + Random undersample without replacement. + Randomly order the rows and give a unique (per class) + identifier to each one. + Select rows that have identifiers under the target limit. + """ + _undersampling_with_no_replacement(source_table, output_table, class_col, + class_sizes, output_table_size, grouping_cols, with_replacement, + class_counts, source_table_columns) + + _delete_temp_views(temp_views) + return + + """ + Create views for true and desired sample sizes of classes + """ + """ + include_unsampled_classes tracks is unsampled classes are desired or not. + include_unsampled_classes is always true in output_table_size Null cases but changes given values of desired sample class sizes in comma-delimited classsize paramter. + """ + include_unsampled_classes = True + sampling_with_comma_delimited_class_sizes = class_sizes.find(':') > 0 + + if sampling_with_comma_delimited_class_sizes: + """ + Compute sample sizes based on + comman-delimited list of class_sizes + and/or output_table_size + """ + class_sizes, include_unsampled_classes = _validate_format_and_values(class_sizes, source_table, + class_col, output_table_size, class_counts, include_unsampled_classes) + + """ + Only valid condition for sampling is desired_sample_sizes <= output_table_size + """ + temp_views.extend(_create_desired_and_actual_sampling_views(class_counts, + desired_sample_per_class, desired_counts + , source_table, output_table, class_col + , class_sizes, output_table_size, include_unsampled_classes)) + + if class_sizes.lower() == 'uniform': + """ + Compute sample sizes based on + uniform distribution of class sizes + """ + temp_views.extend(_compute_uniform_class_sizes( + class_counts, desired_sample_per_class, desired_counts + , source_table, output_table, class_col, class_sizes, + output_table_size)) + + oversampling_specific_classes = False + desired_undersample_class_sizes = defaultdict(str) + + if sampling_with_comma_delimited_class_sizes or class_sizes.lower() == 'uniform': + + oversampling_specific_classes = plpy.execute(""" + SELECT * FROM {desired_sample_per_class} + WHERE category = 'oversample' + """.format(**locals())).nrows() > 0 + if oversampling_specific_classes: + with_replacement = True + + undersampling_res = plpy.execute(""" + SELECT array_agg(classes::text || ':' || sample_class_size::text) + as undersample_set FROM {desired_sample_per_class} + WHERE category = 'undersample' + """.format(**locals())) + if undersampling_res.nrows() > 0 and undersampling_res[0]['undersample_set'] is not None: + for val in undersampling_res[0]['undersample_set']: + desired_undersample_class_sizes[val.split(':')[0]] = val.split(':')[1] + + if class_sizes.lower() == 'oversample': + """ + oversampling with replacement + """ + with_replacement = True + func_name = 'max' + + if class_sizes.lower() == 'undersample' and with_replacement: + """ + Undersampling with replacement. + """ + func_name = 'min' + + if with_replacement: + """ + Random sample with replacement. + Undersample will have func_name set to min + Oversample will have func_name set to max. + """ + """ + Create row identifiers for each row wrt the class + """ + classwise_row_numbering_sql = """ + SELECT + *, + row_number() OVER(PARTITION BY {class_col}) + AS __row_no + FROM + {source_table} + """.format(**locals()) + if oversampling_specific_classes: + select_oversample_classes = """ WHERE {class_col}::text in + (SELECT classes + FROM {desired_sample_per_class} + WHERE category like 'oversample') + """.format(**locals()) + classwise_row_numbering_sql += select_oversample_classes + + """ + Create independent random values + for each class that has a different row count than the target + """ + if oversampling_specific_classes: + random_targetclass_size_sample_number_gen_sql = """ + SELECT + {desired_sample_per_class}.classes, + generate_series(1, sample_class_size::int) AS _i, + ((random()*({class_counts}.class_count-1)+1)::int) + AS __row_no + FROM + {class_counts}, + {desired_sample_per_class} + WHERE + {desired_sample_per_class}.classes = {class_counts}.classes + AND category like 'oversample' + """.format(**locals()) + else: + random_targetclass_size_sample_number_gen_sql = """ + SELECT + classes, + generate_series(1, target_class_size::int) AS _i, + ((random()*({class_counts}.class_count-1)+1)::int) + AS __row_no + FROM + (SELECT + {func_name}(class_count) AS target_class_size + FROM {class_counts}) + AS foo, + {class_counts} + WHERE {class_counts}.class_count != target_class_size + """.format(**locals()) + + """ + Match random values with the row identifiers + """ + sample_otherclass_set = """ + SELECT + {source_table_columns} + FROM + ({classwise_row_numbering_sql}) AS f1 + RIGHT JOIN + ({random_targetclass_size_sample_number_gen_sql}) AS + f2 + ON (f1.__row_no = f2.__row_no) AND + (f1.{class_col}::text = f2.classes) + """.format(**locals()) + + if not oversampling_specific_classes: + """ + Find classes with target number of rows + """ + targetclass_set = """ + SELECT + {source_table_columns} + FROM {source_table} + WHERE {class_col}::text IN + (SELECT + classes AS target_class + FROM {class_counts} + WHERE class_count in + (SELECT {func_name}(class_count) FROM {class_counts})) + """.format(**locals()) + + """ + Combine target and other sampled classes + """ + output_sql = """ + CREATE TABLE {output_table} AS ( + SELECT {source_table_columns} + FROM + ({targetclass_set}) AS a + UNION ALL + ({sample_otherclass_set})) + """.format(**locals()) + plpy.execute(output_sql) + + _delete_temp_views(temp_views) + return + + """ + Unsampled classes + """ + nosample_classset_sql = """ + SELECT + {source_table_columns} + FROM {source_table} + WHERE {class_col}::text IN + (SELECT + classes + FROM {desired_sample_per_class} + WHERE category like 'nosample') + """.format(**locals()) + """ + Union all Undersampled classes + """ + undersampling_classset_sql = '' + if len(desired_undersample_class_sizes) > 0: + undersampling_classset_sql = ' UNION ALL'.join(""" + (SELECT {source_table_columns} + FROM {source_table} + WHERE {class_col} = '{clas}' + ORDER BY random() + LIMIT {limit_bound}) + """.format(source_table_columns=source_table_columns, + source_table=source_table, + class_col=class_col, + limit_bound=clas_limit, + clas=clas) for clas, clas_limit in desired_undersample_class_sizes.iteritems()) + undersampling_classset_sql = ' UNION ALL ' + undersampling_classset_sql + + """ + Union all Oversampled classes + """ + oversampling_specific_classes_classset_sql = '' + if oversampling_specific_classes: + oversampling_specific_classes_classset_sql = """ + UNION ALL + ({sample_otherclass_set}) + """.format(**locals()) + + if (oversampling_specific_classes or len(desired_undersample_class_sizes) > 0): + """ + Combine all sampled and/or unsampled classes + """ + if not include_unsampled_classes: + nosample_classset_sql.replace('nosample', '') + + output_sql = """ + CREATE TABLE {output_table} AS ( + SELECT {source_table_columns} + FROM + ({nosample_classset_sql}) AS a + {oversampling_specific_classes_classset_sql} + {undersampling_classset_sql}) + """.format(**locals()) + + plpy.execute(output_sql) + + _delete_temp_views(temp_views) + return + +""" + Delete all temp views +""" +def _delete_temp_views(temp_views): + for temp_view in temp_views: + plpy.execute("DROP VIEW IF EXISTS {0} cascade".format(temp_view)) + return + +""" + Random undersample without replacement. +""" +def _undersampling_with_no_replacement(source_table, output_table, class_col, + class_sizes, output_table_size, grouping_cols, with_replacement, + class_counts, source_table_columns): + + distinct_class_labels = plpy.execute(""" + SELECT array_agg(DISTINCT {class_col}::text) AS labels + FROM {source_table} + """.format(**locals()))[0]['labels'] + + limit_bound = plpy.execute(""" + SELECT MIN(class_count)::int AS min + FROM {class_counts}""".format(**locals()))[0]['min'] + + minority_class = plpy.execute(""" + SELECT array_agg(classes::text) as minority_class + FROM {class_counts} + WHERE class_count = {limit_bound} + """.format(**locals()))[0]['minority_class'] + + distinct_class_labels = [cl for cl in distinct_class_labels + if cl not in minority_class] + + foo_table = unique_string(desp='foo') + start_output_qry = """ --- End diff -- Changed the query to have a single query for all.
---