This is an automated email from the ASF dual-hosted git repository.

tzulitai pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.11 by this push:
     new f33c30f  [FLINK-19300] Fix input stream read to prevent heap based 
timer loss
f33c30f is described below

commit f33c30feed2cdd36c04b373ef36f6c11a5f5d504
Author: acesine <xianggaod...@gmail.com>
AuthorDate: Mon Oct 5 18:02:24 2020 -0700

    [FLINK-19300] Fix input stream read to prevent heap based timer loss
    
    This closes #14042.
---
 .../core/io/PostVersionedIOReadableWritable.java   |  13 ++-
 .../main/java/org/apache/flink/util/IOUtils.java   |  26 +++++
 .../io/PostVersionedIOReadableWritableTest.java    | 128 +++++++++++++++------
 .../java/org/apache/flink/util/IOUtilsTest.java    |  52 +++++++++
 4 files changed, 183 insertions(+), 36 deletions(-)

diff --git 
a/flink-core/src/main/java/org/apache/flink/core/io/PostVersionedIOReadableWritable.java
 
b/flink-core/src/main/java/org/apache/flink/core/io/PostVersionedIOReadableWritable.java
index 6edc983..e396384 100644
--- 
a/flink-core/src/main/java/org/apache/flink/core/io/PostVersionedIOReadableWritable.java
+++ 
b/flink-core/src/main/java/org/apache/flink/core/io/PostVersionedIOReadableWritable.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.util.IOUtils;
 
 import java.io.IOException;
 import java.io.InputStream;
@@ -62,7 +63,7 @@ public abstract class PostVersionedIOReadableWritable extends 
VersionedIOReadabl
         */
        public final void read(InputStream inputStream) throws IOException {
                byte[] tmp = new byte[VERSIONED_IDENTIFIER.length];
-               inputStream.read(tmp);
+               int totalRead = IOUtils.tryReadFully(inputStream, tmp);
 
                if (Arrays.equals(tmp, VERSIONED_IDENTIFIER)) {
                        DataInputView inputView = new 
DataInputViewStreamWrapper(inputStream);
@@ -70,10 +71,14 @@ public abstract class PostVersionedIOReadableWritable 
extends VersionedIOReadabl
                        super.read(inputView);
                        read(inputView, true);
                } else {
-                       PushbackInputStream resetStream = new 
PushbackInputStream(inputStream, VERSIONED_IDENTIFIER.length);
-                       resetStream.unread(tmp);
+                       InputStream streamToRead = inputStream;
+                       if (totalRead > 0) {
+                               PushbackInputStream resetStream = new 
PushbackInputStream(inputStream, totalRead);
+                               resetStream.unread(tmp, 0, totalRead);
+                               streamToRead = resetStream;
+                       }
 
-                       read(new DataInputViewStreamWrapper(resetStream), 
false);
+                       read(new DataInputViewStreamWrapper(streamToRead), 
false);
                }
        }
 
diff --git a/flink-core/src/main/java/org/apache/flink/util/IOUtils.java 
b/flink-core/src/main/java/org/apache/flink/util/IOUtils.java
index 0b8f210..8f204a9 100644
--- a/flink-core/src/main/java/org/apache/flink/util/IOUtils.java
+++ b/flink-core/src/main/java/org/apache/flink/util/IOUtils.java
@@ -142,6 +142,32 @@ public final class IOUtils {
        }
 
        /**
+        * Similar to {@link #readFully(InputStream, byte[], int, int)}. 
Returns the total number of
+        * bytes read into the buffer.
+        *
+        * @param in
+        *        The InputStream to read from
+        * @param buf
+        *        The buffer to fill
+        * @return
+        *        The total number of bytes read into the buffer
+        * @throws IOException
+        *         If the first byte cannot be read for any reason other than 
end of file,
+        *         or if the input stream has been closed, or if some other I/O 
error occurs.
+        */
+       public static int tryReadFully(final InputStream in, final byte[] buf) 
throws IOException {
+               int totalRead = 0;
+               while (totalRead != buf.length) {
+                       int read = in.read(buf, totalRead, buf.length - 
totalRead);
+                       if (read == -1) {
+                               break;
+                       }
+                       totalRead += read;
+               }
+               return totalRead;
+       }
+
+       /**
         * Similar to readFully(). Skips bytes in a loop.
         *
         * @param in
diff --git 
a/flink-core/src/test/java/org/apache/flink/core/io/PostVersionedIOReadableWritableTest.java
 
b/flink-core/src/test/java/org/apache/flink/core/io/PostVersionedIOReadableWritableTest.java
index 2954536..3f41ad1 100644
--- 
a/flink-core/src/test/java/org/apache/flink/core/io/PostVersionedIOReadableWritableTest.java
+++ 
b/flink-core/src/test/java/org/apache/flink/core/io/PostVersionedIOReadableWritableTest.java
@@ -26,6 +26,7 @@ import 
org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.junit.Assert;
 import org.junit.Test;
 
+import java.io.EOFException;
 import java.io.IOException;
 
 /**
@@ -35,9 +36,51 @@ public class PostVersionedIOReadableWritableTest {
 
        @Test
        public void testReadVersioned() throws IOException {
+               byte[] payload = "test-data".getBytes();
+               byte[] serialized = 
serializeWithPostVersionedReadableWritable(payload);
+               byte[] restored = 
restoreWithPostVersionedReadableWritable(serialized, payload.length);
 
-               String payload = "test-data";
-               TestPostVersionedReadableWritable versionedReadableWritable = 
new TestPostVersionedReadableWritable(payload);
+               Assert.assertArrayEquals(payload, restored);
+       }
+
+       @Test
+       public void testReadNonVersioned() throws IOException {
+               byte[] preVersionedPayload = new byte[]{0x00, 0x00, 0x02, 0x33};
+               byte[] serialized = 
serializeWithNonVersionedReadableWritable(preVersionedPayload);
+               byte[] restored = 
restoreWithPostVersionedReadableWritable(serialized, 
preVersionedPayload.length);
+
+               Assert.assertArrayEquals(preVersionedPayload, restored);
+       }
+
+       @Test
+       public void testReadNonVersionedWithLongPayload() throws IOException {
+               byte[] preVersionedPayload = "test-data".getBytes();
+               byte[] serialized = 
serializeWithNonVersionedReadableWritable(preVersionedPayload);
+               byte[] restored = 
restoreWithPostVersionedReadableWritable(serialized, 
preVersionedPayload.length);
+
+               Assert.assertArrayEquals(preVersionedPayload, restored);
+       }
+
+       @Test
+       public void testReadNonVersionedWithShortPayload() throws IOException {
+               byte[] preVersionedPayload = new byte[]{-15, -51};
+               byte[] serialized = 
serializeWithNonVersionedReadableWritable(preVersionedPayload);
+               byte[] restored = 
restoreWithPostVersionedReadableWritable(serialized, 
preVersionedPayload.length);
+
+               Assert.assertArrayEquals(preVersionedPayload, restored);
+       }
+
+       @Test
+       public void testReadNonVersionedWithEmptyPayload() throws IOException {
+               byte[] preVersionedPayload = new byte[0];
+               byte[] serialized = 
serializeWithNonVersionedReadableWritable(preVersionedPayload);
+               byte[] restored = 
restoreWithPostVersionedReadableWritable(serialized, 
preVersionedPayload.length);
+
+               Assert.assertArrayEquals(preVersionedPayload, restored);
+       }
+
+       private byte[] serializeWithNonVersionedReadableWritable(byte[] 
payload) throws IOException {
+               TestNonVersionedReadableWritable versionedReadableWritable = 
new TestNonVersionedReadableWritable(payload);
 
                byte[] serialized;
                try (ByteArrayOutputStreamWithPos out = new 
ByteArrayOutputStreamWithPos()) {
@@ -45,42 +88,49 @@ public class PostVersionedIOReadableWritableTest {
                        serialized = out.toByteArray();
                }
 
-               TestPostVersionedReadableWritable 
restoredVersionedReadableWritable = new TestPostVersionedReadableWritable();
-               try(ByteArrayInputStreamWithPos in = new 
ByteArrayInputStreamWithPos(serialized)) {
-                       restoredVersionedReadableWritable.read(in);
-               }
-
-               Assert.assertEquals(payload, 
restoredVersionedReadableWritable.getData());
+               return serialized;
        }
 
-       @Test
-       public void testReadNonVersioned() throws IOException {
-               int preVersionedPayload = 563;
-
-               TestNonVersionedReadableWritable nonVersionedReadableWritable = 
new TestNonVersionedReadableWritable(preVersionedPayload);
+       private byte[] serializeWithPostVersionedReadableWritable(byte[] 
payload) throws IOException {
+               TestPostVersionedReadableWritable versionedReadableWritable = 
new TestPostVersionedReadableWritable(payload);
 
                byte[] serialized;
                try (ByteArrayOutputStreamWithPos out = new 
ByteArrayOutputStreamWithPos()) {
-                       nonVersionedReadableWritable.write(new 
DataOutputViewStreamWrapper(out));
+                       versionedReadableWritable.write(new 
DataOutputViewStreamWrapper(out));
                        serialized = out.toByteArray();
                }
 
-               TestPostVersionedReadableWritable 
restoredVersionedReadableWritable = new TestPostVersionedReadableWritable();
-               try(ByteArrayInputStreamWithPos in = new 
ByteArrayInputStreamWithPos(serialized)) {
+               return serialized;
+       }
+
+       private byte[] restoreWithPostVersionedReadableWritable(byte[] 
serialized, int expectedLength) throws IOException {
+               TestPostVersionedReadableWritable 
restoredVersionedReadableWritable = new 
TestPostVersionedReadableWritable(expectedLength);
+
+               try(ByteArrayInputStreamWithPos in = new 
TestByteArrayInputStreamProducingOneByteAtATime(serialized)) {
                        restoredVersionedReadableWritable.read(in);
                }
 
-               Assert.assertEquals(String.valueOf(preVersionedPayload), 
restoredVersionedReadableWritable.getData());
+               return restoredVersionedReadableWritable.getData();
+       }
+
+       private static void assertEmpty(DataInputView in) throws IOException {
+               try {
+                       in.readByte();
+                       Assert.fail();
+               } catch (EOFException ignore) {
+               }
        }
 
        static class TestPostVersionedReadableWritable extends 
PostVersionedIOReadableWritable {
 
                private static final int VERSION = 1;
-               private String data;
+               private byte[] data;
 
-               TestPostVersionedReadableWritable() {}
+               TestPostVersionedReadableWritable(int len) {
+                       this.data = new byte[len];
+               }
 
-               TestPostVersionedReadableWritable(String data) {
+               TestPostVersionedReadableWritable(byte[] data) {
                        this.data = data;
                }
 
@@ -92,40 +142,54 @@ public class PostVersionedIOReadableWritableTest {
                @Override
                public void write(DataOutputView out) throws IOException {
                        super.write(out);
-                       out.writeUTF(data);
+                       out.write(data);
                }
 
                @Override
                protected void read(DataInputView in, boolean wasVersioned) 
throws IOException {
-                       if (wasVersioned) {
-                               this.data = in.readUTF();
-                       } else {
-                               // in the previous non-versioned format, we 
wrote integers instead
-                               this.data = String.valueOf(in.readInt());
-                       }
+                       in.readFully(data);
+                       assertEmpty(in);
                }
 
-               public String getData() {
+               public byte[] getData() {
                        return data;
                }
        }
 
        static class TestNonVersionedReadableWritable implements 
IOReadableWritable {
 
-               private int data;
+               private byte[] data;
 
-               TestNonVersionedReadableWritable(int data) {
+               TestNonVersionedReadableWritable(byte[] data) {
                        this.data = data;
                }
 
                @Override
                public void write(DataOutputView out) throws IOException {
-                       out.writeInt(data);
+                       out.write(data);
                }
 
                @Override
                public void read(DataInputView in) throws IOException {
-                       this.data = in.readInt();
+                       in.readFully(data);
+                       assertEmpty(in);
+               }
+       }
+
+       static class TestByteArrayInputStreamProducingOneByteAtATime extends 
ByteArrayInputStreamWithPos {
+
+               public TestByteArrayInputStreamProducingOneByteAtATime(byte[] 
buf) {
+                       super(buf);
+               }
+
+               @Override
+               public int read(byte[] b, int off, int len) {
+                       return super.read(b, off, Math.min(len, 1));
+               }
+
+               @Override
+               public int read(byte[] b) throws IOException {
+                       return read(b, 0, b.length);
                }
        }
 
diff --git a/flink-core/src/test/java/org/apache/flink/util/IOUtilsTest.java 
b/flink-core/src/test/java/org/apache/flink/util/IOUtilsTest.java
new file mode 100644
index 0000000..9a9cf28
--- /dev/null
+++ b/flink-core/src/test/java/org/apache/flink/util/IOUtilsTest.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.util;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * Tests for the {@link IOUtils}.
+ */
+public class IOUtilsTest extends TestLogger {
+
+       @Test
+       public void testTryReadFullyFromLongerStream() throws IOException {
+               ByteArrayInputStream inputStream = new 
ByteArrayInputStream("test-data".getBytes());
+
+               byte[] out = new byte[4];
+               int read = IOUtils.tryReadFully(inputStream, out);
+
+               Assert.assertArrayEquals("test".getBytes(), 
Arrays.copyOfRange(out, 0, read));
+       }
+
+       @Test
+       public void testTryReadFullyFromShorterStream() throws IOException {
+               ByteArrayInputStream inputStream = new 
ByteArrayInputStream("t".getBytes());
+
+               byte[] out = new byte[4];
+               int read = IOUtils.tryReadFully(inputStream, out);
+
+               Assert.assertArrayEquals("t".getBytes(), 
Arrays.copyOfRange(out, 0, read));
+       }
+}

Reply via email to