This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 675a7da39 [CELEBORN-368][FLINK] Pass exceptions in buffer stream.
(#1304)
675a7da39 is described below
commit 675a7da393e33cf1c7ab043f027fb4d8913de700
Author: Ethan Feng <[email protected]>
AuthorDate: Fri Mar 3 15:43:30 2023 +0800
[CELEBORN-368][FLINK] Pass exceptions in buffer stream. (#1304)
---
.../plugin/flink/network/MessageDecoderExt.java | 7 +++
.../plugin/flink/network/ReadClientHandler.java | 42 +++++++------
.../TransportFrameDecoderWithBufferSupplier.java | 16 +++++
.../plugin/flink/RemoteBufferStreamReader.java | 10 +++
.../common/network/protocol/BufferStreamEnd.java | 52 ++++++++++++++++
.../celeborn/common/network/protocol/Message.java | 13 +++-
.../network/protocol/TransportableError.java | 71 ++++++++++++++++++++++
.../common/network/server/BufferStreamManager.java | 6 +-
.../common/network/server/DataPartitionReader.java | 11 +++-
.../service/deploy/worker/FetchHandler.scala | 7 +++
10 files changed, 214 insertions(+), 21 deletions(-)
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java
index fbee2e4fd..5b61ffed0 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java
@@ -81,6 +81,13 @@ public class MessageDecoderExt {
backlog = in.readInt();
return new BacklogAnnouncement(streamId, backlog);
+ case TRANSPORTABLE_ERROR:
+ streamId = in.readLong();
+ int len = in.readInt();
+ byte[] errorBytes = new byte[len];
+ in.readBytes(errorBytes);
+ return new TransportableError(streamId, errorBytes);
+
default:
throw new IllegalArgumentException("Unexpected message type: " + type);
}
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
index 8f768c26b..86afc80c8 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
@@ -25,7 +25,9 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
+import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.TransportableError;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
@@ -43,7 +45,21 @@ public class ReadClientHandler extends BaseMessageHandler {
public void removeHandler(long streamId) {
streamHandlers.remove(streamId);
- streamClients.remove(streamId);
+ TransportClient client = streamClients.remove(streamId);
+ // If read handler is removed, we should notify worker to release resource.
+ if (client.isActive()) {
+ client.getChannel().writeAndFlush(new BufferStreamEnd(streamId));
+ }
+ }
+
+ private void processMessageInternal(long streamId, RequestMessage msg) {
+ Consumer<RequestMessage> handler = streamHandlers.get(streamId);
+ if (handler != null) {
+ logger.debug("received streamId: {}, msg :{}", streamId, msg);
+ handler.accept(msg);
+ } else {
+ logger.warn("Unexpected streamId received: {}", streamId);
+ }
}
@Override
@@ -53,27 +69,17 @@ public class ReadClientHandler extends BaseMessageHandler {
case READ_DATA:
ReadData readData = (ReadData) msg;
streamId = readData.getStreamId();
- if (streamHandlers.containsKey(streamId)) {
- logger.debug(
- "received streamId: {}, readData size:{}",
- streamId,
- readData.getFlinkBuffer().readableBytes());
- streamHandlers.get(streamId).accept(msg);
- } else {
- logger.warn("Unexpected streamId received: {}", streamId);
- }
+ processMessageInternal(streamId, readData);
break;
case BACKLOG_ANNOUNCEMENT:
BacklogAnnouncement backlogAnnouncement = (BacklogAnnouncement) msg;
streamId = backlogAnnouncement.getStreamId();
- Consumer<RequestMessage> consumer = streamHandlers.get(streamId);
- if (consumer != null) {
- logger.debug(
- "received streamId: {}, backlog: {}", streamId,
backlogAnnouncement.getBacklog());
- consumer.accept(msg);
- } else {
- logger.warn("Unexpected streamId received: {}", streamId);
- }
+ processMessageInternal(streamId, backlogAnnouncement);
+ break;
+ case TRANSPORTABLE_ERROR:
+ TransportableError transportableError = ((TransportableError) msg);
+ streamId = transportableError.getStreamId();
+ processMessageInternal(streamId, transportableError);
break;
case ONE_WAY_MESSAGE:
// ignore it.
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
index 2dc541628..ce85e7167 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java
@@ -17,6 +17,8 @@
package org.apache.celeborn.plugin.flink.network;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
@@ -24,13 +26,18 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
import org.apache.celeborn.common.network.protocol.Message;
import org.apache.celeborn.common.network.util.FrameDecoder;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
public class TransportFrameDecoderWithBufferSupplier extends
ChannelInboundHandlerAdapter
implements FrameDecoder {
+ public static final Logger logger =
+ LoggerFactory.getLogger(TransportFrameDecoderWithBufferSupplier.class);
private int msgSize = -1;
private int bodySize = -1;
private Message.Type curType = Message.Type.UNKNOWN_TYPE;
@@ -116,6 +123,8 @@ public class TransportFrameDecoderWithBufferSupplier
extends ChannelInboundHandl
io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
ReadData readData = (ReadData) curMsg;
if (externalBuf == null) {
+ Supplier<ByteBuf> supplier = bufferSuppliers.get(readData.getStreamId());
+ checkState(supplier == null, "Stream " + readData.getStreamId() + "
buffer supplier is null");
externalBuf = bufferSuppliers.get(readData.getStreamId()).get();
}
copyByteBuf(buf, externalBuf, bodySize);
@@ -144,6 +153,13 @@ public class TransportFrameDecoderWithBufferSupplier
extends ChannelInboundHandl
}
}
}
+ } catch (IllegalStateException e) {
+ // Decode ReadData might encounter IllegalStateException.
+ long streamId = ((ReadData) curMsg).getStreamId();
+ logger.info("Stream {} is closed,reply to server", streamId);
+ if (ctx.channel().isActive()) {
+ ctx.channel().writeAndFlush(new BufferStreamEnd(streamId));
+ }
} finally {
if (nettyBuf != null) {
nettyBuf.release();
diff --git
a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
index 12cc0f324..9313ff701 100644
---
a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
+++
b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
@@ -26,6 +26,7 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
import org.apache.celeborn.common.network.protocol.ReadAddCredit;
import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.TransportableError;
import org.apache.celeborn.plugin.flink.buffer.CreditListener;
import org.apache.celeborn.plugin.flink.buffer.TransferBufferPool;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
@@ -72,6 +73,8 @@ public class RemoteBufferStreamReader extends CreditListener {
dataReceived((ReadData) requestMessage);
} else if (requestMessage instanceof BacklogAnnouncement) {
backlogReceived(((BacklogAnnouncement)
requestMessage).getBacklog());
+ } else if (requestMessage instanceof TransportableError) {
+ errorReceived(((TransportableError)
requestMessage).getErrorMessage());
}
};
}
@@ -125,6 +128,13 @@ public class RemoteBufferStreamReader extends
CreditListener {
}
}
+ public void errorReceived(String errorMsg) {
+ if (!closed) {
+ closed = true;
+ failureListener.accept(new IOException(errorMsg));
+ }
+ }
+
public void dataReceived(ReadData readData) {
logger.debug(
"Rss buffer stream reader get streamid {} received readable bytes {}.",
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
new file mode 100644
index 000000000..d85e380d1
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/BufferStreamEnd.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+public class BufferStreamEnd extends RequestMessage {
+ private long streamId;
+
+ public BufferStreamEnd(long streamId) {
+ this.streamId = streamId;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 8;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(streamId);
+ }
+
+ @Override
+ public Type type() {
+ return Type.BUFFER_STREAM_END;
+ }
+
+ public static Message decode(ByteBuf buffer) {
+ long streamId = buffer.readLong();
+ return new BufferStreamEnd(streamId);
+ }
+
+ public long getStreamId() {
+ return streamId;
+ }
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
index d2bbd55a6..31f81e336 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java
@@ -92,7 +92,9 @@ public abstract class Message implements Encodable {
READ_ADD_CREDIT(16),
READ_DATA(17),
OPEN_STREAM_WITH_CREDIT(18),
- BACKLOG_ANNOUNCEMENT(19);
+ BACKLOG_ANNOUNCEMENT(19),
+ TRANSPORTABLE_ERROR(20),
+ BUFFER_STREAM_END(21);
private final byte id;
Type(int id) {
@@ -158,6 +160,10 @@ public abstract class Message implements Encodable {
return OPEN_STREAM_WITH_CREDIT;
case 19:
return BACKLOG_ANNOUNCEMENT;
+ case 20:
+ return TRANSPORTABLE_ERROR;
+ case 21:
+ return BUFFER_STREAM_END;
case -1:
throw new IllegalArgumentException("User type messages cannot be
decoded.");
default:
@@ -223,6 +229,11 @@ public abstract class Message implements Encodable {
case BACKLOG_ANNOUNCEMENT:
return BacklogAnnouncement.decode(in);
+ case TRANSPORTABLE_ERROR:
+ return TransportableError.decode(in);
+ case BUFFER_STREAM_END:
+ return BufferStreamEnd.decode(in);
+
default:
throw new IllegalArgumentException("Unexpected message type: " +
msgType);
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java
new file mode 100644
index 000000000..628bf8509
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+import java.nio.charset.StandardCharsets;
+
+import io.netty.buffer.ByteBuf;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+
+public class TransportableError extends RequestMessage {
+ private long streamId;
+ private byte[] errorMessage;
+
+ public TransportableError(long streamId, Throwable throwable) {
+ this.streamId = streamId;
+ this.errorMessage =
ExceptionUtils.getStackTrace(throwable).getBytes(StandardCharsets.UTF_8);
+ }
+
+ public TransportableError(long streamId, byte[] errorMessage) {
+ this.streamId = streamId;
+ this.errorMessage = errorMessage;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 8 + 4 + errorMessage.length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeLong(streamId);
+ buf.writeInt(errorMessage.length);
+ buf.writeBytes(errorMessage);
+ }
+
+ @Override
+ public Type type() {
+ return Type.TRANSPORTABLE_ERROR;
+ }
+
+ public static TransportableError decode(ByteBuf buf) {
+ long streamId = buf.readLong();
+ int msgLen = buf.readInt();
+ byte[] errorMsg = new byte[msgLen];
+ buf.readBytes(errorMsg);
+ return new TransportableError(streamId, errorMsg);
+ }
+
+ public long getStreamId() {
+ return streamId;
+ }
+
+ public String getErrorMessage() {
+ return new String(errorMessage, StandardCharsets.UTF_8);
+ }
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
b/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
index 9e367776d..09cf561b0 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java
@@ -151,6 +151,10 @@ public class BufferStreamManager {
}
}
+ public void notifyStreamEndByClient(long streamId) {
+ recycleStream(streamId);
+ }
+
public void recycleStream(long streamId) {
recycleStreamIds.add(new DelayedStreamId(streamId));
startRecycleThread(); // lazy start thread
@@ -380,7 +384,7 @@ public class BufferStreamManager {
DataPartitionReader dataPartitionReader = streamReaders.get(streamId);
dataPartitionReader.release();
if (dataPartitionReader.isFinished()) {
- logger.info("release all for stream: {}", streamId);
+ logger.debug("release all for stream: {}", streamId);
removeStream(streamId);
streams.remove(streamId);
servingStreams.remove(streamId);
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/server/DataPartitionReader.java
b/common/src/main/java/org/apache/celeborn/common/network/server/DataPartitionReader.java
index 612712147..b91a077c0 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/server/DataPartitionReader.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/server/DataPartitionReader.java
@@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.meta.FileInfo;
import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
import org.apache.celeborn.common.network.protocol.ReadData;
+import org.apache.celeborn.common.network.protocol.TransportableError;
import org.apache.celeborn.common.network.server.memory.Recycler;
import org.apache.celeborn.common.network.server.memory.WrappedDataBuffer;
import org.apache.celeborn.common.util.Utils;
@@ -372,7 +373,11 @@ public class DataPartitionReader implements
Comparable<DataPartitionReader> {
private void notifyError(Throwable throwable) {
logger.error("read error stream id {} message:{}", streamId,
throwable.getMessage(), throwable);
- // TODO notify client the exception
+ if (this.associatedChannel.isActive()) {
+ // If a stream is failed, send exceptions with the best effort, do not
expect response.
+ // And do not close channel because multiple streams are using the very
same channel.
+ this.associatedChannel.writeAndFlush(new TransportableError(streamId,
throwable));
+ }
}
public long getPriority() {
@@ -445,4 +450,8 @@ public class DataPartitionReader implements
Comparable<DataPartitionReader> {
sb.append('}');
return sb.toString();
}
+
+ public long getStreamId() {
+ return streamId;
+ }
}
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index cbd5f53a3..787ee7d96 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -74,6 +74,9 @@ class FetchHandler(val conf: TransportConf) extends
BaseMessageHandler with Logg
override def receive(client: TransportClient, msg: RequestMessage): Unit = {
msg match {
+ case r: BufferStreamEnd =>
+ rpcSource.updateMessageMetrics(r, 0)
+ handleEndStreamFromClient(client, r)
case r: ReadAddCredit =>
rpcSource.updateMessageMetrics(r, 0)
handleReadAddCredit(client, r)
@@ -187,6 +190,10 @@ class FetchHandler(val conf: TransportConf) extends
BaseMessageHandler with Logg
}
}
+ def handleEndStreamFromClient(client: TransportClient, req:
BufferStreamEnd): Unit = {
+ bufferStreamManager.notifyStreamEndByClient(req.getStreamId)
+ }
+
def handleReadAddCredit(client: TransportClient, req: ReadAddCredit): Unit =
{
bufferStreamManager.addCredit(req.getCredit, req.getStreamId)
}