orhankislal commented on a change in pull request #467: DL: Improve performance 
of mini-batch preprocessor
URL: https://github.com/apache/madlib/pull/467#discussion_r362661848
 
 

 ##########
 File path: 
src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
 ##########
 @@ -235,124 +281,324 @@ class InputDataPreprocessorDL(object):
         dep_shape = self._get_dependent_var_shape()
         dep_shape = ','.join([str(i) for i in dep_shape])
 
+        one_hot_dep_var_array_expr = self.get_one_hot_encoded_dep_var_expr()
+
+        # skip normalization step if normalizing_const = 1.0
+        if self.normalizing_const and (self.normalizing_const < 0.999999 or 
self.normalizing_const > 1.000001):
+            rescale_independent_var = 
"""{self.schema_madlib}.array_scalar_mult(
+                                         
{self.independent_varname}::{FLOAT32_SQL_TYPE}[],
+                                         
(1/{self.normalizing_const})::{FLOAT32_SQL_TYPE})
+                                      
""".format(FLOAT32_SQL_TYPE=FLOAT32_SQL_TYPE, **locals())
+        else:
+            self.normalizing_const = DEFAULT_NORMALIZING_CONST
+            rescale_independent_var = 
"{self.independent_varname}::{FLOAT32_SQL_TYPE}[]".format(FLOAT32_SQL_TYPE=FLOAT32_SQL_TYPE,
  **locals())
+
+        # It's important that we shuffle all rows before batching for fit(), 
but
+        #  we can skip that for predict()
+        order_by_clause = " ORDER BY RANDOM()" if order_by_random else ""
+
         if is_platform_pg():
+            # used later for writing summary table
+            self.distribution_rules = '$__madlib__$all_segments$__madlib__$'
+
+            #
+            # For postgres, we just need 3 simple queries:
+            #   1-hot-encode/normalize + batching + bytea conversion
+            #
+
+            # see note in gpdb code branch (lower down) on
+            # 1-hot-encoding of dependent var
+            one_hot_sql = """
+                CREATE TEMP TABLE {normalized_tbl} AS SELECT
+                    (ROW_NUMBER() OVER({order_by_clause}) - 1)::INTEGER as 
row_id,
+                    {rescale_independent_var} AS x_norm,
+                    {one_hot_dep_var_array_expr} AS y
+                FROM {self.source_table}
+            """.format(**locals())
+
+            plpy_execute(one_hot_sql)
+
+            self.buffer_size = self._get_buffer_size(1)
+
+            make_buffer_id = 'row_id / {0} AS '.format(self.buffer_size)
+
+            dist_by_buffer_id = ''
+            self.run_batch_rows_query(locals())
+            plpy.execute("DROP TABLE {0}".format(normalized_tbl))
+
+            dist_by_dist_key = ''
+            dist_key_col_comma = ''
+            self.convert_to_bytea(locals())
 
 Review comment:
   Passing every local variable looks like a somewhat dangerous idea. We should 
at least put a comment to warn future devs.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to