This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch madlib2-master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 12fb88877de26f6efd6777f6ac9fa12960e3658e
Author: Nikhil Kak <n...@vmware.com>
AuthorDate: Tue Feb 20 16:18:15 2024 -0800

    PMML: Improve dev-check tests for decision tree
    
    JIRA: MADLIB-1517
    
    This commit adds a few more decision tree pmml tests that compare 
tree_predict's
    output with pypmml's output
---
 .../postgres/modules/pmml/test/pmml_dt.sql_in      | 117 ++++++++++++++++++++-
 1 file changed, 113 insertions(+), 4 deletions(-)

diff --git a/src/ports/postgres/modules/pmml/test/pmml_dt.sql_in 
b/src/ports/postgres/modules/pmml/test/pmml_dt.sql_in
index c03d6f34..86d97360 100644
--- a/src/ports/postgres/modules/pmml/test/pmml_dt.sql_in
+++ b/src/ports/postgres/modules/pmml/test/pmml_dt.sql_in
@@ -1,3 +1,10 @@
+\i m4_regexp(MADLIB_LIBRARY_PATH,
+             `\(.*\)/lib',
+              `\1/../modules/pmml/test/pmml.setup.sql_in'
+)
+
+m4_changequote(`<!'', `!>'')
+
 DROP TABLE IF EXISTS dt_golf;
 CREATE TABLE dt_golf (
     id integer NOT NULL,
@@ -24,12 +31,49 @@ INSERT INTO dt_golf 
(id,"OUTLOOK",temperature,humidity,windy,class) VALUES
 (13, 'overcast', 81, 75, 'false', 'Play'),
 (14, 'rain', 71, 80, 'true', 'Don''t Play');
 
+-- regression, no grouping
+DROP TABLE IF EXISTS train_output, train_output_summary;
+SELECT tree_train('dt_golf'::text,         -- source table
+                         'train_output'::text,    -- output model table
+                         'id'::text,              -- id column
+                         'temperature'::text,           -- response
+                         'humidity, windy'::text,   -- features
+                         NULL::text,        -- exclude columns
+                         'gini'::text,      -- split criterion
+                         NULL::text,     -- no grouping
+                         NULL::text,        -- no weights
+                         10::integer,       -- max depth
+                         3::integer,        -- min split
+                         1::integer,        -- min bucket
+                         3::integer,        -- number of bins per continuous 
variable
+                         'cp=0.01'          -- cost-complexity pruning 
parameter
+                         );
+
+SELECT _print_decision_tree(tree) from train_output;
+-- TODO: Enable these lines after the DT tree_display bug is fixed
+-- SELECT tree_display('train_output', False);
+
+DROP TABLE IF EXISTS tree_predict_output;
+SELECT tree_predict('train_output',
+                   'dt_golf',
+                   'tree_predict_output',
+                   'response');
+SELECT test_pmml_output('dt_golf', 'train_output', 'tree_predict_output','id', 
'estimated_temperature','predicted_temperature_pmml_prediction');
+
+DROP TABLE IF EXISTS tree_predict_output;
+SELECT tree_predict('train_output',
+                   'dt_golf',
+                   'tree_predict_output',
+                   'prob');
+SELECT test_pmml_output('dt_golf', 'train_output', 'tree_predict_output','id', 
'prob_temperature','predicted_temperature_pmml_prediction');
+
+
 -- regression, grouping
 DROP TABLE IF EXISTS train_output, train_output_summary;
 SELECT tree_train('dt_golf'::text,         -- source table
                          'train_output'::text,    -- output model table
                          'id'::text,              -- id column
-                         'temperature::double precision'::text,           -- 
response
+                         'temperature'::text,           -- response
                          'humidity, windy'::text,   -- features
                          NULL::text,        -- exclude columns
                          'gini'::text,      -- split criterion
@@ -46,8 +90,58 @@ SELECT _print_decision_tree(tree) from train_output;
 -- TODO: Enable these lines after the DT tree_display bug is fixed
 -- SELECT tree_display('train_output', False);
 
-SELECT pmml('train_output');
+DROP TABLE IF EXISTS tree_predict_output;
+SELECT tree_predict('train_output',
+                   'dt_golf',
+                   'tree_predict_output',
+                   'response');
+SELECT test_pmml_output('dt_golf', 'train_output', 'tree_predict_output','id', 
'estimated_temperature','predicted_temperature_pmml_prediction');
+
+DROP TABLE IF EXISTS tree_predict_output;
+SELECT tree_predict('train_output',
+                   'dt_golf',
+                   'tree_predict_output',
+                   'prob');
+SELECT test_pmml_output('dt_golf', 'train_output', 'tree_predict_output','id', 
'prob_temperature','predicted_temperature_pmml_prediction');
+
 -------------------------------------------------------------------------
+-- classification, no grouping
+DROP TABLE IF EXISTS train_output, train_output_summary;
+SELECT tree_train('dt_golf'::text,         -- source table
+                         'train_output'::text,    -- output model table
+                         'id'::text,              -- id column
+                         '"OUTLOOK"'::text,           -- response
+                         'humidity, windy'::text,   -- features
+                         NULL::text,        -- exclude columns
+                         'gini'::text,      -- split criterion
+                         NULL::text,     -- no grouping
+                         NULL::text,        -- no weights
+                         10::integer,       -- max depth
+                         3::integer,        -- min split
+                         1::integer,        -- min bucket
+                         3::integer,        -- number of bins per continuous 
variable
+                         'cp=0.01'          -- cost-complexity pruning 
parameter
+                         );
+
+SELECT _print_decision_tree(tree) from train_output;
+-- SELECT tree_display('train_output', False);
+
+DROP TABLE IF EXISTS tree_predict_output;
+SELECT tree_predict('train_output',
+                   'dt_golf',
+                   'tree_predict_output',
+                   'response');
+SELECT test_pmml_output('dt_golf', 'train_output', 'tree_predict_output','id', 
'estimated_OUTLOOK','predicted_OUTLOOK_pmml_prediction');
+
+DROP TABLE IF EXISTS tree_predict_output;
+SELECT tree_predict('train_output',
+                   'dt_golf',
+                   'tree_predict_output',
+                   'prob');
+SELECT test_pmml_output('dt_golf', 'train_output', 'tree_predict_output','id', 
'estimated_prob_overcast','probability_overcast');
+SELECT test_pmml_output('dt_golf', 'train_output', 'tree_predict_output','id', 
'estimated_prob_rain','probability_rain');
+SELECT test_pmml_output('dt_golf', 'train_output', 'tree_predict_output','id', 
'estimated_prob_sunny','probability_sunny');
+
 
 -- classification, grouping
 DROP TABLE IF EXISTS train_output, train_output_summary;
@@ -58,7 +152,7 @@ SELECT tree_train('dt_golf'::text,         -- source table
                          'humidity, windy'::text,   -- features
                          NULL::text,        -- exclude columns
                          'gini'::text,      -- split criterion
-                         'class'::text,     -- no grouping
+                         'class'::text,     -- grouping
                          NULL::text,        -- no weights
                          10::integer,       -- max depth
                          3::integer,        -- min split
@@ -70,6 +164,21 @@ SELECT tree_train('dt_golf'::text,         -- source table
 SELECT _print_decision_tree(tree) from train_output;
 -- SELECT tree_display('train_output', False);
 
-SELECT pmml('train_output');
+DROP TABLE IF EXISTS tree_predict_output;
+SELECT tree_predict('train_output',
+                   'dt_golf',
+                   'tree_predict_output',
+                   'response');
+SELECT test_pmml_output('dt_golf', 'train_output', 'tree_predict_output','id', 
'estimated_OUTLOOK','predicted_OUTLOOK_pmml_prediction');
+
+DROP TABLE IF EXISTS tree_predict_output;
+SELECT tree_predict('train_output',
+                   'dt_golf',
+                   'tree_predict_output',
+                   'prob');
+SELECT test_pmml_output('dt_golf', 'train_output', 'tree_predict_output','id', 
'estimated_prob_overcast','probability_overcast');
+SELECT test_pmml_output('dt_golf', 'train_output', 'tree_predict_output','id', 
'estimated_prob_rain','probability_rain');
+SELECT test_pmml_output('dt_golf', 'train_output', 'tree_predict_output','id', 
'estimated_prob_sunny','probability_sunny');
+
 -------------------------------------------------------------------------
 

Reply via email to