This is an automated email from the ASF dual-hosted git repository. cmeier pushed a commit to branch clojure-bert-sentence-pair-classification in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 7295d8c3cc3ed72ff23afc4372719382fb16c0dc Author: gigasquid <cme...@gigasquidsoftware.com> AuthorDate: Fri Apr 19 17:43:41 2019 -0400 base working (although slow) --- .../src/bert_qa/bert_sentence_classification.clj | 69 ++++++++++------------ 1 file changed, 30 insertions(+), 39 deletions(-) diff --git a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj index 053dade..7f6723e 100644 --- a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj +++ b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj @@ -11,6 +11,7 @@ [org.apache.clojure-mxnet.symbol :as sym] [org.apache.clojure-mxnet.module :as m] [org.apache.clojure-mxnet.infer :as infer] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] [org.apache.clojure-mxnet.optimizer :as optimizer] [clojure.pprint :as pprint] [clojure-csv.core :as csv] @@ -115,15 +116,7 @@ :dtype dtype/FLOAT32 :layout layout/NT}]) - ;; now create the module - (def mod (-> (m/module model-sym {:contexts devs - :data-names ["data0" "data1" "data2"]}) - (m/bind {:data-shapes input-descs :label-shapes label-descs}) - (m/init-params {:arg-params arg-params :aux-params aux-params - :allow-missing true}) - (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.9})}))) - - (def base-mod (-> bert-base + #_(def base-mod (-> bert-base (m/bind {:data-shapes input-descs}) (m/init-params {:arg-params arg-params :aux-params aux-params :allow-missing true}))) @@ -152,8 +145,9 @@ (vals (select-keys (first raw-file) [3 4 0])) ;=> ("#1 String" "#2 String" "Quality") (def data-train-raw (->> raw-file - (map #(vals (select-keys % [3 4 0]))) + (mapv #(vals (select-keys % [3 4 0]))) (rest) ;;drop header + (into []) )) (def sample (first data-train-raw)) (nth sample 0) ;;;sentence a @@ -174,34 +168,31 @@ ;;; our sample item (def sample-data (pre-processing (context/default-context) idx->token token->idx sample)) - - - ;; with a predictor - (defn make-predictor [ctx] - (let [input-descs [{:name "data0" - :shape [1 seq-length] - :dtype dtype/FLOAT32 - :layout layout/NT} - {:name "data1" - :shape [1 seq-length] - :dtype dtype/FLOAT32 - :layout layout/NT} - {:name "data2" - :shape [1] - :dtype dtype/FLOAT32 - :layout layout/N}] - factory (infer/model-factory model-path-prefix input-descs)] - (infer/create-predictor - factory - {:contexts [ctx] - :epoch 0}))) - - (def predictor (make-predictor (context/default-context))) - (def sample-result (first (infer/predict-with-ndarray predictor (:input-batch sample-data)))) - - - - - + (def train-count (count data-train-raw)) ;=> 389 + + ;; now create the module + (def model (-> (m/module model-sym {:contexts devs + :data-names ["data0" "data1" "data2"]}) + (m/bind {:data-shapes input-descs :label-shapes label-descs}) + (m/init-params {:arg-params arg-params :aux-params aux-params + :allow-missing true}) + (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.9})}))) + + (def metric (eval-metric/accuracy)) + (def num-epoch 3) + (def processed-datas (mapv #(pre-processing (context/default-context) idx->token token->idx %) + data-train-raw)) + + (doseq [epoch-num (range num-epoch)] + (doall (map-indexed (fn [i batch-data] + (-> model + (m/forward {:data (:input-batch batch-data)}) + (m/update-metric metric [(:label batch-data)]) + (m/backward) + (m/update)) + (when (mod i 10) + (println "Working on " i " of " train-count " acc: " (eval-metric/get metric)))) + processed-datas)) + (println "result for epoch " epoch-num " is " (eval-metric/get-and-reset metric))) )