Github user orhankislal commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/223#discussion_r161865238
  
    --- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
    @@ -0,0 +1,994 @@
    +# coding=utf-8
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one
    +# or more contributor license agreements.  See the NOTICE file
    +# distributed with this work for additional information
    +# regarding copyright ownership.  The ASF licenses this file
    +# to you under the Apache License, Version 2.0 (the
    +# "License"); you may not use this file EXCEPT in compliance
    +# with the License.  You may obtain a copy of the License at
    +#
    +#   http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing,
    +# software distributed under the License is distributed on an
    +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +# KIND, either express or implied.  See the License for the
    +# specific language governing permissions and limitations
    +# under the License.
    +import math
    +import plpy
    +import re
    +from collections import defaultdict
    +from fractions import Fraction
    +from utilities.control import MinWarning
    +from utilities.utilities import _assert
    +from utilities.utilities import unique_string
    +from utilities.validate_args import table_exists
    +from utilities.validate_args import columns_exist_in_table
    +from utilities.validate_args import table_is_empty
    +from utilities.validate_args import get_cols
    +from utilities.utilities import py_list_to_sql_string
    +
    +
    +m4_changequote(`<!', `!>')
    +
    +def balance_sample(schema_madlib, source_table, output_table, class_col,
    +    class_sizes, output_table_size, grouping_cols, with_replacement, 
**kwargs):
    +
    +    """
    +    Balance sampling function
    +    Args:
    +        @param source_table       Input table name.
    +        @param output_table       Output table name.
    +        @param class_col          Name of the column containing the class 
to be
    +                                  balanced.
    +        @param class_size         Parameter to define the size of the 
different
    +                                  class values.
    +        @param output_table_size  Desired size of the output data set.
    +        @param grouping_cols      The columns columns that defines the 
grouping.
    +        @param with_replacement   The sampling method.
    +
    +    """
    +    with MinWarning("warning"):
    +
    +        class_counts = unique_string(desp='class_counts')
    +        desired_sample_per_class = 
unique_string(desp='desired_sample_per_class')
    +        desired_counts = unique_string(desp='desired_counts')
    +
    +        if not class_sizes or class_sizes.strip().lower() in ('null', ''):
    +            class_sizes = 'uniform'
    +
    +        _validate_strs(source_table, output_table, class_col, class_sizes,
    +            output_table_size, grouping_cols, with_replacement)
    +
    +        source_table_columns = ','.join(get_cols(source_table))
    +        grp_by = "GROUP BY {0}".format(class_col)
    +
    +        _create_frequency_distribution(class_counts, source_table, 
class_col)
    +        temp_views = [class_counts]
    +
    +        if class_sizes.lower() == 'undersample' and not with_replacement:
    +            """
    +                Random undersample without replacement.
    +                Randomly order the rows and give a unique (per class)
    +                identifier to each one.
    +                Select rows that have identifiers under the target limit.
    +            """
    +            _undersampling_with_no_replacement(source_table, output_table, 
class_col,
    +                    class_sizes, output_table_size, grouping_cols, 
with_replacement,
    +                    class_counts, source_table_columns)
    +
    +            _delete_temp_views(temp_views)
    +            return
    +
    +        """
    +            Create views for true and desired sample sizes of classes
    +        """
    +        """
    +            include_unsampled_classes tracks is unsampled classes are 
desired or not.
    +            include_unsampled_classes is always true in output_table_size 
Null cases but changes given values of desired sample class sizes in 
comma-delimited classsize paramter.
    +        """
    +        include_unsampled_classes = True
    +        sampling_with_comma_delimited_class_sizes = class_sizes.find(':') 
> 0
    +
    +        if sampling_with_comma_delimited_class_sizes:
    +            """
    +                Compute sample sizes based on
    +                comman-delimited list of class_sizes
    +                and/or output_table_size
    +            """
    +            class_sizes, include_unsampled_classes = 
_validate_format_and_values(class_sizes, source_table,
    +                            class_col, output_table_size, class_counts, 
include_unsampled_classes)
    +
    +            """
    +                 Only valid condition for sampling is desired_sample_sizes 
<= output_table_size
    +            """
    +            
temp_views.extend(_create_desired_and_actual_sampling_views(class_counts,
    +                    desired_sample_per_class, desired_counts
    +                    , source_table, output_table, class_col
    +                    , class_sizes, output_table_size, 
include_unsampled_classes))
    +
    +        if class_sizes.lower() == 'uniform':
    +            """
    +                Compute sample sizes based on
    +                uniform distribution of class sizes
    +            """
    +            temp_views.extend(_compute_uniform_class_sizes(
    +                class_counts, desired_sample_per_class, desired_counts
    +                , source_table, output_table, class_col, class_sizes,
    +                output_table_size))
    +
    +        oversampling_specific_classes = False
    +        desired_undersample_class_sizes = defaultdict(str)
    +
    +        if sampling_with_comma_delimited_class_sizes or 
class_sizes.lower() == 'uniform':
    +
    +            oversampling_specific_classes = plpy.execute("""
    +                SELECT * FROM {desired_sample_per_class}
    +                WHERE category = 'oversample'
    +                """.format(**locals())).nrows() > 0
    +            if oversampling_specific_classes:
    +                with_replacement = True
    +
    +            undersampling_res = plpy.execute("""
    +                SELECT array_agg(classes::text || ':' || 
sample_class_size::text)
    +                        as undersample_set FROM {desired_sample_per_class}
    +                WHERE category = 'undersample'
    +                """.format(**locals()))
    +            if undersampling_res.nrows() > 0 and 
undersampling_res[0]['undersample_set'] is not None:
    +                for val in undersampling_res[0]['undersample_set']:
    +                    desired_undersample_class_sizes[val.split(':')[0]] = 
val.split(':')[1]
    +
    +        if class_sizes.lower() == 'oversample':
    +            """
    +                oversampling with replacement
    +            """
    +            with_replacement = True
    +            func_name = 'max'
    +
    +        if class_sizes.lower() == 'undersample' and with_replacement:
    +            """
    +                Undersampling with replacement.
    +            """
    +            func_name = 'min'
    +
    +        if with_replacement:
    +            """
    +                Random sample with replacement.
    +                Undersample will have func_name set to min
    +                Oversample will have func_name set to max.
    +            """
    +            """
    +                Create row identifiers for each row wrt the class
    +            """
    +            classwise_row_numbering_sql = """
    +                SELECT
    +                    *,
    +                    row_number() OVER(PARTITION BY {class_col})
    +                    AS __row_no
    +                FROM
    +                    {source_table}
    +                """.format(**locals())
    +            if oversampling_specific_classes:
    +                select_oversample_classes = """ WHERE {class_col}::text in
    +                            (SELECT classes
    +                            FROM {desired_sample_per_class}
    +                            WHERE category like 'oversample')
    +                """.format(**locals())
    +                classwise_row_numbering_sql += select_oversample_classes
    +
    +            """
    +                Create independent random values
    +                for each class that has a different row count than the 
target
    +            """
    +            if oversampling_specific_classes:
    +                random_targetclass_size_sample_number_gen_sql = """
    +                SELECT
    +                    {desired_sample_per_class}.classes,
    +                    generate_series(1, sample_class_size::int) AS _i,
    +                    ((random()*({class_counts}.class_count-1)+1)::int)
    +                    AS __row_no
    +                FROM
    +                    {class_counts},
    +                    {desired_sample_per_class}
    +                WHERE
    +                {desired_sample_per_class}.classes = {class_counts}.classes
    +                AND category like 'oversample'
    +                """.format(**locals())
    +            else:
    +                random_targetclass_size_sample_number_gen_sql = """
    +                SELECT
    +                    classes,
    +                    generate_series(1, target_class_size::int) AS _i,
    +                    ((random()*({class_counts}.class_count-1)+1)::int)
    +                    AS __row_no
    +                FROM
    +                    (SELECT
    +                        {func_name}(class_count) AS target_class_size
    +                    FROM {class_counts})
    +                        AS foo,
    +                        {class_counts}
    +                    WHERE {class_counts}.class_count != target_class_size
    +                """.format(**locals())
    +
    +            """
    +                Match random values with the row identifiers
    +            """
    +            sample_otherclass_set = """
    +                SELECT
    +                     {source_table_columns}
    +                FROM
    +                    ({classwise_row_numbering_sql}) AS f1
    +                RIGHT JOIN
    +                    ({random_targetclass_size_sample_number_gen_sql}) AS
    +                        f2
    +                ON (f1.__row_no = f2.__row_no) AND
    +                (f1.{class_col}::text = f2.classes)
    +                """.format(**locals())
    +
    +            if not oversampling_specific_classes:
    +                """
    +                    Find classes with target number of rows
    +                """
    +                targetclass_set = """
    +                    SELECT
    +                        {source_table_columns}
    +                    FROM {source_table}
    +                    WHERE {class_col}::text IN
    +                        (SELECT
    +                            classes AS target_class
    +                         FROM {class_counts}
    +                         WHERE class_count in
    +                            (SELECT {func_name}(class_count) FROM 
{class_counts}))
    +                    """.format(**locals())
    +
    +                """
    +                    Combine target and other sampled classes
    +                """
    +                output_sql = """
    +                    CREATE TABLE {output_table} AS (
    +                        SELECT {source_table_columns}
    +                        FROM
    +                            ({targetclass_set}) AS a
    +                        UNION ALL
    +                            ({sample_otherclass_set}))
    +                    """.format(**locals())
    +                plpy.execute(output_sql)
    +
    +                _delete_temp_views(temp_views)
    +                return
    +
    +        """
    +            Unsampled classes
    +        """
    +        nosample_classset_sql = """
    +            SELECT
    +                {source_table_columns}
    +            FROM {source_table}
    +            WHERE {class_col}::text IN
    +                (SELECT
    +                    classes
    +                 FROM {desired_sample_per_class}
    +                 WHERE category like 'nosample')
    +                    """.format(**locals())
    +        """
    +            Union all Undersampled classes
    +        """
    +        undersampling_classset_sql = ''
    +        if len(desired_undersample_class_sizes) > 0:
    +            undersampling_classset_sql = ' UNION ALL'.join("""
    +                (SELECT {source_table_columns}
    +                FROM {source_table}
    +                WHERE {class_col} = '{clas}'
    +                ORDER BY random()
    +                LIMIT {limit_bound})
    +                """.format(source_table_columns=source_table_columns,
    +                            source_table=source_table,
    +                            class_col=class_col,
    +                            limit_bound=clas_limit,
    +                            clas=clas) for clas, clas_limit in 
desired_undersample_class_sizes.iteritems())
    +            undersampling_classset_sql = ' UNION ALL ' + 
undersampling_classset_sql
    +
    +        """
    +            Union all Oversampled classes
    +        """
    +        oversampling_specific_classes_classset_sql = ''
    +        if oversampling_specific_classes:
    +            oversampling_specific_classes_classset_sql = """
    +                    UNION ALL
    +                        ({sample_otherclass_set})
    +                        """.format(**locals())
    +
    +        if (oversampling_specific_classes or 
len(desired_undersample_class_sizes) > 0):
    +            """
    +                Combine all sampled and/or unsampled classes
    +            """
    +            if not include_unsampled_classes:
    +                nosample_classset_sql.replace('nosample', '')
    +
    +            output_sql = """
    +                CREATE TABLE {output_table} AS (
    +                    SELECT {source_table_columns}
    +                    FROM
    +                        ({nosample_classset_sql}) AS a
    +                    {oversampling_specific_classes_classset_sql}
    +                    {undersampling_classset_sql})
    +                """.format(**locals())
    +
    +            plpy.execute(output_sql)
    +
    +        _delete_temp_views(temp_views)
    +    return
    +
    +"""
    +    Delete all temp views
    +"""
    +def _delete_temp_views(temp_views):
    +    for temp_view in temp_views:
    +            plpy.execute("DROP VIEW IF EXISTS {0} 
cascade".format(temp_view))
    +    return
    +
    +"""
    +    Random undersample without replacement.
    +"""
    +def _undersampling_with_no_replacement(source_table, output_table, 
class_col,
    +        class_sizes, output_table_size, grouping_cols, with_replacement,
    +        class_counts, source_table_columns):
    +
    +    distinct_class_labels = plpy.execute("""
    +        SELECT array_agg(DISTINCT {class_col}::text) AS labels
    +        FROM {source_table}
    +        """.format(**locals()))[0]['labels']
    +
    +    limit_bound = plpy.execute("""
    +        SELECT MIN(class_count)::int AS min
    +        FROM {class_counts}""".format(**locals()))[0]['min']
    +
    +    minority_class = plpy.execute("""
    +        SELECT array_agg(classes::text) as minority_class
    +        FROM {class_counts}
    +        WHERE class_count = {limit_bound}
    +        """.format(**locals()))[0]['minority_class']
    +
    +    distinct_class_labels = [cl for cl in distinct_class_labels
    +                                if cl not in minority_class]
    +
    +    foo_table = unique_string(desp='foo')
    +    start_output_qry = """
    +        SELECT {source_table_columns}
    +        FROM (
    +            SELECT {source_table_columns}
    +            FROM {source_table}
    +            WHERE {class_col}::text = '{dcl}'
    +            ORDER BY random()
    +            LIMIT {limit_bound}
    +            ) AS {foo_table}
    +        UNION """.format(dcl=distinct_class_labels[0], **locals())
    +
    +    union_qry = ' UNION '.join("""
    +        (SELECT {source_table_columns}
    +        FROM {source_table}
    +        WHERE {class_col}::text = '{val}'
    +        ORDER BY random()
    +        LIMIT {limit_bound})
    +        """.format(source_table_columns=source_table_columns,
    +                    source_table=source_table,
    +                    class_col=class_col,
    +                    limit_bound=limit_bound,
    +                    val=val) for val in distinct_class_labels[1:])
    +
    +    min_class_tuple = "('" + "','".join([str(a) for a in minority_class]) 
+ "')"
    +
    +    minority_sql = """ UNION
    +        SELECT {source_table_columns}
    +        FROM {source_table}
    +        WHERE {class_col} IN {min_class_tuple} """.format(**locals())
    +
    +    output_sql = """
    +        CREATE TABLE {output_table} AS (
    +            {start_output_qry}
    +            {union_qry}
    +            {minority_sql} )""".format(**locals())
    +    plpy.execute(output_sql)
    +
    +"""
    +    Captures cases where classes are specified multiple time in 
comma-delimited string.
    +    e.g. class_sizes = '5:6,5:4'
    +"""
    +class UniqueDict(dict):
    +    def __setitem__(self, classkey, class_size):
    +        if classkey not in self:
    +            """
    +             float(class_size).is_integer() ensures only whole numbers are 
added as
    +             class's sample size
    +            """
    +            if (class_size > 0.0 or class_size < 1.0) or (class_size >= 
1.0 and float(class_size).is_integer()):
    +                dict.__setitem__(self, classkey, class_size)
    +            else:
    +                plpy.error("Sample: Sample size should be a fraction 
between (0.0,1.0) or a whole number greater than 1")
    +        else:
    +            plpy.error("Sample: Repeated classes in class_sizes")
    +
    +"""
    +    Check if the class sizes (passed as strings) are alphanumeric, float 
or whole numbers.
    +    Error out on alphanumeric values.
    +    e.g. class_sizes = '3:a,5:b
    +    Returns type of the size as int or float.
    +"""
    +def _check_value_type(value):
    +    try:
    +        float(value)
    +    except ValueError:
    +        plpy.error("Sample: Specify either fractions (0.0,1.0) or whole 
numbers for class sample size.")
    +
    +    if re.match("^\d+?\.\d+?$", value) is not None and not 
float(value).is_integer():
    +        valueType = float
    +    elif (re.match(r"[-+]?\d+$", value) is not None):
    +        valueType = int
    +    return valueType
    +"""
    +    Check if classes are present in class_col
    +"""
    +def _validate_classes(all_classes, source_table, class_col):
    +
    +    nonexisting_classes = plpy.execute("""
    +            SELECT
    +                unnest(ARRAY[{all_classes}])
    +            EXCEPT
    +            SELECT
    +                distinct({class_col}::text)
    +            FROM {source_table}
    +        """.format(**locals()))
    +
    +    if nonexisting_classes.nrows() > 0:
    +        plpy.error("""Sample: Specified classes do not exist in
    +            {class_col}""".format(**locals()))
    +
    +"""
    +    Checks the format and values of classes and their respective sizes 
specified in comman-delimited string.
    +    1. Checks if the classes specified are present in the source table.
    +    2. Checks for total sample size to be between (0.0,1.0]
    +    3. Checks for output_table_size < total desired size of the classes
    +    specified in class_sizes
    +    4. Checks for cases when only fractions or whole numbers specified in 
sample sizes
    +    5. Checks for value is a whole number greater then 1
    +    6. Checks if same class is specified multiple times in class_sizes
    +    7. Checks is classes are present in class_col
    +"""
    +def _validate_format_and_values(class_sizes, source_table, class_col,
    +                    output_table_size, class_counts, 
include_unsampled_classes):
    +
    +    class_sizes_arr = class_sizes.split(',')
    +
    +    cs_dict = UniqueDict(defaultdict())
    +
    +    numeric_value_sum = 0
    +    fraction_value_sum = 0.0
    +
    +    for x in class_sizes_arr:
    +        class_and_size = x.split(':')
    +        valueType = _check_value_type(class_and_size[1])
    +        # Following error type is invalidated by Frank as of 29th Dec.
    +        # if _check_value_type(val[1].strip(), valueType) != valueType:
    +            #plpy.error("Sample: Specify either fractions or whole number 
values for ALL classes.")
    +        fraction_value_sum += valueType(class_and_size[1]) if valueType == 
float else 0.0
    +        numeric_value_sum += valueType(class_and_size[1]) if valueType != 
float else 0
    +        cs_dict[class_and_size[0].strip()] = valueType(class_and_size[1])
    +
    +    """
    +        Check to see if specified classes are present in the class_col
    +    """
    +    all_classes = str(cs_dict.keys())[1:-1]
    +    _validate_classes(all_classes, source_table, class_col)
    +
    +    """
    +        Error out if fraction_value_sum is greater than 1.0 or when 
fraction sum is 1.0 and other classes with whole numbers as class sizes are 
also specified.
    +    """
    +    if fraction_value_sum > 1.0 or (fraction_value_sum == 1.0 and 
numeric_value_sum != 0):
    +        plpy.error("""Sample: Fraction sum < 1.0, when any other class is 
also specified as class_name:class_size-in-whole-numbers. Fraction sum can be 
at most 1.0.
    +            """.format(**locals()))
    +
    +    total_table_size = plpy.execute("""
    +            SELECT
    +                count(*) AS total
    +            FROM {source_table}
    +        """.format(**locals()))[0]['total']
    +
    +    """
    +        Compute class sizes when no Fractions are specified in class_sizes
    +    """
    +    if Fraction(fraction_value_sum) == Fraction(0.0):
    +        if (not output_table_size):
    +            # Sample remaining classes uniformly
    +            return class_sizes, include_unsampled_classes
    +
    +        if output_table_size < numeric_value_sum:
    +            plpy.error("""Sample: Output table size ({output_table_size}) 
must be more than total specified sample size i.e. 
{numeric_value_sum}""".format(**locals()))
    +
    +        if output_table_size == numeric_value_sum:
    +            # Do not sample other classes
    +            return class_sizes, not include_unsampled_classes
    +
    +        # Sample remaining classes uniformly with target table size
    +        return class_sizes, include_unsampled_classes
    +
    +    """
    +        Compute class sizes when only Fractions are specified in 
class_sizes, which also sum to 1.0
    +    """
    +    if Fraction(fraction_value_sum) == Fraction(1.0):
    +        # Do not sample other classes
    +        return _compute_class_sizes(cs_dict, total_table_size)[0], not 
include_unsampled_classes
    +
    +    """
    +        Compute sample classs size when both fractions and whole numbers 
are mentioned in class_size comma-delimited string
    +    """
    +    if Fraction(fraction_value_sum) > Fraction(0.0):
    +
    +        sum_remaining_class_samples = plpy.execute("""
    +                SELECT sum({class_counts}.class_count) AS 
remaining_classes FROM {class_counts}
    +                WHERE classes not IN ({all_classes})
    +                """.format(**locals()))
    +        """
    +            When output_table_size is Null. Use following example logic to 
compute desired sample sizes.
    +
    +            Suppose male=.4,output_table_size= NULL and let’s say there 
are 2 other categorical values female=10M, other=1M
    +            Use the following logic to calculate 'computed' 
output_table_size x
    +                     .4x + 10M + 1M = x
    +            where x = computed output_table_size. Here x = 18.3M
    +        """
    +        if (not output_table_size):
    +            y = 1.0 - fraction_value_sum
    +
    +            if sum_remaining_class_samples.nrows() > 0:
    +                numeric_value_sum += 
sum_remaining_class_samples[0]['remaining_classes']
    +
    +            if numeric_value_sum == 0:
    +                # A rare case happens WHEN there is ONLY one class present 
in class_col
    +                return _compute_class_sizes(cs_dict, total_table_size)[0], 
not include_unsampled_classes
    +
    +            # Compute total_desired_sample_size as x
    +            x = math.ceil(float(numeric_value_sum) / y)
    +            class_size, _ = _compute_class_sizes(cs_dict, x)
    +            return class_size, include_unsampled_classes
    +
    +        """
    +            When output_table_size is given,
    +            compute the total_desired_sample_size and perform checks to 
ensure validity of class_size with total_desired_sample_size
    +        """
    +        output_table_size = float(output_table_size)
    +        class_size, total_desired_sample_size = 
_compute_class_sizes(cs_dict, output_table_size)
    +
    +        ## Cases when total desired sample size > specified output table 
size
    +        if total_desired_sample_size > output_table_size:
    +            plpy.error("""Sample: Output table size ({output_table_size}) 
must be more than total desired sample size i.e 
{total_desired_sample_size}""".format(**locals()))
    +
    +        if total_desired_sample_size == output_table_size:
    +            # Do not sample other classes
    +            return class_size, not include_unsampled_classes
    +
    +        if total_desired_sample_size < output_table_size:
    +            # Sample other classes uniformly
    +            return class_size, include_unsampled_classes
    +
    +"""
    +    Proportions of class sizes are multiplied by output_table_size to get 
desired whole number value for class sizes.
    +    A total_desired_size is sum of all whole number class sizes
    +"""
    +def _compute_class_sizes(cs_dict, x):
    +    class_size = ''
    +    total_desired_size = 0
    +    for clas, class_val in cs_dict.iteritems():
    +        if float(class_val).is_integer():
    +            total_desired_size += class_val
    +            class_size += str(clas) + ':' + str(class_val) + ','
    +        else:
    +            class_val = int(round(class_val * x, 0))
    +            total_desired_size += class_val
    +            class_size += str(clas) + ':' + str(class_val) + ','
    +
    +    return class_size[:-1], total_desired_size
    +
    +"""
    +    Create view to store class counts of classes in class_col
    +"""
    +def _create_frequency_distribution(class_counts, source_table, class_col):
    +
    +    grp_by = "GROUP BY {0}".format(class_col)
    +    plpy.execute(""" CREATE VIEW {class_counts} AS (
    +                 SELECT
    +                    {class_col}::text AS classes,
    +                    count(*) AS class_count
    +                 FROM {source_table}
    +                    {grp_by})
    +             """.format(**locals()))
    +
    +"""
    +    Create view desired_sample_per_class which contains the desired 
sampling for each class.
    + """
    +def _create_desired_and_actual_sampling_views(class_counts, 
desired_sample_per_class, desired_counts
    +            , source_table, output_table, class_col, class_sizes, 
output_table_size, include_unsampled_classes):
    +
    +    # Split class_sizes
    +    if len(class_sizes.split(',')) > 1:
    +        values = '\',\''.join(class_sizes.split(','))
    +    else:
    +        values = class_sizes.split(',')[0]
    +
    +    create_desired_counts_sql = """
    +        CREATE VIEW {desired_counts} AS (
    +        SELECT val[1]::text AS desired_classes,
    +               val[2]::int AS desired_values
    +        FROM
    +            (SELECT regexp_split_to_array(unnest(ARRAY['{values}']),E':') 
AS val)
    +        AS foo
    +        )
    +    """.format(**locals())
    +    temp_views = [desired_counts]
    +    plpy.execute(create_desired_counts_sql)
    +
    +    if not include_unsampled_classes:
    +        """
    +            No extra/other classes desired, other than the ones
    +            specified in the class_sizes parameter
    +        """
    +        desired_sample_per_class_sql = """
    +            CREATE VIEW {desired_sample_per_class} AS (
    +                SELECT
    +                    desired_classes AS classes,
    +                    desired_values::int AS sample_class_size,
    +                    1::int as actual_class_size,
    +                    CASE
    +                      WHEN (desired_values - class_count) = 0 THEN 
'nosample'
    +                      WHEN (desired_values - class_count) > 0 THEN 
'oversample' ELSE 'undersample'
    +                    END AS category
    +            FROM
    +                {class_counts}
    +                JOIN
    +                {desired_counts}
    +                ON ({class_counts}.classes = 
{desired_counts}.desired_classes)
    +            )
    +            """.format(**locals())
    +    else:
    +        """
    +            If output_table_size is specified,
    +            uniformly sample remaining classes.
    +        """
    +        remaining_class_samples = ''
    +        if output_table_size:
    +
    +            remaining_class_samples = 
_sample_remaining_classes_uniformly_per_output_table_size(
    +                class_counts, desired_counts, output_table_size)
    +
    +            desired_sample_per_class_sql = """
    +            CREATE VIEW {desired_sample_per_class} AS (
    +                SELECT
    +                        classes,
    +                        desired_values::int AS sample_class_size,
    +                        class_count AS actual_class_size,
    +                        CASE
    +                          WHEN (desired_values - class_count) = 0 THEN 
'nosample'
    +                          WHEN (desired_values - class_count) > 0 THEN 
'oversample' ELSE 'undersample'
    +                        END AS category
    +                FROM (
    +                    SELECT
    +                            classes,
    +                            CASE (SELECT 1 WHERE desired_values is NULL)
    +                                WHEN 1 THEN extra_values
    +                                ELSE desired_values
    +                            END AS desired_values,
    +                            class_count
    +                    FROM
    +                        {class_counts}
    +                        LEFT JOIN
    +                        {desired_counts}
    +                        ON ({class_counts}.classes = 
{desired_counts}.desired_classes)
    +                        LEFT JOIN
    +                        {remaining_class_samples}
    +                        ON ({class_counts}.classes = 
{remaining_class_samples}.extra_classes)
    +                        ) AS foo
    +                )
    +            """.format(**locals())
    +        else:
    +            """
    +                If output_table_size is not specified,
    +                return remaining classes as is.
    +            """
    +            desired_sample_per_class_sql = """
    +                CREATE VIEW {desired_sample_per_class} AS (
    +                    SELECT
    +                            classes,
    +                            desired_values::int AS sample_class_size,
    +                            class_count as actual_class_size,
    +                            CASE
    +                              WHEN (desired_values - class_count) = 0 THEN 
'nosample'
    +                              WHEN (desired_values - class_count) > 0 THEN 
'oversample' ELSE 'undersample'
    +                            END AS category
    +                    FROM (
    +                        SELECT
    +                                classes,
    +                                CASE (SELECT 1 WHERE desired_values is 
NULL)
    +                                    WHEN 1 THEN class_count
    +                                    ELSE desired_values
    +                                END AS desired_values,
    +                                class_count
    +                        FROM
    +                            {class_counts}
    +                             LEFT JOIN
    +                            {desired_counts}
    +                            ON ({class_counts}.classes = 
{desired_counts}.desired_classes)
    +                            ) AS foo
    +                    )
    +                """.format(**locals())
    +        if remaining_class_samples:
    +            temp_views.append(remaining_class_samples)
    +
    +    plpy.execute(desired_sample_per_class_sql)
    +    temp_views.append(desired_sample_per_class)
    +    return temp_views
    +
    +"""
    +    Sample the remaining classes in output_table_size - (total calculated 
class sizes).
    +"""
    +def _sample_remaining_classes_uniformly_per_output_table_size(class_counts
    +                , desired_counts, output_table_size):
    +
    +    remaining_class_samples = unique_string(desp='remaining_class_samples')
    +    foo_table = unique_string(desp='foo')
    +    remaining_class_samples_sql = """
    +        CREATE VIEW {remaining_class_samples} as (
    --- End diff --
    
    Maybe we should read these counts and handle simple computations in the 
python layer instead of creating a view. It would be easier to read IMHO.


---

Reply via email to