Github user kaknikhil commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/230#discussion_r165530126
  
    --- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
    @@ -0,0 +1,748 @@
    +# coding=utf-8
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one
    +# or more contributor license agreements.  See the NOTICE file
    +# distributed with this work for additional information
    +# regarding copyright ownership.  The ASF licenses this file
    +# to you under the Apache License, Version 2.0 (the
    +# "License"); you may not use this file EXCEPT in compliance
    +# with the License.  You may obtain a copy of the License at
    +#
    +#   http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing,
    +# software distributed under the License is distributed on an
    +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +# KIND, either express or implied.  See the License for the
    +# specific language governing permissions and limitations
    +# under the License.
    +
    +m4_changequote(`<!', `!>')
    +
    +import math
    +
    +if __name__ != "__main__":
    +    import plpy
    +    from utilities.control import MinWarning
    +    from utilities.utilities import _assert
    +    from utilities.utilities import extract_keyvalue_params
    +    from utilities.utilities import unique_string
    +    from utilities.validate_args import columns_exist_in_table
    +    from utilities.validate_args import get_cols
    +    from utilities.validate_args import table_exists
    +    from utilities.validate_args import table_is_empty
    +else:
    +    # Used only for Unit Testing
    +    # FIXME: repeating a function from utilities that is needed by the 
unit test.
    +    # This should be removed once a unittest framework in used for testing.
    +    import random
    +    import time
    +
    +    def unique_string(desp='', **kwargs):
    +        """
    +        Generate random remporary names for temp table and other names.
    +        It has a SQL interface so both SQL and Python functions can call 
it.
    +        """
    +        r1 = random.randint(1, 100000000)
    +        r2 = int(time.time())
    +        r3 = int(time.time()) % random.randint(1, 100000000)
    +        u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" 
+ str(r3) + "__"
    +        return u_string
    +# 
------------------------------------------------------------------------------
    +
    +UNIFORM = 'uniform'
    +UNDERSAMPLE = 'undersample'
    +OVERSAMPLE = 'oversample'
    +NOSAMPLE = 'nosample'
    +
    +NEW_ID_COLUMN = '__madlib_id__'
    +NULL_IDENTIFIER = '__madlib_null_id__'
    +
    +def _get_frequency_distribution(source_table, class_col):
    +    """ Returns a dict containing the number of rows associated with each 
class
    +        level. Each class level value is converted to a string using 
::text.
    +    """
    +    query_result = plpy.execute("""
    +                    SELECT {class_col}::text AS classes,
    +                           count(*) AS class_count
    +                    FROM {source_table}
    +                    GROUP BY {class_col}
    +                 """.format(**locals()))
    +    actual_level_counts = {}
    +    for each_row in query_result:
    +        level = each_row['classes']
    +        if level:
    +            level = level.strip()
    +        actual_level_counts[level] = each_row['class_count']
    +    return actual_level_counts
    +
    +
    +def _validate_and_get_sampling_strategy(sampling_strategy_str, 
output_table_size,
    +                            supported_strategies=None, default=UNIFORM):
    +    """ Returns the sampling strategy based on the class_sizes input param.
    +        @param sampling_strategy_str The sampling strategy specified by the
    +                                         user (class_sizes param)
    +        @returns:
    +            Str. One of [UNIFORM, UNDERSAMPLE, OVERSAMPLE]. Default is 
UNIFORM.
    +    """
    +    if not sampling_strategy_str:
    +        sampling_strategy_str = default
    +    else:
    +        if len(sampling_strategy_str) < 3:
    +            # Require at least 3 characters since UNIFORM and UNDERSAMPLE 
have
    +            # common prefix substring
    +            plpy.error("Sample: Invalid class_sizes parameter")
    +
    +        if not supported_strategies:
    +            supported_strategies = [UNIFORM, UNDERSAMPLE, OVERSAMPLE]
    +        try:
    +            # allow user to specify a prefix substring of
    +            # supported strategies.
    +            sampling_strategy_str = next(x for x in supported_strategies
    +                                         if 
x.startswith(sampling_strategy_str.lower()))
    +        except StopIteration:
    +            # next() returns a StopIteration if no element found
    +            plpy.error("Sample: Invalid class_sizes parameter: "
    +                       "{0}. Supported class_size parameters are ({1})"
    +                       .format(sampling_strategy_str, 
','.join(sorted(supported_strategies))))
    +
    +    _assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE, 
OVERSAMPLE) or
    +            (sampling_strategy_str.find('=') > 0),
    +            "Sample: Invalid class size 
({sampling_strategy_str}).".format(**locals()))
    +
    +    _assert(not(sampling_strategy_str.lower() == 'oversample' and 
output_table_size),
    +            "Sample: Cannot set output_table_size with oversampling.")
    +
    +    _assert(not(sampling_strategy_str.lower() == 'undersample' and 
output_table_size),
    +            "Sample: Cannot set output_table_size with undersampling.")
    +
    +    return sampling_strategy_str
    +# 
------------------------------------------------------------------------------
    +
    +
    +def _choose_strategy(actual_count, desired_count):
    +    """ Choose sampling strategy by comparing actual and desired sample 
counts
    +
    +    @param actual_count: Actual number of samples for some level
    +    @param desired_count: Desired number of sample for the level
    +    @returns:
    +        Str. Sampling strategy string (either UNDERSAMPlE or OVERSAMPLE)
    +    """
    +    # OVERSAMPLE when the actual count is less than the desired count
    +    # UNDERSAMPLE when the actual count is more than the desired count
    +
    +    # If the actual count for a class level is the same as desired count, 
then
    +    # we could potentially return the input rows as is.  This, however,
    +    # precludes the case of bootstrapping (i.e. returning same  number of 
rows
    +    # but after sampling with replacement).  Hence, we treat the 
actual=desired
    +    # as UNDERSAMPLE.  It's specifically set to UNDERSAMPLE since it 
provides
    +    # both 'with' and 'without' replacement  (OVERSAMPLE is always with
    +    # replacement and NOSAMPLE is always without replacement)
    +    if actual_count < desired_count:
    +        return OVERSAMPLE
    +    else:
    +        return UNDERSAMPLE
    +# -------------------------------------------------------------------------
    +
    +def _get_target_level_counts(sampling_strategy_str, desired_level_counts,
    +                             actual_level_counts, output_table_size):
    +    """
    +    @param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE, 
OVERSAMPLE, None].
    +                               This is 'None' only if this is 
user-defined, i.e.,
    +                               a comma separated list of class levels and 
number of
    +                               rows desired pairs.
    +    @param desired_level_counts: Dict that is defined and populated only 
when
    +                                    sampling_strategy_str is None.
    +    @param actual_level_counts: Dict of various class levels and number of 
rows
    +                                  in each of them in the input table
    +    @param output_table_size: Size of the desired output table (NULL or 
Integer)
    +
    +    @returns:
    +        Dict. Number of samples to be drawn, and the sampling strategy to 
be
    +              used for each class level.
    +    """
    +    target_level_counts = {}
    +    if not sampling_strategy_str:
    +        # This case implies user has provided a desired count for one or 
more
    +        # levels. Counts for the rest of the levels depend on 
'output_table_size'.
    +        #   if 'output_table_size' = NULL, unspecified level counts remain 
as is
    +        #   if 'output_table_size' = <Integer>, divide remaining row count
    +        #                             uniformly among unspecified level 
counts
    +        for each_level, desired_count in desired_level_counts.items():
    +            sample_strategy = 
_choose_strategy(actual_level_counts[each_level],
    +                                               desired_count)
    +            target_level_counts[each_level] = (desired_count, 
sample_strategy)
    +
    +        remaining_levels = (set(actual_level_counts.keys()) -
    +                            set(desired_level_counts.keys()))
    +        if output_table_size:
    +            # Uniformly distribute across the remaining class levels
    +            remaining_rows = output_table_size - 
sum(desired_level_counts.values())
    +            if remaining_rows > 0:
    +                rows_per_level = math.ceil(float(remaining_rows) /
    +                                           len(remaining_levels))
    +                for each_level in remaining_levels:
    +                    sample_strategy = _choose_strategy(
    +                        actual_level_counts[each_level], rows_per_level)
    +                    target_level_counts[each_level] = (rows_per_level,
    +                                                       sample_strategy)
    +        else:
    +            # When output_table_size is unspecified, rows from the input 
table
    +            # are sampled as is for remaining class levels. This is same 
as the
    +            # NOSAMPLE strategy.
    +            for each_level in remaining_levels:
    +                target_level_counts[each_level] = 
(actual_level_counts[each_level],
    +                                                    NOSAMPLE)
    +    else:
    +        def ceil_of_mean(numbers):
    +            return math.ceil(float(sum(numbers)) / max(len(numbers), 1))
    +
    +        # UNIFORM: Ensure all level counts are same (size determined by 
output_table_size)
    +        # UNDERSAMPLE: Ensure all level counts are same as the minimum 
count
    +        # OVERSAMPLE: Ensure all level counts are same as the maximum count
    +        size_function = {UNDERSAMPLE: min,
    +                         OVERSAMPLE: max,
    +                         UNIFORM: ceil_of_mean
    +                         }[sampling_strategy_str]
    +        if sampling_strategy_str == UNIFORM and output_table_size:
    +            # Ignore actual counts for computing target sizes
    +            # if output_table_size is specified
    +            target_size_per_level = math.ceil(float(output_table_size) /
    +                                              len(actual_level_counts))
    +        else:
    +            target_size_per_level = 
size_function(actual_level_counts.values())
    +        for each_level, actual_count in actual_level_counts.items():
    +            sample_strategy = _choose_strategy(actual_count, 
target_size_per_level)
    +            target_level_counts[each_level] = (target_size_per_level,
    +                                               sample_strategy)
    +    return target_level_counts
    +
    +# -------------------------------------------------------------------------
    +
    +
    +def _get_sampling_strategy_specific_dict(target_class_sizes):
    +    """ Return three dicts, one each for undersampling, oversampling, and
    +        nosampling. The dict contains the number of samples to be drawn for
    +        each class level.
    +    """
    +    undersample_level_dict = {}
    +    oversample_level_dict = {}
    +    nosample_level_dict = {}
    +    for level, (count, strategy) in target_class_sizes.items():
    +        if strategy == UNDERSAMPLE:
    +            chosen_strategy = undersample_level_dict
    +        elif strategy == OVERSAMPLE:
    +            chosen_strategy = oversample_level_dict
    +        else:
    +            chosen_strategy = nosample_level_dict
    +        chosen_strategy[level] = count
    +    return (undersample_level_dict, oversample_level_dict, 
nosample_level_dict)
    +# 
------------------------------------------------------------------------------
    +
    +
    +def _get_nosample_subquery(source_table, class_col, nosample_levels):
    +    """ Return the subquery for fetching all rows as is from the input 
table
    +        for specific class levels.
    +    """
    +    if not nosample_levels:
    +        return ''
    +    subquery = """
    +                SELECT *
    +                FROM {0}
    +                WHERE {1} in ({2}) OR {1} IS NULL
    +            """.format(source_table, class_col,
    +                       ','.join(["'{0}'".format(level)
    +                                for level in nosample_levels if level]))
    +    return subquery
    +# 
------------------------------------------------------------------------------
    +
    +
    +def _get_without_replacement_subquery(schema_madlib, source_table,
    +                                      source_table_columns, class_col,
    +                                      actual_level_counts, 
desired_level_counts):
    +    """ Return the subquery for sampling without replacement for specific
    +        class levels.
    +    """
    +    if not desired_level_counts:
    +        return ''
    +    class_col_tmp = unique_string()
    +    row_number_col = unique_string()
    +    desired_count_col = unique_string()
    +
    +    null_value_string = "'{0}'".format(NULL_IDENTIFIER)
    +
    +    desired_level_counts_str = "VALUES " + \
    +            ','.join("({0}, {1})".
    +            format("'{0}'::text".format(k) if k else null_value_string, v)
    +            for k, v in desired_level_counts.items())
    +    subquery = """
    +            SELECT {source_table_columns}
    +            FROM
    +                (
    +                    SELECT {source_table_columns},
    +                           row_number() OVER (PARTITION BY {class_col} 
ORDER BY random()) AS {row_number_col},
    +                           {desired_count_col}
    +                    FROM
    +                    (
    +                        SELECT {source_table_columns},
    +                               {desired_count_col}
    +                        FROM
    +                            {source_table} s,
    +                            ({desired_level_counts_str})
    +                                q({class_col_tmp}, {desired_count_col})
    +                        WHERE {class_col_tmp} = 
coalesce({class_col}::text, '{null_level_val}')
    +                    ) q2
    +                ) q3
    +            WHERE {row_number_col} <= {desired_count_col}
    +        """.format(null_level_val=NULL_IDENTIFIER, **locals())
    +    return subquery
    +# 
------------------------------------------------------------------------------
    +
    +
    +def _get_with_replacement_subquery(schema_madlib, source_table,
    +                                   source_table_columns, class_col,
    +                                   actual_level_counts, 
desired_level_counts):
    +    """ Return the query for sampling with replacement for specific class
    +        levels (always used for oversampling, and used for undersampling if
    +        with_replacement flag is set to TRUE).
    +    """
    +    if not desired_level_counts:
    +        return ''
    +
    +    class_col_tmp = unique_string()
    +    desired_count_col = unique_string()
    +    actual_count_col = unique_string()
    +    q1_row_no = unique_string()
    +    q2_row_no = unique_string()
    +
    +    null_value_string = "'{0}'".format(NULL_IDENTIFIER)
    +
    +    desired_and_actual_level_counts = "VALUES " + \
    +    ','.join("({0}, {1}, {2})".
    +             format("'{0}'::text".format(k) if k else null_value_string,
    +                v, actual_level_counts[k])
    +             for k, v in desired_level_counts.items())
    +    subquery = """
    +            SELECT {source_table_columns}
    +            FROM
    +                (
    +                    SELECT
    +                         {class_col_tmp},
    +                         generate_series(1, {desired_count_col}::int) AS 
_i,
    +                         ((random()*({actual_count_col}-1)+1)::int) AS 
{q1_row_no}
    +                    FROM
    +                        ({desired_and_actual_level_counts})
    +                            q({class_col_tmp}, {desired_count_col}, 
{actual_count_col})
    +                ) q1,
    +                (
    +                    SELECT
    +                        *,
    +                        row_number() OVER(PARTITION BY {class_col}) AS 
{q2_row_no}
    +                    FROM
    +                         {source_table}
    +                ) q2
    +            WHERE {class_col_tmp} = coalesce({class_col}::text, 
'{null_level_val}') AND
    +                  q1.{q1_row_no} = q2.{q2_row_no}
    +        """.format(null_level_val=NULL_IDENTIFIER, **locals())
    +    return subquery
    +# 
------------------------------------------------------------------------------
    +
    +def balance_sample(schema_madlib, source_table, output_table, class_col,
    +                   class_sizes, output_table_size, grouping_cols,
    +                   with_replacement, keep_null, **kwargs):
    +    """
    +    Balance sampling function
    +    Args:
    +        @param source_table       Input table name.
    +        @param output_table       Output table name.
    +        @param class_col          Name of the column containing the class 
to be
    +                                  balanced.
    +        @param class_sizes        Parameter to define the size of the 
different
    +                                  class values.
    +        @param output_table_size  Desired size of the output data set.
    +        @param grouping_cols      The columns that define the grouping.
    +        @param with_replacement   The sampling method.
    +        @param keep_null          Flag to include rows with class level 
values
    +                                  NULL. Default is False.
    +
    +    """
    +    with MinWarning("warning"):
    +
    +        desired_sample_per_class = 
unique_string(desp='desired_sample_per_class')
    +        desired_counts = unique_string(desp='desired_counts')
    +
    +        # set all default values
    +        if not class_sizes:
    +            class_sizes = UNIFORM
    +        if not with_replacement:
    +            with_replacement = False
    +        keep_null = False if not keep_null else True
    +        if class_sizes:
    +            class_sizes = class_sizes.strip()
    +
    +        _validate_strs(source_table, output_table, class_col,
    +                       output_table_size, grouping_cols)
    +        source_table_columns = ','.join(get_cols(source_table))
    +
    +        new_source_table = source_table
    +        # If keep_null=False, create a view of the input table ignoring 
NULL
    +        # values for class levels.
    +        if not keep_null:
    +            new_source_table = unique_string(desp='source_table')
    +            plpy.execute("""
    +                        CREATE VIEW {new_source_table} AS
    +                        SELECT * FROM {source_table}
    +                        WHERE {class_col} IS NOT NULL
    +                    """.format(**locals()))
    +        actual_level_counts = _get_frequency_distribution(new_source_table,
    +                                                          class_col)
    +        # class_sizes can be of two forms:
    +        #   1. A string describing sampling strategy (as described in
    +        #       _validate_and_get_sampling_strategy).
    +        #       In this case, 'sampling_strategy_str' is set to one of
    +        #       [UNIFORM, UNDERSAMPLE, OVERSAMPLE]
    +        #   2. Class sizes for all (or a subset) of the class levels
    +        #       In this case, sampling_strategy_str = None and 
parsed_class_sizes
    +        #       is used for the sampling.
    +        parsed_class_sizes = extract_keyvalue_params(class_sizes,
    +                                                     
allow_duplicates=False,
    +                                                     
lower_case_names=False)
    +        if not parsed_class_sizes:
    +            sampling_strategy_str = 
_validate_and_get_sampling_strategy(class_sizes,
    +                                        output_table_size)
    +        else:
    +            sampling_strategy_str = None
    +            try:
    +                all_levels = actual_level_counts.keys()
    +                for each_level, each_class_size in 
parsed_class_sizes.items():
    +                    _assert(each_level in all_levels,
    +                            "Sample: Invalid class value specified ({0})".
    +                                       format(each_level))
    +                    each_class_size = int(each_class_size)
    +                    _assert(each_class_size >= 1,
    +                            "Sample: Class size has to be greater than 
zero")
    +                    parsed_class_sizes[each_level] = each_class_size
    +
    +            except TypeError:
    +                plpy.error("Sample: Invalid value for class_sizes ({0})".
    +                           format(class_sizes))
    +
    +        # Get the number of rows to be sampled for each class level, based 
on
    +        # the input table, class_sizes, and output_table_size params. This 
also
    +        # includes info about the resulting sampling strategy, i.e., one of
    +        # UNDERSAMPLE, OVERSAMPLE, or NOSAMPLE for each level.
    +        target_class_sizes = 
_get_target_level_counts(sampling_strategy_str,
    +                                                      parsed_class_sizes,
    +                                                      actual_level_counts,
    +                                                      output_table_size)
    +
    +        undersample_level_dict, oversample_level_dict, nosample_level_dict 
= \
    +            _get_sampling_strategy_specific_dict(target_class_sizes)
    +
    +        # Get subqueries for each sampling strategy, so that they can be 
used
    +        # together in one big query.
    +        nosample_subquery = _get_nosample_subquery(
    +            new_source_table, class_col, nosample_level_dict.keys())
    +        oversample_subquery = _get_with_replacement_subquery(
    +            schema_madlib, new_source_table, source_table_columns, 
class_col,
    +            actual_level_counts, oversample_level_dict)
    +        if with_replacement:
    +            undersample_subquery = _get_with_replacement_subquery(
    +                schema_madlib, new_source_table, source_table_columns, 
class_col,
    +                actual_level_counts, undersample_level_dict)
    +        else:
    +            undersample_subquery = _get_without_replacement_subquery(
    +                schema_madlib, new_source_table, source_table_columns, 
class_col,
    +                actual_level_counts, undersample_level_dict)
    +
    +        # Merge the three subqueries using a UNION ALL clause.
    +        union_all_subquery = ' UNION ALL '.join(
    +            ['({0})'.format(subquery)
    +             for subquery in [undersample_subquery, oversample_subquery, 
nosample_subquery]
    +             if subquery])
    +
    +        final_query = """
    +                CREATE TABLE {output_table} AS
    +                SELECT row_number() OVER() AS {new_col_name}, *
    +                FROM (
    +                    {union_all_subquery}
    +                ) union_query
    +            """.format(new_col_name=NEW_ID_COLUMN, **locals())
    +        plpy.execute(final_query)
    +        if not keep_null:
    +            plpy.execute("DROP VIEW {0}".format(new_source_table))
    +
    +
    +def _validate_strs(source_table, output_table, class_col, 
output_table_size,
    +                    grouping_cols):
    +    _assert(source_table and table_exists(source_table),
    +            "Sample: Source table ({source_table}) does not 
exist.".format(**locals()))
    +    _assert(not table_is_empty(source_table),
    +            "Sample: Source table ({source_table}) is 
empty.".format(**locals()))
    +
    +    _assert(output_table,
    +            "Sample: Output table name is missing.".format(**locals()))
    +    _assert(not table_exists(output_table),
    +            "Sample: Output table ({output_table}) already 
exists.".format(**locals()))
    +
    +    _assert(class_col,
    +            "Sample: Class column name is missing.".format(**locals()))
    +    _assert(columns_exist_in_table(source_table, [class_col]),
    +            ("""Sample: Class column ({class_col}) does not exist in""" +
    +             """ table ({source_table}).""").format(**locals()))
    +
    +    _assert(not columns_exist_in_table(source_table, [NEW_ID_COLUMN]),
    +            ("""Sample: Please ensure the source table ({0})""" +
    +             """ does not contain a column named 
{1}""").format(source_table, NEW_ID_COLUMN))
    +
    +    _assert((not output_table_size) or (output_table_size > 0),
    +            "Sample: Invalid output table size 
({output_table_size}).".format(
    +            **locals()))
    +
    +    _assert(grouping_cols is None,
    +            "grouping_cols is not supported at the moment."
    +            .format(**locals()))
    +
    +
    +def balance_sample_help(schema_madlib, message, **kwargs):
    +    """
    +    Help function for balance_sample
    +
    +    Args:
    +        @param schema_madlib
    +        @param message: string, Help message string
    +        @param kwargs
    +
    +    Returns:
    +        String. Help/usage information
    +    """
    +    if not message:
    +        help_string = """
    +-----------------------------------------------------------------------
    +                            SUMMARY
    +-----------------------------------------------------------------------
    +Given a table with varying set of records for each class label,
    +this function will create an output table with a varying types (by
    +default: uniform) of sampling distributions of each class label. It is
    +possible to use with or without replacement sampling methods, specify
    +different proportions of each class, multiple grouping columns and/or
    +output table size.
    +
    +For more details on function usage:
    +    SELECT {schema_madlib}.balance_sample('usage');
    +    SELECT {schema_madlib}.balance_sample('example');
    +            """
    +    elif message.lower() in ['usage', 'help', '?']:
    +        help_string = """
    +
    +Given a table, stratified sampling returns a proportion of records for
    +each group (strata). It is possible to use with or without replacement
    +sampling methods, specify a set of target columns, and assume the
    +whole table is a single strata.
    +
    
+----------------------------------------------------------------------------
    +                            USAGE
    
+----------------------------------------------------------------------------
    +
    + SELECT {schema_madlib}.balance_sample(
    +    source_table      TEXT,     -- Input table name.
    +    output_table      TEXT,     -- Output table name.
    +    class_col         TEXT,     -- Name of column containing the class to 
be
    +                                -- balanced.
    +    class_size        TEXT,     -- (Default: NULL) Parameter to define the 
size
    +                                -- of the different class values.
    +    output_table_size INTEGER,  -- (Default: NULL) Desired size of the 
output
    +                                -- data set.
    +    grouping_cols     TEXT,     -- (Default: NULL) The columns columns that
    +                                -- defines the grouping.
    +    with_replacement  BOOLEAN   -- (Default: FALSE) The sampling method.
    +    keep_null         BOOLEAN   -- (Default: FALSE) Consider class levels 
with
    +                                    NULL values or not.
    +
    +If class_size is NULL, the source table is uniformly sampled.
    +
    +If output_table_size is NULL, the resulting output table size will depend 
on
    +the settings for the ‘class_size’ parameter. It is ignored if 
‘class_size’
    +parameter is set to either ‘oversample’ or ‘undersample’.
    +
    +If grouping_cols is NULL, the whole table is treated as a single group and
    +sampled accordingly.
    +
    +If with_replacement is TRUE, each sample is independent (the same row may
    +be selected in the sample set more than once). Else (if with_replacement
    +is FALSE), a row can be selected at most once.
    +);
    +
    +The output_table would contain the required number of samples, along with a
    +new column named __madlib_id__, that contain unique numbers for all
    +sampled rows.
    +"""
    +    elif message.lower() in ("example", "examples"):
    +        help_string = """
    
+----------------------------------------------------------------------------
    +                                EXAMPLES
    
+----------------------------------------------------------------------------
    +
    +-- Create an input table
    +DROP TABLE IF EXISTS test;
    +
    +CREATE TABLE test(
    +    id1 INTEGER,
    +    id2 INTEGER,
    +    gr1 INTEGER,
    +    gr2 INTEGER
    +);
    +
    +INSERT INTO test VALUES
    +(1,0,1,1),
    +(2,0,1,1),
    +(3,0,1,1),
    +(4,0,1,1),
    +(5,0,1,1),
    +(6,0,1,1),
    +(7,0,1,1),
    +(8,0,1,1),
    +(9,0,1,1),
    +(9,0,1,1),
    +(9,0,1,1),
    +(9,0,1,1),
    +(0,1,1,2),
    +(0,2,1,2),
    +(0,3,1,2),
    +(0,4,1,2),
    +(0,5,1,2),
    +(0,6,1,2),
    +(10,10,2,2),
    +(20,20,2,2),
    +(30,30,2,2),
    +(40,40,2,2),
    +(50,50,2,2),
    +(60,60,2,2),
    +(70,70,2,2)
    +;
    +
    +-- Sample without replacement
    +DROP TABLE IF EXISTS out;
    +SELECT balance_sample('test', 'out', 'gr1', 'undersample', NULL, NULL, 
FALSE);
    +SELECT * FROM out;
    +
    +--- Sample with replacement
    +DROP TABLE IF EXISTS out_sr2;
    +SELECT balance_sample('test', 'out', 'gr1', 'undersample', NULL, NULL, 
TRUE);
    +SELECT * FROM out;
    +"""
    +    else:
    +        help_string = "No such option. Use {schema_madlib}.graph_sssp()"
    +
    +    return help_string.format(schema_madlib=schema_madlib)
    +
    +
    +import unittest
    +
    +
    +class UtilitiesTestCase(unittest.TestCase):
    +    """
    +        Comment "import plpy" and replace plpy.error calls with appropriate
    +        Python Exceptions to successfully run the test cases
    +    """
    +
    +    def setUp(self):
    +        self.input_class_level_counts1 = {'a': 20, 'b': 30, 'c': 25}
    +        self.level1a = 'a'
    +        self.level1a_cnt1 = 15
    +        self.level1a_cnt2 = 25
    +        self.level1a_cnt3 = 20
    +
    +        self.sampling_strategy_str0 = ''
    +        self.sampling_strategy_str1 = 'uniform'
    +        self.sampling_strategy_str2 = 'oversample'
    +        self.sampling_strategy_str3 = 'undersample'
    +        self.user_specified_class_size0 = ''
    +        self.user_specified_class_size1 = {'a': 25, 'b': 25}
    +        self.user_specified_class_size2 = {'b': 25}
    +        self.user_specified_class_size3 = {'a': 30}
    +        self.output_table_size1 = None
    +        self.output_table_size2 = 60
    +        # self.input_class_level_counts2 = {'a':100, 'b':100, 'c':100}
    +
    +    def test__choose_strategy(self):
    +        self.assertEqual(UNDERSAMPLE, _choose_strategy(35, 25))
    +        self.assertEqual(OVERSAMPLE, _choose_strategy(15, 25))
    +        self.assertEqual(UNDERSAMPLE, _choose_strategy(25, 25))
    +
    +    def test__get_target_level_counts(self):
    +        # Test cases for user defined class level samples, without output 
table size
    +        self.assertEqual({'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 
'c': (25, NOSAMPLE)},
    +                         
_get_target_level_counts(self.sampling_strategy_str0,
    +                                                  
self.user_specified_class_size1,
    +                                                  
self.input_class_level_counts1,
    +                                                  self.output_table_size1))
    +        self.assertEqual({'a': (20, NOSAMPLE), 'b': (25, UNDERSAMPLE), 
'c': (25, NOSAMPLE)},
    +                         
_get_target_level_counts(self.sampling_strategy_str0,
    +                                                  
self.user_specified_class_size2,
    +                                                  
self.input_class_level_counts1,
    +                                                  self.output_table_size1))
    +        self.assertEqual({'a': (30, OVERSAMPLE), 'b': (30, NOSAMPLE), 'c': 
(25, NOSAMPLE)},
    +                         
_get_target_level_counts(self.sampling_strategy_str0,
    +                                                  
self.user_specified_class_size3,
    +                                                  
self.input_class_level_counts1,
    +                                                  self.output_table_size1))
    +        # Test cases for user defined class level samples, with output 
table size
    +        self.assertEqual({'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 
'c': (10, UNDERSAMPLE)},
    +                         
_get_target_level_counts(self.sampling_strategy_str0,
    +                                                  
self.user_specified_class_size1,
    +                                                  
self.input_class_level_counts1,
    +                                                  self.output_table_size2))
    +        self.assertEqual({'a': (18, UNDERSAMPLE), 'b': (25, UNDERSAMPLE), 
'c': (18, UNDERSAMPLE)},
    +                         
_get_target_level_counts(self.sampling_strategy_str0,
    +                                                  
self.user_specified_class_size2,
    +                                                  
self.input_class_level_counts1,
    +                                                  self.output_table_size2))
    +        self.assertEqual({'a': (30, OVERSAMPLE), 'b': (15, UNDERSAMPLE), 
'c': (15, UNDERSAMPLE)},
    +                         
_get_target_level_counts(self.sampling_strategy_str0,
    +                                                  
self.user_specified_class_size3,
    +                                                  
self.input_class_level_counts1,
    +                                                  self.output_table_size2))
    +        # Test cases for UNIFORM, OVERSAMPLE, and UNDERSAMPLE without any 
output table size
    +        self.assertEqual({'a': (25, OVERSAMPLE), 'b': (25, UNDERSAMPLE), 
'c': (25, UNDERSAMPLE)},
    +                         
_get_target_level_counts(self.sampling_strategy_str1,
    +                                                  
self.user_specified_class_size0,
    +                                                  
self.input_class_level_counts1,
    +                                                  self.output_table_size1))
    +        self.assertEqual({'a': (30, OVERSAMPLE), 'b': (30, UNDERSAMPLE), 
'c': (30, OVERSAMPLE)},
    +                         
_get_target_level_counts(self.sampling_strategy_str2,
    +                                                  
self.user_specified_class_size0,
    +                                                  
self.input_class_level_counts1,
    +                                                  self.output_table_size1))
    +        self.assertEqual({'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 
'c': (20, UNDERSAMPLE)},
    +                         
_get_target_level_counts(self.sampling_strategy_str3,
    +                                                  
self.user_specified_class_size0,
    +                                                  
self.input_class_level_counts1,
    +                                                  self.output_table_size1))
    +        # Test cases for UNIFORM with output table size
    +        self.assertEqual({'a': (20, UNDERSAMPLE), 'b': (20, UNDERSAMPLE), 
'c': (20, UNDERSAMPLE)},
    +                         
_get_target_level_counts(self.sampling_strategy_str1,
    +                                                  
self.user_specified_class_size0,
    +                                                  
self.input_class_level_counts1,
    +                                                  self.output_table_size2))
    +
    +    def test__get_sampling_strategy_specific_dict(self):
    +        # Test cases for getting sampling strategy specific counts
    +        target_level_counts_1 = {'a': (25, OVERSAMPLE), 'b': (25, 
UNDERSAMPLE), 'c': (25, NOSAMPLE)}
    +        target_level_counts_2 = {'a': (25, OVERSAMPLE), 'b': (25, 
UNDERSAMPLE)}
    +        target_level_counts_3 = {'a': (25, OVERSAMPLE), 'b': (25, 
NOSAMPLE), 'c': (25, NOSAMPLE)}
    +        self.assertEqual(({'b': 25}, {'a': 25}, {'c': 25}),
    +                         
_get_sampling_strategy_specific_dict(target_level_counts_1))
    +        self.assertEqual(({'b': 25}, {'a': 25}, {}),
    +                         
_get_sampling_strategy_specific_dict(target_level_counts_2))
    +        self.assertEqual(({}, {'a': 25}, {'c': 25, 'b': 25}),
    +                         
_get_sampling_strategy_specific_dict(target_level_counts_3))
    +
    +
    +if __name__ == '__main__':
    +    unittest.main()
    +    # print(_get_with_replacement_subquery('madlib', 
'madlibtestdata.log_ornstein_wi',
    --- End diff --
    
    Can we remove this commented out code ?


---

Reply via email to