Github user jingyimei commented on a diff in the pull request:
https://github.com/apache/madlib/pull/289#discussion_r201501481
--- Diff:
src/ports/postgres/modules/recursive_partitioning/random_forest.py_in ---
@@ -1291,38 +1300,64 @@ def _create_group_table(
schema_madlib, output_table_name, oob_error_table,
importance_table, cat_features_info_table, grp_key_to_grp_cols,
grouping_cols, tree_terminated):
- """ Ceate the group table for random forest"""
+ """ Create the group table for random forest"""
+
+ cat_var_importance_str = ''
+ con_var_importance_str = ''
+ impurity_var_importance_str = ''
+ left_join_importance_table_str = ''
+ join_impurity_table_str = ''
+
+ if importance_table:
+ impurity_var_importance_table_name = unique_string(desp='impurity')
+ plpy.execute("""
+ CREATE TEMP TABLE {impurity_var_importance_table_name} AS
+ SELECT
+ gid,
+ {schema_madlib}.array_avg(impurity_var_importance, False)
AS impurity_var_importance
+ FROM {output_table_name}
+ GROUP BY gid
+ """.format(**locals()))
+
+ cat_var_importance_str = ", cat_var_importance AS
oob_cat_var_importance,"
+ con_var_importance_str = "con_var_importance AS
oob_con_var_importance,"
+ impurity_var_importance_str = "impurity_var_importance"
+ left_join_importance_table_str = """LEFT OUTER JOIN
{importance_table}
+ USING (gid)""".format(importance_table=importance_table)
+ join_impurity_table_str = """JOIN
{impurity_var_importance_table_name} USING
(gid)""".format(impurity_var_importance_table_name=impurity_var_importance_table_name)
+
grouping_cols_str = ('' if grouping_cols is None
else grouping_cols + ",")
group_table_name = add_postfix(output_table_name, "_group")
+
sql_create_group_table = """
CREATE TABLE {group_table_name} AS
SELECT
gid,
{grouping_cols_str}
- grp_finished as success,
+ grp_finished AS success,
cat_n_levels,
cat_levels_in_text,
- oob_error,
- cat_var_importance,
- con_var_importance
+ oob_error
+ {cat_var_importance_str}
--- End diff --
We can make it consistent with oob_*
---