kaknikhil commented on a change in pull request #361: Minibatch Preprocessor 
DL: Add optional num_classes param.
URL: https://github.com/apache/madlib/pull/361#discussion_r271052671
 
 

 ##########
 File path: src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
 ##########
 @@ -363,21 +365,70 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
 
         self._validate_args()
         self.num_of_buffers = self._get_num_buffers()
-        self.to_one_hot_encode = True
         if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
             self.dependent_levels = None
         else:
             self.dependent_levels = get_distinct_col_levels(
-                self.source_table, self.dependent_varname, 
self.dependent_vartype)
+                self.source_table, self.dependent_varname,
+                self.dependent_vartype, exclude_nulls=False)
+            # if any class level was NULL in sql, that would show up as
+            # None in self.dependent_levels. Replace all None with NULL
+            # in the list.
+            self.dependent_levels = ['NULL' if level is None else level
+                for level in self.dependent_levels]
+            self._validate_num_classes()
+
+    def _validate_num_classes(self):
+        if self.num_classes is not None and \
+            self.num_classes < len(self.dependent_levels):
+            plpy.error("{0}: Invalid num_classes value specified. It must "\
+                "be equal to or greater than distinct class values found "\
+                "in table ({1}).".format(
+                    self.module_name, len(self.dependent_levels)))
+
+    def get_dep_var_array_expr(self):
+        """
+        :param dependent_varname: Name of the dependent variable
+        :param num_classes: Number of class values to consider in 1-hot
+        :return:
+            This function returns a tuple of
+            1. A string with transformed dependent varname depending on it's 
type
+            2. All the distinct dependent class levels encoded as a string
+
+            If dep_type == numeric[] , do not encode
+                    1. dependent_varname = rings
+                        transformed_value = ARRAY[rings]
+                    2. dependent_varname = ARRAY[a, b, c]
+                        transformed_value = ARRAY[a, b, c]
+            else if dep_type in ("text", "boolean"), encode:
+                    3. dependent_varname = rings (encoding)
+                        transformed_value = ARRAY[rings=1, rings=2, rings=3]
+        """
+        # Assuming the input NUMERIC[] is already one_hot_encoded,
+        # so casting to INTEGER[]
+        if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
+            return self.dependent_varname + '::INTEGER[]'
 
 Review comment:
   Consider replacing the `+` with a `.format`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to