juripetersen opened a new pull request, #30:
URL: https://github.com/apache/incubator-wayang-website/pull/30

   This PR provides a short .md-guide that shows an examplory usage of the 
previously introduced abstraction of the cost model.
   The guide shows how it can be utilized in order to predict query plans 
runtimes with a pre-trained ML model.
   
   # Using Machine Learning for query optimization in Apache Wayang (incubating)
   Apache Wayang (incubating) can be customized with concrete
   implementations of the `EstimatableCost` interface in order to optimize
   for a desired metric.  The implementation can be enabled by providing it
   to a `Configuration`.
   
   ```java
   public class CustomEstimatableCost implements EstimatableCost {
       /* Provide concrete implementations to match desired cost function(s)
        * by implementing the interface in this class.
        */
   }
   public class WordCount {
       public static void main(String[] args) {
           /* Create a Wayang context and specify the platforms Wayang will 
consider */
           Configuration config = new Configuration();
           /* Provision of a EstimatableCost that implements the interface.*/
           config.setCostModel(new CustomEstimatableCost());
           WayangContext wayangContext = new WayangContext(config)
                   .withPlugin(Java.basicPlugin())
                   .withPlugin(Spark.basicPlugin());
           /*... omitted */
       }
   }
   ```
   
   In combination with an encoding scheme and a third party package to load
   ML models, the following example shows how to predict runtimes of query
   execution plans runtimes in Apache Wayang (incubating):
   
   ```java
   public class MLCost implements EstimatableCost {
       public EstimatableCostFactory getFactory() {
           return new Factory();
       }
   
       public static class Factory implements EstimatableCostFactory {
           @Override public EstimatableCost makeCost() {
               return new MLCost();
           }
       }
   
       @Override public ProbabilisticDoubleInterval 
getEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
           try {
               Configuration config = plan
                   .getOptimizationContext()
                   .getConfiguration();
               OrtMLModel model = OrtMLModel.getInstance(config);
   
               return ProbabilisticDoubleInterval.ofExactly(
                   model.runModel(OneHotEncoder.encode(plan))
               );
           } catch(Exception e) {
               return ProbabilisticDoubleInterval.zero;
           }
       }
   
       @Override public ProbabilisticDoubleInterval 
getParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
           try {
               Configuration config = plan
                   .getOptimizationContext()
                   .getConfiguration();
               OrtMLModel model = OrtMLModel.getInstance(config);
   
               return ProbabilisticDoubleInterval.ofExactly(
                   model.runModel(OneHotEncoder.encode(plan))
               );
           } catch(Exception e) {
               return ProbabilisticDoubleInterval.zero;
           }
       }
   
       /** Returns a squashed cost estimate. */
       @Override public double getSquashedEstimate(PlanImplementation plan, 
boolean isOverheadIncluded) {
           try {
               Configuration config = plan
                   .getOptimizationContext()
                   .getConfiguration();
               OrtMLModel model = OrtMLModel.getInstance(config);
   
               return model.runModel(OneHotEncoder.encode(plan));
           } catch(Exception e) {
               return 0;
           }
       }
   
       @Override public double getSquashedParallelEstimate(PlanImplementation 
plan, boolean isOverheadIncluded) {
           try {
               Configuration config = plan
                   .getOptimizationContext()
                   .getConfiguration();
               OrtMLModel model = OrtMLModel.getInstance(config);
   
               return model.runModel(OneHotEncoder.encode(plan));
           } catch(Exception e) {
               return 0;
           }
       }
   
       @Override public Tuple<List<ProbabilisticDoubleInterval>, List<Double>> 
getParallelOperatorJunctionAllCostEstimate(PlanImplementation plan, Operator 
operator) {
           List<ProbabilisticDoubleInterval> intervalList = new 
ArrayList<ProbabilisticDoubleInterval>();
           List<Double> doubleList = new ArrayList<Double>();
           intervalList.add(this.getEstimate(plan, true));
           doubleList.add(this.getSquashedEstimate(plan, true));
   
           return new Tuple<>(intervalList, doubleList);
       }
   
       public PlanImplementation pickBestExecutionPlan(
               Collection<PlanImplementation> executionPlans,
               ExecutionPlan existingPlan,
               Set<Channel> openChannels,
               Set<ExecutionStage> executedStages) {
           final PlanImplementation bestPlanImplementation = 
executionPlans.stream()
                   .reduce((p1, p2) -> {
                       final double t1 = p1.getSquashedCostEstimate();
                       final double t2 = p2.getSquashedCostEstimate();
                       return t1 < t2 ? p1 : p2;
                   })
                   .orElseThrow(() -> new WayangException("Could not find an 
execution plan."));
           return bestPlanImplementation;
       }
   }
   ```
   
   Third-party packages such as `OnnxRuntime` can be used to load
   pre-trained `.onnx` files that contain desired ML models.
   
   ```java
   
   public class OrtMLModel {
   
       private static OrtMLModel INSTANCE;
   
       private OrtSession session;
       private OrtEnvironment env;
   
       private final Map<String, OnnxTensor> inputMap = new HashMap<>();
       private final Set<String> requestedOutputs = new HashSet<>();
   
       public static OrtMLModel getInstance(Configuration configuration) throws 
OrtException {
           if (INSTANCE == null) {
               INSTANCE = new OrtMLModel(configuration);
           }
   
           return INSTANCE;
       }
   
       private OrtMLModel(Configuration configuration) throws OrtException {
           
this.loadModel(configuration.getStringProperty("wayang.ml.model.file"));
       }
   
       public void loadModel(String filePath) throws OrtException {
           if (this.env == null) {
               this.env = OrtEnvironment.getEnvironment();
           }
   
           if (this.session == null) {
               this.session = env.createSession(filePath, new 
OrtSession.SessionOptions());
           }
       }
   
       public void closeSession() throws OrtException {
           this.session.close();
           this.env.close();
       }
   
       /**
        * @param encodedVector
        * @return NaN on error, and a predicted cost on any other value.
        * @throws OrtException
        */
       public double runModel(Vector<Long> encodedVector) throws OrtException {
           double costPrediction;
   
           OnnxTensor tensor = OnnxTensor.createTensor(env, encodedVector);
           this.inputMap.put("input", tensor);
           this.requestedOutputs.add("output");
   
           BiFunction<Result, String, Double> unwrapFunc = (r, s) -> {
               try {
                   return ((double[]) r.get(s).get().getValue())[0];
               } catch (OrtException e) {
                   return Double.NaN;
               }
           };
   
           try (Result r = session.run(inputMap, requestedOutputs)) {
               costPrediction = unwrapFunc.apply(r, "output");
           }
   
           return costPrediction;
       }
   }
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to