Repository: incubator-madlib Updated Branches: refs/heads/master 8e2778a39 -> 18b8486ca
DT/RF: Allow expressions in feature list JIRA: MADLIB-1087 Changes: - Add numeric as a continuous type - Get data type of features from an expression instead of the table column names - Update to allow expressions in the feature list Closes #129 Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/18b8486c Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/18b8486c Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/18b8486c Branch: refs/heads/master Commit: 18b8486ca651a218ea4b84c4b2876e30cd189e33 Parents: 8e2778a Author: Rahul Iyer <ri...@apache.org> Authored: Tue May 2 12:39:52 2017 -0700 Committer: Rahul Iyer <ri...@apache.org> Committed: Wed May 10 15:56:14 2017 -0700 ---------------------------------------------------------------------- .../recursive_partitioning/decision_tree.py_in | 37 ++++++++------------ .../recursive_partitioning/random_forest.py_in | 36 ++++++++++--------- .../test/random_forest.sql_in | 33 ++++++++--------- .../modules/utilities/validate_args.py_in | 2 +- 4 files changed, 52 insertions(+), 56 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/18b8486c/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in index dbf7db7..d3ca9b2 100644 --- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in +++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in @@ -137,13 +137,6 @@ def _get_features_to_use(schema_madlib, training_table_name, # ------------------------------------------------------------ -def _dict_get_quoted(input_dict, col_name): - """Return value from dict where key could be quoted or unquoted name""" - return input_dict.get( - col_name, input_dict.get(unquote_ident(col_name))) -# ------------------------------------------------------------------------- - - def _classify_features(feature_to_type, features): """ Returns 1) an array of categorical features (all casted to string) @@ -157,17 +150,16 @@ def _classify_features(feature_to_type, features): cat_types = int_types + text_types + boolean_types ordered_cat_types = int_types - cat_features = [c for c in features - if _dict_get_quoted(feature_to_type, c) in cat_types] - ordered_cat_features = [c for c in features if _dict_get_quoted( - feature_to_type, c) in ordered_cat_types] + cat_features = [c for c in features if feature_to_type[c] in cat_types] + ordered_cat_features = [c for c in features + if feature_to_type[c] in ordered_cat_types] cat_features_set = set(cat_features) # continuous types - 'real' is cast to 'double precision' for uniformity - con_types = ['real', 'float8', 'double precision'] + con_types = ['real', 'float8', 'double precision', 'numeric'] con_features = [c for c in features if (c not in cat_features_set and - _dict_get_quoted(feature_to_type, c) in con_types)] + feature_to_type[c] in con_types)] # In order to be able to form an array, all categorical variables # will be cast into TEXT type, but GPDB cannot cast a boolean @@ -175,7 +167,7 @@ def _classify_features(feature_to_type, features): # need special treatment: cast them into integers before casting # into text. boolean_cats = [c for c in features - if _dict_get_quoted(feature_to_type, c) in boolean_types] + if feature_to_type[c] in boolean_types] return cat_features, ordered_cat_features, con_features, boolean_cats # ------------------------------------------------------------ @@ -481,13 +473,12 @@ def get_grouping_array_str(table_name, grouping_cols, qualifier=None): else: qualifier_str = '' - all_cols_types = dict(get_cols_and_types(table_name)) grouping_cols_list = [col.strip() for col in grouping_cols.split(',')] - grouping_cols_and_types = [(col, _dict_get_quoted(all_cols_types, col)) - for col in grouping_cols_list] + grouping_cols_and_types = [(c, get_expr_type(c, table_name)) + for c in grouping_cols_list] grouping_array_str = 'array_to_string(array[' + \ ','.join("(case when " + col + " then 'True' else 'False' end)::text" - if col_type == 'boolean' else '(' + qualifier_str + col + ')::text' + if col_type.lower() == 'boolean' else '(' + qualifier_str + col + ')::text' for col, col_type in grouping_cols_and_types) + "]::text[], ',')" return grouping_array_str # ------------------------------------------------------------------------------ @@ -519,7 +510,9 @@ def _build_tree(schema_madlib, is_classification, split_criterion, with MinWarning(msg_level): plpy.notice("Building tree for cross validation") tree_states, bins, dep_list, n_rows = _get_tree_states(**locals()) - all_cols_types = dict(get_cols_and_types(training_table_name)) + all_cols_types = dict([(f, get_expr_type(f, training_table_name)) + for f in cat_features + con_features]) + n_all_rows = plpy.execute("select count(*) from " + training_table_name )[0]['count'] cp = grp_key_to_cp.values()[0] @@ -604,7 +597,8 @@ def tree_train(schema_madlib, training_table_name, output_table_name, "Decision tree error: No feature is selected for the model.") # 2) - all_cols_types = dict(get_cols_and_types(training_table_name)) + all_cols_types = dict([(f, get_expr_type(f, training_table_name)) + for f in features]) cat_features, ordered_cat_features, con_features, boolean_cats = \ _classify_features(all_cols_types, features) # get all rows @@ -1529,8 +1523,7 @@ def _create_summary_table( "$dep_list$") else: dep_list_str = "NULL" - indep_type = ', '.join(_dict_get_quoted(all_cols_types, col) - for col in cat_features + con_features) + indep_type = ', '.join(all_cols_types[c] for c in cat_features + con_features) dep_type = _get_dep_type(training_table_name, dependent_variable) cat_features_str = ','.join(cat_features) con_features_str = ','.join(con_features) http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/18b8486c/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in index 1226591..4b6f2d6 100644 --- a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in +++ b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in @@ -13,16 +13,17 @@ from math import sqrt from utilities.control import MinWarning from utilities.control import EnableOptimizer from utilities.control import EnableHashagg -from utilities.validate_args import get_cols_and_types -from utilities.validate_args import is_var_valid -from utilities.validate_args import input_tbl_valid -from utilities.validate_args import output_tbl_valid -from utilities.validate_args import cols_in_tbl_valid from utilities.utilities import _assert from utilities.utilities import unique_string from utilities.utilities import add_postfix from utilities.utilities import split_quoted_delimited_str from utilities.utilities import extract_keyvalue_params +from utilities.validate_args import get_cols_and_types +from utilities.validate_args import is_var_valid +from utilities.validate_args import input_tbl_valid +from utilities.validate_args import output_tbl_valid +from utilities.validate_args import cols_in_tbl_valid +from utilities.validate_args import get_expr_type from decision_tree import _tree_train_using_bins from decision_tree import _tree_train_grps_using_bins @@ -34,7 +35,6 @@ from decision_tree import _is_dep_categorical from decision_tree import _get_n_and_deplist from decision_tree import _classify_features from decision_tree import _get_filter_str -from decision_tree import _dict_get_quoted from decision_tree import _get_display_header from decision_tree import get_feature_str # ------------------------------------------------------------ @@ -272,8 +272,7 @@ def forest_train( with EnableHashagg(False): # we disable hashagg since large number of groups could # result in excessive memory usage. - ################################################################## - #### set default values + # set default values if grouping_cols is not None and grouping_cols.strip() == '': grouping_cols = None num_trees = 100 if num_trees is None else num_trees @@ -328,7 +327,8 @@ def forest_train( "Random forest error: Number of features to be selected " "is more than the actual number of features.") - all_cols_types = dict(get_cols_and_types(training_table_name)) + all_cols_types = dict([(f, get_expr_type(f, training_table_name)) + for f in features]) cat_features, ordered_cat_features, con_features, boolean_cats = \ _classify_features(all_cols_types, features) @@ -349,10 +349,10 @@ def forest_train( dep_col_str = ("CASE WHEN " + dependent_variable + " THEN 'True' ELSE 'False' END") if is_bool else dependent_variable dep = ("(CASE " + - "\n ".join([ - "WHEN ({dep_col})::text = $${c}$$ THEN {i}".format( - dep_col=dep_col_str, c=c, i=i) - for i, c in enumerate(dep_list)]) + + "\n ". + join(["WHEN ({dep_col})::text = $${c}$$ THEN {i}". + format(dep_col=dep_col_str, c=c, i=i) + for i, c in enumerate(dep_list)]) + "\nEND)") dep_n_levels = len(dep_list) else: @@ -388,12 +388,12 @@ def forest_train( bins['grp_key_cat'] = [''] else: grouping_cols_list = [col.strip() for col in grouping_cols.split(',')] - grouping_cols_and_types = [(col, _dict_get_quoted(all_cols_types, col)) + grouping_cols_and_types = [(col, get_expr_type(col, training_table_name)) for col in grouping_cols_list] grouping_array_str = ( "array_to_string(array[" + ','.join("(case when " + col + " then 'True' else 'False' end)::text" - if col_type == 'boolean' else '(' + col + ')::text' + if col_type.lower() == 'boolean' else '(' + col + ')::text' for col, col_type in grouping_cols_and_types) + "], ',')") grouping_cols_str = ('' if grouping_cols is None @@ -1198,8 +1198,9 @@ def _create_summary_table(**kwargs): else: kwargs['dep_list_str'] = "NULL" - kwargs['indep_type'] = ', '.join(_dict_get_quoted(kwargs['all_cols_types'], col) - for col in kwargs['cat_features'] + kwargs['con_features']) + kwargs['indep_type'] = ', '.join(kwargs['all_cols_types'][col] + for col in kwargs['cat_features'] + + kwargs['con_features']) kwargs['dep_type'] = _get_dep_type(kwargs['training_table_name'], kwargs['dependent_variable']) kwargs['cat_features_str'] = ','.join(kwargs['cat_features']) @@ -1408,6 +1409,7 @@ def _validate_get_tree(model, gid, sample_id): # ------------------------------------------------------------ + def forest_predict_help_message(schema_madlib, message, **kwargs): if not message: help_string = """ http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/18b8486c/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in index e91c71a..086e74b 100644 --- a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in +++ b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in @@ -4,25 +4,26 @@ CREATE TABLE dt_golf ( "OUTLOOK" text, temperature double precision, humidity double precision, + cont_features double precision[], windy boolean, class text ) ; -INSERT INTO dt_golf (id,"OUTLOOK",temperature,humidity,windy,class) VALUES -(1, 'sunny', 85, 85, false, 'Don''t Play'), -(2, 'sunny', 80, 90, true, 'Don''t Play'), -(3, 'overcast', 83, 78, false, 'Play'), -(4, 'rain', 70, 96, false, 'Play'), -(5, 'rain', 68, 80, false, 'Play'), -(6, 'rain', 65, 70, true, 'Don''t Play'), -(7, 'overcast', 64, 65, true, 'Play'), -(8, 'sunny', 72, 95, false, 'Don''t Play'), -(9, 'sunny', 69, 70, false, 'Play'), -(10, 'rain', 75, 80, false, 'Play'), -(11, 'sunny', 75, 70, true, 'Play'), -(12, 'overcast', 72, 90, true, 'Play'), -(13, 'overcast', 81, 75, false, 'Play'), -(14, 'rain', 71, 80, true, 'Don''t Play'); +INSERT INTO dt_golf (id,"OUTLOOK",temperature,humidity,cont_features,windy,class) VALUES +(1, 'sunny', 85, 85,ARRAY[85, 85], false, 'Don''t Play'), +(2, 'sunny', 80, 90,ARRAY[80, 90], true, 'Don''t Play'), +(3, 'overcast', 83, 78,ARRAY[83, 78], false, 'Play'), +(4, 'rain', 70, 96,ARRAY[70, 96], false, 'Play'), +(5, 'rain', 68, 80,ARRAY[68, 80], false, 'Play'), +(6, 'rain', 65, 70,ARRAY[65, 70], true, 'Don''t Play'), +(7, 'overcast', 64, 65,ARRAY[64, 65], true, 'Play'), +(8, 'sunny', 72, 95,ARRAY[72, 95], false, 'Don''t Play'), +(9, 'sunny', 69, 70,ARRAY[69, 70], false, 'Play'), +(10, 'rain', 75, 80,ARRAY[75, 80], false, 'Play'), +(11, 'sunny', 75, 70,ARRAY[75, 70], true, 'Play'), +(12, 'overcast', 72, 90,ARRAY[72, 90], true, 'Play'), +(13, 'overcast', 81, 75,ARRAY[81, 75], false, 'Play'), +(14, 'rain', 71, 80,ARRAY[71, 80], true, 'Don''t Play'); ------------------------------------------------------------------------- DROP TABLE IF EXISTS train_output, train_output_summary, train_output_group; @@ -31,7 +32,7 @@ SELECT forest_train( 'train_output'::TEXT, -- output model table 'id'::TEXT, -- id column 'class'::TEXT, -- response - 'windy, temperature'::TEXT, -- features + 'windy, cont_features[1]'::TEXT, -- features NULL::TEXT, -- exclude columns NULL::TEXT, -- no grouping 5, -- num of trees http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/18b8486c/src/ports/postgres/modules/utilities/validate_args.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in index 5832124..532326d 100644 --- a/src/ports/postgres/modules/utilities/validate_args.py_in +++ b/src/ports/postgres/modules/utilities/validate_args.py_in @@ -359,7 +359,7 @@ def get_expr_type(expr, tbl): FROM {1} LIMIT 1 """.format(expr, tbl))[0]['type'] - return expr_type.upper() + return expr_type.lower() # -------------------------------------------------------------------------