This is an automated email from the ASF dual-hosted git repository.

kgyrtkirk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new e70bd7667f5 Refactor MSQSpec usages to reduce reliance on the native 
query (#17753)
e70bd7667f5 is described below

commit e70bd7667f545e4e4869ade73f04f1231bf7a8cd
Author: Zoltan Haindrich <[email protected]>
AuthorDate: Mon Feb 24 11:24:38 2025 +0100

    Refactor MSQSpec usages to reduce reliance on the native query (#17753)
    
    This is a preparation refactor step in the direction of moving the 
production point of the staged plan more closer to when the task is submitted.
    
    This change:
    * reduces the usages of the `MSQSpec#getQuery` usages as much as possible
    * remove usages of `MSQSpec` during the querykit planning phase
    * moved the top level QueryKit related planning code out from the 
`ControllerImpl` class into a standalone class
    * lots of minor changes to use `MSQSpec#getContext` instead of 
`MSQSpec.getQuery().getContext()`
    
    This change was produced with refactoring steps only - so it should work 
just like the old code.
---
 .../msq/dart/controller/DartControllerContext.java |   4 +-
 .../org/apache/druid/msq/exec/ControllerImpl.java  | 249 ++----------------
 .../druid/msq/exec/QueryKitBasedMSQPlanner.java    | 288 +++++++++++++++++++++
 .../msq/indexing/IndexerControllerContext.java     |  40 +--
 .../druid/msq/indexing/MSQControllerTask.java      |  44 +++-
 .../org/apache/druid/msq/indexing/MSQSpec.java     |   6 +
 .../apache/druid/msq/sql/MSQTaskQueryMaker.java    |   2 +-
 .../msq/sql/resources/SqlStatementResource.java    |   3 +-
 .../druid/msq/util/MSQTaskQueryMakerUtils.java     |  13 +-
 .../dart/controller/DartControllerContextTest.java |   3 +-
 .../msq/test/MSQTestOverlordServiceClient.java     |   2 +-
 11 files changed, 384 insertions(+), 270 deletions(-)

diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java
index 52f21518b31..b7cf087a3f4 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java
@@ -140,7 +140,7 @@ public class DartControllerContext implements 
ControllerContext
         );
 
     final int maxConcurrentStages = 
MultiStageQueryContext.getMaxConcurrentStagesWithDefault(
-        querySpec.getQuery().context(),
+        querySpec.getContext(),
         DEFAULT_MAX_CONCURRENT_STAGES
     );
 
@@ -225,7 +225,7 @@ public class DartControllerContext implements 
ControllerContext
       final ControllerQueryKernelConfig queryKernelConfig
   )
   {
-    final QueryContext queryContext = querySpec.getQuery().context();
+    final QueryContext queryContext = querySpec.getContext();
     return new QueryKitSpec(
         queryKit,
         queryId,
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
index 4c8cc067294..b370ab58643 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
@@ -114,7 +114,6 @@ import org.apache.druid.msq.indexing.error.MSQErrorReport;
 import org.apache.druid.msq.indexing.error.MSQException;
 import org.apache.druid.msq.indexing.error.MSQFault;
 import org.apache.druid.msq.indexing.error.MSQWarningReportLimiterPublisher;
-import org.apache.druid.msq.indexing.error.QueryNotSupportedFault;
 import org.apache.druid.msq.indexing.error.TooManyBucketsFault;
 import org.apache.druid.msq.indexing.error.TooManySegmentsInTimeChunkFault;
 import org.apache.druid.msq.indexing.error.TooManyWarningsFault;
@@ -130,7 +129,6 @@ import 
org.apache.druid.msq.indexing.report.MSQTaskReportPayload;
 import org.apache.druid.msq.input.InputSpec;
 import org.apache.druid.msq.input.InputSpecSlicer;
 import org.apache.druid.msq.input.InputSpecSlicerFactory;
-import org.apache.druid.msq.input.InputSpecs;
 import org.apache.druid.msq.input.MapInputSpecSlicer;
 import org.apache.druid.msq.input.external.ExternalInputSpec;
 import org.apache.druid.msq.input.external.ExternalInputSpecSlicer;
@@ -143,7 +141,6 @@ import org.apache.druid.msq.input.stage.StageInputSpec;
 import org.apache.druid.msq.input.stage.StageInputSpecSlicer;
 import org.apache.druid.msq.input.table.TableInputSpec;
 import org.apache.druid.msq.kernel.QueryDefinition;
-import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
 import org.apache.druid.msq.kernel.StageDefinition;
 import org.apache.druid.msq.kernel.StageId;
 import org.apache.druid.msq.kernel.StagePartition;
@@ -152,32 +149,18 @@ import 
org.apache.druid.msq.kernel.controller.ControllerQueryKernel;
 import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig;
 import org.apache.druid.msq.kernel.controller.ControllerStagePhase;
 import org.apache.druid.msq.kernel.controller.WorkerInputs;
-import org.apache.druid.msq.querykit.MultiQueryKit;
-import org.apache.druid.msq.querykit.QueryKit;
-import org.apache.druid.msq.querykit.QueryKitSpec;
 import org.apache.druid.msq.querykit.QueryKitUtils;
-import org.apache.druid.msq.querykit.ShuffleSpecFactory;
-import org.apache.druid.msq.querykit.WindowOperatorQueryKit;
-import org.apache.druid.msq.querykit.groupby.GroupByQueryKit;
-import 
org.apache.druid.msq.querykit.results.ExportResultsFrameProcessorFactory;
-import org.apache.druid.msq.querykit.results.QueryResultFrameProcessorFactory;
-import org.apache.druid.msq.querykit.scan.ScanQueryKit;
 import org.apache.druid.msq.shuffle.input.DurableStorageInputChannelFactory;
 import org.apache.druid.msq.shuffle.input.WorkerInputChannelFactory;
 import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
 import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation;
 import org.apache.druid.msq.util.IntervalUtils;
 import org.apache.druid.msq.util.MSQFutureUtils;
-import org.apache.druid.msq.util.MSQTaskQueryMakerUtils;
 import org.apache.druid.msq.util.MultiStageQueryContext;
-import org.apache.druid.query.Query;
 import org.apache.druid.query.QueryContext;
 import org.apache.druid.query.aggregation.AggregatorFactory;
 import org.apache.druid.query.groupby.GroupByQuery;
-import org.apache.druid.query.operator.WindowOperatorQuery;
-import org.apache.druid.query.scan.ScanQuery;
 import org.apache.druid.segment.IndexSpec;
-import org.apache.druid.segment.column.ColumnHolder;
 import org.apache.druid.segment.column.ColumnType;
 import org.apache.druid.segment.column.RowSignature;
 import org.apache.druid.segment.indexing.DataSchema;
@@ -187,8 +170,6 @@ import org.apache.druid.segment.transform.TransformSpec;
 import org.apache.druid.server.DruidNode;
 import org.apache.druid.sql.calcite.parser.DruidSqlInsert;
 import org.apache.druid.sql.calcite.planner.ColumnMappings;
-import org.apache.druid.sql.http.ResultFormat;
-import org.apache.druid.storage.ExportStorageProvider;
 import org.apache.druid.timeline.CompactionState;
 import org.apache.druid.timeline.DataSegment;
 import org.apache.druid.timeline.partition.DimensionRangeShardSpec;
@@ -207,7 +188,6 @@ import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -461,7 +441,7 @@ public class ControllerImpl implements Controller
       }
     }
 
-    boolean shouldWaitForSegmentLoad = 
MultiStageQueryContext.shouldWaitForSegmentLoad(querySpec.getQuery().context());
+    boolean shouldWaitForSegmentLoad = 
MultiStageQueryContext.shouldWaitForSegmentLoad(querySpec.getContext());
 
     try {
       if (MSQControllerTask.isIngestion(querySpec)) {
@@ -612,13 +592,10 @@ public class ControllerImpl implements Controller
     this.netClient = closer.register(new 
ExceptionWrappingWorkerClient(context.newWorkerClient()));
     this.queryKernelConfig = context.queryKernelConfig(queryId, querySpec);
 
-    final QueryContext queryContext = querySpec.getQuery().context();
-    final QueryDefinition queryDef = makeQueryDefinition(
-        context.makeQueryKitSpec(makeQueryControllerToolKit(queryContext), 
queryId, querySpec, queryKernelConfig),
-        querySpec,
-        context,
-        resultsContext
-    );
+    final QueryContext queryContext = querySpec.getContext();
+    QueryKitBasedMSQPlanner qkPlanner = new QueryKitBasedMSQPlanner(context, 
querySpec, resultsContext, queryKernelConfig, queryId);
+
+    final QueryDefinition queryDef = qkPlanner.makeQueryDefinition();
 
     if (log.isDebugEnabled()) {
       try {
@@ -680,7 +657,7 @@ public class ControllerImpl implements Controller
         netClient,
         workerManager,
         queryKernelConfig.isFaultTolerant(),
-        
MultiStageQueryContext.getSketchEncoding(querySpec.getQuery().context())
+        MultiStageQueryContext.getSketchEncoding(queryContext)
     );
     closer.register(workerSketchFetcher::close);
 
@@ -1261,25 +1238,6 @@ public class ControllerImpl implements Controller
     return null;
   }
 
-  @SuppressWarnings("rawtypes")
-  private QueryKit<Query<?>> makeQueryControllerToolKit(QueryContext 
queryContext)
-  {
-    final Map<Class<? extends Query>, QueryKit> kitMap =
-        ImmutableMap.<Class<? extends Query>, QueryKit>builder()
-                    .put(ScanQuery.class, new 
ScanQueryKit(context.jsonMapper()))
-                    .put(GroupByQuery.class, new 
GroupByQueryKit(context.jsonMapper()))
-                    .put(
-                        WindowOperatorQuery.class,
-                        new WindowOperatorQueryKit(
-                            context.jsonMapper(),
-                            
MultiStageQueryContext.isWindowFunctionOperatorTransformationEnabled(queryContext)
-                        )
-                    )
-                    .build();
-
-    return new MultiQueryKit(kitMap);
-  }
-
   /**
    * A blocking function used to contact multiple workers. Checks if all the 
workers are running before contacting them.
    *
@@ -1440,7 +1398,7 @@ public class ControllerImpl implements Controller
                  .submit(new 
MarkSegmentsAsUnusedAction(destination.getDataSource(), interval));
         }
       } else {
-        if 
(MultiStageQueryContext.shouldWaitForSegmentLoad(querySpec.getQuery().context()))
 {
+        if 
(MultiStageQueryContext.shouldWaitForSegmentLoad(querySpec.getContext())) {
           segmentLoadWaiter = new SegmentLoadStatusFetcher(
               context.injector().getInstance(BrokerClient.class),
               context.jsonMapper(),
@@ -1456,7 +1414,7 @@ public class ControllerImpl implements Controller
         );
       }
     } else if (!segments.isEmpty()) {
-      if 
(MultiStageQueryContext.shouldWaitForSegmentLoad(querySpec.getQuery().context()))
 {
+      if 
(MultiStageQueryContext.shouldWaitForSegmentLoad(querySpec.getContext())) {
         segmentLoadWaiter = new SegmentLoadStatusFetcher(
             context.injector().getInstance(BrokerClient.class),
             context.jsonMapper(),
@@ -1587,11 +1545,11 @@ public class ControllerImpl implements Controller
       @SuppressWarnings("unchecked")
       Set<DataSegment> segments = (Set<DataSegment>) 
queryKernel.getResultObjectForStage(finalStageId);
 
-      boolean storeCompactionState = 
QueryContext.of(querySpec.getQuery().getContext())
-                                                 .getBoolean(
-                                                     
Tasks.STORE_COMPACTION_STATE_KEY,
-                                                     
Tasks.DEFAULT_STORE_COMPACTION_STATE
-                                                 );
+      boolean storeCompactionState = querySpec.getContext()
+          .getBoolean(
+              Tasks.STORE_COMPACTION_STATE_KEY,
+              Tasks.DEFAULT_STORE_COMPACTION_STATE
+          );
 
       if (storeCompactionState) {
         DataSourceMSQDestination destination = (DataSourceMSQDestination) 
querySpec.getDestination();
@@ -1620,7 +1578,7 @@ public class ControllerImpl implements Controller
       }
       log.info("Query [%s] publishing %d segments.", queryDef.getQueryId(), 
segments.size());
       publishAllSegments(segments, compactionStateAnnotateFunction);
-    } else if (MSQControllerTask.isExport(querySpec)) {
+    } else if (MSQControllerTask.isExport(querySpec.getDestination())) {
       // Write manifest file.
       ExportMSQDestination destination = (ExportMSQDestination) 
querySpec.getDestination();
       ExportMetadataManager exportMetadataManager = new ExportMetadataManager(
@@ -1704,7 +1662,7 @@ public class ControllerImpl implements Controller
 
     GranularitySpec granularitySpec = new UniformGranularitySpec(
         segmentGranularity,
-        QueryContext.of(querySpec.getQuery().getContext())
+        querySpec.getContext()
                     
.getGranularity(DruidSqlInsert.SQL_INSERT_QUERY_GRANULARITY, jsonMapper),
         dataSchema.getGranularitySpec().isRollup(),
         // Not using dataSchema.getGranularitySpec().inputIntervals() as that 
always has ETERNITY
@@ -1769,173 +1727,6 @@ public class ControllerImpl implements Controller
     }
   }
 
-  @SuppressWarnings("unchecked")
-  private static QueryDefinition makeQueryDefinition(
-      final QueryKitSpec queryKitSpec,
-      final MSQSpec querySpec,
-      final ControllerContext controllerContext,
-      final ResultsContext resultsContext
-  )
-  {
-    final ObjectMapper jsonMapper = controllerContext.jsonMapper();
-    final MSQTuningConfig tuningConfig = querySpec.getTuningConfig();
-    final ColumnMappings columnMappings = querySpec.getColumnMappings();
-    final Query<?> queryToPlan;
-    final ShuffleSpecFactory resultShuffleSpecFactory;
-
-    if (MSQControllerTask.isIngestion(querySpec)) {
-      resultShuffleSpecFactory = querySpec.getDestination()
-                                          
.getShuffleSpecFactory(tuningConfig.getRowsPerSegment());
-
-      if (!columnMappings.hasUniqueOutputColumnNames()) {
-        // We do not expect to hit this case in production, because the SQL 
validator checks that column names
-        // are unique for INSERT and REPLACE statements (i.e. anything where 
MSQControllerTask.isIngestion would
-        // be true). This check is here as defensive programming.
-        throw new ISE("Column names are not unique: [%s]", 
columnMappings.getOutputColumnNames());
-      }
-
-      MSQTaskQueryMakerUtils.validateRealtimeReindex(querySpec);
-
-      if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) {
-        // We know there's a single time column, because we've checked 
columnMappings.hasUniqueOutputColumnNames().
-        final int timeColumn = 
columnMappings.getOutputColumnsByName(ColumnHolder.TIME_COLUMN_NAME).getInt(0);
-        queryToPlan = querySpec.getQuery().withOverriddenContext(
-            ImmutableMap.of(
-                QueryKitUtils.CTX_TIME_COLUMN_NAME,
-                columnMappings.getQueryColumnName(timeColumn)
-            )
-        );
-      } else {
-        queryToPlan = querySpec.getQuery();
-      }
-    } else {
-      resultShuffleSpecFactory =
-          querySpec.getDestination()
-                   
.getShuffleSpecFactory(MultiStageQueryContext.getRowsPerPage(querySpec.getQuery().context()));
-      queryToPlan = querySpec.getQuery();
-    }
-
-    final QueryDefinition queryDef;
-
-    try {
-      queryDef = queryKitSpec.getQueryKit().makeQueryDefinition(
-          queryKitSpec,
-          queryToPlan,
-          resultShuffleSpecFactory,
-          0
-      );
-    }
-    catch (MSQException e) {
-      // If the toolkit throws a MSQFault, don't wrap it in a more generic 
QueryNotSupportedFault
-      throw e;
-    }
-    catch (Exception e) {
-      throw new MSQException(e, QueryNotSupportedFault.INSTANCE);
-    }
-
-    if (MSQControllerTask.isIngestion(querySpec)) {
-      // Find the stage that provides shuffled input to the final 
segment-generation stage.
-      StageDefinition finalShuffleStageDef = 
queryDef.getFinalStageDefinition();
-
-      while (!finalShuffleStageDef.doesShuffle()
-             && 
InputSpecs.getStageNumbers(finalShuffleStageDef.getInputSpecs()).size() == 1) {
-        finalShuffleStageDef = queryDef.getStageDefinition(
-            
Iterables.getOnlyElement(InputSpecs.getStageNumbers(finalShuffleStageDef.getInputSpecs()))
-        );
-      }
-
-      if (!finalShuffleStageDef.doesShuffle()) {
-        finalShuffleStageDef = null;
-      }
-
-      // Add all query stages.
-      // Set shuffleCheckHasMultipleValues on the stage that serves as input 
to the final segment-generation stage.
-      final QueryDefinitionBuilder builder = 
QueryDefinition.builder(queryKitSpec.getQueryId());
-
-      for (final StageDefinition stageDef : queryDef.getStageDefinitions()) {
-        if (stageDef.equals(finalShuffleStageDef)) {
-          
builder.add(StageDefinition.builder(stageDef).shuffleCheckHasMultipleValues(true));
-        } else {
-          builder.add(StageDefinition.builder(stageDef));
-        }
-      }
-
-      final DataSourceMSQDestination destination = (DataSourceMSQDestination) 
querySpec.getDestination();
-      return builder.add(
-                        destination.getTerminalStageSpec()
-                                   .constructFinalStage(
-                                       queryDef,
-                                       querySpec,
-                                       jsonMapper
-                                   )
-                    )
-                    .build();
-    } else if (MSQControllerTask.writeFinalResultsToTaskReport(querySpec)) {
-      return queryDef;
-    } else if 
(MSQControllerTask.writeFinalStageResultsToDurableStorage(querySpec)) {
-
-      // attaching new query results stage if the final stage does sort during 
shuffle so that results are ordered.
-      StageDefinition finalShuffleStageDef = 
queryDef.getFinalStageDefinition();
-      if (finalShuffleStageDef.doesSortDuringShuffle()) {
-        final QueryDefinitionBuilder builder = 
QueryDefinition.builder(queryKitSpec.getQueryId());
-        builder.addAll(queryDef);
-        builder.add(StageDefinition.builder(queryDef.getNextStageNumber())
-                                   .inputs(new 
StageInputSpec(queryDef.getFinalStageDefinition().getStageNumber()))
-                                   
.maxWorkerCount(tuningConfig.getMaxNumWorkers())
-                                   
.signature(finalShuffleStageDef.getSignature())
-                                   .shuffleSpec(null)
-                                   .processorFactory(new 
QueryResultFrameProcessorFactory())
-        );
-        return builder.build();
-      } else {
-        return queryDef;
-      }
-    } else if (MSQControllerTask.isExport(querySpec)) {
-      final ExportMSQDestination exportMSQDestination = (ExportMSQDestination) 
querySpec.getDestination();
-      final ExportStorageProvider exportStorageProvider = 
exportMSQDestination.getExportStorageProvider();
-
-      try {
-        // Check that the export destination is empty as a sanity check. We 
want to avoid modifying any other files with export.
-        Iterator<String> filesIterator = 
exportStorageProvider.createStorageConnector(controllerContext.taskTempDir())
-                                                              .listDir("");
-        if (filesIterator.hasNext()) {
-          throw DruidException.forPersona(DruidException.Persona.USER)
-                              
.ofCategory(DruidException.Category.RUNTIME_FAILURE)
-                              .build(
-                                  "Found files at provided export 
destination[%s]. Export is only allowed to "
-                                  + "an empty path. Please provide an empty 
path/subdirectory or move the existing files.",
-                                  exportStorageProvider.getBasePath()
-                              );
-        }
-      }
-      catch (IOException e) {
-        throw DruidException.forPersona(DruidException.Persona.USER)
-                            
.ofCategory(DruidException.Category.RUNTIME_FAILURE)
-                            .build(e, "Exception occurred while connecting to 
export destination.");
-      }
-
-      final ResultFormat resultFormat = exportMSQDestination.getResultFormat();
-      final QueryDefinitionBuilder builder = 
QueryDefinition.builder(queryKitSpec.getQueryId());
-      builder.addAll(queryDef);
-      builder.add(StageDefinition.builder(queryDef.getNextStageNumber())
-                                 .inputs(new 
StageInputSpec(queryDef.getFinalStageDefinition().getStageNumber()))
-                                 
.maxWorkerCount(tuningConfig.getMaxNumWorkers())
-                                 
.signature(queryDef.getFinalStageDefinition().getSignature())
-                                 .shuffleSpec(null)
-                                 .processorFactory(new 
ExportResultsFrameProcessorFactory(
-                                     queryKitSpec.getQueryId(),
-                                     exportStorageProvider,
-                                     resultFormat,
-                                     columnMappings,
-                                     resultsContext
-                                 ))
-      );
-      return builder.build();
-    } else {
-      throw new ISE("Unsupported destination [%s]", 
querySpec.getDestination());
-    }
-  }
-
   private static String getDataSourceForIngestion(final MSQSpec querySpec)
   {
     return ((DataSourceMSQDestination) 
querySpec.getDestination()).getDataSource();
@@ -2542,7 +2333,7 @@ public class ControllerImpl implements Controller
     private void startStages() throws IOException, InterruptedException
     {
       final long maxInputBytesPerWorker =
-          
MultiStageQueryContext.getMaxInputBytesPerWorker(querySpec.getQuery().context());
+          
MultiStageQueryContext.getMaxInputBytesPerWorker(querySpec.getContext());
 
       logKernelStatus(queryDef.getQueryId(), queryKernel);
 
@@ -2602,7 +2393,7 @@ public class ControllerImpl implements Controller
       final StageId shuffleStageId = new StageId(queryDef.getQueryId(), 
shuffleStageNumber);
 
       final boolean isFailOnEmptyInsertEnabled =
-          
MultiStageQueryContext.isFailOnEmptyInsertEnabled(querySpec.getQuery().context());
+          
MultiStageQueryContext.isFailOnEmptyInsertEnabled(querySpec.getContext());
       final Boolean isShuffleStageOutputEmpty = 
queryKernel.isStageOutputEmpty(shuffleStageId);
       if (isFailOnEmptyInsertEnabled && 
Boolean.TRUE.equals(isShuffleStageOutputEmpty)) {
         throw new MSQException(new 
InsertCannotBeEmptyFault(getDataSourceForIngestion(querySpec)));
@@ -2769,12 +2560,12 @@ public class ControllerImpl implements Controller
 
       final InputChannelFactory inputChannelFactory;
 
-      if (queryKernelConfig.isDurableStorage() || 
MSQControllerTask.writeFinalStageResultsToDurableStorage(querySpec)) {
+      if (queryKernelConfig.isDurableStorage() || 
MSQControllerTask.writeFinalStageResultsToDurableStorage(querySpec.getDestination()))
 {
         inputChannelFactory = 
DurableStorageInputChannelFactory.createStandardImplementation(
             queryId(),
             MSQTasks.makeStorageConnector(context.injector()),
             closer,
-            MSQControllerTask.writeFinalStageResultsToDurableStorage(querySpec)
+            
MSQControllerTask.writeFinalStageResultsToDurableStorage(querySpec.getDestination())
         );
       } else {
         inputChannelFactory = new WorkerInputChannelFactory(netClient, () -> 
taskIds);
@@ -2794,7 +2585,7 @@ public class ControllerImpl implements Controller
             resultReaderExec,
             RESULT_READER_CANCELLATION_ID,
             null,
-            
MultiStageQueryContext.removeNullBytes(querySpec.getQuery().context())
+            MultiStageQueryContext.removeNullBytes(querySpec.getContext())
         );
 
         resultsChannel = ReadableConcatFrameChannel.open(
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/QueryKitBasedMSQPlanner.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/QueryKitBasedMSQPlanner.java
new file mode 100644
index 00000000000..fbeae38151b
--- /dev/null
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/QueryKitBasedMSQPlanner.java
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.exec;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import org.apache.druid.error.DruidException;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.msq.indexing.MSQControllerTask;
+import org.apache.druid.msq.indexing.MSQSpec;
+import org.apache.druid.msq.indexing.MSQTuningConfig;
+import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination;
+import org.apache.druid.msq.indexing.destination.ExportMSQDestination;
+import org.apache.druid.msq.indexing.destination.MSQDestination;
+import org.apache.druid.msq.indexing.error.MSQException;
+import org.apache.druid.msq.indexing.error.QueryNotSupportedFault;
+import org.apache.druid.msq.input.InputSpecs;
+import org.apache.druid.msq.input.stage.StageInputSpec;
+import org.apache.druid.msq.kernel.QueryDefinition;
+import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
+import org.apache.druid.msq.kernel.StageDefinition;
+import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig;
+import org.apache.druid.msq.querykit.MultiQueryKit;
+import org.apache.druid.msq.querykit.QueryKit;
+import org.apache.druid.msq.querykit.QueryKitSpec;
+import org.apache.druid.msq.querykit.QueryKitUtils;
+import org.apache.druid.msq.querykit.ShuffleSpecFactory;
+import org.apache.druid.msq.querykit.WindowOperatorQueryKit;
+import org.apache.druid.msq.querykit.groupby.GroupByQueryKit;
+import 
org.apache.druid.msq.querykit.results.ExportResultsFrameProcessorFactory;
+import org.apache.druid.msq.querykit.results.QueryResultFrameProcessorFactory;
+import org.apache.druid.msq.querykit.scan.ScanQueryKit;
+import org.apache.druid.msq.util.MSQTaskQueryMakerUtils;
+import org.apache.druid.msq.util.MultiStageQueryContext;
+import org.apache.druid.query.Query;
+import org.apache.druid.query.QueryContext;
+import org.apache.druid.query.groupby.GroupByQuery;
+import org.apache.druid.query.operator.WindowOperatorQuery;
+import org.apache.druid.query.scan.ScanQuery;
+import org.apache.druid.segment.column.ColumnHolder;
+import org.apache.druid.sql.calcite.planner.ColumnMappings;
+import org.apache.druid.sql.http.ResultFormat;
+import org.apache.druid.storage.ExportStorageProvider;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.Map;
+
+public class QueryKitBasedMSQPlanner
+{
+  private final ControllerContext context;
+  private final MSQSpec querySpec;
+  private final ResultsContext resultsContext;
+  private final QueryKitSpec queryKitSpec;
+  private final ObjectMapper jsonMapper;
+  private final MSQTuningConfig tuningConfig;
+  private final ColumnMappings columnMappings;
+  private final MSQDestination destination;
+  private final QueryContext queryContext;
+  private final Query<?> query;
+
+
+  public QueryKitBasedMSQPlanner(ControllerContext context, MSQSpec querySpec, 
ResultsContext resultsContext,
+      ControllerQueryKernelConfig queryKernelConfig, String queryId)
+  {
+    this.context = context;
+    this.querySpec = querySpec;
+    this.jsonMapper = context.jsonMapper();
+    this.tuningConfig = querySpec.getTuningConfig();
+    this.columnMappings = querySpec.getColumnMappings();
+    this.destination = querySpec.getDestination();
+    this.queryContext = querySpec.getContext();
+    this.query = querySpec.getQuery();
+    this.resultsContext = resultsContext;
+    this.queryKitSpec = context.makeQueryKitSpec(
+        makeQueryControllerToolKit(querySpec.getContext(), 
context.jsonMapper()), queryId, querySpec,
+        queryKernelConfig
+    );
+  }
+
+  @SuppressWarnings("rawtypes")
+  static QueryKit<Query<?>> makeQueryControllerToolKit(QueryContext 
queryContext, ObjectMapper jsonMapper)
+  {
+    final Map<Class<? extends Query>, QueryKit> kitMap =
+        ImmutableMap.<Class<? extends Query>, QueryKit>builder()
+                    .put(ScanQuery.class, new ScanQueryKit(jsonMapper))
+                    .put(GroupByQuery.class, new GroupByQueryKit(jsonMapper))
+                    .put(
+                        WindowOperatorQuery.class,
+                        new WindowOperatorQueryKit(
+                            jsonMapper,
+                            
MultiStageQueryContext.isWindowFunctionOperatorTransformationEnabled(queryContext)
+                        )
+                    )
+                    .build();
+
+    return new MultiQueryKit(kitMap);
+  }
+
+  @SuppressWarnings("unchecked")
+  QueryDefinition makeQueryDefinition()
+  {
+    boolean ingestion = MSQControllerTask.isIngestion(destination);
+    final Query<?> queryToPlan;
+    final ShuffleSpecFactory resultShuffleSpecFactory;
+
+    if (ingestion) {
+      resultShuffleSpecFactory = destination
+          .getShuffleSpecFactory(tuningConfig.getRowsPerSegment());
+
+      if (!columnMappings.hasUniqueOutputColumnNames()) {
+        // We do not expect to hit this case in production, because the SQL 
validator checks that column names
+        // are unique for INSERT and REPLACE statements (i.e. anything where 
MSQControllerTask.isIngestion would
+        // be true). This check is here as defensive programming.
+        throw new ISE("Column names are not unique: [%s]", 
columnMappings.getOutputColumnNames());
+      }
+
+      MSQTaskQueryMakerUtils.validateRealtimeReindex(queryContext, 
destination, query);
+
+      if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) {
+        // We know there's a single time column, because we've checked 
columnMappings.hasUniqueOutputColumnNames().
+        final int timeColumn = 
columnMappings.getOutputColumnsByName(ColumnHolder.TIME_COLUMN_NAME).getInt(0);
+        queryToPlan = query.withOverriddenContext(
+            ImmutableMap.of(
+                QueryKitUtils.CTX_TIME_COLUMN_NAME,
+                columnMappings.getQueryColumnName(timeColumn)
+            )
+        );
+      } else {
+        queryToPlan = query;
+      }
+    } else {
+      resultShuffleSpecFactory =
+          destination
+                   
.getShuffleSpecFactory(MultiStageQueryContext.getRowsPerPage(query.context()));
+      queryToPlan = query;
+    }
+
+    final QueryDefinition queryDef;
+
+    try {
+      queryDef = queryKitSpec.getQueryKit().makeQueryDefinition(
+          queryKitSpec,
+          queryToPlan,
+          resultShuffleSpecFactory,
+          0
+      );
+    }
+    catch (MSQException e) {
+      // If the toolkit throws a MSQFault, don't wrap it in a more generic 
QueryNotSupportedFault
+      throw e;
+    }
+    catch (Exception e) {
+      throw new MSQException(e, QueryNotSupportedFault.INSTANCE);
+    }
+
+    if (ingestion) {
+      // Find the stage that provides shuffled input to the final 
segment-generation stage.
+      StageDefinition finalShuffleStageDef = 
queryDef.getFinalStageDefinition();
+
+      while (!finalShuffleStageDef.doesShuffle()
+             && 
InputSpecs.getStageNumbers(finalShuffleStageDef.getInputSpecs()).size() == 1) {
+        finalShuffleStageDef = queryDef.getStageDefinition(
+            
Iterables.getOnlyElement(InputSpecs.getStageNumbers(finalShuffleStageDef.getInputSpecs()))
+        );
+      }
+
+      if (!finalShuffleStageDef.doesShuffle()) {
+        finalShuffleStageDef = null;
+      }
+
+      // Add all query stages.
+      // Set shuffleCheckHasMultipleValues on the stage that serves as input 
to the final segment-generation stage.
+      final QueryDefinitionBuilder builder = 
QueryDefinition.builder(queryKitSpec.getQueryId());
+
+      for (final StageDefinition stageDef : queryDef.getStageDefinitions()) {
+        if (stageDef.equals(finalShuffleStageDef)) {
+          
builder.add(StageDefinition.builder(stageDef).shuffleCheckHasMultipleValues(true));
+        } else {
+          builder.add(StageDefinition.builder(stageDef));
+        }
+      }
+
+      final DataSourceMSQDestination destination1 = (DataSourceMSQDestination) 
destination;
+      return builder.add(
+                        destination1.getTerminalStageSpec()
+                                   .constructFinalStage(
+                                       queryDef,
+                                       querySpec,
+                                       jsonMapper
+                                   )
+                    )
+                    .build();
+    } else if (MSQControllerTask.writeFinalResultsToTaskReport(destination)) {
+      return queryDef;
+    } else if 
(MSQControllerTask.writeFinalStageResultsToDurableStorage(destination)) {
+
+      // attaching new query results stage if the final stage does sort during 
shuffle so that results are ordered.
+      StageDefinition finalShuffleStageDef = 
queryDef.getFinalStageDefinition();
+      if (finalShuffleStageDef.doesSortDuringShuffle()) {
+        final QueryDefinitionBuilder builder = 
QueryDefinition.builder(queryKitSpec.getQueryId());
+        builder.addAll(queryDef);
+        builder.add(StageDefinition.builder(queryDef.getNextStageNumber())
+                                   .inputs(new 
StageInputSpec(queryDef.getFinalStageDefinition().getStageNumber()))
+                                   
.maxWorkerCount(tuningConfig.getMaxNumWorkers())
+                                   
.signature(finalShuffleStageDef.getSignature())
+                                   .shuffleSpec(null)
+                                   .processorFactory(new 
QueryResultFrameProcessorFactory())
+        );
+        return builder.build();
+      } else {
+        return queryDef;
+      }
+    } else if (MSQControllerTask.isExport(destination)) {
+      final ExportMSQDestination exportMSQDestination = (ExportMSQDestination) 
destination;
+      final ExportStorageProvider exportStorageProvider = 
exportMSQDestination.getExportStorageProvider();
+
+      ensureExportLocationEmpty(context, destination);
+
+      final ResultFormat resultFormat = exportMSQDestination.getResultFormat();
+      final QueryDefinitionBuilder builder = 
QueryDefinition.builder(queryKitSpec.getQueryId());
+      builder.addAll(queryDef);
+      builder.add(StageDefinition.builder(queryDef.getNextStageNumber())
+                                 .inputs(new 
StageInputSpec(queryDef.getFinalStageDefinition().getStageNumber()))
+                                 
.maxWorkerCount(tuningConfig.getMaxNumWorkers())
+                                 
.signature(queryDef.getFinalStageDefinition().getSignature())
+                                 .shuffleSpec(null)
+                                 .processorFactory(new 
ExportResultsFrameProcessorFactory(
+                                     queryKitSpec.getQueryId(),
+                                     exportStorageProvider,
+                                     resultFormat,
+                                     columnMappings,
+                                     resultsContext
+                                 ))
+      );
+      return builder.build();
+    } else {
+      throw new ISE("Unsupported destination [%s]", destination);
+    }
+  }
+
+  public static void ensureExportLocationEmpty(final ControllerContext 
context, final MSQDestination destination)
+  {
+    if (MSQControllerTask.isExport(destination)) {
+      final ExportMSQDestination exportMSQDestination = (ExportMSQDestination) 
destination;
+      final ExportStorageProvider exportStorageProvider = 
exportMSQDestination.getExportStorageProvider();
+
+      try {
+        // Check that the export destination is empty as a sanity check. We 
want
+        // to avoid modifying any other files with export.
+        Iterator<String> filesIterator = 
exportStorageProvider.createStorageConnector(context.taskTempDir())
+            .listDir("");
+        if (filesIterator.hasNext()) {
+          throw DruidException.forPersona(DruidException.Persona.USER)
+              .ofCategory(DruidException.Category.RUNTIME_FAILURE)
+              .build(
+                  "Found files at provided export destination[%s]. Export is 
only allowed to "
+                      + "an empty path. Please provide an empty 
path/subdirectory or move the existing files.",
+                  exportStorageProvider.getBasePath()
+              );
+        }
+      }
+      catch (IOException e) {
+        throw DruidException.forPersona(DruidException.Persona.USER)
+            .ofCategory(DruidException.Category.RUNTIME_FAILURE)
+            .build(e, "Exception occurred while connecting to export 
destination.");
+      }
+    }
+  }
+}
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java
index e74ec7f2a66..a78aa0648eb 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/IndexerControllerContext.java
@@ -27,7 +27,6 @@ import org.apache.druid.guice.annotations.Self;
 import org.apache.druid.indexing.common.TaskLockType;
 import org.apache.druid.indexing.common.TaskToolbox;
 import org.apache.druid.indexing.common.actions.TaskActionClient;
-import org.apache.druid.indexing.common.task.IndexTaskUtils;
 import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.java.util.common.io.Closer;
 import org.apache.druid.java.util.common.logger.Logger;
@@ -78,7 +77,10 @@ public class IndexerControllerContext implements 
ControllerContext
 
   private static final Logger log = new Logger(IndexerControllerContext.class);
 
-  private final MSQControllerTask task;
+  private final TaskLockType taskLockType;
+  private final String taskDataSource;
+  private final QueryContext taskQuerySpecContext;
+  private final Map<String, Object> taskContext;
   private final TaskToolbox toolbox;
   private final Injector injector;
   private final ServiceClientFactory clientFactory;
@@ -86,21 +88,29 @@ public class IndexerControllerContext implements 
ControllerContext
   private final ServiceMetricEvent.Builder metricBuilder;
   private final MemoryIntrospector memoryIntrospector;
 
+
+
   public IndexerControllerContext(
-      final MSQControllerTask task,
+      final TaskLockType taskLockType,
+      final String taskDataSource,
+      final QueryContext taskQuerySpecContext,
+      final Map<String, Object> taskContext,
+      final ServiceMetricEvent.Builder metricBuilder,
       final TaskToolbox toolbox,
       final Injector injector,
       final ServiceClientFactory clientFactory,
       final OverlordClient overlordClient
   )
   {
-    this.task = task;
+    this.taskLockType = taskLockType;
+    this.taskDataSource = taskDataSource;
+    this.taskQuerySpecContext = taskQuerySpecContext;
+    this.taskContext = taskContext;
     this.toolbox = toolbox;
     this.clientFactory = clientFactory;
     this.overlordClient = overlordClient;
-    this.metricBuilder = new ServiceMetricEvent.Builder();
+    this.metricBuilder = metricBuilder;
     this.memoryIntrospector = injector.getInstance(MemoryIntrospector.class);
-    IndexTaskUtils.setTaskDimensions(metricBuilder, task);
     final StorageConnectorProvider storageConnectorProvider = 
injector.getInstance(Key.get(StorageConnectorProvider.class, 
MultiStageQuery.class));
     final StorageConnector storageConnector = 
storageConnectorProvider.createStorageConnector(toolbox.getIndexingTmpDir());
     this.injector = injector.createChildInjector(
@@ -164,7 +174,7 @@ public class IndexerControllerContext implements 
ControllerContext
   public InputSpecSlicer newTableInputSpecSlicer(final WorkerManager 
workerManager)
   {
     final SegmentSource includeSegmentSource =
-        
MultiStageQueryContext.getSegmentSources(task.getQuerySpec().getQuery().context());
+        MultiStageQueryContext.getSegmentSources(taskQuerySpecContext);
     return new IndexerTableInputSpecSlicer(
         toolbox.getCoordinatorClient(),
         toolbox.getTaskActionClient(),
@@ -181,7 +191,7 @@ public class IndexerControllerContext implements 
ControllerContext
   @Override
   public TaskLockType taskLockType()
   {
-    return task.getTaskLockType();
+    return taskLockType;
   }
 
   @Override
@@ -195,7 +205,7 @@ public class IndexerControllerContext implements 
ControllerContext
   {
     ChatHandler chatHandler = new ControllerChatHandler(
         controller,
-        task.getDataSource(),
+        taskDataSource,
         toolbox.getAuthorizerMapper()
     );
     toolbox.getChatHandlerProvider().register(controller.queryId(), 
chatHandler, false);
@@ -212,10 +222,10 @@ public class IndexerControllerContext implements 
ControllerContext
   {
     return new MSQWorkerTaskLauncher(
         queryId,
-        task.getDataSource(),
+        taskDataSource,
         overlordClient,
         workerFailureListener,
-        makeTaskContext(querySpec, queryKernelConfig, task.getContext()),
+        makeTaskContext(querySpec, queryKernelConfig, taskContext),
         // 10 minutes +- 2 minutes jitter
         TimeUnit.SECONDS.toMillis(600 + 
ThreadLocalRandom.current().nextInt(-4, 5) * 30L),
         new MSQWorkerTaskLauncherConfig()
@@ -245,7 +255,7 @@ public class IndexerControllerContext implements 
ControllerContext
         // Assume tasks are symmetric: workers have the same number of 
processors available as a controller.
         // Create one partition per processor per task, for maximum 
parallelism.
         MultiStageQueryContext.getTargetPartitionsPerWorkerWithDefault(
-            querySpec.getQuery().context(),
+            querySpec.getContext(),
             memoryIntrospector.numProcessingThreads()
         )
     );
@@ -259,7 +269,7 @@ public class IndexerControllerContext implements 
ControllerContext
       final ControllerMemoryParameters memoryParameters
   )
   {
-    final QueryContext queryContext = querySpec.getQuery().context();
+    final QueryContext queryContext = querySpec.getContext();
     final int maxConcurrentStages =
         MultiStageQueryContext.getMaxConcurrentStagesWithDefault(queryContext, 
DEFAULT_MAX_CONCURRENT_STAGES);
     final boolean isFaultToleranceEnabled = 
MultiStageQueryContext.isFaultToleranceEnabled(queryContext);
@@ -313,7 +323,7 @@ public class IndexerControllerContext implements 
ControllerContext
       final int maxConcurrentStages
   )
   {
-    final QueryContext queryContext = querySpec.getQuery().context();
+    final QueryContext queryContext = querySpec.getContext();
     final long maxParseExceptions = 
MultiStageQueryContext.getMaxParseExceptions(queryContext);
     final boolean removeNullBytes = 
MultiStageQueryContext.removeNullBytes(queryContext);
     final boolean includeAllCounters = 
MultiStageQueryContext.getIncludeAllCounters(queryContext);
@@ -322,7 +332,7 @@ public class IndexerControllerContext implements 
ControllerContext
     builder
         .put(MultiStageQueryContext.CTX_DURABLE_SHUFFLE_STORAGE, 
durableStorageEnabled)
         .put(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED, maxParseExceptions)
-        .put(MultiStageQueryContext.CTX_IS_REINDEX, 
MSQControllerTask.isReplaceInputDataSourceTask(querySpec))
+        .put(MultiStageQueryContext.CTX_IS_REINDEX, 
MSQControllerTask.isReplaceInputDataSourceTask(querySpec.getQuery(), 
querySpec.getDestination()))
         .put(MultiStageQueryContext.CTX_MAX_CONCURRENT_STAGES, 
maxConcurrentStages)
         .put(MultiStageQueryContext.CTX_REMOVE_NULL_BYTES, removeNullBytes)
         .put(MultiStageQueryContext.CTX_INCLUDE_ALL_COUNTERS, 
includeAllCounters);
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java
index 6e53542c3a4..d6d755b1fa9 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java
@@ -40,9 +40,12 @@ import 
org.apache.druid.indexing.common.actions.TaskActionClient;
 import org.apache.druid.indexing.common.actions.TimeChunkLockTryAcquireAction;
 import org.apache.druid.indexing.common.config.TaskConfig;
 import org.apache.druid.indexing.common.task.AbstractTask;
+import org.apache.druid.indexing.common.task.IndexTaskUtils;
 import org.apache.druid.indexing.common.task.PendingSegmentAllocatingTask;
 import org.apache.druid.indexing.common.task.Tasks;
 import org.apache.druid.java.util.common.logger.Logger;
+import org.apache.druid.java.util.emitter.service.ServiceMetricEvent;
+import org.apache.druid.java.util.emitter.service.ServiceMetricEvent.Builder;
 import org.apache.druid.msq.exec.Controller;
 import org.apache.druid.msq.exec.ControllerContext;
 import org.apache.druid.msq.exec.ControllerImpl;
@@ -54,6 +57,7 @@ import 
org.apache.druid.msq.indexing.destination.ExportMSQDestination;
 import org.apache.druid.msq.indexing.destination.MSQDestination;
 import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
 import org.apache.druid.msq.util.MultiStageQueryContext;
+import org.apache.druid.query.Query;
 import org.apache.druid.query.QueryContext;
 import org.apache.druid.query.QueryContexts;
 import org.apache.druid.rpc.ServiceClientFactory;
@@ -266,8 +270,14 @@ public class MSQControllerTask extends AbstractTask 
implements ClientTaskQuery,
         injector.getInstance(Key.get(ServiceClientFactory.class, 
EscalatedGlobal.class));
     final OverlordClient overlordClient = 
injector.getInstance(OverlordClient.class)
                                                   
.withRetryPolicy(StandardRetryPolicy.unlimited());
+    Builder metricBuilder = new ServiceMetricEvent.Builder();
+    IndexTaskUtils.setTaskDimensions(metricBuilder, this);
     final ControllerContext context = new IndexerControllerContext(
-        this,
+        this.getTaskLockType(),
+        this.getDataSource(),
+        this.getQuerySpec().getContext(),
+        this.getContext(),
+        metricBuilder,
         toolbox,
         injector,
         clientFactory,
@@ -316,7 +326,7 @@ public class MSQControllerTask extends AbstractTask 
implements ClientTaskQuery,
               // Use the task context and override with the query context
               QueryContexts.override(
                   getContext(),
-                  querySpec.getQuery().getContext()
+                  querySpec.getContext().asMap()
               )
           ),
           ((DataSourceMSQDestination) 
querySpec.getDestination()).isReplaceTimeChunks()
@@ -351,42 +361,50 @@ public class MSQControllerTask extends AbstractTask 
implements ClientTaskQuery,
    */
   public static boolean isIngestion(final MSQSpec querySpec)
   {
-    return querySpec.getDestination() instanceof DataSourceMSQDestination;
+    return isIngestion(querySpec.getDestination());
+  }
+
+  /**
+   * Checks whether the task is an ingestion into a Druid datasource.
+   */
+  public static boolean isIngestion(MSQDestination destination)
+  {
+    return destination instanceof DataSourceMSQDestination;
   }
 
   /**
    * Checks whether the task is an export into external files.
    */
-  public static boolean isExport(final MSQSpec querySpec)
+  public static boolean isExport(MSQDestination destination)
   {
-    return querySpec.getDestination() instanceof ExportMSQDestination;
+    return destination instanceof ExportMSQDestination;
   }
 
   /**
    * Checks whether the task is an async query which writes frame files 
containing the final results into durable storage.
    */
-  public static boolean writeFinalStageResultsToDurableStorage(final MSQSpec 
querySpec)
+  public static boolean writeFinalStageResultsToDurableStorage(final 
MSQDestination destination)
   {
-    return querySpec.getDestination() instanceof DurableStorageMSQDestination;
+    return destination instanceof DurableStorageMSQDestination;
   }
 
   /**
    * Checks whether the task is an async query which writes frame files 
containing the final results into durable storage.
    */
-  public static boolean writeFinalResultsToTaskReport(final MSQSpec querySpec)
+  public static boolean writeFinalResultsToTaskReport(final MSQDestination 
destination)
   {
-    return querySpec.getDestination() instanceof TaskReportMSQDestination;
+    return destination instanceof TaskReportMSQDestination;
   }
 
   /**
    * Returns true if the task reads from the same table as the destination. In 
this case, we would prefer to fail
    * instead of reading any unused segments to ensure that old data is not 
read.
    */
-  public static boolean isReplaceInputDataSourceTask(MSQSpec querySpec)
+  public static boolean isReplaceInputDataSourceTask(Query<?> query, 
MSQDestination destination)
   {
-    if (isIngestion(querySpec)) {
-      final String targetDataSource = ((DataSourceMSQDestination) 
querySpec.getDestination()).getDataSource();
-      final Set<String> sourceTableNames = 
querySpec.getQuery().getDataSource().getTableNames();
+    if (isIngestion(destination)) {
+      final String targetDataSource = ((DataSourceMSQDestination) 
destination).getDataSource();
+      final Set<String> sourceTableNames = 
query.getDataSource().getTableNames();
       return sourceTableNames.contains(targetDataSource);
     } else {
       return false;
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQSpec.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQSpec.java
index 4bb4e32754f..98e2a060d26 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQSpec.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQSpec.java
@@ -26,6 +26,7 @@ import 
org.apache.druid.msq.indexing.destination.MSQDestination;
 import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
 import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
 import org.apache.druid.query.Query;
+import org.apache.druid.query.QueryContext;
 import org.apache.druid.sql.calcite.planner.ColumnMappings;
 
 import java.util.Map;
@@ -66,6 +67,11 @@ public class MSQSpec
     return query;
   }
 
+  public QueryContext getContext()
+  {
+    return query.context();
+  }
+
   @JsonProperty("columnMappings")
   public ColumnMappings getColumnMappings()
   {
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java
index 5462b991737..ae25810cdac 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java
@@ -301,7 +301,7 @@ public class MSQTaskQueryMaker implements QueryMaker
                .tuningConfig(new MSQTuningConfig(maxNumWorkers, 
maxRowsInMemory, rowsPerSegment, maxNumSegments, indexSpec))
                .build();
 
-    MSQTaskQueryMakerUtils.validateRealtimeReindex(querySpec);
+    MSQTaskQueryMakerUtils.validateRealtimeReindex(querySpec.getContext(), 
querySpec.getDestination(), querySpec.getQuery());
 
     return querySpec.withOverriddenContext(nativeQueryContext);
   }
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java
index e26969c6761..dcaa447539b 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java
@@ -665,7 +665,6 @@ public class SqlStatementResource
 
     MSQControllerTask msqControllerTask = (MSQControllerTask) 
taskPayloadResponse.getPayload();
     String queryUser = String.valueOf(msqControllerTask.getQuerySpec()
-                                                       .getQuery()
                                                        .getContext()
                                                        
.get(MSQTaskQueryMaker.USER_KEY));
 
@@ -721,7 +720,7 @@ public class SqlStatementResource
     if (resultFormatParam == null) {
       return QueryContexts.getAsEnum(
           RESULT_FORMAT,
-          msqSpec.getQuery().context().get(RESULT_FORMAT),
+          msqSpec.getContext().get(RESULT_FORMAT),
           ResultFormat.class,
           ResultFormat.DEFAULT_RESULT_FORMAT
       );
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MSQTaskQueryMakerUtils.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MSQTaskQueryMakerUtils.java
index 36c90a21f00..838c438cbf4 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MSQTaskQueryMakerUtils.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MSQTaskQueryMakerUtils.java
@@ -25,8 +25,10 @@ import org.apache.druid.error.InvalidSqlInput;
 import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.msq.exec.SegmentSource;
 import org.apache.druid.msq.indexing.MSQControllerTask;
-import org.apache.druid.msq.indexing.MSQSpec;
 import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination;
+import org.apache.druid.msq.indexing.destination.MSQDestination;
+import org.apache.druid.query.Query;
+import org.apache.druid.query.QueryContext;
 
 import java.util.List;
 import java.util.Set;
@@ -96,18 +98,19 @@ public class MSQTaskQueryMakerUtils
    * Validates that a query does not read from a datasource that it is 
ingesting data into, if realtime segments are
    * being queried.
    */
-  public static void validateRealtimeReindex(final MSQSpec querySpec)
+  public static void validateRealtimeReindex(QueryContext context, 
MSQDestination destination, Query<?> query)
   {
-    final SegmentSource segmentSources = 
MultiStageQueryContext.getSegmentSources(querySpec.getQuery().context());
-    if (MSQControllerTask.isReplaceInputDataSourceTask(querySpec) && 
SegmentSource.REALTIME.equals(segmentSources)) {
+    final SegmentSource segmentSources = 
MultiStageQueryContext.getSegmentSources(context);
+    if (MSQControllerTask.isReplaceInputDataSourceTask(query, destination) && 
SegmentSource.REALTIME.equals(segmentSources)) {
       throw DruidException.forPersona(DruidException.Persona.USER)
                           .ofCategory(DruidException.Category.INVALID_INPUT)
                           .build("Cannot ingest into datasource[%s] since it 
is also being queried from, with "
                                  + "REALTIME segments included. Ingest to a 
different datasource, or disable querying "
                                  + "of realtime segments by modifying [%s] in 
the query context.",
-                                 ((DataSourceMSQDestination) 
querySpec.getDestination()).getDataSource(),
+                                 ((DataSourceMSQDestination) 
destination).getDataSource(),
                                  
MultiStageQueryContext.CTX_INCLUDE_SEGMENT_SOURCE
                           );
     }
   }
+
 }
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartControllerContextTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartControllerContextTest.java
index 0bf61054f32..5da56dde8bb 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartControllerContextTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartControllerContextTest.java
@@ -87,9 +87,8 @@ public class DartControllerContextTest
     mockCloser = MockitoAnnotations.openMocks(this);
     memoryIntrospector = new MemoryIntrospectorImpl(100_000_000, 0.75, 1, 1, 
null);
     Mockito.when(serverView.getDruidServerMetadatas()).thenReturn(SERVERS);
-    Mockito.when(querySpec.getQuery()).thenReturn(query);
     
Mockito.when(querySpec.getDestination()).thenReturn(TaskReportMSQDestination.instance());
-    Mockito.when(query.context()).thenReturn(queryContext);
+    Mockito.when(querySpec.getContext()).thenReturn(queryContext);
   }
 
   @AfterEach
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java
index 590a086c8e4..903ff1544c3 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java
@@ -104,7 +104,7 @@ public class MSQTestOverlordServiceClient extends 
NoopOverlordClient
           workerMemoryParameters,
           loadedSegmentMetadata,
           cTask.getTaskLockType(),
-          cTask.getQuerySpec().getQuery().context()
+          cTask.getQuerySpec().getContext()
       );
 
       inMemoryControllerTask.put(cTask.getId(), cTask);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to