Github user iyerr3 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/291#discussion_r204607008
--- Diff: src/ports/postgres/modules/utilities/transform_vec_cols.py_in ---
@@ -0,0 +1,513 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import plpy
+from control import MinWarning
+from internal.db_utils import is_col_1d_array
+from internal.db_utils import quote_literal
+from utilities import _assert
+from utilities import add_postfix
+from utilities import ANY_ARRAY
+from utilities import is_psql_boolean_type
+from utilities import is_psql_char_type
+from utilities import is_psql_numeric_type
+from utilities import is_valid_psql_type
+from utilities import py_list_to_sql_string
+from utilities import split_quoted_delimited_str
+from validate_args import is_var_valid
+from validate_args import get_cols
+from validate_args import get_cols_and_types
+from validate_args import get_expr_type
+from validate_args import input_tbl_valid
+from validate_args import output_tbl_valid
+from validate_args import table_exists
+
+class vec_cols_helper:
+ def __init__(self):
+ self.all_cols = None
+
+ def get_cols_as_list(self, cols_to_process, source_table=None,
exclude_cols=None):
+ """
+ Get a list of columns based on the value of cols_to_process
+ Args:
+ @param cols_to_process: str, Either a * or a comma-separated
list of col names
+ @param source_table: str, optional. Source table name
+ @param exclude_cols: str, optional. Comma-separated list of
the col(s) to exclude
+ from the source table, only used if
cols_to_process is *
+ Returns:
+ A list of column names (or an empty list)
+ """
+ # If cols_to_process is empty/None, return empty list
+ if not cols_to_process:
+ return []
+ if cols_to_process.strip() != "*":
+ # If cols_to_process is a comma separated list of names,
return list
+ # of column names in cols_to_process.
+ return [col for col in
split_quoted_delimited_str(cols_to_process)
+ if col not in split_quoted_delimited_str(exclude_cols)]
+ if source_table:
+ if not self.all_cols:
+ self.all_cols = get_cols(source_table)
+ return [col for col in self.all_cols
+ if col not in split_quoted_delimited_str(exclude_cols)]
+ return []
+
+ def get_type_class(self, arg):
+ if is_psql_numeric_type(arg):
+ return "double precision"
+ elif is_psql_char_type(arg):
+ return "text"
+ else:
+ return arg
+
+class vec2cols:
+ def __init__(self):
+ self.get_cols_helper = vec_cols_helper()
+ self.module_name = self.__class__.__name__
+
+ def validate_args(self, source_table, output_table, vector_col,
feature_names,
+ cols_to_output):
+ """
+ Validate args for vec2cols
+ """
+ input_tbl_valid(source_table, self.module_name)
+ output_tbl_valid(output_table, self.module_name)
+ is_var_valid(source_table, cols_to_output)
+ is_var_valid(source_table, vector_col)
+ _assert(is_valid_psql_type(get_expr_type(vector_col,
source_table), ANY_ARRAY),
+ "{0}: vector_col should refer to an
array.".format(self.module_name))
+ _assert(is_col_1d_array(source_table, vector_col),
+ "{0}: vector_col must be a 1-dimensional
array.".format(self.module_name))
+
+ def get_names_for_split_output_cols(self, source_table, vector_col,
feature_names):
+ """
+ Get list of names for the newly-split columns to include in the
+ output table.
+ Args:
+ @param: source_table, str. Source table
+ @param: vector_col, str. Column name containing the array input
+ @param: feature_names, list. Python list of the feature names
to
+ use for the split elements in the vector_col array
+ """
+ query = """
+ SELECT array_upper({0}, 1) AS n_x
+ FROM {1}
+ LIMIT 1
+ """.format(vector_col, source_table)
+ result = plpy.execute(query)[0]["n_x"]
+ if not result:
+ plpy.error('{0}: Column to split ({1}) must not be an empty
array'
+ .format(self.module_name, vector_col))
+ if not feature_names:
+ # Create custom col names for output columns, with prefix "f".
+ feature_names = ["f{0}".format(i+1) for i in range(result)]
+ else:
+ # Check if the array dimension is equal to the number of col
names
+ # specified in feature_names.
+ _assert(result == len(feature_names),
+ "{0}: Mismatch between size of vector_col and number
of "
+ "cols in feature_names.".format(self.module_name))
+ return feature_names
+
+ def validate_output_cols(self, features_to_unnest, cols_to_keep):
+ # If there are more than 1600 columns for the output table, we
give a
+ # warning as it might give an error
+ MAX_OUTPUT_COLUMN_COUNT = 1600
+ _assert(len(features_to_unnest)+len(cols_to_keep) <
MAX_OUTPUT_COLUMN_COUNT,
+ "{0}: The output exceeds the max number of columns that " +
+ "can be created ({1})".format(self.module_name,
MAX_OUTPUT_COLUMN_COUNT))
+ # Check if newly created col names have the same name as existing
cols
+ duplicate_col_names =
set(features_to_unnest).intersection(set(cols_to_keep))
+ _assert(len(duplicate_col_names) == 0,
+ "{0}: Conflicting column names. Column names in source "
+ "table cannot be {1}".format(self.module_name,
+ list(duplicate_col_names)))
+
+ def vec2cols(self, schema_madlib, source_table, output_table,
+ vector_col, feature_names, cols_to_output, **kwargs):
+ """
+ Split up a column of array entries into multiple columns, each
column
+ corresponding to one array position
+ Args:
+ @param: schema_madlib, str. The schema with madlib installed
+ @param: source_table, str. The source table
+ @param: output_table, str. The output table
+ @param: vector_col, str. The column with array entries to
split up
+ @param: feature_names, list. Python list of the feature names
to use
+ for the split elements in the vector_col array
+ @param: cols_to_output, str. Comma-separated list of the
columns in
+ the source_table to include in the output_table
+ """
+ self.validate_args(source_table, output_table, vector_col,
feature_names,
+ cols_to_output)
+
+ # Get names of columns to use for the split vector_col
+ features_to_unnest =
self.get_names_for_split_output_cols(source_table,
+ vector_col, feature_names)
+ cols_to_keep =
self.get_cols_helper.get_cols_as_list(cols_to_output,
+ source_table)
+
+ self.validate_output_cols(features_to_unnest, cols_to_keep)
+
+ # Construct the output query and populate the output table with
all the
+ # correct parameters
+ select_new_cols = ', '.join(["{0}[{1}] AS {2}".format(vector_col,
+ i+1, features_to_unnest[i]) for i in
range(len(features_to_unnest))])
+ cols_from_src_table = ', '.join(cols_to_keep)+', ' if cols_to_keep
else ''
+ query = """
+ CREATE TABLE {output_table} AS
+ SELECT {cols_from_src_table} {select_new_cols}
+ FROM {source_table}
+ """.format(**locals())
+ plpy.execute(query)
+
+ def vec2cols_help_message(self, schema_madlib, message, **kwargs):
+ """
+ Help message for vec2cols function
+ """
+ summary_string = """
+-----------------------------------------------------------------------------------
+ SUMMARY
+-----------------------------------------------------------------------------------
+Functionality: Vector to Columns
+
+The MADlib vec2cols function enables the user to split up a single column
into
+multiple columns, given that the input column contains array entries. For
example,
+if the input column contained ARRAY[1, 2, 3] in one of its rows, the
output table
+will contain 3 different columns, one for each element of the array.
+
+For more details on function usage:
+ SELECT {schema_madlib}.vec2cols('usage');
+
+For a small example on using the function:
+ SELECT {schema_madlib}.vec2cols('example');
+ """.format(schema_madlib=schema_madlib)
+
+ usage_string = """
+-----------------------------------------------------------------------------------
+ USAGE
+-----------------------------------------------------------------------------------
+SELECT {schema_madlib}.vec2cols(
+ 'source_table', -- str, Name of the source table that contains the
data
+ 'output_table', -- str, Name of the output view or table
+ 'vector_col', -- str, Name of the array entry column to be split
+ 'feature_names', -- array, Optional parameter to provide a text
array of
+ -- the feature names for the newly split columns
(if not
+ -- provided, default names f0, f1, ... will be
used)
+ 'cols_to_output' -- str, Optional parameter to specify any other
columns
+ -- in the source_table to include in the
output_table
+ -- (default none of them, also supports '*' as
input)
+ """.format(schema_madlib=schema_madlib)
+
+ example_string = """
+-----------------------------------------------------------------------------------
+ EXAMPLE
+-----------------------------------------------------------------------------------
+-- Create an input data set:
+
+DROP TABLE IF EXISTS golf CASCADE;
+CREATE TABLE golf (
+ id integer NOT NULL,
+ "OUTLOOK" text,
+ temperature double precision,
+ humidity double precision,
+ "Temp_Humidity" double precision[],
+ clouds_airquality text[],
+ windy boolean,
+ class text,
+ observation_weight double precision
+);
+INSERT INTO golf VALUES
+(1,'sunny', 85, 85, ARRAY[85, 85],ARRAY['none', 'unhealthy'],
'false','Don''t Play', 5.0),
+(2, 'sunny', 80, 90, ARRAY[80, 90], ARRAY['none', 'moderate'], 'true',
'Don''t Play', 5.0),
+(3, 'overcast', 83, 78, ARRAY[83, 78], ARRAY['low', 'moderate'], 'false',
'Play', 1.5),
+(4, 'rain', 70, 96, ARRAY[70, 96], ARRAY['low', 'moderate'], 'false',
'Play', 1.0),
+(5, 'rain', 68, 80, ARRAY[68, 80], ARRAY['medium', 'good'], 'false',
'Play', 1.0),
+(6, 'rain', 65, 70, ARRAY[65, 70], ARRAY['low', 'unhealthy'], 'true',
'Don''t Play', 1.0),
+(7, 'overcast', 64, 65, ARRAY[64, 65], ARRAY['medium', 'moderate'],
'true', 'Play', 1.5),
+(8, 'sunny', 72, 95, ARRAY[72, 95], ARRAY['high', 'unhealthy'], 'false',
'Don''t Play', 5.0),
+(9, 'sunny', 69, 70, ARRAY[69, 70], ARRAY['high', 'good'], 'false',
'Play', 5.0),
+(10, 'rain', 75, 80, ARRAY[75, 80], ARRAY['medium', 'good'], 'false',
'Play', 1.0),
+(11, 'sunny', 75, 70, ARRAY[75, 70], ARRAY['none', 'good'], 'true',
'Play', 5.0),
+(12, 'overcast', 72, 90, ARRAY[72, 90], ARRAY['medium', 'moderate'],
'true', 'Play', 1.5),
+(13, 'overcast', 81, 75, ARRAY[81, 75], ARRAY['medium', 'moderate'],
'false', 'Play', 1.5),
+(14, 'rain', 71, 80, ARRAY[71, 80], ARRAY['low', 'unhealthy'], 'true',
'Don''t Play', 1.0);
+
+-- Call the vec2cols function on the 'clouds_airquality' column, to split
it up
+
+DROP TABLE IF EXISTS output_table;
+SELECT {schema_madlib}.vec2cols(
+ 'golf', -- source table
+ 'output_table', -- output table
+ 'clouds_airquality', -- column with array entries to split
+ ARRAY['a', 'b'], -- feature_names array (will use 'a' to name the
first new column, and 'b' for the second)
+ '"OUTLOOK", id' -- columns to keep from source table (as a
comma-separated list)
+);
+
+SELECT * FROM output_table ORDER BY id;
+ OUTLOOK | id | a | b
+----------+----+--------+-----------
+ sunny | 1 | none | unhealthy
+ sunny | 2 | none | moderate
+ overcast | 3 | low | moderate
+ rain | 4 | low | moderate
+ rain | 5 | medium | good
+ rain | 6 | low | unhealthy
+ overcast | 7 | medium | moderate
+ sunny | 8 | high | unhealthy
+ sunny | 9 | high | good
+ rain | 10 | medium | good
+ sunny | 11 | none | good
+ overcast | 12 | medium | moderate
+ overcast | 13 | medium | moderate
+ rain | 14 | low | unhealthy
+(14 rows)
+""".format(schema_madlib=schema_madlib)
+
+ if not message:
+ return summary_string
+ elif message.lower() in ('usage', 'help', '?'):
+ return usage_string
+ elif message.lower() in ('example', 'examples'):
+ return example_string
+ else:
+ return """
+No such option. Use "SELECT {schema_madlib}.vec2cols()" for help.
+ """.format(schema_madlib=schema_madlib)
+
+class cols2vec:
+ def __init__(self):
+ self.get_cols_helper = vec_cols_helper()
+ self.module_name = self.__class__.__name__
+
+ def validate_args(self, source_table, output_table,
+ list_of_features, list_of_features_to_exclude,
cols_to_output):
+ """
+ Function to validate input parameters
+ """
+ input_tbl_valid(source_table, self.module_name)
+ output_tbl_valid(output_table, self.module_name)
+
+ _assert(list_of_features and list_of_features.strip(), "{0}: List
of "
+ "features cannot be empty".format(self.module_name))
+ if list_of_features.strip() != '*':
+ is_var_valid(source_table, list_of_features)
+
+ if list_of_features_to_exclude:
+ if list_of_features_to_exclude.strip() == "*":
+ plpy.error("{0}: Cannot exclude all columns from being "
+ "features".format(self.module_name))
+ elif list_of_features.strip() != '*':
+ plpy.info("{0} NOTICE: will exclude given column(s) even
though "
+ "list of features was not *".format(self.module_name))
+
+ is_var_valid(source_table, list_of_features_to_exclude)
+ is_var_valid(source_table, cols_to_output)
+
+ def get_and_validate_feature_types(self, source_table,
features_to_nest):
+ """
+ This function will validate and return the appropriate type to
cast
+ the final SQL array. Will fail if feature types do not belong
to the
+ same group (numeric, text, etc.) or if any type is an array.
+
+ If all features to nest are of the same type, we just return
that type
+ and do not cast. Else, if they only constitute integers and
smallints,
+ we cast everything to an integer. Else, if they are all part
of the same
+ type group, we return the most comprehensive type in that
group. Else, throw an error.
+ """
+ all_cols_and_types = get_cols_and_types(source_table)
+ distinct_types = set([col_type[1] for col_type in
all_cols_and_types
+ if col_type[0] in features_to_nest])
+ for expr_type in distinct_types:
+ _assert(not is_valid_psql_type(expr_type, ANY_ARRAY),
+ "{0}: Feature columns to nest cannot be of type array"
+ .format(self.module_name))
+ if len(distinct_types) > 1:
+ if distinct_types == {'integer', 'smallint'}:
--- End diff --
Does allowing underlying platform to make this decision lead to unknown
consequences?
PostgreSQL would be better suited to make these decisions and throw an
error if they can't be nested into an array, isn't it?
---