Copilot commented on code in PR #11365: URL: https://github.com/apache/incubator-gluten/pull/11365#discussion_r2675446219
########## gluten-flink/runtime/src/main/java/org/apache/gluten/client/OperatorChainSliceGraphGenerator.java: ########## @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.client; + +import org.apache.gluten.streaming.api.operators.GlutenOneInputOperatorFactory; +import org.apache.gluten.streaming.api.operators.GlutenOperator; + +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.graph.StreamEdge; +import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +class OperatorChainSliceGraphGenerator { + private static final Logger LOG = LoggerFactory.getLogger(OperatorChainSliceGraphGenerator.class); + private OperatorChainSliceGraph chainSliceGraph = null; + private Map<Integer, List<Integer>> operatorParents; + private JobVertex jobVertex; + private Map<Integer, StreamConfig> chainedConfigs; + private final ClassLoader userClassloader; + + public OperatorChainSliceGraphGenerator(JobVertex jobVertex, ClassLoader userClassloader) { + this.operatorParents = new HashMap<>(); + this.jobVertex = jobVertex; + this.userClassloader = userClassloader; + } + + public OperatorChainSliceGraph getGraph() { + generateInternal(); + return chainSliceGraph; + } + + private void generateInternal() { + if (chainSliceGraph != null) { + return; + } + chainSliceGraph = new OperatorChainSliceGraph(); + + StreamConfig rootOpConfig = new StreamConfig(jobVertex.getConfiguration()); + + chainedConfigs = new HashMap<>(); + rootOpConfig + .getTransitiveChainedTaskConfigs(userClassloader) + .forEach( + (id, config) -> { + chainedConfigs.put(id, new StreamConfig(config.getConfiguration())); + }); + chainedConfigs.put(rootOpConfig.getVertexID(), rootOpConfig); + + collectOperatorParents(rootOpConfig, null); + + OperatorChainSlice chainSlice = new OperatorChainSlice(rootOpConfig.getVertexID()); + chainSlice.setOffloadable(isOffloadableOperator(rootOpConfig)); + chainSlice.getOperatorConfigs().add(rootOpConfig); + chainSliceGraph.addSlice(chainSlice.id(), chainSlice); + + advanceOperatorChainSlice(chainSlice, rootOpConfig); + } + + private void advanceOperatorChainSlice( + OperatorChainSlice chainSlice, StreamConfig currentOpConfig) { + List<StreamEdge> outputEdges = currentOpConfig.getChainedOutputs(userClassloader); + if (outputEdges == null || outputEdges.isEmpty()) { + return; + } + if (outputEdges.size() == 1) { + Integer targetId = outputEdges.get(0).getTargetId(); + StreamConfig childOpConfig = chainedConfigs.get(targetId); + // We don't coalesce operators into the same velox plan at present. Each operator is a + // separate velox plan. + startNewOperatorChainSlice(chainSlice, childOpConfig); + } else { + for (StreamEdge edge : outputEdges) { + Integer targetId = edge.getTargetId(); + StreamConfig childOpConfig = chainedConfigs.get(targetId); + startNewOperatorChainSlice(chainSlice, childOpConfig); + } + } + } + + private void startNewOperatorChainSlice( + OperatorChainSlice parentChainSlice, StreamConfig childOpConfig) { + Boolean isFistVisit = false; + OperatorChainSlice childChainSlice = chainSliceGraph.getSlice(childOpConfig.getVertexID()); + if (childChainSlice == null) { + isFistVisit = true; + childChainSlice = new OperatorChainSlice(childOpConfig.getVertexID()); + } + + parentChainSlice.getOutputs().add(childChainSlice.id()); + childChainSlice.getInputs().add(parentChainSlice.id()); + // If this path has been visited, do not advance again. + if (isFistVisit) { Review Comment: Corrected spelling of 'isFistVisit' to 'isFirstVisit'. ```suggestion Boolean isFirstVisit = false; OperatorChainSlice childChainSlice = chainSliceGraph.getSlice(childOpConfig.getVertexID()); if (childChainSlice == null) { isFirstVisit = true; childChainSlice = new OperatorChainSlice(childOpConfig.getVertexID()); } parentChainSlice.getOutputs().add(childChainSlice.id()); childChainSlice.getInputs().add(parentChainSlice.id()); // If this path has been visited, do not advance again. if (isFirstVisit) { ``` ########## gluten-flink/runtime/src/main/java/org/apache/gluten/util/VectorInputBridge.java: ########## @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.table.runtime.operators; + +import org.apache.gluten.vectorized.FlinkRowToVLVectorConvertor; + +import io.github.zhztheplayer.velox4j.data.RowVector; +import io.github.zhztheplayer.velox4j.session.Session; +import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; +import io.github.zhztheplayer.velox4j.type.RowType; + +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.data.RowData; + +import org.apache.arrow.memory.BufferAllocator; + +import java.io.Serializable; + +// This bridge is used to convert the input data to RowVector. +public class VectorInputBridge<IN> implements Serializable { + private static final long serialVersionUID = 1L; + private final Class<IN> inClass; + private final String nodeId; + + public class RowVectorWrapper { + public RowVector rowVector; + public String nodeId; + + public RowVectorWrapper(RowVector rowVector, String nodeId) { + this.rowVector = rowVector; + this.nodeId = nodeId; + } + } + ; Review Comment: Remove the unnecessary semicolon after the closing brace of the inner class `RowVectorWrapper`. Semicolons after class declarations are not required in Java and should be removed for cleaner code. ```suggestion ``` ########## gluten-flink/runtime/src/main/java/org/apache/gluten/client/OffloadedJobGraphGenerator.java: ########## @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.client; + +import org.apache.gluten.streaming.api.operators.GlutenOneInputOperatorFactory; +import org.apache.gluten.streaming.api.operators.GlutenOperator; +import org.apache.gluten.streaming.api.operators.GlutenStreamSource; +import org.apache.gluten.table.runtime.keyselector.GlutenKeySelector; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenSourceFunction; +import org.apache.gluten.table.runtime.operators.GlutenTwoInputOperator; +import org.apache.gluten.table.runtime.typeutils.GlutenStatefulRecordSerializer; + +import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; +import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; +import io.github.zhztheplayer.velox4j.type.RowType; + +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.graph.StreamEdge; +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.table.data.RowData; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/* + * If a operator is offloadable + * - update its input/output serializers as needed. + * - update its key selectors as needed. + * - coalesce it with its siblings as needed. Also need to update the stream edges. + */ +public class OffloadedJobGraphGenerator { + private static final Logger LOG = LoggerFactory.getLogger(OffloadedJobGraphGenerator.class); + private final JobGraph jobGraph; + private final ClassLoader userClassloader; + private boolean hasGenerated = false; + + public OffloadedJobGraphGenerator(JobGraph jobGraph, ClassLoader userClassloader) { + this.jobGraph = jobGraph; + this.userClassloader = userClassloader; + } + + public JobGraph generate() { + if (hasGenerated) { + throw new IllegalStateException("JobGraph has been generated."); + } + hasGenerated = true; + for (JobVertex jobVertex : jobGraph.getVertices()) { + offloadJobVertex(jobVertex); + } + return jobGraph; + } + + private void offloadJobVertex(JobVertex jobVertex) { + OperatorChainSliceGraphGenerator graphGenerator = + new OperatorChainSliceGraphGenerator(jobVertex, userClassloader); + OperatorChainSliceGraph chainSliceGraph = graphGenerator.getGraph(); + chainSliceGraph.dumpLog(); + + OperatorChainSlice sourceChainSlice = chainSliceGraph.getSourceSlice(); + OperatorChainSliceGraph targetChainSliceGraph = new OperatorChainSliceGraph(); + visitAndOffloadChainOperators(sourceChainSlice, chainSliceGraph, targetChainSliceGraph, 0); + visitAndUpdateStreamEdges(sourceChainSlice, chainSliceGraph, targetChainSliceGraph); + serializeAllOperatorsConfigs(targetChainSliceGraph); + + StreamConfig sourceConfig = sourceChainSlice.getOperatorConfigs().get(0); + StreamConfig targetSourceConfig = + targetChainSliceGraph.getSlice(sourceChainSlice.id()).getOperatorConfigs().get(0); + + Map<Integer, StreamConfig> chainedConfig = new HashMap<Integer, StreamConfig>(); + if (sourceChainSlice.isOffloadable()) { + // Update the first operator config + sourceConfig.setStreamOperatorFactory( + targetSourceConfig.getStreamOperatorFactory(userClassloader)); + List<StreamEdge> chainedOutputs = targetSourceConfig.getChainedOutputs(userClassloader); + sourceConfig.setChainedOutputs(targetSourceConfig.getChainedOutputs(userClassloader)); + + // Update the serializers and partitioners + sourceConfig.setTypeSerializerOut(targetSourceConfig.getTypeSerializerOut(userClassloader)); + sourceConfig.setInputs(targetSourceConfig.getInputs(userClassloader)); + KeySelector<?, ?> keySelector0 = targetSourceConfig.getStatePartitioner(0, userClassloader); + if (keySelector0 != null) { + sourceConfig.setStatePartitioner(0, keySelector0); + } + KeySelector<?, ?> keySelector1 = targetSourceConfig.getStatePartitioner(1, userClassloader); + if (keySelector1 != null) { + sourceConfig.setStatePartitioner(1, keySelector1); + } + + // The chained operators should be empty. + } else { + List<StreamConfig> operatorConfigs = sourceChainSlice.getOperatorConfigs(); + for (int i = 1; i < operatorConfigs.size(); i++) { + StreamConfig opConfig = operatorConfigs.get(i); + chainedConfig.put(opConfig.getVertexID(), opConfig); + } + } + for (OperatorChainSlice chainSlice : targetChainSliceGraph.getSlices().values()) { + if (chainSlice.id().equals(sourceChainSlice.id())) { + continue; + } + List<StreamConfig> operatorConfigs = chainSlice.getOperatorConfigs(); + for (StreamConfig opConfig : operatorConfigs) { + chainedConfig.put(opConfig.getVertexID(), opConfig); + } + } + sourceConfig.setAndSerializeTransitiveChainedTaskConfigs(chainedConfig); + sourceConfig.serializeAllConfigs(); + } + + // Fold offloadable operator chain slice + private void visitAndOffloadChainOperators( + OperatorChainSlice chainSlice, + OperatorChainSliceGraph originalChainSliceGraph, + OperatorChainSliceGraph targetChainSliceGraph, + Integer chainedIndex) { + List<Integer> outputs = chainSlice.getOutputs(); + List<Integer> outputIndex = new ArrayList<>(); + OperatorChainSlice finalChainSlice = null; + if (chainSlice.isOffloadable()) { + finalChainSlice = + OffloadOperatorChainSlice(originalChainSliceGraph, chainSlice, chainedIndex); + chainedIndex = chainedIndex + 1; + } else { + finalChainSlice = applyUnoffloadableOperatorChainSlice(chainSlice, chainedIndex); + chainedIndex = chainedIndex + chainSlice.getOperatorConfigs().size(); + } + + finalChainSlice.getInputs().addAll(chainSlice.getInputs()); + finalChainSlice.getOutputs().addAll(chainSlice.getOutputs()); + targetChainSliceGraph.addSlice(chainSlice.id(), finalChainSlice); + + for (Integer outputChainIndex : outputs) { + OperatorChainSlice outputChainSlice = originalChainSliceGraph.getSlice(outputChainIndex); + OperatorChainSlice outputResultChainSlice = targetChainSliceGraph.getSlice(outputChainIndex); + if (outputResultChainSlice == null) { + visitAndOffloadChainOperators( + outputChainSlice, originalChainSliceGraph, targetChainSliceGraph, chainedIndex); + } + } + } + + // Keep the original operator chain slice as is. + private OperatorChainSlice applyUnoffloadableOperatorChainSlice( + OperatorChainSlice originalChainSlice, Integer chainedIndex) { + OperatorChainSlice finalChainSlice = new OperatorChainSlice(originalChainSlice.id()); + List<StreamConfig> operatorConfigs = originalChainSlice.getOperatorConfigs(); + for (StreamConfig opConfig : operatorConfigs) { + StreamConfig newOpConfig = new StreamConfig(new Configuration(opConfig.getConfiguration())); + newOpConfig.setChainIndex(chainedIndex); + finalChainSlice.getOperatorConfigs().add(newOpConfig); + } + finalChainSlice.setOffloadable(false); + return finalChainSlice; + } + + // Fold offloadable operator chain slice, and update the input/output channel serializers + private OperatorChainSlice OffloadOperatorChainSlice( + OperatorChainSliceGraph chainSliceGraph, + OperatorChainSlice originalChainSlice, + Integer chainedIndex) { + OperatorChainSlice finalChainSlice = new OperatorChainSlice(originalChainSlice.id()); + List<StreamConfig> operatorConfigs = originalChainSlice.getOperatorConfigs(); + + // May coalesce multiple operators into one in the future. + if (operatorConfigs.size() != 1) { + throw new UnsupportedOperationException( + "Only one operator is supported for offloaded operator chain slice."); + } + + StreamConfig originalOpConfig = operatorConfigs.get(0); + GlutenOperator originalOp = getGlutenOperator(originalOpConfig).get(); + StatefulPlanNode planNode = originalOp.getPlanNode(); + // Create a new operator config for the offloaded operator. + StreamConfig finalOpConfig = + new StreamConfig(new Configuration(originalOpConfig.getConfiguration())); + if (originalOp instanceof GlutenStreamSource) { + boolean couldOutputRowVector = couldOutputRowVector(originalChainSlice, chainSliceGraph); + Class<?> outClass = couldOutputRowVector ? StatefulRecord.class : RowData.class; + GlutenStreamSource newSourceOp = + new GlutenStreamSource( + new GlutenSourceFunction<>( + planNode, + originalOp.getOutputTypes(), + originalOp.getId(), + ((GlutenStreamSource) originalOp).getConnectorSplit(), + outClass)); + finalOpConfig.setStreamOperator(newSourceOp); + if (couldOutputRowVector) { + RowType rowType = originalOp.getOutputTypes().entrySet().iterator().next().getValue(); + finalOpConfig.setTypeSerializerOut( + new GlutenStatefulRecordSerializer(rowType, originalOp.getId())); + } + } else if (originalOp instanceof GlutenOneInputOperator) { + boolean couldOutputRowVector = couldOutputRowVector(originalChainSlice, chainSliceGraph); + boolean couldInputRowVector = couldInputRowVector(originalChainSlice, chainSliceGraph); + Class<?> inClass = couldInputRowVector ? StatefulRecord.class : RowData.class; + Class<?> outClass = couldOutputRowVector ? StatefulRecord.class : RowData.class; + GlutenOneInputOperator<?, ?> newOneInputOp = + new GlutenOneInputOperator<>( + planNode, + originalOp.getId(), + originalOp.getInputType(), + originalOp.getOutputTypes(), + inClass, + outClass, + originalOp.getDescription()); + finalOpConfig.setStreamOperator(newOneInputOp); + if (couldOutputRowVector) { + RowType rowType = originalOp.getOutputTypes().entrySet().iterator().next().getValue(); + finalOpConfig.setTypeSerializerOut( + new GlutenStatefulRecordSerializer(rowType, originalOp.getId())); + } + if (couldInputRowVector) { + finalOpConfig.setupNetworkInputs( + new GlutenStatefulRecordSerializer(originalOp.getInputType(), originalOp.getId())); + + // This node is the first node in the chain. If it has a state partitioner, we need to + // change it to GlutenKeySelector. + KeySelector<?, ?> keySelector = originalOpConfig.getStatePartitioner(0, userClassloader); + if (keySelector != null) { + LOG.info( + "State partitioner ({}) found in the first node {}, change it to GlutenKeySelector.", + keySelector.getClass().getName(), + originalOp.getDescription()); + finalOpConfig.setStatePartitioner(0, new GlutenKeySelector()); + } + } + } else if (originalOp instanceof GlutenTwoInputOperator) { + GlutenTwoInputOperator<?, ?> twoInputOp = (GlutenTwoInputOperator<?, ?>) originalOp; + boolean couldOutputRowVector = couldOutputRowVector(originalChainSlice, chainSliceGraph); + boolean couldInputRowVector = couldInputRowVector(originalChainSlice, chainSliceGraph); + KeySelector<?, ?> keySelector0 = originalOpConfig.getStatePartitioner(0, userClassloader); + if (keySelector0 != null) { + LOG.info( + "State partitioner ({}) found in the first node {}, change it to GlutenKeySelector.", + keySelector0.getClass().getName(), + originalOp.getDescription()); + finalOpConfig.setStatePartitioner(0, new GlutenKeySelector()); + } + KeySelector<?, ?> keySelector1 = originalOpConfig.getStatePartitioner(1, userClassloader); + if (keySelector1 != null) { + LOG.info( + "State partitioner ({}) found in the second node {}, change it to GlutenKeySelector.", + keySelector1.getClass().getName(), + originalOp.getDescription()); + finalOpConfig.setStatePartitioner(1, new GlutenKeySelector()); + } + Class<?> inClass = couldInputRowVector ? StatefulRecord.class : RowData.class; + Class<?> outClass = couldOutputRowVector ? StatefulRecord.class : RowData.class; + GlutenTwoInputOperator<?, ?> newTwoInputOp = + new GlutenTwoInputOperator<>( + planNode, + twoInputOp.getLeftId(), + twoInputOp.getRightId(), + twoInputOp.getLeftInputType(), + twoInputOp.getRightInputType(), + twoInputOp.getOutputTypes(), + inClass, + outClass); + finalOpConfig.setStreamOperator(newTwoInputOp); + finalOpConfig.setStatePartitioner(0, new GlutenKeySelector()); + finalOpConfig.setStatePartitioner(1, new GlutenKeySelector()); + // Update the output channel serializer + if (couldOutputRowVector) { + RowType rowType = twoInputOp.getOutputTypes().entrySet().iterator().next().getValue(); + finalOpConfig.setTypeSerializerOut( + new GlutenStatefulRecordSerializer(rowType, twoInputOp.getId())); + } + // Update the input channel serializers + if (couldInputRowVector) { + finalOpConfig.setupNetworkInputs( + new GlutenStatefulRecordSerializer(twoInputOp.getLeftInputType(), twoInputOp.getId()), + new GlutenStatefulRecordSerializer(twoInputOp.getRightInputType(), twoInputOp.getId())); + } + } else { + throw new UnsupportedOperationException( + "Only GlutenStreamSource is supported for offloaded operator chain slice."); + } + + finalOpConfig.setChainIndex(chainedIndex); + finalChainSlice.getOperatorConfigs().add(finalOpConfig); + finalChainSlice.setOffloadable(true); + return finalChainSlice; + } + + private StreamNode mockStreamNode(StreamConfig streamConfig) { + return new StreamNode( + streamConfig.getVertexID(), + null, + null, + (StreamOperatorFactory<?>) streamConfig.getStreamOperatorFactory(userClassloader), + streamConfig.getOperatorName(), + null); + } + + // Incase the vetexs has been changed, update the stream edges. Review Comment: Corrected spelling of 'Incase' to 'In case' and 'vetexs' to 'vertices'. ```suggestion // In case the vertices have been changed, update the stream edges. ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
