This is an automated email from the ASF dual-hosted git repository. nswamy pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new c71a2f3 Usability improvement bi lstm sort (#8944) c71a2f3 is described below commit c71a2f3bc9a815b8ef9b6caa615013c74deb10ad Author: Anirudh Subramanian <anirudh2...@gmail.com> AuthorDate: Tue Dec 19 21:45:17 2017 -0800 Usability improvement bi lstm sort (#8944) * Improve usability for the bilstm example * Remove argparse from infer_sort since it changes existing usage --- example/bi-lstm-sort/README.md | 48 +++++++++++------------- example/bi-lstm-sort/infer_sort.py | 25 ++++++++++--- example/bi-lstm-sort/lstm.py | 1 - example/bi-lstm-sort/lstm_sort.py | 75 +++++++++++++++++++++++++++++++++----- example/bi-lstm-sort/rnn_model.py | 1 - example/bi-lstm-sort/sort_io.py | 1 - 6 files changed, 106 insertions(+), 45 deletions(-) diff --git a/example/bi-lstm-sort/README.md b/example/bi-lstm-sort/README.md index a590a18..3bacc86 100644 --- a/example/bi-lstm-sort/README.md +++ b/example/bi-lstm-sort/README.md @@ -1,28 +1,24 @@ This is an example of using bidirection lstm to sort an array. -Firstly, generate data by: - - python gen_data.py - -Move generated txt files to data directory - - mkdir data - mv *.txt data - -Then, train the model by: - - python lstm_sort.py - -At last, test model by: - - python infer_sort.py 234 189 785 763 231 - -and will output sorted seq - - 189 - 231 - 234 - 763 - 785 - - +Run the training script by doing the following: + +``` +python lstm_sort.py --start-range 100 --end-range 1000 --cpu +``` +You can provide the start-range and end-range for the numbers and whether to train on the cpu or not. +By default the script tries to train on the GPU. The default start-range is 100 and end-range is 1000. + +At last, test model by doing the following: + +``` +python infer_sort.py 234 189 785 763 231 +``` + +This should output the sorted seq like the following: +``` +189 +231 +234 +763 +785 +``` diff --git a/example/bi-lstm-sort/infer_sort.py b/example/bi-lstm-sort/infer_sort.py index b074c03..f81c6c0 100644 --- a/example/bi-lstm-sort/infer_sort.py +++ b/example/bi-lstm-sort/infer_sort.py @@ -18,20 +18,29 @@ # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme # pylint: disable=superfluous-parens, no-member, invalid-name import sys -sys.path.insert(0, "../../python") +import os +import argparse import numpy as np import mxnet as mx from sort_io import BucketSentenceIter, default_build_vocab from rnn_model import BiLSTMInferenceModel +TRAIN_FILE = "sort.train.txt" +TEST_FILE = "sort.test.txt" +VALID_FILE = "sort.valid.txt" +DATA_DIR = os.path.join(os.getcwd(), "data") +SEQ_LEN = 5 + def MakeInput(char, vocab, arr): idx = vocab[char] tmp = np.zeros((1,)) tmp[0] = idx arr[:] = tmp -if __name__ == '__main__': +def main(): + tks = sys.argv[1:] + assert len(tks) >= 5, "Please provide 5 numbers for sorting as sequence length is 5" batch_size = 1 buckets = [] num_hidden = 300 @@ -42,20 +51,21 @@ if __name__ == '__main__': learning_rate = 0.1 momentum = 0.9 - contexts = [mx.context.gpu(i) for i in range(1)] + contexts = [mx.context.cpu(i) for i in range(1)] - vocab = default_build_vocab("./data/sort.train.txt") + vocab = default_build_vocab(os.path.join(DATA_DIR, TRAIN_FILE)) rvocab = {} for k, v in vocab.items(): rvocab[v] = k _, arg_params, __ = mx.model.load_checkpoint("sort", 1) + for tk in tks: + assert (tk in vocab), "{} not in range of numbers that the model trained for.".format(tk) - model = BiLSTMInferenceModel(5, len(vocab), + model = BiLSTMInferenceModel(SEQ_LEN, len(vocab), num_hidden=num_hidden, num_embed=num_embed, num_label=len(vocab), arg_params=arg_params, ctx=contexts, dropout=0.0) - tks = sys.argv[1:] data = np.zeros((1, len(tks))) for k in range(len(tks)): data[0][k] = vocab[tks[k]] @@ -65,3 +75,6 @@ if __name__ == '__main__': for k in range(len(tks)): print(rvocab[np.argmax(prob, axis = 1)[k]]) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/example/bi-lstm-sort/lstm.py b/example/bi-lstm-sort/lstm.py index a082092..362481d 100644 --- a/example/bi-lstm-sort/lstm.py +++ b/example/bi-lstm-sort/lstm.py @@ -17,7 +17,6 @@ # pylint:skip-file import sys -sys.path.insert(0, "../../python") import mxnet as mx import numpy as np from collections import namedtuple diff --git a/example/bi-lstm-sort/lstm_sort.py b/example/bi-lstm-sort/lstm_sort.py index aef88b8..3fd4a2a 100644 --- a/example/bi-lstm-sort/lstm_sort.py +++ b/example/bi-lstm-sort/lstm_sort.py @@ -17,14 +17,65 @@ # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme # pylint: disable=superfluous-parens, no-member, invalid-name +import os import sys -sys.path.insert(0, "../../python") import numpy as np import mxnet as mx +import random +import argparse from lstm import bi_lstm_unroll from sort_io import BucketSentenceIter, default_build_vocab +import logging +head = '%(asctime)-15s %(message)s' +logging.basicConfig(level=logging.DEBUG, format=head) + + +TRAIN_FILE = "sort.train.txt" +TEST_FILE = "sort.test.txt" +VALID_FILE = "sort.valid.txt" +DATA_DIR = os.path.join(os.getcwd(), "data") +SEQ_LEN = 5 + +def gen_data(seq_len, start_range, end_range): + if not os.path.exists(DATA_DIR): + try: + logging.info('create directory %s', DATA_DIR) + os.makedirs(DATA_DIR) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise OSError('failed to create ' + DATA_DIR) + vocab = [str(x) for x in range(start_range, end_range)] + sw_train = open(os.path.join(DATA_DIR, TRAIN_FILE), "w") + sw_test = open(os.path.join(DATA_DIR, TEST_FILE), "w") + sw_valid = open(os.path.join(DATA_DIR, VALID_FILE), "w") + + for i in range(1000000): + seq = " ".join([vocab[random.randint(0, len(vocab) - 1)] for j in range(seq_len)]) + k = i % 50 + if k == 0: + sw_test.write(seq + "\n") + elif k == 1: + sw_valid.write(seq + "\n") + else: + sw_train.write(seq + "\n") + + sw_train.close() + sw_test.close() + +def parse_args(): + parser = argparse.ArgumentParser(description="Parse args for lstm_sort example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--start-range', type=int, default=100, + help='starting number of the range') + parser.add_argument('--end-range', type=int, default=1000, + help='Ending number of the range') + parser.add_argument('--cpu', action='store_true', + help='To use CPU for training') + return parser.parse_args() + + def Perplexity(label, pred): label = label.T.reshape((-1,)) loss = 0. @@ -32,7 +83,9 @@ def Perplexity(label, pred): loss += -np.log(max(1e-10, pred[i][int(label[i])])) return np.exp(loss / label.size) -if __name__ == '__main__': +def main(): + args = parse_args() + gen_data(SEQ_LEN, args.start_range, args.end_range) batch_size = 100 buckets = [] num_hidden = 300 @@ -43,9 +96,12 @@ if __name__ == '__main__': learning_rate = 0.1 momentum = 0.9 - contexts = [mx.context.gpu(i) for i in range(1)] + if args.cpu: + contexts = [mx.context.cpu(i) for i in range(1)] + else: + contexts = [mx.context.gpu(i) for i in range(1)] - vocab = default_build_vocab("./data/sort.train.txt") + vocab = default_build_vocab(os.path.join(DATA_DIR, TRAIN_FILE)) def sym_gen(seq_len): return bi_lstm_unroll(seq_len, len(vocab), @@ -56,9 +112,9 @@ if __name__ == '__main__': init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] init_states = init_c + init_h - data_train = BucketSentenceIter("./data/sort.train.txt", vocab, + data_train = BucketSentenceIter(os.path.join(DATA_DIR, TRAIN_FILE), vocab, buckets, batch_size, init_states) - data_val = BucketSentenceIter("./data/sort.valid.txt", vocab, + data_val = BucketSentenceIter(os.path.join(DATA_DIR, VALID_FILE), vocab, buckets, batch_size, init_states) if len(buckets) == 1: @@ -74,12 +130,11 @@ if __name__ == '__main__': wd=0.00001, initializer=mx.init.Xavier(factor_type="in", magnitude=2.34)) - import logging - head = '%(asctime)-15s %(message)s' - logging.basicConfig(level=logging.DEBUG, format=head) - model.fit(X=data_train, eval_data=data_val, eval_metric = mx.metric.np(Perplexity), batch_end_callback=mx.callback.Speedometer(batch_size, 50),) model.save("sort") + +if __name__ == '__main__': + sys.exit(main()) diff --git a/example/bi-lstm-sort/rnn_model.py b/example/bi-lstm-sort/rnn_model.py index 202aae6..1079e90 100644 --- a/example/bi-lstm-sort/rnn_model.py +++ b/example/bi-lstm-sort/rnn_model.py @@ -18,7 +18,6 @@ # pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme # pylint: disable=superfluous-parens, no-member, invalid-name import sys -sys.path.insert(0, "../../python") import numpy as np import mxnet as mx diff --git a/example/bi-lstm-sort/sort_io.py b/example/bi-lstm-sort/sort_io.py index 8cb44c6..853d0ee 100644 --- a/example/bi-lstm-sort/sort_io.py +++ b/example/bi-lstm-sort/sort_io.py @@ -19,7 +19,6 @@ # pylint: disable=superfluous-parens, no-member, invalid-name from __future__ import print_function import sys -sys.path.insert(0, "../../python") import numpy as np import mxnet as mx -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].