juripetersen opened a new pull request, #30:
URL: https://github.com/apache/incubator-wayang-website/pull/30
This PR provides a short .md-guide that shows an examplory usage of the
previously introduced abstraction of the cost model.
The guide shows how it can be utilized in order to predict query plans
runtimes with a pre-trained ML model.
# Using Machine Learning for query optimization in Apache Wayang (incubating)
Apache Wayang (incubating) can be customized with concrete
implementations of the `EstimatableCost` interface in order to optimize
for a desired metric. The implementation can be enabled by providing it
to a `Configuration`.
```java
public class CustomEstimatableCost implements EstimatableCost {
/* Provide concrete implementations to match desired cost function(s)
* by implementing the interface in this class.
*/
}
public class WordCount {
public static void main(String[] args) {
/* Create a Wayang context and specify the platforms Wayang will
consider */
Configuration config = new Configuration();
/* Provision of a EstimatableCost that implements the interface.*/
config.setCostModel(new CustomEstimatableCost());
WayangContext wayangContext = new WayangContext(config)
.withPlugin(Java.basicPlugin())
.withPlugin(Spark.basicPlugin());
/*... omitted */
}
}
```
In combination with an encoding scheme and a third party package to load
ML models, the following example shows how to predict runtimes of query
execution plans runtimes in Apache Wayang (incubating):
```java
public class MLCost implements EstimatableCost {
public EstimatableCostFactory getFactory() {
return new Factory();
}
public static class Factory implements EstimatableCostFactory {
@Override public EstimatableCost makeCost() {
return new MLCost();
}
}
@Override public ProbabilisticDoubleInterval
getEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return ProbabilisticDoubleInterval.ofExactly(
model.runModel(OneHotEncoder.encode(plan))
);
} catch(Exception e) {
return ProbabilisticDoubleInterval.zero;
}
}
@Override public ProbabilisticDoubleInterval
getParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return ProbabilisticDoubleInterval.ofExactly(
model.runModel(OneHotEncoder.encode(plan))
);
} catch(Exception e) {
return ProbabilisticDoubleInterval.zero;
}
}
/** Returns a squashed cost estimate. */
@Override public double getSquashedEstimate(PlanImplementation plan,
boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return model.runModel(OneHotEncoder.encode(plan));
} catch(Exception e) {
return 0;
}
}
@Override public double getSquashedParallelEstimate(PlanImplementation
plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return model.runModel(OneHotEncoder.encode(plan));
} catch(Exception e) {
return 0;
}
}
@Override public Tuple<List<ProbabilisticDoubleInterval>, List<Double>>
getParallelOperatorJunctionAllCostEstimate(PlanImplementation plan, Operator
operator) {
List<ProbabilisticDoubleInterval> intervalList = new
ArrayList<ProbabilisticDoubleInterval>();
List<Double> doubleList = new ArrayList<Double>();
intervalList.add(this.getEstimate(plan, true));
doubleList.add(this.getSquashedEstimate(plan, true));
return new Tuple<>(intervalList, doubleList);
}
public PlanImplementation pickBestExecutionPlan(
Collection<PlanImplementation> executionPlans,
ExecutionPlan existingPlan,
Set<Channel> openChannels,
Set<ExecutionStage> executedStages) {
final PlanImplementation bestPlanImplementation =
executionPlans.stream()
.reduce((p1, p2) -> {
final double t1 = p1.getSquashedCostEstimate();
final double t2 = p2.getSquashedCostEstimate();
return t1 < t2 ? p1 : p2;
})
.orElseThrow(() -> new WayangException("Could not find an
execution plan."));
return bestPlanImplementation;
}
}
```
Third-party packages such as `OnnxRuntime` can be used to load
pre-trained `.onnx` files that contain desired ML models.
```java
public class OrtMLModel {
private static OrtMLModel INSTANCE;
private OrtSession session;
private OrtEnvironment env;
private final Map<String, OnnxTensor> inputMap = new HashMap<>();
private final Set<String> requestedOutputs = new HashSet<>();
public static OrtMLModel getInstance(Configuration configuration) throws
OrtException {
if (INSTANCE == null) {
INSTANCE = new OrtMLModel(configuration);
}
return INSTANCE;
}
private OrtMLModel(Configuration configuration) throws OrtException {
this.loadModel(configuration.getStringProperty("wayang.ml.model.file"));
}
public void loadModel(String filePath) throws OrtException {
if (this.env == null) {
this.env = OrtEnvironment.getEnvironment();
}
if (this.session == null) {
this.session = env.createSession(filePath, new
OrtSession.SessionOptions());
}
}
public void closeSession() throws OrtException {
this.session.close();
this.env.close();
}
/**
* @param encodedVector
* @return NaN on error, and a predicted cost on any other value.
* @throws OrtException
*/
public double runModel(Vector<Long> encodedVector) throws OrtException {
double costPrediction;
OnnxTensor tensor = OnnxTensor.createTensor(env, encodedVector);
this.inputMap.put("input", tensor);
this.requestedOutputs.add("output");
BiFunction<Result, String, Double> unwrapFunc = (r, s) -> {
try {
return ((double[]) r.get(s).get().getValue())[0];
} catch (OrtException e) {
return Double.NaN;
}
};
try (Result r = session.run(inputMap, requestedOutputs)) {
costPrediction = unwrapFunc.apply(r, "output");
}
return costPrediction;
}
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]