Github user zuyu commented on a diff in the pull request: https://github.com/apache/incubator-quickstep/pull/309#discussion_r143527518 --- Diff: query_execution/ProbabilityStore.hpp --- @@ -0,0 +1,274 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + **/ + +#ifndef QUICKSTEP_QUERY_EXECUTION_PROBABILITY_STORE_HPP_ +#define QUICKSTEP_QUERY_EXECUTION_PROBABILITY_STORE_HPP_ + +#include <algorithm> +#include <cstddef> +#include <random> +#include <unordered_map> +#include <vector> + +#include "utility/Macros.hpp" + +#include "glog/logging.h" + +namespace quickstep { + +/** + * @brief A class that stores the probabilities of objects. We use an integer field + * called "key" to identify each object. + * A probability is expressed as a fraction. All the objects share a common denominator. + **/ +class ProbabilityStore { + public: + /** + * @brief Constructor. + **/ + ProbabilityStore() + : common_denominator_(1.0) {} + + /** + * @brief Get the number of objects in the store. + **/ + inline const std::size_t getNumObjects() const { + DCHECK_EQ(individual_probabilities_.size(), cumulative_probabilities_.size()); + return individual_probabilities_.size(); + } + + /** + * @brief Get the common denominator. + */ + inline const std::size_t getDenominator() const { + return common_denominator_; + } + + /** + * @brief Check if an object with a given key is present. + * @param key The key of the given object. + * @return True if the object is present, false otherwise. + */ + inline bool hasObject(const std::size_t key) const { + return (individual_probabilities_.find(key) != individual_probabilities_.end()); + } + + /** + * @brief Add individual (not cumulative) probability for a given object with + * updated denominator. + * + * @note This function leaves the cumulative probabilities in a consistent + * state. An alternative lazy implementation should be written if cost + * of calculating cumulative probabilities is high. + * + * @param key The key of the given object. + * @param numerator The numerator for the given object. + * @param new_denominator The updated denominator for the store. + **/ + void addOrUpdateObjectNewDenominator(const std::size_t key, + const float numerator, + const float new_denominator) { + CHECK_GT(new_denominator, 0u); + DCHECK_LE(numerator, new_denominator); + common_denominator_ = new_denominator; + addOrUpdateObjectHelper(key, numerator); + } + + /** + * @brief Add or update multiple objects with a new denominator. + * @param keys The keys to be added/updated. + * @param numerators The numerators to be added/updated. + * @param new_denominator The new denominator. + */ + void addOrUpdateObjectsNewDenominator( + const std::vector<std::size_t> &keys, + const std::vector<float> &numerators, + const float new_denominator) { + CHECK_GT(new_denominator, 0u); + common_denominator_ = new_denominator; + addOrUpdateObjectsHelper(keys, numerators); + } + + /** + * @brief Remove an object from the store. + * + * @note This function decrements the denominator with the value of the numerator being removed. + * + * @param key The key of the object to be removed. + **/ + void removeObject(const std::size_t key) { + auto individual_it = individual_probabilities_.find(key); + DCHECK(individual_it != individual_probabilities_.end()); + const float new_denominator = common_denominator_ - individual_it->second; + individual_probabilities_.erase(individual_it); + common_denominator_ = new_denominator; + updateCumulativeProbabilities(); + } + + /** + * @brief Get the individual probability (not cumulative) for an object. + * + * @param key The key of the object. + **/ + const float getIndividualProbability(const std::size_t key) const { + const auto it = individual_probabilities_.find(key); + DCHECK(it != individual_probabilities_.end()); + return it->second / common_denominator_; + } + + /** + * @brief Update the cumulative probabilities. + * + * @note This function should be called upon if there are any updates, + * additions or deletions to the individual probabilities. + * @note An efficient implementation should be written if there are large + * number of objects. + **/ + void updateCumulativeProbabilities() { + cumulative_probabilities_.clear(); + float cumulative_probability = 0; + for (const auto p : individual_probabilities_) { + cumulative_probability += p.second / common_denominator_; + cumulative_probabilities_.emplace_back(p.first, + cumulative_probability); + } + if (!cumulative_probabilities_.empty()) { + // Adjust the last cumulative probability manually to 1, so that + // floating addition related rounding issues are ignored. + cumulative_probabilities_.back().updateProbability(1); + } + } + + /** + * @brief Return a randomly chosen key. + * + * TODO(harshad) - If it is expensive to create the random device + * on every invocation of this function, make it a class variable. + * In which case, we won't be able to mark the function as const. + * + * @note The random number is uniformly chosen. + **/ + inline const std::size_t pickRandomKey() const { + std::uniform_real_distribution<float> dist(0.0, 1.0); + std::random_device rd; + const float chosen_probability = dist(rd); + return getKeyForProbability(chosen_probability); + } + + private: + class ProbabilityInfo { + public: + ProbabilityInfo(const std::size_t key, const float probability) + : key_(key), probability_(probability) { + // As GLOG doesn't provide DEBUG only checks for less than equal + // comparison for floats, we can't ensure that probability is less than + // 1.0. + } + + ProbabilityInfo(const ProbabilityInfo &other) = default; + + ProbabilityInfo& operator=(const ProbabilityInfo &other) = default; + + void updateProbability(const float new_probability) { + probability_ = new_probability; + } + + std::size_t key_; + float probability_; + }; + + /** + * @brief Get a key for a given cumulative probability. + * + * @param key_cumulative_probability The input cumulative probability. + * + * @return The object that has a cumulative probability that is greater than + * or equal to the input cumulative probability. + **/ + inline const std::size_t getKeyForProbability( --- End diff -- Remove `const`.
---