This is an automated email from the ASF dual-hosted git repository.

domino pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 85eba29  DL: Improve performance of mini-batch preprocessor (#467)
85eba29 is described below

commit 85eba2968b5651b6242a398df569b1f3f6412703
Author: Domino Valdano <dom...@apache.org>
AuthorDate: Thu Jan 9 18:06:56 2020 -0800

    DL: Improve performance of mini-batch preprocessor (#467)
    
    JIRA MADLIB-1334
    
    This commit adds the following optimizations to minibatch preprocessor:
    
    - Skip normalization if normalizing contant is 1.0
    
    - Split the batching query to generate buffer_id's (based of row_id)
    without moving around any data. Previously, calling `ROW_NUMBER()
    OVER()` to add row_id's to the table was causing the data to be gathered
    on the master node and then numbering the rows, which for large datasets
    would be taking most of the time.
    
    - Separate out the JOIN (called for even distribution) as well as
    converting to bytea from the batching query. This avoids any VMEM limit
    issues.
    
    - num_buffers gets rounded up to the nearest multiple of num_segments
    for even distribution across buffers on segments.
    
    - Add new C function `array_to_bytea()` to convert array to bytea, and
    some tests for it.  This is much faster than the python version we were
    using, speeding up the query significantly.
    
    Additionally, this commit adds a new function `plpy_execute_debug()`
    in the utilities module that prints EXPLAIN plans and execution time
    for debugging a specific query.
    
    Co-authored-by: Ekta Khanna <ekha...@pivotal.io>
    
    Co-authored-by: Ekta Khanna <ekha...@pivotal.io>
---
 methods/array_ops/src/pg_gp/array_ops.c            |  30 ++
 methods/array_ops/src/pg_gp/array_ops.sql_in       |  12 +
 methods/array_ops/src/pg_gp/test/array_ops.sql_in  |  71 +++
 .../deep_learning/input_data_preprocessor.py_in    | 520 +++++++++++++++------
 .../deep_learning/madlib_keras_helper.py_in        |   1 -
 .../deep_learning/madlib_keras_validator.py_in     |   3 +-
 .../test/input_data_preprocessor.sql_in            | 136 ++++--
 .../test/madlib_keras_cifar.setup.sql_in           |   4 +-
 .../unit_tests/test_input_data_preprocessor.py_in  |   8 +
 .../utilities/minibatch_preprocessing.py_in        |   7 +-
 .../postgres/modules/utilities/utilities.py_in     |  21 +
 11 files changed, 617 insertions(+), 196 deletions(-)

diff --git a/methods/array_ops/src/pg_gp/array_ops.c 
b/methods/array_ops/src/pg_gp/array_ops.c
index 48880a6..a842a60 100644
--- a/methods/array_ops/src/pg_gp/array_ops.c
+++ b/methods/array_ops/src/pg_gp/array_ops.c
@@ -2107,3 +2107,33 @@ General_Array_to_Cumulative_Array(
 
     return pgarray;
 }
+
+PG_FUNCTION_INFO_V1(array_to_bytea);
+Datum
+array_to_bytea(PG_FUNCTION_ARGS)
+{
+    ArrayType *a = PG_GETARG_ARRAYTYPE_P(0);
+    Oid element_type = ARR_ELEMTYPE(a);
+    TypeCacheEntry * TI;
+    int data_length, nitems, items_avail;
+
+    data_length = VARSIZE(a) - ARR_DATA_OFFSET(a);
+    nitems = ArrayGetNItems(ARR_NDIM(a), ARR_DIMS(a));
+    TI = lookup_type_cache(element_type, TYPECACHE_CMP_PROC_FINFO);
+    items_avail = (data_length / TI->typlen);
+
+    if (nitems > items_avail) {
+        elog(ERROR, "Unexpected end of array:  expected %d elements but 
received only %d",  nitems,  data_length);
+    } else if (nitems < items_avail) {
+        elog(WARNING, "to_bytea(): Ignoring %d extra elements after end of 
%d-element array!", items_avail - nitems, nitems);
+        data_length = (nitems * TI->typlen);
+    }
+
+    bytea *ba = palloc(VARHDRSZ + data_length);
+
+    SET_VARSIZE(ba, VARHDRSZ + data_length);
+
+    memcpy(((char *)ba) + VARHDRSZ, ARR_DATA_PTR(a), data_length);
+
+    PG_RETURN_BYTEA_P(ba);
+}
diff --git a/methods/array_ops/src/pg_gp/array_ops.sql_in 
b/methods/array_ops/src/pg_gp/array_ops.sql_in
index e1aa368..c1ec853 100644
--- a/methods/array_ops/src/pg_gp/array_ops.sql_in
+++ b/methods/array_ops/src/pg_gp/array_ops.sql_in
@@ -733,3 +733,15 @@ ORDER BY 1,2;
         """.format(schema_madlib='MADLIB_SCHEMA')
 $$ LANGUAGE PLPYTHONU IMMUTABLE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', `');
+
+m4_changequote(<!, !>)
+m4_ifelse(__PORT__ __DBMS_VERSION_MAJOR__, <!GREENPLUM 4!>,,
+<!
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.array_to_bytea(ANYARRAY)
+RETURNS BYTEA
+AS
+'MODULE_PATHNAME', 'array_to_bytea'
+LANGUAGE C IMMUTABLE
+!>)
+m4_changequote(`,')
+m4_ifdef(__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
diff --git a/methods/array_ops/src/pg_gp/test/array_ops.sql_in 
b/methods/array_ops/src/pg_gp/test/array_ops.sql_in
index b05d0b7..511564f 100644
--- a/methods/array_ops/src/pg_gp/test/array_ops.sql_in
+++ b/methods/array_ops/src/pg_gp/test/array_ops.sql_in
@@ -7,6 +7,9 @@
 --    all objects created in the default schema will be cleaned-up outside.
 ---------------------------------------------------------------------------
 
+m4_include(`SQLCommon.m4')
+m4_changequote(`<!', `!>')
+
 ---------------------------------------------------------------------------
 -- Setup:
 ---------------------------------------------------------------------------
@@ -307,3 +310,71 @@ FROM (
     unnest_2d_tbl05_groundtruth t2
     USING (id,unnest_row_id)
 ) t3;
+
+-- TESTING array_to_bytea() function - skip for gpdb 4.3
+m4_ifelse(__PORT__ __DBMS_VERSION_MAJOR__, <!GREENPLUM 4!>,,
+<!
+
+-- create input table ( n = 3 x 5 x 7 dim SMALLINT[], r =  2 x 3 x 5 dim 
REAL[] )
+DROP TABLE IF EXISTS array_input_tbl;
+CREATE TABLE array_input_tbl (id SMALLINT, n SMALLINT[], r REAL[]);
+INSERT INTO array_input_tbl SELECT generate_series(1, 10);
+SELECT id, count(*), array_agg(n) from (select id, unnest(n) as n from 
array_input_tbl) u group by id order by id;
+UPDATE array_input_tbl SET
+    n=array_fill(2*id, ARRAY[3, 5, 7]),
+    r=array_fill(id + 0.4, array[2, 3, 5]);
+
+-- create flattened input table
+DROP TABLE IF EXISTS flat_array_input_tbl;
+CREATE TABLE flat_array_input_tbl (id SMALLINT, n SMALLINT[], n_length 
SMALLINT, r REAL[], r_length SMALLINT);
+INSERT INTO flat_array_input_tbl
+    SELECT n.id, n, n_length, r, r_length
+    FROM
+    (
+        SELECT id, array_agg(n) AS n, 2*COUNT(*) AS n_length
+        FROM
+        (
+            SELECT id, unnest(n) AS n FROM array_input_tbl
+        ) n_values
+        GROUP BY id
+    ) n
+    JOIN
+    (
+        SELECT id, array_agg(r) AS r, 4*COUNT(*) AS r_length
+        FROM
+        (
+            SELECT id, unnest(r) AS r FROM array_input_tbl
+        ) r_values
+        GROUP BY id
+    ) r
+    USING (id);
+
+CREATE TABLE bytea_tbl AS SELECT id, array_to_bytea(n) AS n, array_to_bytea(r) 
AS r FROM array_input_tbl;
+
+    -- verify lengths of BYTEA output is correct for SMALLINT & REAL arrays
+    SELECT assert(
+        length(o.n) = i.n_length AND length(o.r) = i.r_length,
+        'array_to_bytea() returned incorrect lengths:\n' ||
+        '   Expected length(n) = ' || n_length::TEXT || ', got ' || 
length(o.n) ||
+        '   Expected ' || r_length::TEXT || ', got ' || length(o.r)
+    )
+    FROM flat_array_input_tbl i JOIN bytea_tbl o USING (id);
+
+    -- convert BYTEA back to flat arrays of SMALLINT's & REAL's
+
+    CREATE TABLE array_output_tbl AS
+    SELECT
+        id,
+        convert_bytea_to_smallint_array(n) AS n,
+        convert_bytea_to_real_array(r) AS r
+    FROM bytea_tbl;
+
+    -- verify that data in above table matches flattened input table exactly
+    SELECT assert(
+        i.n = o.n AND i.r = o.r,
+        'output of array_to_bytea() does not convert back to flattened input'
+    )
+    FROM flat_array_input_tbl i JOIN array_output_tbl o USING (id);
+!>)
+
+m4_changequote(,)
diff --git 
a/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in 
b/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
index 757a5bc..351e6a5 100644
--- a/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
+++ b/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
@@ -29,6 +29,8 @@ from internal.db_utils import get_distinct_col_levels
 from internal.db_utils import quote_literal
 from internal.db_utils import get_product_of_dimensions
 from utilities.minibatch_preprocessing import MiniBatchBufferSizeCalculator
+from utilities.control import OptimizerControl
+from utilities.control import HashaggControl
 from utilities.utilities import _assert
 from utilities.utilities import add_postfix
 from utilities.utilities import is_platform_pg
@@ -46,6 +48,7 @@ from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import get_expr_type
 
 from madlib_keras_helper import *
+import time
 
 NUM_CLASSES_COLNAME = "num_classes"
 
@@ -59,9 +62,9 @@ class InputDataPreprocessorDL(object):
         self.dependent_varname = dependent_varname
         self.independent_varname = independent_varname
         self.buffer_size = buffer_size
-        self.normalizing_const = normalizing_const if normalizing_const is not 
None else DEFAULT_NORMALIZING_CONST
+        self.normalizing_const = normalizing_const
         self.num_classes = num_classes
-        self.distribution_rules = distribution_rules if distribution_rules 
else DEFAULT_GPU_CONFIG
+        self.distribution_rules = distribution_rules if distribution_rules 
else 'all_segments'
         self.module_name = module_name
         self.output_summary_table = None
         self.dependent_vartype = None
@@ -73,7 +76,6 @@ class InputDataPreprocessorDL(object):
         ## Validating input args prior to using them in 
_set_validate_vartypes()
         self._validate_args()
         self._set_validate_vartypes()
-        self.num_of_buffers = self._get_num_buffers()
         self.dependent_levels = None
         # The number of padded zeros to include in 1-hot vector
         self.padding_size = 0
@@ -199,160 +201,382 @@ class InputDataPreprocessorDL(object):
             3) One-hot encodes the dependent variable.
             4) Minibatches the one-hot encoded dependent variable.
         """
+        # setup for 1-hot encoding
         self._set_one_hot_encoding_variables()
-        # Create a temp table that has independent var normalized.
-        norm_tbl = unique_string(desp='normalized')
-
-        # Always one-hot encode the dependent var. For now, we are assuming
-        # that input_preprocessor_dl will be used only for deep
-        # learning and mostly for classification. So make a strong
-        # assumption that it is only for classification, so one-hot
-        # encode the dep var, unless it's already a numeric array in
-        # which case we assume it's already one-hot encoded.
-        one_hot_dep_var_array_expr = \
-            self.get_one_hot_encoded_dep_var_expr()
-        order_by_clause = " ORDER BY RANDOM() " if order_by_random else ""
-        scalar_mult_sql = """
-            CREATE TEMP TABLE {norm_tbl} AS
-            SELECT {self.schema_madlib}.array_scalar_mult(
-                {self.independent_varname}::{FLOAT32_SQL_TYPE}[],
-                (1/{self.normalizing_const})::{FLOAT32_SQL_TYPE}) AS x_norm,
-                {one_hot_dep_var_array_expr} AS y,
-                row_number() over() AS row_id
-            FROM {self.source_table} {order_by_clause}
-            """.format(FLOAT32_SQL_TYPE=FLOAT32_SQL_TYPE, **locals())
-        plpy.execute(scalar_mult_sql)
 
+        # Generate random strings for TEMP tables
         series_tbl = unique_string(desp='series')
         dist_key_tbl = unique_string(desp='dist_key')
-        dep_shape_col = add_postfix(
-            MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL, "_shape")
-        ind_shape_col = add_postfix(
-            MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, "_shape")
+        normalized_tbl = unique_string(desp='normalized_table')
+        batched_table = unique_string(desp='batched_table')
+
+        # Used later in locals() for formatting queries
+        x=MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
+        y=MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
+        float32=FLOAT32_SQL_TYPE
+        dep_shape_col = add_postfix(y, "_shape")
+        ind_shape_col = add_postfix(x, "_shape")
 
         ind_shape = self._get_independent_var_shape()
         ind_shape = ','.join([str(i) for i in ind_shape])
         dep_shape = self._get_dependent_var_shape()
         dep_shape = ','.join([str(i) for i in dep_shape])
 
+        one_hot_dep_var_array_expr = self.get_one_hot_encoded_dep_var_expr()
+
+        # skip normalization step if normalizing_const = 1.0
+        if self.normalizing_const and (self.normalizing_const < 0.999999 or 
self.normalizing_const > 1.000001):
+            rescale_independent_var = 
"""{self.schema_madlib}.array_scalar_mult(
+                                         
{self.independent_varname}::{float32}[],
+                                         
(1/{self.normalizing_const})::{float32})
+                                      """.format(**locals())
+        else:
+            self.normalizing_const = DEFAULT_NORMALIZING_CONST
+            rescale_independent_var = 
"{self.independent_varname}::{float32}[]".format(**locals())
+
+        # It's important that we shuffle all rows before batching for fit(), 
but
+        #  we can skip that for predict()
+        order_by_clause = " ORDER BY RANDOM()" if order_by_random else ""
+
+        # This query template will be used later in pg & gp specific code 
paths,
+        #  where {make_buffer_id} and {dist_by_buffer_id} are filled in
+        batching_query = """
+            CREATE TEMP TABLE {batched_table} AS SELECT
+                {{make_buffer_id}} buffer_id,
+                {self.schema_madlib}.agg_array_concat(
+                    ARRAY[x_norm::{float32}[]]) AS {x},
+                {self.schema_madlib}.agg_array_concat(
+                    ARRAY[y]) AS {y},
+                COUNT(*) AS count
+            FROM {normalized_tbl}
+            GROUP BY buffer_id
+            {{dist_by_buffer_id}}
+        """.format(**locals())
+
+        # This query template will be used later in pg & gp specific code 
paths,
+        #  where {dist_key_col_comma} and {dist_by_dist_key} will be filled in
+        bytea_query = """
+            CREATE TABLE {self.output_table} AS SELECT
+                {{dist_key_col_comma}}
+                {self.schema_madlib}.array_to_bytea({x}) AS {x},
+                {self.schema_madlib}.array_to_bytea({y}) AS {y},
+                ARRAY[count,{ind_shape}]::SMALLINT[] AS {ind_shape_col},
+                ARRAY[count,{dep_shape}]::SMALLINT[] AS {dep_shape_col},
+                buffer_id
+            FROM {batched_table}
+            {{dist_by_dist_key}}
+        """.format(**locals())
+
         if is_platform_pg():
+            # used later for writing summary table
+            self.distribution_rules = '$__madlib__$all_segments$__madlib__$'
+
+            #
+            # For postgres, we just need 3 simple queries:
+            #   1-hot-encode/normalize + batching + bytea conversion
+            #
+
+            # see note in gpdb code branch (lower down) on
+            # 1-hot-encoding of dependent var
+            one_hot_sql = """
+                CREATE TEMP TABLE {normalized_tbl} AS SELECT
+                    (ROW_NUMBER() OVER({order_by_clause}) - 1)::INTEGER as 
row_id,
+                    {rescale_independent_var} AS x_norm,
+                    {one_hot_dep_var_array_expr} AS y
+                FROM {self.source_table}
+            """.format(**locals())
+
+            plpy.execute(one_hot_sql)
+
+            self.buffer_size = self._get_buffer_size(1)
+
+            # Used to format query templates with locals()
+            make_buffer_id = 'row_id / {0} AS '.format(self.buffer_size)
+            dist_by_dist_key = ''
+            dist_by_buffer_id = ''
+            dist_key_col_comma = ''
+
+            # Disable hashagg since large number of arrays being concatenated
+            # could result in excessive memory usage.
+            with HashaggControl(False):
+                # Batch rows with GROUP BY
+                plpy.execute(batching_query.format(**locals()))
+
+            plpy.execute("DROP TABLE {0}".format(normalized_tbl))
+
+            # Convert to BYTEA and output final (permanent table)
+            plpy.execute(bytea_query.format(**locals()))
+
+            plpy.execute("DROP TABLE {0}".format(batched_table))
+
+            self._create_output_summary_table()
+
+            return
+
+        # Done with postgres, rest is all for gpdb
+        #
+        # This gpdb code path is far more complex, and depends on
+        #   how the user wishes to distribute the data.  Even if
+        #   it's to be spread evenly across all segments, we still
+        #   need to do some extra work to ensure that happens.
+
+        if self.distribution_rules == 'all_segments':
+            all_segments = True
             self.distribution_rules = '$__madlib__$all_segments$__madlib__$'
-            distributed_by_clause = ''
-            dist_key_clause = ''
-            join_clause = ''
-            dist_key_comma = ''
+            num_segments = get_seg_number()
         else:
-            dist_key = DISTRIBUTION_KEY_COLNAME
-            # Create large temp table such that there is atleast 1 row on each 
segment
-            # Using 999999 would distribute data(atleast 1 row on each 
segment) for
-            # a cluster as large as 20000
-            query = """
-                    CREATE TEMP TABLE {series_tbl}
-                    AS
-                    SELECT generate_series(0, 999999) {dist_key}
-                    DISTRIBUTED BY ({dist_key})
-                """.format(**locals())
-            plpy.execute(query)
-            distributed_by_clause= ' DISTRIBUTED BY ({dist_key}) 
'.format(**locals())
-            dist_key_comma = dist_key + ' ,'
-            gpu_join_clause = """JOIN {dist_key_tbl} ON
-                ({self.gpu_config})[b.buffer_id%{num_segments}+1] = 
{dist_key_tbl}.id
-                """
-
-            if self.distribution_rules == 'gpu_segments':
-                gpu_info_table = unique_string(desp='gpu_info')
-                plpy.execute("""
-                    SELECT 
{self.schema_madlib}.gpu_configuration('{gpu_info_table}')
-                """.format(**locals()))
-                gpu_query = """
-                    SELECT array_agg(DISTINCT(hostname)) as gpu_config
-                    FROM {gpu_info_table}
-                """.format(**locals())
-                gpu_query_result = plpy.execute(gpu_query)[0]['gpu_config']
-                if not gpu_query_result:
-                   plpy.error("{self.module_name}: No GPUs configured on 
hosts.".format(self=self))
-
-                gpu_config_hostnames = "ARRAY{0}".format(gpu_query_result)
-                # find hosts with gpus
-                get_segment_query = """
-                    SELECT array_agg(content) as segment_ids,
-                           array_agg(dbid) as dbid,
-                           count(*) as count
-                    FROM gp_segment_configuration
-                    WHERE content != -1 AND role = 'p'
-                    AND hostname=ANY({gpu_config_hostnames})
-                """.format(**locals())
-                segment_ids_result = plpy.execute(get_segment_query)[0]
-                plpy.execute("DROP TABLE IF EXISTS {0}".format(gpu_info_table))
-
-                self.gpu_config = 
"ARRAY{0}".format(sorted(segment_ids_result['segment_ids']))
-                self.distribution_rules = 
"ARRAY{0}".format(sorted(segment_ids_result['dbid']))
-
-                num_segments = segment_ids_result['count']
-                where_clause = "WHERE 
gp_segment_id=ANY({self.gpu_config})".format(**locals())
-                join_clause = gpu_join_clause.format(**locals())
-
-            elif self.distribution_rules == DEFAULT_GPU_CONFIG:
-
-                self.distribution_rules = 
'$__madlib__$all_segments$__madlib__$'
-                where_clause = ''
-                num_segments = get_seg_number()
-                join_clause = 'JOIN {dist_key_tbl} ON 
(b.buffer_id%{num_segments})= {dist_key_tbl}.id'.format(**locals())
-
-            else:  # Read from a table with dbids to distribute the data
-
-                self._validate_distribution_table()
-                gpu_query = """
-                    SELECT array_agg(content) as gpu_config,
-                           array_agg(gp_segment_configuration.dbid) as dbid
-                    FROM {self.distribution_rules} JOIN 
gp_segment_configuration
-                    ON {self.distribution_rules}.dbid = 
gp_segment_configuration.dbid
-                """.format(**locals())
-                gpu_query_result = plpy.execute(gpu_query)[0]
-                self.gpu_config = 
"ARRAY{0}".format(sorted(gpu_query_result['gpu_config']))
-                where_clause = "WHERE 
gp_segment_id=ANY({self.gpu_config})".format(**locals())
-                num_segments = plpy.execute("SELECT count(*) as count FROM 
{self.distribution_rules}".format(**locals()))[0]['count']
-                join_clause = gpu_join_clause.format(**locals())
-                self.distribution_rules = 
"ARRAY{0}".format(sorted(gpu_query_result['dbid']))
-
-            dist_key_query = """
-                    CREATE TEMP TABLE {dist_key_tbl} AS
-                    SELECT gp_segment_id AS id, min({dist_key}) AS {dist_key}
-                    FROM {series_tbl}
-                    {where_clause}
-                    GROUP BY gp_segment_id
-            """
-            plpy.execute(dist_key_query.format(**locals()))
-
-        # Create the mini-batched output table
+            all_segments = False
+
+        if self.distribution_rules == 'gpu_segments':
+            gpu_info_table = unique_string(desp='gpu_info')
+            plpy.execute("""
+                SELECT 
{self.schema_madlib}.gpu_configuration('{gpu_info_table}')
+            """.format(**locals()))
+            gpu_query = """
+                SELECT array_agg(DISTINCT(hostname)) as gpu_config
+                FROM {gpu_info_table}
+            """.format(**locals())
+            gpu_query_result = plpy.execute(gpu_query)[0]['gpu_config']
+            if not gpu_query_result:
+               plpy.error("{self.module_name}: No GPUs configured on 
hosts.".format(self=self))
+
+            gpu_config_hostnames = "ARRAY{0}".format(gpu_query_result)
+            # find hosts with gpus
+            get_segment_query = """
+                SELECT array_agg(content) as segment_ids,
+                       array_agg(dbid) as dbid,
+                       count(*) as count
+                FROM gp_segment_configuration
+                WHERE content != -1 AND role = 'p'
+                AND hostname=ANY({gpu_config_hostnames})
+            """.format(**locals())
+            segment_ids_result = plpy.execute(get_segment_query)[0]
+            plpy.execute("DROP TABLE IF EXISTS {0}".format(gpu_info_table))
+
+            self.gpu_config = 
"ARRAY{0}".format(sorted(segment_ids_result['segment_ids']))
+            self.distribution_rules = 
"ARRAY{0}".format(sorted(segment_ids_result['dbid']))
+
+            num_segments = segment_ids_result['count']
+
+        elif not all_segments:  # Read from a table with dbids to distribute 
the data
+            self._validate_distribution_table()
+            gpu_query = """
+                SELECT array_agg(content) as gpu_config,
+                       array_agg(gp_segment_configuration.dbid) as dbid
+                FROM {self.distribution_rules} JOIN gp_segment_configuration
+                ON {self.distribution_rules}.dbid = 
gp_segment_configuration.dbid
+            """.format(**locals())
+            gpu_query_result = plpy.execute(gpu_query)[0]
+            self.gpu_config = 
"ARRAY{0}".format(sorted(gpu_query_result['gpu_config']))
+            num_segments = plpy.execute("SELECT count(*) as count FROM 
{self.distribution_rules}".format(**locals()))[0]['count']
+            self.distribution_rules = 
"ARRAY{0}".format(sorted(gpu_query_result['dbid']))
+
+        join_key = 't.buffer_id % {num_segments}'.format(**locals())
+
+        if not all_segments:
+            join_key = '({self.gpu_config})[{join_key} + 1]'.format(**locals())
+
+        # Create large temp table such that there is atleast 1 row on each 
segment
+        # Using 999999 would distribute data(atleast 1 row on each segment) for
+        # a cluster as large as 20000
+        dist_key_col = DISTRIBUTION_KEY_COLNAME
+        query = """
+            CREATE TEMP TABLE {series_tbl} AS
+                SELECT generate_series(0, 999999) {dist_key_col}
+                DISTRIBUTED BY ({dist_key_col})
+            """.format(**locals())
+
+        plpy.execute(query)
+
+        # Used in locals() to format queries, including template queries
+        #  bytea_query & batching_query defined in section common to
+        #  pg & gp (very beginning of this function)
+        dist_by_dist_key = 'DISTRIBUTED BY ({dist_key_col})'.format(**locals())
+        dist_by_buffer_id = 'DISTRIBUTED BY (buffer_id)'
+        dist_key_col_comma = dist_key_col + ' ,'
+        make_buffer_id = ''
+
+        dist_key_query = """
+                CREATE TEMP TABLE {dist_key_tbl} AS
+                SELECT min({dist_key_col}) AS {dist_key_col}
+                FROM {series_tbl}
+                GROUP BY gp_segment_id
+                DISTRIBUTED BY ({dist_key_col})
+        """.format(**locals())
+
+        plpy.execute(dist_key_query)
+
+        plpy.execute("DROP TABLE {0}".format(series_tbl))
+
+        # Always one-hot encode the dependent var. For now, we are assuming
+        # that input_preprocessor_dl will be used only for deep
+        # learning and mostly for classification. So make a strong
+        # assumption that it is only for classification, so one-hot
+        # encode the dep var, unless it's already a numeric array in
+        # which case we assume it's already one-hot encoded.
+
+        # While 1-hot-encoding is done, we also normalize the independent
+        # var and randomly shuffle the rows on each segment.  (The dist key
+        # we're adding avoids any rows moving between segments.  This may
+        # make things slightly less random, but helps with speed--probably
+        # a safe tradeoff to make.)
+
+        norm_tbl = unique_string(desp='norm_table')
+
+        one_hot_sql = """
+            CREATE TEMP TABLE {norm_tbl} AS
+            SELECT {dist_key_col},
+                {rescale_independent_var} AS x_norm,
+                {one_hot_dep_var_array_expr} AS y
+            FROM {self.source_table} s JOIN {dist_key_tbl} AS d
+                ON (s.gp_segment_id = d.gp_segment_id)
+            {order_by_clause}
+            DISTRIBUTED BY ({dist_key_col})
+        """.format(**locals())
+        plpy.execute(one_hot_sql)
+
+        rows_per_seg_tbl = unique_string(desp='rows_per_seg')
+        start_rows_tbl = unique_string(desp='start_rows')
+
+        #  Generate rows_per_segment table; this small table will
+        #  just have one row on each segment containing the number
+        #  of rows on that segment in the norm_tbl
+        sql = """
+            CREATE TEMP TABLE {rows_per_seg_tbl} AS SELECT
+                COUNT(*) as rows_per_seg,
+                {dist_key_col}
+            FROM {norm_tbl}
+            GROUP BY {dist_key_col}
+            DISTRIBUTED BY ({dist_key_col})
+        """.format(**locals())
+
+        plpy.execute(sql)
+
+        #  Generate start_rows_tbl from rows_per_segment table.
+        #  This assigns a start_row number for each segment based on
+        #  the sum of all rows in previous segments.  These will be
+        #  added to the row numbers within each segment to get an
+        #  absolute index into the table.  All of this is to accomplish
+        #  the equivalent of ROW_NUMBER() OVER() on the whole table,
+        #  but this way is much faster because we don't have to do an
+        #  N:1 Gather Motion (moving entire table to a single segment
+        #  and scanning through it).
+        #
         sql = """
-            CREATE TABLE {self.output_table} AS
-            SELECT {dist_key_comma}
-                   {self.schema_madlib}.convert_array_to_bytea({x}) AS {x},
-                   {self.schema_madlib}.convert_array_to_bytea({y}) AS {y},
-                   ARRAY[count,{ind_shape}]::SMALLINT[] AS {ind_shape_col},
-                   ARRAY[count,{dep_shape}]::SMALLINT[] AS {dep_shape_col},
-                   buffer_id
-            FROM
-            (
-                SELECT
-                    {self.schema_madlib}.agg_array_concat(
-                        ARRAY[{norm_tbl}.x_norm::{FLOAT32_SQL_TYPE}[]]) AS {x},
-                    {self.schema_madlib}.agg_array_concat(
-                        ARRAY[{norm_tbl}.y]) AS {y},
-                    ({norm_tbl}.row_id%{self.num_of_buffers})::smallint AS 
buffer_id,
-                    count(*) AS count
-                FROM {norm_tbl}
-                GROUP BY buffer_id
-            ) b
-            {join_clause}
-            {distributed_by_clause}
-            """.format(x=MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL,
-                       y=MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL,
-                       FLOAT32_SQL_TYPE=FLOAT32_SQL_TYPE,
-                       **locals())
+            CREATE TEMP TABLE {start_rows_tbl} AS SELECT
+                {dist_key_col},
+                SUM(rows_per_seg) OVER (ORDER BY gp_segment_id) - rows_per_seg 
AS start_row
+            FROM {rows_per_seg_tbl}
+            DISTRIBUTED BY ({dist_key_col})
+        """.format(**locals())
+
         plpy.execute(sql)
-        plpy.execute("DROP TABLE IF EXISTS {0}, {1}, {2}".format(norm_tbl, 
series_tbl, dist_key_tbl))
+
+        plpy.execute("DROP TABLE {0}".format(rows_per_seg_tbl))
+
+        self.buffer_size = self._get_buffer_size(num_segments)
+
+        # The query below assigns slot_id's to each row within
+        #  a segment, computes a row_id by adding start_row for
+        #  that segment to it, then divides by buffer_size to make
+        #  this into a buffer_id
+        # ie:
+        #  buffer_id = row_id / buffer_size
+        #     row_id = start_row + slot_id
+        #    slot_id = ROW_NUMBER() OVER(PARTITION BY <dist key>)::INTEGER
+        #
+        #   Instead of partitioning by gp_segment_id itself, we
+        # use __dist_key__ col instead.  This is the same partition,
+        # since there's a 1-to-1 mapping between the columns; but
+        # using __dist_key__ avoids an extra Redistribute Motion.
+        #
+        # Note: even though the ordering of these two columns is
+        #  different, this doesn't matter as each segment is being
+        #  numbered separately (only the start_row is different,
+        #  and those are fixed to the correct segments by the JOIN
+        #  condition.
+
+        sql = """
+        CREATE TEMP TABLE {normalized_tbl} AS SELECT
+            {dist_key_col},
+            x_norm,
+            y,
+            (ROW_NUMBER() OVER( PARTITION BY {dist_key_col} ))::INTEGER as 
slot_id,
+            ((start_row +
+               (ROW_NUMBER() OVER( PARTITION BY {dist_key_col} ) - 1)
+             )::INTEGER / {self.buffer_size}
+            ) AS buffer_id
+        FROM {norm_tbl} JOIN {start_rows_tbl}
+            USING ({dist_key_col})
+        ORDER BY buffer_id
+        DISTRIBUTED BY (slot_id)
+        """.format(**locals())
+
+        plpy.execute(sql)   # label buffer_id's
+
+        # A note on DISTRIBUTED BY (slot_id) in above query:
+        #
+        #     In the next query, we'll be doing the actual batching.  Due
+        #  to the GROUP BY, gpdb will Redistribute on buffer_id.  We could
+        #  avoid this by using DISTRIBUTED BY (buffer_id) in the above
+        #  (buffer-labelling) query.  But this also causes the GROUP BY
+        #  to use single-stage GroupAgg instead of multistage GroupAgg,
+        #  which for unknown reasons is *much* slower and often runs out
+        #  of VMEM unless it's set very high!
+
+        plpy.execute("DROP TABLE {norm_tbl}, 
{start_rows_tbl}".format(**locals()))
+
+        # Disable optimizer (ORCA) for platforms that use it
+        # since we want to use a groupagg instead of hashagg
+        with OptimizerControl(False) and HashaggControl(False):
+            # Run actual batching query
+            plpy.execute(batching_query.format(**locals()))
+
+        plpy.execute("DROP TABLE {0}".format(normalized_tbl))
+
+        if not all_segments: # remove any segments we don't plan to use
+            sql = """
+                DELETE FROM {dist_key_tbl}
+                    WHERE NOT gp_segment_id = ANY({self.gpu_config})
+            """.format(**locals())
+
+        plpy.execute("ANALYZE {dist_key_tbl}".format(**locals()))
+        plpy.execute("ANALYZE {batched_table}".format(**locals()))
+
+        # Redistribute from buffer_id to dist_key
+        #
+        #  This has to be separate from the batching query, because
+        #   we found that adding DISTRIBUTED BY (dist_key) to that
+        #   query causes it to run out of VMEM on large datasets such
+        #   as places100.  Possibly this is because the memory available
+        #   for GroupAgg has to be shared with an extra slice if they
+        #   are part of the same query.
+        #
+        #  We also tried adding this to the BYTEA conversion query, but
+        #   that resulted in slower performance than just keeping it
+        #   separate.
+        #
+        sql = """CREATE TEMP TABLE {batched_table}_dist_key AS
+                    SELECT {dist_key_col}, t.*
+                        FROM {batched_table} t
+                            JOIN {dist_key_tbl} d
+                                ON {join_key} = d.gp_segment_id
+                            DISTRIBUTED BY ({dist_key_col})
+              """.format(**locals())
+
+        # match buffer_id's with dist_keys
+        plpy.execute(sql)
+
+        sql = """DROP TABLE {batched_table}, {dist_key_tbl};
+                 ALTER TABLE {batched_table}_dist_key RENAME TO {batched_table}
+              """.format(**locals())
+        plpy.execute(sql)
+
+        # Convert batched table to BYTEA and output as final (permanent) table
+        plpy.execute(bytea_query.format(**locals()))
+
+        plpy.execute("DROP TABLE {0}".format(batched_table))
+
         # Create summary table
         self._create_output_summary_table()
 
@@ -405,7 +629,8 @@ class InputDataPreprocessorDL(object):
             _assert(self.buffer_size > 0,
                     "{0}: The buffer size has to be a "
                     "positive integer or NULL.".format(self.module_name))
-        _assert(self.normalizing_const > 0,
+        if self.normalizing_const is not None:
+            _assert(self.normalizing_const > 0,
                 "{0}: The normalizing constant has to be a "
                 "positive integer or NULL.".format(self.module_name))
 
@@ -442,16 +667,17 @@ class InputDataPreprocessorDL(object):
         return get_distinct_col_levels(table, dependent_varname,
             dependent_vartype, include_nulls=True)
 
-    def _get_num_buffers(self):
+    def _get_buffer_size(self, num_segments):
         num_rows_in_tbl = plpy.execute("""
                 SELECT count(*) AS cnt FROM {0}
             """.format(self.source_table))[0]['cnt']
         buffer_size_calculator = MiniBatchBufferSizeCalculator()
         indepdent_var_dim = get_product_of_dimensions(self.source_table,
             self.independent_varname)
-        self.buffer_size = 
buffer_size_calculator.calculate_default_buffer_size(
-            self.buffer_size, num_rows_in_tbl, indepdent_var_dim)
-        return ceil((1.0 * num_rows_in_tbl) / self.buffer_size)
+        buffer_size = buffer_size_calculator.calculate_default_buffer_size(
+            self.buffer_size, num_rows_in_tbl, indepdent_var_dim, num_segments)
+        num_buffers = num_segments * ceil((1.0 * num_rows_in_tbl) / 
buffer_size / num_segments)
+        return int(ceil(num_rows_in_tbl / num_buffers))
 
 class ValidationDataPreprocessorDL(InputDataPreprocessorDL):
     def __init__(self, schema_madlib, source_table, output_table,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index 40ae56e..6e006d5 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -51,7 +51,6 @@ FLOAT32_SQL_TYPE = 'REAL'
 SMALLINT_SQL_TYPE = 'SMALLINT'
 
 DEFAULT_NORMALIZING_CONST = 1.0
-DEFAULT_GPU_CONFIG = 'all_segments'
 GP_SEGMENT_ID_COLNAME = "gp_segment_id"
 INTERNAL_GPU_CONFIG = '__internal_gpu_config__'
 
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 37a2e25..11730cf 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -31,7 +31,6 @@ from madlib_keras_helper import NORMALIZING_CONST_COLNAME
 from madlib_keras_helper import DISTRIBUTION_KEY_COLNAME
 from madlib_keras_helper import METRIC_TYPE_COLNAME
 from madlib_keras_helper import INTERNAL_GPU_CONFIG
-from madlib_keras_helper import DEFAULT_GPU_CONFIG
 from madlib_keras_helper import query_model_configs
 
 from utilities.minibatch_validation import validate_bytea_var_for_minibatch
@@ -234,7 +233,7 @@ class InputValidator:
         gpu_config = plpy.execute(
             "SELECT {0} FROM {1}".format(INTERNAL_GPU_CONFIG, summary_table)
             )[0][INTERNAL_GPU_CONFIG]
-        if gpu_config == DEFAULT_GPU_CONFIG:
+        if gpu_config == 'all_segments':
             _assert(0 not in accessible_gpus_for_seg,
                 "{0} error: Host(s) are missing gpus.".format(module_name))
         else:
diff --git 
a/src/ports/postgres/modules/deep_learning/test/input_data_preprocessor.sql_in 
b/src/ports/postgres/modules/deep_learning/test/input_data_preprocessor.sql_in
index d8c6798..7c6c5c3 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/input_data_preprocessor.sql_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/input_data_preprocessor.sql_in
@@ -19,6 +19,7 @@
  *
  *//* ----------------------------------------------------------------------- 
*/
 m4_include(`SQLCommon.m4')
+m4_changequote(`<!', `!>')
 
 DROP TABLE IF EXISTS data_preprocessor_input;
 CREATE TABLE data_preprocessor_input(id serial, x double precision[], label 
TEXT);
@@ -49,20 +50,72 @@ SELECT training_preprocessor_dl(
   'x',
   5);
 
-SELECT assert(count(*)=4, 'Incorrect number of buffers in 
data_preprocessor_input_batch.')
+-- Divide two numbers and round up to the nearest integer
+CREATE FUNCTION divide_roundup(numerator NUMERIC, denominator NUMERIC)
+RETURNS INTEGER AS
+$$
+    SELECT (ceil($1 / $2)::INTEGER);
+$$ LANGUAGE SQL;
+
+-- num_buffers_calc() represents the num_buffers value that should be
+--  calculated by the preprocessor.
+-- For postgres, just need total rows / buffer_size rounded up.
+-- For greenplum, we take that result, and round up to the nearest multiple
+--   of num_segments.
+CREATE FUNCTION num_buffers_calc(rows_in_tbl INTEGER, buffer_size INTEGER)
+RETURNS INTEGER AS
+$$
+m4_ifdef(<!__POSTGRESQL__!>,
+    <! SELECT divide_roundup($1, $2); !>,
+    <! SELECT (COUNT(*)::INTEGER) * divide_roundup(divide_roundup($1, $2), 
COUNT(*)) FROM gp_segment_configuration
+                                                WHERE role = 'p' AND content 
!= -1; !>
+)
+$$ LANGUAGE SQL;
+
+--  num_buffers() represents the actual number of buffers expected to
+--      be returned in the output table.
+--   For postgres, this should always be the same as num_buffers_calc()
+--      (as long as rows_in_tbl > 0, which should be validated elsewhere)
+--   For greenplum, this can be less than num_buffers_calc() in
+--     the special case where there is only one row per buffer.  In
+--      that case, the buffers in the output table will be equal to
+--      the number of rows in the input table.  This can only happen
+--      if rows_in_tbl < num_segments and is the only case where the
+--      number of buffers on each segment will not be exactly equal
+CREATE FUNCTION num_buffers(rows_in_tbl INTEGER, buffer_size INTEGER)
+RETURNS INTEGER AS
+$$
+    SELECT LEAST(num_buffers_calc($1, $2), $1);
+$$ LANGUAGE SQL;
+
+CREATE FUNCTION buffer_size(rows_in_tbl INTEGER, requested_buffer_size INTEGER)
+RETURNS INTEGER AS
+$$
+  SELECT divide_roundup($1, num_buffers($1, $2));
+$$ LANGUAGE SQL;
+
+SELECT assert(COUNT(*) = num_buffers(17, 5),
+    'Incorrect number of buffers in data_preprocessor_input_batch.')
 FROM data_preprocessor_input_batch;
 
-SELECT assert(independent_var_shape[2]=6, 'Incorrect buffer size.')
+SELECT assert(independent_var_shape[2]=6, 'Incorrect image shape ' || 
independent_var_shape[2])
 FROM data_preprocessor_input_batch WHERE buffer_id=0;
 
-SELECT assert(independent_var_shape[1]=5, 'Incorrect buffer size.')
-FROM data_preprocessor_input_batch WHERE buffer_id=1;
+SELECT assert(independent_var_shape[1]=buffer_size, 'Incorrect buffer size ' 
|| independent_var_shape[1])
+FROM (SELECT buffer_size(17, 5) buffer_size) a, data_preprocessor_input_batch 
WHERE buffer_id=0;
 
-SELECT assert(independent_var_shape[1]=4, 'Incorrect buffer size.')
-FROM data_preprocessor_input_batch WHERE buffer_id=3;
+SELECT assert(independent_var_shape[1]=buffer_size, 'Incorrect buffer size ' 
|| independent_var_shape[1])
+FROM (SELECT buffer_size(17, 5) buffer_size) a, data_preprocessor_input_batch 
WHERE buffer_id=1;
+
+SELECT assert(independent_var_shape[1]=buffer_size, 'Incorrect buffer size ' 
|| independent_var_shape[1])
+FROM (SELECT buffer_size(17, 5) buffer_size) a, data_preprocessor_input_batch 
WHERE buffer_id=2;
+
+SELECT assert(total_images = 17, 'Incorrect total number of images! Last 
buffer has incorrect size?')
+FROM (SELECT SUM(independent_var_shape[1]) AS total_images FROM 
data_preprocessor_input_batch) a;
+
+SELECT assert(octet_length(independent_var) = buffer_size*6*4, 'Incorrect 
buffer length ' || octet_length(independent_var)::TEXT)
+FROM (SELECT buffer_size(17, 5) buffer_size) a, data_preprocessor_input_batch 
WHERE buffer_id=0;
 
-SELECT assert(octet_length(independent_var) = 96, 'Incorrect buffer size')
-FROM data_preprocessor_input_batch WHERE buffer_id=0;
 
 DROP TABLE IF EXISTS validation_out, validation_out_summary;
 SELECT validation_preprocessor_dl(
@@ -73,20 +126,21 @@ SELECT validation_preprocessor_dl(
   'data_preprocessor_input_batch',
   5);
 
-SELECT assert(count(*)=4, 'Incorrect number of buffers in validation_out.')
+SELECT assert(COUNT(*) = num_buffers(17, 5),
+    'Incorrect number of buffers in validation_out.')
 FROM validation_out;
 
-SELECT assert(independent_var_shape[2]=6, 'Incorrect buffer size.')
+SELECT assert(independent_var_shape[2]=6, 'Incorrect image shape.')
 FROM data_preprocessor_input_batch WHERE buffer_id=0;
 
-SELECT assert(independent_var_shape[1]=5, 'Incorrect buffer size.')
-FROM data_preprocessor_input_batch WHERE buffer_id=1;
+SELECT assert(independent_var_shape[1]=buffer_size, 'Incorrect buffer size.')
+FROM (SELECT buffer_size(17, 5) buffer_size) a, data_preprocessor_input_batch 
WHERE buffer_id=1;
 
-SELECT assert(independent_var_shape[1]=4, 'Incorrect buffer size.')
-FROM data_preprocessor_input_batch WHERE buffer_id=3;
+SELECT assert(total_images = 17, 'Incorrect total number of images! Last 
buffer has incorrect size?')
+FROM (SELECT SUM(independent_var_shape[1]) AS total_images FROM 
data_preprocessor_input_batch) a;
 
-SELECT assert(octet_length(independent_var) = 96, 'Incorrect buffer size')
-FROM validation_out WHERE buffer_id=0;
+SELECT assert(octet_length(independent_var) = buffer_size*6*4, 'Incorrect 
buffer length')
+FROM (SELECT buffer_size(17, 5) buffer_size) a, validation_out WHERE 
buffer_id=0;
 
 DROP TABLE IF EXISTS data_preprocessor_input_batch, 
data_preprocessor_input_batch_summary;
 SELECT training_preprocessor_dl(
@@ -96,7 +150,6 @@ SELECT training_preprocessor_dl(
   'x');
 
 -- Test data is evenly distributed across all segments (GPDB only)
-m4_changequote(`<!', `!>')
 m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
 DROP TABLE IF EXISTS data_preprocessor_input_batch, 
data_preprocessor_input_batch_summary;
 SELECT training_preprocessor_dl(
@@ -109,11 +162,10 @@ SELECT training_preprocessor_dl(
 -- This test expects that total number of images(17 for input table 
data_preprocessor_input)
 -- are equally distributed across all segments.
 -- Therefore, after preprocessing seg0 will have 17/(# of segs) buffers.
-SELECT assert(count(*)=(SELECT ceil(17.0/count(*)) from 
gp_segment_configuration WHERE role = 'p' and content != -1), 'Even 
distribution of buffers failed.')
-FROM data_preprocessor_input_batch
-WHERE gp_segment_id = 0;
+SELECT gp_segment_id, assert((SELECT divide_roundup(17, count(*)) from 
gp_segment_configuration WHERE role = 'p' and content != -1) - COUNT(*) <= 1, 
'Even distribution of buffers failed. Seeing ' || count(*) || ' buffers.')
+    FROM data_preprocessor_input_batch GROUP BY 1;
 SELECT assert(__internal_gpu_config__ = 'all_segments', 'Missing column in 
summary table')
-FROM data_preprocessor_input_batch_summary;
+    FROM data_preprocessor_input_batch_summary;
 
 -- Test validation data is evenly distributed across all segments (GPDB only)
 DROP TABLE IF EXISTS validation_out, validation_out_summary;
@@ -124,9 +176,8 @@ SELECT validation_preprocessor_dl(
   'x',
   'data_preprocessor_input_batch',
   1);
-SELECT assert(count(*)=(SELECT ceil(17.0/count(*)) from 
gp_segment_configuration WHERE role = 'p' and content != -1), 'Even 
distribution of validation buffers failed.')
-FROM validation_out
-WHERE gp_segment_id = 0;
+SELECT gp_segment_id, assert((SELECT divide_roundup(17, count(*)) from 
gp_segment_configuration WHERE role = 'p' and content != -1) - COUNT(*) <= 1, 
'Even distribution of buffers failed. Seeing ' || count(*) || ' buffers.')
+    FROM validation_out GROUP BY 1;
 SELECT assert(__internal_gpu_config__ = 'all_segments', 'Missing column in 
validation summary table')
 FROM validation_out_summary;
 
@@ -208,8 +259,8 @@ SELECT assert(relative_error(MAX(x),46.6) < 0.00001, 
'Independent var not normal
 SELECT assert(dependent_var_shape[2] = 16, 'Incorrect one-hot encode dimension 
with num_classes') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
 
-SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
-FROM data_preprocessor_input_batch WHERE buffer_id=0;
+SELECT assert(octet_length(independent_var) = buffer_size*6*4, 'Incorrect 
buffer length')
+FROM (SELECT buffer_size(17, 4) buffer_size) a, data_preprocessor_input_batch 
WHERE buffer_id=0;
 
 -- Test summary table
 SELECT assert
@@ -220,13 +271,14 @@ SELECT assert
         independent_varname = 'x' AND
         dependent_vartype   = 'integer' AND
         class_values        = '{-6,-3,-1,0,2,3,4,5,6,7,8,9,10,12,NULL,NULL}' 
AND
-        buffer_size         = 4 AND  -- we sort the class values in python
+        summary.buffer_size = a.buffer_size AND  -- we sort the class values 
in python
         normalizing_const   = 5 AND
         pg_typeof(normalizing_const) = 'real'::regtype AND
         num_classes         = 16 AND
         distribution_rules  = 'all_segments',
         'Summary Validation failed. Actual:' || __to_char(summary)
-        ) from (select * from data_preprocessor_input_batch_summary) summary;
+        ) FROM (SELECT buffer_size(17, 4) buffer_size) a,
+          (SELECT * FROM data_preprocessor_input_batch_summary) summary;
 
 --- Test output data type
 SELECT assert(pg_typeof(independent_var) = 'bytea'::regtype, 'Wrong 
independent_var type') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
@@ -286,8 +338,8 @@ SELECT assert(pg_typeof(dependent_var) = 'bytea'::regtype, 
'One-hot encode doesn
 SELECT assert(dependent_var_shape[2] = 2, 'Incorrect one-hot encode 
dimension') FROM
    data_preprocessor_input_batch WHERE buffer_id = 0;
 
-SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
-FROM data_preprocessor_input_batch WHERE buffer_id=0;
+SELECT assert(octet_length(independent_var) = buffer_size*6*4, 'Incorrect 
buffer length')
+FROM (SELECT buffer_size(17, 4) buffer_size) a, data_preprocessor_input_batch 
WHERE buffer_id=0;
 
 SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT 
buffer_id, UNNEST((convert_bytea_to_smallint_array(dependent_var))[1:2]) as y 
FROM data_preprocessor_input_batch) a WHERE buffer_id = 0;
 SELECT assert (dependent_vartype   = 'boolean' AND
@@ -328,8 +380,8 @@ SELECT assert(pg_typeof(dependent_var) = 'bytea'::regtype, 
'One-hot encode doesn
 SELECT assert(dependent_var_shape[2] = 3, 'Incorrect one-hot encode 
dimension') FROM
    data_preprocessor_input_batch WHERE buffer_id = 0;
 
-SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
-FROM data_preprocessor_input_batch WHERE buffer_id=0;
+SELECT assert(octet_length(independent_var) = buffer_size*6*4, 'Incorrect 
buffer length')
+FROM (SELECT buffer_size(17, 4) buffer_size) a, data_preprocessor_input_batch 
WHERE buffer_id=0;
 
 SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT 
buffer_id, UNNEST((convert_bytea_to_smallint_array(dependent_var))[1:3]) as y 
FROM data_preprocessor_input_batch) a WHERE buffer_id = 0;
 SELECT assert (dependent_vartype   = 'text' AND
@@ -363,8 +415,8 @@ SELECT training_preprocessor_dl(
 SELECT assert(pg_typeof(dependent_var) = 'bytea'::regtype, 'One-hot encode 
doesn''t convert into integer array format') FROM data_preprocessor_input_batch 
WHERE buffer_id = 0;
 SELECT assert(dependent_var_shape[2] = 3, 'Incorrect one-hot encode 
dimension') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
-FROM data_preprocessor_input_batch WHERE buffer_id=0;
+SELECT assert(octet_length(independent_var) = buffer_size*6*4, 'Incorrect 
buffer length')
+FROM (SELECT buffer_size(17, 4) buffer_size) a, data_preprocessor_input_batch 
WHERE buffer_id=0;
 SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT 
buffer_id, UNNEST((convert_bytea_to_smallint_array(dependent_var))[1:3]) as y 
FROM data_preprocessor_input_batch) a WHERE buffer_id = 0;
 SELECT assert (dependent_vartype   = 'double precision' AND
                class_values        = '{4.0,4.2,5.0}' AND
@@ -385,8 +437,8 @@ SELECT assert(pg_typeof(dependent_var) = 'bytea'::regtype, 
'One-hot encode doesn
 SELECT assert(dependent_var_shape[2] = 2, 'Incorrect one-hot encode 
dimension') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
 
-SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
-FROM data_preprocessor_input_batch WHERE buffer_id=0;
+SELECT assert(octet_length(independent_var) = buffer_size*6*4, 'Incorrect 
buffer length')
+FROM (SELECT buffer_size(17, 4) buffer_size) a, data_preprocessor_input_batch 
WHERE buffer_id=0;
 
 SELECT assert(relative_error(SUM(y), SUM(y4)) < 0.000001, 'Incorrect one-hot 
encode value') FROM (SELECT 
UNNEST(convert_bytea_to_smallint_array(dependent_var)) AS y FROM 
data_preprocessor_input_batch) a, (SELECT UNNEST(y4) as y4 FROM 
data_preprocessor_input) b;
 SELECT assert (dependent_vartype   = 'double precision[]' AND
@@ -419,8 +471,8 @@ SELECT assert(pg_typeof(dependent_var) = 'bytea'::regtype, 
'One-hot encode doesn
 SELECT assert(dependent_var_shape[2] = 2, 'Incorrect one-hot encode 
dimension') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
 
-SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
-FROM data_preprocessor_input_batch WHERE buffer_id=0;
+SELECT assert(octet_length(independent_var) = buffer_size*6*4, 'Incorrect 
buffer length')
+FROM (SELECT buffer_size(17, 4) buffer_size) a, data_preprocessor_input_batch 
WHERE buffer_id=0;
 
 SELECT assert(relative_error(SUM(y), SUM(y5)) < 0.000001, 'Incorrect one-hot 
encode value') FROM (SELECT 
UNNEST(convert_bytea_to_smallint_array(dependent_var)) AS y FROM 
data_preprocessor_input_batch) a, (SELECT UNNEST(y5) as y5 FROM 
data_preprocessor_input) b;
 SELECT assert (dependent_vartype   = 'integer[]' AND
@@ -473,8 +525,8 @@ SELECT assert
 SELECT assert(dependent_var_shape[2] = 5, 'Incorrect one-hot encode 
dimension') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
 
-SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
-FROM data_preprocessor_input_batch WHERE buffer_id=0;
+SELECT assert(octet_length(independent_var) = buffer_size*6*4, 'Incorrect 
buffer length')
+FROM (SELECT buffer_size(17, 4) buffer_size) a, data_preprocessor_input_batch 
WHERE buffer_id=0;
 
 -- The same tests, but for validation.
 DROP TABLE IF EXISTS data_preprocessor_input_validation_null;
@@ -541,7 +593,7 @@ SELECT assert
 
 SELECT assert(dependent_var_shape[2] = 3, 'Incorrect one-hot encode 
dimension') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(octet_length(independent_var) = 24, 'Incorrect buffer size')
+SELECT assert(octet_length(independent_var) = 24, 'Incorrect buffer length')
 FROM data_preprocessor_input_batch WHERE buffer_id=0;
 -- NULL is treated as a class label, so it should show '1' for the
 -- first index
@@ -570,7 +622,7 @@ SELECT assert
 
 SELECT assert(dependent_var_shape[2] = 3, 'Incorrect one-hot encode 
dimension') FROM
   validation_out_batch WHERE buffer_id = 0;
-SELECT assert(octet_length(independent_var) = 24, 'Incorrect buffer size')
+SELECT assert(octet_length(independent_var) = 24, 'Incorrect buffer length')
 FROM data_preprocessor_input_batch WHERE buffer_id=0;
 -- NULL is treated as a class label, so it should show '1' for the
 -- first index
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in
index 1f3a24f..7c9ad5e 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in
@@ -24,8 +24,8 @@
 DROP TABLE IF EXISTS cifar_10_sample;
 CREATE TABLE cifar_10_sample(id INTEGER, y SMALLINT, y_text TEXT, imgpath 
TEXT, x  REAL[]);
 COPY cifar_10_sample FROM STDIN DELIMITER '|';
-1|0|'cat'|'0/img0.jpg'|{{{202,204,199},{202,204,199},{204,206,201},{206,208,203},{208,210,205},{209,211,206},{210,212,207},{212,214,210},{213,215,212},{215,217,214},{216,218,215},{216,218,215},{215,217,214},{216,218,215},{216,218,215},{216,218,214},{217,219,214},{217,219,214},{218,220,215},{218,219,214},{216,217,212},{217,218,213},{218,219,214},{214,215,209},{213,214,207},{212,213,206},{211,212,205},{209,210,203},{208,209,202},{207,208,200},{205,206,199},{203,204,198}},{{206,208,203},{20
 [...]
-2|1|'dog'|'0/img2.jpg'|{{{126,118,110},{122,115,108},{126,119,111},{127,119,109},{130,122,111},{130,122,111},{132,124,113},{133,125,114},{130,122,111},{132,124,113},{134,126,115},{131,123,112},{131,123,112},{134,126,115},{133,125,114},{136,128,117},{137,129,118},{137,129,118},{136,128,117},{131,123,112},{130,122,111},{132,124,113},{132,124,113},{132,124,113},{129,122,110},{127,121,109},{127,121,109},{125,119,107},{124,118,106},{124,118,106},{120,114,102},{117,111,99}},{{122,115,107},{119
 [...]
+0|0|'cat'|'0/img0.jpg'|{{{202,204,199},{202,204,199},{204,206,201},{206,208,203},{208,210,205},{209,211,206},{210,212,207},{212,214,210},{213,215,212},{215,217,214},{216,218,215},{216,218,215},{215,217,214},{216,218,215},{216,218,215},{216,218,214},{217,219,214},{217,219,214},{218,220,215},{218,219,214},{216,217,212},{217,218,213},{218,219,214},{214,215,209},{213,214,207},{212,213,206},{211,212,205},{209,210,203},{208,209,202},{207,208,200},{205,206,199},{203,204,198}},{{206,208,203},{20
 [...]
+1|1|'dog'|'0/img2.jpg'|{{{126,118,110},{122,115,108},{126,119,111},{127,119,109},{130,122,111},{130,122,111},{132,124,113},{133,125,114},{130,122,111},{132,124,113},{134,126,115},{131,123,112},{131,123,112},{134,126,115},{133,125,114},{136,128,117},{137,129,118},{137,129,118},{136,128,117},{131,123,112},{130,122,111},{132,124,113},{132,124,113},{132,124,113},{129,122,110},{127,121,109},{127,121,109},{125,119,107},{124,118,106},{124,118,106},{120,114,102},{117,111,99}},{{122,115,107},{119
 [...]
 \.
 
 DROP TABLE IF EXISTS cifar_10_sample_batched;
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_input_data_preprocessor.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_input_data_preprocessor.py_in
index f21176c..d2e14cd 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_input_data_preprocessor.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_input_data_preprocessor.py_in
@@ -61,6 +61,8 @@ class InputPreProcessorDLTestCase(unittest.TestCase):
         self.module = deep_learning.input_data_preprocessor
         import utilities.minibatch_preprocessing
         self.util_module = utilities.minibatch_preprocessing
+        import utilities.control
+        self.control_module = utilities.control
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
         self.module.validate_module_input_params = Mock()
         self.module.get_distinct_col_levels = Mock(return_value = [0,22,100])
@@ -70,6 +72,9 @@ class InputPreProcessorDLTestCase(unittest.TestCase):
 
     def test_input_preprocessor_dl_executes_query(self):
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
+        self.control_module.OptimizerControl.__enter__ = Mock()
+        self.control_module.OptimizerControl.optimizer_control = True
+        self.control_module.OptimizerControl.optimizer_enabled = True
         preprocessor_obj = self.module.InputDataPreprocessorDL(
             self.default_schema_madlib,
             "input",
@@ -85,6 +90,9 @@ class InputPreProcessorDLTestCase(unittest.TestCase):
 
     def test_input_preprocessor_null_buffer_size_executes_query(self):
         self.module.get_expr_type = Mock(side_effect = ['integer[]', 
'integer[]'])
+        self.control_module.OptimizerControl.__enter__ = Mock()
+        self.control_module.OptimizerControl.optimizer_control = True
+        self.control_module.OptimizerControl.optimizer_enabled = True
         preprocessor_obj = self.module.InputDataPreprocessorDL(
             self.default_schema_madlib,
             "input",
diff --git a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in 
b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
index e03bf44..c25463c 100644
--- a/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
+++ b/src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
@@ -457,10 +457,13 @@ class MiniBatchBufferSizeCalculator:
     @staticmethod
     def calculate_default_buffer_size(buffer_size,
                                       avg_num_rows_processed,
-                                      independent_var_dimension):
+                                      independent_var_dimension,
+                                      num_of_segments=None):
         if buffer_size is not None:
             return buffer_size
-        num_of_segments = get_seg_number()
+
+        if num_of_segments is None:
+            num_of_segments = get_seg_number()
 
         default_buffer_size = min(75000000.0/independent_var_dimension,
                                     
float(avg_num_rows_processed)/num_of_segments)
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in 
b/src/ports/postgres/modules/utilities/utilities.py_in
index 687566a..12b5205 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -20,6 +20,27 @@ import plpy
 
 m4_changequote(`<!', `!>')
 
+def plpy_execute_debug(sql, *args, **kwargs):
+    """ Replace plpy.execute(sql, ...) with
+        plpy_execute_debug(sql, ...) to debug
+        a query.  Shows the query itself, the
+        EXPLAIN of it, and how long the query
+        takes to execute.
+    """
+    plpy.info(sql)  # Print sql command
+
+    # Print EXPLAIN of sql command
+    res = plpy.execute("EXPLAIN " + sql, *args)
+    for r in res:
+        plpy.info(r['QUERY PLAN'])
+
+    # Run actual sql command, with timing
+    start = time.time()
+    plpy.execute(sql, *args)
+
+    # Print how long execution of query took
+    plpy.info("Query took {0}s".format(time.time() - start))
+
 def has_function_properties():
     """ __HAS_FUNCTION_PROPERTIES__ variable defined during configure """
     return m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>)

Reply via email to