This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d5d1d7a  Add BERT QA Scala/Java example (#14592)
d5d1d7a is described below

commit d5d1d7ac417ed75b4c2c898d653a210edec046b9
Author: Lanking <lanking...@live.com>
AuthorDate: Fri Apr 5 09:26:25 2019 -0700

    Add BERT QA Scala/Java example (#14592)
    
    * add BertQA major code piece
    
    * add scripts and bug fixes
    
    * add integration test
    
    * address comments
    
    * address doc comments
---
 .../scala/org/apache/mxnet/javaapi/Layout.scala    |  34 +++++
 scala-package/examples/pom.xml                     |   5 +
 .../examples/scripts/infer/bert/get_bert_data.sh   |  31 +++++
 .../scripts/infer/bert/run_bert_qa_example.sh      |  27 ++++
 .../javaapi/infer/bert/BertDataParser.java         | 126 ++++++++++++++++++
 .../mxnetexamples/javaapi/infer/bert/BertQA.java   | 148 +++++++++++++++++++++
 .../mxnetexamples/javaapi/infer/bert/README.md     | 103 ++++++++++++++
 .../javaapi/infer/predictor/BertExampleTest.java   |  71 ++++++++++
 8 files changed, 545 insertions(+)

diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala
new file mode 100644
index 0000000..cfe290c
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Layout.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mxnet.javaapi
+
+/**
+  * Layout definition of DataDesc
+  * N Batch size
+  * C channels
+  * H Height
+  * W Weight
+  * T sequence length
+  * __undefined__ default value of Layout
+  */
+object Layout {
+  val UNDEFINED: String = org.apache.mxnet.Layout.UNDEFINED
+  val NCHW: String = org.apache.mxnet.Layout.NCHW
+  val NTC: String = org.apache.mxnet.Layout.NTC
+  val NT: String = org.apache.mxnet.Layout.NT
+  val N: String = org.apache.mxnet.Layout.N
+}
diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml
index 07e4301..d60782f 100644
--- a/scala-package/examples/pom.xml
+++ b/scala-package/examples/pom.xml
@@ -145,5 +145,10 @@
       <artifactId>slf4j-simple</artifactId>
       <version>1.7.5</version>
     </dependency>
+    <dependency>
+      <groupId>com.google.code.gson</groupId>
+      <artifactId>gson</artifactId>
+      <version>2.8.5</version>
+    </dependency>
   </dependencies>
 </project>
diff --git a/scala-package/examples/scripts/infer/bert/get_bert_data.sh 
b/scala-package/examples/scripts/infer/bert/get_bert_data.sh
new file mode 100755
index 0000000..609aae2
--- /dev/null
+++ b/scala-package/examples/scripts/infer/bert/get_bert_data.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+
+MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd)
+
+data_path=$MXNET_ROOT/scripts/infer/models/static-bert-qa/
+
+if [ ! -d "$data_path" ]; then
+  mkdir -p "$data_path"
+  curl 
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json
 -o $data_path/vocab.json
+  curl 
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params
 -o $data_path/static_bert_qa-0002.params
+  curl 
https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json
 -o $data_path/static_bert_qa-symbol.json
+fi
diff --git a/scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh 
b/scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh
new file mode 100755
index 0000000..d8ba092
--- /dev/null
+++ b/scala-package/examples/scripts/infer/bert/run_bert_qa_example.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+
+MXNET_ROOT=$(cd "$(dirname $0)/../../../../.."; pwd)
+
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*
+
+java -Xmx8G -Dmxnet.traceLeakedObjects=true -cp $CLASS_PATH \
+       org.apache.mxnetexamples.javaapi.infer.bert.BertQA $@
diff --git 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
 
b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
new file mode 100644
index 0000000..440670a
--- /dev/null
+++ 
b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertDataParser.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.infer.bert;
+
+import java.io.FileReader;
+import java.util.*;
+
+import com.google.gson.Gson;
+import com.google.gson.JsonArray;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+
+/**
+ * This is the Utility for pre-processing the data for Bert Model
+ * You can use this utility to parse Vocabulary JSON into Java Array and 
Dictionary,
+ * clean and tokenize sentences and pad the text
+ */
+public class BertDataParser {
+
+    private Map<String, Integer> token2idx;
+    private List<String> idx2token;
+
+    /**
+     * Parse the Vocabulary to JSON files
+     * [PAD], [CLS], [SEP], [MASK], [UNK] are reserved tokens
+     * @param jsonFile the filePath of the vocab.json
+     * @throws Exception
+     */
+    void parseJSON(String jsonFile) throws Exception {
+        Gson gson = new Gson();
+        token2idx = new HashMap<>();
+        idx2token = new LinkedList<>();
+        JsonObject jsonObject = gson.fromJson(new FileReader(jsonFile), 
JsonObject.class);
+        JsonArray arr = jsonObject.getAsJsonArray("idx_to_token");
+        for (JsonElement element : arr) {
+            idx2token.add(element.getAsString());
+        }
+        JsonObject preMap = jsonObject.getAsJsonObject("token_to_idx");
+        for (String key : preMap.keySet()) {
+            token2idx.put(key, preMap.get(key).getAsInt());
+        }
+    }
+
+    /**
+     * Tokenize the input, split all kinds of whitespace and
+     * Separate the end of sentence symbol: . , ? !
+     * @param input The input string
+     * @return List of tokens
+     */
+    List<String> tokenizer(String input) {
+        String[] step1 = input.split("\\s+");
+        List<String> finalResult = new LinkedList<>();
+        for (String item : step1) {
+            if (item.length() != 0) {
+                if ((item + "a").split("[.,?!]+").length > 1) {
+                    finalResult.add(item.substring(0, item.length() - 1));
+                    finalResult.add(item.substring(item.length() -1));
+                } else {
+                    finalResult.add(item);
+                }
+            }
+        }
+        return finalResult;
+    }
+
+    /**
+     * Pad the tokens to the required length
+     * @param tokens input tokens
+     * @param padItem things to pad at the end
+     * @param num total length after padding
+     * @return List of padded tokens
+     */
+    <E> List<E> pad(List<E> tokens, E padItem, int num) {
+        if (tokens.size() >= num) return tokens;
+        List<E> padded = new LinkedList<>(tokens);
+        for (int i = 0; i < num - tokens.size(); i++) {
+            padded.add(padItem);
+        }
+        return padded;
+    }
+
+    /**
+     * Convert tokens to indexes
+     * @param tokens input tokens
+     * @return List of indexes
+     */
+    List<Integer> token2idx(List<String> tokens) {
+        List<Integer> indexes = new ArrayList<>();
+        for (String token : tokens) {
+            if (token2idx.containsKey(token)) {
+                indexes.add(token2idx.get(token));
+            } else {
+                indexes.add(token2idx.get("[UNK]"));
+            }
+        }
+        return indexes;
+    }
+
+    /**
+     * Convert indexes to tokens
+     * @param indexes List of indexes
+     * @return List of tokens
+     */
+    List<String> idx2token(List<Integer> indexes) {
+        List<String> tokens = new ArrayList<>();
+        for (int index : indexes) {
+            tokens.add(idx2token.get(index));
+        }
+        return tokens;
+    }
+}
diff --git 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
 
b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
new file mode 100644
index 0000000..b40a4e9
--- /dev/null
+++ 
b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.infer.bert;
+
+import org.apache.mxnet.infer.javaapi.Predictor;
+import org.apache.mxnet.javaapi.*;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.*;
+
+/**
+ * This is an example of using BERT to do the general Question and Answer 
inference jobs
+ * Users can provide a question with a paragraph contains answer to the model 
and
+ * the model will be able to find the best answer from the answer paragraph
+ */
+public class BertQA {
+    @Option(name = "--model-path-prefix", usage = "input model directory and 
prefix of the model")
+    private String modelPathPrefix = "/model/static_bert_qa";
+    @Option(name = "--model-epoch", usage = "Epoch number of the model")
+    private int epoch = 2;
+    @Option(name = "--model-vocab", usage = "the vocabulary used in the model")
+    private String modelVocab = "/model/vocab.json";
+    @Option(name = "--input-question", usage = "the input question")
+    private String inputQ = "When did BBC Japan start broadcasting?";
+    @Option(name = "--input-answer", usage = "the input answer")
+    private String inputA =
+        "BBC Japan was a general entertainment Channel.\n" +
+                " Which operated between December 2004 and April 2006.\n" +
+            "It ceased operations after its Japanese distributor folded.";
+    @Option(name = "--seq-length", usage = "the maximum length of the 
sequence")
+    private int seqLength = 384;
+
+    private final static Logger logger = LoggerFactory.getLogger(BertQA.class);
+    private static NDArray$ NDArray = NDArray$.MODULE$;
+
+    private static int argmax(float[] prob) {
+        int maxIdx = 0;
+        for (int i = 0; i < prob.length; i++) {
+            if (prob[maxIdx] < prob[i]) maxIdx = i;
+        }
+        return maxIdx;
+    }
+
+    /**
+     * Do the post processing on the output, apply softmax to get the 
probabilities
+     * reshape and get the most probable index
+     * @param result prediction result
+     * @param tokens word tokens
+     * @return Answers clipped from the original paragraph
+     */
+    static List<String> postProcessing(NDArray result, List<String> tokens) {
+        NDArray[] output = NDArray.split(
+                NDArray.new splitParam(result, 2).setAxis(2));
+        // Get the formatted logits result
+        NDArray startLogits = output[0].reshape(new int[]{0, -3});
+        NDArray endLogits = output[1].reshape(new int[]{0, -3});
+        // Get Probability distribution
+        float[] startProb = NDArray.softmax(
+                NDArray.new softmaxParam(startLogits))[0].toArray();
+        float[] endProb = NDArray.softmax(
+                NDArray.new softmaxParam(endLogits))[0].toArray();
+        int startIdx = argmax(startProb);
+        int endIdx = argmax(endProb);
+        return tokens.subList(startIdx, endIdx + 1);
+    }
+
+    public static void main(String[] args) throws Exception{
+        BertQA inst = new BertQA();
+        CmdLineParser parser = new CmdLineParser(inst);
+        parser.parseArgument(args);
+        BertDataParser util = new BertDataParser();
+        Context context = Context.cpu();
+        if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+                Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
+            context = Context.gpu();
+        }
+        // pre-processing - tokenize sentence
+        List<String> tokenQ = util.tokenizer(inst.inputQ.toLowerCase());
+        List<String> tokenA = util.tokenizer(inst.inputA.toLowerCase());
+        int validLength = tokenQ.size() + tokenA.size();
+        logger.info("Valid length: " + validLength);
+        // generate token types [0000...1111....0000]
+        List<Float> QAEmbedded = new ArrayList<>();
+        util.pad(QAEmbedded, 0f, tokenQ.size()).addAll(
+                util.pad(new ArrayList<Float>(), 1f, tokenA.size())
+        );
+        List<Float> tokenTypes = util.pad(QAEmbedded, 0f, inst.seqLength);
+        // make BERT pre-processing standard
+        tokenQ.add("[SEP]");
+        tokenQ.add(0, "[CLS]");
+        tokenA.add("[SEP]");
+        tokenQ.addAll(tokenA);
+        List<String> tokens = util.pad(tokenQ, "[PAD]", inst.seqLength);
+        logger.info("Pre-processed tokens: " + 
Arrays.toString(tokenQ.toArray()));
+        // pre-processing - token to index translation
+        util.parseJSON(inst.modelVocab);
+        List<Integer> indexes = util.token2idx(tokens);
+        List<Float> indexesFloat = new ArrayList<>();
+        for (int integer : indexes) {
+            indexesFloat.add((float) integer);
+        }
+        // Preparing the input data
+        List<NDArray> inputBatch = Arrays.asList(
+                new NDArray(indexesFloat,
+                        new Shape(new int[]{1, inst.seqLength}), context),
+                new NDArray(tokenTypes,
+                        new Shape(new int[]{1, inst.seqLength}), context),
+                new NDArray(new float[] { validLength },
+                        new Shape(new int[]{1}), context)
+        );
+        // Build the model
+        List<Context> contexts = new ArrayList<>();
+        contexts.add(context);
+        List<DataDesc> inputDescs = Arrays.asList(
+                new DataDesc("data0",
+                        new Shape(new int[]{1, inst.seqLength}), 
DType.Float32(), Layout.NT()),
+                new DataDesc("data1",
+                        new Shape(new int[]{1, inst.seqLength}), 
DType.Float32(), Layout.NT()),
+                new DataDesc("data2",
+                        new Shape(new int[]{1}), DType.Float32(), Layout.N())
+        );
+        Predictor bertQA = new Predictor(inst.modelPathPrefix, inputDescs, 
contexts, inst.epoch);
+        // Start prediction
+        NDArray result = bertQA.predictWithNDArray(inputBatch).get(0);
+        List<String> answer = postProcessing(result, tokens);
+        logger.info("Question: " + inst.inputQ);
+        logger.info("Answer paragraph: " + inst.inputA);
+        logger.info("Answer: " + Arrays.toString(answer.toArray()));
+    }
+}
diff --git 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md
 
b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md
new file mode 100644
index 0000000..7925a25
--- /dev/null
+++ 
b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/README.md
@@ -0,0 +1,103 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Run BERT QA model using Java Inference API
+
+In this tutorial, we will walk through the BERT QA model trained by MXNet. 
+Users can provide a question with a paragraph contains answer to the model and
+the model will be able to find the best answer from the answer paragraph.
+
+Example:
+```text
+Q: When did BBC Japan start broadcasting?
+```
+
+Answer paragraph
+```text
+BBC Japan was a general entertainment channel, which operated between December 
2004 and April 2006.
+It ceased operations after its Japanese distributor folded.
+```
+And it picked up the right one:
+```text
+A: December 2004
+```
+
+## Setup Guide
+
+### Step 1: Download the model
+
+For this tutorial, you can get the model and vocabulary by running following 
bash file. This script will use `wget` to download these artifacts from AWS S3.
+
+From the `scala-package/examples/scripts/infer/bert/` folder run:
+
+```bash
+./get_bert_data.sh
+```
+
+### Step 2: Setup data path of the model
+
+### Setup Datapath and Parameters
+
+The available arguments are as follows:
+
+| Argument                      | Comments                                 |
+| ----------------------------- | ---------------------------------------- |
+| `--model-path-prefix`           | Folder path with prefix to the model 
(including json, params). |
+| `--model-vocab`                 | Vocabulary path |
+| `--model-epoch`                 | Epoch number of the model |
+| `--input-question`              | Question that asked to the model |
+| `--input-answer`                | Paragraph that contains the answer |
+| `--seq-length`                  | Sequence Length of the model (384 by 
default) |
+
+### Step 3: Run Inference
+After the previous steps, you should be able to run the code using the 
following script that will pass all of the required parameters to the Infer API.
+
+From the `scala-package/examples/scripts/infer/bert/` folder run:
+
+```bash
+./run_bert_qa_example.sh --model-path-prefix 
../models/static-bert-qa/static_bert_qa \
+                         --model-vocab ../models/static-bert-qa/vocab.json \
+                         --model-epoch 2
+```
+
+## Background
+
+To learn more about how BERT works in MXNet, please follow this [MXNet Gluon 
tutorial on NLP using 
BERT](https://medium.com/apache-mxnet/gluon-nlp-bert-6a489bdd3340).
+
+The model was extracted from MXNet GluonNLP with static length settings.
+
+[Download link for the script](https://gluon-nlp.mxnet.io/_downloads/bert.zip)
+
+The original description can be found in the [MXNet GluonNLP model 
zoo](https://gluon-nlp.mxnet.io/model_zoo/bert/index.html#bert-base-on-squad-1-1).
+```bash
+python static_finetune_squad.py --optimizer adam --accumulate 2 --batch_size 6 
--lr 3e-5 --epochs 2 --gpu 0 --export
+
+```
+This script will generate `json` and `param` fles that are the standard MXNet 
model files.
+By default, this model are using `bert_12_768_12` model with extra layers for 
QA jobs.
+
+After that, to be able to use it in Java, we need to export the dictionary 
from the script to parse the text
+to actual indexes. Please add the following lines after [this 
line](https://github.com/dmlc/gluon-nlp/blob/master/scripts/bert/staticbert/static_finetune_squad.py#L262).
+```python
+import json
+json_str = vocab.to_json()
+f = open("vocab.json", "w")
+f.write(json_str)
+f.close()
+```
+This would export the token vocabulary in json format.
+Once you have these three files, you will be able to run this example without 
problems.
diff --git 
a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java
 
b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java
new file mode 100644
index 0000000..0518254
--- /dev/null
+++ 
b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/BertExampleTest.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.infer.predictor;
+
+import org.apache.mxnetexamples.Util;
+import org.apache.mxnetexamples.javaapi.infer.bert.BertQA;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+
+/**
+ * Test on BERT QA model
+ */
+public class BertExampleTest {
+    final static Logger logger = 
LoggerFactory.getLogger(BertExampleTest.class);
+    private static String modelPathPrefix = "";
+    private static String vocabPath = "";
+
+    @BeforeClass
+    public static void downloadFile() {
+        logger.info("Downloading Bert QA Model");
+        String tempDirPath = System.getProperty("java.io.tmpdir");
+        logger.info("tempDirPath: %s".format(tempDirPath));
+
+        String baseUrl = 
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA";;
+        Util.downloadUrl(baseUrl + "/static_bert_qa-symbol.json",
+                tempDirPath + "/static_bert_qa/static_bert_qa-symbol.json", 3);
+        Util.downloadUrl(baseUrl + "/static_bert_qa-0002.params",
+                tempDirPath + "/static_bert_qa/static_bert_qa-0002.params", 3);
+        Util.downloadUrl(baseUrl + "/vocab.json",
+                tempDirPath + "/static_bert_qa/vocab.json", 3);
+        modelPathPrefix = tempDirPath + File.separator + 
"static_bert_qa/static_bert_qa";
+        vocabPath = tempDirPath + File.separator + "static_bert_qa/vocab.json";
+    }
+
+    @Test
+    public void testBertQA() throws Exception{
+        BertQA bert = new BertQA();
+        String Q = "When did BBC Japan start broadcasting?";
+        String A = "BBC Japan was a general entertainment Channel.\n" +
+                " Which operated between December 2004 and April 2006.\n" +
+                "It ceased operations after its Japanese distributor folded.";
+        String[] args = new String[] {
+                "--model-path-prefix", modelPathPrefix,
+                "--model-vocab", vocabPath,
+                "--model-epoch", "2",
+                "--input-question", Q,
+                "--input-answer", A,
+                "--seq-length", "384"
+        };
+        bert.main(args);
+    }
+}

Reply via email to