Hello list Please find the attached patch that fixes JDK-8228580.
It works the similar way UDP timeout does: calculate the initial timeout from retry attempt, and account for duration of every blocking call on the TCP socket. I am listed as Author[1] on "jdk" project. [1] https://openjdk.java.net/census#mmimica -- Milan Mimica
diff -r 56df9a08ed9c src/jdk.naming.dns/share/classes/com/sun/jndi/dns/DnsClient.java --- a/src/jdk.naming.dns/share/classes/com/sun/jndi/dns/DnsClient.java Mon Aug 19 14:28:43 2019 +0100 +++ b/src/jdk.naming.dns/share/classes/com/sun/jndi/dns/DnsClient.java Mon Aug 19 16:19:45 2019 +0000 @@ -30,6 +30,7 @@ import java.net.DatagramPacket; import java.net.InetAddress; import java.net.Socket; +import java.net.SocketTimeoutException; import java.security.SecureRandom; import javax.naming.*; @@ -82,7 +83,7 @@ private static final SecureRandom random = JCAUtil.getSecureRandom(); private InetAddress[] servers; private int[] serverPorts; - private int timeout; // initial timeout on UDP queries in ms + private int timeout; // initial timeout on UDP and TCP queries in ms private int retries; // number of UDP retries private final Object udpSocketLock = new Object(); @@ -100,7 +101,7 @@ /* * Each server is of the form "server[:port]". IPv6 literal host names * include delimiting brackets. - * "timeout" is the initial timeout interval (in ms) for UDP queries, + * "timeout" is the initial timeout interval (in ms) for queries, * and "retries" gives the number of retries per server. */ public DnsClient(String[] servers, int timeout, int retries) @@ -237,6 +238,7 @@ // Try each server, starting with the one that just // provided the truncated message. + int retryTimeout = (timeout * (1 << retry)); for (int j = 0; j < servers.length; j++) { int ij = (i + j) % servers.length; if (doNotRetry[ij]) { @@ -244,7 +246,7 @@ } try { Tcp tcp = - new Tcp(servers[ij], serverPorts[ij]); + new Tcp(servers[ij], serverPorts[ij], retryTimeout); byte[] msg2; try { msg2 = doTcpQuery(tcp, pkt); @@ -327,7 +329,7 @@ // Try each name server. for (int i = 0; i < servers.length; i++) { try { - Tcp tcp = new Tcp(servers[i], serverPorts[i]); + Tcp tcp = new Tcp(servers[i], serverPorts[i], timeout); byte[] msg; try { msg = doTcpQuery(tcp, pkt); @@ -462,11 +464,11 @@ */ private byte[] continueTcpQuery(Tcp tcp) throws IOException { - int lenHi = tcp.in.read(); // high-order byte of response length + int lenHi = tcp.read(); // high-order byte of response length if (lenHi == -1) { return null; // EOF } - int lenLo = tcp.in.read(); // low-order byte of response length + int lenLo = tcp.read(); // low-order byte of response length if (lenLo == -1) { throw new IOException("Corrupted DNS response: bad length"); } @@ -474,7 +476,7 @@ byte[] msg = new byte[len]; int pos = 0; // next unfilled position in msg while (len > 0) { - int n = tcp.in.read(msg, pos, len); + int n = tcp.read(msg, pos, len); if (n == -1) { throw new IOException( "Corrupted DNS response: too little data"); @@ -683,19 +685,47 @@ class Tcp { private Socket sock; - java.io.InputStream in; + private java.io.InputStream in; java.io.OutputStream out; + private int timeoutLeft; - Tcp(InetAddress server, int port) throws IOException { + Tcp(InetAddress server, int port, int timeout) throws IOException { sock = new Socket(server, port); sock.setTcpNoDelay(true); out = new java.io.BufferedOutputStream(sock.getOutputStream()); in = new java.io.BufferedInputStream(sock.getInputStream()); + timeoutLeft = timeout; } void close() throws IOException { sock.close(); } + + private interface SockerReadOp { + int read() throws IOException; + } + + private int readWithTimeout(SockerReadOp reader) throws IOException { + if (timeoutLeft <= 0) + throw new SocketTimeoutException(); + + sock.setSoTimeout(timeoutLeft); + long start = System.currentTimeMillis(); + try { + return reader.read(); + } + finally { + timeoutLeft -= System.currentTimeMillis() - start; + } + } + + int read() throws IOException { + return readWithTimeout(() -> in.read()); + } + + int read(byte b[], int off, int len) throws IOException { + return readWithTimeout(() -> in.read(b, off, len)); + } } /*