This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch clojure-bert-sentence-pair-classification
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 1c6e734dc53e16063f4c1dca18d8cb9d0fb5d2d2
Author: gigasquid <cme...@gigasquidsoftware.com>
AuthorDate: Fri Apr 19 20:13:38 2019 -0400

    gradients not exploding
---
 .../src/bert_qa/bert_sentence_classification.clj   | 203 ++++++++++++---------
 1 file changed, 117 insertions(+), 86 deletions(-)

diff --git 
a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj
 
b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj
index b7d4425..20257d2 100644
--- 
a/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj
+++ 
b/contrib/clojure-package/examples/bert-qa/src/bert_qa/bert_sentence_classification.clj
@@ -65,16 +65,19 @@
 
 
 (comment
+  
+ (do
 
-  ;;; load the pre-trained BERT model using the module api
-  (def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}))
-  ;;; now that we have loaded the BERT model we need to attach an additional 
layer for classification which is a dense layer with 2 classes
-  (def model-sym (fine-tune-model (m/symbol bert-base) 2))
-  (def arg-params (m/arg-params bert-base))
-  (def aux-params (m/aux-params bert-base))
-
-  (def devs [(context/default-context)])
-  (def input-descs [{:name "data0"
+;;; load the pre-trained BERT model using the module api
+  
+   (def bert-base (m/load-checkpoint {:prefix model-path-prefix :epoch 0}))
+;;; now that we have loaded the BERT model we need to attach an additional 
layer for classification which is a dense layer with 2 classes
+   (def model-sym (fine-tune-model (m/symbol bert-base) 2))
+   (def arg-params (m/arg-params bert-base))
+   (def aux-params (m/aux-params bert-base))
+
+   (def devs [(context/default-context)])
+   (def input-descs [{:name "data0"
                       :shape [1 seq-length]
                       :dtype dtype/FLOAT32
                       :layout layout/NT}
@@ -86,83 +89,111 @@
                       :shape [1]
                       :dtype dtype/FLOAT32
                       :layout layout/N}])
-  (def label-descs [{:name "softmax_label"
-                     :shape [1 2]
-                     :dtype dtype/FLOAT32
-                     :layout layout/NT}])
-
-  ;;; Data Preprocessing for BERT
-
-  ;; For demonstration purpose, we use the dev set of the Microsoft Research 
Paraphrase Corpus dataset. The file is named ‘dev.tsv’. Let’s take a look at 
the raw dataset.
-  ;; it contains 5 columns seperated by tabs
-  (def raw-file (->> (string/split (slurp "dev.tsv") #"\n")
-                     (map #(string/split % #"\t") )))
-  (def raw-file (csv/parse-csv (slurp "dev.tsv") :delimiter \tab))
-  (take 3 raw-file)
- ;; (["Quality" "#1 ID" "#2 ID" "#1 String" "#2 String"]
- ;; ["1"
- ;;  "1355540"
- ;;  "1355592"
- ;;  "He said the foodservice pie business doesn 't fit the company 's 
long-term growth strategy ."
- ;;  "\" The foodservice pie business does not fit our long-term growth 
strategy ."]
- ;; ["0"
- ;;  "2029631"
- ;;  "2029565"
- ;;  "Magnarelli said Racicot hated the Iraqi regime and looked forward to 
using his long years of training in the war ."
- ;;  "His wife said he was \" 100 percent behind George Bush \" and looked 
forward to using his years of training in the war ."])
-
-  ;;; for our task we are only interested in the 0 3rd and 4th column
-  (vals (select-keys (first raw-file) [3 4 0]))
-  ;=> ("#1 String" "#2 String" "Quality")
-  (def data-train-raw (->> raw-file
-                           (mapv #(vals (select-keys % [3 4 0])))
-                           (rest) ;;drop header
-                           (into [])
-                           ))
-  (def sample (first data-train-raw))
-  (nth sample 0) ;;;sentence a
-  ;=> "He said the foodservice pie business doesn 't fit the company 's 
long-term growth strategy ."
-  (nth sample 1) ;; sentence b
-  "\" The foodservice pie business does not fit our long-term growth strategy 
."
-
-  (nth sample 2) ; 1 means equivalent, 0 means not equivalent
-   ;=> "1"
-
-  ;;; Now we need to turn these into ndarrays to make a Data Iterator
-  (def vocab (bert-infer/get-vocab))
-  (def idx->token (:idx->token vocab))
-  (def token->idx (:token->idx vocab))
+   (def label-descs [{:name "softmax_label"
+                      :shape [1 2]
+                      :dtype dtype/FLOAT32
+                      :layout layout/NT}])
+
+;;; Data Preprocessing for BERT
+
+   ;; For demonstration purpose, we use the dev set of the Microsoft Research 
Paraphrase Corpus dataset. The file is named ‘dev.tsv’. Let’s take a look at 
the raw dataset.
+   ;; it contains 5 columns seperated by tabs
+   (def raw-file (->> (string/split (slurp "dev.tsv") #"\n")
+                      (map #(string/split % #"\t") )))
+   (def raw-file (csv/parse-csv (slurp "dev.tsv") :delimiter \tab))
+   (take 3 raw-file)
+   ;; (["Quality" "#1 ID" "#2 ID" "#1 String" "#2 String"]
+   ;; ["1"
+   ;;  "1355540"
+   ;;  "1355592"
+   ;;  "He said the foodservice pie business doesn 't fit the company 's 
long-term growth strategy ."
+   ;;  "\" The foodservice pie business does not fit our long-term growth 
strategy ."]
+   ;; ["0"
+   ;;  "2029631"
+   ;;  "2029565"
+   ;;  "Magnarelli said Racicot hated the Iraqi regime and looked forward to 
using his long years of training in the war ."
+   ;;  "His wife said he was \" 100 percent behind George Bush \" and looked 
forward to using his years of training in the war ."])
+
+;;; for our task we are only interested in the 0 3rd and 4th column
+   (vals (select-keys (first raw-file) [3 4 0]))
+                                        ;=> ("#1 String" "#2 String" 
"Quality")
+   (def data-train-raw (->> raw-file
+                            (mapv #(vals (select-keys % [3 4 0])))
+                            (rest) ;;drop header
+                            (into [])
+                            ))
+   (def sample (first data-train-raw))
+   (nth sample 0) ;;;sentence a
+                                        ;=> "He said the foodservice pie 
business doesn 't fit the company 's long-term growth strategy ."
+   (nth sample 1)                       ;; sentence b
+   "\" The foodservice pie business does not fit our long-term growth strategy 
."
+
+   (nth sample 2)         ; 1 means equivalent, 0 means not equivalent
+                                        ;=> "1"
+
+;;; Now we need to turn these into ndarrays to make a Data Iterator
+   (def vocab (bert-infer/get-vocab))
+   (def idx->token (:idx->token vocab))
+   (def token->idx (:token->idx vocab))
 
   
 
-  ;;; our sample item
-  (def sample-data (pre-processing (context/default-context) idx->token 
token->idx sample))
-
-  (def train-count (count data-train-raw))    ;=> 389
-
-    ;; now create the module
-  (def model (-> (m/module model-sym {:contexts devs
-                                      :data-names ["data0" "data1" "data2"]})
-                 (m/bind {:data-shapes input-descs :label-shapes label-descs})
-                 (m/init-params {:arg-params arg-params :aux-params aux-params
-                                 :allow-missing true})
-                 (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 
0.01 :momentum 0.9})})))
-
-  (def metric (eval-metric/accuracy))
-  (def num-epoch 3)
-  (def processed-datas (mapv #(pre-processing (context/default-context) 
idx->token token->idx %)
-                             data-train-raw))
-
-  (doseq [epoch-num (range num-epoch)]
-    (doall (map-indexed (fn [i batch-data]
-                          (-> model
-                              (m/forward {:data (:input-batch batch-data)})
-                              (m/update-metric metric [(:label batch-data)])
-                              (m/backward)
-                              (m/update))
-                          (when (mod i 10)
-                            (println "Working on " i " of " train-count " acc: 
" (eval-metric/get metric))))
-                        processed-datas))
-    (println "result for epoch " epoch-num " is " (eval-metric/get-and-reset 
metric)))
-
-  )
+;;; our sample item
+   (def sample-data (pre-processing (context/default-context) idx->token 
token->idx sample))
+
+   (def train-count (count data-train-raw)) ;=> 389
+
+   ;; now create the module
+
+   (def lr 5e-6)
+   
+   (def model (-> (m/module model-sym {:contexts devs
+                                       :data-names ["data0" "data1" "data2"]})
+                  (m/bind {:data-shapes input-descs :label-shapes label-descs})
+                  (m/init-params {:arg-params arg-params :aux-params aux-params
+                                  :allow-missing true})
+                  (m/init-optimizer {:optimizer (optimizer/adam 
{:learning-rate lr :episilon 1e-9})})))
+
+   (def metric (eval-metric/accuracy))
+   (def num-epoch 1)
+   (def processed-datas (mapv #(pre-processing (context/default-context) 
idx->token token->idx %)
+                              data-train-raw))
+
+   (doseq [epoch-num (range num-epoch)]
+     (let [counter (atom 0)]
+       (doseq [batch-data processed-datas]
+         (when (and (pos? @counter)
+                    (zero? (mod @counter 20)))
+           (println "Working on " @counter " of " train-count " acc: " 
(eval-metric/get metric)))
+         (-> model
+             (m/forward {:data (:input-batch batch-data)})
+             (m/update-metric metric [(:label batch-data)])
+             (m/backward)
+             (m/update))
+         (swap! counter inc))
+       (println "result for epoch " epoch-num " is " 
(eval-metric/get-and-reset metric))))
+   )
+
+
+ #_(m/save-checkpoint model {:prefix "fine-tune-sentence-bert" :epoch 0 
:save-opt-states true})
+
+ (def clojure-test-data (pre-processing (context/default-context) idx->token 
token->idx
+                                        ["Rich Hickey is the creator of the 
Clojure language."
+                                         "The Clojure language was Rich 
Hickey." "1"]))
+(-> model
+    (m/forward {:data (:input-batch sample-data)})
+    (m/outputs)
+    (ffirst)
+    (ndarray/->vec)
+    (zipmap [:equivalent :not-equivalent]))
+
+(-> model
+    (m/forward {:data (:input-batch clojure-test-data)})
+    (m/outputs)
+    (ffirst)
+    (ndarray/->vec)
+    (zipmap [:equivalent :not-equivalent]))
+
+
+
+ )

Reply via email to