Repository: madlib Updated Branches: refs/heads/master 20f95b33b -> e2534e44e
DT/RF: Don't eliminate single-level cat variable JIRA: MADLIB-1258 When DT/RF is run with grouping, a subset of the groups could eliminate a categorical variable leading to multiple issues downstream, including invalid importance values and incorrect prediction. This commit keeps all categorical variables (even if it contains just one level). The accumulator state would use additional space during tree_train for this categorical variable, even though the variable is never consumed by the tree. This inefficiency is still preferred since it yields clean code and error-free prediction/importance reporting. Additional changes: - get_expr_type (validate_args.py) has been updated to return type for multiple expressions at the same time. This prevents calling a separate query for each expression, thus saving time. - Cat features are not stored per tree (in the grouping case) anymore since the features are now consistent across trees. Closes #301 Co-authored-by: Nandish Jayaram <njaya...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/madlib/repo Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/e2534e44 Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/e2534e44 Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/e2534e44 Branch: refs/heads/master Commit: e2534e44ea36aedec843a3a7c48236d0e1104e2c Parents: 20f95b3 Author: Rahul Iyer <ri...@apache.org> Authored: Thu Jul 26 12:17:58 2018 -0700 Committer: Rahul Iyer <ri...@apache.org> Committed: Wed Aug 1 12:51:13 2018 -0700 ---------------------------------------------------------------------- src/modules/recursive_partitioning/DT_impl.hpp | 91 ++++---- .../recursive_partitioning/decision_tree.cpp | 21 +- .../recursive_partitioning/decision_tree.py_in | 217 +++++++++---------- .../recursive_partitioning/random_forest.py_in | 120 +++++----- .../test/decision_tree.sql_in | 83 +++---- .../test/random_forest.sql_in | 46 ++-- .../modules/utilities/validate_args.py_in | 49 +++-- 7 files changed, 319 insertions(+), 308 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/madlib/blob/e2534e44/src/modules/recursive_partitioning/DT_impl.hpp ---------------------------------------------------------------------- diff --git a/src/modules/recursive_partitioning/DT_impl.hpp b/src/modules/recursive_partitioning/DT_impl.hpp index 69bdc88..75e4ce4 100644 --- a/src/modules/recursive_partitioning/DT_impl.hpp +++ b/src/modules/recursive_partitioning/DT_impl.hpp @@ -518,6 +518,7 @@ DecisionTree<Container>::expand(const Accumulator &state, double gain = impurityGain( state.cat_stats.row(stats_i). segment(fv_index, sps * 2), sps); + if (gain > max_impurity_gain){ max_impurity_gain = gain; max_feat = f; @@ -665,21 +666,29 @@ DecisionTree<Container>::pickSurrogates( // 1. Compute the max count and corresponding split threshold for // each categorical and continuous feature + ColumnVector cat_max_thres = ColumnVector::Zero(n_cats); ColumnVector cat_max_count = ColumnVector::Zero(n_cats); IntegerVector cat_max_is_reverse = IntegerVector::Zero(n_cats); Index prev_cum_levels = 0; for (Index each_cat=0; each_cat < n_cats; each_cat++){ Index n_levels = state.cat_levels_cumsum(each_cat) - prev_cum_levels; - Index max_label; - (cat_stats_counts.row(stats_i).segment( - prev_cum_levels * 2, n_levels * 2)).maxCoeff(&max_label); - cat_max_thres(each_cat) = static_cast<double>(max_label / 2); - cat_max_count(each_cat) = - cat_stats_counts(stats_i, prev_cum_levels*2 + max_label); - // every odd col is for reverse, hence i % 2 == 1 for reverse index i - cat_max_is_reverse(each_cat) = (max_label % 2 == 1) ? 1 : 0; - prev_cum_levels = state.cat_levels_cumsum(each_cat); + if (n_levels > 0){ + Index max_label; + (cat_stats_counts.row(stats_i).segment( + prev_cum_levels * 2, n_levels * 2)).maxCoeff(&max_label); + + // For each split, there are two stats => + // max_label / 2 gives the split index. A floor + // operation is unnecessary since the threshold will yield + // the same results for n and n+0.5. + cat_max_thres(each_cat) = static_cast<double>(max_label / 2); + cat_max_count(each_cat) = + cat_stats_counts(stats_i, prev_cum_levels*2 + max_label); + // every odd col is for reverse, hence i % 2 == 1 for reverse index i + cat_max_is_reverse(each_cat) = max_label % 2; + prev_cum_levels = state.cat_levels_cumsum(each_cat); + } } ColumnVector con_max_thres = ColumnVector::Zero(n_cons); @@ -800,7 +809,7 @@ DecisionTree<Container>::expand_by_sampling(const Accumulator &state, std::random_shuffle(cat_con_feature_indices, cat_con_feature_indices + total_cat_con_features, rvt); // if a leaf node exists, compute the gain in impurity for each split - // pick split with maximum gain and update node with split value + // pick split with maximum gain and update node with split value int max_feat = -1; Index max_bin = -1; bool max_is_cat = false; @@ -831,7 +840,6 @@ DecisionTree<Container>::expand_by_sampling(const Accumulator &state, } } - } else { //f >= state.n_cat.features //continuous feature f -= state.n_cat_features; @@ -854,40 +862,41 @@ DecisionTree<Container>::expand_by_sampling(const Accumulator &state, } } - // Create and update child nodes if splitting current - uint64_t true_count = statCount(max_stats.segment(0, sps)); - uint64_t false_count = statCount(max_stats.segment(sps, sps)); - uint64_t total_count = statCount(predictions.row(current)); - - if (max_impurity_gain > 0 && - shouldSplit(total_count, true_count, false_count, + bool is_leaf_split = FALSE; + if (max_impurity_gain > 0){ + // Create and update child nodes if splitting current + uint64_t true_count = statCount(max_stats.segment(0, sps)); + uint64_t false_count = statCount(max_stats.segment(sps, sps)); + uint64_t total_count = statCount(predictions.row(current)); + if (shouldSplit(total_count, true_count, false_count, min_split, min_bucket, max_depth)) { - double max_threshold; - if (max_is_cat) - max_threshold = static_cast<double>(max_bin); - else - max_threshold = con_splits(max_feat, max_bin); - - if (children_not_allocated) { - // allocate the memory for child nodes if not allocated already - incrementInPlace(); - children_not_allocated = false; - } - - children_wont_split &= - updatePrimarySplit( - current, static_cast<int>(max_feat), - max_threshold, max_is_cat, - min_split, - max_stats.segment(0, sps), // true_stats - max_stats.segment(sps, sps) // false_stats - ); + is_leaf_split = TRUE; + double max_threshold; + if (max_is_cat) + max_threshold = static_cast<double>(max_bin); + else + max_threshold = con_splits(max_feat, max_bin); + + if (children_not_allocated) { + // allocate the memory for child nodes if not allocated already + incrementInPlace(); + children_not_allocated = false; + } - } else { + children_wont_split &= + updatePrimarySplit( + current, static_cast<int>(max_feat), + max_threshold, max_is_cat, + min_split, + max_stats.segment(0, sps), // true_stats + max_stats.segment(sps, sps) // false_stats + ); + } // if shouldSplit + } //if max_impurity_gain > 0 + if (not is_leaf_split) feature_indices(current) = FINISHED_LEAF; - } - } // if leaf exists + } // if leaf is in_process } // for each leaf // return true if tree expansion is finished http://git-wip-us.apache.org/repos/asf/madlib/blob/e2534e44/src/modules/recursive_partitioning/decision_tree.cpp ---------------------------------------------------------------------- diff --git a/src/modules/recursive_partitioning/decision_tree.cpp b/src/modules/recursive_partitioning/decision_tree.cpp index 351fced..d249946 100644 --- a/src/modules/recursive_partitioning/decision_tree.cpp +++ b/src/modules/recursive_partitioning/decision_tree.cpp @@ -123,19 +123,22 @@ compute_leaf_stats_transition::run(AnyType & args){ return args[0]; } - // cat_levels size = n_cat_features + // cat_levels.size = n_cat_features NativeIntegerVector cat_levels; if (args[6].isNull()){ cat_levels.rebind(this->allocateArray<int>(0)); } else { - MutableNativeIntegerVector xx_cat = args[6].getAs<MutableNativeIntegerVector>(); - for (Index i = 0; i < xx_cat.size(); i++) - xx_cat[i] -= 1; // ignore the last level since a split - // like 'var <= last level' would move all rows to - // a one side. Such a split will always be ignored - // when selecting the best split. - cat_levels.rebind(xx_cat.memoryHandle(), xx_cat.size()); + MutableNativeIntegerVector n_levels_per_cat = + args[6].getAs<MutableNativeIntegerVector>(); + for (Index i = 0; i < n_levels_per_cat.size(); i++){ + n_levels_per_cat[i] -= 1; + // ignore the last level since a split + // like 'var <= last level' would move all rows to + // a one side. Such a split will always be ignored + // when selecting the best split. + } + cat_levels.rebind(n_levels_per_cat.memoryHandle(), n_levels_per_cat.size()); } // con_splits size = num_con_features x num_bins @@ -203,7 +206,6 @@ compute_leaf_stats_transition::run(AnyType & args){ current_sum += cat_levels(i); state.cat_levels_cumsum(i) = current_sum; } - } state << MutableLevelState::tuple_type(dt, cat_features, con_features, @@ -257,7 +259,6 @@ dt_apply::run(AnyType & args){ return_code = TERMINATED; // indicates termination due to error } - AnyType output_tuple; output_tuple << dt.storage() << return_code http://git-wip-us.apache.org/repos/asf/madlib/blob/e2534e44/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in index 89acd8a..cedfa48 100644 --- a/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in +++ b/src/ports/postgres/modules/recursive_partitioning/decision_tree.py_in @@ -30,6 +30,7 @@ from utilities.utilities import unique_string from utilities.validate_args import _get_table_schema_names from utilities.validate_args import columns_exist_in_table +from utilities.validate_args import explicit_bool_to_text from utilities.validate_args import get_cols from utilities.validate_args import get_cols_and_types from utilities.validate_args import get_expr_type @@ -49,13 +50,6 @@ def _tree_validate_args( min_split, min_bucket, n_bins, cp, n_folds, **kwargs): """ Validate the arguments """ - if not split_criterion: - split_criterion = 'gini' - _assert(split_criterion.lower().strip() in ['mse', 'gini', 'cross-entropy', - 'entropy', 'misclass', - 'misclassification'], - "Decision tree error: Invalid split_criterion.") - _assert(training_table_name and training_table_name.strip().lower() not in ('null', ''), "Decision tree error: Invalid data table.") @@ -110,6 +104,24 @@ def _tree_validate_args( # ------------------------------------------------------------ +def _validate_split_criterion(split_criterion, is_classification): + _assert(split_criterion.lower().strip() in ['mse', 'gini', 'cross-entropy', + 'entropy', 'misclass', + 'misclassification'], + "Decision tree error: Invalid split_criterion.") + if is_classification: + if split_criterion.lower().strip() == "mse": + plpy.error("Decision tree error: MSE is not a valid " + "split criterion for classification.") + else: + if split_criterion.lower().strip() != "mse": + plpy.warning("Decision tree: Using MSE as split criterion as it " + "is the only one supported for regression trees.") + split_criterion = "mse" + return split_criterion +# ------------------------------------------------------------------------------ + + def _get_features_to_use(schema_madlib, training_table_name, list_of_features, list_of_features_to_exclude, id_col_name, weights, dependent_variable, @@ -235,30 +247,25 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion, key is '' """ filter_dep = _get_filter_str(dependent_variable, grouping_cols) - # 3) if is_classification: - if split_criterion.lower().strip() == "mse": - plpy.error("Decision tree error: MSE is not a valid " - "split criterion for classification.") # For classifications, we also need to map dependent_variable to integers n_rows, dep_list = _get_n_and_deplist(training_table_name, dependent_variable, filter_dep) - dep_list.sort() if dep_is_bool: - dep_col_str = ("CASE WHEN {0} THEN 'True' ELSE 'False' END". + # false = 0, true = 1 + # This order is maintained in dep_list since + # _get_n_and_deplist returns a sorted list + dep_var_str = ("(CASE WHEN {0} THEN 1 ELSE 0 END)". format(dependent_variable)) + else: - dep_col_str = dependent_variable - dep_var_str = ("(CASE " + - "\n\t\t".join(["WHEN ({dep_col})::text = $${c}$$ THEN {i}". - format(dep_col=dep_col_str, c=c, i=i) - for i, c in enumerate(dep_list)]) + - "\nEND)") + dep_var_str = ("(CASE " + + "\n\t\t".join(["WHEN ({0})::text = $${1}$$ THEN {2}". + format(dependent_variable, str(c), i) + for i, c in enumerate(dep_list)]) + + "\nEND)") else: - if split_criterion.lower().strip() != "mse": - plpy.warning("Decision tree: Using MSE as split criterion as it " - "is the only one supported for regression trees.") n_rows = long(plpy.execute("SELECT count(*)::bigint " "FROM {src} " "WHERE {filter}". @@ -268,7 +275,6 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion, dep_list = [] dep_n_levels = len(dep_list) if dep_list else 1 - cat_features_info_table = unique_string() if not grouping_cols: # non-grouping case # 3) Find the splitting bins, one dict containing two arrays: @@ -288,12 +294,9 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion, tree = _tree_train_using_bins(**locals()) tree['grp_key'] = '' tree['cp'] = grp_key_to_cp[tree['grp_key']] - tree['cat_features'] = cat_features - tree['con_features'] = con_features tree_states = [tree] else: grouping_array_str = get_grouping_array_str(training_table_name, grouping_cols) - with OptimizerControl(False): # we disable optimizer (ORCA) for platforms that use it # since ORCA doesn't provide an easy way to disable hashagg @@ -325,8 +328,6 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion, tree['cp'] = grp_key_to_cp.values()[0] else: tree['cp'] = grp_key_to_cp[grp_key] - tree['cat_features'] = bins['grp_to_cat_features'][grp_key] - tree['con_features'] = bins['con_features'] # 5) prune the tree using provided 'cp' value and produce a list of # cp values if cross-validation is required (cp_list = [] if not) @@ -345,9 +346,8 @@ def _get_tree_states(schema_madlib, is_classification, split_criterion, importance_vectors = _compute_var_importance( schema_madlib, tree, - len(tree['cat_features']), len(tree['con_features'])) + len(cat_features), len(con_features)) tree.update(**importance_vectors) - return tree_states, bins, dep_list, n_rows, cat_features_info_table # ------------------------------------------------------------------------- @@ -357,18 +357,18 @@ def get_grouping_array_str(table_name, grouping_cols, qualifier=None): Args: @param grouping_cols: list, List of columns used as grouping columns """ - if qualifier: - qualifier_str = qualifier + "." - else: - qualifier_str = '' + def _col_to_text(col, col_type): + qualifier_str = qualifier + "." if qualifier else '' + if is_psql_boolean_type(col_type): + return "(case when {0} then 'true' else 'false' end)::text".format(col) + else: + return '({0}{1})::text'.format(qualifier_str, col) grouping_cols_list = [col.strip() for col in grouping_cols.split(',')] grouping_cols_and_types = [(c, get_expr_type(c, table_name)) for c in grouping_cols_list] - grouping_array_str = 'array_to_string(array[' + \ - ','.join("(case when " + col + " then 'True' else 'False' end)::text" - if col_type.lower() == 'boolean' else '(' + qualifier_str + col + ')::text' - for col, col_type in grouping_cols_and_types) + "]::text[], ',')" + grouping_array_str = "array_to_string(array[{0}]::text[], ',')".format( + ','.join(_col_to_text(col, col_type) for col, col_type in grouping_cols_and_types)) return grouping_array_str # ------------------------------------------------------------------------------ @@ -516,6 +516,7 @@ def tree_train(schema_madlib, training_table_name, output_table_name, is_classification, dep_is_bool = _is_dep_categorical( training_table_name, dependent_variable) + split_criterion = _validate_split_criterion(split_criterion, is_classification) # 4) Build the tree with provided cp value compute_cp_list = (n_folds > 1) @@ -574,11 +575,14 @@ def _get_n_and_deplist(training_table_name, dependent_variable, filter_null): @brief Query the database for the total number of rows and levels of dependent variable if the dependent variable is categorical. + + Note: The deplist is sorted in the array_agg which is necessary to ensure + false = 0 and true = 1 for boolean dependent variable. """ sql = """ SELECT sum(n_rows) as n_rows, - array_agg(dep) as dep + array_agg(dep ORDER BY dep) as dep FROM ( SELECT count(*) as n_rows, @@ -700,7 +704,7 @@ def _get_bins(schema_madlib, training_table_name, # variable levels to integers, and keep this mapping in the memory. if len(cat_features) > 0: if is_classification: - # For classifications + # For classification the dependent variable is encoded as an integer order_fun = ("{madlib}._dst_compute_entropy({dep}, {n})". format(madlib=schema_madlib, dep=dependent_variable, @@ -745,18 +749,12 @@ def _get_bins(schema_madlib, training_table_name, {union_null_proxy} ) s ) s1 - WHERE array_upper(levels, 1) > 1 """.format(training_table_name=training_table_name, filter_str=filter_str, union_null_proxy=union_null_proxy) - all_col_expressions = {} - for col in cat_features: - if col in boolean_cats: - all_col_expressions[col] = ("(CASE WHEN " + col + - " THEN 'True' ELSE 'False' END)") - else: - all_col_expressions[col] = col + all_col_expressions = dict(zip(cat_features, explicit_bool_to_text( + training_table_name, cat_features, schema_madlib))) sql_all_cats = ' UNION '.join( sql_cat_levels.format( @@ -765,14 +763,6 @@ def _get_bins(schema_madlib, training_table_name, order_fun=expr if col_name in ordered_cat_features else order_fun) for col_name, expr in all_col_expressions.items()) all_levels = plpy.execute(sql_all_cats) - - if len(all_levels) != len(cat_features): - use_cat_features = [row['colname'] for row in all_levels] - cat_features = [feature for feature in cat_features - if feature in use_cat_features] - plpy.warning("Decision tree warning: Categorical columns with only " - "one value are dropped from the tree model.") - col_to_row = dict((row['colname'], i) for i, row in enumerate(all_levels)) return dict( @@ -992,21 +982,10 @@ def _get_bins_grps( ) s GROUP BY grp_key ) s1 - where array_upper(levels, 1) > 1 """.format(**locals()) - all_col_expressions = {} - for col in cat_features: - if col in boolean_cats: - all_col_expressions[col] = ("(CASE WHEN " + col + - " THEN 'True' ELSE 'False' END)") - else: - if null_proxy is not None: - all_col_expressions[col] = ("COALESCE({0}::TEXT, '{1}')". - format(col, null_proxy)) - else: - all_col_expressions[col] = col - + all_col_expressions = dict(zip(cat_features, explicit_bool_to_text( + training_table_name, cat_features, schema_madlib))) sql_all_cats = ' UNION ALL '.join( sql_cat_levels.format( col=expr, @@ -1017,13 +996,6 @@ def _get_bins_grps( all_levels = list(plpy.execute(sql_all_cats)) all_levels.sort(key=itemgetter('grp_key')) - use_cat_features = set([row['colname'] for row in all_levels]) - if len(use_cat_features) != len(cat_features): - plpy.warning("Decision tree warning: Categorical columns with only " - "one value are dropped from the tree model.") - cat_features = [feature for feature in cat_features - if feature in use_cat_features] - # grp_col_to_levels is a list of tuples (pairs) with # first value = group value, # second value = a dict mapping a categorical column to its levels in data @@ -1079,7 +1051,7 @@ def _get_bins_grps( grp_key_cat=grp_key_cat, grouping_array_str=grouping_array_str, grp_to_col_to_levels=grp_to_col_to_levels, - grp_to_cat_features=grp_to_cat_features) + ) # ------------------------------------------------------------ @@ -1115,7 +1087,9 @@ def _create_cat_features_info_table(cat_features_info_table, bins): cat_levels = [quote_literal(each_level) for sublist in cat_levels for each_level in sublist] - cat_levels_str = py_list_to_sql_string(cat_levels, 'text', long_format=True) + cat_levels_str = py_list_to_sql_string(cat_levels, + 'text', + long_format=True) else: # this is the case if no categorical features present cat_names_str = cat_n_levels_str = cat_levels_str = "NULL" @@ -1151,7 +1125,7 @@ def _create_cat_features_info_table(cat_features_info_table, bins): # ------------------------------------------------------------------------------ -def get_feature_str(schema_madlib, boolean_cats, +def get_feature_str(schema_madlib, source_table, cat_features, con_features, levels_str, n_levels_str, null_proxy=None): @@ -1163,14 +1137,16 @@ def get_feature_str(schema_madlib, boolean_cats, # (1 to N). The unique value will be mapped to -1 indicating an # unknown/missing value in the underlying layers. null_val = unique_string() if null_proxy is None else null_proxy + + # Cast boolean column to text: requires a special cast expression for + # platforms where __HAS_BOOL_TO_TEXT_CAST__ is not enabled + patched_cat_features = explicit_bool_to_text(source_table, + cat_features, + schema_madlib) cat_features_cast = [] - for col in cat_features: - if col in boolean_cats: - cat_features_cast.append( - "(CASE WHEN " + col + " THEN 'True' ELSE 'False' END)::text") - else: - cat_features_cast.append( - "(coalesce({0}::text, '{1}'))::text".format(col, null_val)) + for col in patched_cat_features: + cat_features_cast.append( + "(coalesce(({0})::text, '{1}'))::text".format(col, null_val)) cat_features_str = ("{0}._map_catlevel_to_int(array[" + ", ".join(cat_features_cast) + "], {1}, {2}, {3})" @@ -1207,7 +1183,7 @@ def _one_step(schema_madlib, training_table_name, cat_features, # XXX cat_feature_str contains $5 and $2, and a SQL function bytea8 = schema_madlib + '.bytea8' cat_features_str, con_features_str = get_feature_str(schema_madlib, - boolean_cats, + training_table_name, cat_features, con_features, "$3", "$2", @@ -1303,7 +1279,7 @@ def _one_step_for_grps( cat_levels_in_text = unique_string() cat_features_str, con_features_str = get_feature_str( - schema_madlib, boolean_cats, cat_features, con_features, + schema_madlib, training_table_name, cat_features, con_features, cat_levels_in_text, cat_n_levels, null_proxy) train_apply_func = """ @@ -1578,14 +1554,19 @@ def _create_summary_table( output_table_summary = add_postfix(output_table_name, "_summary") # dependent variables + dep_type = _get_dep_type(training_table_name, dependent_variable) if dep_list: - dep_list_str = ("$dep_list$" + - ','.join('"{0}"'.format(str(dep)) for dep in dep_list) + - "$dep_list$") + if is_psql_boolean_type(dep_type): + # Special handling for boolean since Python booleans start with + # capitals (i.e False instead of false) + # Note: dep_list is sorted, hence 'false' will be first + dep_list_str = "'false, true'" + else: + dep_list_str = '$__dep_list__${0}$__dep_list__$'.format( + ','.join(map(str, dep_list))) else: dep_list_str = "NULL" indep_type = ', '.join(all_cols_types[c] for c in cat_features + con_features) - dep_type = _get_dep_type(training_table_name, dependent_variable) independent_varnames = ','.join(cat_features + con_features) cat_features_str = ','.join(cat_features) con_features_str = ','.join(con_features) @@ -1626,8 +1607,8 @@ def _create_summary_table( {n_rows_skipped}::integer AS total_rows_skipped, {dep_list_str}::text AS dependent_var_levels, '{dep_type}'::text AS dependent_var_type, - {cp_str} AS input_cp, '{indep_type}'::text AS independent_var_types, + {cp_str} AS input_cp, {n_folds}::integer AS n_folds, {null_proxy_str}::text AS null_proxy """.format(**locals()) @@ -1704,7 +1685,6 @@ def tree_predict(schema_madlib, model, source, output, pred_type='response', 'prob' gives the probability of the classes in a classification tree. For regression tree, only type='response' is defined. - Returns: None @@ -1728,9 +1708,9 @@ def tree_predict(schema_madlib, model, source, output, pred_type='response', "that were used during training".format(source)) id_col_name = summary_elements["id_col_name"] dep_varname = summary_elements["dependent_varname"] - dep_levels = summary_elements["dependent_var_levels"] - is_classification = summary_elements["is_classification"] + dep_levels = split_quoted_delimited_str(summary_elements["dependent_var_levels"]) dep_type = summary_elements['dependent_var_type'] + is_classification = summary_elements["is_classification"] # optional variables, default value is None grouping_cols_str = summary_elements.get("grouping_cols") null_proxy = summary_elements.get('null_proxy') @@ -1740,12 +1720,12 @@ def tree_predict(schema_madlib, model, source, output, pred_type='response', if value == 'boolean']) cat_features_str, con_features_str = get_feature_str( - schema_madlib, boolean_cats, cat_features, con_features, + schema_madlib, source, cat_features, con_features, "m.cat_levels_in_text", "m.cat_n_levels", null_proxy) if use_existing_tables and table_exists(output): - plpy.execute("truncate " + output) - header = "INSERT INTO " + output + " " + plpy.execute("TRUNCATE " + output) + header = "INSERT INTO " + output use_fold = 'WHERE k = ' + str(k) else: header = "CREATE TABLE " + output + " AS " @@ -1775,23 +1755,27 @@ def tree_predict(schema_madlib, model, source, output, pred_type='response', {use_fold} """ else: - if dep_type.lower() == "boolean": + if is_psql_boolean_type(dep_type): # some platforms don't have text to boolean cast. We manually check the string. - dep_cast_str = ("(case {pred_name} when 'True' then " - "True else False end)::BOOLEAN as {pred_name}") + dep_cast_str = ("(case {pred_name} when 'true' then true " + " when 'false' then false " + "end)::BOOLEAN as {pred_name}") else: dep_cast_str = "{pred_name}::{dep_type}" + dep_levels_array_str = py_list_to_sql_string(map(quote_literal, dep_levels), + 'TEXT', + long_format=True) if pred_type == "response": sql = header + """ SELECT - {id_col_name} - , %s + {id_col_name}, + %s FROM ( SELECT - {id_col_name} - , - ($sql${{ {dep_levels} }}$sql$::varchar[])[ - {schema_madlib}._predict_dt_response ( + {id_col_name}, + -- _predict_dt_response returns 0-based indexing. + -- Hence the "+ 1" (DB by default uses 1-based indexing) + (%s)[{schema_madlib}._predict_dt_response ( tree, {cat_features_str}::INTEGER[], {con_features_str}::DOUBLE PRECISION[]) + 1]::TEXT @@ -1799,27 +1783,26 @@ def tree_predict(schema_madlib, model, source, output, pred_type='response', FROM {source} as s {join_str} {model} as m {using_str} {use_fold} ) q - """ % (dep_cast_str) + """ % (dep_cast_str, dep_levels_array_str) else: - intermediate_col = unique_string() + temp_col = unique_string() score_format = ', \n'.join([ - '{interim}[{j}] as "estimated_prob_{c}"'. - format(j=i+1, c=c.strip(' "'), interim=intermediate_col) - for i, c in enumerate(split_quoted_delimited_str(dep_levels))]) + '{0}[{1}] as "estimated_prob_{2}"'.format(temp_col, i, c.strip(' "')) + for i, c in enumerate(dep_levels, start=1)]) sql = header + """ SELECT {id_col_name}, - {score_format} + %s FROM ( SELECT {id_col_name}, {schema_madlib}._predict_dt_prob(tree, {cat_features_str}::INTEGER[], {con_features_str}::DOUBLE PRECISION[]) - AS {intermediate_col} + AS {temp_col} FROM {source} as s {join_str} {model} as m {using_str} {use_fold} ) q - """ + """ % (score_format) sql = sql.format(**locals()) with MinWarning('warning'): with OptimizerControl(False): @@ -2181,7 +2164,7 @@ def _xvalidate(schema_madlib, tree_states, training_table_name, output_table_nam tree['pruned_depth'] = 0 importance_vectors = _compute_var_importance( schema_madlib, tree, - len(tree['cat_features']), len(tree['con_features'])) + len(cat_features), len(con_features)) tree.update(**importance_vectors) plpy.execute("DROP TABLE {group_to_param_list_table}".format(**locals())) http://git-wip-us.apache.org/repos/asf/madlib/blob/e2534e44/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in index 4d74872..d3b3f09 100644 --- a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in +++ b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in @@ -18,12 +18,12 @@ from utilities.control import HashaggControl from utilities.utilities import _assert from utilities.utilities import add_postfix from utilities.utilities import extract_keyvalue_params +from utilities.utilities import is_psql_boolean_type from utilities.utilities import py_list_to_sql_string from utilities.utilities import split_quoted_delimited_str from utilities.utilities import unique_string from utilities.validate_args import cols_in_tbl_valid -from utilities.validate_args import get_cols_and_types from utilities.validate_args import get_expr_type from utilities.validate_args import input_tbl_valid from utilities.validate_args import is_var_valid @@ -34,13 +34,13 @@ from decision_tree import _tree_train_grps_using_bins from decision_tree import _get_bins from decision_tree import _get_bins_grps from decision_tree import _get_features_to_use -from decision_tree import _get_dep_type from decision_tree import _is_dep_categorical from decision_tree import _get_n_and_deplist from decision_tree import _classify_features from decision_tree import _get_filter_str from decision_tree import _get_display_header from decision_tree import get_feature_str +from decision_tree import get_grouping_array_str from decision_tree import _compute_var_importance from decision_tree import _create_cat_features_info_table # ------------------------------------------------------------ @@ -324,7 +324,7 @@ def forest_train( _assert(bool(features), "Random forest error: No feature is selected for the model.") - is_classification, is_bool = _is_dep_categorical( + is_classification, dep_is_bool = _is_dep_categorical( training_table_name, dependent_variable) split_criterion = 'gini' if is_classification else 'mse' @@ -352,24 +352,26 @@ def forest_train( n_rows, dep_list = _get_n_and_deplist(training_table_name, dependent_variable, filter_null) + dep_n_levels = len(dep_list) _assert(n_rows > 0, "Random forest error: There should be at least one " "data point for each class where all features are non NULL") - dep_list.sort() - dep_col_str = ("CASE WHEN " + dependent_variable + - " THEN 'True' ELSE 'False' END") if is_bool else dependent_variable - dep = ("(CASE " + - "\n ". - join(["WHEN ({dep_col})::text = $${c}$$ THEN {i}". - format(dep_col=dep_col_str, c=c, i=i) - for i, c in enumerate(dep_list)]) + - "\nEND)") - dep_n_levels = len(dep_list) + if dep_is_bool: + # false = 0, true = 1 + # This order is maintained in dep_list since + # _get_n_and_deplist returns a sorted list + dep = ("(CASE WHEN {0} THEN 1 ELSE 0 END)". + format(dependent_variable)) + else: + dep = ("(CASE " + + "\n\t\t".join(["WHEN ({0})::text = $${1}$$ THEN {2}". + format(dependent_variable, c, i) + for i, c in enumerate(dep_list)]) + + "\nEND)") else: n_rows = plpy.execute( - "SELECT count(*) FROM {source_table} where {filter_null}". - format(source_table=training_table_name, - filter_null=filter_null))[0]['count'] + "SELECT count(*) FROM {0} WHERE {1}". + format(training_table_name, filter_null))[0]['count'] dep = dependent_variable dep_n_levels = 1 dep_list = None @@ -397,15 +399,8 @@ def forest_train( cat_features = bins['cat_features'] bins['grp_key_cat'] = [''] else: - grouping_cols_list = [col.strip() for col in grouping_cols.split(',')] - grouping_cols_and_types = [(col, get_expr_type(col, training_table_name)) - for col in grouping_cols_list] - grouping_array_str = ( - "array_to_string(array[" + - ','.join("(case when " + col + " then 'True' else 'False' end)::text" - if col_type.lower() == 'boolean' else '(' + col + ')::text' - for col, col_type in grouping_cols_and_types) + - "], ',')") + grouping_array_str = get_grouping_array_str( + training_table_name, grouping_cols) grouping_cols_str = ('' if grouping_cols is None else grouping_cols + ",") sql_grp_key_to_grp_cols = """ @@ -430,7 +425,8 @@ def forest_train( con_features, num_bins, dep, boolean_cats, grouping_cols, grouping_array_str, n_rows, - is_classification, dep_n_levels, filter_null, null_proxy) + is_classification, dep_n_levels, + filter_null, null_proxy) cat_features = bins['cat_features'] # a table for getting information of cat features for each group @@ -558,8 +554,6 @@ def forest_train( num_random_features, max_n_surr, null_proxy) tree['grp_key'] = '' - tree['cat_features'] = cat_features - tree['con_features'] = con_features if importance: tree.update(_compute_var_importance( schema_madlib, tree, @@ -584,14 +578,12 @@ def forest_train( # stop calculating that group further. for tree in tree_states: grp_key = tree['grp_key'] - tree['cat_features'] = bins['grp_to_cat_features'][grp_key] - tree['con_features'] = bins['con_features'] tree_terminated[grp_key] = tree['finished'] if importance: importance_vectors = _compute_var_importance( schema_madlib, tree, - len(tree['cat_features']), - len(tree['con_features'])) + len(cat_features), + len(con_features)) tree.update(**importance_vectors) _insert_into_result_table( @@ -602,7 +594,7 @@ def forest_train( schema_madlib, output_table_name, cat_features_info_table, con_splits_table, oob_prediction_table, oob_view, sample_id, id_col_name, cat_features, con_features, - boolean_cats, grouping_cols, grp_key_to_grp_cols, dep, + training_table_name, grouping_cols, grp_key_to_grp_cols, dep, num_permutations, is_classification, importance, num_bins, filter_null, null_proxy) @@ -686,7 +678,7 @@ def forest_predict(schema_madlib, model, source, output, id_col_name = summary_elements["id_col_name"] grouping_cols = summary_elements.get("grouping_cols") # optional, default = None dep_varname = summary_elements["dependent_varname"] - dep_levels = summary_elements["dependent_var_levels"] + dep_levels = split_quoted_delimited_str(summary_elements["dependent_var_levels"]) is_classification = summary_elements["is_classification"] dep_type = summary_elements['dependent_var_type'] null_proxy = summary_elements.get('null_proxy') # optional, default = None @@ -695,12 +687,8 @@ def forest_predict(schema_madlib, model, source, output, _assert(is_classification or pred_type == 'response', "Random forest error: pred_type cannot be 'prob' for regression model.") - # find which columns are of type boolean - boolean_cats = set([key for key, value in get_cols_and_types(source) - if value == 'boolean']) - cat_features_str, con_features_str = get_feature_str( - schema_madlib, boolean_cats, cat_features, con_features, + schema_madlib, source, cat_features, con_features, "cat_levels_in_text", "cat_n_levels", null_proxy) pred_name = ('"prob_{0}"' if pred_type == "prob" else @@ -712,17 +700,21 @@ def forest_predict(schema_madlib, model, source, output, if not is_classification: majority_pred_expression = "AVG(aggregated_prediction)" else: - majority_pred_expression = """($sql${{ {dep_levels} }}$sql$::varchar[])[ - {schema_madlib}.mode(aggregated_prediction + 1)]::TEXT - """.format(**locals()) - - if dep_type.lower() == "boolean": + dep_levels_array_str = py_list_to_sql_string(map(quote_literal, dep_levels), + 'TEXT', + long_format=True) + majority_pred_expression = ( + "({0})[{1}.mode(aggregated_prediction + 1)]::TEXT". + format(dep_levels_array_str, schema_madlib)) + + if is_psql_boolean_type(dep_type): # some platforms don't have text to boolean cast. We manually check the string. - majority_pred_cast_str = ("(case {majority_pred_expression} when 'True' then " - "True else False end)::BOOLEAN as {pred_name}") + majority_pred_cast_str = ("(case {majority_pred_expression} " + " when 'true' then true " + " when 'false' then false " + " end)::BOOLEAN AS {pred_name}") else: - majority_pred_cast_str = "{majority_pred_expression}::{dep_type} as {pred_name}" - + majority_pred_cast_str = "({majority_pred_expression})::{dep_type} AS {pred_name}" majority_pred_cast_str = majority_pred_cast_str.format(**locals()) num_trees_grown = plpy.execute( "SELECT count(DISTINCT sample_id) FROM {0}".format(model))[0]['count'] @@ -733,8 +725,7 @@ def forest_predict(schema_madlib, model, source, output, SELECT {id_col_name}, {majority_pred_cast_str} - FROM - ( + FROM ( SELECT {id_col_name}, {schema_madlib}._predict_dt_response( @@ -753,12 +744,12 @@ def forest_predict(schema_madlib, model, source, output, GROUP BY {id_col_name} """.format(**locals()) else: - len_dep_levels = len(split_quoted_delimited_str(dep_levels)) + len_dep_levels = len(dep_levels) normalized_majority_pred = unique_string() score_format = ', \n'.join([ '{temp}[{j}] as "estimated_prob_{c}"'. format(j=i+1, c=c.strip(' "'), temp=normalized_majority_pred) - for i, c in enumerate(split_quoted_delimited_str(dep_levels))]) + for i, c in enumerate(dep_levels)]) sql_prediction = """ CREATE TABLE {output} AS @@ -922,12 +913,12 @@ def get_tree(schema_madlib, model_table, gid, sample_id, def _calculate_oob_prediction( schema_madlib, model_table, cat_features_info_table, con_splits_table, oob_prediction_table, oob_view, sample_id, id_col_name, cat_features, - con_features, boolean_cats, grouping_cols, grp_key_to_grp_cols, dep, + con_features, source_table, grouping_cols, grp_key_to_grp_cols, dep, num_permutations, is_classification, importance, num_bins, filter_null, null_proxy=None): """Calculate predication for out-of-bag sample""" cat_features_str, con_features_str = get_feature_str( - schema_madlib, boolean_cats, cat_features, con_features, + schema_madlib, source_table, cat_features, con_features, "cat_levels_in_text", "cat_n_levels", null_proxy) join_str = "," if grouping_cols is None else "JOIN" @@ -1227,19 +1218,22 @@ def _calculate_oob_error(schema_madlib, oob_prediction_table, oob_error_table, def _create_summary_table(**kwargs): kwargs['features'] = ','.join(kwargs['cat_features'] + kwargs['con_features']) + kwargs['dep_type'] = get_expr_type(kwargs['dependent_variable'], + kwargs['training_table_name']) if kwargs['dep_list']: - kwargs['dep_list_str'] = ( - "$dep_list$" + - ','.join('"{0}"'.format(str(dep)) for dep in kwargs['dep_list']) + - "$dep_list$") + if is_psql_boolean_type(kwargs['dep_type']): + # Special handling for boolean since Python booleans start with + # capitals (i.e False instead of false) + # Note: dep_list is sorted, hence 'false' will be first + kwargs['dep_list_str'] = "'false, true'" + else: + kwargs['dep_list_str'] = '$__dep_list__${0}$__dep_list__$'.format( + ','.join(map(str, kwargs['dep_list']))) else: kwargs['dep_list_str'] = "NULL" - kwargs['indep_type'] = ', '.join(kwargs['all_cols_types'][col] - for col in kwargs['cat_features'] + - kwargs['con_features']) - kwargs['dep_type'] = _get_dep_type(kwargs['training_table_name'], - kwargs['dependent_variable']) + for col in (kwargs['cat_features'] + + kwargs['con_features'])) kwargs['cat_features_str'] = ','.join(kwargs['cat_features']) kwargs['con_features_str'] = ','.join(kwargs['con_features']) if kwargs['grouping_cols']: http://git-wip-us.apache.org/repos/asf/madlib/blob/e2534e44/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in b/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in index dee3e32..74bc518 100644 --- a/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in +++ b/src/ports/postgres/modules/recursive_partitioning/test/decision_tree.sql_in @@ -243,37 +243,39 @@ CREATE TABLE dt_golf ( "Cont_features" double precision[], cat_features text[], windy boolean, + windy2 boolean, class text ) ; -INSERT INTO dt_golf (id,"OUTLOOK",temperature,humidity,"Cont_features",cat_features, windy,class) VALUES -(1, 'sunny', 85, 85,ARRAY[85, 85], ARRAY['a', 'b'], false, 'Don''t Play'), -(2, 'sunny', 80, 90, ARRAY[80, 90], ARRAY['a', 'b'], true, 'Don''t Play'), -(3, 'overcast', 83, 78, ARRAY[83, 78], ARRAY['a', 'b'], false, 'Play'), -(4, 'rain', 70, NULL, ARRAY[70, 96], ARRAY['a', 'b'], false, 'Play'), -(5, 'rain', 68, 80, ARRAY[68, 80], ARRAY['a', 'b'], false, 'Play'), -(6, 'rain', NULL, 70, ARRAY[65, 70], ARRAY['a', 'b'], true, 'Don''t Play'), -(7, 'overcast', 64, 65, ARRAY[64, 65], ARRAY['c', 'b'], NULL , 'Play'), -(8, 'sunny', 72, 95, ARRAY[72, 95], ARRAY['a', 'b'], false, 'Don''t Play'), -(9, 'sunny', 69, 70, ARRAY[69, 70], ARRAY['a', 'b'], false, 'Play'), -(10, 'rain', 75, 80, ARRAY[75, 80], ARRAY['a', 'b'], false, 'Play'), -(11, 'sunny', 75, 70, ARRAY[75, 70], ARRAY['a', 'd'], true, 'Play'), -(12, 'overcast', 72, 90, ARRAY[72, 90], ARRAY['c', 'b'], NULL, 'Play'), -(13, 'overcast', 81, 75, ARRAY[81, 75], ARRAY['a', 'b'], false, 'Play'), -(15, NULL, 81, 75, ARRAY[81, 75], ARRAY['a', 'b'], false, 'Play'), -(16, 'overcast', NULL, 75, ARRAY[81, 75], ARRAY['a', 'd'], false, 'Play'), -(14, 'rain', 71, 80, ARRAY[71, 80], ARRAY['c', 'b'], true, 'Don''t Play'); +INSERT INTO dt_golf (id,"OUTLOOK",temperature,humidity,"Cont_features",cat_features, windy, windy2,class) VALUES +(1, 'sunny', 85, 85,ARRAY[85, 85], ARRAY['a', 'b'], false, false, 'Don''t Play'), +(2, 'sunny', 80, 90, ARRAY[80, 90], ARRAY['a', 'b'], true, false, 'Don''t Play'), +(3, 'overcast', 83, 78, ARRAY[83, 78], ARRAY['a', 'b'], false, false, 'Play'), +(4, 'rain', 70, NULL, ARRAY[70, 96], ARRAY['a', 'b'], false, false, 'Play'), +(5, 'rain', 68, 80, ARRAY[68, 80], ARRAY['a', 'b'], false, false, 'Play'), +(6, 'rain', NULL, 70, ARRAY[65, 70], ARRAY['a', 'b'], true, false, 'Don''t Play'), +(7, 'overcast', 64, 65, ARRAY[64, 65], ARRAY['c', 'b'], NULL, NULL, 'Play'), +(8, 'sunny', 72, 95, ARRAY[72, 95], ARRAY['a', 'b'], false, false, 'Don''t Play'), +(9, 'sunny', 69, 70, ARRAY[69, 70], ARRAY['a', 'b'], false, false, 'Play'), +(10, 'rain', 75, 80, ARRAY[75, 80], ARRAY['a', 'b'], false, false, 'Play'), +(11, 'sunny', 75, 70, ARRAY[75, 70], ARRAY['a', 'd'], true, false, 'Play'), +(12, 'overcast', 72, 90, ARRAY[72, 90], ARRAY['c', 'b'], NULL, NULL, 'Play'), +(13, 'overcast', 81, 75, ARRAY[81, 75], ARRAY['a', 'b'], false, false, 'Play'), +(15, NULL, 81, 75, ARRAY[81, 75], ARRAY['a', 'b'], false, false, 'Play'), +(16, 'overcast', NULL, 75, ARRAY[81, 75], ARRAY['a', 'd'], false, false, 'Play'), +(14, 'rain', 71, 80, ARRAY[71, 80], ARRAY['c', 'b'], true, false, 'Don''t Play'); update dt_golf set id_2 = id % 2; ------------------------------------------------------------------------- -- no grouping, with cross_validation +-- also adding a categorical with just a single level (windy2) DROP TABLE IF EXISTS train_output, train_output_summary, train_output_cv; SELECT tree_train('dt_golf'::text, -- source table 'train_output'::text, -- output model table 'id'::text, -- id column 'temperature::double precision'::text, -- response - 'humidity, windy, "Cont_features"'::text, -- features + 'humidity, windy2, "Cont_features"'::text, -- features NULL::text, -- exclude columns 'gini'::text, -- split criterion NULL::text, -- no grouping @@ -282,38 +284,46 @@ SELECT tree_train('dt_golf'::text, -- source table 6::integer, -- min split 2::integer, -- min bucket 3::integer, -- number of bins per continuous variable - 'cp=0.01, n_folds=2' -- cost-complexity pruning parameter + 'cp=0.01, n_folds=2', -- cost-complexity pruning parameter + 'null_as_category=True' ); SELECT _print_decision_tree(tree) from train_output; SELECT tree_display('train_output', False); SELECT impurity_var_importance FROM train_output; SELECT * FROM train_output_cv; +SELECT * FROM train_output_summary; ------------------------------------------------------------------------- -- grouping DROP TABLE IF EXISTS train_output, train_output_summary, predict_output; SELECT tree_train('dt_golf'::text, -- source table - 'train_output'::text, -- output model table - 'id'::text, -- id column - 'temperature::double precision'::text, -- response - '"OUTLOOK", humidity, windy, cat_features'::text, -- features - NULL::text, -- exclude columns - 'gini'::text, -- split criterion - 'class'::text, -- grouping - NULL::text, -- no weights - 10::integer, -- max depth - 6::integer, -- min split - 2::integer, -- min bucket - 3::integer, -- number of bins per continuous variable - 'cp=0.01' -- cost-complexity pruning parameter - ); + 'train_output'::text, -- output model table + 'id'::text, -- id column + 'temperature::double precision'::text, -- response + '"OUTLOOK", humidity, windy, cat_features'::text, -- features + NULL::text, -- exclude columns + 'gini'::text, -- split criterion + 'class'::text, -- grouping + NULL::text, -- no weights + 10::integer, -- max depth + 2::integer, -- min split + 1::integer, -- min bucket + 3::integer, -- number of bins per contvariable + 'cp=0', -- cost-complexity pruning parameter + 'max_surrogates=2' + ); SELECT _print_decision_tree(tree) from train_output; SELECT tree_display('train_output', FALSE); -SELECT * FROM train_output; -SELECT tree_display('train_output', False); +-- cat_features[2] has a single level. The cat_n_levels is in order of the +-- input categorical features. +-- In this case "OUTLOOL" = 1, windy = 2, cat_features[1] = 3, cat_features[2] = 4 +SELECT assert(cat_n_levels[4] = 1, + 'Categorical features with single level not being retained.') +FROM train_output +WHERE class = E'Don\'t Play'; -- testing tree_predict with a category not present in training table CREATE TABLE dt_golf2 as @@ -321,7 +331,7 @@ SELECT * FROM dt_golf UNION SELECT 15 as id, 1 as id_2, 'humid' as "OUTLOOK", 71 as temperature, 80 as humidity, ARRAY[90, 90] as "Cont_features", ARRAY['b', 'c'] as cat_features, - true as windy, 'Don''t Play' as class; + true as windy, false as windy2, 'Don''t Play' as class; \x off SELECT * FROM dt_golf2; SELECT tree_predict('train_output', 'dt_golf2', 'predict_output'); @@ -332,7 +342,6 @@ JOIN dt_golf2 USING (id); \x on -select * from train_output; select * from train_output_summary; ------------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/madlib/blob/e2534e44/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in index ecb24c7..364e459 100644 --- a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in +++ b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in @@ -40,7 +40,7 @@ SELECT forest_train( NULL, -- exclude columns NULL, -- no grouping 5, -- num of trees - NULL, -- num of random features + 2, -- num of random features TRUE, -- importance 1, -- num_permutations 10, -- max depth @@ -64,13 +64,13 @@ SELECT forest_train( 'train_output', -- output model table 'id', -- id column 'temperature::double precision', -- response - 'humidity, cat_features, windy, "Cont_features"', -- features + 'cat_features, windy, "Cont_features"', -- features NULL, -- exclude columns 'class', -- grouping 5, -- num of trees - NULL, -- num of random features + 5, -- num of random features TRUE, -- importance - 20, -- num_permutations + 20, -- num_permutations 10, -- max depth 1, -- min split 1, -- min bucket @@ -234,7 +234,7 @@ SELECT forest_train( NULL, -- exclude columns NULL, -- no grouping 5, -- num of trees - 1, -- num of random features + NULL, -- num of random features TRUE, -- importance 1, -- num_permutations 10, -- max depth @@ -287,21 +287,21 @@ INSERT INTO rf_gr_test (id,gr,f1,f2,f3,cl) VALUES DROP TABLE IF EXISTS train_output, train_output_summary, train_output_group; SELECT forest_train( - 'rf_gr_test', -- source table - 'train_output', -- output model table - 'id', -- id column - 'cl', -- response - 'f1, f2', -- features - NULL, -- exclude columns - 'gr', -- grouping - 2, -- num of trees - 1, -- num of random features - TRUE, -- importance - 1, -- num_permutations - 10, -- max depth - 1, -- min split - 1, -- min bucket - 2, -- number of bins per continuous variable - 'max_surrogates=0', - FALSE - ); + 'rf_gr_test', -- source table + 'train_output', -- output model table + 'id', -- id column + 'cl', -- response + 'f1, f2', -- features + NULL, -- exclude columns + 'gr', -- grouping + 2, -- num of trees + 1, -- num of random features + TRUE, -- importance + 1, -- num_permutations + 10, -- max depth + 1, -- min split + 1, -- min bucket + 2, -- number of bins per continuous variable + 'max_surrogates=0', + FALSE + ); http://git-wip-us.apache.org/repos/asf/madlib/blob/e2534e44/src/ports/postgres/modules/utilities/validate_args.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in index 4296491..f7f79e9 100644 --- a/src/ports/postgres/modules/utilities/validate_args.py_in +++ b/src/ports/postgres/modules/utilities/validate_args.py_in @@ -1,6 +1,8 @@ +from collections import Iterable import plpy import re import string +from types import StringTypes # Postgresql naming restrictions """ @@ -361,8 +363,8 @@ def get_cols_and_types(tbl): # ------------------------------------------------------------------------- -def get_expr_type(expr, tbl): - """ Return the type of an expression run on a given table +def get_expr_type(expressions, tbl): + """ Return the type of a multiple expressions run on a given table Note: this Args: @@ -371,17 +373,33 @@ def get_expr_type(expr, tbl): Returns: str """ - expr_type = plpy.execute(""" - SELECT pg_typeof({0}) AS type + # FIXME: Below transformation exist to ensure backwards compatibility + # Remove this when all callers have been modified to pass an Iterable 'expressions' + if (isinstance(expressions, StringTypes) or + not isinstance(expressions, Iterable)): + expressions = [expressions] + input_was_scalar = True + else: + input_was_scalar = False + + pg_type_expressions = ["pg_typeof({0})".format(e) for e in expressions] + expr_types = plpy.execute(""" + SELECT {0} as all_types FROM {1} LIMIT 1 - """.format(expr, tbl)) - if not expr_type: - plpy.error("Unable to get type of expression ({0}). " + """.format("ARRAY[{0}]".format(','.join(pg_type_expressions)), + tbl)) + if not expr_types: + plpy.error("Unable toget type of expression ({0}). " "Table {1} may not contain any valid tuples". - format(expr, tbl)) - return expr_type[0]['type'].lower() -# ------------------------------------------------------------------------- + format(expressions, tbl)) + expr_types = expr_types[0]["all_types"] + if input_was_scalar: + # output should be same form as input + return expr_types[0] + else: + return expr_types +# # ------------------------------------------------------------------------- def columns_exist_in_table(tbl, cols, schema_madlib="madlib"): @@ -518,14 +536,11 @@ def explicit_bool_to_text(tbl, cols, schema_madlib): Patch madlib.bool_to_text for columns that are of type boolean. """ m4_ifdef(<!__HAS_BOOL_TO_TEXT_CAST__!>, <!return cols!>, <!!>) - col_to_type = dict(get_cols_and_types(tbl)) patched = [] - for col in cols: - if col not in col_to_type: - plpy.error("Column ({col}) does not exist " - "in table ({tbl})".format(col=col, tbl=tbl)) - if col_to_type[col] == 'boolean': - patched.append(schema_madlib + ".bool_to_text(" + col + ")") + col_types = get_expr_type(cols, tbl) + for col, col_type in zip(cols, col_types): + if col_type == 'boolean': + patched.append("{0}.bool_to_text({1})".format(schema_madlib, col)) else: patched.append(col) return patched