OPENNLP-1009 - minor improvements / fixes
Project: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/repo Commit: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/commit/6bfb15f0 Tree: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/tree/6bfb15f0 Diff: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/diff/6bfb15f0 Branch: refs/heads/master Commit: 6bfb15f07c4639b2b1fe1940cf4f846acb3e1401 Parents: a63ec16 Author: Tommaso Teofili <[email protected]> Authored: Tue May 9 16:40:12 2017 +0200 Committer: Tommaso Teofili <[email protected]> Committed: Tue May 9 16:40:12 2017 +0200 ---------------------------------------------------------------------- .../src/main/java/opennlp/tools/dl/RNN.java | 45 ++++++-------- .../main/java/opennlp/tools/dl/StackedRNN.java | 65 ++++++++++---------- .../src/test/java/opennlp/tools/dl/RNNTest.java | 18 +++--- .../java/opennlp/tools/dl/StackedRNNTest.java | 17 +++-- 4 files changed, 69 insertions(+), 76 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/6bfb15f0/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java ---------------------------------------------------------------------- diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java index 155ec03..2fabecd 100644 --- a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java +++ b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java @@ -55,6 +55,7 @@ public class RNN { protected final int hiddenLayerSize; protected final int epochs; protected final boolean useChars; + protected final int batch; protected final int vocabSize; protected final Map<String, Integer> charToIx; protected final Map<Integer, String> ixToChar; @@ -71,14 +72,15 @@ public class RNN { private INDArray hPrev = null; // memory state public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) { - this(learningRate, seqLength, hiddenLayerSize, epochs, text, true); + this(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, true); } - public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) { + public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, int batch, boolean useChars) { this.learningRate = learningRate; this.seqLength = seqLength; this.hiddenLayerSize = hiddenLayerSize; this.epochs = epochs; + this.batch = batch; this.useChars = useChars; String[] textTokens = useChars ? toStrings(text.toCharArray()) : text.split(" "); @@ -169,21 +171,24 @@ public class RNN { System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress } - // perform parameter update with Adagrad - mWxh.addi(dWxh.mul(dWxh)); - wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(reg)))); + if (n% batch == 0) { - mWhh.addi(dWhh.mul(dWhh)); - whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg)))); + // perform parameter update with Adagrad + mWxh.addi(dWxh.mul(dWxh)); + wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(reg)))); - mWhy.addi(dWhy.mul(dWhy)); - why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(reg)))); + mWhh.addi(dWhh.mul(dWhh)); + whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg)))); - mbh.addi(dbh.mul(dbh)); - bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg)))); + mWhy.addi(dWhy.mul(dWhy)); + why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(reg)))); - mby.addi(dby.mul(dby)); - by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg)))); + mbh.addi(dbh.mul(dbh)); + bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg)))); + + mby.addi(dby.mul(dby)); + by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg)))); + } p += seqLength; // move data pointer n++; // iteration counter @@ -244,7 +249,7 @@ public class RNN { } // backward pass: compute gradients going backwards - INDArray dhNext = Nd4j.zerosLike(hs.getRow(0)); + INDArray dhNext = Nd4j.zerosLike(hPrev); for (int t = inputs.length() - 1; t >= 0; t--) { INDArray dy = ps.getRow(t); dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // backprop into y @@ -334,17 +339,7 @@ public class RNN { ", epochs=" + epochs + ", vocabSize=" + vocabSize + ", useChars=" + useChars + - '}'; - } - - - public String getHyperparamsString() { - return getClass().getName() + "{" + - "wxh=" + wxh + - ", whh=" + whh + - ", why=" + why + - ", bh=" + bh + - ", by=" + by + + ", batch=" + batch + '}'; } http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/6bfb15f0/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java ---------------------------------------------------------------------- diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java index e7a49d7..e6ceb9b 100644 --- a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java +++ b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java @@ -18,14 +18,6 @@ */ package opennlp.tools.dl; -import org.apache.commons.math3.distribution.EnumeratedDistribution; -import org.apache.commons.math3.util.Pair; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.SetRange; -import org.nd4j.linalg.api.ops.impl.transforms.SoftMax; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.ops.transforms.Transforms; - import java.io.BufferedWriter; import java.io.File; import java.io.FileWriter; @@ -34,6 +26,13 @@ import java.util.Date; import java.util.LinkedList; import java.util.List; +import org.apache.commons.math3.distribution.EnumeratedDistribution; +import org.apache.commons.math3.util.Pair; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.SoftMax; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.ops.transforms.Transforms; + /** * A basic char/word-level stacked RNN model (2 hidden recurrent layers), based on Stacked RNN architecture from ICLR 2014's * "How to Construct Deep Recurrent Neural Networks" by Razvan Pascanu, Caglar Gulcehre, Kyunghyun Cho and Yoshua Bengio @@ -61,11 +60,11 @@ public class StackedRNN extends RNN { private INDArray hPrev2 = null; // memory state public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) { - this(learningRate, seqLength, hiddenLayerSize, epochs, text, true); + this(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, true); } - public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) { - super(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars); + public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, int batch, boolean useChars) { + super(learningRate, seqLength, hiddenLayerSize, epochs, text, batch, useChars); wxh = Nd4j.randn(hiddenLayerSize, vocabSize).div(Math.sqrt(hiddenLayerSize)); whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize)); @@ -141,30 +140,32 @@ public class StackedRNN extends RNN { System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress } - // perform parameter update with Adagrad - mWxh.addi(dWxh.mul(dWxh)); - wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg)))); + if (n % batch == 0) { + // perform parameter update with Adagrad + mWxh.addi(dWxh.mul(dWxh)); + wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg)))); - mWxh2.addi(dWxh2.mul(dWxh2)); - wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2.add(reg)))); + mWxh2.addi(dWxh2.mul(dWxh2)); + wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2.add(reg)))); - mWhh.addi(dWhh.mul(dWhh)); - whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg)))); + mWhh.addi(dWhh.mul(dWhh)); + whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg)))); - mWhh2.addi(dWhh2.mul(dWhh2)); - whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.add(reg)))); + mWhh2.addi(dWhh2.mul(dWhh2)); + whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.add(reg)))); - mbh2.addi(dbh2.mul(dbh2)); - bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.add(reg)))); + mbh2.addi(dbh2.mul(dbh2)); + bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.add(reg)))); - mWh2y.addi(dWh2y.mul(dWh2y)); - wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.add(reg)))); + mWh2y.addi(dWh2y.mul(dWh2y)); + wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.add(reg)))); - mbh.addi(dbh.mul(dbh)); - bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg)))); + mbh.addi(dbh.mul(dbh)); + bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg)))); - mby.addi(dby.mul(dby)); - by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg)))); + mby.addi(dby.mul(dby)); + by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg)))); + } p += seqLength; // move data pointer n++; // iteration counter @@ -176,7 +177,7 @@ public class StackedRNN extends RNN { * hprev is Hx1 array of initial hidden state * returns the loss, gradients on model parameters and last hidden state */ - private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWxh2, INDArray dWhh2, INDArray dWh2y, + private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWxh2, INDArray dWhh2, INDArray dWh2y, INDArray dbh, INDArray dbh2, INDArray dby) { INDArray xs = Nd4j.zeros(seqLength, vocabSize); @@ -222,8 +223,9 @@ public class StackedRNN extends RNN { } // backward pass: compute gradients going backwards - INDArray dhNext = Nd4j.zerosLike(hs.getRow(0)); - INDArray dh2Next = Nd4j.zerosLike(hs2.getRow(0)); + INDArray dhNext = Nd4j.zerosLike(hPrev); + INDArray dh2Next = Nd4j.zerosLike(hPrev2); + for (int t = seqLength - 1; t >= 0; t--) { INDArray dy = ps.getRow(t); dy.getRow(targets.getInt(t)).subi(1); // backprop into y @@ -249,7 +251,6 @@ public class StackedRNN extends RNN { INDArray hsRow = t == 0 ? hPrev : hs.getRow(t - 1); dWhh.addi(dhraw.mmul(hsRow.transpose())); dhNext = whh.transpose().mmul(dhraw); - } this.hPrev = hs.getRow(seqLength - 1); http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/6bfb15f0/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java ---------------------------------------------------------------------- diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java index 2808f4d..57f7682 100644 --- a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java +++ b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java @@ -18,7 +18,6 @@ */ package opennlp.tools.dl; -import java.io.FileInputStream; import java.io.InputStream; import java.util.Arrays; import java.util.Collection; @@ -40,44 +39,44 @@ public class RNNTest { private float learningRate; private int seqLength; private int hiddenLayerSize; + private int epochs; + private Random r = new Random(); private String text; - private final int epochs = 20; private List<String> words; - public RNNTest(float learningRate, int seqLength, int hiddenLayerSize) { + public RNNTest(float learningRate, int seqLength, int hiddenLayerSize, int epochs) { this.learningRate = learningRate; this.seqLength = seqLength; this.hiddenLayerSize = hiddenLayerSize; + this.epochs = epochs; } @Before public void setUp() throws Exception { InputStream stream = getClass().getResourceAsStream("/text/sentences.txt"); text = IOUtils.toString(stream); - words = Arrays.asList(text.split(" ")); + words = Arrays.asList(text.split("\\s")); stream.close(); } @Parameterized.Parameters public static Collection<Object[]> data() { return Arrays.asList(new Object[][] { - {1e-1f, 25, 20}, - {1e-1f, 25, 40}, - {1e-1f, 25, 60}, + {1e-1f, 15, 20, 5}, }); } @Test public void testVanillaCharRNNLearn() throws Exception { - RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text); + RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true); evaluate(rnn, true); rnn.serialize("target/crnn-weights-"); } @Test public void testVanillaWordRNNLearn() throws Exception { - RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs * 2, text, false); + RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false); evaluate(rnn, true); rnn.serialize("target/wrnn-weights-"); } @@ -89,7 +88,6 @@ public class RNNTest { for (int i = 0; i < 2; i++) { int seed = r.nextInt(rnn.getVocabSize()); String sample = rnn.sample(seed); - System.out.println(sample); if (checkRatio && rnn.useChars) { String[] sampleWords = sample.split(" "); for (String sw : sampleWords) { http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/6bfb15f0/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java ---------------------------------------------------------------------- diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java index ac0434c..686d603 100644 --- a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java +++ b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java @@ -39,44 +39,44 @@ public class StackedRNNTest { private float learningRate; private int seqLength; private int hiddenLayerSize; + private int epochs; + private Random r = new Random(); private String text; - private final int epochs = 20; private List<String> words; - public StackedRNNTest(float learningRate, int seqLength, int hiddenLayerSize) { + public StackedRNNTest(float learningRate, int seqLength, int hiddenLayerSize, int epochs) { this.learningRate = learningRate; this.seqLength = seqLength; this.hiddenLayerSize = hiddenLayerSize; + this.epochs = epochs; } @Before public void setUp() throws Exception { InputStream stream = getClass().getResourceAsStream("/text/sentences.txt"); text = IOUtils.toString(stream); - words = Arrays.asList(text.split(" ")); + words = Arrays.asList(text.split("\\s")); stream.close(); } @Parameterized.Parameters public static Collection<Object[]> data() { return Arrays.asList(new Object[][] { - {1e-1f, 25, 20}, - {1e-1f, 25, 40}, - {1e-1f, 25, 60}, + {1e-1f, 15, 20, 5}, }); } @Test public void testStackedCharRNNLearn() throws Exception { - RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text); + RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true); evaluate(rnn, true); rnn.serialize("target/scrnn-weights-"); } @Test public void testStackedWordRNNLearn() throws Exception { - RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false); + RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false); evaluate(rnn, true); rnn.serialize("target/swrnn-weights-"); } @@ -88,7 +88,6 @@ public class StackedRNNTest { for (int i = 0; i < 2; i++) { int seed = r.nextInt(rnn.getVocabSize()); String sample = rnn.sample(seed); - System.out.println(sample); if (checkRatio && rnn.useChars) { String[] sampleWords = sample.split(" "); for (String sw : sampleWords) {
