This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch clojure-bert-sentence-pair-classification
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 7295d8c3cc3ed72ff23afc4372719382fb16c0dc
Author: gigasquid <cme...@gigasquidsoftware.com>
AuthorDate: Fri Apr 19 17:43:41 2019 -0400

    base working (although slow)
---
 .../src/bert_qa/bert_sentence_classification.clj   | 69 ++++++++++------------
 1 file changed, 30 insertions(+), 39 deletions(-)

diff --git 
a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj
 
b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj
index 053dade..7f6723e 100644
--- 
a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj
+++ 
b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj
@@ -11,6 +11,7 @@
             [org.apache.clojure-mxnet.symbol :as sym]
             [org.apache.clojure-mxnet.module :as m]
             [org.apache.clojure-mxnet.infer :as infer]
+            [org.apache.clojure-mxnet.eval-metric :as eval-metric]
             [org.apache.clojure-mxnet.optimizer :as optimizer]
             [clojure.pprint :as pprint]
             [clojure-csv.core :as csv]
@@ -115,15 +116,7 @@
                      :dtype dtype/FLOAT32
                      :layout layout/NT}])
 
-  ;; now create the module
-  (def mod (-> (m/module model-sym {:contexts devs
-                                    :data-names ["data0" "data1" "data2"]})
-               (m/bind {:data-shapes input-descs :label-shapes label-descs})
-               (m/init-params {:arg-params arg-params :aux-params aux-params
-                               :allow-missing true})
-               (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 
0.01 :momentum 0.9})})))
-
-  (def base-mod (-> bert-base
+  #_(def base-mod (-> bert-base
                     (m/bind {:data-shapes input-descs})
                     (m/init-params {:arg-params arg-params :aux-params 
aux-params
                                     :allow-missing true})))
@@ -152,8 +145,9 @@
   (vals (select-keys (first raw-file) [3 4 0]))
   ;=> ("#1 String" "#2 String" "Quality")
   (def data-train-raw (->> raw-file
-                           (map #(vals (select-keys % [3 4 0])))
+                           (mapv #(vals (select-keys % [3 4 0])))
                            (rest) ;;drop header
+                           (into [])
                            ))
   (def sample (first data-train-raw))
   (nth sample 0) ;;;sentence a
@@ -174,34 +168,31 @@
   ;;; our sample item
   (def sample-data (pre-processing (context/default-context) idx->token 
token->idx sample))
 
-  
-
-  ;; with a predictor
-  (defn make-predictor [ctx]
-  (let [input-descs [{:name "data0"
-                      :shape [1 seq-length]
-                      :dtype dtype/FLOAT32
-                      :layout layout/NT}
-                     {:name "data1"
-                      :shape [1 seq-length]
-                      :dtype dtype/FLOAT32
-                      :layout layout/NT}
-                     {:name "data2"
-                      :shape [1]
-                      :dtype dtype/FLOAT32
-                      :layout layout/N}]
-        factory (infer/model-factory model-path-prefix input-descs)]
-    (infer/create-predictor
-     factory
-     {:contexts [ctx]
-      :epoch 0})))
-
-  (def predictor (make-predictor (context/default-context)))
-  (def sample-result (first (infer/predict-with-ndarray predictor 
(:input-batch sample-data))))
-  
-
-
-  
- 
+  (def train-count (count data-train-raw))    ;=> 389
+
+    ;; now create the module
+  (def model (-> (m/module model-sym {:contexts devs
+                                      :data-names ["data0" "data1" "data2"]})
+                 (m/bind {:data-shapes input-descs :label-shapes label-descs})
+                 (m/init-params {:arg-params arg-params :aux-params aux-params
+                                 :allow-missing true})
+                 (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 
0.01 :momentum 0.9})})))
+
+  (def metric (eval-metric/accuracy))
+  (def num-epoch 3)
+  (def processed-datas (mapv #(pre-processing (context/default-context) 
idx->token token->idx %)
+                             data-train-raw))
+
+  (doseq [epoch-num (range num-epoch)]
+    (doall (map-indexed (fn [i batch-data]
+                          (-> model
+                              (m/forward {:data (:input-batch batch-data)})
+                              (m/update-metric metric [(:label batch-data)])
+                              (m/backward)
+                              (m/update))
+                          (when (mod i 10)
+                            (println "Working on " i " of " train-count " acc: 
" (eval-metric/get metric))))
+                        processed-datas))
+    (println "result for epoch " epoch-num " is " (eval-metric/get-and-reset 
metric)))
 
   )

Reply via email to