Github user njayaram2 commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/223#discussion_r162515193
  
    --- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
    @@ -0,0 +1,994 @@
    +# coding=utf-8
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one
    +# or more contributor license agreements.  See the NOTICE file
    +# distributed with this work for additional information
    +# regarding copyright ownership.  The ASF licenses this file
    +# to you under the Apache License, Version 2.0 (the
    +# "License"); you may not use this file EXCEPT in compliance
    +# with the License.  You may obtain a copy of the License at
    +#
    +#   http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing,
    +# software distributed under the License is distributed on an
    +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +# KIND, either express or implied.  See the License for the
    +# specific language governing permissions and limitations
    +# under the License.
    +import math
    +import plpy
    +import re
    +from collections import defaultdict
    +from fractions import Fraction
    +from utilities.control import MinWarning
    +from utilities.utilities import _assert
    +from utilities.utilities import unique_string
    +from utilities.validate_args import table_exists
    +from utilities.validate_args import columns_exist_in_table
    +from utilities.validate_args import table_is_empty
    +from utilities.validate_args import get_cols
    +from utilities.utilities import py_list_to_sql_string
    +
    +
    +m4_changequote(`<!', `!>')
    +
    +def balance_sample(schema_madlib, source_table, output_table, class_col,
    +    class_sizes, output_table_size, grouping_cols, with_replacement, 
**kwargs):
    +
    +    """
    +    Balance sampling function
    +    Args:
    +        @param source_table       Input table name.
    +        @param output_table       Output table name.
    +        @param class_col          Name of the column containing the class 
to be
    +                                  balanced.
    +        @param class_size         Parameter to define the size of the 
different
    +                                  class values.
    +        @param output_table_size  Desired size of the output data set.
    +        @param grouping_cols      The columns columns that defines the 
grouping.
    +        @param with_replacement   The sampling method.
    +
    +    """
    +    with MinWarning("warning"):
    +
    +        class_counts = unique_string(desp='class_counts')
    +        desired_sample_per_class = 
unique_string(desp='desired_sample_per_class')
    +        desired_counts = unique_string(desp='desired_counts')
    +
    +        if not class_sizes or class_sizes.strip().lower() in ('null', ''):
    +            class_sizes = 'uniform'
    +
    +        _validate_strs(source_table, output_table, class_col, class_sizes,
    +            output_table_size, grouping_cols, with_replacement)
    +
    +        source_table_columns = ','.join(get_cols(source_table))
    +        grp_by = "GROUP BY {0}".format(class_col)
    +
    +        _create_frequency_distribution(class_counts, source_table, 
class_col)
    +        temp_views = [class_counts]
    +
    +        if class_sizes.lower() == 'undersample' and not with_replacement:
    +            """
    +                Random undersample without replacement.
    +                Randomly order the rows and give a unique (per class)
    +                identifier to each one.
    +                Select rows that have identifiers under the target limit.
    +            """
    +            _undersampling_with_no_replacement(source_table, output_table, 
class_col,
    +                    class_sizes, output_table_size, grouping_cols, 
with_replacement,
    +                    class_counts, source_table_columns)
    +
    +            _delete_temp_views(temp_views)
    +            return
    +
    +        """
    +            Create views for true and desired sample sizes of classes
    +        """
    +        """
    +            include_unsampled_classes tracks is unsampled classes are 
desired or not.
    +            include_unsampled_classes is always true in output_table_size 
Null cases but changes given values of desired sample class sizes in 
comma-delimited classsize paramter.
    +        """
    +        include_unsampled_classes = True
    +        sampling_with_comma_delimited_class_sizes = class_sizes.find(':') 
> 0
    +
    +        if sampling_with_comma_delimited_class_sizes:
    +            """
    +                Compute sample sizes based on
    +                comman-delimited list of class_sizes
    +                and/or output_table_size
    +            """
    +            class_sizes, include_unsampled_classes = 
_validate_format_and_values(class_sizes, source_table,
    +                            class_col, output_table_size, class_counts, 
include_unsampled_classes)
    +
    +            """
    +                 Only valid condition for sampling is desired_sample_sizes 
<= output_table_size
    +            """
    +            
temp_views.extend(_create_desired_and_actual_sampling_views(class_counts,
    +                    desired_sample_per_class, desired_counts
    +                    , source_table, output_table, class_col
    +                    , class_sizes, output_table_size, 
include_unsampled_classes))
    +
    +        if class_sizes.lower() == 'uniform':
    +            """
    +                Compute sample sizes based on
    +                uniform distribution of class sizes
    +            """
    +            temp_views.extend(_compute_uniform_class_sizes(
    +                class_counts, desired_sample_per_class, desired_counts
    +                , source_table, output_table, class_col, class_sizes,
    +                output_table_size))
    +
    +        oversampling_specific_classes = False
    +        desired_undersample_class_sizes = defaultdict(str)
    +
    +        if sampling_with_comma_delimited_class_sizes or 
class_sizes.lower() == 'uniform':
    +
    +            oversampling_specific_classes = plpy.execute("""
    +                SELECT * FROM {desired_sample_per_class}
    +                WHERE category = 'oversample'
    +                """.format(**locals())).nrows() > 0
    +            if oversampling_specific_classes:
    +                with_replacement = True
    +
    +            undersampling_res = plpy.execute("""
    +                SELECT array_agg(classes::text || ':' || 
sample_class_size::text)
    +                        as undersample_set FROM {desired_sample_per_class}
    +                WHERE category = 'undersample'
    +                """.format(**locals()))
    +            if undersampling_res.nrows() > 0 and 
undersampling_res[0]['undersample_set'] is not None:
    +                for val in undersampling_res[0]['undersample_set']:
    +                    desired_undersample_class_sizes[val.split(':')[0]] = 
val.split(':')[1]
    +
    +        if class_sizes.lower() == 'oversample':
    +            """
    +                oversampling with replacement
    +            """
    +            with_replacement = True
    +            func_name = 'max'
    +
    +        if class_sizes.lower() == 'undersample' and with_replacement:
    +            """
    +                Undersampling with replacement.
    +            """
    +            func_name = 'min'
    +
    +        if with_replacement:
    +            """
    +                Random sample with replacement.
    +                Undersample will have func_name set to min
    +                Oversample will have func_name set to max.
    +            """
    +            """
    +                Create row identifiers for each row wrt the class
    +            """
    +            classwise_row_numbering_sql = """
    +                SELECT
    +                    *,
    +                    row_number() OVER(PARTITION BY {class_col})
    +                    AS __row_no
    +                FROM
    +                    {source_table}
    +                """.format(**locals())
    +            if oversampling_specific_classes:
    +                select_oversample_classes = """ WHERE {class_col}::text in
    +                            (SELECT classes
    +                            FROM {desired_sample_per_class}
    +                            WHERE category like 'oversample')
    +                """.format(**locals())
    +                classwise_row_numbering_sql += select_oversample_classes
    +
    +            """
    +                Create independent random values
    +                for each class that has a different row count than the 
target
    +            """
    +            if oversampling_specific_classes:
    +                random_targetclass_size_sample_number_gen_sql = """
    +                SELECT
    +                    {desired_sample_per_class}.classes,
    +                    generate_series(1, sample_class_size::int) AS _i,
    +                    ((random()*({class_counts}.class_count-1)+1)::int)
    +                    AS __row_no
    +                FROM
    +                    {class_counts},
    +                    {desired_sample_per_class}
    +                WHERE
    +                {desired_sample_per_class}.classes = {class_counts}.classes
    +                AND category like 'oversample'
    +                """.format(**locals())
    +            else:
    +                random_targetclass_size_sample_number_gen_sql = """
    +                SELECT
    +                    classes,
    +                    generate_series(1, target_class_size::int) AS _i,
    +                    ((random()*({class_counts}.class_count-1)+1)::int)
    +                    AS __row_no
    +                FROM
    +                    (SELECT
    +                        {func_name}(class_count) AS target_class_size
    +                    FROM {class_counts})
    +                        AS foo,
    +                        {class_counts}
    +                    WHERE {class_counts}.class_count != target_class_size
    +                """.format(**locals())
    +
    +            """
    +                Match random values with the row identifiers
    +            """
    +            sample_otherclass_set = """
    +                SELECT
    +                     {source_table_columns}
    +                FROM
    +                    ({classwise_row_numbering_sql}) AS f1
    +                RIGHT JOIN
    +                    ({random_targetclass_size_sample_number_gen_sql}) AS
    +                        f2
    +                ON (f1.__row_no = f2.__row_no) AND
    +                (f1.{class_col}::text = f2.classes)
    +                """.format(**locals())
    +
    +            if not oversampling_specific_classes:
    +                """
    +                    Find classes with target number of rows
    +                """
    +                targetclass_set = """
    +                    SELECT
    +                        {source_table_columns}
    +                    FROM {source_table}
    +                    WHERE {class_col}::text IN
    +                        (SELECT
    +                            classes AS target_class
    +                         FROM {class_counts}
    +                         WHERE class_count in
    +                            (SELECT {func_name}(class_count) FROM 
{class_counts}))
    +                    """.format(**locals())
    +
    +                """
    +                    Combine target and other sampled classes
    +                """
    +                output_sql = """
    +                    CREATE TABLE {output_table} AS (
    +                        SELECT {source_table_columns}
    +                        FROM
    +                            ({targetclass_set}) AS a
    +                        UNION ALL
    +                            ({sample_otherclass_set}))
    +                    """.format(**locals())
    +                plpy.execute(output_sql)
    +
    +                _delete_temp_views(temp_views)
    +                return
    +
    +        """
    +            Unsampled classes
    +        """
    +        nosample_classset_sql = """
    +            SELECT
    +                {source_table_columns}
    +            FROM {source_table}
    +            WHERE {class_col}::text IN
    +                (SELECT
    +                    classes
    +                 FROM {desired_sample_per_class}
    +                 WHERE category like 'nosample')
    +                    """.format(**locals())
    +        """
    +            Union all Undersampled classes
    +        """
    +        undersampling_classset_sql = ''
    +        if len(desired_undersample_class_sizes) > 0:
    +            undersampling_classset_sql = ' UNION ALL'.join("""
    +                (SELECT {source_table_columns}
    +                FROM {source_table}
    +                WHERE {class_col} = '{clas}'
    +                ORDER BY random()
    +                LIMIT {limit_bound})
    +                """.format(source_table_columns=source_table_columns,
    +                            source_table=source_table,
    +                            class_col=class_col,
    +                            limit_bound=clas_limit,
    +                            clas=clas) for clas, clas_limit in 
desired_undersample_class_sizes.iteritems())
    +            undersampling_classset_sql = ' UNION ALL ' + 
undersampling_classset_sql
    +
    +        """
    +            Union all Oversampled classes
    +        """
    +        oversampling_specific_classes_classset_sql = ''
    +        if oversampling_specific_classes:
    +            oversampling_specific_classes_classset_sql = """
    +                    UNION ALL
    +                        ({sample_otherclass_set})
    +                        """.format(**locals())
    +
    +        if (oversampling_specific_classes or 
len(desired_undersample_class_sizes) > 0):
    +            """
    +                Combine all sampled and/or unsampled classes
    +            """
    +            if not include_unsampled_classes:
    +                nosample_classset_sql.replace('nosample', '')
    +
    +            output_sql = """
    +                CREATE TABLE {output_table} AS (
    +                    SELECT {source_table_columns}
    +                    FROM
    +                        ({nosample_classset_sql}) AS a
    +                    {oversampling_specific_classes_classset_sql}
    +                    {undersampling_classset_sql})
    +                """.format(**locals())
    +
    +            plpy.execute(output_sql)
    +
    +        _delete_temp_views(temp_views)
    +    return
    +
    +"""
    +    Delete all temp views
    +"""
    +def _delete_temp_views(temp_views):
    +    for temp_view in temp_views:
    +            plpy.execute("DROP VIEW IF EXISTS {0} 
cascade".format(temp_view))
    +    return
    +
    +"""
    +    Random undersample without replacement.
    +"""
    +def _undersampling_with_no_replacement(source_table, output_table, 
class_col,
    +        class_sizes, output_table_size, grouping_cols, with_replacement,
    +        class_counts, source_table_columns):
    +
    +    distinct_class_labels = plpy.execute("""
    +        SELECT array_agg(DISTINCT {class_col}::text) AS labels
    +        FROM {source_table}
    +        """.format(**locals()))[0]['labels']
    +
    +    limit_bound = plpy.execute("""
    +        SELECT MIN(class_count)::int AS min
    +        FROM {class_counts}""".format(**locals()))[0]['min']
    +
    +    minority_class = plpy.execute("""
    +        SELECT array_agg(classes::text) as minority_class
    +        FROM {class_counts}
    +        WHERE class_count = {limit_bound}
    +        """.format(**locals()))[0]['minority_class']
    +
    +    distinct_class_labels = [cl for cl in distinct_class_labels
    +                                if cl not in minority_class]
    +
    +    foo_table = unique_string(desp='foo')
    +    start_output_qry = """
    --- End diff --
    
    Changed the query to have a single query for all.


---

Reply via email to