Hi Raffaello,

There is way more code changed here than is needed to fix the bug.
General enhancement should be separated from bug fixes.
It makes it easier to review to see the bug was fixed
and easier to separately review other code to see that there are no unexpected changes.

If some of the changes are motivated by expected performance improvements,
there should be JMH tests comparing the before and after.
The change to use byte arrays seems useful, but even using char[]
there is little danger of cache thrashing.

If using the code using xxxExact was correct, don't change it.
Those methods are intrinsified and perform as well or better than using long.
Usually, it is better to leave code alone and not risk breaking it.

Special care needs taken when changing a method that is intrinsified.
The optimized version may use fields of the object and stop
working if they are changed.

In the test, the range of buffer sizes tests seems to waste a lot
of cycles on sizes greater than the encoded size of the input.

Regards, Roger


On 6/29/20 1:51 PM, Raffaello Giulietti wrote:
Hello,

here's a fix and an additional test, both in inline form and as an attachment.

The fix also contains a reimplementation of encodedOutLength() that makes use of long arithmetic rather than relying on addExact(), multiplyExact(), etc.

Further, lookup tables have been declared as byte[], which should help maintaining them in caches for longer because of their smaller sizes.


Greetings
Raffaello

----

# HG changeset patch
# User lello
# Date 1593437938 -7200
#      Mon Jun 29 15:38:58 2020 +0200
# Node ID 58aadb9efae6d5d88479ad8171f2219d41cc761f
# Parent  4a91f6b96a506d9d67437338c33b6953a57bfbd9
8222187: java.util.Base64.Decoder stream adds unexpected null bytes at the end
Reviewed-by: TBD
Contributed-by: Raffaello Giulietti <raffaello.giulie...@gmail.com>

diff --git a/src/java.base/share/classes/java/util/Base64.java b/src/java.base/share/classes/java/util/Base64.java
--- a/src/java.base/share/classes/java/util/Base64.java
+++ b/src/java.base/share/classes/java/util/Base64.java
@@ -30,7 +30,6 @@
 import java.io.IOException;
 import java.io.OutputStream;
 import java.nio.ByteBuffer;
-import java.nio.charset.StandardCharsets;

 import sun.nio.cs.ISO_8859_1;

@@ -133,7 +132,7 @@
      */
     public static Encoder getMimeEncoder(int lineLength, byte[] lineSeparator) {
          Objects.requireNonNull(lineSeparator);
-         int[] base64 = Decoder.fromBase64;
+         byte[] base64 = Decoder.fromBase64;
          for (byte b : lineSeparator) {
              if (base64[b & 0xff] != -1)
                  throw new IllegalArgumentException(
@@ -216,7 +215,7 @@
          * index values into their "Base64 Alphabet" equivalents as specified
          * in "Table 1: The Base64 Alphabet" of RFC 2045 (and RFC 4648).
          */
-        private static final char[] toBase64 = {
+        private static final byte[] toBase64 = {
             'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',              'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',              'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
@@ -229,7 +228,7 @@
          * in Table 2 of the RFC 4648, with the '+' and '/' changed to '-' and
          * '_'. This table is used when BASE64_URL is specified.
          */
-        private static final char[] toBase64URL = {
+        private static final byte[] toBase64URL = {
             'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',              'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',              'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
@@ -238,7 +237,7 @@
         };

         private static final int MIMELINEMAX = 76;
-        private static final byte[] CRLF = new byte[] {'\r', '\n'};
+        private static final byte[] CRLF = {'\r', '\n'};

         static final Encoder RFC4648 = new Encoder(false, null, -1, true);          static final Encoder RFC4648_URLSAFE = new Encoder(true, null, -1, true);
@@ -255,27 +254,19 @@
          *
          */
         private final int encodedOutLength(int srclen, boolean throwOOME) {
-            int len = 0;
-            try {
-                if (doPadding) {
-                    len = Math.multiplyExact(4, (Math.addExact(srclen, 2) / 3));
-                } else {
-                    int n = srclen % 3;
-                    len = Math.addExact(Math.multiplyExact(4, (srclen / 3)), (n == 0 ? 0 : n + 1));
-                }
-                if (linemax > 0) { // line separators
-                    len = Math.addExact(len, (len - 1) / linemax * newline.length);
-                }
-            } catch (ArithmeticException ex) {
-                if (throwOOME) {
-                    throw new OutOfMemoryError("Encoded size is too large");
-                } else {
-                    // let the caller know that encoded bytes length
-                    // is too large
-                    len = -1;
-                }
+            long len = doPadding
+                    ? (srclen + 2L) / 3 * 4
+                    : (srclen + 2L) / 3 + srclen;
+            if (linemax > 0) {
+                len += (len - 1) / linemax * newline.length;
             }
-            return len;
+            if (len <= Integer.MAX_VALUE) {
+                return (int) len;
+            }
+            if (throwOOME) {
+                throw new OutOfMemoryError("Encoded size is too large");
+            }
+            return -1;
         }

         /**
@@ -421,24 +412,24 @@

         @HotSpotIntrinsicCandidate
         private void encodeBlock(byte[] src, int sp, int sl, byte[] dst, int dp, boolean isURL) {
-            char[] base64 = isURL ? toBase64URL : toBase64;
+            byte[] base64 = isURL ? toBase64URL : toBase64;
             for (int sp0 = sp, dp0 = dp ; sp0 < sl; ) {
                 int bits = (src[sp0++] & 0xff) << 16 |
                            (src[sp0++] & 0xff) <<  8 |
                            (src[sp0++] & 0xff);
-                dst[dp0++] = (byte)base64[(bits >>> 18) & 0x3f];
-                dst[dp0++] = (byte)base64[(bits >>> 12) & 0x3f];
-                dst[dp0++] = (byte)base64[(bits >>> 6) & 0x3f];
-                dst[dp0++] = (byte)base64[bits & 0x3f];
+                dst[dp0++] = base64[(bits >>> 18) & 0x3f];
+                dst[dp0++] = base64[(bits >>> 12) & 0x3f];
+                dst[dp0++] = base64[(bits >>> 6)  & 0x3f];
+                dst[dp0++] = base64[bits & 0x3f];
             }
         }

         private int encode0(byte[] src, int off, int end, byte[] dst) {
-            char[] base64 = isURL ? toBase64URL : toBase64;
+            byte[] base64 = isURL ? toBase64URL : toBase64;
             int sp = off;
             int slen = (end - off) / 3 * 3;
             int sl = off + slen;
-            if (linemax > 0 && slen  > linemax / 4 * 3)
+            if (linemax > 0 && slen > linemax / 4 * 3)
                 slen = linemax / 4 * 3;
             int dp = 0;
             while (sp < sl) {
@@ -455,17 +446,17 @@
             }
             if (sp < end) {               // 1 or 2 leftover bytes
                 int b0 = src[sp++] & 0xff;
-                dst[dp++] = (byte)base64[b0 >> 2];
+                dst[dp++] = base64[b0 >> 2];
                 if (sp == end) {
-                    dst[dp++] = (byte)base64[(b0 << 4) & 0x3f];
+                    dst[dp++] = base64[(b0 << 4) & 0x3f];
                     if (doPadding) {
                         dst[dp++] = '=';
                         dst[dp++] = '=';
                     }
                 } else {
                     int b1 = src[sp++] & 0xff;
-                    dst[dp++] = (byte)base64[(b0 << 4) & 0x3f | (b1 >> 4)];
-                    dst[dp++] = (byte)base64[(b1 << 2) & 0x3f];
+                    dst[dp++] = base64[(b0 << 4) & 0x3f | (b1 >> 4)];
+                    dst[dp++] = base64[(b1 << 2) & 0x3f];
                     if (doPadding) {
                         dst[dp++] = '=';
                     }
@@ -523,11 +514,12 @@
          * the array are encoded to -1.
          *
          */
-        private static final int[] fromBase64 = new int[256];
+        private static final byte[] fromBase64 = new byte[256];
+
         static {
-            Arrays.fill(fromBase64, -1);
+            Arrays.fill(fromBase64, (byte) -1);
             for (int i = 0; i < Encoder.toBase64.length; i++)
-                fromBase64[Encoder.toBase64[i]] = i;
+                fromBase64[Encoder.toBase64[i]] = (byte) i;
             fromBase64['='] = -2;
         }

@@ -535,12 +527,12 @@
          * Lookup table for decoding "URL and Filename safe Base64 Alphabet"
          * as specified in Table2 of the RFC 4648.
          */
-        private static final int[] fromBase64URL = new int[256];
+        private static final byte[] fromBase64URL = new byte[256];

         static {
-            Arrays.fill(fromBase64URL, -1);
+            Arrays.fill(fromBase64URL, (byte) -1);
             for (int i = 0; i < Encoder.toBase64URL.length; i++)
-                fromBase64URL[Encoder.toBase64URL[i]] = i;
+                fromBase64URL[Encoder.toBase64URL[i]] = (byte) i;
             fromBase64URL['='] = -2;
         }

@@ -699,7 +691,7 @@
          *
          */
         private int decodedOutLength(byte[] src, int sp, int sl) {
-            int[] base64 = isURL ? fromBase64URL : fromBase64;
+            byte[] base64 = isURL ? fromBase64URL : fromBase64;
             int paddings = 0;
             int len = sl - sp;
             if (len == 0)
@@ -743,7 +735,7 @@
         }

         private int decode0(byte[] src, int sp, int sl, byte[] dst) {
-            int[] base64 = isURL ? fromBase64URL : fromBase64;
+            byte[] base64 = isURL ? fromBase64URL : fromBase64;
             int dp = 0;
             int bits = 0;
             int shiftto = 18;       // pos of first byte of 4-byte atom
@@ -832,14 +824,14 @@
         private int b0, b1, b2;
         private boolean closed = false;

-        private final char[] base64;    // byte->base64 mapping
+        private final byte[] base64;    // byte->base64 mapping
         private final byte[] newline;   // line separator, if needed
         private final int linemax;
         private final boolean doPadding;// whether or not to pad
         private int linepos = 0;
         private byte[] buf;

-        EncOutputStream(OutputStream os, char[] base64,
+        EncOutputStream(OutputStream os, byte[] base64,
                         byte[] newline, int linemax, boolean doPadding) {
             super(os);
             this.base64 = base64;
@@ -863,11 +855,11 @@
             }
         }

-        private void writeb4(char b1, char b2, char b3, char b4) throws IOException {
-            buf[0] = (byte)b1;
-            buf[1] = (byte)b2;
-            buf[2] = (byte)b3;
-            buf[3] = (byte)b4;
+        private void writeb4(byte b1, byte b2, byte b3, byte b4) throws IOException {
+            buf[0] = b1;
+            buf[1] = b2;
+            buf[2] = b3;
+            buf[3] = b4;
             out.write(buf, 0, 4);
         }

@@ -909,10 +901,10 @@
                     int bits = (b[sp++] & 0xff) << 16 |
                                (b[sp++] & 0xff) <<  8 |
                                (b[sp++] & 0xff);
-                    buf[dp++] = (byte)base64[(bits >>> 18) & 0x3f];
-                    buf[dp++] = (byte)base64[(bits >>> 12) & 0x3f];
-                    buf[dp++] = (byte)base64[(bits >>> 6)  & 0x3f];
-                    buf[dp++] = (byte)base64[bits & 0x3f];
+                    buf[dp++] = base64[(bits >>> 18) & 0x3f];
+                    buf[dp++] = base64[(bits >>> 12) & 0x3f];
+                    buf[dp++] = base64[(bits >>> 6) & 0x3f];
+                    buf[dp++] = base64[bits & 0x3f];
                 }
                 out.write(buf, 0, dp);
                 off = sl;
@@ -960,130 +952,119 @@
     private static class DecInputStream extends InputStream {

         private final InputStream is;
+        private final byte[] base64;    // mapping from alphabet to values
         private final boolean isMIME;
-        private final int[] base64;      // base64 -> byte mapping
-        private int bits = 0;            // 24-bit buffer for decoding
-        private int nextin = 18;         // next available "off" in "bits" for input;
-                                         // -> 18, 12, 6, 0
-        private int nextout = -8;        // next available "off" in "bits" for output; -                                         // -> 8, 0, -8 (no byte for output)
-        private boolean eof = false;
-        private boolean closed = false;
+        private int bits;               // 24 bit buffer for decoding
+        private int wpos;               // writing bit pos inside bits
+                                        // one of 24 (left, msb), 18, 12, 6, 0
+        private int rpos;               // reading bit pos inside bits
+                                        // one of 24 (left, msb), 16, 8, 0
+        private boolean eos;
+        private boolean closed;
+        private byte[] onebuf = new byte[1];

-        DecInputStream(InputStream is, int[] base64, boolean isMIME) {
+        DecInputStream(InputStream is, byte[] base64, boolean isMIME) {
             this.is = is;
             this.base64 = base64;
             this.isMIME = isMIME;
         }

-        private byte[] sbBuf = new byte[1];
-
         @Override
         public int read() throws IOException {
-            return read(sbBuf, 0, 1) == -1 ? -1 : sbBuf[0] & 0xff;
+            return read(onebuf, 0, 1) >= 0 ? onebuf[0] & 0xff : -1;
+        }
+
+        private int leftovers(byte[] b, int off, int pos, int limit) {
+            eos = true;
+            while (rpos - 8 >= wpos && pos != limit) {
+                b[pos++] = (byte) (bits >> (rpos -= 8));
+            }
+            return pos - off != 0 || rpos - 8 >= wpos ? pos - off : -1;
         }

-        private int eof(byte[] b, int off, int len, int oldOff)
-            throws IOException
-        {
-            eof = true;
-            if (nextin != 18) {
-                if (nextin == 12)
-                    throw new IOException("Base64 stream has one un-decoded dangling byte.");
-                // treat ending xx/xxx without padding character legal.
-                // same logic as v == '=' below
-                b[off++] = (byte)(bits >> (16));
-                if (nextin == 0) {           // only one padding byte
-                    if (len == 1) {          // no enough output space
-                        bits >>= 8;          // shift to lowest byte
-                        nextout = 0;
-                    } else {
-                        b[off++] = (byte) (bits >>  8);
-                    }
-                }
+        private int eos(byte[] b, int off, int pos, int limit) throws IOException {
+            // wpos == 18: x     dangling single x, invalid unit
+            // accept ending xx or xxx without padding characters
+            if (wpos == 18) {
+                throw new IOException("Base64 stream has one un-decoded dangling byte");
             }
-            return off == oldOff ? -1 : off - oldOff;
+            rpos = 24;
+            return leftovers(b, off, pos, limit);
         }

-        private int padding(byte[] b, int off, int len, int oldOff)
-            throws IOException
-        {
-            // =     shiftto==18 unnecessary padding
-            // x=    shiftto==12 dangling x, invalid unit
-            // xx=   shiftto==6 && missing last '='
-            // xx=y  or last is not '='
-            if (nextin == 18 || nextin == 12 ||
-                nextin == 6 && is.read() != '=') {
-                throw new IOException("Illegal base64 ending sequence:" + nextin); +        private int padding(byte[] b, int off, int pos, int limit) throws IOException {
+            // wpos == 24: =    (unnecessary padding)
+            // wpos == 18: x=   (dangling single x, invalid unit)
+            // wpos == 12 and missing last '=': xx=  (invalid padding)
+            // wpos == 12 and last is not '=': xx=x (invalid padding)
+            if (wpos >= 18 || wpos == 12 && is.read() != '=') {
+                throw new IOException("Illegal base64 ending sequence");
             }
-            b[off++] = (byte)(bits >> (16));
-            if (nextin == 0) {           // only one padding byte
-                if (len == 1) {          // no enough output space
-                    bits >>= 8;          // shift to lowest byte
-                    nextout = 0;
-                } else {
-                    b[off++] = (byte) (bits >>  8);
-                }
-            }
-            eof = true;
-            return off - oldOff;
+            rpos = 24;
+            return leftovers(b, off, pos, limit);
         }

         @Override
         public int read(byte[] b, int off, int len) throws IOException {
-            if (closed)
+            if (closed) {
                 throw new IOException("Stream is closed");
-            if (eof && nextout < 0)    // eof and no leftover
-                return -1;
-            if (off < 0 || len < 0 || len > b.length - off)
-                throw new IndexOutOfBoundsException();
-            int oldOff = off;
-            while (nextout >= 0) {       // leftover output byte(s) in bits buf
-                if (len == 0)
-                    return off - oldOff;
-                b[off++] = (byte)(bits >> nextout);
-                len--;
-                nextout -= 8;
+            }
+            Objects.checkFromIndexSize(off, len, b.length);
+
+            // limit can overflow to Integer.MIN_VALUE. However, as long
+            // as comparisons with pos are done as coded, there's no harm.
+            int pos = off;
+            int limit = off + len;
+            if (eos) {
+                return leftovers(b, off, pos, limit);
+            }
+
+            // leftovers from previous invocation; here, wpos == 0
+            while (rpos - 8 >= 0 && pos != limit) {
+                b[pos++] = (byte) (bits >> (rpos -= 8));
             }
+            if (pos == limit) {
+                return limit - off;
+            }
+
             bits = 0;
-            while (len > 0) {
-                int v = is.read();
-                if (v == -1) {
-                    return eof(b, off, len, oldOff);
+            wpos = 24;
+            while (pos != limit) {
+                int i = is.read();
+                if (i < 0) {
+                    return eos(b, off, pos, limit);
                 }
-                if ((v = base64[v]) < 0) {
-                    if (v == -2) {       // padding byte(s)
-                        return padding(b, off, len, oldOff);
-                    }
+                int v = base64[i];
+                if (v < 0) {
+                    // i not in alphabet
+                    // v = -2 (i is '=') or v = -1 (i is other, e.g., CR or LF)
                     if (v == -1) {
-                        if (!isMIME)
-                            throw new IOException("Illegal base64 character " +
-                                Integer.toString(v, 16));
-                        continue;        // skip if for rfc2045
+                        if (isMIME) {
+                            continue;
+                        }
+                        throw new IOException("Illegal base64 byte 0x" +
+                                Integer.toHexString(i));
                     }
-                    // neve be here
+                    return padding(b, off, pos, limit);
                 }
-                bits |= (v << nextin);
-                if (nextin == 0) {
-                    nextin = 18;         // clear for next in
-                    b[off++] = (byte)(bits >> 16);
-                    if (len == 1) {
-                        nextout = 8;    // 2 bytes left in bits
-                        break;
+                bits |= (v << (wpos -= 6));
+                if (wpos == 0) {
+                    if (limit - pos >= 3) {
+                        // frequently taken fast path, no need to track rpos
+                        b[pos++] = (byte) (bits >> 16);
+                        b[pos++] = (byte) (bits >> 8);
+                        b[pos++] = (byte) bits;
+                        bits = 0;
+                        wpos = 24;
+                    } else {
+                        rpos = 24;
+                        while (pos != limit) {
+                            b[pos++] = (byte) (bits >> (rpos -= 8));
+                        }
                     }
-                    b[off++] = (byte)(bits >> 8);
-                    if (len == 2) {
-                        nextout = 0;    // 1 byte left in bits
-                        break;
-                    }
-                    b[off++] = (byte)bits;
-                    len -= 3;
-                    bits = 0;
-                } else {
-                    nextin -= 6;
                 }
             }
-            return off - oldOff;
+            return limit - off;
         }

         @Override
diff --git a/test/jdk/java/util/Base64/TestBase64.java b/test/jdk/java/util/Base64/TestBase64.java
--- a/test/jdk/java/util/Base64/TestBase64.java
+++ b/test/jdk/java/util/Base64/TestBase64.java
@@ -144,6 +144,10 @@
         testDecoderKeepsAbstinence(Base64.getDecoder());
         testDecoderKeepsAbstinence(Base64.getUrlDecoder());
         testDecoderKeepsAbstinence(Base64.getMimeDecoder());
+
+        // tests patch addressing JDK-8222187
+        // https://bugs.openjdk.java.net/browse/JDK-8222187
+        testJDK_8222187();
     }

     private static void test(Base64.Encoder enc, Base64.Decoder dec,
@@ -607,4 +611,27 @@
             }
         }
     }
+
+    private static void testJDK_8222187() throws Throwable {
+        byte[] orig = "12345678".getBytes("US-ASCII");
+        byte[] encoded = Base64.getEncoder().encode(orig);
+        // decode using different buffer sizes
+        for (int bufferSize = 1; bufferSize <= 100; bufferSize++) {
+            try (
+                    InputStream in = Base64.getDecoder().wrap(
+                            new ByteArrayInputStream(encoded));
+                    ByteArrayOutputStream baos = new ByteArrayOutputStream();
+            ) {
+                byte[] buffer = new byte[bufferSize];
+                int read;
+                while ((read = in.read(buffer, 0, bufferSize)) >= 0) {
+                    baos.write(buffer, 0, read);
+                }
+                // compare result, output info if lengths do not match
+                byte[] decoded = baos.toByteArray();
+                checkEqual(decoded, orig, "Base64 stream decoding failed!");
+            }
+        }
+
+    }
 }



On 2020-06-09 09:20, Raffaello Giulietti wrote:
Hi Lance,

before working on a fix, I just wanted to make sure that I'm not interfering with existing efforts. Thus, I don't have a fix, yet.

I'll be using the example provided in the bug report as a basic test.

I'll show up here once the fix is ready.


Greetings
Raffaello



On 2020-06-08 22:34, Lance Andersen wrote:
Hi Raffaello,

If you are interested in providing a fix, you are more than welcome to do so.  Please include a test which validates your fix.

Best
lance

On Jun 8, 2020, at 3:45 PM, Raffaello Giulietti <raffaello.giulie...@gmail.com <mailto:raffaello.giulie...@gmail.com>> wrote:

Raffaello

<http://oracle.com/us/design/oracle-email-sig-198324.gif>
<http://oracle.com/us/design/oracle-email-sig-198324.gif><http://oracle.com/us/design/oracle-email-sig-198324.gif> <http://oracle.com/us/design/oracle-email-sig-198324.gif>Lance Andersen| Principal Member of Technical Staff | +1.781.442.2037
Oracle Java Engineering
1 Network Drive
Burlington, MA 01803
lance.ander...@oracle.com <mailto:lance.ander...@oracle.com>




Reply via email to