Github user iyerr3 commented on a diff in the pull request:
https://github.com/apache/madlib/pull/291#discussion_r201878219
--- Diff: src/ports/postgres/modules/utilities/vec2cols.py_in ---
@@ -0,0 +1,266 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import plpy
+from control import MinWarning
+from internal.db_utils import is_col_1d_array
+from utilities import _assert
+from utilities import ANY_ARRAY
+from utilities import is_valid_psql_type
+from utilities import py_list_to_sql_string
+from utilities import split_quoted_delimited_str
+from validate_args import cols_in_tbl_valid
+from validate_args import get_cols
+from validate_args import get_expr_type
+from validate_args import input_tbl_valid
+from validate_args import output_tbl_valid
+from validate_args import table_exists
+
+def get_cols_to_keep_list(cols_to_output, source_table=None,
vector_col=None):
+ """
+ Get a list of columns based on the value of cols_to_output
+ Args:
+ @param cols_to_output: str, Either a * or a comma separated list
of col names
+ @param source_table: str, optional. Source table name
+ @param vector_col: str, optional. Name of the column representing
the vector
+
+ Returns:
+ A list of column names (or an empty list)
+ """
+ # If cols_to_output is empty/None, return empty list
+ if not cols_to_output:
+ return []
+ if cols_to_output.strip() != "*":
+ # If cols_to_output is a comma separated list of names, return list
+ # of column names in cols_to_output.
+ return split_quoted_delimited_str(cols_to_output)
+ if source_table and vector_col:
+ # If cols_to_output is *, and both
+ # source_table and vector_col are non-null values, return a list of
+ # all columns in source_table except the vector_col.
+ return [col for col in get_cols(source_table) if col != vector_col]
+ return []
+
+def validate_args(source_table, out_table, vector_col, feature_names,
+ cols_to_output):
+ """
+ Validate args for vec2cols
+ """
+ input_tbl_valid(source_table, 'vec2cols')
+ output_tbl_valid(out_table, 'vec2cols')
+ cols_to_validate = get_cols_to_keep_list(cols_to_output) + [vector_col]
+ cols_in_tbl_valid(source_table, cols_to_validate, 'vec2cols')
+ # Check if vector_col is an array (not null)
+ _assert(is_valid_psql_type(get_expr_type(vector_col, source_table),
ANY_ARRAY),
+ "vec2cols: vector_col should refer to an array.")
+ # Check if vector_col is a 1-dimensional array
+ _assert(is_col_1d_array(source_table, vector_col),
+ "vec2cols: vector_col must be a 1-dimensional array.")
+
+def get_names_for_split_output_cols(source_table, vector_col,
feature_names):
+ """
+ Get list of names for the newly-split columns to include in the
+ output table.
+ Args:
+ @param: source_table, str. Source table
+ @param: vector_col, str. Column name containing the array input
+ @param: feature_names, list. Python list of the feature names to
+ use for the split elements in the vector_col array
+ """
+ query = """
--- End diff --
I'm assuming this was meant to use the `is_col_1d_array` function?
---