Copilot commented on code in PR #11365: URL: https://github.com/apache/incubator-gluten/pull/11365#discussion_r2684538572
########## gluten-flink/runtime/src/main/java/org/apache/gluten/util/VectorInputBridge.java: ########## @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.util; + +import io.github.zhztheplayer.velox4j.session.Session; +import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; +import io.github.zhztheplayer.velox4j.type.RowType; + +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.arrow.memory.BufferAllocator; + +import java.io.Serializable; + +/** + * Interface for converting input data from Flink StreamRecord to StatefulRecord. Different + * implementations handle different input types. + */ +public interface VectorInputBridge<IN> extends Serializable { + + /** + * Converts a StreamRecord to StatefulRecord based on the input type. + * + * @param inputData the input StreamRecord + * @param allocator buffer allocator for creating RowVector + * @param session Velox session + * @param inputType the RowType schema of the input + * @return StatefulRecord containing the converted or original data + */ + StatefulRecord convertToStatefulRecord( + StreamRecord<IN> inputData, BufferAllocator allocator, Session session, RowType inputType); + + /** Factory for creating VectorInputBridge instances based on input type. */ + class Factory { + /** + * Creates a VectorInputBridge instance for the given input class. + * + * @param inputClass the input class type + * @param nodeId the node ID for the bridge + * @param <IN> the input type + * @return a VectorInputBridge instance + * @throws UnsupportedOperationException if input class is not supported + */ + public static <IN> VectorInputBridge<IN> create(Class<IN> inputClass, String nodeId) { + if (inputClass.isAssignableFrom(org.apache.flink.table.data.RowData.class)) { + @SuppressWarnings("unchecked") + VectorInputBridge<IN> bridge = (VectorInputBridge<IN>) new RowDataInputBridge(nodeId); + return bridge; + } else if (inputClass.isAssignableFrom(StatefulRecord.class)) { + @SuppressWarnings("unchecked") + VectorInputBridge<IN> bridge = (VectorInputBridge<IN>) new StatefulRecordInputBridge(); + return bridge; + } else { + throw new UnsupportedOperationException("Unsupported input class: " + inputClass.getName()); + } + } + } + + /** + * Implementation for RowData input type. Converts RowData to RowVector and wraps in + * StatefulRecord. + */ + class RowDataInputBridge implements VectorInputBridge<org.apache.flink.table.data.RowData> { + private static final long serialVersionUID = 1L; + private final String nodeId; + + public RowDataInputBridge(String nodeId) { + this.nodeId = nodeId; + } + + @Override + public StatefulRecord convertToStatefulRecord( + StreamRecord<org.apache.flink.table.data.RowData> inputData, + BufferAllocator allocator, + Session session, + RowType inputType) { + org.apache.flink.table.data.RowData rowData = inputData.getValue(); + var rowVector = + org.apache.gluten.vectorized.FlinkRowToVLVectorConvertor.fromRowData( + rowData, allocator, session, inputType); + StatefulRecord statefulRecord = new StatefulRecord(nodeId, rowVector.id(), 0, false, -1); + statefulRecord.setRowVector(rowVector); + return statefulRecord; + } + } + + /** Implementation for StatefulRecord input type. Passes through the StatefulRecord directly. */ + class StatefulRecordInputBridge implements VectorInputBridge<StatefulRecord> { + private static final long serialVersionUID = 1L; + + @Override + public StatefulRecord convertToStatefulRecord( + StreamRecord<StatefulRecord> inputData, + BufferAllocator allocator, + Session session, + RowType inputType) { + // Pass through the StatefulRecord directly. The original RowVector object is safe to close. Review Comment: This comment is ambiguous. It's unclear whether the comment means the caller should close the RowVector or that closing is handled elsewhere. Clarify the ownership and lifecycle management of the RowVector resource. ```suggestion // Pass through the StatefulRecord directly. This bridge does not take ownership of or close // the underlying RowVector; its lifecycle is managed by the code that created the // StatefulRecord. ``` ########## gluten-flink/runtime/src/main/java/org/apache/gluten/client/OperatorChainSliceGraph.java: ########## @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.client; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class OperatorChainSliceGraph { + private Map<Integer, OperatorChainSlice> slices; + + public OperatorChainSliceGraph() { + slices = new HashMap<>(); + } + + public void addSlice(Integer id, OperatorChainSlice chainSlice) { + slices.put(id, chainSlice); + } + + public OperatorChainSlice getSlice(Integer id) { + return slices.get(id); + } + + public OperatorChainSlice getSourceSlice() { + List<OperatorChainSlice> sourceCandidates = new ArrayList<>(); + + for (OperatorChainSlice chainSlice : slices.values()) { + if (chainSlice.getInputs().isEmpty()) { + sourceCandidates.add(chainSlice); + } + } + + if (sourceCandidates.isEmpty()) { + throw new IllegalStateException( + "No source suboperator chain found (no suboperator chain with empty inputs)"); + } else if (sourceCandidates.size() > 1) { + throw new IllegalStateException( + "Multiple source suboperator chains found: " + + sourceCandidates.size() + + " suboperator chains have empty inputs"); Review Comment: The term 'suboperator chain' is inconsistent with the class name 'OperatorChainSlice'. Consider using 'operator chain slice' consistently throughout the codebase for clarity. ```suggestion "No source operator chain slice found (no operator chain slice with empty inputs)"); } else if (sourceCandidates.size() > 1) { throw new IllegalStateException( "Multiple source operator chain slices found: " + sourceCandidates.size() + " operator chain slices have empty inputs"); ``` ########## gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenSourceFunction.java: ########## @@ -88,48 +96,139 @@ public ConnectorSplit getConnectorSplit() { } @Override - public void run(SourceContext<RowData> sourceContext) throws Exception { - LOG.debug("Running GlutenSourceFunction: " + Serde.toJson(planNode)); - memoryManager = MemoryManager.create(AllocationListener.NOOP); - session = Velox4j.newSession(memoryManager); - query = - new Query( - planNode, - VeloxQueryConfig.getConfig(getRuntimeContext()), - VeloxConnectorConfig.getConfig(getRuntimeContext())); - allocator = new RootAllocator(Long.MAX_VALUE); + public void open(Configuration parameters) throws Exception { + initSession(); + } - SerialTask task = session.queryOps().execute(query); - task.addSplit(id, split); - task.noMoreSplits(id); + @Override + public void run(SourceContext<OUT> sourceContext) throws Exception { while (isRunning) { UpIterator.State state = task.advance(); - if (state == UpIterator.State.AVAILABLE) { - final StatefulElement element = task.statefulGet(); - try (final RowVector outRv = element.asRecord().getRowVector()) { - List<RowData> rows = - FlinkRowToVLVectorConvertor.toRowData( - outRv, allocator, outputTypes.values().iterator().next()); - for (RowData row : rows) { - sourceContext.collect(row); - } - } - } else if (state == UpIterator.State.BLOCKED) { - LOG.debug("Get empty row"); + switch (state) { + case AVAILABLE: + processAvailableElement(sourceContext); + break; + case BLOCKED: + LOG.debug("Get empty row"); + break; + default: + LOG.info("Velox task finished"); + return; + } + taskMetrics.updateMetrics(task, id); + } + } + + /** Processes an available element from the task, handling records and watermarks. */ + private void processAvailableElement(SourceContext<OUT> sourceContext) { + StatefulElement element = task.statefulGet(); + try { + if (element.isRecord()) { + processRecord(sourceContext, element.asRecord()); + } else if (element.isWatermark()) { + processWatermark(sourceContext, element.asWatermark()); } else { - LOG.info("Velox task finished"); - break; + LOG.debug("Ignoring element that is neither record nor watermark"); } + } finally { + element.close(); } + } + + /** Processes a StatefulRecord and collects it to the source context. */ + private void processRecord(SourceContext<OUT> sourceContext, StatefulRecord record) { + if (isRowDataOutput()) { + collectAsRowData(sourceContext, record); + } else if (isStatefulRecordOutput()) { + collectAsStatefulRecord(sourceContext, record); + } else { + throw new UnsupportedOperationException("Unsupported output class: " + outClass.getName()); + } + } + Review Comment: The output type checking logic is duplicated between `isRowDataOutput()`, `isStatefulRecordOutput()`, and the exception message. Consider refactoring to use a strategy pattern or a map of output type handlers to reduce code duplication and improve maintainability. ```suggestion /** Strategy for collecting a StatefulRecord into the source context. */ private interface OutputHandler { void collect(SourceContext<OUT> sourceContext, StatefulRecord record); } /** Returns the appropriate OutputHandler based on the configured output class. */ private OutputHandler getOutputHandler() { if (isRowDataOutput()) { return this::collectAsRowData; } if (isStatefulRecordOutput()) { return this::collectAsStatefulRecord; } throw unsupportedOutputException(); } /** Creates an exception for unsupported output class configurations. */ private UnsupportedOperationException unsupportedOutputException() { return new UnsupportedOperationException("Unsupported output class: " + outClass.getName()); } /** Processes a StatefulRecord and collects it to the source context. */ private void processRecord(SourceContext<OUT> sourceContext, StatefulRecord record) { OutputHandler handler = getOutputHandler(); handler.collect(sourceContext, record); } ``` ########## gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenOneInputOperator.java: ########## @@ -104,43 +141,66 @@ public void open() throws Exception { mockInput, VeloxQueryConfig.getConfig(getRuntimeContext()), VeloxConnectorConfig.getConfig(getRuntimeContext())); - allocator = new RootAllocator(Long.MAX_VALUE); - task = session.queryOps().execute(query); - ExternalStreamConnectorSplit split = - new ExternalStreamConnectorSplit("connector-external-stream", inputQueue.id()); - task.addSplit(id, split); + task = sessionResource.getSession().queryOps().execute(query); + task.addSplit( + id, new ExternalStreamConnectorSplit("connector-external-stream", inputQueue.id())); task.noMoreSplits(id); } @Override - public void processElement(StreamRecord<RowData> element) { - try (RowVector inRv = - FlinkRowToVLVectorConvertor.fromRowData( - element.getValue(), allocator, session, inputType)) { - inputQueue.put(inRv); + public void open() throws Exception { + super.open(); + initSession(); + } + + @Override + public void processElement(StreamRecord<IN> element) { + if (element.getValue() == null) { + return; + } + StatefulRecord statefulRecord = + inputBridge.convertToStatefulRecord( + element, sessionResource.getAllocator(), sessionResource.getSession(), inputType); + inputQueue.put(statefulRecord.getRowVector()); + + // Only the rowvectors generated by this operator should be closed here. + if (getId().equals(statefulRecord.getNodeId())) { + statefulRecord.close(); + } + processElementInternal(); + } + + private void processElementInternal() { + while (true) { UpIterator.State state = task.advance(); if (state == UpIterator.State.AVAILABLE) { final StatefulElement statefulElement = task.statefulGet(); - - try (RowVector outRv = statefulElement.asRecord().getRowVector()) { - List<RowData> rows = - FlinkRowToVLVectorConvertor.toRowData( - outRv, allocator, outputTypes.values().iterator().next()); - for (RowData row : rows) { - output.collect(outElement.replace(row)); - } + if (statefulElement.isWatermark()) { + StatefulWatermark watermark = statefulElement.asWatermark(); + output.emitWatermark(new Watermark(watermark.getTimestamp())); + } else { + outputBridge.collect( + output, statefulElement.asRecord(), sessionResource.getAllocator(), outputType); } + statefulElement.close(); + } else { + break; } } } @Override public void close() throws Exception { - inputQueue.close(); - task.close(); - session.close(); - memoryManager.close(); - allocator.close(); + if (inputQueue != null) { + inputQueue.noMoreInput(); Review Comment: The `noMoreInput()` call should be wrapped in a null check similar to the other cleanup operations. If `inputQueue` is null, this will throw a NullPointerException. ########## gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenOneInputOperator.java: ########## @@ -19,75 +19,112 @@ import org.apache.gluten.streaming.api.operators.GlutenOperator; import org.apache.gluten.table.runtime.config.VeloxConnectorConfig; import org.apache.gluten.table.runtime.config.VeloxQueryConfig; -import org.apache.gluten.vectorized.FlinkRowToVLVectorConvertor; +import org.apache.gluten.util.VectorInputBridge; +import org.apache.gluten.util.VectorOutputBridge; -import io.github.zhztheplayer.velox4j.Velox4j; import io.github.zhztheplayer.velox4j.connector.ExternalStreamConnectorSplit; import io.github.zhztheplayer.velox4j.connector.ExternalStreamTableHandle; import io.github.zhztheplayer.velox4j.connector.ExternalStreams; -import io.github.zhztheplayer.velox4j.data.RowVector; import io.github.zhztheplayer.velox4j.iterator.UpIterator; -import io.github.zhztheplayer.velox4j.memory.AllocationListener; -import io.github.zhztheplayer.velox4j.memory.MemoryManager; import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; import io.github.zhztheplayer.velox4j.plan.TableScanNode; import io.github.zhztheplayer.velox4j.query.Query; import io.github.zhztheplayer.velox4j.query.SerialTask; import io.github.zhztheplayer.velox4j.serde.Serde; -import io.github.zhztheplayer.velox4j.session.Session; import io.github.zhztheplayer.velox4j.stateful.StatefulElement; +import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; +import io.github.zhztheplayer.velox4j.stateful.StatefulWatermark; import io.github.zhztheplayer.velox4j.type.RowType; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.table.data.RowData; import org.apache.flink.table.runtime.operators.TableStreamOperator; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.List; import java.util.Map; /** Calculate operator in gluten, which will call Velox to run. */ -public class GlutenOneInputOperator extends TableStreamOperator<RowData> - implements OneInputStreamOperator<RowData, RowData>, GlutenOperator { +public class GlutenOneInputOperator<IN, OUT> extends TableStreamOperator<OUT> + implements OneInputStreamOperator<IN, OUT>, GlutenOperator { private static final Logger LOG = LoggerFactory.getLogger(GlutenOneInputOperator.class); private final StatefulPlanNode glutenPlan; private final String id; private final RowType inputType; private final Map<String, RowType> outputTypes; + private final RowType outputType; + private final String description; - private StreamRecord<RowData> outElement = null; - - private MemoryManager memoryManager; - private Session session; - private Query query; - private ExternalStreams.BlockingQueue inputQueue; - private BufferAllocator allocator; - private SerialTask task; + private transient GlutenSessionResource sessionResource; + private transient Query query; + private transient ExternalStreams.BlockingQueue inputQueue; + private transient SerialTask task; + private final Class<IN> inClass; + private final Class<OUT> outClass; + private transient VectorInputBridge<IN> inputBridge; + private transient VectorOutputBridge<OUT> outputBridge; public GlutenOneInputOperator( - StatefulPlanNode plan, String id, RowType inputType, Map<String, RowType> outputTypes) { + StatefulPlanNode plan, + String id, + RowType inputType, + Map<String, RowType> outputTypes, + Class<IN> inClass, + Class<OUT> outClass, + String description) { + if (plan == null) { + throw new IllegalArgumentException("plan is null"); Review Comment: The error message should be more descriptive. Consider including context about which parameter is null and what the expected value should be. For example: 'StatefulPlanNode cannot be null in GlutenOneInputOperator constructor'. ```suggestion throw new IllegalArgumentException( "StatefulPlanNode 'plan' cannot be null in GlutenOneInputOperator constructor"); ``` ########## gluten-flink/runtime/src/main/java/org/apache/gluten/table/runtime/operators/GlutenOneInputOperator.java: ########## @@ -19,75 +19,112 @@ import org.apache.gluten.streaming.api.operators.GlutenOperator; import org.apache.gluten.table.runtime.config.VeloxConnectorConfig; import org.apache.gluten.table.runtime.config.VeloxQueryConfig; -import org.apache.gluten.vectorized.FlinkRowToVLVectorConvertor; +import org.apache.gluten.util.VectorInputBridge; +import org.apache.gluten.util.VectorOutputBridge; -import io.github.zhztheplayer.velox4j.Velox4j; import io.github.zhztheplayer.velox4j.connector.ExternalStreamConnectorSplit; import io.github.zhztheplayer.velox4j.connector.ExternalStreamTableHandle; import io.github.zhztheplayer.velox4j.connector.ExternalStreams; -import io.github.zhztheplayer.velox4j.data.RowVector; import io.github.zhztheplayer.velox4j.iterator.UpIterator; -import io.github.zhztheplayer.velox4j.memory.AllocationListener; -import io.github.zhztheplayer.velox4j.memory.MemoryManager; import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; import io.github.zhztheplayer.velox4j.plan.TableScanNode; import io.github.zhztheplayer.velox4j.query.Query; import io.github.zhztheplayer.velox4j.query.SerialTask; import io.github.zhztheplayer.velox4j.serde.Serde; -import io.github.zhztheplayer.velox4j.session.Session; import io.github.zhztheplayer.velox4j.stateful.StatefulElement; +import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; +import io.github.zhztheplayer.velox4j.stateful.StatefulWatermark; import io.github.zhztheplayer.velox4j.type.RowType; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; -import org.apache.flink.table.data.RowData; import org.apache.flink.table.runtime.operators.TableStreamOperator; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.List; import java.util.Map; /** Calculate operator in gluten, which will call Velox to run. */ -public class GlutenOneInputOperator extends TableStreamOperator<RowData> - implements OneInputStreamOperator<RowData, RowData>, GlutenOperator { +public class GlutenOneInputOperator<IN, OUT> extends TableStreamOperator<OUT> + implements OneInputStreamOperator<IN, OUT>, GlutenOperator { private static final Logger LOG = LoggerFactory.getLogger(GlutenOneInputOperator.class); private final StatefulPlanNode glutenPlan; private final String id; private final RowType inputType; private final Map<String, RowType> outputTypes; + private final RowType outputType; + private final String description; - private StreamRecord<RowData> outElement = null; - - private MemoryManager memoryManager; - private Session session; - private Query query; - private ExternalStreams.BlockingQueue inputQueue; - private BufferAllocator allocator; - private SerialTask task; + private transient GlutenSessionResource sessionResource; + private transient Query query; + private transient ExternalStreams.BlockingQueue inputQueue; + private transient SerialTask task; + private final Class<IN> inClass; + private final Class<OUT> outClass; + private transient VectorInputBridge<IN> inputBridge; + private transient VectorOutputBridge<OUT> outputBridge; public GlutenOneInputOperator( - StatefulPlanNode plan, String id, RowType inputType, Map<String, RowType> outputTypes) { + StatefulPlanNode plan, + String id, + RowType inputType, + Map<String, RowType> outputTypes, + Class<IN> inClass, + Class<OUT> outClass, + String description) { + if (plan == null) { + throw new IllegalArgumentException("plan is null"); + } this.glutenPlan = plan; this.id = id; this.inputType = inputType; this.outputTypes = outputTypes; + this.inClass = inClass; + this.outClass = outClass; + this.inputBridge = VectorInputBridge.Factory.create(inClass, getId()); + this.outputBridge = VectorOutputBridge.Factory.create(outClass); + this.outputType = outputTypes.values().iterator().next(); + this.description = description; + } + + public GlutenOneInputOperator( + StatefulPlanNode plan, + String id, + RowType inputType, + Map<String, RowType> outputTypes, + Class<IN> inClass, + Class<OUT> outClass) { + this(plan, id, inputType, outputTypes, inClass, outClass, ""); } @Override - public void open() throws Exception { - super.open(); - outElement = new StreamRecord(null); - memoryManager = MemoryManager.create(AllocationListener.NOOP); - session = Velox4j.newSession(memoryManager); + public String getDescription() { + return description; + } - inputQueue = session.externalStreamOps().newBlockingQueue(); + void initSession() { + if (sessionResource != null) { + return; + } + if (inputBridge == null) { + inputBridge = VectorInputBridge.Factory.create(inClass, getId()); + } + if (outputBridge == null) { + outputBridge = VectorOutputBridge.Factory.create(outClass); + } + sessionResource = new GlutenSessionResource(); + inputQueue = sessionResource.getSession().externalStreamOps().newBlockingQueue(); // add a mock input as velox not allow the source is empty. + if (inputType == null) { + throw new IllegalArgumentException("inputType is null. plan is " + Serde.toJson(glutenPlan)); Review Comment: The error message should not serialize the entire plan as it may be large and contain sensitive information. Consider logging the plan separately if needed for debugging, and keep the exception message concise: 'inputType cannot be null'. ```suggestion if (LOG.isDebugEnabled()) { LOG.debug("inputType is null. Plan: {}", Serde.toJson(glutenPlan)); } throw new IllegalArgumentException("inputType cannot be null"); ``` ########## gluten-flink/runtime/src/main/java/org/apache/gluten/client/OffloadedJobGraphGenerator.java: ########## @@ -0,0 +1,611 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.client; + +import org.apache.gluten.streaming.api.operators.GlutenOperator; +import org.apache.gluten.streaming.api.operators.GlutenStreamSource; +import org.apache.gluten.table.runtime.keyselector.GlutenKeySelector; +import org.apache.gluten.table.runtime.operators.GlutenOneInputOperator; +import org.apache.gluten.table.runtime.operators.GlutenSourceFunction; +import org.apache.gluten.table.runtime.operators.GlutenTwoInputOperator; +import org.apache.gluten.table.runtime.typeutils.GlutenStatefulRecordSerializer; +import org.apache.gluten.util.Utils; + +import io.github.zhztheplayer.velox4j.plan.StatefulPlanNode; +import io.github.zhztheplayer.velox4j.stateful.StatefulRecord; +import io.github.zhztheplayer.velox4j.type.RowType; + +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.jobgraph.IntermediateDataSet; +import org.apache.flink.runtime.jobgraph.JobEdge; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobVertex; +import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.graph.StreamEdge; +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.operators.StreamOperatorFactory; +import org.apache.flink.table.data.RowData; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/* + * Generates an offloaded JobGraph by transforming offloadable operators to Gluten operators. + * Main workflow: + * 1. For each JobVertex, generate an OperatorChainSliceGraph to identify offloadable slices. + * 2. Recursively visit chain slices: create offloaded operators for offloadable slices, + * keep original operators for unoffloadable ones. + * 3. For offloadable operators: update input/output serializers and state partitioners + * based on upstream/downstream operator capabilities. + * 4. Update stream edges and serialize all operator configurations. + */ +public class OffloadedJobGraphGenerator { + private static final Logger LOG = LoggerFactory.getLogger(OffloadedJobGraphGenerator.class); + private final JobGraph jobGraph; + private final ClassLoader userClassloader; + private boolean hasGenerated = false; + + public OffloadedJobGraphGenerator(JobGraph jobGraph, ClassLoader userClassloader) { + this.jobGraph = jobGraph; + this.userClassloader = userClassloader; + } + + public JobGraph generate() { + if (hasGenerated) { + throw new IllegalStateException("JobGraph has been generated."); + } + hasGenerated = true; + for (JobVertex jobVertex : jobGraph.getVertices()) { + offloadJobVertex(jobVertex); + } + return jobGraph; + } + + private void offloadJobVertex(JobVertex jobVertex) { + OperatorChainSliceGraphGenerator graphGenerator = + new OperatorChainSliceGraphGenerator(jobVertex, userClassloader); + OperatorChainSliceGraph chainSliceGraph = graphGenerator.getGraph(); + LOG.info("OperatorChainSliceGraph:\n{}", chainSliceGraph); + + OperatorChainSlice sourceChainSlice = chainSliceGraph.getSourceSlice(); + OperatorChainSliceGraph offloadedChainSliceGraph = new OperatorChainSliceGraph(); + visitAndOffloadChainOperators( + sourceChainSlice, chainSliceGraph, offloadedChainSliceGraph, 0, jobVertex); + visitAndUpdateStreamEdges(sourceChainSlice, chainSliceGraph, offloadedChainSliceGraph); + serializeAllOperatorsConfigs(offloadedChainSliceGraph); + + StreamConfig sourceConfig = sourceChainSlice.getOperatorConfigs().get(0); + StreamConfig offloadedSourceConfig = + offloadedChainSliceGraph.getSlice(sourceChainSlice.id()).getOperatorConfigs().get(0); + + Map<Integer, StreamConfig> chainedConfigs = + collectChainedConfigs(sourceChainSlice, offloadedChainSliceGraph); + updateSourceConfigIfOffloadable( + sourceConfig, offloadedSourceConfig, sourceChainSlice, chainedConfigs); + sourceConfig.setAndSerializeTransitiveChainedTaskConfigs(chainedConfigs); + sourceConfig.serializeAllConfigs(); + } + + // Process and offload operator chain slices recursively + private void visitAndOffloadChainOperators( + OperatorChainSlice sourceChainSlice, + OperatorChainSliceGraph sourceChainSliceGraph, + OperatorChainSliceGraph offloadedChainSliceGraph, + Integer chainIndex, + JobVertex jobVertex) { + OperatorChainSlice processedChainSlice; + if (sourceChainSlice.isOffloadable()) { + processedChainSlice = + createOffloadedOperatorChainSlice( + sourceChainSliceGraph, sourceChainSlice, chainIndex, jobVertex); + chainIndex = chainIndex + 1; + } else { + processedChainSlice = createUnoffloadableOperatorChainSlice(sourceChainSlice, chainIndex); + chainIndex = chainIndex + sourceChainSlice.getOperatorConfigs().size(); + } + + processedChainSlice.getInputs().addAll(sourceChainSlice.getInputs()); + processedChainSlice.getOutputs().addAll(sourceChainSlice.getOutputs()); + offloadedChainSliceGraph.addSlice(sourceChainSlice.id(), processedChainSlice); + + // Recursively process downstream chain slices + for (Integer downstreamSliceId : sourceChainSlice.getOutputs()) { + OperatorChainSlice downstreamSourceSlice = sourceChainSliceGraph.getSlice(downstreamSliceId); + OperatorChainSlice downstreamProcessedSlice = + offloadedChainSliceGraph.getSlice(downstreamSliceId); + if (downstreamProcessedSlice == null) { + visitAndOffloadChainOperators( + downstreamSourceSlice, + sourceChainSliceGraph, + offloadedChainSliceGraph, + chainIndex, + jobVertex); + } + } + } + + // Keep the original operator chain slice as is. + private OperatorChainSlice createUnoffloadableOperatorChainSlice( + OperatorChainSlice sourceChainSlice, Integer chainIndex) { + OperatorChainSlice unoffloadableChainSlice = new OperatorChainSlice(sourceChainSlice.id()); + List<StreamConfig> operatorConfigs = sourceChainSlice.getOperatorConfigs(); + for (StreamConfig opConfig : operatorConfigs) { + StreamConfig newOpConfig = new StreamConfig(new Configuration(opConfig.getConfiguration())); + newOpConfig.setChainIndex(chainIndex); + unoffloadableChainSlice.getOperatorConfigs().add(newOpConfig); + } + unoffloadableChainSlice.setOffloadable(false); + return unoffloadableChainSlice; + } + + // Create offloadable operator chain slice, and update the input/output channel serializers + private OperatorChainSlice createOffloadedOperatorChainSlice( + OperatorChainSliceGraph chainSliceGraph, + OperatorChainSlice sourceChainSlice, + Integer chainIndex, + JobVertex jobVertex) { + OperatorChainSlice offloadedChainSlice = new OperatorChainSlice(sourceChainSlice.id()); + List<StreamConfig> operatorConfigs = sourceChainSlice.getOperatorConfigs(); + + // May coalesce multiple operators into one in the future. + if (operatorConfigs.size() != 1) { + throw new UnsupportedOperationException( + "Only one operator is supported for offloaded operator chain slice."); + } + + StreamConfig sourceOpConfig = operatorConfigs.get(0); + GlutenOperator sourceOperator = Utils.getGlutenOperator(sourceOpConfig, userClassloader).get(); + StatefulPlanNode planNode = sourceOperator.getPlanNode(); + // Create a new operator config for the offloaded operator. + StreamConfig offloadedOpConfig = + new StreamConfig(new Configuration(sourceOpConfig.getConfiguration())); + if (sourceOperator instanceof GlutenStreamSource) { + boolean canOutputRowVector = canOutputRowVector(sourceChainSlice, chainSliceGraph, jobVertex); + Class<?> outClass = canOutputRowVector ? StatefulRecord.class : RowData.class; + GlutenStreamSource newSourceOp = + new GlutenStreamSource( + new GlutenSourceFunction<>( + planNode, + sourceOperator.getOutputTypes(), + sourceOperator.getId(), + ((GlutenStreamSource) sourceOperator).getConnectorSplit(), + outClass)); + offloadedOpConfig.setStreamOperator(newSourceOp); + if (canOutputRowVector) { + setOffloadedOutputSerializer(offloadedOpConfig, sourceOperator); + } + } else if (sourceOperator instanceof GlutenOneInputOperator) { + createOffloadedOneInputOperator( + sourceChainSlice, + chainSliceGraph, + jobVertex, + planNode, + (GlutenOneInputOperator<?, ?>) sourceOperator, + sourceOpConfig, + offloadedOpConfig); + } else if (sourceOperator instanceof GlutenTwoInputOperator) { + createOffloadedTwoInputOperator( + sourceChainSlice, + chainSliceGraph, + jobVertex, + planNode, + (GlutenTwoInputOperator<?, ?>) sourceOperator, + sourceOpConfig, + offloadedOpConfig); + } else { + throw new UnsupportedOperationException( + "Unsupported operator type for offloading: " + sourceOperator.getClass().getName()); + } + + offloadedOpConfig.setChainIndex(chainIndex); + offloadedChainSlice.getOperatorConfigs().add(offloadedOpConfig); + offloadedChainSlice.setOffloadable(true); + return offloadedChainSlice; + } + + private void createOffloadedOneInputOperator( + OperatorChainSlice sourceChainSlice, + OperatorChainSliceGraph chainSliceGraph, + JobVertex jobVertex, + StatefulPlanNode planNode, + GlutenOneInputOperator<?, ?> sourceOperator, + StreamConfig sourceOpConfig, + StreamConfig offloadedOpConfig) { + boolean canOutputRowVector = canOutputRowVector(sourceChainSlice, chainSliceGraph, jobVertex); + boolean canInputRowVector = canInputRowVector(sourceChainSlice, chainSliceGraph, jobVertex); + Class<?> inClass = canInputRowVector ? StatefulRecord.class : RowData.class; + Class<?> outClass = canOutputRowVector ? StatefulRecord.class : RowData.class; + GlutenOneInputOperator<?, ?> newOneInputOp = + new GlutenOneInputOperator<>( + planNode, + sourceOperator.getId(), + sourceOperator.getInputType(), + sourceOperator.getOutputTypes(), + inClass, + outClass, + sourceOperator.getDescription()); + offloadedOpConfig.setStreamOperator(newOneInputOp); + if (canOutputRowVector) { + setOffloadedOutputSerializer(offloadedOpConfig, sourceOperator); + } + if (canInputRowVector) { + setOffloadedInputSerializer(offloadedOpConfig, sourceOperator); + setOffloadedStatePartitioner( + sourceOpConfig, offloadedOpConfig, 0, sourceOperator.getDescription()); + } + } + + private void createOffloadedTwoInputOperator( + OperatorChainSlice sourceChainSlice, + OperatorChainSliceGraph chainSliceGraph, + JobVertex jobVertex, + StatefulPlanNode planNode, + GlutenTwoInputOperator<?, ?> sourceOperator, + StreamConfig sourceOpConfig, + StreamConfig offloadedOpConfig) { + boolean canOutputRowVector = canOutputRowVector(sourceChainSlice, chainSliceGraph, jobVertex); + boolean canInputRowVector = canInputRowVector(sourceChainSlice, chainSliceGraph, jobVertex); + setOffloadedStatePartitioner( + sourceOpConfig, offloadedOpConfig, 0, sourceOperator.getDescription()); + setOffloadedStatePartitioner( + sourceOpConfig, offloadedOpConfig, 1, sourceOperator.getDescription()); + Class<?> inClass = canInputRowVector ? StatefulRecord.class : RowData.class; + Class<?> outClass = canOutputRowVector ? StatefulRecord.class : RowData.class; + GlutenTwoInputOperator<?, ?> newTwoInputOp = + new GlutenTwoInputOperator<>( + planNode, + sourceOperator.getLeftId(), + sourceOperator.getRightId(), + sourceOperator.getLeftInputType(), + sourceOperator.getRightInputType(), + sourceOperator.getOutputTypes(), + inClass, + outClass); + offloadedOpConfig.setStreamOperator(newTwoInputOp); + offloadedOpConfig.setStatePartitioner(0, new GlutenKeySelector()); + offloadedOpConfig.setStatePartitioner(1, new GlutenKeySelector()); + if (canOutputRowVector) { + setOffloadedOutputSerializer(offloadedOpConfig, sourceOperator); + } + if (canInputRowVector) { + setOffloadedInputSerializersForTwoInputOperator(offloadedOpConfig, sourceOperator); + } + } + + private void setOffloadedOutputSerializer(StreamConfig opConfig, GlutenOperator operator) { + RowType rowType = operator.getOutputTypes().entrySet().iterator().next().getValue(); + opConfig.setTypeSerializerOut(new GlutenStatefulRecordSerializer(rowType, operator.getId())); + } + + private void setOffloadedInputSerializer(StreamConfig opConfig, GlutenOperator operator) { + opConfig.setupNetworkInputs( + new GlutenStatefulRecordSerializer(operator.getInputType(), operator.getId())); + } + + private void setOffloadedInputSerializersForTwoInputOperator( + StreamConfig opConfig, GlutenTwoInputOperator<?, ?> operator) { + opConfig.setupNetworkInputs( + new GlutenStatefulRecordSerializer(operator.getLeftInputType(), operator.getId()), + new GlutenStatefulRecordSerializer(operator.getRightInputType(), operator.getId())); + } + + private void setOffloadedStatePartitioner( + StreamConfig sourceOpConfig, + StreamConfig offloadedOpConfig, + int inputIndex, + String operatorDescription) { + KeySelector<?, ?> keySelector = sourceOpConfig.getStatePartitioner(inputIndex, userClassloader); + if (keySelector != null) { + LOG.info( + "State partitioner ({}) found in input {} of operator {}, change it to GlutenKeySelector.", + keySelector.getClass().getName(), + inputIndex, + operatorDescription); + offloadedOpConfig.setStatePartitioner(inputIndex, new GlutenKeySelector()); + } + } + + private StreamConfig findLastOperatorInChain(JobVertex vertex) { + StreamConfig rootStreamConfig = new StreamConfig(vertex.getConfiguration()); + Map<Integer, StreamConfig> chainedConfigs = + rootStreamConfig.getTransitiveChainedTaskConfigs(userClassloader); + chainedConfigs.put(rootStreamConfig.getVertexID(), rootStreamConfig); + + // Find the last operator (the one with no chained outputs) + for (StreamConfig config : chainedConfigs.values()) { + List<StreamEdge> chainedOutputs = config.getChainedOutputs(userClassloader); + if (chainedOutputs == null || chainedOutputs.isEmpty()) { + return config; + } + } + + // If no last operator found, use the root config + return rootStreamConfig; + } + + private StreamNode mockStreamNode(StreamConfig streamConfig) { + return new StreamNode( + streamConfig.getVertexID(), + null, + null, + (StreamOperatorFactory<?>) streamConfig.getStreamOperatorFactory(userClassloader), + streamConfig.getOperatorName(), + null); + } + + // Update stream edges when vertices have been changed due to offloading + private void visitAndUpdateStreamEdges( + OperatorChainSlice sourceChainSlice, + OperatorChainSliceGraph sourceChainSliceGraph, + OperatorChainSliceGraph offloadedChainSliceGraph) { + OperatorChainSlice offloadedChainSlice = + offloadedChainSliceGraph.getSlice(sourceChainSlice.id()); + if (offloadedChainSlice.isOffloadable()) { + updateStreamEdgesForOffloadedSlice( + sourceChainSlice, sourceChainSliceGraph, offloadedChainSliceGraph, offloadedChainSlice); + } + + // Recursively update downstream chain slices + for (Integer downstreamSliceId : sourceChainSlice.getOutputs()) { + visitAndUpdateStreamEdges( + sourceChainSliceGraph.getSlice(downstreamSliceId), + sourceChainSliceGraph, + offloadedChainSliceGraph); + } + } + + private void updateStreamEdgesForOffloadedSlice( + OperatorChainSlice sourceChainSlice, + OperatorChainSliceGraph sourceChainSliceGraph, + OperatorChainSliceGraph offloadedChainSliceGraph, + OperatorChainSlice offloadedChainSlice) { + List<Integer> downstreamSliceIds = sourceChainSlice.getOutputs(); + if (downstreamSliceIds.isEmpty()) { + StreamConfig offloadedOpConfig = offloadedChainSlice.getOperatorConfigs().get(0); + offloadedOpConfig.setChainedOutputs(new ArrayList<>()); + return; + } + + List<StreamEdge> newOutputEdges = new ArrayList<>(); + List<StreamConfig> sourceOperatorConfigs = sourceChainSlice.getOperatorConfigs(); + StreamConfig lastSourceOpConfig = sourceOperatorConfigs.get(sourceOperatorConfigs.size() - 1); + List<StreamEdge> originalOutputEdges = lastSourceOpConfig.getChainedOutputs(userClassloader); + + for (int i = 0; i < downstreamSliceIds.size(); i++) { + Integer downstreamSliceId = downstreamSliceIds.get(i); + OperatorChainSlice downstreamOffloadedSlice = + offloadedChainSliceGraph.getSlice(downstreamSliceId); + StreamConfig downstreamOpConfig = downstreamOffloadedSlice.getOperatorConfigs().get(0); + StreamEdge originalEdge = originalOutputEdges.get(i); + StreamEdge newEdge = createStreamEdge(originalEdge, downstreamOpConfig, offloadedChainSlice); + newOutputEdges.add(newEdge); + } + + StreamConfig offloadedOpConfig = offloadedChainSlice.getOperatorConfigs().get(0); + offloadedOpConfig.setChainedOutputs(newOutputEdges); + } + + private StreamEdge createStreamEdge( + StreamEdge originalEdge, + StreamConfig downstreamOpConfig, + OperatorChainSlice offloadedChainSlice) { + StreamConfig sourceOpConfig = offloadedChainSlice.getOperatorConfigs().get(0); + return new StreamEdge( + mockStreamNode(sourceOpConfig), + mockStreamNode(downstreamOpConfig), + originalEdge.getTypeNumber(), + originalEdge.getBufferTimeout(), + originalEdge.getPartitioner(), + originalEdge.getOutputTag(), + originalEdge.getExchangeMode(), + 0, Review Comment: The magic number '0' for buffer size in StreamEdge creation is unclear. Add a comment explaining why buffer size is set to 0 or use a named constant to clarify the intent. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
