Github user kaknikhil commented on a diff in the pull request:
https://github.com/apache/madlib/pull/230#discussion_r165734791
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -0,0 +1,748 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file EXCEPT in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+m4_changequote(`<!', `!>')
+
+import math
+
+if __name__ != "__main__":
+ import plpy
+ from utilities.control import MinWarning
+ from utilities.utilities import _assert
+ from utilities.utilities import extract_keyvalue_params
+ from utilities.utilities import unique_string
+ from utilities.validate_args import columns_exist_in_table
+ from utilities.validate_args import get_cols
+ from utilities.validate_args import table_exists
+ from utilities.validate_args import table_is_empty
+else:
+ # Used only for Unit Testing
+ # FIXME: repeating a function from utilities that is needed by the
unit test.
+ # This should be removed once a unittest framework in used for testing.
+ import random
+ import time
+
+ def unique_string(desp='', **kwargs):
+ """
+ Generate random remporary names for temp table and other names.
+ It has a SQL interface so both SQL and Python functions can call
it.
+ """
+ r1 = random.randint(1, 100000000)
+ r2 = int(time.time())
+ r3 = int(time.time()) % random.randint(1, 100000000)
+ u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_"
+ str(r3) + "__"
+ return u_string
+#
------------------------------------------------------------------------------
+
+UNIFORM = 'uniform'
+UNDERSAMPLE = 'undersample'
+OVERSAMPLE = 'oversample'
+NOSAMPLE = 'nosample'
+
+NEW_ID_COLUMN = '__madlib_id__'
+NULL_IDENTIFIER = '__madlib_null_id__'
+
+def _get_frequency_distribution(source_table, class_col):
+ """ Returns a dict containing the number of rows associated with each
class
+ level. Each class level value is converted to a string using
::text.
+ """
+ query_result = plpy.execute("""
+ SELECT {class_col}::text AS classes,
+ count(*) AS class_count
+ FROM {source_table}
+ GROUP BY {class_col}
+ """.format(**locals()))
+ actual_level_counts = {}
+ for each_row in query_result:
+ level = each_row['classes']
+ if level:
+ level = level.strip()
+ actual_level_counts[level] = each_row['class_count']
+ return actual_level_counts
+
+
+def _validate_and_get_sampling_strategy(sampling_strategy_str,
output_table_size,
+ supported_strategies=None, default=UNIFORM):
+ """ Returns the sampling strategy based on the class_sizes input param.
+ @param sampling_strategy_str The sampling strategy specified by the
+ user (class_sizes param)
+ @returns:
+ Str. One of [UNIFORM, UNDERSAMPLE, OVERSAMPLE]. Default is
UNIFORM.
+ """
+ if not sampling_strategy_str:
+ sampling_strategy_str = default
+ else:
+ if len(sampling_strategy_str) < 3:
+ # Require at least 3 characters since UNIFORM and UNDERSAMPLE
have
+ # common prefix substring
+ plpy.error("Sample: Invalid class_sizes parameter")
+
+ if not supported_strategies:
+ supported_strategies = [UNIFORM, UNDERSAMPLE, OVERSAMPLE]
+ try:
+ # allow user to specify a prefix substring of
+ # supported strategies.
+ sampling_strategy_str = next(x for x in supported_strategies
+ if
x.startswith(sampling_strategy_str.lower()))
+ except StopIteration:
+ # next() returns a StopIteration if no element found
+ plpy.error("Sample: Invalid class_sizes parameter: "
+ "{0}. Supported class_size parameters are ({1})"
+ .format(sampling_strategy_str,
','.join(sorted(supported_strategies))))
+
+ _assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE,
OVERSAMPLE) or
+ (sampling_strategy_str.find('=') > 0),
+ "Sample: Invalid class size
({sampling_strategy_str}).".format(**locals()))
+
+ _assert(not(sampling_strategy_str.lower() == 'oversample' and
output_table_size),
+ "Sample: Cannot set output_table_size with oversampling.")
+
+ _assert(not(sampling_strategy_str.lower() == 'undersample' and
output_table_size),
+ "Sample: Cannot set output_table_size with undersampling.")
+
+ return sampling_strategy_str
+#
------------------------------------------------------------------------------
+
+
+def _choose_strategy(actual_count, desired_count):
+ """ Choose sampling strategy by comparing actual and desired sample
counts
+
+ @param actual_count: Actual number of samples for some level
+ @param desired_count: Desired number of sample for the level
+ @returns:
+ Str. Sampling strategy string (either UNDERSAMPlE or OVERSAMPLE)
+ """
+ # OVERSAMPLE when the actual count is less than the desired count
+ # UNDERSAMPLE when the actual count is more than the desired count
+
+ # If the actual count for a class level is the same as desired count,
then
+ # we could potentially return the input rows as is. This, however,
+ # precludes the case of bootstrapping (i.e. returning same number of
rows
+ # but after sampling with replacement). Hence, we treat the
actual=desired
+ # as UNDERSAMPLE. It's specifically set to UNDERSAMPLE since it
provides
+ # both 'with' and 'without' replacement (OVERSAMPLE is always with
+ # replacement and NOSAMPLE is always without replacement)
+ if actual_count < desired_count:
+ return OVERSAMPLE
+ else:
+ return UNDERSAMPLE
+# -------------------------------------------------------------------------
+
+def _get_target_level_counts(sampling_strategy_str, desired_level_counts,
+ actual_level_counts, output_table_size):
+ """
+ @param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE,
OVERSAMPLE, None].
+ This is 'None' only if this is
user-defined, i.e.,
+ a comma separated list of class levels and
number of
+ rows desired pairs.
+ @param desired_level_counts: Dict that is defined and populated only
when
+ sampling_strategy_str is None.
+ @param actual_level_counts: Dict of various class levels and number of
rows
+ in each of them in the input table
+ @param output_table_size: Size of the desired output table (NULL or
Integer)
+
+ @returns:
+ Dict. Number of samples to be drawn, and the sampling strategy to
be
+ used for each class level.
+ """
+ target_level_counts = {}
+ if not sampling_strategy_str:
+ # This case implies user has provided a desired count for one or
more
+ # levels. Counts for the rest of the levels depend on
'output_table_size'.
+ # if 'output_table_size' = NULL, unspecified level counts remain
as is
+ # if 'output_table_size' = <Integer>, divide remaining row count
+ # uniformly among unspecified level
counts
+ for each_level, desired_count in desired_level_counts.items():
+ sample_strategy =
_choose_strategy(actual_level_counts[each_level],
+ desired_count)
+ target_level_counts[each_level] = (desired_count,
sample_strategy)
+
+ remaining_levels = (set(actual_level_counts.keys()) -
+ set(desired_level_counts.keys()))
+ if output_table_size:
+ # Uniformly distribute across the remaining class levels
+ remaining_rows = output_table_size -
sum(desired_level_counts.values())
+ if remaining_rows > 0:
+ rows_per_level = math.ceil(float(remaining_rows) /
+ len(remaining_levels))
+ for each_level in remaining_levels:
+ sample_strategy = _choose_strategy(
+ actual_level_counts[each_level], rows_per_level)
+ target_level_counts[each_level] = (rows_per_level,
+ sample_strategy)
+ else:
+ # When output_table_size is unspecified, rows from the input
table
+ # are sampled as is for remaining class levels. This is same
as the
+ # NOSAMPLE strategy.
+ for each_level in remaining_levels:
+ target_level_counts[each_level] =
(actual_level_counts[each_level],
+ NOSAMPLE)
+ else:
+ def ceil_of_mean(numbers):
+ return math.ceil(float(sum(numbers)) / max(len(numbers), 1))
+
+ # UNIFORM: Ensure all level counts are same (size determined by
output_table_size)
+ # UNDERSAMPLE: Ensure all level counts are same as the minimum
count
+ # OVERSAMPLE: Ensure all level counts are same as the maximum count
+ size_function = {UNDERSAMPLE: min,
+ OVERSAMPLE: max,
+ UNIFORM: ceil_of_mean
+ }[sampling_strategy_str]
+ if sampling_strategy_str == UNIFORM and output_table_size:
+ # Ignore actual counts for computing target sizes
+ # if output_table_size is specified
+ target_size_per_level = math.ceil(float(output_table_size) /
+ len(actual_level_counts))
+ else:
+ target_size_per_level =
size_function(actual_level_counts.values())
+ for each_level, actual_count in actual_level_counts.items():
+ sample_strategy = _choose_strategy(actual_count,
target_size_per_level)
+ target_level_counts[each_level] = (target_size_per_level,
+ sample_strategy)
+ return target_level_counts
+
+# -------------------------------------------------------------------------
+
+
+def _get_sampling_strategy_specific_dict(target_class_sizes):
+ """ Return three dicts, one each for undersampling, oversampling, and
+ nosampling. The dict contains the number of samples to be drawn for
+ each class level.
+ """
+ undersample_level_dict = {}
+ oversample_level_dict = {}
+ nosample_level_dict = {}
+ for level, (count, strategy) in target_class_sizes.items():
+ if strategy == UNDERSAMPLE:
+ chosen_strategy = undersample_level_dict
+ elif strategy == OVERSAMPLE:
+ chosen_strategy = oversample_level_dict
+ else:
+ chosen_strategy = nosample_level_dict
+ chosen_strategy[level] = count
+ return (undersample_level_dict, oversample_level_dict,
nosample_level_dict)
+#
------------------------------------------------------------------------------
+
+
+def _get_nosample_subquery(source_table, class_col, nosample_levels):
+ """ Return the subquery for fetching all rows as is from the input
table
+ for specific class levels.
+ """
+ if not nosample_levels:
+ return ''
+ subquery = """
+ SELECT *
+ FROM {0}
+ WHERE {1} in ({2}) OR {1} IS NULL
+ """.format(source_table, class_col,
+ ','.join(["'{0}'".format(level)
+ for level in nosample_levels if level]))
+ return subquery
+#
------------------------------------------------------------------------------
+
+
+def _get_without_replacement_subquery(schema_madlib, source_table,
+ source_table_columns, class_col,
+ actual_level_counts,
desired_level_counts):
+ """ Return the subquery for sampling without replacement for specific
+ class levels.
+ """
+ if not desired_level_counts:
+ return ''
+ class_col_tmp = unique_string()
+ row_number_col = unique_string()
+ desired_count_col = unique_string()
+
+ null_value_string = "'{0}'".format(NULL_IDENTIFIER)
+
+ desired_level_counts_str = "VALUES " + \
+ ','.join("({0}, {1})".
+ format("'{0}'::text".format(k) if k else null_value_string, v)
+ for k, v in desired_level_counts.items())
+ subquery = """
+ SELECT {source_table_columns}
+ FROM
+ (
+ SELECT {source_table_columns},
+ row_number() OVER (PARTITION BY {class_col}
ORDER BY random()) AS {row_number_col},
+ {desired_count_col}
+ FROM
+ (
+ SELECT {source_table_columns},
+ {desired_count_col}
+ FROM
+ {source_table} s,
+ ({desired_level_counts_str})
+ q({class_col_tmp}, {desired_count_col})
+ WHERE {class_col_tmp} =
coalesce({class_col}::text, '{null_level_val}')
+ ) q2
+ ) q3
+ WHERE {row_number_col} <= {desired_count_col}
+ """.format(null_level_val=NULL_IDENTIFIER, **locals())
+ return subquery
+#
------------------------------------------------------------------------------
+
+
+def _get_with_replacement_subquery(schema_madlib, source_table,
+ source_table_columns, class_col,
+ actual_level_counts,
desired_level_counts):
+ """ Return the query for sampling with replacement for specific class
+ levels (always used for oversampling, and used for undersampling if
+ with_replacement flag is set to TRUE).
+ """
+ if not desired_level_counts:
+ return ''
+
+ class_col_tmp = unique_string()
+ desired_count_col = unique_string()
+ actual_count_col = unique_string()
+ q1_row_no = unique_string()
+ q2_row_no = unique_string()
+
+ null_value_string = "'{0}'".format(NULL_IDENTIFIER)
+
+ desired_and_actual_level_counts = "VALUES " + \
+ ','.join("({0}, {1}, {2})".
+ format("'{0}'::text".format(k) if k else null_value_string,
+ v, actual_level_counts[k])
+ for k, v in desired_level_counts.items())
+ subquery = """
+ SELECT {source_table_columns}
+ FROM
+ (
+ SELECT
+ {class_col_tmp},
+ generate_series(1, {desired_count_col}::int) AS
_i,
+ ((random()*({actual_count_col}-1)+1)::int) AS
{q1_row_no}
+ FROM
+ ({desired_and_actual_level_counts})
+ q({class_col_tmp}, {desired_count_col},
{actual_count_col})
+ ) q1,
+ (
+ SELECT
+ *,
+ row_number() OVER(PARTITION BY {class_col}) AS
{q2_row_no}
+ FROM
+ {source_table}
+ ) q2
+ WHERE {class_col_tmp} = coalesce({class_col}::text,
'{null_level_val}') AND
+ q1.{q1_row_no} = q2.{q2_row_no}
+ """.format(null_level_val=NULL_IDENTIFIER, **locals())
+ return subquery
+#
------------------------------------------------------------------------------
+
+def balance_sample(schema_madlib, source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols,
+ with_replacement, keep_null, **kwargs):
+ """
+ Balance sampling function
+ Args:
+ @param source_table Input table name.
+ @param output_table Output table name.
+ @param class_col Name of the column containing the class
to be
+ balanced.
+ @param class_sizes Parameter to define the size of the
different
+ class values.
+ @param output_table_size Desired size of the output data set.
+ @param grouping_cols The columns that define the grouping.
+ @param with_replacement The sampling method.
+ @param keep_null Flag to include rows with class level
values
+ NULL. Default is False.
+
+ """
+ with MinWarning("warning"):
+
+ desired_sample_per_class =
unique_string(desp='desired_sample_per_class')
+ desired_counts = unique_string(desp='desired_counts')
+
+ # set all default values
+ if not class_sizes:
+ class_sizes = UNIFORM
+ if not with_replacement:
+ with_replacement = False
+ keep_null = False if not keep_null else True
+ if class_sizes:
+ class_sizes = class_sizes.strip()
+
+ _validate_strs(source_table, output_table, class_col,
+ output_table_size, grouping_cols)
+ source_table_columns = ','.join(get_cols(source_table))
+
+ new_source_table = source_table
+ # If keep_null=False, create a view of the input table ignoring
NULL
+ # values for class levels.
+ if not keep_null:
+ new_source_table = unique_string(desp='source_table')
+ plpy.execute("""
+ CREATE VIEW {new_source_table} AS
+ SELECT * FROM {source_table}
+ WHERE {class_col} IS NOT NULL
+ """.format(**locals()))
+ actual_level_counts = _get_frequency_distribution(new_source_table,
+ class_col)
+ # class_sizes can be of two forms:
+ # 1. A string describing sampling strategy (as described in
+ # _validate_and_get_sampling_strategy).
+ # In this case, 'sampling_strategy_str' is set to one of
+ # [UNIFORM, UNDERSAMPLE, OVERSAMPLE]
+ # 2. Class sizes for all (or a subset) of the class levels
+ # In this case, sampling_strategy_str = None and
parsed_class_sizes
+ # is used for the sampling.
+ parsed_class_sizes = extract_keyvalue_params(class_sizes,
+
allow_duplicates=False,
+
lower_case_names=False)
+ if not parsed_class_sizes:
+ sampling_strategy_str =
_validate_and_get_sampling_strategy(class_sizes,
+ output_table_size)
+ else:
+ sampling_strategy_str = None
+ try:
+ all_levels = actual_level_counts.keys()
+ for each_level, each_class_size in
parsed_class_sizes.items():
+ _assert(each_level in all_levels,
+ "Sample: Invalid class value specified ({0})".
+ format(each_level))
+ each_class_size = int(each_class_size)
+ _assert(each_class_size >= 1,
+ "Sample: Class size has to be greater than
zero")
+ parsed_class_sizes[each_level] = each_class_size
+
+ except TypeError:
+ plpy.error("Sample: Invalid value for class_sizes ({0})".
+ format(class_sizes))
+
+ # Get the number of rows to be sampled for each class level, based
on
+ # the input table, class_sizes, and output_table_size params. This
also
+ # includes info about the resulting sampling strategy, i.e., one of
+ # UNDERSAMPLE, OVERSAMPLE, or NOSAMPLE for each level.
+ target_class_sizes =
_get_target_level_counts(sampling_strategy_str,
+ parsed_class_sizes,
+ actual_level_counts,
+ output_table_size)
+
+ undersample_level_dict, oversample_level_dict, nosample_level_dict
= \
+ _get_sampling_strategy_specific_dict(target_class_sizes)
+
+ # Get subqueries for each sampling strategy, so that they can be
used
+ # together in one big query.
+ nosample_subquery = _get_nosample_subquery(
+ new_source_table, class_col, nosample_level_dict.keys())
+ oversample_subquery = _get_with_replacement_subquery(
+ schema_madlib, new_source_table, source_table_columns,
class_col,
+ actual_level_counts, oversample_level_dict)
+ if with_replacement:
+ undersample_subquery = _get_with_replacement_subquery(
+ schema_madlib, new_source_table, source_table_columns,
class_col,
+ actual_level_counts, undersample_level_dict)
+ else:
+ undersample_subquery = _get_without_replacement_subquery(
+ schema_madlib, new_source_table, source_table_columns,
class_col,
+ actual_level_counts, undersample_level_dict)
+
+ # Merge the three subqueries using a UNION ALL clause.
+ union_all_subquery = ' UNION ALL '.join(
+ ['({0})'.format(subquery)
+ for subquery in [undersample_subquery, oversample_subquery,
nosample_subquery]
+ if subquery])
+
+ final_query = """
+ CREATE TABLE {output_table} AS
+ SELECT row_number() OVER() AS {new_col_name}, *
+ FROM (
+ {union_all_subquery}
+ ) union_query
+ """.format(new_col_name=NEW_ID_COLUMN, **locals())
+ plpy.execute(final_query)
+ if not keep_null:
+ plpy.execute("DROP VIEW {0}".format(new_source_table))
+
+
+def _validate_strs(source_table, output_table, class_col,
output_table_size,
+ grouping_cols):
+ _assert(source_table and table_exists(source_table),
+ "Sample: Source table ({source_table}) does not
exist.".format(**locals()))
+ _assert(not table_is_empty(source_table),
+ "Sample: Source table ({source_table}) is
empty.".format(**locals()))
+
+ _assert(output_table,
+ "Sample: Output table name is missing.".format(**locals()))
+ _assert(not table_exists(output_table),
+ "Sample: Output table ({output_table}) already
exists.".format(**locals()))
+
+ _assert(class_col,
+ "Sample: Class column name is missing.".format(**locals()))
+ _assert(columns_exist_in_table(source_table, [class_col]),
+ ("""Sample: Class column ({class_col}) does not exist in""" +
+ """ table ({source_table}).""").format(**locals()))
+
+ _assert(not columns_exist_in_table(source_table, [NEW_ID_COLUMN]),
+ ("""Sample: Please ensure the source table ({0})""" +
+ """ does not contain a column named
{1}""").format(source_table, NEW_ID_COLUMN))
+
+ _assert((not output_table_size) or (output_table_size > 0),
+ "Sample: Invalid output table size
({output_table_size}).".format(
+ **locals()))
+
+ _assert(grouping_cols is None,
+ "grouping_cols is not supported at the moment."
+ .format(**locals()))
+
+
+def balance_sample_help(schema_madlib, message, **kwargs):
+ """
+ Help function for balance_sample
+
+ Args:
+ @param schema_madlib
+ @param message: string, Help message string
+ @param kwargs
+
+ Returns:
+ String. Help/usage information
+ """
+ if not message:
+ help_string = """
+-----------------------------------------------------------------------
+ SUMMARY
+-----------------------------------------------------------------------
+Given a table with varying set of records for each class label,
+this function will create an output table with a varying types (by
+default: uniform) of sampling distributions of each class label. It is
+possible to use with or without replacement sampling methods, specify
+different proportions of each class, multiple grouping columns and/or
+output table size.
+
+For more details on function usage:
+ SELECT {schema_madlib}.balance_sample('usage');
+ SELECT {schema_madlib}.balance_sample('example');
+ """
+ elif message.lower() in ['usage', 'help', '?']:
+ help_string = """
+
+Given a table, stratified sampling returns a proportion of records for
+each group (strata). It is possible to use with or without replacement
+sampling methods, specify a set of target columns, and assume the
+whole table is a single strata.
+
+----------------------------------------------------------------------------
+ USAGE
+----------------------------------------------------------------------------
+
+ SELECT {schema_madlib}.balance_sample(
+ source_table TEXT, -- Input table name.
+ output_table TEXT, -- Output table name.
+ class_col TEXT, -- Name of column containing the class to
be
+ -- balanced.
+ class_size TEXT, -- (Default: NULL) Parameter to define the
size
+ -- of the different class values.
+ output_table_size INTEGER, -- (Default: NULL) Desired size of the
output
+ -- data set.
+ grouping_cols TEXT, -- (Default: NULL) The columns columns that
+ -- defines the grouping.
+ with_replacement BOOLEAN -- (Default: FALSE) The sampling method.
+ keep_null BOOLEAN -- (Default: FALSE) Consider class levels
with
+ NULL values or not.
+
+If class_size is NULL, the source table is uniformly sampled.
+
+If output_table_size is NULL, the resulting output table size will depend
on
+the settings for the âclass_sizeâ parameter. It is ignored if
âclass_sizeâ
+parameter is set to either âoversampleâ or âundersampleâ.
+
+If grouping_cols is NULL, the whole table is treated as a single group and
+sampled accordingly.
+
+If with_replacement is TRUE, each sample is independent (the same row may
+be selected in the sample set more than once). Else (if with_replacement
+is FALSE), a row can be selected at most once.
+);
+
+The output_table would contain the required number of samples, along with a
+new column named __madlib_id__, that contain unique numbers for all
+sampled rows.
+"""
+ elif message.lower() in ("example", "examples"):
+ help_string = """
+----------------------------------------------------------------------------
+ EXAMPLES
+----------------------------------------------------------------------------
+
+-- Create an input table
+DROP TABLE IF EXISTS test;
+
+CREATE TABLE test(
+ id1 INTEGER,
+ id2 INTEGER,
+ gr1 INTEGER,
+ gr2 INTEGER
+);
+
+INSERT INTO test VALUES
+(1,0,1,1),
+(2,0,1,1),
+(3,0,1,1),
+(4,0,1,1),
+(5,0,1,1),
+(6,0,1,1),
+(7,0,1,1),
+(8,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(0,1,1,2),
+(0,2,1,2),
+(0,3,1,2),
+(0,4,1,2),
+(0,5,1,2),
+(0,6,1,2),
+(10,10,2,2),
+(20,20,2,2),
+(30,30,2,2),
+(40,40,2,2),
+(50,50,2,2),
+(60,60,2,2),
+(70,70,2,2)
+;
+
+-- Sample without replacement
+DROP TABLE IF EXISTS out;
+SELECT balance_sample('test', 'out', 'gr1', 'undersample', NULL, NULL,
FALSE);
+SELECT * FROM out;
+
+--- Sample with replacement
+DROP TABLE IF EXISTS out_sr2;
+SELECT balance_sample('test', 'out', 'gr1', 'undersample', NULL, NULL,
TRUE);
+SELECT * FROM out;
+"""
+ else:
+ help_string = "No such option. Use {schema_madlib}.graph_sssp()"
--- End diff --
This should be `{schema_madlib}.balance_sample()`
---