Github user orhankislal commented on a diff in the pull request: https://github.com/apache/madlib/pull/223#discussion_r161864354 --- Diff: src/ports/postgres/modules/sample/balance_sample.py_in --- @@ -0,0 +1,994 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file EXCEPT in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import math +import plpy +import re +from collections import defaultdict +from fractions import Fraction +from utilities.control import MinWarning +from utilities.utilities import _assert +from utilities.utilities import unique_string +from utilities.validate_args import table_exists +from utilities.validate_args import columns_exist_in_table +from utilities.validate_args import table_is_empty +from utilities.validate_args import get_cols +from utilities.utilities import py_list_to_sql_string + + +m4_changequote(`<!', `!>') + +def balance_sample(schema_madlib, source_table, output_table, class_col, + class_sizes, output_table_size, grouping_cols, with_replacement, **kwargs): + + """ + Balance sampling function + Args: + @param source_table Input table name. + @param output_table Output table name. + @param class_col Name of the column containing the class to be + balanced. + @param class_size Parameter to define the size of the different + class values. + @param output_table_size Desired size of the output data set. + @param grouping_cols The columns columns that defines the grouping. + @param with_replacement The sampling method. + + """ + with MinWarning("warning"): + + class_counts = unique_string(desp='class_counts') + desired_sample_per_class = unique_string(desp='desired_sample_per_class') + desired_counts = unique_string(desp='desired_counts') + + if not class_sizes or class_sizes.strip().lower() in ('null', ''): + class_sizes = 'uniform' + + _validate_strs(source_table, output_table, class_col, class_sizes, + output_table_size, grouping_cols, with_replacement) + + source_table_columns = ','.join(get_cols(source_table)) + grp_by = "GROUP BY {0}".format(class_col) + + _create_frequency_distribution(class_counts, source_table, class_col) + temp_views = [class_counts] + + if class_sizes.lower() == 'undersample' and not with_replacement: + """ + Random undersample without replacement. + Randomly order the rows and give a unique (per class) + identifier to each one. + Select rows that have identifiers under the target limit. + """ + _undersampling_with_no_replacement(source_table, output_table, class_col, + class_sizes, output_table_size, grouping_cols, with_replacement, + class_counts, source_table_columns) + + _delete_temp_views(temp_views) + return + + """ + Create views for true and desired sample sizes of classes + """ + """ + include_unsampled_classes tracks is unsampled classes are desired or not. + include_unsampled_classes is always true in output_table_size Null cases but changes given values of desired sample class sizes in comma-delimited classsize paramter. + """ + include_unsampled_classes = True + sampling_with_comma_delimited_class_sizes = class_sizes.find(':') > 0 + + if sampling_with_comma_delimited_class_sizes: + """ + Compute sample sizes based on + comman-delimited list of class_sizes + and/or output_table_size + """ + class_sizes, include_unsampled_classes = _validate_format_and_values(class_sizes, source_table, + class_col, output_table_size, class_counts, include_unsampled_classes) + + """ + Only valid condition for sampling is desired_sample_sizes <= output_table_size + """ + temp_views.extend(_create_desired_and_actual_sampling_views(class_counts, + desired_sample_per_class, desired_counts + , source_table, output_table, class_col + , class_sizes, output_table_size, include_unsampled_classes)) + + if class_sizes.lower() == 'uniform': + """ + Compute sample sizes based on + uniform distribution of class sizes + """ + temp_views.extend(_compute_uniform_class_sizes( + class_counts, desired_sample_per_class, desired_counts + , source_table, output_table, class_col, class_sizes, + output_table_size)) + + oversampling_specific_classes = False + desired_undersample_class_sizes = defaultdict(str) + + if sampling_with_comma_delimited_class_sizes or class_sizes.lower() == 'uniform': + + oversampling_specific_classes = plpy.execute(""" + SELECT * FROM {desired_sample_per_class} + WHERE category = 'oversample' + """.format(**locals())).nrows() > 0 + if oversampling_specific_classes: + with_replacement = True + + undersampling_res = plpy.execute(""" + SELECT array_agg(classes::text || ':' || sample_class_size::text) + as undersample_set FROM {desired_sample_per_class} + WHERE category = 'undersample' + """.format(**locals())) + if undersampling_res.nrows() > 0 and undersampling_res[0]['undersample_set'] is not None: + for val in undersampling_res[0]['undersample_set']: + desired_undersample_class_sizes[val.split(':')[0]] = val.split(':')[1] + + if class_sizes.lower() == 'oversample': + """ + oversampling with replacement + """ + with_replacement = True + func_name = 'max' + + if class_sizes.lower() == 'undersample' and with_replacement: + """ + Undersampling with replacement. + """ + func_name = 'min' + + if with_replacement: + """ + Random sample with replacement. + Undersample will have func_name set to min + Oversample will have func_name set to max. + """ + """ + Create row identifiers for each row wrt the class + """ + classwise_row_numbering_sql = """ + SELECT + *, + row_number() OVER(PARTITION BY {class_col}) + AS __row_no + FROM + {source_table} + """.format(**locals()) + if oversampling_specific_classes: + select_oversample_classes = """ WHERE {class_col}::text in + (SELECT classes + FROM {desired_sample_per_class} + WHERE category like 'oversample') + """.format(**locals()) + classwise_row_numbering_sql += select_oversample_classes + + """ + Create independent random values + for each class that has a different row count than the target + """ + if oversampling_specific_classes: + random_targetclass_size_sample_number_gen_sql = """ + SELECT + {desired_sample_per_class}.classes, + generate_series(1, sample_class_size::int) AS _i, + ((random()*({class_counts}.class_count-1)+1)::int) + AS __row_no + FROM + {class_counts}, + {desired_sample_per_class} + WHERE + {desired_sample_per_class}.classes = {class_counts}.classes + AND category like 'oversample' + """.format(**locals()) + else: + random_targetclass_size_sample_number_gen_sql = """ + SELECT + classes, + generate_series(1, target_class_size::int) AS _i, + ((random()*({class_counts}.class_count-1)+1)::int) + AS __row_no + FROM + (SELECT + {func_name}(class_count) AS target_class_size + FROM {class_counts}) + AS foo, + {class_counts} + WHERE {class_counts}.class_count != target_class_size + """.format(**locals()) + + """ + Match random values with the row identifiers + """ + sample_otherclass_set = """ + SELECT + {source_table_columns} + FROM + ({classwise_row_numbering_sql}) AS f1 + RIGHT JOIN + ({random_targetclass_size_sample_number_gen_sql}) AS + f2 + ON (f1.__row_no = f2.__row_no) AND + (f1.{class_col}::text = f2.classes) + """.format(**locals()) + + if not oversampling_specific_classes: + """ + Find classes with target number of rows + """ + targetclass_set = """ + SELECT + {source_table_columns} + FROM {source_table} + WHERE {class_col}::text IN + (SELECT + classes AS target_class + FROM {class_counts} + WHERE class_count in + (SELECT {func_name}(class_count) FROM {class_counts})) + """.format(**locals()) + + """ + Combine target and other sampled classes + """ + output_sql = """ + CREATE TABLE {output_table} AS ( + SELECT {source_table_columns} + FROM + ({targetclass_set}) AS a + UNION ALL + ({sample_otherclass_set})) + """.format(**locals()) + plpy.execute(output_sql) + + _delete_temp_views(temp_views) + return + + """ + Unsampled classes + """ + nosample_classset_sql = """ + SELECT + {source_table_columns} + FROM {source_table} + WHERE {class_col}::text IN + (SELECT + classes + FROM {desired_sample_per_class} + WHERE category like 'nosample') + """.format(**locals()) + """ + Union all Undersampled classes + """ + undersampling_classset_sql = '' + if len(desired_undersample_class_sizes) > 0: + undersampling_classset_sql = ' UNION ALL'.join(""" + (SELECT {source_table_columns} + FROM {source_table} + WHERE {class_col} = '{clas}' + ORDER BY random() + LIMIT {limit_bound}) + """.format(source_table_columns=source_table_columns, + source_table=source_table, + class_col=class_col, + limit_bound=clas_limit, + clas=clas) for clas, clas_limit in desired_undersample_class_sizes.iteritems()) + undersampling_classset_sql = ' UNION ALL ' + undersampling_classset_sql + + """ + Union all Oversampled classes + """ + oversampling_specific_classes_classset_sql = '' + if oversampling_specific_classes: + oversampling_specific_classes_classset_sql = """ + UNION ALL + ({sample_otherclass_set}) + """.format(**locals()) + + if (oversampling_specific_classes or len(desired_undersample_class_sizes) > 0): + """ + Combine all sampled and/or unsampled classes + """ + if not include_unsampled_classes: + nosample_classset_sql.replace('nosample', '') + + output_sql = """ + CREATE TABLE {output_table} AS ( + SELECT {source_table_columns} + FROM + ({nosample_classset_sql}) AS a + {oversampling_specific_classes_classset_sql} + {undersampling_classset_sql}) + """.format(**locals()) + + plpy.execute(output_sql) + + _delete_temp_views(temp_views) + return + +""" + Delete all temp views +""" +def _delete_temp_views(temp_views): + for temp_view in temp_views: + plpy.execute("DROP VIEW IF EXISTS {0} cascade".format(temp_view)) + return + +""" + Random undersample without replacement. +""" +def _undersampling_with_no_replacement(source_table, output_table, class_col, + class_sizes, output_table_size, grouping_cols, with_replacement, + class_counts, source_table_columns): + + distinct_class_labels = plpy.execute(""" + SELECT array_agg(DISTINCT {class_col}::text) AS labels + FROM {source_table} + """.format(**locals()))[0]['labels'] + + limit_bound = plpy.execute(""" + SELECT MIN(class_count)::int AS min + FROM {class_counts}""".format(**locals()))[0]['min'] + + minority_class = plpy.execute(""" + SELECT array_agg(classes::text) as minority_class + FROM {class_counts} + WHERE class_count = {limit_bound} + """.format(**locals()))[0]['minority_class'] + + distinct_class_labels = [cl for cl in distinct_class_labels + if cl not in minority_class] + + foo_table = unique_string(desp='foo') + start_output_qry = """ + SELECT {source_table_columns} + FROM ( + SELECT {source_table_columns} + FROM {source_table} + WHERE {class_col}::text = '{dcl}' + ORDER BY random() + LIMIT {limit_bound} + ) AS {foo_table} + UNION """.format(dcl=distinct_class_labels[0], **locals()) + + union_qry = ' UNION '.join(""" + (SELECT {source_table_columns} + FROM {source_table} + WHERE {class_col}::text = '{val}' + ORDER BY random() + LIMIT {limit_bound}) + """.format(source_table_columns=source_table_columns, + source_table=source_table, + class_col=class_col, + limit_bound=limit_bound, + val=val) for val in distinct_class_labels[1:]) + + min_class_tuple = "('" + "','".join([str(a) for a in minority_class]) + "')" + + minority_sql = """ UNION + SELECT {source_table_columns} + FROM {source_table} + WHERE {class_col} IN {min_class_tuple} """.format(**locals()) + + output_sql = """ + CREATE TABLE {output_table} AS ( + {start_output_qry} + {union_qry} + {minority_sql} )""".format(**locals()) + plpy.execute(output_sql) + +""" + Captures cases where classes are specified multiple time in comma-delimited string. + e.g. class_sizes = '5:6,5:4' +""" +class UniqueDict(dict): + def __setitem__(self, classkey, class_size): + if classkey not in self: + """ + float(class_size).is_integer() ensures only whole numbers are added as + class's sample size + """ + if (class_size > 0.0 or class_size < 1.0) or (class_size >= 1.0 and float(class_size).is_integer()): + dict.__setitem__(self, classkey, class_size) + else: + plpy.error("Sample: Sample size should be a fraction between (0.0,1.0) or a whole number greater than 1") + else: + plpy.error("Sample: Repeated classes in class_sizes") + +""" + Check if the class sizes (passed as strings) are alphanumeric, float or whole numbers. + Error out on alphanumeric values. + e.g. class_sizes = '3:a,5:b + Returns type of the size as int or float. +""" +def _check_value_type(value): + try: + float(value) + except ValueError: + plpy.error("Sample: Specify either fractions (0.0,1.0) or whole numbers for class sample size.") + + if re.match("^\d+?\.\d+?$", value) is not None and not float(value).is_integer(): + valueType = float + elif (re.match(r"[-+]?\d+$", value) is not None): + valueType = int + return valueType +""" + Check if classes are present in class_col +""" +def _validate_classes(all_classes, source_table, class_col): + + nonexisting_classes = plpy.execute(""" + SELECT + unnest(ARRAY[{all_classes}]) + EXCEPT + SELECT + distinct({class_col}::text) + FROM {source_table} + """.format(**locals())) + + if nonexisting_classes.nrows() > 0: + plpy.error("""Sample: Specified classes do not exist in + {class_col}""".format(**locals())) + +""" + Checks the format and values of classes and their respective sizes specified in comman-delimited string. + 1. Checks if the classes specified are present in the source table. + 2. Checks for total sample size to be between (0.0,1.0] + 3. Checks for output_table_size < total desired size of the classes + specified in class_sizes + 4. Checks for cases when only fractions or whole numbers specified in sample sizes + 5. Checks for value is a whole number greater then 1 + 6. Checks if same class is specified multiple times in class_sizes + 7. Checks is classes are present in class_col +""" +def _validate_format_and_values(class_sizes, source_table, class_col, + output_table_size, class_counts, include_unsampled_classes): + + class_sizes_arr = class_sizes.split(',') + + cs_dict = UniqueDict(defaultdict()) + + numeric_value_sum = 0 + fraction_value_sum = 0.0 + + for x in class_sizes_arr: + class_and_size = x.split(':') + valueType = _check_value_type(class_and_size[1]) + # Following error type is invalidated by Frank as of 29th Dec. + # if _check_value_type(val[1].strip(), valueType) != valueType: + #plpy.error("Sample: Specify either fractions or whole number values for ALL classes.") + fraction_value_sum += valueType(class_and_size[1]) if valueType == float else 0.0 + numeric_value_sum += valueType(class_and_size[1]) if valueType != float else 0 + cs_dict[class_and_size[0].strip()] = valueType(class_and_size[1]) + + """ + Check to see if specified classes are present in the class_col + """ + all_classes = str(cs_dict.keys())[1:-1] + _validate_classes(all_classes, source_table, class_col) + + """ + Error out if fraction_value_sum is greater than 1.0 or when fraction sum is 1.0 and other classes with whole numbers as class sizes are also specified. + """ + if fraction_value_sum > 1.0 or (fraction_value_sum == 1.0 and numeric_value_sum != 0): + plpy.error("""Sample: Fraction sum < 1.0, when any other class is also specified as class_name:class_size-in-whole-numbers. Fraction sum can be at most 1.0. + """.format(**locals())) + + total_table_size = plpy.execute(""" + SELECT + count(*) AS total + FROM {source_table} + """.format(**locals()))[0]['total'] + + """ + Compute class sizes when no Fractions are specified in class_sizes + """ + if Fraction(fraction_value_sum) == Fraction(0.0): + if (not output_table_size): + # Sample remaining classes uniformly + return class_sizes, include_unsampled_classes + + if output_table_size < numeric_value_sum: + plpy.error("""Sample: Output table size ({output_table_size}) must be more than total specified sample size i.e. {numeric_value_sum}""".format(**locals())) + + if output_table_size == numeric_value_sum: + # Do not sample other classes + return class_sizes, not include_unsampled_classes + + # Sample remaining classes uniformly with target table size + return class_sizes, include_unsampled_classes + + """ + Compute class sizes when only Fractions are specified in class_sizes, which also sum to 1.0 + """ + if Fraction(fraction_value_sum) == Fraction(1.0): + # Do not sample other classes + return _compute_class_sizes(cs_dict, total_table_size)[0], not include_unsampled_classes + + """ + Compute sample classs size when both fractions and whole numbers are mentioned in class_size comma-delimited string + """ + if Fraction(fraction_value_sum) > Fraction(0.0): + + sum_remaining_class_samples = plpy.execute(""" + SELECT sum({class_counts}.class_count) AS remaining_classes FROM {class_counts} + WHERE classes not IN ({all_classes}) + """.format(**locals())) + """ + When output_table_size is Null. Use following example logic to compute desired sample sizes. + + Suppose male=.4,output_table_size= NULL and letâs say there are 2 other categorical values female=10M, other=1M + Use the following logic to calculate 'computed' output_table_size x + .4x + 10M + 1M = x + where x = computed output_table_size. Here x = 18.3M + """ + if (not output_table_size): + y = 1.0 - fraction_value_sum + + if sum_remaining_class_samples.nrows() > 0: + numeric_value_sum += sum_remaining_class_samples[0]['remaining_classes'] + + if numeric_value_sum == 0: + # A rare case happens WHEN there is ONLY one class present in class_col + return _compute_class_sizes(cs_dict, total_table_size)[0], not include_unsampled_classes + + # Compute total_desired_sample_size as x + x = math.ceil(float(numeric_value_sum) / y) + class_size, _ = _compute_class_sizes(cs_dict, x) + return class_size, include_unsampled_classes + + """ + When output_table_size is given, + compute the total_desired_sample_size and perform checks to ensure validity of class_size with total_desired_sample_size + """ + output_table_size = float(output_table_size) + class_size, total_desired_sample_size = _compute_class_sizes(cs_dict, output_table_size) + + ## Cases when total desired sample size > specified output table size + if total_desired_sample_size > output_table_size: + plpy.error("""Sample: Output table size ({output_table_size}) must be more than total desired sample size i.e {total_desired_sample_size}""".format(**locals())) + + if total_desired_sample_size == output_table_size: + # Do not sample other classes + return class_size, not include_unsampled_classes + + if total_desired_sample_size < output_table_size: + # Sample other classes uniformly + return class_size, include_unsampled_classes + +""" + Proportions of class sizes are multiplied by output_table_size to get desired whole number value for class sizes. + A total_desired_size is sum of all whole number class sizes +""" +def _compute_class_sizes(cs_dict, x): + class_size = '' + total_desired_size = 0 + for clas, class_val in cs_dict.iteritems(): + if float(class_val).is_integer(): + total_desired_size += class_val + class_size += str(clas) + ':' + str(class_val) + ',' + else: + class_val = int(round(class_val * x, 0)) + total_desired_size += class_val + class_size += str(clas) + ':' + str(class_val) + ',' + + return class_size[:-1], total_desired_size + +""" + Create view to store class counts of classes in class_col +""" +def _create_frequency_distribution(class_counts, source_table, class_col): + + grp_by = "GROUP BY {0}".format(class_col) + plpy.execute(""" CREATE VIEW {class_counts} AS ( + SELECT + {class_col}::text AS classes, + count(*) AS class_count + FROM {source_table} + {grp_by}) + """.format(**locals())) + +""" + Create view desired_sample_per_class which contains the desired sampling for each class. + """ +def _create_desired_and_actual_sampling_views(class_counts, desired_sample_per_class, desired_counts + , source_table, output_table, class_col, class_sizes, output_table_size, include_unsampled_classes): + + # Split class_sizes + if len(class_sizes.split(',')) > 1: + values = '\',\''.join(class_sizes.split(',')) + else: + values = class_sizes.split(',')[0] + + create_desired_counts_sql = """ + CREATE VIEW {desired_counts} AS ( + SELECT val[1]::text AS desired_classes, + val[2]::int AS desired_values + FROM + (SELECT regexp_split_to_array(unnest(ARRAY['{values}']),E':') AS val) + AS foo + ) + """.format(**locals()) + temp_views = [desired_counts] + plpy.execute(create_desired_counts_sql) + + if not include_unsampled_classes: + """ + No extra/other classes desired, other than the ones + specified in the class_sizes parameter + """ + desired_sample_per_class_sql = """ + CREATE VIEW {desired_sample_per_class} AS ( + SELECT + desired_classes AS classes, + desired_values::int AS sample_class_size, + 1::int as actual_class_size, + CASE + WHEN (desired_values - class_count) = 0 THEN 'nosample' + WHEN (desired_values - class_count) > 0 THEN 'oversample' ELSE 'undersample' + END AS category + FROM + {class_counts} + JOIN + {desired_counts} + ON ({class_counts}.classes = {desired_counts}.desired_classes) + ) + """.format(**locals()) + else: + """ + If output_table_size is specified, + uniformly sample remaining classes. + """ + remaining_class_samples = '' + if output_table_size: + + remaining_class_samples = _sample_remaining_classes_uniformly_per_output_table_size( + class_counts, desired_counts, output_table_size) + + desired_sample_per_class_sql = """ + CREATE VIEW {desired_sample_per_class} AS ( + SELECT + classes, + desired_values::int AS sample_class_size, + class_count AS actual_class_size, + CASE + WHEN (desired_values - class_count) = 0 THEN 'nosample' + WHEN (desired_values - class_count) > 0 THEN 'oversample' ELSE 'undersample' + END AS category + FROM ( + SELECT + classes, + CASE (SELECT 1 WHERE desired_values is NULL) + WHEN 1 THEN extra_values + ELSE desired_values + END AS desired_values, + class_count + FROM + {class_counts} + LEFT JOIN + {desired_counts} + ON ({class_counts}.classes = {desired_counts}.desired_classes) + LEFT JOIN + {remaining_class_samples} + ON ({class_counts}.classes = {remaining_class_samples}.extra_classes) + ) AS foo + ) + """.format(**locals()) + else: + """ + If output_table_size is not specified, + return remaining classes as is. + """ + desired_sample_per_class_sql = """ + CREATE VIEW {desired_sample_per_class} AS ( + SELECT + classes, + desired_values::int AS sample_class_size, + class_count as actual_class_size, + CASE + WHEN (desired_values - class_count) = 0 THEN 'nosample' + WHEN (desired_values - class_count) > 0 THEN 'oversample' ELSE 'undersample' + END AS category + FROM ( + SELECT + classes, + CASE (SELECT 1 WHERE desired_values is NULL) + WHEN 1 THEN class_count + ELSE desired_values + END AS desired_values, + class_count + FROM + {class_counts} + LEFT JOIN + {desired_counts} + ON ({class_counts}.classes = {desired_counts}.desired_classes) + ) AS foo + ) + """.format(**locals()) + if remaining_class_samples: + temp_views.append(remaining_class_samples) + + plpy.execute(desired_sample_per_class_sql) + temp_views.append(desired_sample_per_class) + return temp_views + +""" + Sample the remaining classes in output_table_size - (total calculated class sizes). +""" +def _sample_remaining_classes_uniformly_per_output_table_size(class_counts + , desired_counts, output_table_size): + + remaining_class_samples = unique_string(desp='remaining_class_samples') + foo_table = unique_string(desp='foo') + remaining_class_samples_sql = """ + CREATE VIEW {remaining_class_samples} as ( + SELECT + extra_classes, + ceil(((desired_total_size - desired_total_class_size)::FLOAT8)/((total_classes - desired_distinct_class_count)::FLOAT8)) AS extra_values + FROM + (SELECT + {output_table_size}::float8 AS desired_total_size) + AS constant_output_table_size, + (SELECT + sum(desired_values) AS desired_total_class_size, + count(*) AS desired_distinct_class_count + FROM {desired_counts}) + AS user_desired_class_sizes, + (SELECT + --sum(class_count) as total_size, + count(*) AS total_classes + FROM {class_counts}) + AS actual_table_frequencies, + (SELECT --- End diff -- This SELECT can be combined with the FROM subquery
---