Author: Aart Bik Date: 2023-11-30T14:19:02-08:00 New Revision: 6fb7c2d713587a061cd281eda917746750559380
URL: https://github.com/llvm/llvm-project/commit/6fb7c2d713587a061cd281eda917746750559380 DIFF: https://github.com/llvm/llvm-project/commit/6fb7c2d713587a061cd281eda917746750559380.diff LOG: [mlir][sparse] bug fix on all-dense lex insertion (#73987) Fixes a bug that appended values after insertion completed. Also slight optimization by avoiding all-Dense computation for every lexInsert call Added: Modified: mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp Removed: ################################################################################ diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h index 19c49e6c487df..01c5f2382ffe6 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h @@ -186,6 +186,7 @@ class SparseTensorStorageBase { protected: const MapRef map; // non-owning pointers into dim2lvl/lvl2dim vectors + const bool allDense; }; /// A memory-resident sparse tensor using a storage scheme based on @@ -293,8 +294,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase { /// Partially specialize lexicographical insertions based on template types. void lexInsert(const uint64_t *lvlCoords, V val) final { assert(lvlCoords); - bool allDense = std::all_of(getLvlTypes().begin(), getLvlTypes().end(), - [](LevelType lt) { return isDenseLT(lt); }); if (allDense) { uint64_t lvlRank = getLvlRank(); uint64_t valIdx = 0; @@ -363,10 +362,12 @@ class SparseTensorStorage final : public SparseTensorStorageBase { /// Finalizes lexicographic insertions. void endLexInsert() final { - if (values.empty()) - finalizeSegment(0); - else - endPath(0); + if (!allDense) { + if (values.empty()) + finalizeSegment(0); + else + endPath(0); + } } /// Allocates a new COO object and initializes it with the contents. @@ -705,7 +706,6 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage( // we reserve position/coordinate space based on all previous dense // levels, which works well up to first sparse level; but we should // really use nnz and dense/sparse distribution. - bool allDense = true; uint64_t sz = 1; for (uint64_t l = 0; l < lvlRank; l++) { if (isCompressedLvl(l)) { @@ -713,23 +713,19 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage( positions[l].push_back(0); coordinates[l].reserve(sz); sz = 1; - allDense = false; } else if (isLooseCompressedLvl(l)) { positions[l].reserve(2 * sz + 1); // last one unused positions[l].push_back(0); coordinates[l].reserve(sz); sz = 1; - allDense = false; } else if (isSingletonLvl(l)) { coordinates[l].reserve(sz); sz = 1; - allDense = false; } else if (is2OutOf4Lvl(l)) { - assert(allDense && l == lvlRank - 1 && "unexpected 2:4 usage"); + assert(l == lvlRank - 1 && "unexpected 2:4 usage"); sz = detail::checkedMul(sz, lvlSizes[l]) / 2; coordinates[l].reserve(sz); values.reserve(sz); - allDense = false; } else { // Dense level. assert(isDenseLvl(l)); sz = detail::checkedMul(sz, lvlSizes[l]); diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp index 7f8f76f8ec189..0c7b3a228a65c 100644 --- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp @@ -17,6 +17,13 @@ using namespace mlir::sparse_tensor; +static inline bool isAllDense(uint64_t lvlRank, const LevelType *lvlTypes) { + for (uint64_t l = 0; l < lvlRank; l++) + if (!isDenseLT(lvlTypes[l])) + return false; + return true; +} + SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, const uint64_t *lvlSizes, const LevelType *lvlTypes, @@ -26,15 +33,16 @@ SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT lvlTypes(lvlTypes, lvlTypes + lvlRank), dim2lvlVec(dim2lvl, dim2lvl + lvlRank), lvl2dimVec(lvl2dim, lvl2dim + dimRank), - map(dimRank, lvlRank, dim2lvlVec.data(), lvl2dimVec.data()) { + map(dimRank, lvlRank, dim2lvlVec.data(), lvl2dimVec.data()), + allDense(isAllDense(lvlRank, lvlTypes)) { assert(dimSizes && lvlSizes && lvlTypes && dim2lvl && lvl2dim); // Validate dim-indexed parameters. assert(dimRank > 0 && "Trivial shape is unsupported"); - for (uint64_t d = 0; d < dimRank; ++d) + for (uint64_t d = 0; d < dimRank; d++) assert(dimSizes[d] > 0 && "Dimension size zero has trivial storage"); // Validate lvl-indexed parameters. assert(lvlRank > 0 && "Trivial shape is unsupported"); - for (uint64_t l = 0; l < lvlRank; ++l) { + for (uint64_t l = 0; l < lvlRank; l++) { assert(lvlSizes[l] > 0 && "Level size zero has trivial storage"); assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) || isSingletonLvl(l) || is2OutOf4Lvl(l)); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits