Github user jingyimei commented on a diff in the pull request: https://github.com/apache/madlib/pull/289#discussion_r201501481 --- Diff: src/ports/postgres/modules/recursive_partitioning/random_forest.py_in --- @@ -1291,38 +1300,64 @@ def _create_group_table( schema_madlib, output_table_name, oob_error_table, importance_table, cat_features_info_table, grp_key_to_grp_cols, grouping_cols, tree_terminated): - """ Ceate the group table for random forest""" + """ Create the group table for random forest""" + + cat_var_importance_str = '' + con_var_importance_str = '' + impurity_var_importance_str = '' + left_join_importance_table_str = '' + join_impurity_table_str = '' + + if importance_table: + impurity_var_importance_table_name = unique_string(desp='impurity') + plpy.execute(""" + CREATE TEMP TABLE {impurity_var_importance_table_name} AS + SELECT + gid, + {schema_madlib}.array_avg(impurity_var_importance, False) AS impurity_var_importance + FROM {output_table_name} + GROUP BY gid + """.format(**locals())) + + cat_var_importance_str = ", cat_var_importance AS oob_cat_var_importance," + con_var_importance_str = "con_var_importance AS oob_con_var_importance," + impurity_var_importance_str = "impurity_var_importance" + left_join_importance_table_str = """LEFT OUTER JOIN {importance_table} + USING (gid)""".format(importance_table=importance_table) + join_impurity_table_str = """JOIN {impurity_var_importance_table_name} USING (gid)""".format(impurity_var_importance_table_name=impurity_var_importance_table_name) + grouping_cols_str = ('' if grouping_cols is None else grouping_cols + ",") group_table_name = add_postfix(output_table_name, "_group") + sql_create_group_table = """ CREATE TABLE {group_table_name} AS SELECT gid, {grouping_cols_str} - grp_finished as success, + grp_finished AS success, cat_n_levels, cat_levels_in_text, - oob_error, - cat_var_importance, - con_var_importance + oob_error + {cat_var_importance_str} --- End diff -- We can make it consistent with oob_*
---