Author: Mehdi Amini Date: 2024-02-15T13:26:07-08:00 New Revision: c11e879dec122a027ca9ab897fa9c6517cc3f33d
URL: https://github.com/llvm/llvm-project/commit/c11e879dec122a027ca9ab897fa9c6517cc3f33d DIFF: https://github.com/llvm/llvm-project/commit/c11e879dec122a027ca9ab897fa9c6517cc3f33d.diff LOG: Revert "[mlir][sparse] remove LevelType enum, construct LevelType from LevelF…" This reverts commit 235ec0f791749d94ac1ca1441b8b06d4ba09792c. Added: Modified: mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h mlir/lib/CAPI/Dialect/SparseTensor.cpp mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp mlir/unittests/Dialect/SparseTensor/MergerTest.cpp Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h index a20a7906189d01..74cc0dee554a17 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -153,9 +153,45 @@ enum class Action : uint32_t { kSortCOOInPlace = 8, }; +/// This enum defines all the sparse representations supportable by +/// the SparseTensor dialect. We use a lightweight encoding to encode +/// the "format" per se (dense, compressed, singleton, loose_compressed, +/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when +/// the format is NOutOfM. +/// The encoding is chosen for performance of the runtime library, and thus may +/// change in future versions; consequently, client code should use the +/// predicate functions defined below, rather than relying on knowledge +/// about the particular binary encoding. +/// +/// The `Undef` "format" is a special value used internally for cases +/// where we need to store an undefined or indeterminate `LevelType`. +/// It should not be used externally, since it does not indicate an +/// actual/representable format. +/// +/// Bit manipulations for LevelType: +/// +/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty | +/// +enum class LevelType : uint64_t { + Undef = 0x000000000000, + Dense = 0x000000010000, + Compressed = 0x000000020000, + CompressedNu = 0x000000020001, + CompressedNo = 0x000000020002, + CompressedNuNo = 0x000000020003, + Singleton = 0x000000040000, + SingletonNu = 0x000000040001, + SingletonNo = 0x000000040002, + SingletonNuNo = 0x000000040003, + LooseCompressed = 0x000000080000, + LooseCompressedNu = 0x000000080001, + LooseCompressedNo = 0x000000080002, + LooseCompressedNuNo = 0x000000080003, + NOutOfM = 0x000000100000, +}; + /// This enum defines all supported storage format without the level properties. enum class LevelFormat : uint64_t { - Undef = 0x00000000, Dense = 0x00010000, Compressed = 0x00020000, Singleton = 0x00040000, @@ -163,240 +199,327 @@ enum class LevelFormat : uint64_t { NOutOfM = 0x00100000, }; -template <LevelFormat... targets> -constexpr bool isAnyOfFmt(LevelFormat fmt) { - return (... || (targets == fmt)); -} - -/// Returns string representation of the given level format. -constexpr const char *toFormatString(LevelFormat lvlFmt) { - switch (lvlFmt) { - case LevelFormat::Undef: - return "undef"; - case LevelFormat::Dense: - return "dense"; - case LevelFormat::Compressed: - return "compressed"; - case LevelFormat::Singleton: - return "singleton"; - case LevelFormat::LooseCompressed: - return "loose_compressed"; - case LevelFormat::NOutOfM: - return "structured"; - } - return ""; -} - /// This enum defines all the nondefault properties for storage formats. -enum class LevelPropNonDefault : uint64_t { +enum class LevelPropertyNondefault : uint64_t { Nonunique = 0x0001, Nonordered = 0x0002, }; -/// Returns string representation of the given level properties. -constexpr const char *toPropString(LevelPropNonDefault lvlProp) { - switch (lvlProp) { - case LevelPropNonDefault::Nonunique: - return "nonunique"; - case LevelPropNonDefault::Nonordered: - return "nonordered"; - } - return ""; +/// Get N of NOutOfM level type. +constexpr uint64_t getN(LevelType lt) { + return (static_cast<uint64_t>(lt) >> 32) & 0xff; } -/// This enum defines all the sparse representations supportable by -/// the SparseTensor dialect. We use a lightweight encoding to encode -/// the "format" per se (dense, compressed, singleton, loose_compressed, -/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when -/// the format is NOutOfM. -/// The encoding is chosen for performance of the runtime library, and thus may -/// change in future versions; consequently, client code should use the -/// predicate functions defined below, rather than relying on knowledge -/// about the particular binary encoding. -/// -/// The `Undef` "format" is a special value used internally for cases -/// where we need to store an undefined or indeterminate `LevelType`. -/// It should not be used externally, since it does not indicate an -/// actual/representable format. - -struct LevelType { -public: - /// Check that the `LevelType` contains a valid (possibly undefined) value. - static constexpr bool isValidLvlBits(uint64_t lvlBits) { - auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000); - const uint64_t propertyBits = lvlBits & 0xffff; - // If undefined/dense/NOutOfM, then must be unique and ordered. - // Otherwise, the format must be one of the known ones. - return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense, - LevelFormat::NOutOfM>(fmt)) - ? (propertyBits == 0) - : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton, - LevelFormat::LooseCompressed>(fmt)); - } +/// Get M of NOutOfM level type. +constexpr uint64_t getM(LevelType lt) { + return (static_cast<uint64_t>(lt) >> 40) & 0xff; +} - /// Convert a LevelFormat to its corresponding LevelType with the given - /// properties. Returns std::nullopt when the properties are not applicable - /// for the input level format. - static std::optional<LevelType> - buildLvlType(LevelFormat lf, - const std::vector<LevelPropNonDefault> &properties, - uint64_t n = 0, uint64_t m = 0) { - assert((n & 0xff) == n && (m & 0xff) == m); - uint64_t newN = n << 32; - uint64_t newM = m << 40; - uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM; - for (auto p : properties) - ltBits |= static_cast<uint64_t>(p); - - return isValidLvlBits(ltBits) ? std::optional(LevelType(ltBits)) - : std::nullopt; - } - static std::optional<LevelType> buildLvlType(LevelFormat lf, bool ordered, - bool unique, uint64_t n = 0, - uint64_t m = 0) { - std::vector<LevelPropNonDefault> properties; - if (!ordered) - properties.push_back(LevelPropNonDefault::Nonordered); - if (!unique) - properties.push_back(LevelPropNonDefault::Nonunique); - return buildLvlType(lf, properties, n, m); - } +/// Convert N of NOutOfM level type to the stored bits. +constexpr uint64_t nToBits(uint64_t n) { return n << 32; } - /// Explicit conversion from uint64_t. - constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) { - assert(isValidLvlBits(bits)); - }; +/// Convert M of NOutOfM level type to the stored bits. +constexpr uint64_t mToBits(uint64_t m) { return m << 40; } - /// Constructs a LevelType with the given format using all default properties. - /*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast<uint64_t>(f)) { - assert(isValidLvlBits(lvlBits) && !isa<LevelFormat::NOutOfM>()); - }; +/// Check if the `LevelType` is NOutOfM (regardless of +/// properties and block sizes). +constexpr bool isNOutOfMLT(LevelType lt) { + return ((static_cast<uint64_t>(lt) & 0x100000) == + static_cast<uint64_t>(LevelType::NOutOfM)); +} - /// Converts to uint64_t - explicit operator uint64_t() const { return lvlBits; } +/// Check if the `LevelType` is NOutOfM with the correct block sizes. +constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) { + return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m; +} - bool operator==(const LevelType lhs) const { - return static_cast<uint64_t>(lhs) == lvlBits; +/// Returns string representation of the given dimension level type. +constexpr const char *toMLIRString(LevelType lvlType) { + auto lt = static_cast<LevelType>(static_cast<uint64_t>(lvlType) & 0xffffffff); + switch (lt) { + case LevelType::Undef: + return "undef"; + case LevelType::Dense: + return "dense"; + case LevelType::Compressed: + return "compressed"; + case LevelType::CompressedNu: + return "compressed(nonunique)"; + case LevelType::CompressedNo: + return "compressed(nonordered)"; + case LevelType::CompressedNuNo: + return "compressed(nonunique, nonordered)"; + case LevelType::Singleton: + return "singleton"; + case LevelType::SingletonNu: + return "singleton(nonunique)"; + case LevelType::SingletonNo: + return "singleton(nonordered)"; + case LevelType::SingletonNuNo: + return "singleton(nonunique, nonordered)"; + case LevelType::LooseCompressed: + return "loose_compressed"; + case LevelType::LooseCompressedNu: + return "loose_compressed(nonunique)"; + case LevelType::LooseCompressedNo: + return "loose_compressed(nonordered)"; + case LevelType::LooseCompressedNuNo: + return "loose_compressed(nonunique, nonordered)"; + case LevelType::NOutOfM: + return "structured"; } - bool operator!=(const LevelType lhs) const { return !(*this == lhs); } + return ""; +} - LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); } +/// Check that the `LevelType` contains a valid (possibly undefined) value. +constexpr bool isValidLT(LevelType lt) { + const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000; + const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff; + // If undefined/dense/NOutOfM, then must be unique and ordered. + // Otherwise, the format must be one of the known ones. + return (formatBits <= 0x10000 || formatBits == 0x100000) + ? (propertyBits == 0) + : (formatBits == 0x20000 || formatBits == 0x40000 || + formatBits == 0x80000); +} - /// Get N of NOutOfM level type. - constexpr uint64_t getN() const { - assert(isa<LevelFormat::NOutOfM>()); - return (lvlBits >> 32) & 0xff; - } +/// Check if the `LevelType` is the special undefined value. +constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; } - /// Get M of NOutOfM level type. - constexpr uint64_t getM() const { - assert(isa<LevelFormat::NOutOfM>()); - return (lvlBits >> 40) & 0xff; - } +/// Check if the `LevelType` is dense (regardless of properties). +constexpr bool isDenseLT(LevelType lt) { + return (static_cast<uint64_t>(lt) & ~0xffff) == + static_cast<uint64_t>(LevelType::Dense); +} - /// Get the `LevelFormat` of the `LevelType`. - LevelFormat getLvlFmt() const { - return static_cast<LevelFormat>(lvlBits & 0xffff0000); - } +/// Check if the `LevelType` is compressed (regardless of properties). +constexpr bool isCompressedLT(LevelType lt) { + return (static_cast<uint64_t>(lt) & ~0xffff) == + static_cast<uint64_t>(LevelType::Compressed); +} - /// Check if the `LevelType` is in the `LevelFormat`. - template <LevelFormat fmt> - bool isa() const { - return getLvlFmt() == fmt; - } +/// Check if the `LevelType` is singleton (regardless of properties). +constexpr bool isSingletonLT(LevelType lt) { + return (static_cast<uint64_t>(lt) & ~0xffff) == + static_cast<uint64_t>(LevelType::Singleton); +} - /// Check if the `LevelType` has the properties - template <LevelPropNonDefault p> - bool isa() const { - return lvlBits & static_cast<uint64_t>(p); - } +/// Check if the `LevelType` is loose compressed (regardless of properties). +constexpr bool isLooseCompressedLT(LevelType lt) { + return (static_cast<uint64_t>(lt) & ~0xffff) == + static_cast<uint64_t>(LevelType::LooseCompressed); +} - /// Check if the `LevelType` needs positions array. - bool isWithPosLT() const { - return isa<LevelFormat::Compressed>() || - isa<LevelFormat::LooseCompressed>(); - } +/// Check if the `LevelType` needs positions array. +constexpr bool isWithPosLT(LevelType lt) { + return isCompressedLT(lt) || isLooseCompressedLT(lt); +} - /// Check if the `LevelType` needs coordinates array. - constexpr bool isWithCrdLT() const { - // All sparse levels has coordinate array. - return !isa<LevelFormat::Dense>(); - } +/// Check if the `LevelType` needs coordinates array. +constexpr bool isWithCrdLT(LevelType lt) { + return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) || + isNOutOfMLT(lt); +} - std::string toMLIRString() const { - std::string lvlStr = toFormatString(getLvlFmt()); - std::string propStr = ""; - if (isa<LevelPropNonDefault::Nonunique>()) - propStr += toPropString(LevelPropNonDefault::Nonunique); - - if (isa<LevelPropNonDefault::Nonordered>()) { - if (!propStr.empty()) - propStr += ", "; - propStr += toPropString(LevelPropNonDefault::Nonordered); - } - if (!propStr.empty()) - lvlStr += ("(" + propStr + ")"); - return lvlStr; - } +/// Check if the `LevelType` is ordered (regardless of storage format). +constexpr bool isOrderedLT(LevelType lt) { + return !(static_cast<uint64_t>(lt) & 2); + return !(static_cast<uint64_t>(lt) & 2); +} -private: - /// Bit manipulations for LevelType: - /// - /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty | - /// - uint64_t lvlBits; -}; +/// Check if the `LevelType` is unique (regardless of storage format). +constexpr bool isUniqueLT(LevelType lt) { + return !(static_cast<uint64_t>(lt) & 1); + return !(static_cast<uint64_t>(lt) & 1); +} -// For backward-compatibility. TODO: remove below after fully migration. -constexpr uint64_t nToBits(uint64_t n) { return n << 32; } -constexpr uint64_t mToBits(uint64_t m) { return m << 40; } +/// Convert a LevelType to its corresponding LevelFormat. +/// Returns std::nullopt when input lt is Undef. +constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) { + if (lt == LevelType::Undef) + return std::nullopt; + return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000); +} +/// Convert a LevelFormat to its corresponding LevelType with the given +/// properties. Returns std::nullopt when the properties are not applicable +/// for the input level format. inline std::optional<LevelType> buildLevelType(LevelFormat lf, - const std::vector<LevelPropNonDefault> &properties, + const std::vector<LevelPropertyNondefault> &properties, uint64_t n = 0, uint64_t m = 0) { - return LevelType::buildLvlType(lf, properties, n, m); + uint64_t newN = n << 32; + uint64_t newM = m << 40; + uint64_t ltInt = static_cast<uint64_t>(lf) | newN | newM; + for (auto p : properties) { + ltInt |= static_cast<uint64_t>(p); + } + auto lt = static_cast<LevelType>(ltInt); + return isValidLT(lt) ? std::optional(lt) : std::nullopt; } + inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered, bool unique, uint64_t n = 0, uint64_t m = 0) { - return LevelType::buildLvlType(lf, ordered, unique, n, m); + std::vector<LevelPropertyNondefault> properties; + if (!ordered) + properties.push_back(LevelPropertyNondefault::Nonordered); + if (!unique) + properties.push_back(LevelPropertyNondefault::Nonunique); + return buildLevelType(lf, properties, n, m); } -inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); } -inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); } -inline bool isCompressedLT(LevelType lt) { - return lt.isa<LevelFormat::Compressed>(); -} -inline bool isLooseCompressedLT(LevelType lt) { - return lt.isa<LevelFormat::LooseCompressed>(); -} -inline bool isSingletonLT(LevelType lt) { - return lt.isa<LevelFormat::Singleton>(); -} -inline bool isNOutOfMLT(LevelType lt) { return lt.isa<LevelFormat::NOutOfM>(); } -inline bool isOrderedLT(LevelType lt) { - return !lt.isa<LevelPropNonDefault::Nonordered>(); -} -inline bool isUniqueLT(LevelType lt) { - return !lt.isa<LevelPropNonDefault::Nonunique>(); -} -inline bool isWithCrdLT(LevelType lt) { return lt.isWithCrdLT(); } -inline bool isWithPosLT(LevelType lt) { return lt.isWithPosLT(); } -inline bool isValidLT(LevelType lt) { - return LevelType::isValidLvlBits(static_cast<uint64_t>(lt)); -} -inline std::optional<LevelFormat> getLevelFormat(LevelType lt) { - LevelFormat fmt = lt.getLvlFmt(); - if (fmt == LevelFormat::Undef) - return std::nullopt; - return fmt; -} -inline uint64_t getN(LevelType lt) { return lt.getN(); } -inline uint64_t getM(LevelType lt) { return lt.getM(); } -inline bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) { - return isNOutOfMLT(lt) && lt.getN() == n && lt.getM() == m; -} -inline std::string toMLIRString(LevelType lt) { return lt.toMLIRString(); } + +// +// Ensure the above methods work as intended. +// + +static_assert( + (getLevelFormat(LevelType::Undef) == std::nullopt && + *getLevelFormat(LevelType::Dense) == LevelFormat::Dense && + *getLevelFormat(LevelType::Compressed) == LevelFormat::Compressed && + *getLevelFormat(LevelType::CompressedNu) == LevelFormat::Compressed && + *getLevelFormat(LevelType::CompressedNo) == LevelFormat::Compressed && + *getLevelFormat(LevelType::CompressedNuNo) == LevelFormat::Compressed && + *getLevelFormat(LevelType::Singleton) == LevelFormat::Singleton && + *getLevelFormat(LevelType::SingletonNu) == LevelFormat::Singleton && + *getLevelFormat(LevelType::SingletonNo) == LevelFormat::Singleton && + *getLevelFormat(LevelType::SingletonNuNo) == LevelFormat::Singleton && + *getLevelFormat(LevelType::LooseCompressed) == + LevelFormat::LooseCompressed && + *getLevelFormat(LevelType::LooseCompressedNu) == + LevelFormat::LooseCompressed && + *getLevelFormat(LevelType::LooseCompressedNo) == + LevelFormat::LooseCompressed && + *getLevelFormat(LevelType::LooseCompressedNuNo) == + LevelFormat::LooseCompressed && + *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM), + "getLevelFormat conversion is broken"); + +static_assert( + (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) && + isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) && + isValidLT(LevelType::CompressedNo) && + isValidLT(LevelType::CompressedNuNo) && isValidLT(LevelType::Singleton) && + isValidLT(LevelType::SingletonNu) && isValidLT(LevelType::SingletonNo) && + isValidLT(LevelType::SingletonNuNo) && + isValidLT(LevelType::LooseCompressed) && + isValidLT(LevelType::LooseCompressedNu) && + isValidLT(LevelType::LooseCompressedNo) && + isValidLT(LevelType::LooseCompressedNuNo) && + isValidLT(LevelType::NOutOfM)), + "isValidLT definition is broken"); + +static_assert((isDenseLT(LevelType::Dense) && + !isDenseLT(LevelType::Compressed) && + !isDenseLT(LevelType::CompressedNu) && + !isDenseLT(LevelType::CompressedNo) && + !isDenseLT(LevelType::CompressedNuNo) && + !isDenseLT(LevelType::Singleton) && + !isDenseLT(LevelType::SingletonNu) && + !isDenseLT(LevelType::SingletonNo) && + !isDenseLT(LevelType::SingletonNuNo) && + !isDenseLT(LevelType::LooseCompressed) && + !isDenseLT(LevelType::LooseCompressedNu) && + !isDenseLT(LevelType::LooseCompressedNo) && + !isDenseLT(LevelType::LooseCompressedNuNo) && + !isDenseLT(LevelType::NOutOfM)), + "isDenseLT definition is broken"); + +static_assert((!isCompressedLT(LevelType::Dense) && + isCompressedLT(LevelType::Compressed) && + isCompressedLT(LevelType::CompressedNu) && + isCompressedLT(LevelType::CompressedNo) && + isCompressedLT(LevelType::CompressedNuNo) && + !isCompressedLT(LevelType::Singleton) && + !isCompressedLT(LevelType::SingletonNu) && + !isCompressedLT(LevelType::SingletonNo) && + !isCompressedLT(LevelType::SingletonNuNo) && + !isCompressedLT(LevelType::LooseCompressed) && + !isCompressedLT(LevelType::LooseCompressedNu) && + !isCompressedLT(LevelType::LooseCompressedNo) && + !isCompressedLT(LevelType::LooseCompressedNuNo) && + !isCompressedLT(LevelType::NOutOfM)), + "isCompressedLT definition is broken"); + +static_assert((!isSingletonLT(LevelType::Dense) && + !isSingletonLT(LevelType::Compressed) && + !isSingletonLT(LevelType::CompressedNu) && + !isSingletonLT(LevelType::CompressedNo) && + !isSingletonLT(LevelType::CompressedNuNo) && + isSingletonLT(LevelType::Singleton) && + isSingletonLT(LevelType::SingletonNu) && + isSingletonLT(LevelType::SingletonNo) && + isSingletonLT(LevelType::SingletonNuNo) && + !isSingletonLT(LevelType::LooseCompressed) && + !isSingletonLT(LevelType::LooseCompressedNu) && + !isSingletonLT(LevelType::LooseCompressedNo) && + !isSingletonLT(LevelType::LooseCompressedNuNo) && + !isSingletonLT(LevelType::NOutOfM)), + "isSingletonLT definition is broken"); + +static_assert((!isLooseCompressedLT(LevelType::Dense) && + !isLooseCompressedLT(LevelType::Compressed) && + !isLooseCompressedLT(LevelType::CompressedNu) && + !isLooseCompressedLT(LevelType::CompressedNo) && + !isLooseCompressedLT(LevelType::CompressedNuNo) && + !isLooseCompressedLT(LevelType::Singleton) && + !isLooseCompressedLT(LevelType::SingletonNu) && + !isLooseCompressedLT(LevelType::SingletonNo) && + !isLooseCompressedLT(LevelType::SingletonNuNo) && + isLooseCompressedLT(LevelType::LooseCompressed) && + isLooseCompressedLT(LevelType::LooseCompressedNu) && + isLooseCompressedLT(LevelType::LooseCompressedNo) && + isLooseCompressedLT(LevelType::LooseCompressedNuNo) && + !isLooseCompressedLT(LevelType::NOutOfM)), + "isLooseCompressedLT definition is broken"); + +static_assert((!isNOutOfMLT(LevelType::Dense) && + !isNOutOfMLT(LevelType::Compressed) && + !isNOutOfMLT(LevelType::CompressedNu) && + !isNOutOfMLT(LevelType::CompressedNo) && + !isNOutOfMLT(LevelType::CompressedNuNo) && + !isNOutOfMLT(LevelType::Singleton) && + !isNOutOfMLT(LevelType::SingletonNu) && + !isNOutOfMLT(LevelType::SingletonNo) && + !isNOutOfMLT(LevelType::SingletonNuNo) && + !isNOutOfMLT(LevelType::LooseCompressed) && + !isNOutOfMLT(LevelType::LooseCompressedNu) && + !isNOutOfMLT(LevelType::LooseCompressedNo) && + !isNOutOfMLT(LevelType::LooseCompressedNuNo) && + isNOutOfMLT(LevelType::NOutOfM)), + "isNOutOfMLT definition is broken"); + +static_assert((isOrderedLT(LevelType::Dense) && + isOrderedLT(LevelType::Compressed) && + isOrderedLT(LevelType::CompressedNu) && + !isOrderedLT(LevelType::CompressedNo) && + !isOrderedLT(LevelType::CompressedNuNo) && + isOrderedLT(LevelType::Singleton) && + isOrderedLT(LevelType::SingletonNu) && + !isOrderedLT(LevelType::SingletonNo) && + !isOrderedLT(LevelType::SingletonNuNo) && + isOrderedLT(LevelType::LooseCompressed) && + isOrderedLT(LevelType::LooseCompressedNu) && + !isOrderedLT(LevelType::LooseCompressedNo) && + !isOrderedLT(LevelType::LooseCompressedNuNo) && + isOrderedLT(LevelType::NOutOfM)), + "isOrderedLT definition is broken"); + +static_assert((isUniqueLT(LevelType::Dense) && + isUniqueLT(LevelType::Compressed) && + !isUniqueLT(LevelType::CompressedNu) && + isUniqueLT(LevelType::CompressedNo) && + !isUniqueLT(LevelType::CompressedNuNo) && + isUniqueLT(LevelType::Singleton) && + !isUniqueLT(LevelType::SingletonNu) && + isUniqueLT(LevelType::SingletonNo) && + !isUniqueLT(LevelType::SingletonNuNo) && + isUniqueLT(LevelType::LooseCompressed) && + !isUniqueLT(LevelType::LooseCompressedNu) && + isUniqueLT(LevelType::LooseCompressedNo) && + !isUniqueLT(LevelType::LooseCompressedNuNo) && + isUniqueLT(LevelType::NOutOfM)), + "isUniqueLT definition is broken"); /// Bit manipulations for affine encoding. /// diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 3ae06f220c5281..55af8becbba20e 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -34,9 +34,9 @@ static_assert( "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch"); static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) == - static_cast<int>(LevelPropNonDefault::Nonordered) && + static_cast<int>(LevelPropertyNondefault::Nonordered) && static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) == - static_cast<int>(LevelPropNonDefault::Nonunique), + static_cast<int>(LevelPropertyNondefault::Nonunique), "MlirSparseTensorLevelProperty (C-API) and " "LevelPropertyNondefault (C++) mismatch"); @@ -80,7 +80,7 @@ enum MlirSparseTensorLevelFormat mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) { LevelType lt = static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl)); - return static_cast<MlirSparseTensorLevelFormat>(lt.getLvlFmt()); + return static_cast<MlirSparseTensorLevelFormat>(*getLevelFormat(lt)); } int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { @@ -96,9 +96,9 @@ MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( const enum MlirSparseTensorLevelPropertyNondefault *properties, unsigned size, unsigned n, unsigned m) { - std::vector<LevelPropNonDefault> props; + std::vector<LevelPropertyNondefault> props; for (unsigned i = 0; i < size; i++) - props.push_back(static_cast<LevelPropNonDefault>(properties[i])); + props.push_back(static_cast<LevelPropertyNondefault>(properties[i])); return static_cast<MlirSparseTensorLevelType>( *buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m)); diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp index 380cccc989ec6a..0fb0d2761054b5 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp @@ -88,9 +88,9 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser, ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)), "expected valid level property (e.g. nonordered, nonunique or high)") if (strVal.compare("nonunique") == 0) { - *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique); + *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique); } else if (strVal.compare("nonordered") == 0) { - *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered); + *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered); } else { parser.emitError(loc, "unknown level property: ") << strVal; return failure(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 6d02645d860e96..aed43f26d54f11 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -35,14 +35,6 @@ using namespace mlir; using namespace mlir::sparse_tensor; -// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as -// well. -namespace mlir::sparse_tensor { -llvm::hash_code hash_value(LevelType lt) { - return llvm::hash_value(static_cast<uint64_t>(lt)); -} -} // namespace mlir::sparse_tensor - //===----------------------------------------------------------------------===// // Local Convenience Methods. //===----------------------------------------------------------------------===// @@ -91,11 +83,11 @@ void StorageLayout::foreachField( } // The values array. if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, - LevelFormat::Undef))) + LevelType::Undef))) return; // Put metadata at the end. if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, - LevelFormat::Undef))) + LevelType::Undef))) return; } @@ -349,7 +341,7 @@ Level SparseTensorEncodingAttr::getLvlRank() const { LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { if (!getImpl()) - return LevelFormat::Dense; + return LevelType::Dense; assert(l < getLvlRank() && "Level is out of bounds"); return getLvlTypes()[l]; } @@ -983,7 +975,7 @@ static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { SmallVector<LevelType> lts; for (auto lt : enc.getLvlTypes()) - lts.push_back(lt.stripProperties()); + lts.push_back(*buildLevelType(*getLevelFormat(lt), true, true)); return SparseTensorEncodingAttr::get( enc.getContext(), lts, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 7326a6a3811284..235c5453f9cc98 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -46,7 +46,7 @@ static bool isZeroValue(Value val) { static bool isSparseTensor(Value v) { auto enc = getSparseTensorEncoding(v.getType()); return enc && !llvm::all_of(enc.getLvlTypes(), - [](auto lt) { return lt == LevelFormat::Dense; }); + [](auto lt) { return lt == LevelType::Dense; }); } static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp index 61a3703b73bf07..c85f8204ba7527 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp @@ -63,7 +63,7 @@ class SparseLevel : public SparseTensorLevel { class DenseLevel : public SparseTensorLevel { public: DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded) - : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize), + : SparseTensorLevel(tid, lvl, LevelType::Dense, lvlSize), encoded(encoded) {} Value peekCrdAt(OpBuilder &, Location, Value pos) const override { @@ -1275,7 +1275,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t, Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult() : b.create<tensor::DimOp>(l, t, lvl).getResult(); - switch (lt.getLvlFmt()) { + switch (*getLevelFormat(lt)) { case LevelFormat::Dense: return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding()); case LevelFormat::Compressed: { @@ -1296,8 +1296,6 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t, Value crd = genToCoordinates(b, l, t, lvl); return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd); } - case LevelFormat::Undef: - llvm_unreachable("undefined level format"); } llvm_unreachable("unrecognizable level format"); } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 731cd79a1e3b4b..96537cbb0c4836 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -226,8 +226,7 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops, syntheticTensor(numInputOutputTensors), numTensors(numInputOutputTensors + 1), numLoops(numLoops), hasSparseOut(false), - lvlTypes(numTensors, - std::vector<LevelType>(numLoops, LevelFormat::Undef)), + lvlTypes(numTensors, std::vector<LevelType>(numLoops, LevelType::Undef)), loopToLvl(numTensors, std::vector<std::optional<Level>>(numLoops, std::nullopt)), lvlToLoop(numTensors, diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp index 62a19c084cac0f..ce9c0e39b31b95 100644 --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -313,11 +313,11 @@ class MergerTest3T1L : public MergerTestBase { MergerTest3T1L() : MergerTestBase(3, 1) { EXPECT_TRUE(merger.getOutTensorID() == tid(2)); // Tensor 0: sparse input vector. - merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed); + merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed); // Tensor 1: sparse input vector. - merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed); + merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed); // Tensor 2: dense output vector. - merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense); + merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense); } }; @@ -327,13 +327,13 @@ class MergerTest4T1L : public MergerTestBase { MergerTest4T1L() : MergerTestBase(4, 1) { EXPECT_TRUE(merger.getOutTensorID() == tid(3)); // Tensor 0: sparse input vector. - merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed); + merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed); // Tensor 1: sparse input vector. - merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed); + merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed); // Tensor 2: sparse input vector - merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed); + merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed); // Tensor 3: dense output vector - merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense); + merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense); } }; @@ -347,11 +347,11 @@ class MergerTest3T1LD : public MergerTestBase { MergerTest3T1LD() : MergerTestBase(3, 1) { EXPECT_TRUE(merger.getOutTensorID() == tid(2)); // Tensor 0: sparse input vector. - merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed); + merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed); // Tensor 1: dense input vector. - merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense); + merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense); // Tensor 2: dense output vector. - merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense); + merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense); } }; @@ -365,13 +365,13 @@ class MergerTest4T1LU : public MergerTestBase { MergerTest4T1LU() : MergerTestBase(4, 1) { EXPECT_TRUE(merger.getOutTensorID() == tid(3)); // Tensor 0: undef input vector. - merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef); + merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef); // Tensor 1: dense input vector. - merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense); + merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense); // Tensor 2: undef input vector. - merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef); + merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Undef); // Tensor 3: dense output vector. - merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense); + merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense); } }; @@ -387,11 +387,11 @@ class MergerTest3T1LSo : public MergerTestBase { EXPECT_TRUE(merger.getSynTensorID() == tid(3)); merger.setHasSparseOut(true); // Tensor 0: undef input vector. - merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef); + merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef); // Tensor 1: undef input vector. - merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Undef); + merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Undef); // Tensor 2: sparse output vector. - merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed); + merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed); } }; _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits