Repository: incubator-madlib Updated Branches: refs/heads/master 18b8486ca -> ec60b83d2
RF: Filter NULL dependent values in OOB JIRA: MADLIB-1097 Added `filter_null` string obtained from decision_tree.py into the OOB view to exclude rows that have NULL dependent values. Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/9b45ecaa Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/9b45ecaa Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/9b45ecaa Branch: refs/heads/master Commit: 9b45ecaaadb9e0d4999dc49e72df8a97cb7692d2 Parents: 18b8486 Author: Rahul Iyer <ri...@apache.org> Authored: Wed May 3 17:07:55 2017 -0700 Committer: Rahul Iyer <ri...@apache.org> Committed: Wed May 10 15:56:57 2017 -0700 ---------------------------------------------------------------------- .../recursive_partitioning/random_forest.py_in | 24 ++++++++++++-------- .../test/random_forest.sql_in | 14 +++++++----- 2 files changed, 23 insertions(+), 15 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/9b45ecaa/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in index 4b6f2d6..1b5ad88 100644 --- a/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in +++ b/src/ports/postgres/modules/recursive_partitioning/random_forest.py_in @@ -450,7 +450,8 @@ def forest_train( bins['cat_origin']]) con_splits_table = unique_string() - _create_con_splits_table(schema_madlib, con_splits_table, grouping_cols, grp_key_to_grp_cols, bins) + _create_con_splits_table(schema_madlib, con_splits_table, + grouping_cols, grp_key_to_grp_cols, bins) ################################################################## # create views and tables for training (growing) of trees @@ -600,7 +601,8 @@ def forest_train( con_splits_table, oob_prediction_table, oob_view, sample_id, id_col_name, cat_features, con_features, boolean_cats, grouping_cols, grp_key_to_grp_cols, dep, - num_permutations, is_classification, importance, num_bins) + num_permutations, is_classification, importance, + num_bins, filter_null) ################################################################### # evaluating and summerizing random forest @@ -626,9 +628,9 @@ def forest_train( # calculated, otherwise we use an empty table which will be used later # for an outer join. if importance: - _calculate_variable_importance(schema_madlib, - oob_prediction_table, is_classification, - importance_table, len(cat_features), len(con_features)) + _calculate_variable_importance( + schema_madlib, oob_prediction_table, is_classification, + importance_table, len(cat_features), len(con_features)) _create_group_table(schema_madlib, output_table_name, oob_error_table, importance_table, @@ -926,7 +928,7 @@ def _calculate_oob_prediction( schema_madlib, model_table, cat_features_info_table, con_splits_table, oob_prediction_table, oob_view, sample_id, id_col_name, cat_features, con_features, boolean_cats, grouping_cols, grp_key_to_grp_cols, dep, - num_permutations, is_classification, importance, num_bins): + num_permutations, is_classification, importance, num_bins, filter_null): """Calculate predication for out-of-bag sample""" cat_features_str, con_features_str = get_feature_str( @@ -1045,6 +1047,7 @@ def _calculate_oob_prediction( LEFT OUTER JOIN -- empty if variable importance is disabled {oob_var_dist_view} USING (gid) + WHERE {filter_null} """.format(**locals()) plpy.notice("sql_oob_predict : " + str(sql_oob_predict)) plpy.execute(sql_oob_predict) @@ -1091,11 +1094,14 @@ def _create_con_splits_table(schema_madlib, con_splits_table, grouping_cols, # ------------------------------------------------------------------------------ -def _calculate_variable_importance(schema_madlib, oob_prediction_table, - is_classification, importance_table, n_cat, n_con): +def _calculate_variable_importance( + schema_madlib, oob_prediction_table, is_classification, + importance_table, n_cat, n_con): if not is_classification: + # squared error score_expression = "-((oob_prediction - dep)^2)".format(**locals()) else: + # misclassification score_expression = """ CASE WHEN dep = oob_prediction::integer THEN 1. @@ -1200,7 +1206,7 @@ def _create_summary_table(**kwargs): kwargs['indep_type'] = ', '.join(kwargs['all_cols_types'][col] for col in kwargs['cat_features'] + - kwargs['con_features']) + kwargs['con_features']) kwargs['dep_type'] = _get_dep_type(kwargs['training_table_name'], kwargs['dependent_variable']) kwargs['cat_features_str'] = ','.join(kwargs['cat_features']) http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/9b45ecaa/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in index 086e74b..f3ad93c 100644 --- a/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in +++ b/src/ports/postgres/modules/recursive_partitioning/test/random_forest.sql_in @@ -13,16 +13,18 @@ INSERT INTO dt_golf (id,"OUTLOOK",temperature,humidity,cont_features,windy,class (1, 'sunny', 85, 85,ARRAY[85, 85], false, 'Don''t Play'), (2, 'sunny', 80, 90,ARRAY[80, 90], true, 'Don''t Play'), (3, 'overcast', 83, 78,ARRAY[83, 78], false, 'Play'), -(4, 'rain', 70, 96,ARRAY[70, 96], false, 'Play'), +(4, 'rain', 70, NULL,ARRAY[70, 96], false, 'Play'), (5, 'rain', 68, 80,ARRAY[68, 80], false, 'Play'), -(6, 'rain', 65, 70,ARRAY[65, 70], true, 'Don''t Play'), -(7, 'overcast', 64, 65,ARRAY[64, 65], true, 'Play'), +(6, 'rain', NULL, 70,ARRAY[65, 70], true, 'Don''t Play'), +(7, 'overcast', 64, 65,ARRAY[64, 65],NULL, 'Play'), (8, 'sunny', 72, 95,ARRAY[72, 95], false, 'Don''t Play'), (9, 'sunny', 69, 70,ARRAY[69, 70], false, 'Play'), (10, 'rain', 75, 80,ARRAY[75, 80], false, 'Play'), (11, 'sunny', 75, 70,ARRAY[75, 70], true, 'Play'), -(12, 'overcast', 72, 90,ARRAY[72, 90], true, 'Play'), +(12, 'overcast', 72, 90,ARRAY[72, 90], NULL, 'Play'), (13, 'overcast', 81, 75,ARRAY[81, 75], false, 'Play'), +(15, NULL, 81, 75,ARRAY[81, 75], false, 'Play'), +(16, 'overcast', NULL, 75,ARRAY[81, 75], false, 'Play'), (14, 'rain', 71, 80,ARRAY[71, 80], true, 'Don''t Play'); ------------------------------------------------------------------------- @@ -116,7 +118,7 @@ SELECT forest_train( 'dt_golf'::TEXT, -- source table 'train_output'::TEXT, -- output model table 'id'::TEXT, -- id column - 'temperature::double precision'::TEXT, -- response + 'temperature::double precision'::TEXT, -- response 'class, temperature, windy'::TEXT, -- features NULL::TEXT, -- exclude columns NULL::TEXT, -- no grouping @@ -150,7 +152,7 @@ SELECT forest_train( 'temperature::double precision'::TEXT, -- response 'humidity'::TEXT, -- features NULL::TEXT, -- exclude columns - 'class,windy', -- grouping + 'class', -- grouping 5, -- num of trees 1, -- num of random features TRUE::BOOLEAN, -- importance