szha closed pull request #12535: [MXNET-954] Implementation of 
Structured-Self-Attentive-Sentence-Embedding
URL: https://github.com/apache/incubator-mxnet/pull/12535
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/self_attentive_sentence_embedding/README.md 
b/example/self_attentive_sentence_embedding/README.md
new file mode 100644
index 00000000000..3f21944fb83
--- /dev/null
+++ b/example/self_attentive_sentence_embedding/README.md
@@ -0,0 +1,107 @@
+## Implementation of Structured-Self-Attentive-Sentence-Embedding
+
+This is an implementation of the paper [A Structured Self-Attentive Sentence 
Embedding](https://arxiv.org/abs/1703.03130). This program implements most of 
the details in the paper. Finally, the user reviews the emotional star ratings 
in the three experiments mentioned in the original paper, and used the same 
data set: [The reviews of Yelp 
Data](https://www.kaggle.com/yelp-dataset/yelp-dataset#yelp_academic_dataset_review.json).
 The model structure is as follows:
+
+![Bi_LSTM_Attention](./images/Bi_LSTM_Attention.png)
+
+
+
+## Requirments
+
+1. [Mxnet](https://mxnet.apache.org/)
+2. [Gluon NLP](https://gluon-nlp.mxnet.io)
+3. [Numpy](http://www.numpy.org/)
+4. [Scikit-Learn](http://scikit-learn.org/stable/)
+5. [Python3](https://www.python.org) 
+
+## Implemented
+
+1. **Attention mechanism proposed in the original paper.**
+
+
+   $$
+   A = softmax(W_{s2}tanh(W_{s1}H^T))
+   $$
+
+2. **Punishment constraints to ensure diversity of attention.**
+
+
+   $$
+   P = ||(AA^T-I)||_F^2
+   $$
+
+3. **Parameter pruning proposed in the appendix of the paper.**
+
+
+
+   ![prune weights](./images/prune_weights.png)
+
+4. **Gradient clip and learning rate decay.**
+
+5. **SoftmaxCrossEntropy with category weights**
+
+## For sentiment classification
+
+1. **Training parameter description**
+
+   ```python
+   parser.add_argument('--emsize', type=int, default=300,
+                           help='size of word embeddings')
+       parser.add_argument('--nhid', type=int, default=300,
+                           help='number of hidden units per layer')
+       parser.add_argument('--nlayers', type=int, default=1,
+                           help='number of layers in BiLSTM')
+       parser.add_argument('--attention-unit', type=int, default=350,
+                           help='number of attention unit')
+       parser.add_argument('--attention-hops', type=int, default=1,
+                           help='number of attention hops, for multi-hop 
attention model')
+       parser.add_argument('--drop-prob', type=float, default=0.5,
+                           help='dropout applied to layers (0 = no dropout)')
+       parser.add_argument('--clip', type=float, default=0.5,
+                           help='clip to prevent the too large grad in LSTM')
+       parser.add_argument('--nfc', type=int, default=512,
+                           help='hidden (fully connected) layer size for 
classifier MLP')
+       parser.add_argument('--lr', type=float, default=.001,
+                           help='initial learning rate')
+       parser.add_argument('--epochs', type=int, default=10,
+                           help='upper epoch limit')
+       parser.add_argument('--loss-name', type=str, default='sce', help='loss 
function name')
+       parser.add_argument('--seed', type=int, default=2018,
+                           help='random seed')
+   
+       parser.add_argument('--pool-way', type=str, default='flatten', 
help='pool att output way')
+       parser.add_argument('--prune-p', type=int, default=None, help='prune p 
size')
+       parser.add_argument('--prune-q', type=int, default=None, help='prune q 
size')
+   
+       parser.add_argument('--batch-size', type=int, default=64,
+                           help='batch size for training')
+       parser.add_argument('--class-number', type=int, default=5,
+                           help='number of classes')
+       parser.add_argument('--optimizer', type=str, default='Adam',
+                           help='type of optimizer')
+       parser.add_argument('--penalization-coeff', type=float, default=0.1,
+                           help='the penalization coefficient')
+   
+       parser.add_argument('--save', type=str, default='../models', help='path 
to save the final model')
+       parser.add_argument('--wv-name', type=str, choices={'glove', 'w2v', 
'fasttext', 'random'},
+                           default='random', help='word embedding way')
+       parser.add_argument('--data-json-path', type=str, 
default='../data/sub_review_labels.json', help='raw data path')
+       parser.add_argument('--formated-data-path', type=str,
+                           default='../data/formated_data.pkl', help='formated 
data path')
+   ```
+
+2. **Training details**
+
+   The original paper uses 500K data as the training set, 2000 data as the 
validation set, and 2000 as the test set. Due to personal machine restrictions, 
200 K is randomly selected as the training set and 2000 data is used as the 
validation set in the case of ensuring the data distribution and the original 
data. The weight of the WeightedSoftmaxCrossEntropy is set according to the 
proportion of the data category. If the data is different and needs to be used 
To use this loss function, you need to modify the value of the set class_weight 
yourself.
+
+   Training usage (parameters can be customized):  
+
+   ```python
+   python train_model.py --nlayers 1 --epochs 5 --attention-hops 2 --loss-name 
sce
+   ```
+
+## Reference
+
+1. **[A Structured Self-Attentive Sentence 
Embedding](https://arxiv.org/abs/1703.03130)** 
+
+2. **[The reviews of Yelp 
Data](https://www.kaggle.com/yelp-dataset/yelp-dataset#yelp_academic_dataset_review.json)**
diff --git a/example/self_attentive_sentence_embedding/README_CH.md 
b/example/self_attentive_sentence_embedding/README_CH.md
new file mode 100644
index 00000000000..14940981fe8
--- /dev/null
+++ b/example/self_attentive_sentence_embedding/README_CH.md
@@ -0,0 +1,106 @@
+##  基于自注意力结构的句子表示的实现
+
+这是一个关于[A Structured Self-Attentive Sentence 
Embedding](https://arxiv.org/abs/1703.03130) 
论文的实现。本程序实现了论文中的大多数细节,并选择原论文中提及到的三个实验中的用户评论情感星级进行实验,且采用相同的数据集:[The reviews of 
Yelp 
Data](https://www.kaggle.com/yelp-dataset/yelp-dataset#yelp_academic_dataset_review.json)
 进行实验,模型结构如下图:
+
+![Bi_LSTM_Attention](./images/Bi_LSTM_Attention.png)
+
+
+
+## 环境需要
+
+1. [Mxnet](https://mxnet.apache.org/)
+2. [Gluon NLP](https://gluon-nlp.mxnet.io)
+3. [Numpy](http://www.numpy.org/)
+4. [Scikit-Learn](http://scikit-learn.org/stable/)
+5. [Python3](https://www.python.org) 
+
+## 已实现的具体功能
+
+1. **论文提出的注意力机制**
+
+
+   $$
+   A = softmax(W_{s2}tanh(W_{s1}H^T))
+   $$
+
+2. **惩罚项约束以保证注意力的多样性**
+
+
+   $$
+   P = ||(AA^T-I)||_F^2
+   $$
+
+3. **论文附录中提出的参数修剪**
+
+
+
+   ![prune weights](./images/prune_weights.png)
+
+   4. **梯度修剪和学习率衰减**
+   5. **带类别权重的 SoftmaxCrossEntropy**
+
+   ## 论文模型用于评论星级分类
+
+   1. 训练参数说明:
+
+      ```python
+      parser.add_argument('--emsize', type=int, default=300,
+                              help='size of word embeddings')
+          parser.add_argument('--nhid', type=int, default=300,
+                              help='number of hidden units per layer')
+          parser.add_argument('--nlayers', type=int, default=1,
+                              help='number of layers in BiLSTM')
+          parser.add_argument('--attention-unit', type=int, default=350,
+                              help='number of attention unit')
+          parser.add_argument('--attention-hops', type=int, default=1,
+                              help='number of attention hops, for multi-hop 
attention model')
+          parser.add_argument('--drop-prob', type=float, default=0.5,
+                              help='dropout applied to layers (0 = no 
dropout)')
+          parser.add_argument('--clip', type=float, default=0.5,
+                              help='clip to prevent the too large grad in 
LSTM')
+          parser.add_argument('--nfc', type=int, default=512,
+                              help='hidden (fully connected) layer size for 
classifier MLP')
+          parser.add_argument('--lr', type=float, default=.001,
+                              help='initial learning rate')
+          parser.add_argument('--epochs', type=int, default=10,
+                              help='upper epoch limit')
+          parser.add_argument('--loss-name', type=str, default='sce', 
help='loss function name')
+          parser.add_argument('--seed', type=int, default=2018,
+                              help='random seed')
+      
+          parser.add_argument('--pool-way', type=str, default='flatten', 
help='pool att output way')
+          parser.add_argument('--prune-p', type=int, default=None, help='prune 
p size')
+          parser.add_argument('--prune-q', type=int, default=None, help='prune 
q size')
+      
+          parser.add_argument('--batch-size', type=int, default=64,
+                              help='batch size for training')
+          parser.add_argument('--class-number', type=int, default=5,
+                              help='number of classes')
+          parser.add_argument('--optimizer', type=str, default='Adam',
+                              help='type of optimizer')
+          parser.add_argument('--penalization-coeff', type=float, default=0.1,
+                              help='the penalization coefficient')
+      
+          parser.add_argument('--save', type=str, default='../models', 
help='path to save the final model')
+          parser.add_argument('--wv-name', type=str, choices={'glove', 'w2v', 
'fasttext', 'random'},
+                              default='random', help='word embedding way')
+          parser.add_argument('--data-json-path', type=str, 
default='../data/sub_review_labels.json', help='raw data path')
+          parser.add_argument('--formated-data-path', type=str,
+                              default='../data/formated_data.pkl', 
help='formated data path')
+      ```
+
+   2. **训练细节**
+
+      原论文中使用 500K 数据作为训练集, 2000 条数据作为验证集,2000 
条作为测试集。由于个人机器限制,所以采用在保证数据分布同原文数据的情况下随机抽取 200 K 作为 训练集,2000 
条数据作为验证集,根据数据类别的比例设置了 WeightedSoftmaxCrossEntropy 
的类别权重,若使用数据不同且需要使用这个损失函数,则需自己修改设定的 class_weight 的值。
+
+      训练使用(参数可自定义配置):
+
+      ```python
+      python train_model.py --nlayers 1 --epochs 5 --attention-hops 2 
--loss-name sce
+      ```
+
+   ## 引用参考
+
+   1. **[A Structured Self-Attentive Sentence 
Embedding](https://arxiv.org/abs/1703.03130)** 
+
+   2. **[The reviews of Yelp 
Data](https://www.kaggle.com/yelp-dataset/yelp-dataset#yelp_academic_dataset_review.json)**
diff --git a/example/self_attentive_sentence_embedding/code/models.py 
b/example/self_attentive_sentence_embedding/code/models.py
new file mode 100644
index 00000000000..a932cf7e099
--- /dev/null
+++ b/example/self_attentive_sentence_embedding/code/models.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# The module includes a attention layer,
+# a two-way LSTM combined with a attention mechanism for the model of 
sentiment classification.
+# author: kenjewu
+
+import mxnet as mx
+from mxnet import nd, gluon
+from mxnet.gluon import nn, rnn
+
+
+class SelfAttention(nn.HybridBlock):
+    def __init__(self, att_unit, att_hops, **kwargs):
+        super(SelfAttention, self).__init__(**kwargs)
+        with self.name_scope():
+            self.ut_dense = nn.Dense(att_unit, activation='tanh', 
flatten=False)
+            self.et_dense = nn.Dense(att_hops, activation=None, flatten=False)
+
+    def hybrid_forward(self, F, x):
+        # x shape: [batch_size, seq_len, embedding_width]
+        # ut shape: [batch_size, seq_len, att_unit]
+        ut = self.ut_dense(x)
+        # et shape: [batch_size, seq_len, att_hops]
+        et = self.et_dense(ut)
+
+        # at shape: [batch_size,  att_hops, seq_len]
+        at = F.softmax(F.transpose(et, axes=(0, 2, 1)), axis=-1)
+        # output shape [batch_size, att_hops, embedding_width]
+        output = F.batch_dot(at, x)
+
+        return output
+
+
+class SelfAttentiveBiLSTM(nn.HybridBlock):
+    def __init__(self, vocab_len, emsize, nhide, nlayers, att_unit, att_hops, 
nfc, nclass,
+                 drop_prob, pool_way, prune_p=None, prune_q=None, **kwargs):
+        super(SelfAttentiveBiLSTM, self).__init__(**kwargs)
+        with self.name_scope():
+            self.embedding_layer = nn.Embedding(vocab_len, emsize)
+            self.bilstm = rnn.LSTM(nhide, num_layers=nlayers, 
dropout=drop_prob, bidirectional=True)
+            self.att_encoder = SelfAttention(att_unit, att_hops)
+            self.dense = nn.Dense(nfc, activation='tanh')
+            self.output_layer = nn.Dense(nclass)
+
+            self.dense_p, self.dense_q = None, None
+            if all([prune_p, prune_q]):
+                self.dense_p = nn.Dense(prune_p, activation='tanh', 
flatten=False)
+                self.dense_q = nn.Dense(prune_q, activation='tanh', 
flatten=False)
+
+            self.drop_prob = drop_prob
+            self.pool_way = pool_way
+
+    def hybrid_forward(self, F, inp):
+        # input_embed: [batch, len, emsize]
+        inp_embed = self.embedding_layer(inp)
+        h_output = self.bilstm(F.transpose(inp_embed, axes=(1, 0, 2)))
+        # att_output: [batch, att_hops, emsize]
+        att_output = self.att_encoder(F.transpose(h_output, axes=(1, 0, 2)))
+
+        dense_input = None
+        if self.pool_way == 'flatten':
+            dense_input = F.Dropout(F.flatten(att_output), self.drop_prob)
+        elif self.pool_way == 'mean':
+            dense_input = F.Dropout(F.mean(att_output, axis=1), self.drop_prob)
+        elif self.pool_way == 'prune' and all([self.dense_p, self.dense_q]):
+            # p_section: [batch, att_hops, prune_p]
+            p_section = self.dense_p(att_output)
+            # q_section: [batch, emsize, prune_q]
+            q_section = self.dense_q(F.transpose(att_output, axes=(0, 2, 1)))
+            dense_input = F.Dropout(F.concat(F.flatten(p_section), 
F.flatten(q_section), dim=-1), self.drop_prob)
+
+        dense_out = self.dense(dense_input)
+        output = self.output_layer(F.Dropout(dense_out, self.drop_prob))
+
+        return output, att_output
diff --git a/example/self_attentive_sentence_embedding/code/prepare_data.py 
b/example/self_attentive_sentence_embedding/code/prepare_data.py
new file mode 100644
index 00000000000..9aa9bb5185f
--- /dev/null
+++ b/example/self_attentive_sentence_embedding/code/prepare_data.py
@@ -0,0 +1,169 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This module is used to parse the raw data and process the training data 
needed for the model.
+# author: kenjewu
+
+import mxnet as mx
+import numpy as np
+import gluonnlp as nlp
+
+import os
+import re
+import json
+import pickle
+import collections
+import warnings
+warnings.filterwarnings('ignore')
+
+from sklearn.model_selection import train_test_split
+
+
+UNK = '<unk>'
+PAD = '<pad>'
+
+
+def clean_str(string):
+    """
+    Tokenization/string cleaning.
+    Original from 
https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
+    """
+    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
+    string = re.sub(r"\'s", " \'s", string)
+    string = re.sub(r"\'ve", " \'ve", string)
+    string = re.sub(r"n\'t", " n\'t", string)
+    string = re.sub(r"\'re", " \'re", string)
+    string = re.sub(r"\'d", " \'d", string)
+    string = re.sub(r"\'ll", " \'ll", string)
+    string = re.sub(r",", " , ", string)
+    string = re.sub(r"!", " ! ", string)
+    string = re.sub(r"\(", " \( ", string)
+    string = re.sub(r"\)", " \) ", string)
+    string = re.sub(r"\?", " \? ", string)
+    string = re.sub(r"\s{2,}", " ", string)
+
+    return string.strip().lower()
+
+
+def pad_sequences(sequences, max_len, pad_value):
+    '''
+    Fill the sequence to the specified length, long truncation
+    Args:
+        sequences: A list of all sentences, a list of list
+        max_len: Specified maximum length
+        pad_value: Specified fill value
+    Returns:
+        pades_seqs: A numpy array
+    '''
+
+    # max_len = max(map(lambda x: len(x), sequences))
+
+    paded_seqs = np.zeros((len(sequences), max_len))
+    for idx, seq in enumerate(sequences):
+        paded = None
+        if len(seq) < max_len:
+            paded = np.array((seq + [pad_value] * (max_len - len(seq))))
+        else:
+            paded = np.array(seq[0:max_len])
+        paded_seqs[idx] = paded
+
+    return paded_seqs
+
+
+def get_vocab(sentences, wv_name):
+    '''
+    Get the vocab that is a instance of nlp.Vocab
+    Args:
+        sentences: all sentences, a list of str.
+        wv_name: one of {'glove', 'w2v', 'fasttext', 'random'}.The way the 
representative word is embedded.
+    Returns:
+        my_vocab: a instance of nlp.Vocab
+    '''
+    tokens = []
+    for sent in sentences:
+        tokens.extend(clean_str(sent).split())
+
+    token_counter = nlp.data.count_tokens(tokens)
+    my_vocab = nlp.Vocab(token_counter)
+
+    if wv_name == 'glove':
+        my_embedding = nlp.embedding.GloVe(source='glove.6B.50d', 
embedding_root='..data/embedding')
+    elif wv_name == 'w2v':
+        my_embedding = nlp.embedding.Word2Vec(
+            source='GoogleNews-vectors-negative300', 
embedding_root='..data/embedding')
+    elif wv_name == 'fasttext':
+        my_embedding = nlp.embedding.FastText(source='wiki.simple', 
embedding_root='..data/embedding')
+    else:
+        my_embedding = None
+
+    if my_embedding is not None:
+        my_vocab.set_embedding(my_embedding)
+
+    return my_vocab
+
+
+def sentences2idx(sentences, my_vocab):
+    '''
+    Convert all words of sentences their corresponding index in the vocabulary.
+    Args:
+        sentences: all sentences, a list of str.
+        my_vocab: a instance of nlp.Vocab
+    Retruns:
+        sentences_idx: all index of all words, a list of list.
+    '''
+    sentences_indices = []
+    for sent in sentences:
+        sentences_indices.append(my_vocab.to_indices(clean_str(sent).split()))
+    return sentences_indices
+
+
+def get_data(data_json_path, wv_name, formated_data_path):
+    '''
+    Process raw data and obtain standard data that can be used for model 
training.
+    Args:
+        data_json_path: the path of raw data. This is a json file.
+        wv_name: one of {'glove', 'w2v', 'fasttext', 'random'}.The way the 
representative word is embedded.
+        formated_data_path: The path to save the processed standard data.
+    Returns:
+        formated_data: A dict.
+    Returns
+    '''
+
+    if os.path.exists(formated_data_path):
+        with open(formated_data_path, 'rb') as f:
+            formated_data = pickle.load(f)
+    else:
+        with open(data_json_path, 'r', encoding='utf-8') as fr:
+            data = json.load(fr)
+        sentences, labels = data['texts'], data['labels']
+
+        my_vocab = get_vocab(sentences, wv_name)
+        pad_num_value = my_vocab.to_indices(PAD)
+
+        # 将输入数据转为整数索引
+        input_idx = sentences2idx(sentences, my_vocab)
+
+        # 准备训练和验证数据迭代器
+        max_seq_len = 100
+        input_paded = pad_sequences(input_idx, max_seq_len, pad_num_value)
+        labels = np.array(labels).reshape((-1, 1)) - 1
+
+        formated_data = {'x': input_paded, 'y': labels, 'vocab': my_vocab}
+        with open(formated_data_path, 'wb') as fw:
+            pickle.dump(formated_data, fw)
+
+    return formated_data
diff --git a/example/self_attentive_sentence_embedding/code/train_helper.py 
b/example/self_attentive_sentence_embedding/code/train_helper.py
new file mode 100644
index 00000000000..036032f964c
--- /dev/null
+++ b/example/self_attentive_sentence_embedding/code/train_helper.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Function used for auxiliary training
+# author:kenjewu
+
+import numpy as np
+from time import time
+
+import mxnet as mx
+from mxnet import autograd, gluon, nd
+from sklearn.metrics import accuracy_score, f1_score
+
+
+def train(
+        data_iter_train, data_iter_valid, model, loss, trainer, CTX, 
num_epochs, penal_coeff=0.0, clip=None,
+        class_weight=None, loss_name='wsce'):
+    '''
+    Function used in training
+    Args:
+        data_iter_train: the iter of training data
+        data_iter_valid: the iter of validation data
+        model: model to train
+        loss: loss function
+        trainer: the way of train
+        CTX: context
+        num_epochs: number of total epochs
+        penal_coeff: Penalty factor, default is 0.0
+        clip: gradient clipping threshold, default is None
+        class_weight: the weight of every class, default is None
+        loss_name: the name of loss function, default is 'wsce'
+    '''
+    print('Train on ', CTX)
+
+    for epoch in range(1, num_epochs + 1):
+        start = time()
+        train_loss = 0.
+        total_pred = []
+        total_true = []
+        n_batch = 0
+
+        for batch_x, batch_y in data_iter_train:
+            with autograd.record():
+                batch_pred, att_output = model(batch_x)
+                if loss_name == 'sce':
+                    l = loss(batch_pred, batch_y)
+                elif loss_name == 'wsce':
+                    l = loss(batch_pred, batch_y, class_weight, 
class_weight.shape[0])
+
+                # 惩罚项
+                temp = nd.batch_dot(att_output, nd.transpose(att_output, 
axes=(0, 2, 1))
+                                    ) - nd.eye(att_output.shape[1], ctx=CTX)
+                l = l + penal_coeff * temp.norm(axis=(1, 2))
+            l.backward()
+
+            # 梯度裁剪
+            clip_params = [p.data() for p in model.collect_params().values()]
+            if clip is not None:
+                norm = nd.array([0.0], CTX)
+                for param in clip_params:
+                    norm += (param.grad ** 2).sum()
+                norm = norm.sqrt().asscalar()
+                if norm > clip:
+                    for param in clip_params:
+                        param.grad[:] *= clip / norm
+
+            # 更新参数
+            trainer.step(batch_x.shape[0])
+
+            batch_pred = np.argmax(nd.softmax(batch_pred, axis=1).asnumpy(), 
axis=1)
+            batch_true = np.reshape(batch_y.asnumpy(), (-1, ))
+            total_pred.extend(batch_pred.tolist())
+            total_true.extend(batch_true.tolist())
+            batch_train_loss = l.mean().asscalar()
+
+            n_batch += 1
+            train_loss += batch_train_loss
+
+            if n_batch % 400 == 0:
+                print('epoch %d, batch %d, bach_train_loss %.4f, 
batch_train_acc %.3f' %
+                      (epoch, n_batch, batch_train_loss, 
accuracy_score(batch_true, batch_pred)))
+
+        F1_train = f1_score(np.array(total_true), np.array(total_pred), 
average='weighted')
+        acc_train = accuracy_score(np.array(total_true), np.array(total_pred))
+        train_loss /= n_batch
+
+        F1_valid, acc_valid, valid_loss = evaluate(data_iter_valid, model, 
loss, penal_coeff, class_weight, loss_name)
+
+        print('epoch %d, learning_rate %.5f \n\t train_loss %.4f, acc_train 
%.3f, F1_train %.3f, ' %
+              (epoch, trainer.learning_rate, train_loss, acc_train, F1_train))
+        print('\t valid_loss %.4f, acc_valid %.3f, F1_valid %.3f, '
+              '\ntime %.1f sec' % (valid_loss, acc_valid, F1_valid, time() - 
start))
+        print('='*50)
+
+        # 学习率衰减
+        if epoch % 2 == 0:
+            trainer.set_learning_rate(trainer.learning_rate * 0.9)
+
+
+def evaluate(data_iter_valid, model, loss, penal_coeff=0.0, class_weight=None, 
loss_name='wsce'):
+    '''
+    the evaluation function
+    Args:
+        data_iter_valid: the iter of validation data
+        model: model to train
+        loss: loss function
+        penal_coeff: Penalty factor, default is 0.0
+        class_weight: the weight of every class, default is None
+        loss_name: the name of loss function, default is 'wsce'
+    Returns:
+        F1_valid: the f1 score
+        acc_valid: the accuracy score
+        valid_loss: the value of loss
+    '''
+    valid_loss = 0.
+    total_pred = []
+    total_true = []
+    n_batch = 1
+    for batch_x, batch_y in data_iter_valid:
+        batch_pred, att_output = model(batch_x)
+        if loss_name == 'sce':
+            l = loss(batch_pred, batch_y)
+        elif loss_name == 'wsce':
+            l = loss(batch_pred, batch_y, class_weight, class_weight.shape[0])
+        # 惩罚项
+        temp = nd.batch_dot(att_output, nd.transpose(att_output, axes=(0, 2, 
1))
+                            ) - nd.eye(att_output.shape[1], 
ctx=att_output.context)
+        l = l + penal_coeff * temp.norm(axis=(1, 2))
+        total_pred.extend(np.argmax(nd.softmax(batch_pred, axis=1).asnumpy(), 
axis=1).tolist())
+        total_true.extend(np.reshape(batch_y.asnumpy(), (-1,)).tolist())
+        n_batch += 1
+        valid_loss += l.mean().asscalar()
+
+    F1_valid = f1_score(np.array(total_true), np.array(total_pred), 
average='weighted')
+    acc_valid = accuracy_score(np.array(total_true), np.array(total_pred))
+    valid_loss /= n_batch
+
+    return F1_valid, acc_valid, valid_loss
diff --git a/example/self_attentive_sentence_embedding/code/train_model.py 
b/example/self_attentive_sentence_embedding/code/train_model.py
new file mode 100644
index 00000000000..8db11735b15
--- /dev/null
+++ b/example/self_attentive_sentence_embedding/code/train_model.py
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This is the training script
+# author: kenjewu
+
+import os
+import warnings
+warnings.filterwarnings('ignore')
+import mxnet as mx
+import gluonnlp as nlp
+from mxnet import nd, gluon, init
+from mxnet.gluon.data import ArrayDataset, DataLoader
+
+import train_helper as th
+from utils import get_args
+from prepare_data import get_data
+from models import SelfAttentiveBiLSTM
+from weighted_softmaxCE import WeightedSoftmaxCE
+
+
+def try_gpu():
+    """If GPU is available, return mx.gpu(0); else return mx.cpu()."""
+    try:
+        ctx = mx.gpu()
+        _ = nd.array([0], ctx=ctx)
+    except:
+        ctx = mx.cpu()
+    return ctx
+
+
+if __name__ == '__main__':
+    # 解析参数 (Parsing command line arguments)
+    args = get_args()
+    emsize = args.emsize
+    nhide = args.nhid
+    nlayers = args.nlayers
+    att_unit = args.attention_unit
+    att_hops = args.attention_hops
+    nfc = args.nfc
+    nclass = args.class_number
+    drop_prob = args.drop_prob
+    pool_way = args.pool_way
+    prune_p = args.prune_p
+    prune_q = args.prune_q
+
+    penal_coeff = args.penalization_coeff
+    optim = args.optimizer
+    lr = args.lr
+    num_epochs = args.epochs
+    batch_size = args.batch_size
+    loss_name = args.loss_name
+    clip = args.clip
+
+    # 设置 mxnet 随机数种子 (Set mxnet random number seed)
+    mx.random.seed(args.seed)
+
+    # 设置 gpu 或者 cpu (set the useful of gpu or cpu)
+    ctx = try_gpu()
+
+    # 获取训练数据与验证数据集 (Get training data and validation data set)
+    print('Getting the data...')
+    data = get_data(args.data_json_path, args.wv_name, args.formated_data_path)
+    x, y, my_vocab = data['x'], data['y'], data['vocab']
+
+    if any([args.wv_name == 'glove', args.wv_name == 'fasttext', args.wv_name 
== 'w2v']):
+        embedding_weights = my_vocab.embedding.idx_to_vec
+    else:
+        embedding_weights = None
+
+    data_set = ArrayDataset(nd.array(x, ctx=ctx), nd.array(y, ctx=ctx))
+    train_data_set, valid_data_set = nlp.data.train_valid_split(data_set, 0.01)
+    data_iter_train = DataLoader(train_data_set, batch_size=batch_size, 
shuffle=True, last_batch='rollover')
+    data_iter_valid = DataLoader(valid_data_set, batch_size=batch_size, 
shuffle=False)
+
+    # 配置模型 (Configuration model)
+    vocab_len = len(my_vocab)
+    model = SelfAttentiveBiLSTM(vocab_len, emsize, nhide, nlayers, att_unit, 
att_hops, nfc, nclass,
+                                drop_prob, pool_way, prune_p, prune_q)
+    model.initialize(init=init.Xavier(), ctx=ctx)
+    model.hybridize()
+    if embedding_weights is not None:
+        model.embedding_layer.weight.set_data(embedding_weights)
+        model.embedding_layer.collect_params().setattr('grad_req', 'null')
+
+    trainer = gluon.Trainer(model.collect_params(), optim, {'learning_rate': 
lr})
+
+    class_weight = None
+    if loss_name == 'sce':
+        loss = gluon.loss.SoftmaxCrossEntropyLoss()
+    elif loss_name == 'wsce':
+        loss = WeightedSoftmaxCE()
+        class_weight = nd.array([3.0, 5.3, 4.0, 2.0, 1.0], ctx=ctx)
+
+    # 训练 (Train)
+    th.train(data_iter_train, data_iter_valid, model, loss, trainer, ctx,
+             num_epochs, penal_coeff=penal_coeff, clip=clip, 
class_weight=class_weight, loss_name=loss_name)
+
+    # 保存模型 (Save the structure and parameters of the model)
+    model_dir = args.save
+    if not os.path.exists(model_dir):
+        os.makedirs(model_dir)
+    model_path = os.path.join(model_dir, 'self_att_bilstm_model')
+    model.export(model_path)
+    print('模型训练完毕,训练好的模型已经保存于:', model_path)
diff --git a/example/self_attentive_sentence_embedding/code/utils.py 
b/example/self_attentive_sentence_embedding/code/utils.py
new file mode 100644
index 00000000000..c23b9265bca
--- /dev/null
+++ b/example/self_attentive_sentence_embedding/code/utils.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# Some general function modules
+# author: kenjewu
+
+
+import argparse
+
+
+def get_args():
+    '''
+    Parsing to get command line arguments
+    '''
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--emsize', type=int, default=300,
+                        help='size of word embeddings')
+    parser.add_argument('--nhid', type=int, default=300,
+                        help='number of hidden units per layer')
+    parser.add_argument('--nlayers', type=int, default=1,
+                        help='number of layers in BiLSTM')
+    parser.add_argument('--attention-unit', type=int, default=350,
+                        help='number of attention unit')
+    parser.add_argument('--attention-hops', type=int, default=1,
+                        help='number of attention hops, for multi-hop 
attention model')
+    parser.add_argument('--drop-prob', type=float, default=0.5,
+                        help='dropout applied to layers (0 = no dropout)')
+    parser.add_argument('--clip', type=float, default=0.5,
+                        help='clip to prevent the too large grad in LSTM')
+    parser.add_argument('--nfc', type=int, default=512,
+                        help='hidden (fully connected) layer size for 
classifier MLP')
+    parser.add_argument('--lr', type=float, default=.001,
+                        help='initial learning rate')
+    parser.add_argument('--epochs', type=int, default=10,
+                        help='upper epoch limit')
+    parser.add_argument('--loss-name', type=str, default='sce', help='loss 
function name')
+    parser.add_argument('--seed', type=int, default=2018,
+                        help='random seed')
+
+    parser.add_argument('--pool-way', type=str, default='flatten', help='pool 
att output way')
+    parser.add_argument('--prune-p', type=int, default=None, help='prune p 
size')
+    parser.add_argument('--prune-q', type=int, default=None, help='prune q 
size')
+
+    parser.add_argument('--batch-size', type=int, default=64,
+                        help='batch size for training')
+    parser.add_argument('--class-number', type=int, default=5,
+                        help='number of classes')
+    parser.add_argument('--optimizer', type=str, default='Adam',
+                        help='type of optimizer')
+    parser.add_argument('--penalization-coeff', type=float, default=0.1,
+                        help='the penalization coefficient')
+
+    parser.add_argument('--save', type=str, default='../models', help='path to 
save the final model')
+    parser.add_argument('--wv-name', type=str, choices={'glove', 'w2v', 
'fasttext', 'random'},
+                        default='random', help='word embedding way')
+    parser.add_argument('--data-json-path', type=str, 
default='../data/sub_review_labels.json', help='raw data path')
+    parser.add_argument('--formated-data-path', type=str,
+                        default='../data/formated_data.pkl', help='formated 
data path')
+
+    return parser.parse_args()
diff --git 
a/example/self_attentive_sentence_embedding/code/weighted_softmaxCE.py 
b/example/self_attentive_sentence_embedding/code/weighted_softmaxCE.py
new file mode 100644
index 00000000000..bed1b4b238e
--- /dev/null
+++ b/example/self_attentive_sentence_embedding/code/weighted_softmaxCE.py
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# weighted softmax cross entropy layer
+# author: kenjewu
+
+import mxnet as mx
+from mxnet.gluon import nn
+
+
+class WeightedSoftmaxCE(nn.HybridBlock):
+    def __init__(self, sparse_label=True, from_logits=False,  **kwargs):
+        super(WeightedSoftmaxCE, self).__init__(**kwargs)
+        with self.name_scope():
+            self.sparse_label = sparse_label
+            self.from_logits = from_logits
+
+    def hybrid_forward(self, F, pred, label, class_weight, depth=None):
+        if self.sparse_label:
+            label = F.reshape(label, shape=(-1, ))
+            label = F.one_hot(label, depth)
+        if not self.from_logits:
+            pred = F.log_softmax(pred, -1)
+
+        weight_label = F.broadcast_mul(label, class_weight)
+        loss = -F.sum(pred * weight_label, axis=-1)
+
+        # return F.mean(loss, axis=0, exclude=True)
+        return loss
diff --git 
a/example/self_attentive_sentence_embedding/images/Bi_LSTM_Attention.png 
b/example/self_attentive_sentence_embedding/images/Bi_LSTM_Attention.png
new file mode 100644
index 00000000000..245eb44abdd
Binary files /dev/null and 
b/example/self_attentive_sentence_embedding/images/Bi_LSTM_Attention.png differ
diff --git a/example/self_attentive_sentence_embedding/images/prune_weights.png 
b/example/self_attentive_sentence_embedding/images/prune_weights.png
new file mode 100644
index 00000000000..8cb55c19bb3
Binary files /dev/null and 
b/example/self_attentive_sentence_embedding/images/prune_weights.png differ
diff --git a/example/self_attentive_sentence_embedding/models/READEME.md 
b/example/self_attentive_sentence_embedding/models/READEME.md
new file mode 100644
index 00000000000..594244904b5
--- /dev/null
+++ b/example/self_attentive_sentence_embedding/models/READEME.md
@@ -0,0 +1,2 @@
+## Models dir
+


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to