This is an automated email from the ASF dual-hosted git repository. cmeier pushed a commit to branch clojure-bert-sentence-pair-classification in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 1c6e734dc53e16063f4c1dca18d8cb9d0fb5d2d2 Author: gigasquid <cme...@gigasquidsoftware.com> AuthorDate: Fri Apr 19 20:13:38 2019 -0400 gradients not exploding --- .../src/bert_qa/bert_sentence_classification.clj | 203 ++++++++++++--------- 1 file changed, 117 insertions(+), 86 deletions(-) diff --git a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj index b7d4425..20257d2 100644 --- a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj +++ b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj @@ -65,16 +65,19 @@ (comment + + (do - ;;; load the pre-trained BERT model using the module api - (def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0})) - ;;; now that we have loaded the BERT model we need to attach an additional layer for classification which is a dense layer with 2 classes - (def model-sym (fine-tune-model (m/symbol bert-base) 2)) - (def arg-params (m/arg-params bert-base)) - (def aux-params (m/aux-params bert-base)) - - (def devs [(context/default-context)]) - (def input-descs [{:name "data0" +;;; load the pre-trained BERT model using the module api + + (def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0})) +;;; now that we have loaded the BERT model we need to attach an additional layer for classification which is a dense layer with 2 classes + (def model-sym (fine-tune-model (m/symbol bert-base) 2)) + (def arg-params (m/arg-params bert-base)) + (def aux-params (m/aux-params bert-base)) + + (def devs [(context/default-context)]) + (def input-descs [{:name "data0" :shape [1 seq-length] :dtype dtype/FLOAT32 :layout layout/NT} @@ -86,83 +89,111 @@ :shape [1] :dtype dtype/FLOAT32 :layout layout/N}]) - (def label-descs [{:name "softmax_label" - :shape [1 2] - :dtype dtype/FLOAT32 - :layout layout/NT}]) - - ;;; Data Preprocessing for BERT - - ;; For demonstration purpose, we use the dev set of the Microsoft Research Paraphrase Corpus dataset. The file is named ‘dev.tsv’. Let’s take a look at the raw dataset. - ;; it contains 5 columns seperated by tabs - (def raw-file (->> (string/split (slurp "dev.tsv") #"\n") - (map #(string/split % #"\t") ))) - (def raw-file (csv/parse-csv (slurp "dev.tsv") :delimiter \tab)) - (take 3 raw-file) - ;; (["Quality" "#1 ID" "#2 ID" "#1 String" "#2 String"] - ;; ["1" - ;; "1355540" - ;; "1355592" - ;; "He said the foodservice pie business doesn 't fit the company 's long-term growth strategy ." - ;; "\" The foodservice pie business does not fit our long-term growth strategy ."] - ;; ["0" - ;; "2029631" - ;; "2029565" - ;; "Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war ." - ;; "His wife said he was \" 100 percent behind George Bush \" and looked forward to using his years of training in the war ."]) - - ;;; for our task we are only interested in the 0 3rd and 4th column - (vals (select-keys (first raw-file) [3 4 0])) - ;=> ("#1 String" "#2 String" "Quality") - (def data-train-raw (->> raw-file - (mapv #(vals (select-keys % [3 4 0]))) - (rest) ;;drop header - (into []) - )) - (def sample (first data-train-raw)) - (nth sample 0) ;;;sentence a - ;=> "He said the foodservice pie business doesn 't fit the company 's long-term growth strategy ." - (nth sample 1) ;; sentence b - "\" The foodservice pie business does not fit our long-term growth strategy ." - - (nth sample 2) ; 1 means equivalent, 0 means not equivalent - ;=> "1" - - ;;; Now we need to turn these into ndarrays to make a Data Iterator - (def vocab (bert-infer/get-vocab)) - (def idx->token (:idx->token vocab)) - (def token->idx (:token->idx vocab)) + (def label-descs [{:name "softmax_label" + :shape [1 2] + :dtype dtype/FLOAT32 + :layout layout/NT}]) + +;;; Data Preprocessing for BERT + + ;; For demonstration purpose, we use the dev set of the Microsoft Research Paraphrase Corpus dataset. The file is named ‘dev.tsv’. Let’s take a look at the raw dataset. + ;; it contains 5 columns seperated by tabs + (def raw-file (->> (string/split (slurp "dev.tsv") #"\n") + (map #(string/split % #"\t") ))) + (def raw-file (csv/parse-csv (slurp "dev.tsv") :delimiter \tab)) + (take 3 raw-file) + ;; (["Quality" "#1 ID" "#2 ID" "#1 String" "#2 String"] + ;; ["1" + ;; "1355540" + ;; "1355592" + ;; "He said the foodservice pie business doesn 't fit the company 's long-term growth strategy ." + ;; "\" The foodservice pie business does not fit our long-term growth strategy ."] + ;; ["0" + ;; "2029631" + ;; "2029565" + ;; "Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war ." + ;; "His wife said he was \" 100 percent behind George Bush \" and looked forward to using his years of training in the war ."]) + +;;; for our task we are only interested in the 0 3rd and 4th column + (vals (select-keys (first raw-file) [3 4 0])) + ;=> ("#1 String" "#2 String" "Quality") + (def data-train-raw (->> raw-file + (mapv #(vals (select-keys % [3 4 0]))) + (rest) ;;drop header + (into []) + )) + (def sample (first data-train-raw)) + (nth sample 0) ;;;sentence a + ;=> "He said the foodservice pie business doesn 't fit the company 's long-term growth strategy ." + (nth sample 1) ;; sentence b + "\" The foodservice pie business does not fit our long-term growth strategy ." + + (nth sample 2) ; 1 means equivalent, 0 means not equivalent + ;=> "1" + +;;; Now we need to turn these into ndarrays to make a Data Iterator + (def vocab (bert-infer/get-vocab)) + (def idx->token (:idx->token vocab)) + (def token->idx (:token->idx vocab)) - ;;; our sample item - (def sample-data (pre-processing (context/default-context) idx->token token->idx sample)) - - (def train-count (count data-train-raw)) ;=> 389 - - ;; now create the module - (def model (-> (m/module model-sym {:contexts devs - :data-names ["data0" "data1" "data2"]}) - (m/bind {:data-shapes input-descs :label-shapes label-descs}) - (m/init-params {:arg-params arg-params :aux-params aux-params - :allow-missing true}) - (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.9})}))) - - (def metric (eval-metric/accuracy)) - (def num-epoch 3) - (def processed-datas (mapv #(pre-processing (context/default-context) idx->token token->idx %) - data-train-raw)) - - (doseq [epoch-num (range num-epoch)] - (doall (map-indexed (fn [i batch-data] - (-> model - (m/forward {:data (:input-batch batch-data)}) - (m/update-metric metric [(:label batch-data)]) - (m/backward) - (m/update)) - (when (mod i 10) - (println "Working on " i " of " train-count " acc: " (eval-metric/get metric)))) - processed-datas)) - (println "result for epoch " epoch-num " is " (eval-metric/get-and-reset metric))) - - ) +;;; our sample item + (def sample-data (pre-processing (context/default-context) idx->token token->idx sample)) + + (def train-count (count data-train-raw)) ;=> 389 + + ;; now create the module + + (def lr 5e-6) + + (def model (-> (m/module model-sym {:contexts devs + :data-names ["data0" "data1" "data2"]}) + (m/bind {:data-shapes input-descs :label-shapes label-descs}) + (m/init-params {:arg-params arg-params :aux-params aux-params + :allow-missing true}) + (m/init-optimizer {:optimizer (optimizer/adam {:learning-rate lr :episilon 1e-9})}))) + + (def metric (eval-metric/accuracy)) + (def num-epoch 1) + (def processed-datas (mapv #(pre-processing (context/default-context) idx->token token->idx %) + data-train-raw)) + + (doseq [epoch-num (range num-epoch)] + (let [counter (atom 0)] + (doseq [batch-data processed-datas] + (when (and (pos? @counter) + (zero? (mod @counter 20))) + (println "Working on " @counter " of " train-count " acc: " (eval-metric/get metric))) + (-> model + (m/forward {:data (:input-batch batch-data)}) + (m/update-metric metric [(:label batch-data)]) + (m/backward) + (m/update)) + (swap! counter inc)) + (println "result for epoch " epoch-num " is " (eval-metric/get-and-reset metric)))) + ) + + + #_(m/save-checkpoint model {:prefix "fine-tune-sentence-bert" :epoch 0 :save-opt-states true}) + + (def clojure-test-data (pre-processing (context/default-context) idx->token token->idx + ["Rich Hickey is the creator of the Clojure language." + "The Clojure language was Rich Hickey." "1"])) +(-> model + (m/forward {:data (:input-batch sample-data)}) + (m/outputs) + (ffirst) + (ndarray/->vec) + (zipmap [:equivalent :not-equivalent])) + +(-> model + (m/forward {:data (:input-batch clojure-test-data)}) + (m/outputs) + (ffirst) + (ndarray/->vec) + (zipmap [:equivalent :not-equivalent])) + + + + )