Repository: incubator-madlib Updated Branches: refs/heads/master ff1b0f883 -> 69f788662
Sample: Add function to split train/test JIRA: MADLIB-1119 Add utility to create train and test samples from an input table. This function uses the stratified sampling to create the samples. Closes #166 Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/69f78866 Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/69f78866 Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/69f78866 Branch: refs/heads/master Commit: 69f788662405bf27763d021f05bb89cc4d6a6a17 Parents: b963ac1 Author: Cooper Sloan <cooper.sl...@gmail.com> Authored: Fri Aug 18 06:23:29 2017 -0700 Committer: Rahul Iyer <ri...@apache.org> Committed: Fri Aug 18 06:28:18 2017 -0700 ---------------------------------------------------------------------- doc/mainpage.dox.in | 8 +- .../modules/sample/stratified_sample.py_in | 2 +- .../modules/sample/test/test_train_split.sql_in | 85 +++++ .../modules/sample/test_train_split.py_in | 319 +++++++++++++++++++ .../modules/sample/test_train_split.sql_in | 319 +++++++++++++++++++ 5 files changed, 730 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/69f78866/doc/mainpage.dox.in ---------------------------------------------------------------------- diff --git a/doc/mainpage.dox.in b/doc/mainpage.dox.in index e27e14a..be45369 100644 --- a/doc/mainpage.dox.in +++ b/doc/mainpage.dox.in @@ -142,12 +142,15 @@ Contains graph algorithms. @defgroup grp_wcc Weakly Connected Components @} -@defgroup grp_mdl Model Evaluation -@{Contains functions for evaluating accuracy and validation of predictive methods. @} +@defgroup grp_mdl Model Selection +@{Contains functions for model selection and model evaluation. @} @defgroup grp_validation Cross Validation @ingroup grp_mdl @defgroup grp_pred Prediction Metrics @ingroup grp_mdl + @defgroup grp_test_train_split Test Train Split + @ingroup grp_mdl + @defgroup grp_stats Statistics @{Contains statistics modules @} @@ -264,6 +267,7 @@ Contains graph algorithms. @defgroup grp_strs Stratified Sampling @ingroup grp_sampling + @defgroup grp_sessionize Sessionize @ingroup grp_utility_functions http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/69f78866/src/ports/postgres/modules/sample/stratified_sample.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/sample/stratified_sample.py_in b/src/ports/postgres/modules/sample/stratified_sample.py_in index e7762ef..0d29b41 100644 --- a/src/ports/postgres/modules/sample/stratified_sample.py_in +++ b/src/ports/postgres/modules/sample/stratified_sample.py_in @@ -167,7 +167,7 @@ def validate_strs (source_table, output_table, proportion, glist, target_cols): _assert(not table_is_empty(source_table), "Sample: Source table ({source_table}) is empty!".format(**locals())) - _assert(proportion > 0 and proportion < 1, + _assert(proportion > 0 and proportion <= 1, "Sample: Proportion isn't in the range (0,1)!") if glist is not None: http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/69f78866/src/ports/postgres/modules/sample/test/test_train_split.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/sample/test/test_train_split.sql_in b/src/ports/postgres/modules/sample/test/test_train_split.sql_in new file mode 100644 index 0000000..5ae0ade --- /dev/null +++ b/src/ports/postgres/modules/sample/test/test_train_split.sql_in @@ -0,0 +1,85 @@ +/* ----------------------------------------------------------------------- *//** + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + *//* ----------------------------------------------------------------------- */ + +DROP TABLE IF EXISTS test; + +CREATE TABLE test( + id1 INTEGER, + id2 INTEGER, + gr1 INTEGER, + gr2 INTEGER +); + +INSERT INTO test VALUES +(1,0,1,1), +(2,0,1,1), +(3,0,1,1), +(4,0,1,1), +(5,0,1,1), +(6,0,1,1), +(7,0,1,1), +(8,0,1,1), +(9,0,1,1), +(9,0,1,1), +(9,0,1,1), +(9,0,1,1), +(0,1,1,2), +(0,2,1,2), +(0,3,1,2), +(0,4,1,2), +(0,5,1,2), +(0,6,1,2), +(10,10,2,2), +(20,20,2,2) +; + +SELECT setseed(0); + +DROP TABLE IF EXISTS out_train,out_test,out; +SELECT test_train_split('test', 'out', 0.1, 0.2, NULL, 'id1,id2,gr1,gr2', FALSE, TRUE); +SELECT assert(count(*) = 2, 'Wrong number of samples') FROM out_train; +SELECT assert(count(*) = 4, 'Wrong number of samples') FROM out_test; + +DROP TABLE IF EXISTS out_train,out_test,out; +SELECT test_train_split('test', 'out', 0.1, 0.2, NULL, 'id1,id2,gr1,gr2', FALSE, FALSE); +SELECT assert(count(*) = 2, 'Wrong number of samples') FROM out WHERE split=1; +SELECT assert(count(*) = 4, 'Wrong number of samples') FROM out WHERE split=0; + + +DROP TABLE IF EXISTS out_train,out_test,out; +SELECT test_train_split('test', 'out', 0.5, 0.5, NULL, 'id1,id2,gr1,gr2', TRUE, FALSE); +SELECT assert(count(*) = 20, 'Wrong number of samples') FROM out; + +DROP TABLE IF EXISTS out; +SELECT test_train_split('test', 'out', 0.5, 0.5, 'gr1,gr2', 'id1,id2', TRUE, FALSE); +select * from out; +SELECT assert(count(*) = 6, 'Wrong number of samples') +FROM out WHERE gr1 = 1 AND gr2 = 1 AND split = 0; +SELECT assert(count(*) = 6, 'Wrong number of samples') +FROM out WHERE gr1 = 1 AND gr2 = 1 AND split = 1; +SELECT assert(count(*) = 3, 'Wrong number of samples') +FROM out WHERE gr1 = 1 AND gr2 = 2 AND split = 0; +SELECT assert(count(*) = 3, 'Wrong number of samples') +FROM out WHERE gr1 = 1 AND gr2 = 2 AND split = 1; +SELECT assert(count(*) = 1, 'Wrong number of samples') +FROM out WHERE gr1 = 2 AND gr2 = 2 AND split = 0; +SELECT assert(count(*) = 1, 'Wrong number of samples') +FROM out WHERE gr1 = 2 AND gr2 = 2 AND split = 1; http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/69f78866/src/ports/postgres/modules/sample/test_train_split.py_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/sample/test_train_split.py_in b/src/ports/postgres/modules/sample/test_train_split.py_in new file mode 100644 index 0000000..6056b2a --- /dev/null +++ b/src/ports/postgres/modules/sample/test_train_split.py_in @@ -0,0 +1,319 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import plpy +from utilities.control import MinWarning +from utilities.utilities import _assert +from utilities.utilities import extract_keyvalue_params +from utilities.utilities import add_postfix +from utilities.utilities import unique_string +from utilities.utilities import split_quoted_delimited_str +from utilities.validate_args import table_exists +from utilities.validate_args import columns_exist_in_table +from utilities.validate_args import table_is_empty +from utilities.validate_args import get_expr_type +from utilities.validate_args import get_cols +from graph.graph_utils import _check_groups +from graph.graph_utils import _grp_from_table + +m4_changequote(` <!', `!>') + + +def _get_sql_string(str): + if str: + return "'" + str + "'" + return "NULL" + + +def test_train_split(schema_madlib, source_table, output_table, train_proportion, + test_proportion, grouping_cols, target_cols, with_replacement, + separate_output_tables, **kwargs): + """ + test train split function + Args: + @param source_table Input table name. + @param output_table Output table name. + @param train_proportion The ratio of training data to the entire + input table + @param test_proportion The ratio of test data to the entire + input table + @param grouping_cols (Default: NULL) The columns to distinguish + each strata. + @param target_cols (Default: NULL) The columns to include in + the output. + @param with_replacement (Default: FALSE) The sampling method. + @param separate_output_tables (Default: FALSE) Create two output tables, + <output_table>_train and <output_table>_test. + Otherwise one output table is created with + and additional column 'split' which takes the + value 0 for test and 1 for training. + + """ + with MinWarning("warning"): + if test_proportion is None: + test_proportion = 1 - train_proportion + validate_strs(source_table, output_table, train_proportion, test_proportion, + split_quoted_delimited_str(grouping_cols), target_cols, + with_replacement) + grouping_cols = _get_sql_string(grouping_cols) + target_cols = _get_sql_string(target_cols) + with_replacement = with_replacement or "False" + strat_query = """ + SELECT {schema_madlib}.stratified_sample( + '{strat_source_table}', + '{strat_out_table}', + '{strat_proportion}', + {strat_grouping_cols}, + {strat_target_cols}, + {strat_with_replacement} + ) + """ + strat_out_table = unique_string() + q = strat_query.format( + schema_madlib=schema_madlib, + strat_source_table=source_table, + strat_out_table=strat_out_table, + strat_proportion=train_proportion + test_proportion, + strat_grouping_cols=grouping_cols, + strat_with_replacement=with_replacement, + strat_target_cols=target_cols + ) + plpy.execute(q) + test_table = add_postfix(output_table, "_test") + train_table = add_postfix(output_table, "_train") + if not separate_output_tables: + test_table = unique_string() + train_table = unique_string() + test_query = strat_query.format( + schema_madlib=schema_madlib, + strat_source_table=strat_out_table, + strat_out_table=test_table, + strat_proportion=(test_proportion / + (train_proportion + test_proportion)), + strat_grouping_cols=grouping_cols, + strat_with_replacement=False, + strat_target_cols=target_cols + ) + plpy.execute(test_query) + train_query = """ + CREATE TABLE {train_table} AS + SELECT * FROM {strat_out_table} + EXCEPT ALL + SELECT * FROM {test_table} + """.format(train_table=train_table, + strat_out_table=strat_out_table, + test_table=test_table) + plpy.execute(train_query) + clean_up_tables = [strat_out_table] + if not separate_output_tables: + union_query = """ + CREATE TABLE {output_table} AS + SELECT *,0 AS split FROM {test_table} + UNION ALL + SELECT *,1 AS split FROM {train_table} + """.format(output_table=output_table, + test_table=test_table, + train_table=train_table) + plpy.execute(union_query) + clean_up_tables += [train_table, test_table] + clean_up_query = """ + DROP TABLE IF EXISTS {clean_up_tables} + """.format(clean_up_tables=",".join(clean_up_tables)) + plpy.execute(clean_up_query) + return + + +def validate_strs(source_table, output_table, train_proportion, test_proportion, glist, target_cols, with_replacement): + + _assert(output_table and output_table.strip().lower() not in ('null', ''), + "Sample: Invalid output table name {output_table}!".format(**locals())) + _assert(not table_exists(output_table), + "Sample: Output table already exists!".format(**locals())) + + _assert(source_table and source_table.strip().lower() not in ('null', ''), + "Sample: Invalid Source table name!".format(**locals())) + _assert(table_exists(source_table), + "Sample: Source table ({source_table}) is missing!".format(**locals())) + _assert(not table_is_empty(source_table), + "Sample: Source table ({source_table}) is empty!".format(**locals())) + + for proportion in [train_proportion, test_proportion]: + _assert(proportion > 0 and proportion < 1, + "Sample: Proportions aren't in the range (0,1)!") + if not with_replacement: + _assert(train_proportion + test_proportion <= 1, + "Sample: Proportions add up to greater than 1!") + + if glist is not None: + _assert(columns_exist_in_table(source_table, glist), + ("""Sample: Not all columns from {glist} are present in source""" + + """ table ({source_table}).""").format(**locals())) + + if not (target_cols is None or target_cols is '*'): + tlist = split_quoted_delimited_str(target_cols) + _assert(columns_exist_in_table(source_table, tlist), + ("""Sample: Not all columns from {target_cols} are present in""" + + """ edge table ({source_table})""").format(**locals())) + return + + +def test_train_split_help(schema_madlib, message, **kwargs): + """ + Help function for test_train_split + + Args: + @param schema_madlib + @param message: string, Help message string + @param kwargs + + Returns: + String. Help/usage information + """ + if not message: + help_string = """ +----------------------------------------------------------------------- + SUMMARY +----------------------------------------------------------------------- + +Given a table, test_train_split returns a random sample of the +table for testing and training. It is possible to use with or without +replacement sampling methods, specify a set of target columns, and a +set of grouping columns, in which case, stratified sampling will be +performed. + +For more details on function usage: + SELECT {schema_madlib}.test_train_split('usage'); + SELECT {schema_madlib}.test_train_split('example'); + """ + elif message.lower() in ['usage', 'help', '?']: + help_string = """ + +Given a table, test train split returns a proportion of records for +each group (strata). It is possible to use with or without replacement +sampling methods, specify a set of target columns, and assume the +whole table is a single strata. + +---------------------------------------------------------------------------- + USAGE +---------------------------------------------------------------------------- + + SELECT {schema_madlib}.test_train_split( + source_table TEXT, -- Name of the table containing the input data. + output_table TEXT, -- Output table name. + train_proportion FLOAT8, -- The ratio of train sample size to the + -- number of records. + test_proportion FLOAT8, -- The ratio of test sample size to the + -- number of records. + grouping_cols TEXT -- (Default: NULL) The columns to distinguish + -- each strata. + target_cols TEXT, -- (Default: NULL) The columns to include in + -- the output. + with_replacement BOOLEAN -- (Default: FALSE) The sampling method. + separate_output_tables + BOOLEAN -- (Default: FALSE) Separate the output table + -- into $output_table$_train and + -- $output_table$_test, otherwise, the split + -- column in output_table will identify 1 for + -- train set and 0 for test set. + +If grouping_cols is NULL, the whole table is treated as a single group and +sampled accordingly. + +If target_cols is NULL or '*', all of the columns will be included in the +output table. + +If with_replacement is TRUE, each sample is independent (the same row may +be selected in the sample set more than once). Else (if with_replacement +is FALSE), a row can be selected at most once. +); +""" + elif message.lower() in ("example", "examples"): + help_string = """ +---------------------------------------------------------------------------- + EXAMPLES +---------------------------------------------------------------------------- + +-- Create an input table +DROP TABLE IF EXISTS test; + +CREATE TABLE test( + id1 INTEGER, + id2 INTEGER, + gr1 INTEGER, + gr2 INTEGER +); + +INSERT INTO test VALUES +(1,0,1,1), +(2,0,1,1), +(3,0,1,1), +(4,0,1,1), +(5,0,1,1), +(6,0,1,1), +(7,0,1,1), +(8,0,1,1), +(9,0,1,1), +(9,0,1,1), +(9,0,1,1), +(9,0,1,1), +(0,1,1,2), +(0,2,1,2), +(0,3,1,2), +(0,4,1,2), +(0,5,1,2), +(0,6,1,2), +(10,10,2,2), +(20,20,2,2), +(30,30,2,2), +(40,40,2,2), +(50,50,2,2), +(60,60,2,2), +(70,70,2,2) +; + +-- Sample without replacement +DROP TABLE IF EXISTS out; +SELECT madlib.test_train_split( + 'test', -- Source table + 'out', -- Output table + 0.5, -- Sample proportion + 0.5, -- Sample proportion + 'gr1,gr2', -- Strata definition + 'id1,id2', -- Columns to output + FALSE, -- Sample without replacement + FALSE); -- Do not separate output tables +SELECT * FROM out ORDER BY split,gr1,gr2,id1,id2; + +-- Sample with replacement +DROP TABLE IF EXISTS out_train, out_test; +SELECT madlib.test_train_split( + 'test', -- Source table + 'out', -- Output table + 0.5, -- train_proportion + NULL, -- Default = 1 - train_proportion = 0.5 + 'gr1,gr2', -- Strata definition + 'id1,id2', -- Columns to output + TRUE, -- Sample with replacement + TRUE); -- Separate output tables +SELECT * FROM out_train ORDER BY gr1,gr2,id1,id2; +""" + else: + help_string = "No such option. Use {schema_madlib}.graph_sssp()" + + return help_string.format(schema_madlib=schema_madlib) http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/69f78866/src/ports/postgres/modules/sample/test_train_split.sql_in ---------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/sample/test_train_split.sql_in b/src/ports/postgres/modules/sample/test_train_split.sql_in new file mode 100644 index 0000000..ba1adb3 --- /dev/null +++ b/src/ports/postgres/modules/sample/test_train_split.sql_in @@ -0,0 +1,319 @@ +/* ----------------------------------------------------------------------- *//** + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * + * @file test_train_split.sql_in + * + * @brief SQL functions for test train split. + * @date 07/19/2017 + * + * @sa Given a table, test train split returns a proportion of records + * for each group (strata). + * + *//* ----------------------------------------------------------------------- */ + +m4_include(`SQLCommon.m4') + + +/** +@addtogroup grp_test_train_split + +<div class="toc"><b>Contents</b> +<ul> +<li><a href="#strs">test train split</a></li> +<li><a href="#examples">Examples</a></li> +</ul> +</div> + +@brief A method for independently sampling subpopulations (strata). + +test_train_split is a utility to create test and +training data set as subsets of a single table. + +@anchor strs +@par test train split + +<pre class="syntax"> +test_train_split( source_table, + output_table, + train_proportion, + test_proportion, + grouping_cols, + target_cols, + with_replacement + ) +</pre> + +\b Arguments +<dl class="arglist"> +<dt>source_table</dt> +<dd>TEXT. Name of the table containing the input data.</dd> + +<dt>output_table</dt> +<dd>Name of output table. A new INTEGER column on the right +called 'split' will identify 1 for train set and 0 for test set, +unless the 'separate_output_tables' parameter below is TRUE, +in which case two output tables will be created using +the 'output_table' name with the suffixes '_train' and '_test'. +The output table contains all the columns present in the source +table unless otherwise specified in the 'target_cols' parameter below. </dd> + +<dt>train_proportion</dt> +<dd>FLOAT8 in the range (0,1). Proportion of the dataset to include +in the train split. If the 'grouping_col' parameter is specified below, +each group will be sampled independently using the +train proportion, i.e., in a stratified fashion.</dd> + +<dt>test_proportion</dt> +<dd>FLOAT8 in the range (0,1). Proportion of the dataset to include +in the test split. Default is the complement to the train +proportion (1-'train_proportion'). If the 'grouping_col' +parameter is specified below, each group will be sampled +independently using the train proportion, +i.e., in a stratified fashion.</dd> + +<dt>grouping_cols (optional)</dt> +<dd>TEXT, default: NULL. A single column or a list of comma-separated columns + that defines how to stratify. When this parameter is NULL, +the train-test split is not stratified.</dd> + +<dt>target_cols (optional)</dt> +<dd>TEXT, default NULL. A comma-separated list of columns +to appear in the 'output_table'. If NULL or '*', all +columns from the 'source_table' will appear in +the 'output_table'.</dd> + +@anchor note +@note + Do not include 'grouping_cols' in the parameter 'target_cols', + because they are always included in the 'output_table'. + +<dt>with_replacement (optional)</dt> +<dd>BOOLEAN, default FALSE. Determines whether to sample +with replacement or without replacement (default). +With replacement means that it is possible that the +same row may appear in the sample set more than once. +Without replacement means a given row can be selected +only once.</dd> +</dl> + +<dt>separate_output_tables (optional)</dt> +<dd>BOOLEAN, default FALSE. If TRUE, two output tables will be created using +the 'output_table' name with the suffixes '_train' and '_test'.</dd> +</dl> + + +@anchor examples +@par Examples + +Please note that due to the random nature of sampling, your +results may look different from those below. + +-# Create an input table: +<pre class="syntax"> +DROP TABLE IF EXISTS test; +CREATE TABLE test( + id1 INTEGER, + id2 INTEGER, + gr1 INTEGER, + gr2 INTEGER +); +INSERT INTO test VALUES +(1,0,1,1), +(2,0,1,1), +(3,0,1,1), +(4,0,1,1), +(5,0,1,1), +(6,0,1,1), +(7,0,1,1), +(8,0,1,1), +(9,0,1,1), +(9,0,1,1), +(9,0,1,1), +(9,0,1,1), +(0,1,1,2), +(0,2,1,2), +(0,3,1,2), +(0,4,1,2), +(0,5,1,2), +(0,6,1,2), +(10,10,2,2), +(20,20,2,2), +(30,30,2,2), +(40,40,2,2), +(50,50,2,2), +(60,60,2,2), +(70,70,2,2); +</pre> + +-# Sample without replacement: +<pre class="syntax"> +DROP TABLE IF EXISTS out; +SELECT madlib.test_train_split( + 'test', -- Source table + 'out', -- Output table + 0.5, -- Sample proportion + 0.5, -- Sample proportion + 'gr1,gr2', -- Strata definition + 'id1,id2', -- Columns to output + FALSE, -- Sample without replacement + FALSE); -- Do not separate output tables +SELECT * FROM out ORDER BY split,gr1,gr2,id1,id2; +</pre> +<pre class="result"> + gr1 | gr2 | id1 | id2 | split +-----+-----+-----+-----+------- + 1 | 1 | 1 | 0 | 0 + 1 | 1 | 4 | 0 | 0 + 1 | 1 | 6 | 0 | 0 + 1 | 1 | 9 | 0 | 0 + 1 | 1 | 9 | 0 | 0 + 1 | 1 | 9 | 0 | 0 + 1 | 2 | 0 | 3 | 0 + 1 | 2 | 0 | 4 | 0 + 1 | 2 | 0 | 5 | 0 + 2 | 2 | 10 | 10 | 0 + 2 | 2 | 30 | 30 | 0 + 2 | 2 | 40 | 40 | 0 + 2 | 2 | 60 | 60 | 0 + 1 | 1 | 2 | 0 | 1 + 1 | 1 | 3 | 0 | 1 + 1 | 1 | 5 | 0 | 1 + 1 | 1 | 7 | 0 | 1 + 1 | 1 | 8 | 0 | 1 + 1 | 1 | 9 | 0 | 1 + 1 | 2 | 0 | 1 | 1 + 1 | 2 | 0 | 2 | 1 + 1 | 2 | 0 | 6 | 1 + 2 | 2 | 20 | 20 | 1 + 2 | 2 | 50 | 50 | 1 + 2 | 2 | 70 | 70 | 1 +(25 rows) +</pre> + +-# Sample with replacement: +<pre class="syntax"> +DROP TABLE IF EXISTS out_train, out_test; +SELECT madlib.test_train_split( + 'test', -- Source table + 'out', -- Output table + 0.5, -- train_proportion + NULL, -- Default = 1 - train_proportion = 0.5 + 'gr1,gr2', -- Strata definition + 'id1,id2', -- Columns to output + TRUE, -- Sample with replacement + TRUE); -- Separate output tables +SELECT * FROM out_train ORDER BY gr1,gr2,id1,id2; +</pre> +<pre class="result"> + gr1 | gr2 | id1 | id2 +-----+-----+-----+----- + 1 | 1 | 1 | 0 + 1 | 1 | 2 | 0 + 1 | 1 | 4 | 0 + 1 | 1 | 7 | 0 + 1 | 1 | 8 | 0 + 1 | 1 | 9 | 0 + 1 | 2 | 0 | 4 + 1 | 2 | 0 | 5 + 1 | 2 | 0 | 6 + 2 | 2 | 40 | 40 + 2 | 2 | 50 | 50 + 2 | 2 | 50 | 50 +(12 rows) +</pre> +*/ + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split( + source_table TEXT, + output_table TEXT, + train_proportion FLOAT8, + test_proportion FLOAT8, + grouping_cols TEXT, + target_cols TEXT, + with_replacement BOOLEAN, + separate_output_tables BOOLEAN +) RETURNS VOID AS $$ + PythonFunction(sample, test_train_split, test_train_split) +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `'); + +------------------------------------------------------------------------------- + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split( + source_table TEXT, + output_table TEXT, + train_proportion FLOAT8, + test_proportion FLOAT8, + grouping_cols TEXT, + target_cols TEXT, + with_replacement BOOLEAN +) RETURNS VOID AS $$ + SELECT MADLIB_SCHEMA.test_train_split($1, $2, $3, $4, $5, $6, $7, FALSE); +$$ LANGUAGE sql VOLATILE +m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `'); +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split( + source_table TEXT, + output_table TEXT, + train_proportion FLOAT8, + test_proportion FLOAT8, + grouping_cols TEXT, + target_cols TEXT +) RETURNS VOID AS $$ + SELECT MADLIB_SCHEMA.test_train_split($1, $2, $3, $4, $5, $6, FALSE); +$$ LANGUAGE sql VOLATILE +m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `'); +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split( + source_table TEXT, + output_table TEXT, + train_proportion FLOAT8, + test_proportion FLOAT8, + grouping_cols TEXT +) RETURNS VOID AS $$ + SELECT MADLIB_SCHEMA.test_train_split($1, $2, $3, $4, $5, NULL); +$$ LANGUAGE sql VOLATILE +m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `'); +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split( + source_table TEXT, + output_table TEXT, + train_proportion FLOAT8, + test_proportion FLOAT8 +) RETURNS VOID AS $$ + SELECT MADLIB_SCHEMA.test_train_split($1, $2, $3, $4, NULL); +$$ LANGUAGE sql VOLATILE +m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `'); + +------------------------------------------------------------------------------- + +-- Online help +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split( + message VARCHAR +) RETURNS VARCHAR AS $$ + PythonFunction(sample, test_train_split, test_train_split_help) +$$ LANGUAGE plpythonu IMMUTABLE +m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `'); + +------------------------------------------------------------------------------- + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split() +RETURNS VARCHAR AS $$ + SELECT MADLIB_SCHEMA.test_train_split(''); +$$ LANGUAGE sql IMMUTABLE +m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `'); +-------------------------------------------------------------------------------