piyushghai commented on a change in pull request #14592: Add BERT QA Scala/Java 
example
URL: https://github.com/apache/incubator-mxnet/pull/14592#discussion_r272391955
 
 

 ##########
 File path: 
scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java
 ##########
 @@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.infer.bert;
+
+import org.apache.mxnet.infer.javaapi.Predictor;
+import org.apache.mxnet.javaapi.*;
+import org.kohsuke.args4j.CmdLineParser;
+import org.kohsuke.args4j.Option;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.*;
+
+/**
+ * This is an example of using BERT to do the general Question and Answer 
inference jobs
+ * Users can provide a question with a paragraph contains answer to the model 
and
+ * the model will be able to find the best answer from the answer paragraph
+ */
+public class BertQA {
+    @Option(name = "--model-path-prefix", usage = "input model directory and 
prefix of the model")
+    private String modelPathPrefix = "/model/static_bert_qa";
+    @Option(name = "--model-epoch", usage = "Epoch number of the model")
+    private int epoch = 2;
+    @Option(name = "--model-vocab", usage = "the vocabulary used in the model")
+    private String modelVocab = "/model/vocab.json";
+    @Option(name = "--input-question", usage = "the input question")
+    private String inputQ = "When did BBC Japan start broadcasting?";
+    @Option(name = "--input-answer", usage = "the input answer")
+    private String inputA =
+        "BBC Japan was a general entertainment Channel.\n" +
+                " Which operated between December 2004 and April 2006.\n" +
+            "It ceased operations after its Japanese distributor folded.";
+    @Option(name = "--seq-length", usage = "the maximum length of the 
sequence")
+    private int seqLength = 384;
+
+    private final static Logger logger = LoggerFactory.getLogger(BertQA.class);
+    private static NDArray$ NDArray = NDArray$.MODULE$;
+
+    private static int argmax(float[] prob) {
+        int maxIdx = 0;
+        for (int i = 0; i < prob.length; i++) {
+            if (prob[maxIdx] < prob[i]) maxIdx = i;
+        }
+        return maxIdx;
+    }
+
+    /**
+     * Do the post processing on the output, apply softmax to get the 
probabilities
+     * reshape and get the most probable index
+     * @param result prediction result
+     * @param tokens word tokens
+     * @return Answers clipped from the original paragraph
+     */
+    static List<String> postProcessing(NDArray result, List<String> tokens) {
+        NDArray[] output = NDArray.split(
+                NDArray.new splitParam(result, 2).setAxis(2));
+        // Get the formatted logits result
+        NDArray startLogits = output[0].reshape(new int[]{0, -3});
+        NDArray endLogits = output[1].reshape(new int[]{0, -3});
+        // Get Probability distribution
+        float[] startProb = NDArray.softmax(
+                NDArray.new softmaxParam(startLogits))[0].toArray();
+        float[] endProb = NDArray.softmax(
+                NDArray.new softmaxParam(endLogits))[0].toArray();
+        int startIdx = argmax(startProb);
+        int endIdx = argmax(endProb);
+        return tokens.subList(startIdx, endIdx + 1);
+    }
+
+    public static void main(String[] args) throws Exception{
+        BertQA inst = new BertQA();
+        CmdLineParser parser = new CmdLineParser(inst);
+        parser.parseArgument(args);
+        BertDataParser util = new BertDataParser();
 
 Review comment:
   nit : ```util --> dataparser ``` just a more meaningful variable name :) 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to