This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 675a7da39 [CELEBORN-368][FLINK] Pass exceptions in buffer stream. 
(#1304)
675a7da39 is described below

commit 675a7da393e33cf1c7ab043f027fb4d8913de700
Author: Ethan Feng <[email protected]>
AuthorDate: Fri Mar 3 15:43:30 2023 +0800

    [CELEBORN-368][FLINK] Pass exceptions in buffer stream. (#1304)
---
 .../plugin/flink/network/MessageDecoderExt.java    |  7 +++
 .../plugin/flink/network/ReadClientHandler.java    | 42 +++++++------
 .../TransportFrameDecoderWithBufferSupplier.java   | 16 +++++
 .../plugin/flink/RemoteBufferStreamReader.java     | 10 +++
 .../common/network/protocol/BufferStreamEnd.java   | 52 ++++++++++++++++
 .../celeborn/common/network/protocol/Message.java  | 13 +++-
 .../network/protocol/TransportableError.java       | 71 ++++++++++++++++++++++
 .../common/network/server/BufferStreamManager.java |  6 +-
 .../common/network/server/DataPartitionReader.java | 11 +++-
 .../service/deploy/worker/FetchHandler.scala       |  7 +++
 10 files changed, 214 insertions(+), 21 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java
index fbee2e4fd..5b61ffed0 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java
@@ -81,6 +81,13 @@ public class MessageDecoderExt {
         backlog = in.readInt();
         return new BacklogAnnouncement(streamId, backlog);
 
+      case TRANSPORTABLE_ERROR:
+        streamId = in.readLong();
+        int len = in.readInt();
+        byte[] errorBytes = new byte[len];
+        in.readBytes(errorBytes);
+        return new TransportableError(streamId, errorBytes);
+
       default:
         throw new IllegalArgumentException("Unexpected message type: " + type);
     }
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
index 8f768c26b..86afc80c8 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
@@ -25,7 +25,9 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
+import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
 import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.TransportableError;
 import org.apache.celeborn.common.network.server.BaseMessageHandler;
 import org.apache.celeborn.plugin.flink.protocol.ReadData;
 
@@ -43,7 +45,21 @@ public class ReadClientHandler extends BaseMessageHandler {
 
   public void removeHandler(long streamId) {
     streamHandlers.remove(streamId);
-    streamClients.remove(streamId);
+    TransportClient client = streamClients.remove(streamId);
+    // If read handler is removed, we should notify worker to release resource.
+    if (client.isActive()) {
+      client.getChannel().writeAndFlush(new BufferStreamEnd(streamId));
+    }
+  }
+
+  private void processMessageInternal(long streamId, RequestMessage msg) {
+    Consumer<RequestMessage> handler = streamHandlers.get(streamId);
+    if (handler != null) {
+      logger.debug("received streamId: {}, msg :{}", streamId, msg);
+      handler.accept(msg);
+    } else {
+      logger.warn("Unexpected streamId received: {}", streamId);
+    }
   }
 
   @Override
@@ -53,27 +69,17 @@ public class ReadClientHandler extends BaseMessageHandler {
       case READ_DATA:
         ReadData readData = (ReadData) msg;
         streamId = readData.getStreamId();
-        if (streamHandlers.containsKey(streamId)) {
-          logger.debug(
-              "received streamId: {}, readData size:{}",
-              streamId,
-              readData.getFlinkBuffer().readableBytes());
-          streamHandlers.get(streamId).accept(msg);
-        } else {
-          logger.warn("Unexpected streamId received: {}", streamId);
-        }
+        processMessageInternal(streamId, readData);
         break;
       case BACKLOG_ANNOUNCEMENT:
         BacklogAnnouncement backlogAnnouncement = (BacklogAnnouncement) msg;
         streamId = backlogAnnouncement.getStreamId();
-        Consumer<RequestMessage> consumer = streamHandlers.get(streamId);
-        if (consumer != null) {
-          logger.debug(
-              "received streamId: {}, backlog: {}", streamId, 
backlogAnnouncement.getBacklog());
-          consumer.accept(msg);
-        } else {
-          logger.warn("Unexpected streamId received: {}", streamId);
-        }
+        processMessageInternal(streamId, backlogAnnouncement);
+        break;
+      case TRANSPORTABLE_ERROR:
+        TransportableError transportableError = ((TransportableError) msg);
+        streamId = transportableError.getStreamId();
+        processMessageInternal(streamId, transportableError);
         break;
       case ONE_WAY_MESSAGE:
         // ignore it.
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
index 2dc541628..ce85e7167 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
@@ -17,6 +17,8 @@
 
 package org.apache.celeborn.plugin.flink.network;
 
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Supplier;
 
@@ -24,13 +26,18 @@ import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelInboundHandlerAdapter;
 import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
 import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
+import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
 import org.apache.celeborn.common.network.protocol.Message;
 import org.apache.celeborn.common.network.util.FrameDecoder;
 import org.apache.celeborn.plugin.flink.protocol.ReadData;
 
 public class TransportFrameDecoderWithBufferSupplier extends 
ChannelInboundHandlerAdapter
     implements FrameDecoder {
+  public static final Logger logger =
+      LoggerFactory.getLogger(TransportFrameDecoderWithBufferSupplier.class);
   private int msgSize = -1;
   private int bodySize = -1;
   private Message.Type curType = Message.Type.UNKNOWN_TYPE;
@@ -116,6 +123,8 @@ public class TransportFrameDecoderWithBufferSupplier 
extends ChannelInboundHandl
       io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
     ReadData readData = (ReadData) curMsg;
     if (externalBuf == null) {
+      Supplier<ByteBuf> supplier = bufferSuppliers.get(readData.getStreamId());
+      checkState(supplier == null, "Stream " + readData.getStreamId() + " 
buffer supplier is null");
       externalBuf = bufferSuppliers.get(readData.getStreamId()).get();
     }
     copyByteBuf(buf, externalBuf, bodySize);
@@ -144,6 +153,13 @@ public class TransportFrameDecoderWithBufferSupplier 
extends ChannelInboundHandl
           }
         }
       }
+    } catch (IllegalStateException e) {
+      // Decode ReadData might encounter IllegalStateException.
+      long streamId = ((ReadData) curMsg).getStreamId();
+      logger.info("Stream {} is closed,reply to server", streamId);
+      if (ctx.channel().isActive()) {
+        ctx.channel().writeAndFlush(new BufferStreamEnd(streamId));
+      }
     } finally {
       if (nettyBuf != null) {
         nettyBuf.release();
diff --git 
a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
 
b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
index 12cc0f324..9313ff701 100644
--- 
a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
+++ 
b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
@@ -26,6 +26,7 @@ import org.slf4j.LoggerFactory;
 import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
 import org.apache.celeborn.common.network.protocol.ReadAddCredit;
 import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.TransportableError;
 import org.apache.celeborn.plugin.flink.buffer.CreditListener;
 import org.apache.celeborn.plugin.flink.buffer.TransferBufferPool;
 import org.apache.celeborn.plugin.flink.protocol.ReadData;
@@ -72,6 +73,8 @@ public class RemoteBufferStreamReader extends CreditListener {
             dataReceived((ReadData) requestMessage);
           } else if (requestMessage instanceof BacklogAnnouncement) {
             backlogReceived(((BacklogAnnouncement) 
requestMessage).getBacklog());
+          } else if (requestMessage instanceof TransportableError) {
+            errorReceived(((TransportableError) 
requestMessage).getErrorMessage());
           }
         };
   }
@@ -125,6 +128,13 @@ public class RemoteBufferStreamReader extends 
CreditListener {
     }
   }
 
+  public void errorReceived(String errorMsg) {
+    if (!closed) {
+      closed = true;
+      failureListener.accept(new IOException(errorMsg));
+    }
+  }
+
   public void dataReceived(ReadData readData) {
     logger.debug(
         "Rss buffer stream reader get streamid {} received readable bytes {}.",
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
new file mode 100644
index 000000000..d85e380d1
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+public class BufferStreamEnd extends RequestMessage {
+  private long streamId;
+
+  public BufferStreamEnd(long streamId) {
+    this.streamId = streamId;
+  }
+
+  @Override
+  public int encodedLength() {
+    return 8;
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    buf.writeLong(streamId);
+  }
+
+  @Override
+  public Type type() {
+    return Type.BUFFER_STREAM_END;
+  }
+
+  public static Message decode(ByteBuf buffer) {
+    long streamId = buffer.readLong();
+    return new BufferStreamEnd(streamId);
+  }
+
+  public long getStreamId() {
+    return streamId;
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
index d2bbd55a6..31f81e336 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
@@ -92,7 +92,9 @@ public abstract class Message implements Encodable {
     READ_ADD_CREDIT(16),
     READ_DATA(17),
     OPEN_STREAM_WITH_CREDIT(18),
-    BACKLOG_ANNOUNCEMENT(19);
+    BACKLOG_ANNOUNCEMENT(19),
+    TRANSPORTABLE_ERROR(20),
+    BUFFER_STREAM_END(21);
     private final byte id;
 
     Type(int id) {
@@ -158,6 +160,10 @@ public abstract class Message implements Encodable {
           return OPEN_STREAM_WITH_CREDIT;
         case 19:
           return BACKLOG_ANNOUNCEMENT;
+        case 20:
+          return TRANSPORTABLE_ERROR;
+        case 21:
+          return BUFFER_STREAM_END;
         case -1:
           throw new IllegalArgumentException("User type messages cannot be 
decoded.");
         default:
@@ -223,6 +229,11 @@ public abstract class Message implements Encodable {
       case BACKLOG_ANNOUNCEMENT:
         return BacklogAnnouncement.decode(in);
 
+      case TRANSPORTABLE_ERROR:
+        return TransportableError.decode(in);
+      case BUFFER_STREAM_END:
+        return BufferStreamEnd.decode(in);
+
       default:
         throw new IllegalArgumentException("Unexpected message type: " + 
msgType);
     }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java
new file mode 100644
index 000000000..628bf8509
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+import java.nio.charset.StandardCharsets;
+
+import io.netty.buffer.ByteBuf;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+
+public class TransportableError extends RequestMessage {
+  private long streamId;
+  private byte[] errorMessage;
+
+  public TransportableError(long streamId, Throwable throwable) {
+    this.streamId = streamId;
+    this.errorMessage = 
ExceptionUtils.getStackTrace(throwable).getBytes(StandardCharsets.UTF_8);
+  }
+
+  public TransportableError(long streamId, byte[] errorMessage) {
+    this.streamId = streamId;
+    this.errorMessage = errorMessage;
+  }
+
+  @Override
+  public int encodedLength() {
+    return 8 + 4 + errorMessage.length;
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    buf.writeLong(streamId);
+    buf.writeInt(errorMessage.length);
+    buf.writeBytes(errorMessage);
+  }
+
+  @Override
+  public Type type() {
+    return Type.TRANSPORTABLE_ERROR;
+  }
+
+  public static TransportableError decode(ByteBuf buf) {
+    long streamId = buf.readLong();
+    int msgLen = buf.readInt();
+    byte[] errorMsg = new byte[msgLen];
+    buf.readBytes(errorMsg);
+    return new TransportableError(streamId, errorMsg);
+  }
+
+  public long getStreamId() {
+    return streamId;
+  }
+
+  public String getErrorMessage() {
+    return new String(errorMessage, StandardCharsets.UTF_8);
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
 
b/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
index 9e367776d..09cf561b0 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
@@ -151,6 +151,10 @@ public class BufferStreamManager {
     }
   }
 
+  public void notifyStreamEndByClient(long streamId) {
+    recycleStream(streamId);
+  }
+
   public void recycleStream(long streamId) {
     recycleStreamIds.add(new DelayedStreamId(streamId));
     startRecycleThread(); // lazy start thread
@@ -380,7 +384,7 @@ public class BufferStreamManager {
       DataPartitionReader dataPartitionReader = streamReaders.get(streamId);
       dataPartitionReader.release();
       if (dataPartitionReader.isFinished()) {
-        logger.info("release all for stream: {}", streamId);
+        logger.debug("release all for stream: {}", streamId);
         removeStream(streamId);
         streams.remove(streamId);
         servingStreams.remove(streamId);
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/server/DataPartitionReader.java
 
b/common/src/main/java/org/apache/celeborn/common/network/server/DataPartitionReader.java
index 612712147..b91a077c0 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/server/DataPartitionReader.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/server/DataPartitionReader.java
@@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory;
 import org.apache.celeborn.common.meta.FileInfo;
 import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
 import org.apache.celeborn.common.network.protocol.ReadData;
+import org.apache.celeborn.common.network.protocol.TransportableError;
 import org.apache.celeborn.common.network.server.memory.Recycler;
 import org.apache.celeborn.common.network.server.memory.WrappedDataBuffer;
 import org.apache.celeborn.common.util.Utils;
@@ -372,7 +373,11 @@ public class DataPartitionReader implements 
Comparable<DataPartitionReader> {
 
   private void notifyError(Throwable throwable) {
     logger.error("read error stream id {} message:{}", streamId, 
throwable.getMessage(), throwable);
-    // TODO notify client the exception
+    if (this.associatedChannel.isActive()) {
+      // If a stream is failed, send exceptions with the best effort, do not 
expect response.
+      // And do not close channel because multiple streams are using the very 
same channel.
+      this.associatedChannel.writeAndFlush(new TransportableError(streamId, 
throwable));
+    }
   }
 
   public long getPriority() {
@@ -445,4 +450,8 @@ public class DataPartitionReader implements 
Comparable<DataPartitionReader> {
     sb.append('}');
     return sb.toString();
   }
+
+  public long getStreamId() {
+    return streamId;
+  }
 }
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index cbd5f53a3..787ee7d96 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -74,6 +74,9 @@ class FetchHandler(val conf: TransportConf) extends 
BaseMessageHandler with Logg
 
   override def receive(client: TransportClient, msg: RequestMessage): Unit = {
     msg match {
+      case r: BufferStreamEnd =>
+        rpcSource.updateMessageMetrics(r, 0)
+        handleEndStreamFromClient(client, r)
       case r: ReadAddCredit =>
         rpcSource.updateMessageMetrics(r, 0)
         handleReadAddCredit(client, r)
@@ -187,6 +190,10 @@ class FetchHandler(val conf: TransportConf) extends 
BaseMessageHandler with Logg
     }
   }
 
+  def handleEndStreamFromClient(client: TransportClient, req: 
BufferStreamEnd): Unit = {
+    bufferStreamManager.notifyStreamEndByClient(req.getStreamId)
+  }
+
   def handleReadAddCredit(client: TransportClient, req: ReadAddCredit): Unit = 
{
     bufferStreamManager.addCredit(req.getCredit, req.getStreamId)
   }

Reply via email to