Repository: opennlp-sandbox Updated Branches: refs/heads/master 2707f6656 -> 96c088b00
Add first draft of dl name finder Project: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/repo Commit: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/commit/96c088b0 Tree: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/tree/96c088b0 Diff: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/diff/96c088b0 Branch: refs/heads/master Commit: 96c088b0036e1ced88f1945b6d6064779a5c3dc6 Parents: 2707f66 Author: Jörn Kottmann <[email protected]> Authored: Fri May 5 18:47:47 2017 +0200 Committer: Jörn Kottmann <[email protected]> Committed: Fri May 5 18:47:47 2017 +0200 ---------------------------------------------------------------------- opennlp-dl/pom.xml | 56 +++++ opennlp-dl/src/main/java/NameFinderDL.java | 232 +++++++++++++++++++ .../main/java/NameSampleDataSetIterator.java | 225 ++++++++++++++++++ 3 files changed, 513 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/96c088b0/opennlp-dl/pom.xml ---------------------------------------------------------------------- diff --git a/opennlp-dl/pom.xml b/opennlp-dl/pom.xml new file mode 100644 index 0000000..f8a6679 --- /dev/null +++ b/opennlp-dl/pom.xml @@ -0,0 +1,56 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + + <groupId>burn</groupId> + <artifactId>dl4jtest</artifactId> + <version>1.0-SNAPSHOT</version> + + <dependencies> + <dependency> + <groupId>org.apache.opennlp</groupId> + <artifactId>opennlp-tools</artifactId> + <version>1.7.2</version> + </dependency> + + <dependency> + <groupId>org.deeplearning4j</groupId> + <artifactId>deeplearning4j-core</artifactId> + <version>0.7.2</version> + </dependency> + + <dependency> + <groupId>org.nd4j</groupId> + <artifactId>nd4j-native-platform</artifactId> + <!-- artifactId>nd4j-cuda-8.0-platform</artifactId --> + <version>0.7.2</version> + </dependency> + + <dependency> + <groupId>org.deeplearning4j</groupId> + <artifactId>deeplearning4j-nlp</artifactId> + <version>0.7.2</version> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-simple</artifactId> + <version>1.7.12</version> + </dependency> + </dependencies> + + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-compiler-plugin</artifactId> + <version>3.5.1</version> + <configuration> + <source>1.8</source> + <target>1.8</target> + </configuration> + </plugin> + </plugins> + </build> +</project> http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/96c088b0/opennlp-dl/src/main/java/NameFinderDL.java ---------------------------------------------------------------------- diff --git a/opennlp-dl/src/main/java/NameFinderDL.java b/opennlp-dl/src/main/java/NameFinderDL.java new file mode 100644 index 0000000..1184a06 --- /dev/null +++ b/opennlp-dl/src/main/java/NameFinderDL.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; +import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.Updater; +import org.deeplearning4j.nn.conf.layers.GravesLSTM; +import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import opennlp.tools.namefind.BioCodec; +import opennlp.tools.namefind.NameSample; +import opennlp.tools.namefind.NameSampleDataStream; +import opennlp.tools.namefind.TokenNameFinder; +import opennlp.tools.namefind.TokenNameFinderEvaluator; +import opennlp.tools.util.MarkableFileInputStreamFactory; +import opennlp.tools.util.ObjectStream; +import opennlp.tools.util.PlainTextByLineStream; +import opennlp.tools.util.Span; + +// https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/word2vecsentiment/Word2VecSentimentRNN.java +public class NameFinderDL implements TokenNameFinder { + + private final MultiLayerNetwork network; + private final WordVectors wordVectors; + private int windowSize; + private String[] labels; + + public NameFinderDL(MultiLayerNetwork network, WordVectors wordVectors, int windowSize, + String[] labels) { + this.network = network; + this.wordVectors = wordVectors; + this.windowSize = windowSize; + this.labels = labels; + } + + static List<INDArray> mapToFeatureMatrices(WordVectors wordVectors, String[] tokens, int windowSize) { + + List<INDArray> matrices = new ArrayList<>(); + + // TODO: Dont' hard code word vector dimension ... + + for (int i = 0; i < tokens.length; i++) { + INDArray features = Nd4j.create(1, 300, windowSize); + for (int vectorIndex = 0; vectorIndex < windowSize; vectorIndex++) { + int tokenIndex = i + vectorIndex - ((windowSize - 1) / 2); + if (tokenIndex >= 0 && tokenIndex < tokens.length) { + String token = tokens[tokenIndex]; + double[] wv = wordVectors.getWordVector(token); + if (wv != null) { + INDArray vector = wordVectors.getWordVectorMatrix(token); + features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), + NDArrayIndex.point(vectorIndex)}, vector); + } + } + } + matrices.add(features); + } + + return matrices; + } + + static List<INDArray> mapToLabelVectors(NameSample sample, int windowSize, String[] labelStrings) { + + Map<String, Integer> labelToIndex = IntStream.range(0, labelStrings.length).boxed() + .collect(Collectors.toMap(i -> labelStrings[i], i -> i)); + + List<INDArray> vectors = new ArrayList<INDArray>(); + + for (int i = 0; i < sample.getSentence().length; i++) { + // encode the outcome as one-hot-representation + String outcomes[] = + new BioCodec().encode(sample.getNames(), sample.getSentence().length); + + INDArray labels = Nd4j.create(1, labelStrings.length, windowSize); + labels.putScalar(new int[]{0, labelToIndex.get(outcomes[i]), windowSize - 1}, 1.0d); + vectors.add(labels); + } + + return vectors; + } + + private static int max(INDArray array) { + int best = 0; + for (int i = 0; i < array.size(0); i++) { + if (array.getDouble(i) > array.getDouble(best)) { + best = i; + } + } + return best; + } + + @Override + public Span[] find(String[] tokens) { + List<INDArray> featureMartrices = mapToFeatureMatrices(wordVectors, tokens, windowSize); + + String[] outcomes = new String[tokens.length]; + for (int i = 0; i < tokens.length; i++) { + INDArray predictionMatrix = network.output(featureMartrices.get(i), false); + INDArray outcomeVector = predictionMatrix.get(NDArrayIndex.point(0), NDArrayIndex.all(), + NDArrayIndex.point(windowSize - 1)); + + outcomes[i] = labels[max(outcomeVector)]; + } + + // Delete invalid spans ... + for (int i = 0; i < outcomes.length; i++) { + if (outcomes[i].endsWith("cont") && (i == 0 || "other".equals(outcomes[i - 1]))) { + outcomes[i] = "other"; + } + } + + return new BioCodec().decode(Arrays.asList(outcomes)); + } + + @Override + public void clearAdaptiveData() { + } + + public static MultiLayerNetwork train(WordVectors wordVectors, ObjectStream<NameSample> samples, + int epochs, int windowSize, String[] labels) throws IOException { + int vectorSize = 300; + int layerSize = 256; + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) + .updater(Updater.RMSPROP) + .regularization(true).l2(0.001) + .weightInit(WeightInit.XAVIER) + // .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue).gradientNormalizationThreshold(1.0) + .learningRate(0.01) + .list() + .layer(0, new GravesLSTM.Builder().nIn(vectorSize).nOut(layerSize) + .activation(Activation.TANH).build()) + .layer(1, new RnnOutputLayer.Builder().activation(Activation.SOFTMAX) + .lossFunction(LossFunctions.LossFunction.MCXENT).nIn(layerSize).nOut(3).build()) + .pretrain(false).backprop(true).build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + net.setListeners(new ScoreIterationListener(5)); + + // TODO: Extract labels on the fly from the data + + DataSetIterator train = new NameSampleDataSetIterator(samples, wordVectors, windowSize, labels); + + System.out.println("Starting training"); + + for (int i = 0; i < epochs; i++) { + net.fit(train); + train.reset(); + System.out.println(String.format("Finished epoche %d", i)); + } + + return net; + } + + public static void main(String[] args) throws Exception { + if (args.length != 3) { + System.out.println("Usage: trainFile testFile gloveTxt"); + return; + } + + String[] labels = new String[] { + "default-start", "default-cont", "other" + }; + + System.out.print("Loading vectors ... "); + WordVectors wordVectors = WordVectorSerializer.loadTxtVectors( + new File(args[2])); + System.out.println("Done"); + + int windowSize = 5; + + MultiLayerNetwork net = train(wordVectors, new NameSampleDataStream(new PlainTextByLineStream( + new MarkableFileInputStreamFactory(new File(args[0])), StandardCharsets.UTF_8)), 1, windowSize, labels); + + ObjectStream<NameSample> evalStream = new NameSampleDataStream(new PlainTextByLineStream( + new MarkableFileInputStreamFactory( + new File(args[1])), StandardCharsets.UTF_8)); + + NameFinderDL nameFinder = new NameFinderDL(net, wordVectors, windowSize, labels); + + System.out.print("Evaluating ... "); + TokenNameFinderEvaluator nameFinderEvaluator = new TokenNameFinderEvaluator(nameFinder); + nameFinderEvaluator.evaluate(evalStream); + + System.out.println("Done"); + + System.out.println(); + System.out.println(); + System.out.println("Results"); + + System.out.println(nameFinderEvaluator.getFMeasure().toString()); + } +} http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/96c088b0/opennlp-dl/src/main/java/NameSampleDataSetIterator.java ---------------------------------------------------------------------- diff --git a/opennlp-dl/src/main/java/NameSampleDataSetIterator.java b/opennlp-dl/src/main/java/NameSampleDataSetIterator.java new file mode 100644 index 0000000..f416a1d --- /dev/null +++ b/opennlp-dl/src/main/java/NameSampleDataSetIterator.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; + +import opennlp.tools.namefind.NameSample; +import opennlp.tools.util.FilterObjectStream; +import opennlp.tools.util.ObjectStream; + +public class NameSampleDataSetIterator implements DataSetIterator { + + private static class NameSampleToDataSetStream extends FilterObjectStream<NameSample, DataSet> { + + private final WordVectors wordVectors; + private final String[] labels; + private int windowSize; + + private Iterator<DataSet> dataSets = Collections.emptyListIterator(); + + NameSampleToDataSetStream(ObjectStream<NameSample> samples, WordVectors wordVectors, int windowSize, String[] labels) { + super(samples); + this.wordVectors = wordVectors; + this.windowSize = windowSize; + this.labels = labels; + } + + private Iterator<DataSet> createDataSets(NameSample sample) { + List<INDArray> features = NameFinderDL.mapToFeatureMatrices(wordVectors, sample.getSentence(), + windowSize); + + List<INDArray> labels = NameFinderDL.mapToLabelVectors(sample, windowSize, this.labels); + + List<DataSet> dataSetList = new ArrayList<>(); + + for (int i = 0; i < features.size(); i++) { + dataSetList.add(new DataSet(features.get(i), labels.get(i))); + } + + return dataSetList.iterator(); + } + + @Override + public final DataSet read() throws IOException { + + if (dataSets.hasNext()) { + return dataSets.next(); + } + else { + NameSample sample; + while (!dataSets.hasNext() && (sample = samples.read()) != null) { + dataSets = createDataSets(sample); + } + + if (dataSets.hasNext()) { + return read(); + } + } + + return null; + } + } + + private final int windowSize; + private final String[] labels; + + private final int batchSize = 128; + private final int vectorSize = 300; + + private final int totalSamples; + + private int cursor = 0; + + private final ObjectStream<DataSet> samples; + + NameSampleDataSetIterator(ObjectStream<NameSample> samples, WordVectors wordVectors, int windowSize, + String labels[]) throws IOException { + this.windowSize = windowSize; + this.labels = labels; + + this.samples = new NameSampleToDataSetStream(samples, wordVectors, windowSize, labels); + + int total = 0; + + DataSet sample; + while ((sample = this.samples.read()) != null) { + total++; + } + + totalSamples = total; + + samples.reset(); + } + + public DataSet next(int num) { + if (cursor >= totalExamples()) throw new NoSuchElementException(); + + INDArray features = Nd4j.create(num, vectorSize, windowSize); + INDArray featuresMask = Nd4j.zeros(num, windowSize); + + INDArray labels = Nd4j.create(num, 3, windowSize); + INDArray labelsMask = Nd4j.zeros(num, windowSize); + + // iterate stream and copy to arrays + + for (int i = 0; i < num; i++) { + DataSet sample; + try { + sample = samples.read(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + if (sample != null) { + INDArray feature = sample.getFeatureMatrix(); + features.put(new INDArrayIndex[] {NDArrayIndex.point(i)}, feature.get(NDArrayIndex.point(0))); + + feature.get(new INDArrayIndex[] {NDArrayIndex.point(0), NDArrayIndex.all(), + NDArrayIndex.point(0)}); + + for (int j = 0; j < windowSize; j++) { + featuresMask.putScalar(new int[] {i, j}, 1.0); + } + + INDArray label = sample.getLabels(); + labels.put(new INDArrayIndex[] {NDArrayIndex.point(i)}, label.get(NDArrayIndex.point(0))); + labelsMask.putScalar(new int[] {i, windowSize - 1}, 1.0); + } + + cursor++; + } + + return new DataSet(features, labels, featuresMask, labelsMask); + } + + public int totalExamples() { + return totalSamples; + } + + public int inputColumns() { + return vectorSize; + } + + public int totalOutcomes() { + return getLabels().size(); + } + + public boolean resetSupported() { + return true; + } + + public boolean asyncSupported() { + return false; + } + + public void reset() { + cursor = 0; + + try { + samples.reset(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public int batch() { + return batchSize; + } + + public int cursor() { + return cursor; + } + + public int numExamples() { + return totalExamples(); + } + + public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) { + throw new UnsupportedOperationException(); + } + + public DataSetPreProcessor getPreProcessor() { + throw new UnsupportedOperationException(); + } + + public List<String> getLabels() { + return Arrays.asList("start","cont", "other"); + } + + public boolean hasNext() { + return cursor < numExamples(); + } + + public DataSet next() { + return next(batchSize); + } +}
