junrushao1994 commented on a change in pull request #5585: URL: https://github.com/apache/incubator-tvm/pull/5585#discussion_r426207287
########## File path: include/tvm/runtime/container.h ########## @@ -189,6 +189,759 @@ class InplaceArrayBase { } }; +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template <typename Converter, typename TIter> +class IterAdapter { + public: + using difference_type = typename std::iterator_traits<TIter>::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) + using iterator_category = typename std::iterator_traits<TIter>::iterator_category; + + explicit IterAdapter(TIter iter) : iter_(iter) {} + IterAdapter& operator++() { + ++iter_; + return *this; + } + IterAdapter& operator--() { + --iter_; + return *this; + } + IterAdapter& operator++(int) { + IterAdapter copy = *this; + ++iter_; + return copy; + } + IterAdapter& operator--(int) { + IterAdapter copy = *this; + --iter_; + return copy; + } + + IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } + + template <typename T = IterAdapter> + typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value, + typename T::difference_type>::type inline + operator-(const IterAdapter& rhs) const { + return iter_ - rhs.iter_; + } + + bool operator==(IterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(IterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template <typename Converter, typename TIter> +class ReverseIterAdapter { + public: + using difference_type = typename std::iterator_traits<TIter>::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) + using iterator_category = typename std::iterator_traits<TIter>::iterator_category; + + explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} + ReverseIterAdapter& operator++() { + --iter_; + return *this; + } + ReverseIterAdapter& operator--() { + ++iter_; + return *this; + } + ReverseIterAdapter& operator++(int) { + ReverseIterAdapter copy = *this; + --iter_; + return copy; + } + ReverseIterAdapter& operator--(int) { + ReverseIterAdapter copy = *this; + ++iter_; + return copy; + } + ReverseIterAdapter operator+(difference_type offset) const { + return ReverseIterAdapter(iter_ - offset); + } + + template <typename T = ReverseIterAdapter> + typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value, + typename T::difference_type>::type inline + operator-(const ReverseIterAdapter& rhs) const { + return rhs.iter_ - iter_; + } + + bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +/*! \brief array node content in array */ +class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> { + public: + /*! \return The size of the array */ + size_t size() const { return this->size_; } + + /*! + * \brief Read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const ObjectRef at(int64_t i) const { return this->operator[](i); } + + /*! \return begin constant iterator */ + const ObjectRef* begin() const { return static_cast<ObjectRef*>(InplaceArrayBase::AddressOf(0)); } + + /*! \return end constant iterator */ + const ObjectRef* end() const { return begin() + size_; } + + /*! \brief Release reference to all the elements */ + void clear() { + for (ObjectRef* itr = MutableEnd(); size_; --size_) { + (--itr)->ObjectRef::~ObjectRef(); + } + } + + /*! + * \brief Set i-th element of the array in-place + * \param i The index + * \param item The value to be set + */ + void SetItem(int64_t i, ObjectRef item) { this->operator[](i) = std::move(item); } + + /*! + * \brief Constructs a container and copy from another + * \param cap The capacity of the container + * \param from Source of the copy + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr<ArrayNode> CopyFrom(int64_t cap, ArrayNode* from) { + int64_t size = from->size_; + CHECK_GE(cap, size) << "ValueError: not enough capacity"; + ObjectPtr<ArrayNode> p = ArrayNode::Empty(cap); + ObjectRef* write = p->MutableBegin(); + ObjectRef* read = from->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) ObjectRef(*read++); + } + return p; + } + + /*! + * \brief Constructs a container and move from another + * \param cap The capacity of the container + * \param from Source of the move + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr<ArrayNode> MoveFrom(int64_t cap, ArrayNode* from) { + int64_t size = from->size_; + CHECK_GE(cap, size) << "ValueError: not enough capacity"; + ObjectPtr<ArrayNode> p = ArrayNode::Empty(cap); + ObjectRef* write = p->MutableBegin(); + ObjectRef* read = from->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) ObjectRef(std::move(*read++)); + } + from->size_ = 0; + return p; + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr<ArrayNode> CreateRepeated(int64_t n, const ObjectRef& val) { + ObjectPtr<ArrayNode> p = ArrayNode::Empty(n); + ObjectRef* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < n; ++i) { + new (itr++) ObjectRef(val); + } + return p; + } + + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray; + static constexpr const char* _type_key = "Array"; + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); + + private: + /*! \return Size of initialized memory, used by InplaceArrayBase. */ + size_t GetSize() const { return this->size_; } + + /*! \return begin mutable iterator */ + ObjectRef* MutableBegin() const { + return static_cast<ObjectRef*>(InplaceArrayBase::AddressOf(0)); + } + + /*! \return end mutable iterator */ + ObjectRef* MutableEnd() const { return MutableBegin() + size_; } + + /*! + * \brief Create an ArrayNode with the given capacity. + * \param n Required capacity + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr<ArrayNode> Empty(int64_t n = kInitSize) { + CHECK_GE(n, 0); + ObjectPtr<ArrayNode> p = make_inplace_array_object<ArrayNode, ObjectRef>(n); + p->capacity_ = n; + p->size_ = 0; + return p; + } + + /*! + * \brief Assign the elements starting idx from [first, last) + * \param idx The starting point + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return Self + */ + template <typename IterType> + ArrayNode* AssignRange(int64_t idx, IterType first, IterType last) { + ObjectRef* itr = MutableBegin() + idx; + for (; first != last; ++first) { + ObjectRef ref = *first; + new (itr++) ObjectRef(std::move(ref)); + } + return this; + } + + /*! + * \brief Move elements from right to left, requires src > dst + * \param dst Destination + * \param src Source + * \param numel Number of elements to be moved + * \return Self + */ + ArrayNode* MoveElementsLeft(int64_t dst, int64_t src, int64_t numel) { + ObjectRef* from = MutableBegin() + src; + ObjectRef* to = MutableBegin() + dst; + for (int64_t i = 0; i < numel; ++i) { + *to++ = std::move(*from++); + } + return this; + } + + /*! + * \brief Move elements from left to rigth, requires src < dst + * \param dst Destination + * \param src Source + * \param numel Number of elements to be moved + * \return Self + */ + ArrayNode* MoveElementsRight(int64_t dst, int64_t src, int64_t numel) { + ObjectRef* from = MutableBegin() + (src + numel); + ObjectRef* to = MutableBegin() + (dst + numel); + for (int64_t i = 0; i < numel; ++i) { + *--to = std::move(*--from); + } + return this; + } + + /*! + * \brief Enlarge the size of the array + * \param delta Size enlarged, should be positive + * \param val Default value + * \return Self + */ + ArrayNode* Enlarge(int64_t delta, const ObjectRef& val = ObjectRef(nullptr)) { + ObjectRef* itr = MutableEnd(); + while (delta-- > 0) { + new (itr++) ObjectRef(val); + ++size_; + } + return this; + } + + /*! + * \brief Shrinks the size of the array + * \param delta Size shrinked, should be positive + * \return Self + */ + ArrayNode* Shrink(int64_t delta) { + ObjectRef* itr = MutableEnd(); + while (delta-- > 0) { + (--itr)->ObjectRef::~ObjectRef(); + --size_; + } + return this; + } + + /*! \brief Number of elements used */ + int64_t size_; + + /*! \brief Number of elements allocated */ + int64_t capacity_; + + /*! \brief Initial size of ArrayNode */ + static const constexpr int64_t kInitSize = 16; + + /*! \brief Expansion factor of the Array */ + static const constexpr int64_t kIncFactor = 2; + + // CRTP parent class + friend InplaceArrayBase<ArrayNode, ObjectRef>; + + // Reference class + template <typename, typename> + friend class Array; + + // To specialize make_object<ArrayNode> + friend ObjectPtr<ArrayNode> make_object<>(); +}; + +/*! + * \brief Array container of ObjectRef in DSL graph. + * Array implements copy-on-write semantics, which means array is mutable + * but copy will happen when array is referenced in more than two places. + * + * operator[] only provide const access, use Set to mutate the content. + * \tparam T The content ObjectRef type. + */ +template <typename T, + typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type> +class Array : public ObjectRef { + public: + // constructors + /*! + * \brief default constructor + */ + Array() { data_ = ArrayNode::Empty(); } + + /*! + * \brief move constructor + * \param other source + */ + Array(Array<T>&& other) : ObjectRef() { // NOLINT(*) + data_ = std::move(other.data_); + } + + /*! + * \brief copy constructor + * \param other source + */ + Array(const Array<T>& other) : ObjectRef() { // NOLINT(*) + data_ = other.data_; + } + + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Array(ObjectPtr<Object> n) : ObjectRef(n) {} + + /*! + * \brief Constructor from iterator + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template <typename IterType> + Array(IterType first, IterType last) { + Assign(first, last); + } + + /*! + * \brief constructor from initializer list + * \param init The initializer list + */ + Array(std::initializer_list<T> init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief constructor from vector + * \param init The vector + */ + Array(const std::vector<T>& init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + */ + explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); } + + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array<T>& operator=(Array<T>&& other) { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array<T>& operator=(const Array<T>& other) { + data_ = other.data_; + return *this; + } + + public: + // iterators + struct ValueConverter { + using ResultType = T; + static T convert(const ObjectRef& n) { return DowncastNoCheck<T>(n); } + }; + + using iterator = IterAdapter<ValueConverter, const ObjectRef*>; + using reverse_iterator = ReverseIterAdapter<ValueConverter, const ObjectRef*>; + + /*! \return begin iterator */ + iterator begin() const { return iterator(GetArrayNode()->begin()); } + + /*! \return end iterator */ + iterator end() const { return iterator(GetArrayNode()->end()); } + + /*! \return rbegin iterator */ + reverse_iterator rbegin() const { return reverse_iterator(GetArrayNode()->end() - 1); } + + /*! \return rend iterator */ + reverse_iterator rend() const { return reverse_iterator(GetArrayNode()->begin() - 1); } + + public: + // const methods in std::vector + /*! + * \brief Immutably read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const T operator[](int64_t i) const { + ArrayNode* p = GetArrayNode(); + CHECK(p != nullptr) << "ValueError: cannot index a null array"; + CHECK(0 <= i && i < p->size_) << "IndexError: array index out of bound, indexing " << i + << " on an array of size " << p->size_; + return DowncastNoCheck<T>(*(p->begin() + i)); + } + + /*! \return The size of the array */ + size_t size() const { + ArrayNode* p = GetArrayNode(); + return p == nullptr ? 0 : GetArrayNode()->size_; + } + + /*! \return The capacity of the array */ + size_t capacity() const { + ArrayNode* p = GetArrayNode(); + return p == nullptr ? 0 : GetArrayNode()->capacity_; + } + + /*! \return Whether array is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the array */ + const T& front() const { + ArrayNode* p = GetArrayNode(); + CHECK(p != nullptr) << "ValueError: cannot index a null array"; + CHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; + return *p->begin(); + } + + /*! \return The last element of the array */ + const T& back() const { + ArrayNode* p = GetArrayNode(); + CHECK(p != nullptr) << "ValueError: cannot index a null array"; + CHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; + return *(p->end() - 1); + } + + public: + // mutation in std::vector, implements copy-on-write + + /*! + * \brief push a new item to the back of the list + * \param item The item to be pushed. + */ + void push_back(const T& item) { + ArrayNode* p = CopyOnWrite(1); + p->EmplaceInit(p->size_++, item); + } + + /*! + * \brief Insert an element into the given position + * \param position An iterator pointing to the insertion point + * \param val The element to insert + */ + void insert(iterator position, const T& val) { + CHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + auto addr = CopyOnWrite(1) // + ->Enlarge(1) // + ->MoveElementsRight(idx + 1, idx, size - idx) // + ->MutableBegin(); + new (addr + idx) ObjectRef(val); + } + + /*! + * \brief Insert a range of elements into the given position + * \param position An iterator pointing to the insertion point + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + template <typename IterType> + void insert(iterator position, IterType first, IterType last) { + if (first == last) { + return; + } + CHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + int64_t numel = std::distance(first, last); + CopyOnWrite(numel) + ->Enlarge(numel) + ->MoveElementsRight(idx + numel, idx, size - idx) + ->AssignRange(idx, first, last); + } + + /*! \brief Remove the last item of the list */ + void pop_back() { + CHECK(data_ != nullptr) << "ValueError: cannot pop_back because array is null"; + int64_t size = GetArrayNode()->size_; + CHECK_GT(size, 0) << "ValueError: cannot pop_back because array is empty"; + CopyOnWrite()->Shrink(1); + } + + /*! + * \brief Erase an element on the given position + * \param position An iterator pointing to the element to be erased + */ + void erase(iterator position) { + CHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; + int64_t st = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + CHECK(0 <= st && st < size) << "ValueError: cannot erase at index " << st + << ", because Array size is " << size; + CopyOnWrite() // + ->MoveElementsLeft(st, st + 1, size - st - 1) // + ->Shrink(1); + } + + /*! + * \brief Erase a given range of elements + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + void erase(iterator first, iterator last) { + if (first == last) { + return; + } + CHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; + int64_t size = GetArrayNode()->size_; + int64_t st = std::distance(begin(), first); + int64_t ed = std::distance(begin(), last); + CHECK_LT(st, ed) << "ValueError: cannot erase array in range [" << st << ", " << ed << ")"; + CHECK(0 <= st && st <= size && 0 <= ed && ed <= size) + << "ValueError: cannot erase array in range [" << st << ", " << ed << ")" + << ", because array size is " << size; + CopyOnWrite() // + ->MoveElementsLeft(st, ed, size - ed) // + ->Shrink(ed - st); + } + + /*! + * \brief Resize the array. + * \param n The new size. + */ + void resize(int64_t n) { + CHECK_GE(n, 0) << "ValueError: cannot resize an Array to negative size"; + if (data_ == nullptr) { + SwitchContainer(n); + return; + } + int64_t size = GetArrayNode()->size_; + if (size < n) { + CopyOnWrite(n - size)->Enlarge(n - size); + } else if (size > n) { + CopyOnWrite()->Shrink(size - n); + } + } + + /*! + * \brief Make sure the list has the capacity of at least n + * \param n lower bound of the capacity + */ + void reserve(int64_t n) { + if (data_ == nullptr || n > GetArrayNode()->capacity_) { + SwitchContainer(n); + } + } + + /*! \brief Release reference to all the elements */ + void clear() { + if (data_ != nullptr) { + ArrayNode* p = CopyOnWrite(); + p->clear(); + } + } + + public: + // Array's own methods + + /*! + * \brief set i-th element of the array. + * \param i The index + * \param value The value to be setted. + */ + void Set(int64_t i, const T& value) { + ArrayNode* p = this->CopyOnWrite(); + (*p)[i] = value; + } + + /*! \return The underlying ArrayNode */ + ArrayNode* GetArrayNode() const { return static_cast<ArrayNode*>(data_.get()); } + + /*! + * \brief Helper function to apply fmutate to mutate an array. + * \param fmutate The transformation function T -> T. + * \tparam F the type of the mutation function. + * \note This function performs copy on write optimization. + */ + template <typename F> + void MutateByApply(F fmutate) { + if (data_ == nullptr) { + return; + } + ArrayNode* p = GetArrayNode(); + ObjectRef* iter = p->MutableBegin(); + if (data_.unique()) { + for (int64_t i = 0; i < p->size_; ++i) { + T old_elem = DowncastNoCheck<T>(std::move(*iter)); + T new_elem = fmutate(std::move(old_elem)); + *iter++ = std::move(new_elem); + } + } else { + for (int64_t i = 0; i < p->size_; ++i) { + T old_elem = DowncastNoCheck<T>(*iter); + T new_elem = fmutate(old_elem); + // do nothing until the first real mutation begins. + if (new_elem.same_as(*iter++)) { + continue; + } + ObjectPtr<ArrayNode> copy = ArrayNode::Empty(p->capacity_); + // copy [0, i) + int64_t& size = copy->size_ = 0; + for (ObjectRef* old_iter = p->MutableBegin(); size < i;) { + copy->EmplaceInit(size++, *old_iter++); + } + // copy [i] + copy->EmplaceInit(size++, std::move(new_elem)); + // complete transformation in (i, size) + while (size < p->size_) { + T old_elem = DowncastNoCheck<T>(*iter++); + T new_elem = fmutate(old_elem); + copy->EmplaceInit(size++, std::move(new_elem)); + } + data_ = std::move(copy); + break; + } + } + } + + /*! + * \brief reset the array to content from iterator. + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template <typename IterType> + void Assign(IterType first, IterType last) { Review comment: Hmmm I think we should give a better name than `AssignRange`. `InitRange` sounds better ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org