http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/test/resources/log4j-test.properties ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/test/resources/log4j-test.properties b/flink-contrib/flink-tez/src/test/resources/log4j-test.properties new file mode 100644 index 0000000..0845c81 --- /dev/null +++ b/flink-contrib/flink-tez/src/test/resources/log4j-test.properties @@ -0,0 +1,30 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Set root logger level to OFF to not flood build logs +# set manually to INFO for debugging purposes +log4j.rootLogger=INFO, testlogger + +# A1 is set to be a ConsoleAppender. +log4j.appender.testlogger=org.apache.log4j.ConsoleAppender +log4j.appender.testlogger.target = System.err +log4j.appender.testlogger.layout=org.apache.log4j.PatternLayout +log4j.appender.testlogger.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n + +# suppress the irrelevant (wrong) warnings from the netty channel handler +log4j.logger.org.jboss.netty.channel.DefaultChannelPipeline=ERROR, testlogger \ No newline at end of file
http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/flink-tez/src/test/resources/logback-test.xml ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-tez/src/test/resources/logback-test.xml b/flink-contrib/flink-tez/src/test/resources/logback-test.xml new file mode 100644 index 0000000..48e4374 --- /dev/null +++ b/flink-contrib/flink-tez/src/test/resources/logback-test.xml @@ -0,0 +1,37 @@ +<!-- + ~ Licensed to the Apache Software Foundation (ASF) under one + ~ or more contributor license agreements. See the NOTICE file + ~ distributed with this work for additional information + ~ regarding copyright ownership. The ASF licenses this file + ~ to you under the Apache License, Version 2.0 (the + ~ "License"); you may not use this file except in compliance + ~ with the License. You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> + +<configuration> + <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender"> + <encoder> + <pattern>%d{HH:mm:ss.SSS} %-5level [%thread] %logger{60} - %msg%n</pattern> + </encoder> + </appender> + + <root level="WARN"> + <appender-ref ref="STDOUT"/> + </root> + + <!--<logger name="org.apache.flink.runtime.operators.BatchTask" level="OFF"/>--> + <!--<logger name="org.apache.flink.runtime.client.JobClient" level="OFF"/>--> + <!--<logger name="org.apache.flink.runtime.taskmanager.Task" level="OFF"/>--> + <!--<logger name="org.apache.flink.runtime.jobmanager.JobManager" level="OFF"/>--> + <!--<logger name="org.apache.flink.runtime.taskmanager.TaskManager" level="OFF"/>--> + <!--<logger name="org.apache.flink.runtime.executiongraph.ExecutionGraph" level="OFF"/>--> + <!--<logger name="org.apache.flink.runtime.jobmanager.EventCollector" level="OFF"/>--> +</configuration> \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-contrib/pom.xml ---------------------------------------------------------------------- diff --git a/flink-contrib/pom.xml b/flink-contrib/pom.xml index 8008058..5901621 100644 --- a/flink-contrib/pom.xml +++ b/flink-contrib/pom.xml @@ -31,6 +31,11 @@ under the License. <relativePath>..</relativePath> </parent> + <artifactId>flink-contrib-parent</artifactId> + <name>flink-contrib-parent</name> + + <packaging>pom</packaging> + <modules> <module>flink-storm</module> <module>flink-storm-examples</module> @@ -40,9 +45,14 @@ under the License. <module>flink-connector-wikiedits</module> </modules> - <artifactId>flink-contrib-parent</artifactId> - <name>flink-contrib</name> - <packaging>pom</packaging> + <profiles> + <profile> + <id>include-tez</id> + <modules> + <module>flink-tez</module> + </modules> + </profile> + </profiles> </project> http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-dist/src/main/assemblies/bin.xml ---------------------------------------------------------------------- diff --git a/flink-dist/src/main/assemblies/bin.xml b/flink-dist/src/main/assemblies/bin.xml index b067280..c28f01a 100644 --- a/flink-dist/src/main/assemblies/bin.xml +++ b/flink-dist/src/main/assemblies/bin.xml @@ -71,7 +71,7 @@ under the License. <!-- flink scala shell--> <fileSet> - <directory>../flink-staging/flink-scala-shell/start-script/</directory> + <directory>../flink-scala-shell/start-script/</directory> <outputDirectory>bin</outputDirectory> <fileMode>755</fileMode> </fileSet> @@ -138,17 +138,6 @@ under the License. <outputDirectory>tools</outputDirectory> <fileMode>0644</fileMode> </fileSet> - <fileSet> - <directory>../flink-clients/src/main/resources/web-docs</directory> - <outputDirectory>tools</outputDirectory> - <fileMode>0644</fileMode> - <excludes> - <exclude>*.html</exclude> - <exclude>img/delete-icon.png</exclude> - <exclude>img/GradientBoxes.png</exclude> - <exclude>img/gradient.jpg</exclude> - </excludes> - </fileSet> <!-- copy jar files of the batch examples --> <fileSet> http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-examples/flink-examples-batch/pom.xml ---------------------------------------------------------------------- diff --git a/flink-examples/flink-examples-batch/pom.xml b/flink-examples/flink-examples-batch/pom.xml index e31f35c..d18a6f1 100644 --- a/flink-examples/flink-examples-batch/pom.xml +++ b/flink-examples/flink-examples-batch/pom.xml @@ -246,13 +246,13 @@ under the License. <!-- EnumTriangles Basic --> <execution> - <id>EnumerateGraphTriangles</id> + <id>EnumTriangles</id> <phase>package</phase> <goals> <goal>jar</goal> </goals> <configuration> - <classifier>EnumerateGraphTriangles</classifier> + <classifier>EnumTriangles</classifier> <archive> <manifestEntries> @@ -383,7 +383,7 @@ under the License. <target> <copy file="${project.basedir}/target/flink-examples-batch-${project.version}-KMeans.jar" tofile="${project.basedir}/target/KMeans.jar" /> <copy file="${project.basedir}/target/flink-examples-batch-${project.version}-ConnectedComponents.jar" tofile="${project.basedir}/target/ConnectedComponents.jar" /> - <copy file="${project.basedir}/target/flink-examples-batch-${project.version}-EnumerateGraphTriangles.jar" tofile="${project.basedir}/target/EnumerateGraphTriangles.jar" /> + <copy file="${project.basedir}/target/flink-examples-batch-${project.version}-EnumTriangles.jar" tofile="${project.basedir}/target/EnumTriangles.jar" /> <copy file="${project.basedir}/target/flink-examples-batch-${project.version}-PageRank.jar" tofile="${project.basedir}/target/PageRank.jar" /> <copy file="${project.basedir}/target/flink-examples-batch-${project.version}-TransitiveClosure.jar" tofile="${project.basedir}/target/TransitiveClosure.jar" /> <copy file="${project.basedir}/target/flink-examples-batch-${project.version}-WebLogAnalysis.jar" tofile="${project.basedir}/target/WebLogAnalysis.jar" /> http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-fs-tests/pom.xml ---------------------------------------------------------------------- diff --git a/flink-fs-tests/pom.xml b/flink-fs-tests/pom.xml new file mode 100644 index 0000000..f8a2c3f --- /dev/null +++ b/flink-fs-tests/pom.xml @@ -0,0 +1,97 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> + + <modelVersion>4.0.0</modelVersion> + + <parent> + <groupId>org.apache.flink</groupId> + <artifactId>flink-parent</artifactId> + <version>1.0-SNAPSHOT</version> + <relativePath>..</relativePath> + </parent> + + <artifactId>flink-fs-tests</artifactId> + <name>flink-fs-tests</name> + + <packaging>jar</packaging> + + <!-- + This is a Hadoop2 only flink module. + --> + <dependencies> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>${shading-artifact.name}</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-core</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-streaming-java</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-examples-batch</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-avro</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-test-utils</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-hdfs</artifactId> + <scope>test</scope> + <type>test-jar</type> + <version>${hadoop.version}</version><!--$NO-MVN-MAN-VER$--> + </dependency> + + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-common</artifactId> + <scope>test</scope> + <type>test-jar</type> + <version>${hadoop.version}</version><!--$NO-MVN-MAN-VER$--> + </dependency> + </dependencies> +</project> http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java ---------------------------------------------------------------------- diff --git a/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java new file mode 100644 index 0000000..49dfc21 --- /dev/null +++ b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.hdfstests; + +import org.apache.commons.io.FileUtils; + +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.core.fs.FileStatus; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.Path; +import org.apache.flink.core.testutils.CommonTestUtils; +import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.runtime.state.filesystem.FileStreamStateHandle; +import org.apache.flink.runtime.state.filesystem.FsStateBackend; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hdfs.MiniDFSCluster; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Random; +import java.util.UUID; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class FileStateBackendTest { + + private static File TEMP_DIR; + + private static String HDFS_ROOT_URI; + + private static MiniDFSCluster HDFS_CLUSTER; + + private static FileSystem FS; + + // ------------------------------------------------------------------------ + // startup / shutdown + // ------------------------------------------------------------------------ + + @BeforeClass + public static void createHDFS() { + try { + TEMP_DIR = new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()); + + Configuration hdConf = new Configuration(); + hdConf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, TEMP_DIR.getAbsolutePath()); + MiniDFSCluster.Builder builder = new MiniDFSCluster.Builder(hdConf); + HDFS_CLUSTER = builder.build(); + + HDFS_ROOT_URI = "hdfs://" + HDFS_CLUSTER.getURI().getHost() + ":" + + HDFS_CLUSTER.getNameNodePort() + "/"; + + FS = FileSystem.get(new URI(HDFS_ROOT_URI)); + } + catch (Exception e) { + e.printStackTrace(); + fail("Could not create HDFS mini cluster " + e.getMessage()); + } + } + + @AfterClass + public static void destroyHDFS() { + try { + HDFS_CLUSTER.shutdown(); + FileUtils.deleteDirectory(TEMP_DIR); + } + catch (Exception ignored) {} + } + + // ------------------------------------------------------------------------ + // Tests + // ------------------------------------------------------------------------ + + @Test + public void testSetupAndSerialization() { + try { + URI baseUri = new URI(HDFS_ROOT_URI + UUID.randomUUID().toString()); + + FsStateBackend originalBackend = new FsStateBackend(baseUri); + + assertFalse(originalBackend.isInitialized()); + assertEquals(baseUri, originalBackend.getBasePath().toUri()); + assertNull(originalBackend.getCheckpointDirectory()); + + // serialize / copy the backend + FsStateBackend backend = CommonTestUtils.createCopySerializable(originalBackend); + assertFalse(backend.isInitialized()); + assertEquals(baseUri, backend.getBasePath().toUri()); + assertNull(backend.getCheckpointDirectory()); + + // no file operations should be possible right now + try { + backend.checkpointStateSerializable("exception train rolling in", 2L, System.currentTimeMillis()); + fail("should fail with an exception"); + } catch (IllegalStateException e) { + // supreme! + } + + backend.initializeForJob(new DummyEnvironment("test", 1, 0)); + assertNotNull(backend.getCheckpointDirectory()); + + Path checkpointDir = backend.getCheckpointDirectory(); + assertTrue(FS.exists(checkpointDir)); + assertTrue(isDirectoryEmpty(checkpointDir)); + + backend.disposeAllStateForCurrentJob(); + assertNull(backend.getCheckpointDirectory()); + + assertTrue(isDirectoryEmpty(baseUri)); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testSerializableState() { + try { + FsStateBackend backend = CommonTestUtils.createCopySerializable( + new FsStateBackend(randomHdfsFileUri(), 40)); + backend.initializeForJob(new DummyEnvironment("test", 1, 0)); + + Path checkpointDir = backend.getCheckpointDirectory(); + + String state1 = "dummy state"; + String state2 = "row row row your boat"; + Integer state3 = 42; + + StateHandle<String> handle1 = backend.checkpointStateSerializable(state1, 439568923746L, System.currentTimeMillis()); + StateHandle<String> handle2 = backend.checkpointStateSerializable(state2, 439568923746L, System.currentTimeMillis()); + StateHandle<Integer> handle3 = backend.checkpointStateSerializable(state3, 439568923746L, System.currentTimeMillis()); + + assertEquals(state1, handle1.getState(getClass().getClassLoader())); + handle1.discardState(); + + assertEquals(state2, handle2.getState(getClass().getClassLoader())); + handle2.discardState(); + + assertEquals(state3, handle3.getState(getClass().getClassLoader())); + handle3.discardState(); + + assertTrue(isDirectoryEmpty(checkpointDir)); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testStateOutputStream() { + try { + FsStateBackend backend = CommonTestUtils.createCopySerializable( + new FsStateBackend(randomHdfsFileUri(), 15)); + backend.initializeForJob(new DummyEnvironment("test", 1, 0)); + + Path checkpointDir = backend.getCheckpointDirectory(); + + byte[] state1 = new byte[1274673]; + byte[] state2 = new byte[1]; + byte[] state3 = new byte[0]; + byte[] state4 = new byte[177]; + + Random rnd = new Random(); + rnd.nextBytes(state1); + rnd.nextBytes(state2); + rnd.nextBytes(state3); + rnd.nextBytes(state4); + + long checkpointId = 97231523452L; + + FsStateBackend.FsCheckpointStateOutputStream stream1 = + backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); + FsStateBackend.FsCheckpointStateOutputStream stream2 = + backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); + FsStateBackend.FsCheckpointStateOutputStream stream3 = + backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); + + stream1.write(state1); + stream2.write(state2); + stream3.write(state3); + + FileStreamStateHandle handle1 = (FileStreamStateHandle) stream1.closeAndGetHandle(); + ByteStreamStateHandle handle2 = (ByteStreamStateHandle) stream2.closeAndGetHandle(); + ByteStreamStateHandle handle3 = (ByteStreamStateHandle) stream3.closeAndGetHandle(); + + // use with try-with-resources + StreamStateHandle handle4; + try (StateBackend.CheckpointStateOutputStream stream4 = + backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis())) { + stream4.write(state4); + handle4 = stream4.closeAndGetHandle(); + } + + // close before accessing handle + StateBackend.CheckpointStateOutputStream stream5 = + backend.createCheckpointStateOutputStream(checkpointId, System.currentTimeMillis()); + stream5.write(state4); + stream5.close(); + try { + stream5.closeAndGetHandle(); + fail(); + } catch (IOException e) { + // uh-huh + } + + validateBytesInStream(handle1.getState(getClass().getClassLoader()), state1); + handle1.discardState(); + assertFalse(isDirectoryEmpty(checkpointDir)); + ensureFileDeleted(handle1.getFilePath()); + + validateBytesInStream(handle2.getState(getClass().getClassLoader()), state2); + handle2.discardState(); + + validateBytesInStream(handle3.getState(getClass().getClassLoader()), state3); + handle3.discardState(); + + validateBytesInStream(handle4.getState(getClass().getClassLoader()), state4); + handle4.discardState(); + assertTrue(isDirectoryEmpty(checkpointDir)); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // ------------------------------------------------------------------------ + // Utilities + // ------------------------------------------------------------------------ + + private static void ensureFileDeleted(Path path) { + try { + assertFalse(FS.exists(path)); + } + catch (IOException ignored) {} + } + + private static boolean isDirectoryEmpty(URI directory) { + return isDirectoryEmpty(new Path(directory)); + } + + private static boolean isDirectoryEmpty(Path directory) { + try { + FileStatus[] nested = FS.listStatus(directory); + return nested == null || nested.length == 0; + } + catch (IOException e) { + return true; + } + } + + private static URI randomHdfsFileUri() { + String uriString = HDFS_ROOT_URI + UUID.randomUUID().toString(); + try { + return new URI(uriString); + } + catch (URISyntaxException e) { + throw new RuntimeException("Invalid test directory URI: " + uriString, e); + } + } + + private static void validateBytesInStream(InputStream is, byte[] data) throws IOException { + byte[] holder = new byte[data.length]; + + int pos = 0; + int read; + while (pos < holder.length && (read = is.read(holder, pos, holder.length - pos)) != -1) { + pos += read; + } + + assertEquals("not enough data", holder.length, pos); + assertEquals("too much data", -1, is.read()); + assertArrayEquals("wrong data", data, holder); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/HDFSTest.java ---------------------------------------------------------------------- diff --git a/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/HDFSTest.java b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/HDFSTest.java new file mode 100644 index 0000000..bc800a5 --- /dev/null +++ b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/HDFSTest.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.hdfstests; + +import org.apache.commons.io.IOUtils; +import org.apache.flink.api.common.io.FileOutputFormat; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.ExecutionEnvironmentFactory; +import org.apache.flink.api.java.LocalEnvironment; +import org.apache.flink.api.java.io.AvroOutputFormat; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.core.fs.Path; +import org.apache.flink.examples.java.wordcount.WordCount; +import org.apache.flink.runtime.fs.hdfs.HadoopFileSystem; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileUtil; +import org.apache.hadoop.hdfs.MiniDFSCluster; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.io.StringWriter; + +/** + * This test should logically be located in the 'flink-runtime' tests. However, this project + * has already all dependencies required (flink-java-examples). Also, the ParallelismOneExecEnv is here. + */ +public class HDFSTest { + + protected String hdfsURI; + private MiniDFSCluster hdfsCluster; + private org.apache.hadoop.fs.Path hdPath; + protected org.apache.hadoop.fs.FileSystem hdfs; + + @Before + public void createHDFS() { + try { + Configuration hdConf = new Configuration(); + + File baseDir = new File("./target/hdfs/hdfsTest").getAbsoluteFile(); + FileUtil.fullyDelete(baseDir); + hdConf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath()); + MiniDFSCluster.Builder builder = new MiniDFSCluster.Builder(hdConf); + hdfsCluster = builder.build(); + + hdfsURI = "hdfs://" + hdfsCluster.getURI().getHost() + ":" + hdfsCluster.getNameNodePort() +"/"; + + hdPath = new org.apache.hadoop.fs.Path("/test"); + hdfs = hdPath.getFileSystem(hdConf); + FSDataOutputStream stream = hdfs.create(hdPath); + for(int i = 0; i < 10; i++) { + stream.write("Hello HDFS\n".getBytes()); + } + stream.close(); + + } catch(Throwable e) { + e.printStackTrace(); + Assert.fail("Test failed " + e.getMessage()); + } + } + + @After + public void destroyHDFS() { + try { + hdfs.delete(hdPath, false); + hdfsCluster.shutdown(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + } + + @Test + public void testHDFS() { + + Path file = new Path(hdfsURI + hdPath); + org.apache.hadoop.fs.Path result = new org.apache.hadoop.fs.Path(hdfsURI + "/result"); + try { + FileSystem fs = file.getFileSystem(); + Assert.assertTrue("Must be HadoopFileSystem", fs instanceof HadoopFileSystem); + + DopOneTestEnvironment.setAsContext(); + try { + WordCount.main(new String[]{file.toString(), result.toString()}); + } + catch(Throwable t) { + t.printStackTrace(); + Assert.fail("Test failed with " + t.getMessage()); + } + finally { + DopOneTestEnvironment.unsetAsContext(); + } + + Assert.assertTrue("No result file present", hdfs.exists(result)); + + // validate output: + org.apache.hadoop.fs.FSDataInputStream inStream = hdfs.open(result); + StringWriter writer = new StringWriter(); + IOUtils.copy(inStream, writer); + String resultString = writer.toString(); + + Assert.assertEquals("hdfs 10\n" + + "hello 10\n", resultString); + inStream.close(); + + } catch (IOException e) { + e.printStackTrace(); + Assert.fail("Error in test: " + e.getMessage() ); + } + } + + @Test + public void testAvroOut() { + String type = "one"; + AvroOutputFormat<String> avroOut = + new AvroOutputFormat<String>( String.class ); + + org.apache.hadoop.fs.Path result = new org.apache.hadoop.fs.Path(hdfsURI + "/avroTest"); + + avroOut.setOutputFilePath(new Path(result.toString())); + avroOut.setWriteMode(FileSystem.WriteMode.NO_OVERWRITE); + avroOut.setOutputDirectoryMode(FileOutputFormat.OutputDirectoryMode.ALWAYS); + + try { + avroOut.open(0, 2); + avroOut.writeRecord(type); + avroOut.close(); + + avroOut.open(1, 2); + avroOut.writeRecord(type); + avroOut.close(); + + + Assert.assertTrue("No result file present", hdfs.exists(result)); + FileStatus[] files = hdfs.listStatus(result); + Assert.assertEquals(2, files.length); + for(FileStatus file : files) { + Assert.assertTrue("1.avro".equals(file.getPath().getName()) || "2.avro".equals(file.getPath().getName())); + } + + } catch (IOException e) { + e.printStackTrace(); + Assert.fail(e.getMessage()); + } + } + + // package visible + static abstract class DopOneTestEnvironment extends ExecutionEnvironment { + + public static void setAsContext() { + final LocalEnvironment le = new LocalEnvironment(); + le.setParallelism(1); + + initializeContextEnvironment(new ExecutionEnvironmentFactory() { + + @Override + public ExecutionEnvironment createExecutionEnvironment() { + return le; + } + }); + } + + public static void unsetAsContext() { + resetContextEnvironment(); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-fs-tests/src/test/resources/log4j-test.properties ---------------------------------------------------------------------- diff --git a/flink-fs-tests/src/test/resources/log4j-test.properties b/flink-fs-tests/src/test/resources/log4j-test.properties new file mode 100644 index 0000000..f533ba2 --- /dev/null +++ b/flink-fs-tests/src/test/resources/log4j-test.properties @@ -0,0 +1,31 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Tachyon's test-jar dependency adds a log4j.properties file to classpath. +# Until the issue is resolved (see https://github.com/amplab/tachyon/pull/571) +# we provide a log4j.properties file ourselves. + +log4j.rootLogger=OFF, testlogger + +log4j.appender.testlogger=org.apache.log4j.ConsoleAppender +log4j.appender.testlogger.target = System.err +log4j.appender.testlogger.layout=org.apache.log4j.PatternLayout +log4j.appender.testlogger.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n + +# suppress the irrelevant (wrong) warnings from the netty channel handler +log4j.logger.org.jboss.netty.channel.DefaultChannelPipeline=ERROR, testlogger \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-fs-tests/src/test/resources/log4j.properties ---------------------------------------------------------------------- diff --git a/flink-fs-tests/src/test/resources/log4j.properties b/flink-fs-tests/src/test/resources/log4j.properties new file mode 100644 index 0000000..f533ba2 --- /dev/null +++ b/flink-fs-tests/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Tachyon's test-jar dependency adds a log4j.properties file to classpath. +# Until the issue is resolved (see https://github.com/amplab/tachyon/pull/571) +# we provide a log4j.properties file ourselves. + +log4j.rootLogger=OFF, testlogger + +log4j.appender.testlogger=org.apache.log4j.ConsoleAppender +log4j.appender.testlogger.target = System.err +log4j.appender.testlogger.layout=org.apache.log4j.PatternLayout +log4j.appender.testlogger.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n + +# suppress the irrelevant (wrong) warnings from the netty channel handler +log4j.logger.org.jboss.netty.channel.DefaultChannelPipeline=ERROR, testlogger \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/README.md ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/README.md b/flink-libraries/flink-ml/README.md new file mode 100644 index 0000000..5cabd7c --- /dev/null +++ b/flink-libraries/flink-ml/README.md @@ -0,0 +1,22 @@ +Flink-ML constitutes the machine learning library of Apache Flink. +Our vision is to make machine learning easily accessible to a wide audience and yet to achieve extraordinary performance. +For this purpose, Flink-ML is based on two pillars: + +Flink-ML contains implementations of popular ML algorithms which are highly optimized for Apache Flink. +Theses implementations allow to scale to data sizes which vastly exceed the memory of a single computer. +Flink-ML currently comprises the following algorithms: + +* Classification +** Soft-margin SVM +* Regression +** Multiple linear regression +* Recommendation +** Alternating least squares (ALS) + +Since most of the work in data analytics is related to post- and pre-processing of data where the performance is not crucial, Flink wants to offer a simple abstraction to do that. +Linear algebra, as common ground of many ML algorithms, represents such a high-level abstraction. +Therefore, Flink will support the Mahout DSL as a execution engine and provide tools to neatly integrate the optimized algorithms into a linear algebra program. + +Flink-ML has just been recently started. +As part of Apache Flink, it heavily relies on the active work and contributions of its community and others. +Thus, if you want to add a new algorithm to the library, then find out [how to contribute]((http://flink.apache.org/how-to-contribute.html)) and open a pull request! \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/pom.xml ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/pom.xml b/flink-libraries/flink-ml/pom.xml new file mode 100644 index 0000000..2dd0516 --- /dev/null +++ b/flink-libraries/flink-ml/pom.xml @@ -0,0 +1,162 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- + ~ Licensed to the Apache Software Foundation (ASF) under one + ~ or more contributor license agreements. See the NOTICE file + ~ distributed with this work for additional information + ~ regarding copyright ownership. The ASF licenses this file + ~ to you under the Apache License, Version 2.0 (the + ~ "License"); you may not use this file except in compliance + ~ with the License. You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> + + <modelVersion>4.0.0</modelVersion> + + <parent> + <groupId>org.apache.flink</groupId> + <artifactId>flink-libraries</artifactId> + <version>1.0-SNAPSHOT</version> + <relativePath>..</relativePath> + </parent> + + <artifactId>flink-ml</artifactId> + <name>flink-ml</name> + + <packaging>jar</packaging> + + <dependencies> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-scala</artifactId> + <version>${project.version}</version> + </dependency> + + <dependency> + <groupId>org.scalanlp</groupId> + <artifactId>breeze_${scala.binary.version}</artifactId> + <version>0.11.2</version> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-clients</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-clients</artifactId> + <version>${project.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-core</artifactId> + <version>${project.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-test-utils</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + </dependencies> + + <build> + <plugins> + <plugin> + <groupId>org.scala-tools</groupId> + <artifactId>maven-scala-plugin</artifactId> + <version>2.15.2</version> + <executions> + <execution> + <goals> + <goal>compile</goal> + <goal>testCompile</goal> + </goals> + </execution> + </executions> + <configuration> + <sourceDir>src/main/scala</sourceDir> + <testSourceDir>src/test/scala</testSourceDir> + <jvmArgs> + <jvmArg>-Xms64m</jvmArg> + <jvmArg>-Xmx1024m</jvmArg> + </jvmArgs> + </configuration> + </plugin> + + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <version>1.0</version> + <configuration> + <reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory> + <stdout>W</stdout> <!-- Skip coloring output --> + </configuration> + <executions> + <execution> + <id>scala-test</id> + <goals> + <goal>test</goal> + </goals> + <configuration> + <suffixes>(?<!(IT|Integration))(Test|Suite|Case)</suffixes> + <argLine>-Xms256m -Xmx800m -Dlog4j.configuration=${log4j.configuration} -Dlog.dir=${log.dir} -Dmvn.forkNumber=1 -XX:-UseGCOverheadLimit</argLine> + </configuration> + </execution> + <execution> + <id>integration-test</id> + <phase>integration-test</phase> + <goals> + <goal>test</goal> + </goals> + <configuration> + <suffixes>(IT|Integration)(Test|Suite|Case)</suffixes> + <argLine>-Xms256m -Xmx800m -Dlog4j.configuration=${log4j.configuration} -Dlog.dir=${log.dir} -Dmvn.forkNumber=1 -XX:-UseGCOverheadLimit</argLine> + </configuration> + </execution> + </executions> + </plugin> + + <plugin> + <groupId>org.scalastyle</groupId> + <artifactId>scalastyle-maven-plugin</artifactId> + <version>0.5.0</version> + <executions> + <execution> + <goals> + <goal>check</goal> + </goals> + </execution> + </executions> + <configuration> + <verbose>false</verbose> + <failOnViolation>true</failOnViolation> + <includeTestSourceDirectory>true</includeTestSourceDirectory> + <failOnWarning>false</failOnWarning> + <sourceDirectory>${basedir}/src/main/scala</sourceDirectory> + <testSourceDirectory>${basedir}/src/test/scala</testSourceDirectory> + <configLocation>${project.basedir}/../../tools/maven/scalastyle-config.xml</configLocation> + <outputFile>${project.basedir}/scalastyle-output.xml</outputFile> + <outputEncoding>UTF-8</outputEncoding> + </configuration> + </plugin> + </plugins> + </build> +</project> http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/MLUtils.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/MLUtils.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/MLUtils.scala new file mode 100644 index 0000000..804ab5f --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/MLUtils.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml + +import org.apache.flink.api.common.functions.RichMapFunction +import org.apache.flink.api.java.operators.DataSink +import org.apache.flink.api.scala._ +import org.apache.flink.configuration.Configuration +import org.apache.flink.ml.common.LabeledVector +import org.apache.flink.ml.math.SparseVector + +/** Convenience functions for machine learning tasks + * + * This object contains convenience functions for machine learning tasks: + * + * - readLibSVM: + * Reads a libSVM/SVMLight input file and returns a data set of [[LabeledVector]]. + * The file format is specified [http://svmlight.joachims.org/ here]. + * + * - writeLibSVM: + * Writes a data set of [[LabeledVector]] in libSVM/SVMLight format to disk. THe file format + * is specified [http://svmlight.joachims.org/ here]. + */ +object MLUtils { + + val DIMENSION = "dimension" + + /** Reads a file in libSVM/SVMLight format and converts the data into a data set of + * [[LabeledVector]]. The dimension of the [[LabeledVector]] is determined automatically. + * + * Since the libSVM/SVMLight format stores a vector in its sparse form, the [[LabeledVector]] + * will also be instantiated with a [[SparseVector]]. + * + * @param env executionEnvironment [[ExecutionEnvironment]] + * @param filePath Path to the input file + * @return [[DataSet]] of [[LabeledVector]] containing the information of the libSVM/SVMLight + * file + */ + def readLibSVM(env: ExecutionEnvironment, filePath: String): DataSet[LabeledVector] = { + val labelCOODS = env.readTextFile(filePath).flatMap { + line => + // remove all comments which start with a '#' + val commentFreeLine = line.takeWhile(_ != '#').trim + + if(commentFreeLine.nonEmpty) { + val splits = commentFreeLine.split(' ') + val label = splits.head.toDouble + val sparseFeatures = splits.tail + val coos = sparseFeatures.map { + str => + val pair = str.split(':') + require(pair.length == 2, "Each feature entry has to have the form <feature>:<value>") + + // libSVM index is 1-based, but we expect it to be 0-based + val index = pair(0).toInt - 1 + val value = pair(1).toDouble + + (index, value) + } + + Some((label, coos)) + } else { + None + } + } + + // Calculate maximum dimension of vectors + val dimensionDS = labelCOODS.map { + labelCOO => + labelCOO._2.map( _._1 + 1 ).max + }.reduce(scala.math.max(_, _)) + + labelCOODS.map{ new RichMapFunction[(Double, Array[(Int, Double)]), LabeledVector] { + var dimension = 0 + + override def open(configuration: Configuration): Unit = { + dimension = getRuntimeContext.getBroadcastVariable(DIMENSION).get(0) + } + + override def map(value: (Double, Array[(Int, Double)])): LabeledVector = { + new LabeledVector(value._1, SparseVector.fromCOO(dimension, value._2)) + } + }}.withBroadcastSet(dimensionDS, DIMENSION) + } + + /** Writes a [[DataSet]] of [[LabeledVector]] to a file using the libSVM/SVMLight format. + * + * @param filePath Path to output file + * @param labeledVectors [[DataSet]] of [[LabeledVector]] to write to disk + * @return + */ + def writeLibSVM(filePath: String, labeledVectors: DataSet[LabeledVector]): DataSink[String] = { + val stringRepresentation = labeledVectors.map{ + labeledVector => + val vectorStr = labeledVector.vector. + // remove zero entries + filter( _._2 != 0). + map{case (idx, value) => (idx + 1) + ":" + value}. + mkString(" ") + + labeledVector.label + " " + vectorStr + } + + stringRepresentation.writeAsText(filePath) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala new file mode 100644 index 0000000..4a780e9 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala @@ -0,0 +1,550 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification + +import org.apache.flink.api.common.functions.RichMapFunction +import org.apache.flink.api.scala._ +import org.apache.flink.configuration.Configuration +import org.apache.flink.ml.common.FlinkMLTools.ModuloKeyPartitioner +import org.apache.flink.ml.common._ +import org.apache.flink.ml.math.Breeze._ +import org.apache.flink.ml.math.{DenseVector, Vector} +import org.apache.flink.ml.pipeline.{FitOperation, PredictOperation, Predictor} + +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import breeze.linalg.{DenseVector => BreezeDenseVector, Vector => BreezeVector} + +/** Implements a soft-margin SVM using the communication-efficient distributed dual coordinate + * ascent algorithm (CoCoA) with hinge-loss function. + * + * It can be used for binary classification problems, with the labels set as +1.0 to indiciate a + * positive example and -1.0 to indicate a negative example. + * + * The algorithm solves the following minimization problem: + * + * `min_{w in bbb"R"^d} lambda/2 ||w||^2 + 1/n sum_(i=1)^n l_{i}(w^Tx_i)` + * + * with `w` being the weight vector, `lambda` being the regularization constant, + * `x_{i} in bbb"R"^d` being the data points and `l_{i}` being the convex loss functions, which + * can also depend on the labels `y_{i} in bbb"R"`. + * In the current implementation the regularizer is the 2-norm and the loss functions are the + * hinge-loss functions: + * + * `l_{i} = max(0, 1 - y_{i} * w^Tx_i` + * + * With these choices, the problem definition is equivalent to a SVM with soft-margin. + * Thus, the algorithm allows us to train a SVM with soft-margin. + * + * The minimization problem is solved by applying stochastic dual coordinate ascent (SDCA). + * In order to make the algorithm efficient in a distributed setting, the CoCoA algorithm + * calculates several iterations of SDCA locally on a data block before merging the local + * updates into a valid global state. + * This state is redistributed to the different data partitions where the next round of local + * SDCA iterations is then executed. + * The number of outer iterations and local SDCA iterations control the overall network costs, + * because there is only network communication required for each outer iteration. + * The local SDCA iterations are embarrassingly parallel once the individual data partitions have + * been distributed across the cluster. + * + * Further details of the algorithm can be found [[http://arxiv.org/abs/1409.1458 here]]. + * + * @example + * {{{ + * val trainingDS: DataSet[LabeledVector] = env.readLibSVM(pathToTrainingFile) + * + * val svm = SVM() + * .setBlocks(10) + * + * svm.fit(trainingDS) + * + * val testingDS: DataSet[Vector] = env.readLibSVM(pathToTestingFile) + * .map(lv => lv.vector) + * + * val predictionDS: DataSet[(Vector, Double)] = svm.predict(testingDS) + * }}} + * + * =Parameters= + * + * - [[org.apache.flink.ml.classification.SVM.Blocks]]: + * Sets the number of blocks into which the input data will be split. On each block the local + * stochastic dual coordinate ascent method is executed. This number should be set at least to + * the degree of parallelism. If no value is specified, then the parallelism of the input + * [[DataSet]] is used as the number of blocks. (Default value: '''None''') + * + * - [[org.apache.flink.ml.classification.SVM.Iterations]]: + * Defines the maximum number of iterations of the outer loop method. In other words, it defines + * how often the SDCA method is applied to the blocked data. After each iteration, the locally + * computed weight vector updates have to be reduced to update the global weight vector value. + * The new weight vector is broadcast to all SDCA tasks at the beginning of each iteration. + * (Default value: '''10''') + * + * - [[org.apache.flink.ml.classification.SVM.LocalIterations]]: + * Defines the maximum number of SDCA iterations. In other words, it defines how many data points + * are drawn from each local data block to calculate the stochastic dual coordinate ascent. + * (Default value: '''10''') + * + * - [[org.apache.flink.ml.classification.SVM.Regularization]]: + * Defines the regularization constant of the SVM algorithm. The higher the value, the smaller + * will the 2-norm of the weight vector be. In case of a SVM with hinge loss this means that the + * SVM margin will be wider even though it might contain some false classifications. + * (Default value: '''1.0''') + * + * - [[org.apache.flink.ml.classification.SVM.Stepsize]]: + * Defines the initial step size for the updates of the weight vector. The larger the step size + * is, the larger will be the contribution of the weight vector updates to the next weight vector + * value. The effective scaling of the updates is `stepsize/blocks`. This value has to be tuned + * in case that the algorithm becomes instable. (Default value: '''1.0''') + * + * - [[org.apache.flink.ml.classification.SVM.Seed]]: + * Defines the seed to initialize the random number generator. The seed directly controls which + * data points are chosen for the SDCA method. (Default value: '''0''') + * + * - [[org.apache.flink.ml.classification.SVM.ThresholdValue]]: + * Defines the limiting value for the decision function above which examples are labeled as + * positive (+1.0). Examples with a decision function value below this value are classified as + * negative(-1.0). In order to get the raw decision function values you need to indicate it by + * using the [[org.apache.flink.ml.classification.SVM.OutputDecisionFunction]]. + * (Default value: '''0.0''') + * + * - [[org.apache.flink.ml.classification.SVM.OutputDecisionFunction]]: + * Determines whether the predict and evaluate functions of the SVM should return the distance + * to the separating hyperplane, or binary class labels. Setting this to true will return the raw + * distance to the hyperplane for each example. Setting it to false will return the binary + * class label (+1.0, -1.0) (Default value: '''false''') + */ +class SVM extends Predictor[SVM] { + + import SVM._ + + /** Stores the learned weight vector after the fit operation */ + var weightsOption: Option[DataSet[DenseVector]] = None + + /** Sets the number of data blocks/partitions + * + * @param blocks + * @return itself + */ + def setBlocks(blocks: Int): SVM = { + parameters.add(Blocks, blocks) + this + } + + /** Sets the number of outer iterations + * + * @param iterations + * @return itself + */ + def setIterations(iterations: Int): SVM = { + parameters.add(Iterations, iterations) + this + } + + /** Sets the number of local SDCA iterations + * + * @param localIterations + * @return itselft + */ + def setLocalIterations(localIterations: Int): SVM = { + parameters.add(LocalIterations, localIterations) + this + } + + /** Sets the regularization constant + * + * @param regularization + * @return itself + */ + def setRegularization(regularization: Double): SVM = { + parameters.add(Regularization, regularization) + this + } + + /** Sets the stepsize for the weight vector updates + * + * @param stepsize + * @return itself + */ + def setStepsize(stepsize: Double): SVM = { + parameters.add(Stepsize, stepsize) + this + } + + /** Sets the seed value for the random number generator + * + * @param seed + * @return itself + */ + def setSeed(seed: Long): SVM = { + parameters.add(Seed, seed) + this + } + + /** Sets the threshold above which elements are classified as positive. + * + * The [[predict ]] and [[evaluate]] functions will return +1.0 for items with a decision + * function value above this threshold, and -1.0 for items below it. + * @param threshold + * @return + */ + def setThreshold(threshold: Double): SVM = { + parameters.add(ThresholdValue, threshold) + this + } + + /** Sets whether the predictions should return the raw decision function value or the + * thresholded binary value. + * + * When setting this to true, predict and evaluate return the raw decision value, which is + * the distance from the separating hyperplane. + * When setting this to false, they return thresholded (+1.0, -1.0) values. + * + * @param outputDecisionFunction When set to true, [[predict ]] and [[evaluate]] return the raw + * decision function values. When set to false, they return the + * thresholded binary values (+1.0, -1.0). + */ + def setOutputDecisionFunction(outputDecisionFunction: Boolean): SVM = { + parameters.add(OutputDecisionFunction, outputDecisionFunction) + this + } +} + +/** Companion object of SVM. Contains convenience functions and the parameter type definitions + * of the algorithm. + */ +object SVM{ + + val WEIGHT_VECTOR ="weightVector" + + // ========================================== Parameters ========================================= + + case object Blocks extends Parameter[Int] { + val defaultValue: Option[Int] = None + } + + case object Iterations extends Parameter[Int] { + val defaultValue = Some(10) + } + + case object LocalIterations extends Parameter[Int] { + val defaultValue = Some(10) + } + + case object Regularization extends Parameter[Double] { + val defaultValue = Some(1.0) + } + + case object Stepsize extends Parameter[Double] { + val defaultValue = Some(1.0) + } + + case object Seed extends Parameter[Long] { + val defaultValue = Some(Random.nextLong()) + } + + case object ThresholdValue extends Parameter[Double] { + val defaultValue = Some(0.0) + } + + case object OutputDecisionFunction extends Parameter[Boolean] { + val defaultValue = Some(false) + } + + // ========================================== Factory methods ==================================== + + def apply(): SVM = { + new SVM() + } + + // ========================================== Operations ========================================= + + /** Provides the operation that makes the predictions for individual examples. + * + * @tparam T + * @return A PredictOperation, through which it is possible to predict a value, given a + * feature vector + */ + implicit def predictVectors[T <: Vector] = { + new PredictOperation[SVM, DenseVector, T, Double](){ + + var thresholdValue: Double = _ + var outputDecisionFunction: Boolean = _ + + override def getModel(self: SVM, predictParameters: ParameterMap): DataSet[DenseVector] = { + thresholdValue = predictParameters(ThresholdValue) + outputDecisionFunction = predictParameters(OutputDecisionFunction) + self.weightsOption match { + case Some(model) => model + case None => { + throw new RuntimeException("The SVM model has not been trained. Call first fit" + + "before calling the predict operation.") + } + } + } + + override def predict(value: T, model: DenseVector): Double = { + val rawValue = value.asBreeze dot model.asBreeze + + if (outputDecisionFunction) { + rawValue + } else { + if (rawValue > thresholdValue) 1.0 else -1.0 + } + } + } + } + + /** [[FitOperation]] which trains a SVM with soft-margin based on the given training data set. + * + */ + implicit val fitSVM = { + new FitOperation[SVM, LabeledVector] { + override def fit( + instance: SVM, + fitParameters: ParameterMap, + input: DataSet[LabeledVector]) + : Unit = { + val resultingParameters = instance.parameters ++ fitParameters + + // Check if the number of blocks/partitions has been specified + val blocks = resultingParameters.get(Blocks) match { + case Some(value) => value + case None => input.getParallelism + } + + val scaling = resultingParameters(Stepsize)/blocks + val iterations = resultingParameters(Iterations) + val localIterations = resultingParameters(LocalIterations) + val regularization = resultingParameters(Regularization) + val seed = resultingParameters(Seed) + + // Obtain DataSet with the dimension of the data points + val dimension = input.map{_.vector.size}.reduce{ + (a, b) => { + require(a == b, "Dimensions of feature vectors have to be equal.") + a + } + } + + val initialWeights = createInitialWeights(dimension) + + // Count the number of vectors, but keep the value in a DataSet to broadcast it later + // TODO: Once efficient count and intermediate result partitions are implemented, use count + val numberVectors = input map { x => 1 } reduce { _ + _ } + + // Group the input data into blocks in round robin fashion + val blockedInputNumberElements = FlinkMLTools.block( + input, + blocks, + Some(ModuloKeyPartitioner)). + cross(numberVectors). + map { x => x } + + val resultingWeights = initialWeights.iterate(iterations) { + weights => { + // compute the local SDCA to obtain the weight vector updates + val deltaWs = localDualMethod( + weights, + blockedInputNumberElements, + localIterations, + regularization, + scaling, + seed + ) + + // scale the weight vectors + val weightedDeltaWs = deltaWs map { + deltaW => { + deltaW :*= scaling + } + } + + // calculate the new weight vector by adding the weight vector updates to the weight + // vector value + weights.union(weightedDeltaWs).reduce { _ + _ } + } + } + + // Store the learned weight vector in hte given instance + instance.weightsOption = Some(resultingWeights.map(_.fromBreeze[DenseVector])) + } + } + } + + /** Creates a zero vector of length dimension + * + * @param dimension [[DataSet]] containing the dimension of the initial weight vector + * @return Zero vector of length dimension + */ + private def createInitialWeights(dimension: DataSet[Int]): DataSet[BreezeDenseVector[Double]] = { + dimension.map { + d => BreezeDenseVector.zeros[Double](d) + } + } + + /** Computes the local SDCA on the individual data blocks/partitions + * + * @param w Current weight vector + * @param blockedInputNumberElements Blocked/Partitioned input data + * @param localIterations Number of local SDCA iterations + * @param regularization Regularization constant + * @param scaling Scaling value for new weight vector updates + * @param seed Random number generator seed + * @return [[DataSet]] of weight vector updates. The weight vector updates are double arrays + */ + private def localDualMethod( + w: DataSet[BreezeDenseVector[Double]], + blockedInputNumberElements: DataSet[(Block[LabeledVector], Int)], + localIterations: Int, + regularization: Double, + scaling: Double, + seed: Long) + : DataSet[BreezeDenseVector[Double]] = { + /* + Rich mapper calculating for each data block the local SDCA. We use a RichMapFunction here, + because we broadcast the current value of the weight vector to all mappers. + */ + val localSDCA = new RichMapFunction[(Block[LabeledVector], Int), BreezeDenseVector[Double]] { + var originalW: BreezeDenseVector[Double] = _ + // we keep the alphas across the outer loop iterations + val alphasArray = ArrayBuffer[BreezeDenseVector[Double]]() + // there might be several data blocks in one Flink partition, therefore store mapping + val idMapping = scala.collection.mutable.HashMap[Int, Int]() + var counter = 0 + + var r: Random = _ + + override def open(parameters: Configuration): Unit = { + originalW = getRuntimeContext.getBroadcastVariable(WEIGHT_VECTOR).get(0) + + if(r == null){ + r = new Random(seed ^ getRuntimeContext.getIndexOfThisSubtask) + } + } + + override def map(blockNumberElements: (Block[LabeledVector], Int)) + : BreezeDenseVector[Double] = { + val (block, numberElements) = blockNumberElements + + // check if we already processed a data block with the corresponding block index + val localIndex = idMapping.get(block.index) match { + case Some(idx) => idx + case None => + idMapping += (block.index -> counter) + counter += 1 + + alphasArray += BreezeDenseVector.zeros[Double](block.values.length) + + counter - 1 + } + + // create temporary alpha array for the local SDCA iterations + val tempAlphas = alphasArray(localIndex).copy + + val numLocalDatapoints = tempAlphas.length + val deltaAlphas = BreezeDenseVector.zeros[Double](numLocalDatapoints) + + val w = originalW.copy + + val deltaW = BreezeDenseVector.zeros[Double](originalW.length) + + for(i <- 1 to localIterations) { + // pick random data point for SDCA + val idx = r.nextInt(numLocalDatapoints) + + val LabeledVector(label, vector) = block.values(idx) + val alpha = tempAlphas(idx) + + // maximize the dual problem and retrieve alpha and weight vector updates + val (deltaAlpha, deltaWUpdate) = maximize( + vector.asBreeze, + label, + regularization, + alpha, + w, + numberElements) + + // update alpha values + tempAlphas(idx) += deltaAlpha + deltaAlphas(idx) += deltaAlpha + + // deltaWUpdate is already scaled with 1/lambda/n + w += deltaWUpdate + deltaW += deltaWUpdate + } + + // update local alpha values + alphasArray(localIndex) += deltaAlphas * scaling + + deltaW + } + } + + blockedInputNumberElements.map(localSDCA).withBroadcastSet(w, WEIGHT_VECTOR) + } + + /** Maximizes the dual problem using hinge loss functions. It returns the alpha and weight + * vector updates. + * + * @param x Selected data point + * @param y Label of selected data point + * @param regularization Regularization constant + * @param alpha Alpha value of selected data point + * @param w Current weight vector value + * @param numberElements Number of elements in the training data set + * @return Alpha and weight vector updates + */ + private def maximize( + x: BreezeVector[Double], + y: Double, regularization: Double, + alpha: Double, + w: BreezeVector[Double], + numberElements: Int) + : (Double, BreezeVector[Double]) = { + // compute hinge loss gradient + val dotProduct = x dot w + val grad = (y * dotProduct - 1.0) * (regularization * numberElements) + + // compute projected gradient + var proj_grad = if(alpha <= 0.0){ + scala.math.min(grad, 0) + } else if(alpha >= 1.0) { + scala.math.max(grad, 0) + } else { + grad + } + + if(scala.math.abs(grad) != 0.0){ + val qii = x dot x + val newAlpha = if(qii != 0.0){ + scala.math.min(scala.math.max(alpha - (grad / qii), 0.0), 1.0) + } else { + 1.0 + } + + val deltaW = x * y * (newAlpha - alpha) / (regularization * numberElements) + + (newAlpha - alpha, deltaW) + } else { + (0.0 , BreezeVector.zeros(w.length)) + } + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/Block.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/Block.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/Block.scala new file mode 100644 index 0000000..1af77ea --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/Block.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common + +/** Base class for blocks of elements. + * + * TODO: Replace Vector type by Array type once Flink supports generic arrays + * + * @param index + * @param values + * @tparam T + */ +case class Block[T](index: Int, values: Vector[T]) {} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkMLTools.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkMLTools.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkMLTools.scala new file mode 100644 index 0000000..553ec00 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkMLTools.scala @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common + +import org.apache.flink.api.common.functions.Partitioner +import org.apache.flink.api.common.io.FileOutputFormat.OutputDirectoryMode +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.io.{TypeSerializerInputFormat, TypeSerializerOutputFormat} +import org.apache.flink.api.scala._ +import org.apache.flink.core.fs.FileSystem.WriteMode +import org.apache.flink.core.fs.Path + +import scala.reflect.ClassTag + +/** FlinkTools contains a set of convenience functions for Flink's machine learning library: + * + * - persist: + * Takes up to 5 [[DataSet]]s and file paths. Each [[DataSet]] is written to the specified + * path and subsequently re-read from disk. This method can be used to effectively split the + * execution graph at the given [[DataSet]]. Writing it to disk triggers its materialization + * and specifying it as a source will prevent the re-execution of it. + * + * - block: + * Takes a DataSet of elements T and groups them in n blocks. + * + */ +object FlinkMLTools { + + /** Registers the different FlinkML related types for Kryo serialization + * + * @param env + */ + def registerFlinkMLTypes(env: ExecutionEnvironment): Unit = { + + // Vector types + env.registerType(classOf[org.apache.flink.ml.math.DenseVector]) + env.registerType(classOf[org.apache.flink.ml.math.SparseVector]) + + // Matrix types + env.registerType(classOf[org.apache.flink.ml.math.DenseMatrix]) + env.registerType(classOf[org.apache.flink.ml.math.SparseMatrix]) + + // Breeze Vector types + env.registerType(classOf[breeze.linalg.DenseVector[_]]) + env.registerType(classOf[breeze.linalg.SparseVector[_]]) + + // Breeze specialized types + env.registerType(breeze.linalg.DenseVector.zeros[Double](0).getClass) + env.registerType(breeze.linalg.SparseVector.zeros[Double](0).getClass) + + // Breeze Matrix types + env.registerType(classOf[breeze.linalg.DenseMatrix[Double]]) + env.registerType(classOf[breeze.linalg.CSCMatrix[Double]]) + + // Breeze specialized types + env.registerType(breeze.linalg.DenseMatrix.zeros[Double](0, 0).getClass) + env.registerType(breeze.linalg.CSCMatrix.zeros[Double](0, 0).getClass) + } + + /** Writes a [[DataSet]] to the specified path and returns it as a DataSource for subsequent + * operations. + * + * @param dataset [[DataSet]] to write to disk + * @param path File path to write dataset to + * @tparam T Type of the [[DataSet]] elements + * @return [[DataSet]] reading the just written file + */ + def persist[T: ClassTag: TypeInformation](dataset: DataSet[T], path: String): DataSet[T] = { + val env = dataset.getExecutionEnvironment + val outputFormat = new TypeSerializerOutputFormat[T] + + val filePath = new Path(path) + + outputFormat.setOutputFilePath(filePath) + outputFormat.setWriteMode(WriteMode.OVERWRITE) + + dataset.output(outputFormat) + env.execute("FlinkTools persist") + + val inputFormat = new TypeSerializerInputFormat[T](dataset.getType) + inputFormat.setFilePath(filePath) + + env.createInput(inputFormat) + } + + /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for + * subsequent operations. + * + * @param ds1 First [[DataSet]] to write to disk + * @param ds2 Second [[DataSet]] to write to disk + * @param path1 Path for ds1 + * @param path2 Path for ds2 + * @tparam A Type of the first [[DataSet]]'s elements + * @tparam B Type of the second [[DataSet]]'s elements + * @return Tuple of [[DataSet]]s reading the just written files + */ + def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation](ds1: DataSet[A], ds2: + DataSet[B], path1: String, path2: String):(DataSet[A], DataSet[B]) = { + val env = ds1.getExecutionEnvironment + + val f1 = new Path(path1) + + val of1 = new TypeSerializerOutputFormat[A] + of1.setOutputFilePath(f1) + of1.setWriteMode(WriteMode.OVERWRITE) + + ds1.output(of1) + + val f2 = new Path(path2) + + val of2 = new TypeSerializerOutputFormat[B] + of2.setOutputFilePath(f2) + of2.setWriteMode(WriteMode.OVERWRITE) + + ds2.output(of2) + + env.execute("FlinkTools persist") + + val if1 = new TypeSerializerInputFormat[A](ds1.getType) + if1.setFilePath(f1) + + val if2 = new TypeSerializerInputFormat[B](ds2.getType) + if2.setFilePath(f2) + + (env.createInput(if1), env.createInput(if2)) + } + + /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for + * subsequent operations. + * + * @param ds1 First [[DataSet]] to write to disk + * @param ds2 Second [[DataSet]] to write to disk + * @param ds3 Third [[DataSet]] to write to disk + * @param path1 Path for ds1 + * @param path2 Path for ds2 + * @param path3 Path for ds3 + * @tparam A Type of first [[DataSet]]'s elements + * @tparam B Type of second [[DataSet]]'s elements + * @tparam C Type of third [[DataSet]]'s elements + * @return Tuple of [[DataSet]]s reading the just written files + */ + def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation, + C: ClassTag: TypeInformation](ds1: DataSet[A], ds2: DataSet[B], ds3: DataSet[C], path1: + String, path2: String, path3: String): (DataSet[A], DataSet[B], DataSet[C]) = { + val env = ds1.getExecutionEnvironment + + val f1 = new Path(path1) + + val of1 = new TypeSerializerOutputFormat[A] + of1.setOutputFilePath(f1) + of1.setWriteMode(WriteMode.OVERWRITE) + + ds1.output(of1) + + val f2 = new Path(path2) + + val of2 = new TypeSerializerOutputFormat[B] + of2.setOutputFilePath(f2) + of2.setWriteMode(WriteMode.OVERWRITE) + + ds2.output(of2) + + val f3 = new Path(path3) + + val of3 = new TypeSerializerOutputFormat[C] + of3.setOutputFilePath(f3) + of3.setWriteMode(WriteMode.OVERWRITE) + + ds3.output(of3) + + env.execute("FlinkTools persist") + + val if1 = new TypeSerializerInputFormat[A](ds1.getType) + if1.setFilePath(f1) + + val if2 = new TypeSerializerInputFormat[B](ds2.getType) + if2.setFilePath(f2) + + val if3 = new TypeSerializerInputFormat[C](ds3.getType) + if3.setFilePath(f3) + + (env.createInput(if1), env.createInput(if2), env.createInput(if3)) + } + + /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for + * subsequent operations. + * + * @param ds1 First [[DataSet]] to write to disk + * @param ds2 Second [[DataSet]] to write to disk + * @param ds3 Third [[DataSet]] to write to disk + * @param ds4 Fourth [[DataSet]] to write to disk + * @param path1 Path for ds1 + * @param path2 Path for ds2 + * @param path3 Path for ds3 + * @param path4 Path for ds4 + * @tparam A Type of first [[DataSet]]'s elements + * @tparam B Type of second [[DataSet]]'s elements + * @tparam C Type of third [[DataSet]]'s elements + * @tparam D Type of fourth [[DataSet]]'s elements + * @return Tuple of [[DataSet]]s reading the just written files + */ + def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation, + C: ClassTag: TypeInformation, D: ClassTag: TypeInformation](ds1: DataSet[A], ds2: DataSet[B], + ds3: DataSet[C], ds4: DataSet[D], + path1: String, path2: String, path3: + String, path4: String): + (DataSet[A], DataSet[B], DataSet[C], DataSet[D]) = { + val env = ds1.getExecutionEnvironment + + val f1 = new Path(path1) + + val of1 = new TypeSerializerOutputFormat[A] + of1.setOutputFilePath(f1) + of1.setWriteMode(WriteMode.OVERWRITE) + + ds1.output(of1) + + val f2 = new Path(path2) + + val of2 = new TypeSerializerOutputFormat[B] + of2.setOutputFilePath(f2) + of2.setWriteMode(WriteMode.OVERWRITE) + + ds2.output(of2) + + val f3 = new Path(path3) + + val of3 = new TypeSerializerOutputFormat[C] + of3.setOutputFilePath(f3) + of3.setWriteMode(WriteMode.OVERWRITE) + + ds3.output(of3) + + val f4 = new Path(path4) + + val of4 = new TypeSerializerOutputFormat[D] + of4.setOutputFilePath(f4) + of4.setWriteMode(WriteMode.OVERWRITE) + + ds4.output(of4) + + env.execute("FlinkTools persist") + + val if1 = new TypeSerializerInputFormat[A](ds1.getType) + if1.setFilePath(f1) + + val if2 = new TypeSerializerInputFormat[B](ds2.getType) + if2.setFilePath(f2) + + val if3 = new TypeSerializerInputFormat[C](ds3.getType) + if3.setFilePath(f3) + + val if4 = new TypeSerializerInputFormat[D](ds4.getType) + if4.setFilePath(f4) + + (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4)) + } + + /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for + * subsequent operations. + * + * @param ds1 First [[DataSet]] to write to disk + * @param ds2 Second [[DataSet]] to write to disk + * @param ds3 Third [[DataSet]] to write to disk + * @param ds4 Fourth [[DataSet]] to write to disk + * @param ds5 Fifth [[DataSet]] to write to disk + * @param path1 Path for ds1 + * @param path2 Path for ds2 + * @param path3 Path for ds3 + * @param path4 Path for ds4 + * @param path5 Path for ds5 + * @tparam A Type of first [[DataSet]]'s elements + * @tparam B Type of second [[DataSet]]'s elements + * @tparam C Type of third [[DataSet]]'s elements + * @tparam D Type of fourth [[DataSet]]'s elements + * @tparam E Type of fifth [[DataSet]]'s elements + * @return Tuple of [[DataSet]]s reading the just written files + */ + def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation, + C: ClassTag: TypeInformation, D: ClassTag: TypeInformation, E: ClassTag: TypeInformation] + (ds1: DataSet[A], ds2: DataSet[B], ds3: DataSet[C], ds4: DataSet[D], ds5: DataSet[E], path1: + String, path2: String, path3: String, path4: String, path5: String): (DataSet[A], DataSet[B], + DataSet[C], DataSet[D], DataSet[E]) = { + val env = ds1.getExecutionEnvironment + + val f1 = new Path(path1) + + val of1 = new TypeSerializerOutputFormat[A] + of1.setOutputFilePath(f1) + of1.setWriteMode(WriteMode.OVERWRITE) + + ds1.output(of1) + + val f2 = new Path(path2) + + val of2 = new TypeSerializerOutputFormat[B] + of2.setOutputFilePath(f2) + of2.setOutputDirectoryMode(OutputDirectoryMode.ALWAYS) + of2.setWriteMode(WriteMode.OVERWRITE) + + ds2.output(of2) + + val f3 = new Path(path3) + + val of3 = new TypeSerializerOutputFormat[C] + of3.setOutputFilePath(f3) + of3.setWriteMode(WriteMode.OVERWRITE) + + ds3.output(of3) + + val f4 = new Path(path4) + + val of4 = new TypeSerializerOutputFormat[D] + of4.setOutputFilePath(f4) + of4.setWriteMode(WriteMode.OVERWRITE) + + ds4.output(of4) + + val f5 = new Path(path5) + + val of5 = new TypeSerializerOutputFormat[E] + of5.setOutputFilePath(f5) + of5.setWriteMode(WriteMode.OVERWRITE) + + ds5.output(of5) + + env.execute("FlinkTools persist") + + val if1 = new TypeSerializerInputFormat[A](ds1.getType) + if1.setFilePath(f1) + + val if2 = new TypeSerializerInputFormat[B](ds2.getType) + if2.setFilePath(f2) + + val if3 = new TypeSerializerInputFormat[C](ds3.getType) + if3.setFilePath(f3) + + val if4 = new TypeSerializerInputFormat[D](ds4.getType) + if4.setFilePath(f4) + + val if5 = new TypeSerializerInputFormat[E](ds5.getType) + if5.setFilePath(f5) + + (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4), env + .createInput(if5)) + } + + /** Groups the DataSet input into numBlocks blocks. + * + * @param input + * @param numBlocks Number of Blocks + * @param partitionerOption Optional partitioner to control the partitioning + * @tparam T + * @return + */ + def block[T: TypeInformation: ClassTag]( + input: DataSet[T], + numBlocks: Int, + partitionerOption: Option[Partitioner[Int]] = None) + : DataSet[Block[T]] = { + val blockIDInput = input map { + element => + val blockID = element.hashCode() % numBlocks + + val blockIDResult = if(blockID < 0){ + blockID + numBlocks + } else { + blockID + } + + (blockIDResult, element) + } + + val preGroupBlockIDInput = partitionerOption match { + case Some(partitioner) => + blockIDInput partitionCustom(partitioner, 0) + + case None => blockIDInput + } + + preGroupBlockIDInput.groupBy(0).reduceGroup { + iter => { + val array = iter.toVector + + val blockID = array(0)._1 + val elements = array.map(_._2) + + Block[T](blockID, elements) + } + }.withForwardedFields("0 -> index") + } + + /** Distributes the elements by taking the modulo of their keys and assigning it to this channel + * + */ + object ModuloKeyPartitioner extends Partitioner[Int] { + override def partition(key: Int, numPartitions: Int): Int = { + val result = key % numPartitions + + if(result < 0) { + result + numPartitions + } else { + result + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala new file mode 100644 index 0000000..3b948c0 --- /dev/null +++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common + +import org.apache.flink.ml.math.Vector + +/** This class represents a vector with an associated label as it is required for many supervised + * learning tasks. + * + * @param label Label of the data point + * @param vector Data point + */ +case class LabeledVector(label: Double, vector: Vector) extends Serializable { + + override def equals(obj: Any): Boolean = { + obj match { + case labeledVector: LabeledVector => + vector.equals(labeledVector.vector) && label.equals(labeledVector.label) + case _ => false + } + } + + override def toString: String = { + s"LabeledVector($label, $vector)" + } +}