This is an automated email from the ASF dual-hosted git repository. zaleslaw pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push: new 40377b1 IGNITE-12903 Fixed ML + SQL examples (#7965) 40377b1 is described below commit 40377b109053d2e576ace7c612bf006aff9ef76d Author: Alexey Zinoviev <zaleslaw....@gmail.com> AuthorDate: Fri Jun 26 16:39:29 2020 +0300 IGNITE-12903 Fixed ML + SQL examples (#7965) * [IGNITE-12903] Fixed ML + SQL examples * [IGNITE-12903] Fixed ML + SQL examples --- ...eeClassificationTrainerSQLInferenceExample.java | 36 ++----- ...onTreeClassificationTrainerSQLTableExample.java | 109 ++++++++++++++------- .../selection/scoring/evaluator/package-info.java | 4 +- 3 files changed, 83 insertions(+), 66 deletions(-) diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java index ab2a00c..543e211 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java @@ -17,18 +17,14 @@ package org.apache.ignite.examples.ml.sql; -import java.util.HashSet; +import java.io.IOException; import java.util.List; import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; -import org.apache.ignite.IgniteCheckedException; import org.apache.ignite.Ignition; import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.SqlFieldsQuery; import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.internal.IgniteEx; -import org.apache.ignite.internal.processors.query.h2.IgniteH2Indexing; -import org.apache.ignite.internal.util.IgniteUtils; import org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer; import org.apache.ignite.ml.inference.IgniteModelStorageUtil; import org.apache.ignite.ml.sql.SQLFunctions; @@ -36,6 +32,8 @@ import org.apache.ignite.ml.sql.SqlDatasetBuilder; import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; import org.apache.ignite.ml.tree.DecisionTreeNode; +import static org.apache.ignite.examples.ml.sql.DecisionTreeClassificationTrainerSQLTableExample.loadTitanicDatasets; + /** * Example of using distributed {@link DecisionTreeClassificationTrainer} on a data stored in SQL table and inference * made as SQL select query. @@ -47,30 +45,15 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample { private static final String DUMMY_CACHE_NAME = "dummy_cache"; /** - * Training data. - */ - private static final String TRAIN_DATA_RES = "examples/src/main/resources/datasets/titanic_train.csv"; - - /** - * Test data. - */ - private static final String TEST_DATA_RES = "examples/src/main/resources/datasets/titanic_test.csv"; - - /** * Run example. */ - public static void main(String[] args) throws IgniteCheckedException { + public static void main(String[] args) throws IOException { System.out.println(">>> Decision tree classification trainer example started."); // Start ignite grid. try (Ignite ignite = Ignition.start("examples/config/example-ignite-ml.xml")) { System.out.println(">>> Ignite grid started."); - // Use internal API to enable SQL functions disabled by default (the function CSVREAD is used below) - // TODO: IGNITE-12903 - ((IgniteH2Indexing)((IgniteEx)ignite).context().query().getIndexing()) - .distributedConfiguration().disabledFunctions(new HashSet<>()); - // Dummy cache is required to perform SQL queries. CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME) .setSqlSchema("PUBLIC") @@ -83,8 +66,8 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample { System.out.println(">>> Creating table with training data..."); cache.query(new SqlFieldsQuery("create table titanic_train (\n" + " passengerid int primary key,\n" + - " survived int,\n" + " pclass int,\n" + + " survived int,\n" + " name varchar(255),\n" + " sex varchar(255),\n" + " age float,\n" + @@ -96,14 +79,11 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample { " embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll(); - System.out.println(">>> Filling training data..."); - cache.query(new SqlFieldsQuery("insert into titanic_train select * from csvread('" + - IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll(); - System.out.println(">>> Creating table with test data..."); cache.query(new SqlFieldsQuery("create table titanic_test (\n" + " passengerid int primary key,\n" + " pclass int,\n" + + " survived int,\n" + " name varchar(255),\n" + " sex varchar(255),\n" + " age float,\n" + @@ -115,9 +95,7 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample { " embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll(); - System.out.println(">>> Filling training data..."); - cache.query(new SqlFieldsQuery("insert into titanic_test select * from csvread('" + - IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll(); + loadTitanicDatasets(ignite, cache); System.out.println(">>> Prepare trainer..."); DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0); diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java index 5fe123c..083608e 100644 --- a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java +++ b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java @@ -17,8 +17,9 @@ package org.apache.ignite.examples.ml.sql; -import java.util.HashSet; +import java.io.IOException; import java.util.List; + import org.apache.ignite.Ignite; import org.apache.ignite.IgniteCache; import org.apache.ignite.IgniteCheckedException; @@ -26,9 +27,8 @@ import org.apache.ignite.Ignition; import org.apache.ignite.cache.query.QueryCursor; import org.apache.ignite.cache.query.SqlFieldsQuery; import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.internal.IgniteEx; -import org.apache.ignite.internal.processors.query.h2.IgniteH2Indexing; -import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.examples.ml.util.MLSandboxDatasets; +import org.apache.ignite.examples.ml.util.SandboxMLCache; import org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer; import org.apache.ignite.ml.math.primitives.vector.Vector; import org.apache.ignite.ml.math.primitives.vector.VectorUtils; @@ -46,30 +46,15 @@ public class DecisionTreeClassificationTrainerSQLTableExample { private static final String DUMMY_CACHE_NAME = "dummy_cache"; /** - * Training data. - */ - private static final String TRAIN_DATA_RES = "examples/src/main/resources/datasets/titanic_train.csv"; - - /** - * Test data. - */ - private static final String TEST_DATA_RES = "examples/src/main/resources/datasets/titanic_test.csv"; - - /** * Run example. */ - public static void main(String[] args) throws IgniteCheckedException { + public static void main(String[] args) throws IgniteCheckedException, IOException { System.out.println(">>> Decision tree classification trainer example started."); // Start ignite grid. try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { System.out.println(">>> Ignite grid started."); - // Use internal API to enable SQL functions disabled by default (the function CSVREAD is used below) - // TODO: IGNITE-12903 - ((IgniteH2Indexing)((IgniteEx)ignite).context().query().getIndexing()) - .distributedConfiguration().disabledFunctions(new HashSet<>()); - // Dummy cache is required to perform SQL queries. CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME) .setSqlSchema("PUBLIC"); @@ -81,8 +66,8 @@ public class DecisionTreeClassificationTrainerSQLTableExample { System.out.println(">>> Creating table with training data..."); cache.query(new SqlFieldsQuery("create table titanic_train (\n" + " passengerid int primary key,\n" + - " survived int,\n" + " pclass int,\n" + + " survived int,\n" + " name varchar(255),\n" + " sex varchar(255),\n" + " age float,\n" + @@ -94,14 +79,11 @@ public class DecisionTreeClassificationTrainerSQLTableExample { " embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll(); - System.out.println(">>> Filling training data..."); - cache.query(new SqlFieldsQuery("insert into titanic_train select * from csvread('" + - IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll(); - System.out.println(">>> Creating table with test data..."); cache.query(new SqlFieldsQuery("create table titanic_test (\n" + " passengerid int primary key,\n" + " pclass int,\n" + + " survived int,\n" + " name varchar(255),\n" + " sex varchar(255),\n" + " age float,\n" + @@ -113,9 +95,7 @@ public class DecisionTreeClassificationTrainerSQLTableExample { " embarked varchar(255)\n" + ") with \"template=partitioned\";")).getAll(); - System.out.println(">>> Filling training data..."); - cache.query(new SqlFieldsQuery("insert into titanic_test select * from csvread('" + - IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll(); + loadTitanicDatasets(ignite, cache); System.out.println(">>> Prepare trainer..."); DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0); @@ -128,6 +108,8 @@ public class DecisionTreeClassificationTrainerSQLTableExample { .labeled("survived") ); + System.out.println("Tree is here: " + mdl.toString(true)); + System.out.println(">>> Perform inference..."); try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " + "pclass, " + @@ -137,13 +119,13 @@ public class DecisionTreeClassificationTrainerSQLTableExample { "parch, " + "fare from titanic_test"))) { for (List<?> passenger : cursor) { - Vector input = VectorUtils.of(new Double[] { + Vector input = VectorUtils.of(new Double[]{ asDouble(passenger.get(0)), "male".equals(passenger.get(1)) ? 1.0 : 0.0, asDouble(passenger.get(2)), asDouble(passenger.get(3)), asDouble(passenger.get(4)), - asDouble(passenger.get(5)) + asDouble(passenger.get(5)), }); double prediction = mdl.predict(input); @@ -153,14 +135,12 @@ public class DecisionTreeClassificationTrainerSQLTableExample { } System.out.println(">>> Example completed."); - } - finally { + } finally { cache.query(new SqlFieldsQuery("DROP TABLE titanic_train")); cache.query(new SqlFieldsQuery("DROP TABLE titanic_test")); cache.destroy(); } - } - finally { + } finally { System.out.flush(); } } @@ -177,11 +157,70 @@ public class DecisionTreeClassificationTrainerSQLTableExample { return null; if (obj instanceof Number) { - Number num = (Number)obj; + Number num = (Number) obj; return num.doubleValue(); } throw new IllegalArgumentException("Object is expected to be a number [obj=" + obj + "]"); } + + /** + * Loads Titanic dataset into cache. + * + * @param ignite Ignite instance. + * @throws IOException If dataset not found. + */ + static void loadTitanicDatasets(Ignite ignite, IgniteCache<?, ?> cache) throws IOException { + + List<String> titanicDatasetRows = new SandboxMLCache(ignite).loadDataset(MLSandboxDatasets.TITANIC); + List<String> train = titanicDatasetRows.subList(0, 1000); + List<String> test = titanicDatasetRows.subList(1000, titanicDatasetRows.size()); + + insertToCache(cache, train, "titanic_train"); + insertToCache(cache, test, "titanic_test"); + } + + /** */ + private static void insertToCache(IgniteCache<?, ?> cache, List<String> train, String tableName) { + SqlFieldsQuery insertTrain = new SqlFieldsQuery("insert into " + tableName + " " + + "(passengerid, pclass, survived, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked) " + + "values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"); + + int seq = 0; + for (String s : train) { + String[] line = s.split(";"); + int pclass = parseInteger(line[0]); + int survived = parseInteger(line[1]); + String name = line[2]; + String sex = line[3]; + double age = parseDouble(line[4]); + double sibsp = parseInteger(line[5]); + double parch = parseInteger(line[6]); + String ticket = line[7]; + double fare = parseDouble(line[8]); + String cabin = line[9]; + String embarked = line[10]; + insertTrain.setArgs(seq++, pclass, survived, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked); + cache.query(insertTrain); + } + } + + /** */ + private static Integer parseInteger(String value) { + try { + return Integer.valueOf(value); + } catch (NumberFormatException e) { + return 0; + } + } + + /** */ + private static Double parseDouble(String value) { + try { + return Double.valueOf(value); + } catch (NumberFormatException e) { + return 0.0; + } + } } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/package-info.java index c5cdf08..f74a607 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/package-info.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/evaluator/package-info.java @@ -16,7 +16,7 @@ */ /** - * <!-- Package description. --> Package for model evaluator classes. + * <!-- Package description. --> + * Package for model evaluator classes. */ - package org.apache.ignite.ml.selection.scoring.evaluator;