szilard-nemeth commented on a change in pull request #3259:
URL: https://github.com/apache/hadoop/pull/3259#discussion_r684236151



##########
File path: 
hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-shuffle/src/test/java/org/apache/hadoop/mapred/TestShuffleHandler.java
##########
@@ -106,10 +129,584 @@
       LoggerFactory.getLogger(TestShuffleHandler.class);
   private static final File ABS_LOG_DIR = GenericTestUtils.getTestDir(
       TestShuffleHandler.class.getSimpleName() + "LocDir");
+  private static final long ATTEMPT_ID = 12345L;
+  private static final long ATTEMPT_ID_2 = 12346L;
+  
+
+  //Control test execution properties with these flags
+  private static final boolean DEBUG_MODE = false;
+  //WARNING: If this is set to true and proxy server is not running, tests 
will fail!
+  private static final boolean USE_PROXY = false;
+  private static final int HEADER_WRITE_COUNT = 100000;
+  private static TestExecution TEST_EXECUTION;
+
+  private static class TestExecution {
+    private static final int DEFAULT_KEEP_ALIVE_TIMEOUT = -100;
+    private static final int DEBUG_FRIENDLY_KEEP_ALIVE = 1000;
+    private static final int DEFAULT_PORT = 0; //random port
+    private static final int FIXED_PORT = 8088;
+    private static final String PROXY_HOST = "127.0.0.1";
+    private static final int PROXY_PORT = 8888;
+    private final boolean debugMode;
+    private final boolean useProxy;
+
+    public TestExecution(boolean debugMode, boolean useProxy) {
+      this.debugMode = debugMode;
+      this.useProxy = useProxy;
+    }
+
+    int getKeepAliveTimeout() {
+      if (debugMode) {
+        return DEBUG_FRIENDLY_KEEP_ALIVE;
+      }
+      return DEFAULT_KEEP_ALIVE_TIMEOUT;
+    }
+    
+    HttpURLConnection openConnection(URL url) throws IOException {
+      HttpURLConnection conn;
+      if (useProxy) {
+        Proxy proxy
+            = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(PROXY_HOST, 
PROXY_PORT));
+        conn = (HttpURLConnection) url.openConnection(proxy);
+      } else {
+        conn = (HttpURLConnection) url.openConnection();
+      }
+      return conn;
+    }
+    
+    int shuffleHandlerPort() {
+      if (debugMode) {
+        return FIXED_PORT;
+      } else {
+        return DEFAULT_PORT;
+      }
+    }
+    
+    void parameterizeConnection(URLConnection conn) {
+      if (DEBUG_MODE) {
+        conn.setReadTimeout(1000000);
+        conn.setConnectTimeout(1000000);
+      }
+    }
+  }
+  
+  private static class ResponseConfig {
+    private static final int ONE_HEADER_DISPLACEMENT = 1;
+    
+    private final int headerWriteCount;
+    private final long actualHeaderWriteCount;
+    private final int mapOutputCount;
+    private final int contentLengthOfOneMapOutput;
+    private long headerSize;
+    public long contentLengthOfResponse;
+
+    public ResponseConfig(int headerWriteCount, int mapOutputCount, int 
contentLengthOfOneMapOutput) {
+      if (mapOutputCount <= 0 && contentLengthOfOneMapOutput > 0) {
+        throw new IllegalStateException("mapOutputCount should be at least 1");
+      }
+      this.headerWriteCount = headerWriteCount;
+      this.mapOutputCount = mapOutputCount;
+      this.contentLengthOfOneMapOutput = contentLengthOfOneMapOutput;
+      //MapOutputSender#send will send header N + 1 times
+      //So, (N + 1) * headerSize should be the Content-length header + the 
expected Content-length as well
+      this.actualHeaderWriteCount = headerWriteCount + ONE_HEADER_DISPLACEMENT;
+    }
+
+    private void setHeaderSize(long headerSize) {
+      this.headerSize = headerSize;
+      long contentLengthOfAllHeaders = actualHeaderWriteCount * headerSize;
+      this.contentLengthOfResponse = 
computeContentLengthOfResponse(contentLengthOfAllHeaders);
+      LOG.debug("Content-length of all headers: {}", 
contentLengthOfAllHeaders);
+      LOG.debug("Content-length of one MapOutput: {}", 
contentLengthOfOneMapOutput);
+      LOG.debug("Content-length of final HTTP response: {}", 
contentLengthOfResponse);
+    }
+
+    private long computeContentLengthOfResponse(long 
contentLengthOfAllHeaders) {
+      int mapOutputCountMultiplier = mapOutputCount;
+      if (mapOutputCount == 0) {
+        mapOutputCountMultiplier = 1;
+      }
+      return (contentLengthOfAllHeaders + contentLengthOfOneMapOutput) * 
mapOutputCountMultiplier;
+    }
+  }
+  
+  private enum ShuffleUrlType {
+    SIMPLE, WITH_KEEPALIVE, WITH_KEEPALIVE_MULTIPLE_MAP_IDS, 
WITH_KEEPALIVE_NO_MAP_IDS
+  }
+
+  private static class InputStreamReadResult {
+    final String asString;
+    int totalBytesRead;
+
+    public InputStreamReadResult(byte[] bytes, int totalBytesRead) {
+      this.asString = new String(bytes, StandardCharsets.UTF_8);
+      this.totalBytesRead = totalBytesRead;
+    }
+  }
+
+  private static abstract class AdditionalMapOutputSenderOperations {
+    public abstract ChannelFuture perform(ChannelHandlerContext ctx, Channel 
ch) throws IOException;
+  }
+
+  private class ShuffleHandlerForKeepAliveTests extends ShuffleHandler {
+    final LastSocketAddress lastSocketAddress = new LastSocketAddress();
+    final ArrayList<Throwable> failures = new ArrayList<>();
+    final ShuffleHeaderProvider shuffleHeaderProvider;
+    final HeaderPopulator headerPopulator;
+    MapOutputSender mapOutputSender;
+    private Consumer<IdleStateEvent> channelIdleCallback;
+    private CustomTimeoutHandler customTimeoutHandler;
+    private boolean failImmediatelyOnErrors = false;
+    private boolean closeChannelOnError = true;
+    private ResponseConfig responseConfig;
+
+    public ShuffleHandlerForKeepAliveTests(long attemptId, ResponseConfig 
responseConfig,
+        Consumer<IdleStateEvent> channelIdleCallback) throws IOException {
+      this(attemptId, responseConfig);
+      this.channelIdleCallback = channelIdleCallback;
+    }
+
+    public ShuffleHandlerForKeepAliveTests(long attemptId, ResponseConfig 
responseConfig) throws IOException {
+      this.responseConfig = responseConfig;
+      this.shuffleHeaderProvider = new ShuffleHeaderProvider(attemptId);
+      
this.responseConfig.setHeaderSize(shuffleHeaderProvider.getShuffleHeaderSize());
+      this.headerPopulator = new HeaderPopulator(this, responseConfig, 
shuffleHeaderProvider, true);
+      this.mapOutputSender = new MapOutputSender(responseConfig, 
lastSocketAddress, shuffleHeaderProvider);
+      setUseOutboundExceptionHandler(true);
+    }
+
+    public void setFailImmediatelyOnErrors(boolean failImmediatelyOnErrors) {
+      this.failImmediatelyOnErrors = failImmediatelyOnErrors;
+    }
+
+    public void setCloseChannelOnError(boolean closeChannelOnError) {
+      this.closeChannelOnError = closeChannelOnError;
+    }
+
+    @Override
+    protected Shuffle getShuffle(final Configuration conf) {
+      // replace the shuffle handler with one stubbed for testing
+      return new Shuffle(conf) {
+        @Override
+        protected MapOutputInfo getMapOutputInfo(String mapId, int reduce,
+            String jobId, String user) throws IOException {
+          return null;
+        }
+        @Override
+        protected void verifyRequest(String appid, ChannelHandlerContext ctx,
+            HttpRequest request, HttpResponse response, URL requestUri)
+            throws IOException {
+        }
+
+        @Override
+        protected void populateHeaders(List<String> mapIds, String jobId,
+            String user, int reduce, HttpRequest request,
+            HttpResponse response, boolean keepAliveParam,
+            Map<String, MapOutputInfo> infoMap) throws IOException {
+          long contentLength = headerPopulator.populateHeaders(
+              keepAliveParam);
+          super.setResponseHeaders(response, keepAliveParam, contentLength);
+        }
+
+        @Override
+        protected ChannelFuture sendMapOutput(ChannelHandlerContext ctx,
+            Channel ch, String user, String mapId, int reduce,
+            MapOutputInfo info) throws IOException {
+          return mapOutputSender.send(ctx, ch);
+        }
+
+        @Override
+        public void channelActive(ChannelHandlerContext ctx) throws Exception {
+          ctx.pipeline().replace(HttpResponseEncoder.class, 
ENCODER_HANDLER_NAME, new LoggingHttpResponseEncoder(false));
+          replaceTimeoutHandlerWithCustom(ctx);
+          LOG.debug("Modified pipeline: {}", ctx.pipeline());
+          super.channelActive(ctx);
+        }
+
+        private void replaceTimeoutHandlerWithCustom(ChannelHandlerContext 
ctx) {
+          TimeoutHandler oldTimeoutHandler =
+              (TimeoutHandler)ctx.pipeline().get(TIMEOUT_HANDLER);
+          int timeoutValue =
+              oldTimeoutHandler.getConnectionKeepAliveTimeOut();
+          customTimeoutHandler = new CustomTimeoutHandler(timeoutValue, 
channelIdleCallback);
+          ctx.pipeline().replace(TIMEOUT_HANDLER, TIMEOUT_HANDLER, 
customTimeoutHandler);
+        }
+
+        @Override
+        protected void sendError(ChannelHandlerContext ctx,
+            HttpResponseStatus status) {
+          String message = "Error while processing request. Status: " + status;
+          handleError(ctx, message);
+          if (failImmediatelyOnErrors) {
+            stop();
+          }
+        }
+
+        @Override
+        protected void sendError(ChannelHandlerContext ctx, String message,
+            HttpResponseStatus status) {
+          String errMessage = String.format("Error while processing request. " 
+
+              "Status: " +
+              "%s, message: %s", status, message);
+          handleError(ctx, errMessage);
+          if (failImmediatelyOnErrors) {
+            stop();
+          }
+        }
+      };
+    }
+
+    private void handleError(ChannelHandlerContext ctx, String message) {
+      LOG.error(message);
+      failures.add(new Error(message));
+      if (closeChannelOnError) {
+        LOG.warn("sendError: Closing channel");
+        ctx.channel().close();
+      }
+    }
+
+    private class CustomTimeoutHandler extends TimeoutHandler {
+      private boolean channelIdle = false;
+      private final Consumer<IdleStateEvent> channelIdleCallback;
+
+      public CustomTimeoutHandler(int connectionKeepAliveTimeOut,
+          Consumer<IdleStateEvent> channelIdleCallback) {
+        super(connectionKeepAliveTimeOut);
+        this.channelIdleCallback = channelIdleCallback;
+      }
+
+      @Override
+      public void channelIdle(ChannelHandlerContext ctx, IdleStateEvent e) {
+        LOG.debug("Channel idle");
+        this.channelIdle = true;
+        if (channelIdleCallback != null) {
+          LOG.debug("Calling channel idle callback..");
+          channelIdleCallback.accept(e);
+        }
+        super.channelIdle(ctx, e);
+      }
+    }
+  }
+
+  private static class MapOutputSender {
+    private final ResponseConfig responseConfig;
+    private final LastSocketAddress lastSocketAddress;
+    private final ShuffleHeaderProvider shuffleHeaderProvider;
+    private AdditionalMapOutputSenderOperations 
additionalMapOutputSenderOperations;
+
+    public MapOutputSender(ResponseConfig responseConfig, LastSocketAddress 
lastSocketAddress,
+        ShuffleHeaderProvider shuffleHeaderProvider) {
+      this.responseConfig = responseConfig;
+      this.lastSocketAddress = lastSocketAddress;
+      this.shuffleHeaderProvider = shuffleHeaderProvider;
+    }
+
+    public ChannelFuture send(ChannelHandlerContext ctx, Channel ch) throws 
IOException {
+      LOG.debug("In MapOutputSender#send");
+      lastSocketAddress.setAddress(ch.remoteAddress());
+      ShuffleHeader header = shuffleHeaderProvider.createNewShuffleHeader();
+      writeOneHeader(ch, header);
+      ChannelFuture future = writeHeaderNTimes(ch, header, 
responseConfig.headerWriteCount);
+      // This is the last operation
+      // It's safe to increment ShuffleHeader counter for better identification
+      shuffleHeaderProvider.incrementCounter();
+      if (additionalMapOutputSenderOperations != null) {
+        return additionalMapOutputSenderOperations.perform(ctx, ch);
+      }
+      return future;
+    }
+    private void writeOneHeader(Channel ch, ShuffleHeader header) throws 
IOException {
+      DataOutputBuffer dob = new DataOutputBuffer();
+      header.write(dob);
+      LOG.debug("MapOutputSender#writeOneHeader before WriteAndFlush #1");
+      ch.writeAndFlush(wrappedBuffer(dob.getData(), 0, dob.getLength()));
+      LOG.debug("MapOutputSender#writeOneHeader after WriteAndFlush #1. 
outputBufferSize: " + dob.size());
+    }
+
+    private ChannelFuture writeHeaderNTimes(Channel ch, ShuffleHeader header, 
int iterations) throws IOException {
+      DataOutputBuffer dob = new DataOutputBuffer();
+      for (int i = 0; i < iterations; ++i) {
+        header.write(dob);
+      }
+      LOG.debug("MapOutputSender#writeHeaderNTimes WriteAndFlush big chunk of 
data, outputBufferSize: " + dob.size());
+      return ch.writeAndFlush(wrappedBuffer(dob.getData(), 0, 
dob.getLength()));
+    }
+  }
+
+  private static class ShuffleHeaderProvider {
+    private final long attemptId;
+    private final AtomicInteger attemptCounter;
+    private int cachedSize = Integer.MIN_VALUE;
+
+    public ShuffleHeaderProvider(long attemptId) {
+      this.attemptId = attemptId;
+      this.attemptCounter = new AtomicInteger();
+    }
+
+    ShuffleHeader createNewShuffleHeader() {
+      return new ShuffleHeader(String.format("attempt_%s_1_m_1_0%s", attemptId,
+          attemptCounter.get()), 5678, 5678, 1);
+    }
+
+    void incrementCounter() {
+      attemptCounter.incrementAndGet();
+    }
+
+    private int getShuffleHeaderSize() throws IOException {
+      if (cachedSize != Integer.MIN_VALUE) {
+        return cachedSize;
+      }
+      DataOutputBuffer dob = new DataOutputBuffer();
+      ShuffleHeader header = createNewShuffleHeader();
+      header.write(dob);
+      cachedSize = dob.size();
+      return cachedSize;
+    }
+  }
+
+  private static class HeaderPopulator {
+    private final ShuffleHandler shuffleHandler;
+    private final boolean disableKeepAliveConfig;
+    private final ShuffleHeaderProvider shuffleHeaderProvider;
+    private ResponseConfig responseConfig;
+
+    public HeaderPopulator(ShuffleHandler shuffleHandler,
+        ResponseConfig responseConfig,
+        ShuffleHeaderProvider shuffleHeaderProvider,
+        boolean disableKeepAliveConfig) {
+      this.shuffleHandler = shuffleHandler;
+      this.responseConfig = responseConfig;
+      this.disableKeepAliveConfig = disableKeepAliveConfig;
+      this.shuffleHeaderProvider = shuffleHeaderProvider;
+    }
+
+    public long populateHeaders(boolean keepAliveParam) throws IOException {
+      // Send some dummy data (populate content length details)
+      DataOutputBuffer dob = new DataOutputBuffer();
+      for (int i = 0; i < responseConfig.headerWriteCount; ++i) {
+        ShuffleHeader header =
+            shuffleHeaderProvider.createNewShuffleHeader();
+        header.write(dob);
+      }
+      // for testing purpose;
+      // disable connectionKeepAliveEnabled if keepAliveParam is available
+      if (keepAliveParam && disableKeepAliveConfig) {
+        shuffleHandler.connectionKeepAliveEnabled = false;
+      }
+      return responseConfig.contentLengthOfResponse;
+    }
+  }
+
+  private static class HttpConnectionData {
+    private final Map<String, List<String>> headers;
+    private HttpURLConnection conn;
+    private int payloadLength;
+    private SocketAddress socket;
+    private int responseCode = -1;
+
+    private HttpConnectionData(HttpURLConnection conn, int payloadLength,
+        SocketAddress socket) {
+      this.headers = conn.getHeaderFields();
+      this.conn = conn;
+      this.payloadLength = payloadLength;
+      this.socket = socket;
+      try {
+        this.responseCode = conn.getResponseCode();
+      } catch (IOException e) {
+        Assert.fail("Failed to read response code from connection: " + conn);
+      }
+    }
+
+    static HttpConnectionData create(HttpURLConnection conn, int 
payloadLength, SocketAddress socket) {
+      return new HttpConnectionData(conn, payloadLength, socket);
+    }
+  }
+
+  private static class HttpConnectionAssert {
+    private final HttpConnectionData connData;
+
+    private HttpConnectionAssert(HttpConnectionData connData) {
+      this.connData = connData;
+    }
+
+    static HttpConnectionAssert create(HttpConnectionData connData) {
+      return new HttpConnectionAssert(connData);
+    }
+
+    public static void assertKeepAliveConnectionsAreSame(HttpConnectionHelper 
httpConnectionHelper) {
+      Assert.assertTrue("At least two connection data " +
+          "is required to perform this assertion",
+          httpConnectionHelper.connectionData.size() >= 2);
+      SocketAddress firstAddress = 
httpConnectionHelper.getConnectionData(0).socket;
+      SocketAddress secondAddress = 
httpConnectionHelper.getConnectionData(1).socket;
+      Assert.assertNotNull("Initial shuffle address should not be null",
+          firstAddress);
+      Assert.assertNotNull("Keep-Alive shuffle address should not be null",
+          secondAddress);
+      Assert.assertEquals("Initial shuffle address and keep-alive shuffle "
+          + "address should be the same", firstAddress, secondAddress);
+    }
+
+    public HttpConnectionAssert expectKeepAliveWithTimeout(long timeout) {
+      Assert.assertEquals(HttpURLConnection.HTTP_OK, connData.responseCode);
+      assertHeaderValue(HttpHeader.CONNECTION, 
HttpHeader.KEEP_ALIVE.asString());
+      assertHeaderValue(HttpHeader.KEEP_ALIVE, "timeout=" + timeout);
+      return this;
+    }
+
+    public HttpConnectionAssert expectBadRequest(long timeout) {
+      Assert.assertEquals(HttpURLConnection.HTTP_BAD_REQUEST, 
connData.responseCode);
+      assertHeaderValue(HttpHeader.CONNECTION, 
HttpHeader.KEEP_ALIVE.asString());
+      assertHeaderValue(HttpHeader.KEEP_ALIVE, "timeout=" + timeout);
+      return this;
+    }
+
+    public HttpConnectionAssert expectResponseContentLength(long size) {
+      Assert.assertEquals(size, connData.payloadLength);
+      return this;
+    }
+
+    private void assertHeaderValue(HttpHeader header, String expectedValue) {
+      List<String> headerList = connData.headers.get(header.asString());
+      Assert.assertNotNull("Got null header value for header: " + header, 
headerList);
+      Assert.assertFalse("Got empty header value for header: " + header, 
headerList.isEmpty());
+      assertEquals("Unexpected size of header list for header: " + header, 1,
+          headerList.size());
+      Assert.assertEquals(expectedValue, headerList.get(0));
+    }
+  }
+
+  private static class HttpConnectionHelper {
+    private final LastSocketAddress lastSocketAddress;
+    List<HttpConnectionData> connectionData = new ArrayList<>();
+
+    public HttpConnectionHelper(LastSocketAddress lastSocketAddress) {
+      this.lastSocketAddress = lastSocketAddress;
+    }
+
+    public void connectToUrls(String[] urls, ResponseConfig responseConfig) 
throws IOException {
+      connectToUrlsInternal(urls, responseConfig, HttpURLConnection.HTTP_OK);
+    }
+
+    public void connectToUrls(String[] urls, ResponseConfig responseConfig, 
int expectedHttpStatus) throws IOException {
+      connectToUrlsInternal(urls, responseConfig, expectedHttpStatus);
+    }
+
+    private void connectToUrlsInternal(String[] urls, ResponseConfig 
responseConfig, int expectedHttpStatus) throws IOException {
+      int requests = urls.length;
+      LOG.debug("Will connect to URLs: {}", Arrays.toString(urls));
+      for (int reqIdx = 0; reqIdx < requests; reqIdx++) {
+        String urlString = urls[reqIdx];
+        LOG.debug("Connecting to URL: {}", urlString);
+        URL url = new URL(urlString);
+        HttpURLConnection conn = TEST_EXECUTION.openConnection(url);
+        conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_NAME,
+            ShuffleHeader.DEFAULT_HTTP_HEADER_NAME);
+        conn.setRequestProperty(ShuffleHeader.HTTP_HEADER_VERSION,
+            ShuffleHeader.DEFAULT_HTTP_HEADER_VERSION);
+        TEST_EXECUTION.parameterizeConnection(conn);
+        conn.connect();
+        if (expectedHttpStatus == HttpURLConnection.HTTP_BAD_REQUEST) {
+          //Catch exception as error are caught with overridden sendError 
method
+          //Caught errors will be validated later.
+          try {
+            DataInputStream input = new DataInputStream(conn.getInputStream());
+          } catch (Exception e) {
+            return;
+          }

Review comment:
       Thanks for spotting this, it was a mistake.
   Changed to continue statement.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: common-issues-unsubscr...@hadoop.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: common-issues-unsubscr...@hadoop.apache.org
For additional commands, e-mail: common-issues-h...@hadoop.apache.org

Reply via email to