This is an automated email from the ASF dual-hosted git repository. jark pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 523546101f0180999f11d68269aad53c59134064 Author: fengli <ldliu...@163.com> AuthorDate: Mon Aug 29 20:10:53 2022 +0800 [fixup][table-planner] Using user classloader instead of thread context classloader --- .../exec/batch/BatchExecPythonGroupAggregate.java | 9 ++-- .../batch/BatchExecPythonGroupWindowAggregate.java | 8 ++-- .../exec/batch/BatchExecPythonOverAggregate.java | 8 ++-- .../nodes/exec/common/CommonExecPythonCalc.java | 30 ++++++++---- .../exec/common/CommonExecPythonCorrelate.java | 24 ++++++---- .../stream/StreamExecPythonGroupAggregate.java | 11 +++-- .../StreamExecPythonGroupTableAggregate.java | 12 +++-- .../StreamExecPythonGroupWindowAggregate.java | 16 +++++-- .../exec/stream/StreamExecPythonOverAggregate.java | 10 ++-- .../plan/nodes/exec/utils/CommonPythonUtil.java | 53 ++++++++++++++-------- .../physical/common/CommonPhysicalMatchRule.java | 3 +- .../table/planner/delegation/PlannerBase.scala | 2 +- .../physical/batch/BatchPhysicalSortRule.scala | 1 - 13 files changed, 125 insertions(+), 62 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.java index dbb6033c364..98e2ca2551d 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.java @@ -94,7 +94,8 @@ public class BatchExecPythonGroupAggregate extends ExecNodeBase<RowData> final RowType inputRowType = (RowType) inputEdge.getOutputType(); final RowType outputRowType = InternalTypeInfo.of(getOutputType()).toRowType(); Configuration pythonConfig = - CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config); + CommonPythonUtil.extractPythonConfiguration( + planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader()); OneInputTransformation<RowData, RowData> transform = createPythonOneInputTransformation( inputTransform, @@ -104,7 +105,8 @@ public class BatchExecPythonGroupAggregate extends ExecNodeBase<RowData> config, planner.getFlinkContext().getClassLoader()); - if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) { + if (CommonPythonUtil.isPythonWorkerUsingManagedMemory( + pythonConfig, planner.getFlinkContext().getClassLoader())) { transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); } return transform; @@ -149,7 +151,8 @@ public class BatchExecPythonGroupAggregate extends ExecNodeBase<RowData> int[] udafInputOffsets, PythonFunctionInfo[] pythonFunctionInfos) { final Class<?> clazz = - CommonPythonUtil.loadClass(ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME); + CommonPythonUtil.loadClass( + ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME, classLoader); RowType udfInputType = (RowType) Projection.of(udafInputOffsets).project(inputRowType); RowType udfOutputType = diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupWindowAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupWindowAggregate.java index ae8a9c2ad02..930a2f7fe59 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupWindowAggregate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupWindowAggregate.java @@ -114,7 +114,8 @@ public class BatchExecPythonGroupWindowAggregate extends ExecNodeBase<RowData> final Tuple2<Long, Long> windowSizeAndSlideSize = WindowCodeGenerator.getWindowDef(window); final Configuration pythonConfig = - CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config); + CommonPythonUtil.extractPythonConfiguration( + planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader()); int groupBufferLimitSize = pythonConfig.getInteger( ExecutionConfigOptions.TABLE_EXEC_WINDOW_AGG_BUFFER_SIZE_LIMIT); @@ -130,7 +131,8 @@ public class BatchExecPythonGroupWindowAggregate extends ExecNodeBase<RowData> pythonConfig, config, planner.getFlinkContext().getClassLoader()); - if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) { + if (CommonPythonUtil.isPythonWorkerUsingManagedMemory( + pythonConfig, planner.getFlinkContext().getClassLoader())) { transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); } return transform; @@ -204,7 +206,7 @@ public class BatchExecPythonGroupWindowAggregate extends ExecNodeBase<RowData> PythonFunctionInfo[] pythonFunctionInfos) { Class<?> clazz = CommonPythonUtil.loadClass( - ARROW_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME); + ARROW_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME, classLoader); RowType udfInputType = (RowType) Projection.of(udafInputOffsets).project(inputRowType); RowType udfOutputType = diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonOverAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonOverAggregate.java index 5023931259f..9f4717aa5ef 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonOverAggregate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonOverAggregate.java @@ -153,7 +153,8 @@ public class BatchExecPythonOverAggregate extends BatchExecOverAggregateBase { } } Configuration pythonConfig = - CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config); + CommonPythonUtil.extractPythonConfiguration( + planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader()); OneInputTransformation<RowData, RowData> transform = createPythonOneInputTransformation( inputTransform, @@ -163,7 +164,8 @@ public class BatchExecPythonOverAggregate extends BatchExecOverAggregateBase { pythonConfig, config, planner.getFlinkContext().getClassLoader()); - if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) { + if (CommonPythonUtil.isPythonWorkerUsingManagedMemory( + pythonConfig, planner.getFlinkContext().getClassLoader())) { transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); } return transform; @@ -213,7 +215,7 @@ public class BatchExecPythonOverAggregate extends BatchExecOverAggregateBase { PythonFunctionInfo[] pythonFunctionInfos) { Class<?> clazz = CommonPythonUtil.loadClass( - ARROW_PYTHON_OVER_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME); + ARROW_PYTHON_OVER_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME, classLoader); RowType udfInputType = (RowType) Projection.of(udafInputOffsets).project(inputRowType); RowType udfOutputType = diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.java index e102de9d063..d0249791edd 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.java @@ -108,14 +108,16 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData> final Transformation<RowData> inputTransform = (Transformation<RowData>) inputEdge.translateToPlan(planner); final Configuration pythonConfig = - CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config); + CommonPythonUtil.extractPythonConfiguration( + planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader()); OneInputTransformation<RowData, RowData> ret = createPythonOneInputTransformation( inputTransform, config, planner.getFlinkContext().getClassLoader(), pythonConfig); - if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) { + if (CommonPythonUtil.isPythonWorkerUsingManagedMemory( + pythonConfig, planner.getFlinkContext().getClassLoader())) { ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); } return ret; @@ -139,7 +141,7 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData> .collect(Collectors.toList()); Tuple2<int[], PythonFunctionInfo[]> extractResult = - extractPythonScalarFunctionInfos(pythonRexCalls); + extractPythonScalarFunctionInfos(pythonRexCalls, classLoader); int[] pythonUdfInputOffsets = extractResult.f0; PythonFunctionInfo[] pythonFunctionInfos = extractResult.f1; LogicalType[] inputLogicalTypes = @@ -185,11 +187,14 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData> } private Tuple2<int[], PythonFunctionInfo[]> extractPythonScalarFunctionInfos( - List<RexCall> rexCalls) { + List<RexCall> rexCalls, ClassLoader classLoader) { LinkedHashMap<RexNode, Integer> inputNodes = new LinkedHashMap<>(); PythonFunctionInfo[] pythonFunctionInfos = rexCalls.stream() - .map(x -> CommonPythonUtil.createPythonFunctionInfo(x, inputNodes)) + .map( + x -> + CommonPythonUtil.createPythonFunctionInfo( + x, inputNodes, classLoader)) .collect(Collectors.toList()) .toArray(new PythonFunctionInfo[rexCalls.size()]); @@ -221,14 +226,21 @@ public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData> int[] forwardedFields, boolean isArrow) { Class<?> clazz; - boolean isInProcessMode = CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig); + boolean isInProcessMode = + CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig, classLoader); if (isArrow) { - clazz = CommonPythonUtil.loadClass(ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME); + clazz = + CommonPythonUtil.loadClass( + ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME, classLoader); } else { if (isInProcessMode) { - clazz = CommonPythonUtil.loadClass(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME); + clazz = + CommonPythonUtil.loadClass( + PYTHON_SCALAR_FUNCTION_OPERATOR_NAME, classLoader); } else { - clazz = CommonPythonUtil.loadClass(EMBEDDED_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME); + clazz = + CommonPythonUtil.loadClass( + EMBEDDED_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME, classLoader); } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java index 8661fd9b5b6..81940866104 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java @@ -102,7 +102,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData> final Transformation<RowData> inputTransform = (Transformation<RowData>) inputEdge.translateToPlan(planner); final Configuration pythonConfig = - CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config); + CommonPythonUtil.extractPythonConfiguration( + planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader()); final ExecNodeConfig pythonNodeConfig = ExecNodeConfig.ofNodeConfig(pythonConfig, config.isCompiled()); final OneInputTransformation<RowData, RowData> transform = @@ -111,7 +112,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData> pythonNodeConfig, planner.getFlinkContext().getClassLoader(), pythonConfig); - if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) { + if (CommonPythonUtil.isPythonWorkerUsingManagedMemory( + pythonConfig, planner.getFlinkContext().getClassLoader())) { transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); } return transform; @@ -122,7 +124,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData> ExecNodeConfig pythonNodeConfig, ClassLoader classLoader, Configuration pythonConfig) { - Tuple2<int[], PythonFunctionInfo> extractResult = extractPythonTableFunctionInfo(); + Tuple2<int[], PythonFunctionInfo> extractResult = + extractPythonTableFunctionInfo(classLoader); int[] pythonUdtfInputOffsets = extractResult.f0; PythonFunctionInfo pythonFunctionInfo = extractResult.f1; InternalTypeInfo<RowData> pythonOperatorInputRowType = @@ -146,10 +149,11 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData> inputTransform.getParallelism()); } - private Tuple2<int[], PythonFunctionInfo> extractPythonTableFunctionInfo() { + private Tuple2<int[], PythonFunctionInfo> extractPythonTableFunctionInfo( + ClassLoader classLoader) { LinkedHashMap<RexNode, Integer> inputNodes = new LinkedHashMap<>(); PythonFunctionInfo pythonTableFunctionInfo = - CommonPythonUtil.createPythonFunctionInfo(invocation, inputNodes); + CommonPythonUtil.createPythonFunctionInfo(invocation, inputNodes, classLoader); int[] udtfInputOffsets = inputNodes.keySet().stream() .filter(x -> x instanceof RexInputRef) @@ -168,7 +172,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData> InternalTypeInfo<RowData> outputRowType, PythonFunctionInfo pythonFunctionInfo, int[] udtfInputOffsets) { - boolean isInProcessMode = CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig); + boolean isInProcessMode = + CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig, classLoader); final RowType inputType = inputRowType.toRowType(); final RowType outputType = outputRowType.toRowType(); @@ -180,7 +185,9 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData> try { if (isInProcessMode) { - Class clazz = CommonPythonUtil.loadClass(PYTHON_TABLE_FUNCTION_OPERATOR_NAME); + Class clazz = + CommonPythonUtil.loadClass( + PYTHON_TABLE_FUNCTION_OPERATOR_NAME, classLoader); Constructor ctor = clazz.getConstructor( Configuration.class, @@ -206,7 +213,8 @@ public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData> udtfInputOffsets)); } else { Class clazz = - CommonPythonUtil.loadClass(EMBEDDED_PYTHON_TABLE_FUNCTION_OPERATOR_NAME); + CommonPythonUtil.loadClass( + EMBEDDED_PYTHON_TABLE_FUNCTION_OPERATOR_NAME, classLoader); Constructor ctor = clazz.getConstructor( Configuration.class, diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java index 55a8a8cd3d5..4595191332b 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.java @@ -175,10 +175,12 @@ public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase { PythonAggregateFunctionInfo[] pythonFunctionInfos = aggInfosAndDataViewSpecs.f0; DataViewSpec[][] dataViewSpecs = aggInfosAndDataViewSpecs.f1; Configuration pythonConfig = - CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config); + CommonPythonUtil.extractPythonConfiguration( + planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader()); final OneInputStreamOperator<RowData, RowData> operator = getPythonAggregateFunctionOperator( pythonConfig, + planner.getFlinkContext().getClassLoader(), inputRowType, InternalTypeInfo.of(getOutputType()).toRowType(), pythonFunctionInfos, @@ -196,7 +198,8 @@ public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase { InternalTypeInfo.of(getOutputType()), inputTransform.getParallelism()); - if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) { + if (CommonPythonUtil.isPythonWorkerUsingManagedMemory( + pythonConfig, planner.getFlinkContext().getClassLoader())) { transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); } @@ -214,6 +217,7 @@ public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase { @SuppressWarnings("unchecked") private OneInputStreamOperator<RowData, RowData> getPythonAggregateFunctionOperator( Configuration config, + ClassLoader classLoader, RowType inputType, RowType outputType, PythonAggregateFunctionInfo[] aggregateFunctions, @@ -222,7 +226,8 @@ public class StreamExecPythonGroupAggregate extends StreamExecAggregateBase { long maxIdleStateRetentionTime, int indexOfCountStar, boolean countStarInserted) { - Class<?> clazz = CommonPythonUtil.loadClass(PYTHON_STREAM_AGGREAGTE_OPERATOR_NAME); + Class<?> clazz = + CommonPythonUtil.loadClass(PYTHON_STREAM_AGGREAGTE_OPERATOR_NAME, classLoader); try { Constructor<?> ctor = clazz.getConstructor( diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupTableAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupTableAggregate.java index 179b302941a..3d05d273def 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupTableAggregate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupTableAggregate.java @@ -131,10 +131,12 @@ public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData> PythonAggregateFunctionInfo[] pythonFunctionInfos = aggInfosAndDataViewSpecs.f0; DataViewSpec[][] dataViewSpecs = aggInfosAndDataViewSpecs.f1; Configuration pythonConfig = - CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config); + CommonPythonUtil.extractPythonConfiguration( + planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader()); OneInputStreamOperator<RowData, RowData> pythonOperator = getPythonTableAggregateFunctionOperator( pythonConfig, + planner.getFlinkContext().getClassLoader(), inputRowType, InternalTypeInfo.of(getOutputType()).toRowType(), pythonFunctionInfos, @@ -153,7 +155,8 @@ public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData> InternalTypeInfo.of(getOutputType()), inputTransform.getParallelism()); - if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) { + if (CommonPythonUtil.isPythonWorkerUsingManagedMemory( + pythonConfig, planner.getFlinkContext().getClassLoader())) { transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); } @@ -171,6 +174,7 @@ public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData> @SuppressWarnings("unchecked") private OneInputStreamOperator<RowData, RowData> getPythonTableAggregateFunctionOperator( Configuration config, + ClassLoader classLoader, RowType inputRowType, RowType outputRowType, PythonAggregateFunctionInfo[] aggregateFunctions, @@ -179,7 +183,9 @@ public class StreamExecPythonGroupTableAggregate extends ExecNodeBase<RowData> long maxIdleStateRetentionTime, boolean generateUpdateBefore, int indexOfCountStar) { - Class<?> clazz = CommonPythonUtil.loadClass(PYTHON_STREAM_TABLE_AGGREGATE_OPERATOR_NAME); + Class<?> clazz = + CommonPythonUtil.loadClass( + PYTHON_STREAM_TABLE_AGGREGATE_OPERATOR_NAME, classLoader); try { Constructor<?> ctor = clazz.getConstructor( diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java index 8aa55962285..6e210a642b0 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupWindowAggregate.java @@ -258,7 +258,8 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas WindowAssigner<?> windowAssigner = windowAssignerAndTrigger.f0; Trigger<?> trigger = windowAssignerAndTrigger.f1; final Configuration pythonConfig = - CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config); + CommonPythonUtil.extractPythonConfiguration( + planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader()); final ExecNodeConfig pythonNodeConfig = ExecNodeConfig.ofNodeConfig(pythonConfig, config.isCompiled()); boolean isGeneralPythonUDAF = @@ -289,6 +290,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas emitStrategy.getAllowLateness(), pythonConfig, pythonNodeConfig, + planner.getFlinkContext().getClassLoader(), shiftTimeZone); } else { transform = @@ -306,7 +308,8 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas shiftTimeZone); } - if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) { + if (CommonPythonUtil.isPythonWorkerUsingManagedMemory( + pythonConfig, planner.getFlinkContext().getClassLoader())) { transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); } // set KeyType and Selector for state @@ -436,6 +439,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas long allowance, Configuration pythonConfig, ExecNodeConfig pythonNodeConfig, + ClassLoader classLoader, ZoneId shiftTimeZone) { final int inputCountIndex = aggInfoList.getIndexOfCountStar(); final boolean countStarInserted = aggInfoList.countStarInserted(); @@ -446,6 +450,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas OneInputStreamOperator<RowData, RowData> pythonOperator = getGeneralPythonStreamGroupWindowAggregateFunctionOperator( pythonConfig, + classLoader, inputRowType, outputRowType, windowAssigner, @@ -484,7 +489,8 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas ZoneId shiftTimeZone) { Class clazz = CommonPythonUtil.loadClass( - ARROW_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME); + ARROW_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME, + classLoader); RowType userDefinedFunctionInputType = (RowType) Projection.of(udafInputOffsets).project(inputRowType); RowType userDefinedFunctionOutputType = @@ -542,6 +548,7 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas private OneInputStreamOperator<RowData, RowData> getGeneralPythonStreamGroupWindowAggregateFunctionOperator( Configuration config, + ClassLoader classLoader, RowType inputType, RowType outputType, WindowAssigner<?> windowAssigner, @@ -555,7 +562,8 @@ public class StreamExecPythonGroupWindowAggregate extends StreamExecAggregateBas ZoneId shiftTimeZone) { Class clazz = CommonPythonUtil.loadClass( - GENERAL_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME); + GENERAL_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME, + classLoader); boolean isRowTime = AggregateUtil.isRowtimeAttribute(window.timeAttribute()); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java index fd507b97d4a..d1057bcab6f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java @@ -197,7 +197,8 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData> } long precedingOffset = -1 * (long) boundValue; Configuration pythonConfig = - CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config); + CommonPythonUtil.extractPythonConfiguration( + planner.getExecEnv(), config, planner.getFlinkContext().getClassLoader()); OneInputTransformation<RowData, RowData> transform = createPythonOneInputTransformation( inputTransform, @@ -213,7 +214,8 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData> config, planner.getFlinkContext().getClassLoader()); - if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) { + if (CommonPythonUtil.isPythonWorkerUsingManagedMemory( + pythonConfig, planner.getFlinkContext().getClassLoader())) { transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON); } @@ -306,7 +308,7 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData> className = ARROW_PYTHON_OVER_WINDOW_ROWS_PROC_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME; } - Class<?> clazz = CommonPythonUtil.loadClass(className); + Class<?> clazz = CommonPythonUtil.loadClass(className, classLoader); try { Constructor<?> ctor = @@ -349,7 +351,7 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData> className = ARROW_PYTHON_OVER_WINDOW_RANGE_PROC_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME; } - Class<?> clazz = CommonPythonUtil.loadClass(className); + Class<?> clazz = CommonPythonUtil.loadClass(className, classLoader); try { Constructor<?> ctor = clazz.getConstructor( diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java index a949ad2afa6..201407b718a 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java @@ -98,9 +98,9 @@ public class CommonPythonUtil { private CommonPythonUtil() {} - public static Class<?> loadClass(String className) { + public static Class<?> loadClass(String className, ClassLoader classLoader) { try { - return Class.forName(className, false, Thread.currentThread().getContextClassLoader()); + return Class.forName(className, false, classLoader); } catch (ClassNotFoundException e) { throw new TableException( "The dependency of 'flink-python' is not present on the classpath.", e); @@ -108,8 +108,8 @@ public class CommonPythonUtil { } public static Configuration extractPythonConfiguration( - StreamExecutionEnvironment env, ReadableConfig tableConfig) { - Class<?> clazz = loadClass(PYTHON_CONFIG_UTILS_CLASS); + StreamExecutionEnvironment env, ReadableConfig tableConfig, ClassLoader classLoader) { + Class<?> clazz = loadClass(PYTHON_CONFIG_UTILS_CLASS, classLoader); try { StreamExecutionEnvironment realEnv = getRealEnvironment(env); Method method = @@ -125,20 +125,27 @@ public class CommonPythonUtil { } public static PythonFunctionInfo createPythonFunctionInfo( - RexCall pythonRexCall, Map<RexNode, Integer> inputNodes) { + RexCall pythonRexCall, Map<RexNode, Integer> inputNodes, ClassLoader classLoader) { SqlOperator operator = pythonRexCall.getOperator(); try { if (operator instanceof ScalarSqlFunction) { return createPythonFunctionInfo( - pythonRexCall, inputNodes, ((ScalarSqlFunction) operator).scalarFunction()); + pythonRexCall, + inputNodes, + ((ScalarSqlFunction) operator).scalarFunction(), + classLoader); } else if (operator instanceof TableSqlFunction) { return createPythonFunctionInfo( - pythonRexCall, inputNodes, ((TableSqlFunction) operator).udtf()); + pythonRexCall, + inputNodes, + ((TableSqlFunction) operator).udtf(), + classLoader); } else if (operator instanceof BridgingSqlFunction) { return createPythonFunctionInfo( pythonRexCall, inputNodes, - ((BridgingSqlFunction) operator).getDefinition()); + ((BridgingSqlFunction) operator).getDefinition(), + classLoader); } } catch (InvocationTargetException | IllegalAccessException e) { throw new TableException("Method pickleValue accessed failed. ", e); @@ -147,8 +154,9 @@ public class CommonPythonUtil { } @SuppressWarnings("unchecked") - public static boolean isPythonWorkerUsingManagedMemory(Configuration config) { - Class<?> clazz = loadClass(PYTHON_OPTIONS_CLASS); + public static boolean isPythonWorkerUsingManagedMemory( + Configuration config, ClassLoader classLoader) { + Class<?> clazz = loadClass(PYTHON_OPTIONS_CLASS, classLoader); try { return config.getBoolean( (ConfigOption<Boolean>) (clazz.getField("USE_MANAGED_MEMORY").get(null))); @@ -158,8 +166,9 @@ public class CommonPythonUtil { } @SuppressWarnings("unchecked") - public static boolean isPythonWorkerInProcessMode(Configuration config) { - Class<?> clazz = loadClass(PYTHON_OPTIONS_CLASS); + public static boolean isPythonWorkerInProcessMode( + Configuration config, ClassLoader classLoader) { + Class<?> clazz = loadClass(PYTHON_OPTIONS_CLASS, classLoader); try { return config.getString( (ConfigOption<String>) @@ -337,7 +346,8 @@ public class CommonPythonUtil { }); } - private static byte[] convertLiteralToPython(RexLiteral o, SqlTypeName typeName) + private static byte[] convertLiteralToPython( + RexLiteral o, SqlTypeName typeName, ClassLoader classLoader) throws InvocationTargetException, IllegalAccessException { byte type; Object value; @@ -396,16 +406,18 @@ public class CommonPythonUtil { throw new RuntimeException("Unsupported type " + typeName); } } - loadPickleValue(); + loadPickleValue(classLoader); return (byte[]) pickleValue.invoke(null, value, type); } - private static void loadPickleValue() { + private static void loadPickleValue(ClassLoader classLoader) { if (pickleValue == null) { synchronized (CommonPythonUtil.class) { if (pickleValue == null) { Class<?> clazz = - loadClass("org.apache.flink.api.common.python.PythonBridgeUtils"); + loadClass( + "org.apache.flink.api.common.python.PythonBridgeUtils", + classLoader); try { pickleValue = clazz.getMethod("pickleValue", Object.class, byte.class); } catch (NoSuchMethodException e) { @@ -419,18 +431,21 @@ public class CommonPythonUtil { private static PythonFunctionInfo createPythonFunctionInfo( RexCall pythonRexCall, Map<RexNode, Integer> inputNodes, - FunctionDefinition functionDefinition) + FunctionDefinition functionDefinition, + ClassLoader classLoader) throws InvocationTargetException, IllegalAccessException { ArrayList<Object> inputs = new ArrayList<>(); for (RexNode operand : pythonRexCall.getOperands()) { if (operand instanceof RexCall) { RexCall childPythonRexCall = (RexCall) operand; PythonFunctionInfo argPythonInfo = - createPythonFunctionInfo(childPythonRexCall, inputNodes); + createPythonFunctionInfo(childPythonRexCall, inputNodes, classLoader); inputs.add(argPythonInfo); } else if (operand instanceof RexLiteral) { RexLiteral literal = (RexLiteral) operand; - inputs.add(convertLiteralToPython(literal, literal.getType().getSqlTypeName())); + inputs.add( + convertLiteralToPython( + literal, literal.getType().getSqlTypeName(), classLoader)); } else { if (inputNodes.containsKey(operand)) { inputs.add(inputNodes.get(operand)); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java index 397476b0ebb..ad1b15e8ea8 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java @@ -25,6 +25,7 @@ import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalMatch; import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution; import org.apache.flink.table.planner.plan.utils.MatchUtil.AggregationPatternVariableFinder; import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor; +import org.apache.flink.table.planner.utils.ShortcutUtils; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptRule; @@ -87,7 +88,7 @@ public abstract class CommonPhysicalMatchRule extends ConverterRule { Class.forName( "org.apache.flink.cep.pattern.Pattern", false, - Thread.currentThread().getContextClassLoader()); + ShortcutUtils.unwrapContext(rel).getClassLoader()); } catch (ClassNotFoundException e) { throw new TableException( "MATCH RECOGNIZE clause requires flink-cep dependency to be present on the classpath.", diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala index 8fb8f80d37d..2681ed64fad 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala @@ -473,7 +473,7 @@ abstract class PlannerBase( tableConfig.set(TABLE_QUERY_CURRENT_DATABASE, currentDatabase) // We pass only the configuration to avoid reconfiguration with the rootConfiguration - getExecEnv.configure(tableConfig.getConfiguration, Thread.currentThread().getContextClassLoader) + getExecEnv.configure(tableConfig.getConfiguration, classLoader) // Use config parallelism to override env parallelism. val defaultParallelism = diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortRule.scala index b1b58d803c0..ef387d485d1 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortRule.scala @@ -20,7 +20,6 @@ package org.apache.flink.table.planner.plan.rules.physical.batch import org.apache.flink.annotation.Experimental import org.apache.flink.configuration.ConfigOption import org.apache.flink.configuration.ConfigOptions.key -import org.apache.flink.table.planner.calcite.FlinkContext import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalSort