https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/79935
>From fa5210448dea1f88d8e0a242543ad1be655087e0 Mon Sep 17 00:00:00 2001 From: Yinying Li <yinyin...@google.com> Date: Tue, 30 Jan 2024 01:01:52 +0000 Subject: [PATCH 1/3] [mlir][sparse] Expand LevelType to 64 bit and implement n out of m --- mlir/include/mlir-c/Dialect/SparseTensor.h | 28 +-- .../mlir/Dialect/SparseTensor/IR/Enums.h | 225 +++++++++++------- .../SparseTensor/IR/SparseTensorAttrDefs.td | 4 +- .../SparseTensor/IR/SparseTensorType.h | 2 +- .../mlir/Dialect/SparseTensor/Utils/Merger.h | 2 +- .../ExecutionEngine/SparseTensor/Storage.h | 14 +- .../Bindings/Python/DialectSparseTensor.cpp | 2 +- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 49 ++-- .../IR/Detail/DimLvlMapParser.cpp | 2 + .../SparseTensor/IR/Detail/LvlTypeParser.cpp | 55 ++++- .../SparseTensor/IR/Detail/LvlTypeParser.h | 6 +- .../Transforms/SparseGPUCodegen.cpp | 2 +- .../Transforms/SparseTensorCodegen.cpp | 6 +- .../Transforms/Sparsification.cpp | 2 +- .../Transforms/Utils/CodegenUtils.h | 2 +- .../Transforms/Utils/SparseTensorLevel.cpp | 2 +- .../lib/Dialect/SparseTensor/Utils/Merger.cpp | 4 +- .../ExecutionEngine/SparseTensor/Storage.cpp | 2 +- mlir/test/CAPI/sparse_tensor.c | 6 +- .../SparseTensor/GPU/gpu_matmul24_lib.mlir | 2 +- .../test/Dialect/SparseTensor/conversion.mlir | 16 +- .../SparseTensor/roundtrip_encoding.mlir | 12 +- .../SparseTensor/sparse_fill_zero.mlir | 12 +- .../SparseTensor/CPU/sparse_block_matmul.mlir | 2 +- .../Dialect/SparseTensor/CPU/sparse_ds.mlir | 2 +- .../CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir | 2 +- .../CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir | 2 +- .../python/dialects/sparse_tensor/dialect.py | 148 ++++++------ 28 files changed, 358 insertions(+), 255 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 41d024db04964..5fc1f51452482 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -26,20 +26,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); /// If updating, keep them in sync and update the static_assert in the impl /// file. enum MlirSparseTensorLevelType { - MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00 - MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00 - MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01 - MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10, // 0b00010_10 - MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11 - MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16, // 0b00100_00 - MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17, // 0b00100_01 - MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18, // 0b00100_10 - MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11 - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00 - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01 - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10 - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11 - MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00 + MLIR_SPARSE_TENSOR_LEVEL_DENSE = 65536, // 0x00_00_0001_0000 + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 131072, // 0x00_00_0002_0000 + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 131073, // 0x00_00_0002_0001 + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 131074, // 0x00_00_0002_0002 + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 131075, // 0x00_00_0002_0003 + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 262144, // 0x00_00_0004_0000 + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 262145, // 0x00_00_0004_0001 + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 262146, // 0x00_00_0004_0002 + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 262147, // 0x00_00_0004_0003 + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 524288, // 0x00_00_0008_0000 + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 524289, // 0x00_00_0008_0001 + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 524290, // 0x00_00_0008_0002 + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 524291, // 0x00_00_0008_0003 + MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 1048576, // 0x00_00_0010_0000 }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h index ac91bfa5ae622..6ddc9326179fe 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -154,9 +154,10 @@ enum class Action : uint32_t { /// This enum defines all the sparse representations supportable by /// the SparseTensor dialect. We use a lightweight encoding to encode -/// both the "format" per se (dense, compressed, singleton, loose_compressed, -/// two-out-of-four) as well as the "properties" (ordered, unique). The -/// encoding is chosen for performance of the runtime library, and thus may +/// the "format" per se (dense, compressed, singleton, loose_compressed, +/// n-out-of-m), the "properties" (ordered, unique) as well as n and m for +/// NOutOfM level type. +/// The encoding is chosen for performance of the runtime library, and thus may /// change in future versions; consequently, client code should use the /// predicate functions defined below, rather than relying on knowledge /// about the particular binary encoding. @@ -165,41 +166,74 @@ enum class Action : uint32_t { /// where we need to store an undefined or indeterminate `LevelType`. /// It should not be used externally, since it does not indicate an /// actual/representable format. -enum class LevelType : uint8_t { - Undef = 0, // 0b00000_00 - Dense = 4, // 0b00001_00 - Compressed = 8, // 0b00010_00 - CompressedNu = 9, // 0b00010_01 - CompressedNo = 10, // 0b00010_10 - CompressedNuNo = 11, // 0b00010_11 - Singleton = 16, // 0b00100_00 - SingletonNu = 17, // 0b00100_01 - SingletonNo = 18, // 0b00100_10 - SingletonNuNo = 19, // 0b00100_11 - LooseCompressed = 32, // 0b01000_00 - LooseCompressedNu = 33, // 0b01000_01 - LooseCompressedNo = 34, // 0b01000_10 - LooseCompressedNuNo = 35, // 0b01000_11 - TwoOutOfFour = 64, // 0b10000_00 +/// +/// Bit manipulations for LevelType: +/// +/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty | +/// +enum class LevelType : uint64_t { + Undef = 0, // 0x00_00_0000_0000 + Dense = 65536, // 0x00_00_0001_0000 + Compressed = 131072, // 0x00_00_0002_0000 + CompressedNu = 131073, // 0x00_00_0002_0001 + CompressedNo = 131074, // 0x00_00_0002_0002 + CompressedNuNo = 131075, // 0x00_00_0002_0003 + Singleton = 262144, // 0x00_00_0004_0000 + SingletonNu = 262145, // 0x00_00_0004_0001 + SingletonNo = 262146, // 0x00_00_0004_0002 + SingletonNuNo = 262147, // 0x00_00_0004_0003 + LooseCompressed = 524288, // 0x00_00_0008_0000 + LooseCompressedNu = 524289, // 0x00_00_0008_0001 + LooseCompressedNo = 524290, // 0x00_00_0008_0002 + LooseCompressedNuNo = 524291, // 0x00_00_0008_0003 + NOutOfM = 1048576, // 0x00_00_0010_0000 }; /// This enum defines all supported storage format without the level properties. -enum class LevelFormat : uint8_t { - Dense = 4, // 0b00001_00 - Compressed = 8, // 0b00010_00 - Singleton = 16, // 0b00100_00 - LooseCompressed = 32, // 0b01000_00 - TwoOutOfFour = 64, // 0b10000_00 +enum class LevelFormat : uint64_t { + Dense = 65536, // 0x0001_0000 + Compressed = 131072, // 0x0002_0000 + Singleton = 262144, // 0x0004_0000 + LooseCompressed = 524288, // 0x0008_0000 + NOutOfM = 1048576, // 0x0010_0000 }; /// This enum defines all the nondefault properties for storage formats. -enum class LevelPropertyNondefault : uint8_t { - Nonunique = 1, // 0b00000_01 - Nonordered = 2, // 0b00000_10 +enum class LevelPropertyNondefault : uint64_t { + Nonunique = 1, // 0x0001 + Nonordered = 2, // 0x0002 }; +/// Get N of NOutOfM level type. +constexpr uint64_t getN(LevelType lt) { + return (static_cast<uint64_t>(lt) >> 32) & 0xff; +} + +/// Get M of NOutOfM level type. +constexpr uint64_t getM(LevelType lt) { + return (static_cast<uint64_t>(lt) >> 40) & 0xff; +} + +/// Convert N of NOutOfM level type to the stored bits. +constexpr uint64_t nToBits(uint64_t n) { return n << 32; } + +/// Convert M of NOutOfM level type to the stored bits. +constexpr uint64_t mToBits(uint64_t m) { return m << 40; } + +/// Check if the `LevelType` is NOutOfM (regardless of +/// properties and block sizes). +constexpr bool isNOutOfMLT(LevelType lt) { + return ((static_cast<uint64_t>(lt) & 0x100000) == + static_cast<uint64_t>(LevelType::NOutOfM)); +} + +/// Check if the `LevelType` is NOutOfM with the correct block sizes. +constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) { + return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m; +} + /// Returns string representation of the given dimension level type. -constexpr const char *toMLIRString(LevelType lt) { +std::string toMLIRString(LevelType lt) { switch (lt) { case LevelType::Undef: return "undef"; @@ -229,21 +263,28 @@ constexpr const char *toMLIRString(LevelType lt) { return "loose_compressed(nonordered)"; case LevelType::LooseCompressedNuNo: return "loose_compressed(nonunique, nonordered)"; - case LevelType::TwoOutOfFour: - return "block2_4"; + default: + // If NOutOfM bit is set, print the [n, m] sizes. + if (isNOutOfMLT(lt)) { + unsigned n = getN(lt); + unsigned m = getM(lt); + return std::string("block[") + std::to_string(n) + ", " + + std::to_string(m) + "]"; + } } return ""; } /// Check that the `LevelType` contains a valid (possibly undefined) value. constexpr bool isValidLT(LevelType lt) { - const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2; - const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3; - // If undefined or dense, then must be unique and ordered. + const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000; + const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff; + // If undefined/dense/NOutOfM, then must be unique and ordered. // Otherwise, the format must be one of the known ones. - return (formatBits <= 1 || formatBits == 16) + return (formatBits <= 0x10000 || formatBits == 0x100000) ? (propertyBits == 0) - : (formatBits == 2 || formatBits == 4 || formatBits == 8); + : (formatBits == 0x20000 || formatBits == 0x40000 || + formatBits == 0x80000); } /// Check if the `LevelType` is the special undefined value. @@ -251,33 +292,28 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; } /// Check if the `LevelType` is dense (regardless of properties). constexpr bool isDenseLT(LevelType lt) { - return (static_cast<uint8_t>(lt) & ~3) == - static_cast<uint8_t>(LevelType::Dense); + return (static_cast<uint64_t>(lt) & ~0xffff) == + static_cast<uint64_t>(LevelType::Dense); } /// Check if the `LevelType` is compressed (regardless of properties). constexpr bool isCompressedLT(LevelType lt) { - return (static_cast<uint8_t>(lt) & ~3) == - static_cast<uint8_t>(LevelType::Compressed); + return (static_cast<uint64_t>(lt) & ~0xffff) == + static_cast<uint64_t>(LevelType::Compressed); } /// Check if the `LevelType` is singleton (regardless of properties). constexpr bool isSingletonLT(LevelType lt) { - return (static_cast<uint8_t>(lt) & ~3) == - static_cast<uint8_t>(LevelType::Singleton); + return (static_cast<uint64_t>(lt) & ~0xffff) == + static_cast<uint64_t>(LevelType::Singleton); } /// Check if the `LevelType` is loose compressed (regardless of properties). constexpr bool isLooseCompressedLT(LevelType lt) { - return (static_cast<uint8_t>(lt) & ~3) == - static_cast<uint8_t>(LevelType::LooseCompressed); + return (static_cast<uint64_t>(lt) & ~0xffff) == + static_cast<uint64_t>(LevelType::LooseCompressed); } -/// Check if the `LevelType` is 2OutOf4 (regardless of properties). -constexpr bool is2OutOf4LT(LevelType lt) { - return (static_cast<uint8_t>(lt) & ~3) == - static_cast<uint8_t>(LevelType::TwoOutOfFour); -} /// Check if the `LevelType` needs positions array. constexpr bool isWithPosLT(LevelType lt) { @@ -287,17 +323,17 @@ constexpr bool isWithPosLT(LevelType lt) { /// Check if the `LevelType` needs coordinates array. constexpr bool isWithCrdLT(LevelType lt) { return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) || - is2OutOf4LT(lt); + isNOutOfMLT(lt); } /// Check if the `LevelType` is ordered (regardless of storage format). constexpr bool isOrderedLT(LevelType lt) { - return !(static_cast<uint8_t>(lt) & 2); + return !(static_cast<uint64_t>(lt) & 2); } /// Check if the `LevelType` is unique (regardless of storage format). constexpr bool isUniqueLT(LevelType lt) { - return !(static_cast<uint8_t>(lt) & 1); + return !(static_cast<uint64_t>(lt) & 1); } /// Convert a LevelType to its corresponding LevelFormat. @@ -305,21 +341,25 @@ constexpr bool isUniqueLT(LevelType lt) { constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) { if (lt == LevelType::Undef) return std::nullopt; - return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3); + return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000); } /// Convert a LevelFormat to its corresponding LevelType with the given /// properties. Returns std::nullopt when the properties are not applicable /// for the input level format. constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered, - bool unique) { - auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) | - (ordered ? 0 : 2) | (unique ? 0 : 1)); + bool unique, uint64_t n = 0, + uint64_t m = 0) { + uint64_t newN = n << 32; + uint64_t newM = m << 40; + auto lt = + static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) | + (unique ? 0 : 1) | newN | newM); return isValidLT(lt) ? std::optional(lt) : std::nullopt; } // -// Ensure the above methods work as indended. +// Ensure the above methods work as intended. // static_assert( @@ -341,7 +381,7 @@ static_assert( LevelFormat::LooseCompressed && *getLevelFormat(LevelType::LooseCompressedNuNo) == LevelFormat::LooseCompressed && - *getLevelFormat(LevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour), + *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM), "getLevelFormat conversion is broken"); static_assert( @@ -373,13 +413,28 @@ static_assert( LevelType::LooseCompressedNo && *buildLevelType(LevelFormat::LooseCompressed, false, false) == LevelType::LooseCompressedNuNo && - buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt && - buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt && - buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt && - *buildLevelType(LevelFormat::TwoOutOfFour, true, true) == - LevelType::TwoOutOfFour), + buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt && + buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt && + buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt && + *buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM), "buildLevelType conversion is broken"); +static_assert( + (getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 && + getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 && + getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 && + getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10), + "getN/M conversion is broken"); + +static_assert( + (isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4), + 2, 4) && + isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10), + 8, 10) && + !isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4), + 2, 4)), + "isValidNOutOfMLT definition is broken"); + static_assert( (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) && isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) && @@ -391,7 +446,7 @@ static_assert( isValidLT(LevelType::LooseCompressedNu) && isValidLT(LevelType::LooseCompressedNo) && isValidLT(LevelType::LooseCompressedNuNo) && - isValidLT(LevelType::TwoOutOfFour)), + isValidLT(LevelType::NOutOfM)), "isValidLT definition is broken"); static_assert((isDenseLT(LevelType::Dense) && @@ -407,7 +462,7 @@ static_assert((isDenseLT(LevelType::Dense) && !isDenseLT(LevelType::LooseCompressedNu) && !isDenseLT(LevelType::LooseCompressedNo) && !isDenseLT(LevelType::LooseCompressedNuNo) && - !isDenseLT(LevelType::TwoOutOfFour)), + !isDenseLT(LevelType::NOutOfM)), "isDenseLT definition is broken"); static_assert((!isCompressedLT(LevelType::Dense) && @@ -423,7 +478,7 @@ static_assert((!isCompressedLT(LevelType::Dense) && !isCompressedLT(LevelType::LooseCompressedNu) && !isCompressedLT(LevelType::LooseCompressedNo) && !isCompressedLT(LevelType::LooseCompressedNuNo) && - !isCompressedLT(LevelType::TwoOutOfFour)), + !isCompressedLT(LevelType::NOutOfM)), "isCompressedLT definition is broken"); static_assert((!isSingletonLT(LevelType::Dense) && @@ -439,7 +494,7 @@ static_assert((!isSingletonLT(LevelType::Dense) && !isSingletonLT(LevelType::LooseCompressedNu) && !isSingletonLT(LevelType::LooseCompressedNo) && !isSingletonLT(LevelType::LooseCompressedNuNo) && - !isSingletonLT(LevelType::TwoOutOfFour)), + !isSingletonLT(LevelType::NOutOfM)), "isSingletonLT definition is broken"); static_assert((!isLooseCompressedLT(LevelType::Dense) && @@ -455,24 +510,24 @@ static_assert((!isLooseCompressedLT(LevelType::Dense) && isLooseCompressedLT(LevelType::LooseCompressedNu) && isLooseCompressedLT(LevelType::LooseCompressedNo) && isLooseCompressedLT(LevelType::LooseCompressedNuNo) && - !isLooseCompressedLT(LevelType::TwoOutOfFour)), + !isLooseCompressedLT(LevelType::NOutOfM)), "isLooseCompressedLT definition is broken"); -static_assert((!is2OutOf4LT(LevelType::Dense) && - !is2OutOf4LT(LevelType::Compressed) && - !is2OutOf4LT(LevelType::CompressedNu) && - !is2OutOf4LT(LevelType::CompressedNo) && - !is2OutOf4LT(LevelType::CompressedNuNo) && - !is2OutOf4LT(LevelType::Singleton) && - !is2OutOf4LT(LevelType::SingletonNu) && - !is2OutOf4LT(LevelType::SingletonNo) && - !is2OutOf4LT(LevelType::SingletonNuNo) && - !is2OutOf4LT(LevelType::LooseCompressed) && - !is2OutOf4LT(LevelType::LooseCompressedNu) && - !is2OutOf4LT(LevelType::LooseCompressedNo) && - !is2OutOf4LT(LevelType::LooseCompressedNuNo) && - is2OutOf4LT(LevelType::TwoOutOfFour)), - "is2OutOf4LT definition is broken"); +static_assert((!isNOutOfMLT(LevelType::Dense) && + !isNOutOfMLT(LevelType::Compressed) && + !isNOutOfMLT(LevelType::CompressedNu) && + !isNOutOfMLT(LevelType::CompressedNo) && + !isNOutOfMLT(LevelType::CompressedNuNo) && + !isNOutOfMLT(LevelType::Singleton) && + !isNOutOfMLT(LevelType::SingletonNu) && + !isNOutOfMLT(LevelType::SingletonNo) && + !isNOutOfMLT(LevelType::SingletonNuNo) && + !isNOutOfMLT(LevelType::LooseCompressed) && + !isNOutOfMLT(LevelType::LooseCompressedNu) && + !isNOutOfMLT(LevelType::LooseCompressedNo) && + !isNOutOfMLT(LevelType::LooseCompressedNuNo) && + isNOutOfMLT(LevelType::NOutOfM)), + "isNOutOfMLT definition is broken"); static_assert((isOrderedLT(LevelType::Dense) && isOrderedLT(LevelType::Compressed) && @@ -487,7 +542,7 @@ static_assert((isOrderedLT(LevelType::Dense) && isOrderedLT(LevelType::LooseCompressedNu) && !isOrderedLT(LevelType::LooseCompressedNo) && !isOrderedLT(LevelType::LooseCompressedNuNo) && - isOrderedLT(LevelType::TwoOutOfFour)), + isOrderedLT(LevelType::NOutOfM)), "isOrderedLT definition is broken"); static_assert((isUniqueLT(LevelType::Dense) && @@ -503,7 +558,7 @@ static_assert((isUniqueLT(LevelType::Dense) && !isUniqueLT(LevelType::LooseCompressedNu) && isUniqueLT(LevelType::LooseCompressedNo) && !isUniqueLT(LevelType::LooseCompressedNuNo) && - isUniqueLT(LevelType::TwoOutOfFour)), + isUniqueLT(LevelType::NOutOfM)), "isUniqueLT definition is broken"); /// Bit manipulations for affine encoding. diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td index 12c1068ae1f54..299ba0e603089 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -145,7 +145,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding", - **compressed** : only nonzeros along this level are stored - **loose_compressed** : as compressed, but allows for free space between regions - **singleton** : a variant of the compressed format, where coordinates have no siblings - - **block2_4** : the compression uses a 2:4 encoding per 1x4 block + - **block[2, 4]** : the compression uses a 2:4 encoding per 1x4 block For a compressed level, each position interval is represented in a compact way with a lowerbound `pos(i)` and an upperbound `pos(i+1) - 1`, which implies @@ -374,7 +374,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding", bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedLT(getLvlType(l)); } bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonLT(getLvlType(l)); } bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedLT(getLvlType(l)); } - bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4LT(getLvlType(l)); } + bool isNOutOfMLvl(::mlir::sparse_tensor::Level l) const { return isNOutOfMLT(getLvlType(l)); } bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedLT(getLvlType(l)); } bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueLT(getLvlType(l)); } diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h index 4c98129744bcd..4e2b85d35c1ac 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -291,7 +291,7 @@ class SparseTensorType { return isLooseCompressedLT(getLvlType(l)); } bool isSingletonLvl(Level l) const { return isSingletonLT(getLvlType(l)); } - bool is2OutOf4Lvl(Level l) const { return is2OutOf4LT(getLvlType(l)); } + bool isNOutOfMLvl(Level l) const { return isNOutOfMLT(getLvlType(l)); } bool isOrderedLvl(Level l) const { return isOrderedLT(getLvlType(l)); } bool isUniqueLvl(Level l) const { return isUniqueLT(getLvlType(l)); } bool isWithPos(Level l) const { return isWithPosLT(getLvlType(l)); } diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h index 4a34bb2e003e8..490ef3071af1b 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -510,7 +510,7 @@ class Merger { if (isLvlWithNonTrivialIdxExp(b)) { auto lt = getLoopDependentLevelType(b); return isCompressedLT(lt) || isSingletonLT(lt) || - isLooseCompressedLT(lt) || is2OutOf4LT(lt); + isLooseCompressedLT(lt) || isNOutOfMLT(lt); } return false; } diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h index 01c5f2382ffe6..1d8d9bcfb3b2c 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h @@ -124,7 +124,7 @@ class SparseTensorStorageBase { bool isSingletonLvl(uint64_t l) const { return isSingletonLT(getLvlType(l)); } /// Safely checks if the level uses 2 out of 4 storage. - bool is2OutOf4Lvl(uint64_t l) const { return is2OutOf4LT(getLvlType(l)); } + bool isNOutOfMLvl(uint64_t l) const { return isNOutOfMLT(getLvlType(l)); } /// Safely checks if the level is ordered. bool isOrderedLvl(uint64_t l) const { return isOrderedLT(getLvlType(l)); } @@ -450,7 +450,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase { void appendCrd(uint64_t lvl, uint64_t full, uint64_t crd) { if (!isDenseLvl(lvl)) { assert(isCompressedLvl(lvl) || isLooseCompressedLvl(lvl) || - isSingletonLvl(lvl) || is2OutOf4Lvl(lvl)); + isSingletonLvl(lvl) || isNOutOfMLvl(lvl)); coordinates[lvl].push_back(detail::checkOverflowCast<C>(crd)); } else { // Dense level. assert(crd >= full && "Coordinate was already filled"); @@ -473,7 +473,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase { return positions[l][parentSz]; if (isLooseCompressedLvl(l)) return positions[l][2 * parentSz - 1]; - if (isSingletonLvl(l) || is2OutOf4Lvl(l)) + if (isSingletonLvl(l) || isNOutOfMLvl(l)) return parentSz; // new size same as the parent assert(isDenseLvl(l)); return parentSz * getLvlSize(l); @@ -527,7 +527,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase { uint64_t pos = coordinates[l].size(); positions[l].insert(positions[l].end(), 2 * count, detail::checkOverflowCast<P>(pos)); - } else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) { + } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) { return; // Nothing to finalize. } else { // Dense dimension. assert(isDenseLvl(l)); @@ -624,7 +624,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase { lvlCursor[l] = static_cast<uint64_t>(coordinatesL[pos]); toCOO(pos, l + 1, dimCoords); } - } else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) { + } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) { assert(parentPos < coordinates[l].size()); lvlCursor[l] = static_cast<uint64_t>(coordinates[l][parentPos]); toCOO(parentPos, l + 1, dimCoords); @@ -721,7 +721,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage( } else if (isSingletonLvl(l)) { coordinates[l].reserve(sz); sz = 1; - } else if (is2OutOf4Lvl(l)) { + } else if (isNOutOfMLvl(l)) { assert(l == lvlRank - 1 && "unexpected 2:4 usage"); sz = detail::checkedMul(sz, lvlSizes[l]) / 2; coordinates[l].reserve(sz); @@ -791,7 +791,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage( } } else if (isSingletonLvl(l)) { assert(0 && "general singleton not supported yet"); - } else if (is2OutOf4Lvl(l)) { + } else if (isNOutOfMLvl(l)) { assert(0 && "2Out4 not supported yet"); } else { assert(isDenseLvl(l)); diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 8706c523988b1..f68d77dc129ad 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -25,7 +25,7 @@ using namespace mlir::python::adaptors; static void populateDialectSparseTensorSubmodule(const py::module &m) { py::enum_<MlirSparseTensorLevelType>(m, "LevelType", py::module_local()) .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE) - .value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR) + .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) .value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) .value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index e4534ad132385..a34b9a29b0e90 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -20,25 +20,36 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, mlir::sparse_tensor::SparseTensorDialect) // Ensure the C-API enums are int-castable to C++ equivalents. -static_assert(static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == - static_cast<int>(LevelType::Dense) && - static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == - static_cast<int>(LevelType::Compressed) && - static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) == - static_cast<int>(LevelType::CompressedNu) && - static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) == - static_cast<int>(LevelType::CompressedNo) && - static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) == - static_cast<int>(LevelType::CompressedNuNo) && - static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == - static_cast<int>(LevelType::Singleton) && - static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) == - static_cast<int>(LevelType::SingletonNu) && - static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) == - static_cast<int>(LevelType::SingletonNo) && - static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) == - static_cast<int>(LevelType::SingletonNuNo), - "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch"); +static_assert( + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == + static_cast<int>(LevelType::Dense) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == + static_cast<int>(LevelType::Compressed) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) == + static_cast<int>(LevelType::CompressedNu) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) == + static_cast<int>(LevelType::CompressedNo) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) == + static_cast<int>(LevelType::CompressedNuNo) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == + static_cast<int>(LevelType::Singleton) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) == + static_cast<int>(LevelType::SingletonNu) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) == + static_cast<int>(LevelType::SingletonNo) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) == + static_cast<int>(LevelType::SingletonNuNo) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) == + static_cast<int>(LevelType::LooseCompressed) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) == + static_cast<int>(LevelType::LooseCompressedNu) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) == + static_cast<int>(LevelType::LooseCompressedNo) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) == + static_cast<int>(LevelType::LooseCompressedNuNo) && + static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) == + static_cast<int>(LevelType::NOutOfM), + "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch"); bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { return isa<SparseTensorEncodingAttr>(unwrap(attr)); diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp index 56b435c57d30a..95874d4857fc8 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp @@ -299,6 +299,8 @@ ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) { FAILURE_IF_FAILED(type) lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type)); + llvm::errs() << "type = " << toMLIRString(static_cast<LevelType>(*type)) + << "\n"; return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp index eb7ea63a4e88b..14ebe14b49f64 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp @@ -29,12 +29,21 @@ using namespace mlir::sparse_tensor::ir_detail; // `LvlTypeParser` implementation. //===----------------------------------------------------------------------===// -FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const { +FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const { StringRef base; const auto loc = parser.getCurrentLocation(); ERROR_IF(failed(parser.parseOptionalKeyword(&base)), "expected valid level format (e.g. dense, compressed or singleton)") - uint8_t properties = 0; + uint64_t properties = 0; + SmallVector<unsigned> blockSizes; + + if (base.compare("block") == 0) { + ParseResult res = parser.parseCommaSeparatedList( + mlir::OpAsmParser::Delimiter::OptionalSquare, + [&]() -> ParseResult { return parseBlockSize(parser, &blockSizes); }, + " in block n out of m"); + FAILURE_IF_FAILED(res) + } ParseResult res = parser.parseCommaSeparatedList( mlir::OpAsmParser::Delimiter::OptionalParen, @@ -44,15 +53,21 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const { // Set the base bit for properties. if (base.compare("dense") == 0) { - properties |= static_cast<uint8_t>(LevelFormat::Dense); + properties |= static_cast<uint64_t>(LevelFormat::Dense); } else if (base.compare("compressed") == 0) { - properties |= static_cast<uint8_t>(LevelFormat::Compressed); - } else if (base.compare("block2_4") == 0) { - properties |= static_cast<uint8_t>(LevelFormat::TwoOutOfFour); + properties |= static_cast<uint64_t>(LevelFormat::Compressed); + } else if (base.compare("block") == 0) { + if (blockSizes.size() != 2) { + parser.emitError(loc, "expected exactly 2 block sizes"); + return failure(); + } + properties |= static_cast<uint64_t>(LevelFormat::NOutOfM); + properties |= nToBits(blockSizes[0]) | mToBits(blockSizes[1]); + llvm::errs() << "properties1: " << properties << "\n"; } else if (base.compare("loose_compressed") == 0) { - properties |= static_cast<uint8_t>(LevelFormat::LooseCompressed); + properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed); } else if (base.compare("singleton") == 0) { - properties |= static_cast<uint8_t>(LevelFormat::Singleton); + properties |= static_cast<uint64_t>(LevelFormat::Singleton); } else { parser.emitError(loc, "unknown level format: ") << base; return failure(); @@ -64,15 +79,15 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const { } ParseResult LvlTypeParser::parseProperty(AsmParser &parser, - uint8_t *properties) const { + uint64_t *properties) const { StringRef strVal; auto loc = parser.getCurrentLocation(); ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)), "expected valid level property (e.g. nonordered, nonunique or high)") if (strVal.compare("nonunique") == 0) { - *properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonunique); + *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique); } else if (strVal.compare("nonordered") == 0) { - *properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonordered); + *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered); } else { parser.emitError(loc, "unknown level property: ") << strVal; return failure(); @@ -80,4 +95,22 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser, return success(); } +ParseResult +LvlTypeParser::parseBlockSize(AsmParser &parser, + SmallVector<unsigned> *blockSizes) const { + int intVal; + auto loc = parser.getCurrentLocation(); + OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal); + if (intValParseResult.has_value()) { + if (failed(*intValParseResult)) { + parser.emitError(loc, "failed to parse block size"); + return failure(); + } + blockSizes->push_back(intVal); + return success(); + } + parser.emitError(loc, "expected valid integer for block size"); + return failure(); +} + //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h index 5e2f11b75d4da..78ae667f97923 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h @@ -18,10 +18,12 @@ namespace ir_detail { class LvlTypeParser { public: LvlTypeParser() = default; - FailureOr<uint8_t> parseLvlType(AsmParser &parser) const; + FailureOr<uint64_t> parseLvlType(AsmParser &parser) const; private: - ParseResult parseProperty(AsmParser &parser, uint8_t *properties) const; + ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const; + ParseResult parseBlockSize(AsmParser &parser, + SmallVector<unsigned> *blockSizes) const; }; } // namespace ir_detail diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp index 87a37a7926e9e..23676eccdfb28 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -451,7 +451,7 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) { /// Test for 2:4 matrix with suitable metadata. static bool isAdmissible24(SparseTensorType &aTp) { return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) && - aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp); + aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp); } /// Test for conversion into 2:4 matrix. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 491501a3381b9..d4459c6ea1e52 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -130,7 +130,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc, createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl, /*value=*/posZero, /*repeat=*/linear); return; - } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) { + } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) { return; // nothing to do } // Keep compounding the size, but nothing needs to be initialized @@ -409,7 +409,7 @@ static void genEndInsert(OpBuilder &builder, Location loc, } } else { assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) || - is2OutOf4LT(lt)); + isNOutOfMLT(lt)); } } } @@ -488,7 +488,7 @@ class SparseInsertGenerator } parentPos = genCompressed(builder, loc, desc, coords, value, parentPos, lvl); - } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) { + } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) { // Create: // coordinates[lvl].push_back(coords[lvl]) // positions[lvl] = positions[lvl-1] diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 5266ca7213bfc..cc39f21001168 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -891,7 +891,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr, assert(curr == env.merger().loop(b)); Value clause; if (isCompressedLT(lt) || isSingletonLT(lt) || - isLooseCompressedLT(lt) || is2OutOf4LT(lt)) { + isLooseCompressedLT(lt) || isNOutOfMLT(lt)) { assert(lvl.has_value()); const Value crd = env.emitter().getCoord(tid, *lvl); const Value lvar = env.getLoopVar(curr); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h index 8d54b5959d871..cc119bc704559 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h @@ -423,7 +423,7 @@ inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc, /// Generates a constant of the internal dimension level type encoding. inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, LevelType lt) { - return constantI8(builder, loc, static_cast<uint8_t>(lt)); + return constantI64(builder, loc, static_cast<uint64_t>(lt)); } inline bool isZeroRankedTensorOrScalar(Type type) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp index e43896942d7fe..14051fe631f09 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp @@ -1221,7 +1221,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t, Value crd = genToCoordinates(b, l, t, lvl); return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd); } - case LevelFormat::TwoOutOfFour: { + case LevelFormat::NOutOfM: { Value crd = genToCoordinates(b, l, t, lvl); return std::make_unique<TwoOutFourLevel>(tid, lvl, lt, sz, crd); } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 6cdf5f8c0168b..96537cbb0c483 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -489,7 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { const auto lt = getLvlType(b); if (!isCompressedLT(lt) && !isSingletonLT(lt) && - !isLooseCompressedLT(lt) && !is2OutOf4LT(lt)) { + !isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) { if (reset) simple.reset(b); reset = true; @@ -670,7 +670,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const { for (TensorLoopId b : bits.set_bits()) { const auto lt = getLvlType(b); if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) || - is2OutOf4LT(lt)) + isNOutOfMLT(lt)) return true; } return hasSparseIdxReduction(bits); diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp index 0c7b3a228a65c..9e8b240899d80 100644 --- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp @@ -45,7 +45,7 @@ SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT for (uint64_t l = 0; l < lvlRank; l++) { assert(lvlSizes[l] > 0 && "Level size zero has trivial storage"); assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) || - isSingletonLvl(l) || is2OutOf4Lvl(l)); + isSingletonLvl(l) || isNOutOfMLvl(l)); } } diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c index b0bc9bb6e881a..ea4c56b7ec0c5 100644 --- a/mlir/test/CAPI/sparse_tensor.c +++ b/mlir/test/CAPI/sparse_tensor.c @@ -37,9 +37,9 @@ static int testRoundtripEncoding(MlirContext ctx) { mlirSparseTensorEncodingAttrGetDimToLvl(originalAttr); // CHECK: (d0, d1)[s0] -> (s0, d0, d1) mlirAffineMapDump(dimToLvl); - // CHECK: level_type: 4 - // CHECK: level_type: 8 - // CHECK: level_type: 8 + // CHECK: level_type: 65536 + // CHECK: level_type: 131072 + // CHECK: level_type: 131072 MlirAffineMap lvlToDim = mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr); int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr); diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir index 6fe7ec906f30e..e4884a9bf393f 100644 --- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir +++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir @@ -4,7 +4,7 @@ map = ( i, j ) -> ( i : dense, j floordiv 4 : dense, - j mod 4 : block2_4 + j mod 4 : block[2, 4] ) }> diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir index e4e825bf85043..465f210862660 100644 --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -78,8 +78,8 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind // CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex> // CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}}) -// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8> -// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8> +// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi64> +// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi64> to memref<?xi64> // CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex> // CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]]) @@ -96,8 +96,8 @@ func.func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> { // CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex> // CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}}) // CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]]) -// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8> -// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8> +// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64> +// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64> // CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex> // CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex> // CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]]) @@ -114,8 +114,8 @@ func.func @sparse_new2d(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CSR> { // CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex> // CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}}) // CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]]) -// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8> -// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8> +// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi64> +// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi64> to memref<?xi64> // CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex> // CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex> // CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex> @@ -136,10 +136,10 @@ func.func @sparse_new3d(%arg0: !llvm.ptr) -> tensor<?x?x?xf32, #SparseTensor> { // CHECK-DAG: %[[Empty:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8> +// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64> // CHECK-DAG: %[[Sizes0:.*]] = memref.alloca() : memref<2xindex> // CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex> -// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8> +// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64> // CHECK-DAG: %[[Sizes:.*]] = memref.cast %[[Sizes0]] : memref<2xindex> to memref<?xindex> // CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex> // CHECK-DAG: memref.store %[[I]], %[[Sizes0]][%[[C0]]] : memref<2xindex> diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir index 20702bb985028..966a7ff2d38e1 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir @@ -207,12 +207,12 @@ func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) { map = ( i, j ) -> ( i : dense, j floordiv 4 : dense, - j mod 4 : block2_4 + j mod 4 : block[2, 4] ), crdWidth = 8 // we would even like just 2-bits }> -// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block2_4), crdWidth = 8 }> +// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block[2, 4]), crdWidth = 8 }> // CHECK-LABEL: func private @NV_24( // CHECK-SAME: tensor<?x?xf64, #[[$NV_24]]> func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) { @@ -226,11 +226,11 @@ func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) { ( i : dense, j : dense, k floordiv 4 : dense, - k mod 4 : block2_4 + k mod 4 : block[2, 4] ) }> -// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block2_4) }> +// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block[2, 4]) }> // CHECK-LABEL: func private @NV_24( // CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]> func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) { @@ -244,11 +244,11 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) { ( i : dense, k floordiv 4 : dense, j : dense, - k mod 4 : block2_4 + k mod 4 : block[2, 4] ) }> -// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block2_4) }> +// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block[2, 4]) }> // CHECK-LABEL: func private @NV_24( // CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]> func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) { diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir index 40367f12f85a4..d04fbe8ed5c22 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir @@ -14,11 +14,11 @@ // CHECK-DAG: %[[VAL_8:.*]] = arith.constant true // CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index // CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 8 : i8 -// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi8> -// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8> -// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8> -// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi8> +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 131072 : i64 +// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64> +// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64> +// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64> +// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64> // CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex> // CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex> // CHECK: memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex> @@ -28,7 +28,7 @@ // CHECK: memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex> // CHECK: memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex> // CHECK: %[[VAL_18:.*]] = llvm.mlir.zero : !llvm.ptr -// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr +// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr // CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<300xf64> // CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64> // CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<300xi1> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir index 4bc080fc538fc..554d6207aef7e 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir @@ -59,7 +59,7 @@ map = ( i, j ) -> ( i : dense, j floordiv 4 : dense, - j mod 4 : block2_4 + j mod 4 : block[2, 4] ), }> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir index df5b48a3b6ece..9935d7c69e63a 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir @@ -41,7 +41,7 @@ #NV_24 = #sparse_tensor.encoding<{ map = ( i, j ) -> ( i : dense, j floordiv 4 : dense, - j mod 4 : block2_4), + j mod 4 : block[2, 4]), crdWidth = 8 }> diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir index 17b50b46d073a..25454f5c06b45 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir @@ -20,7 +20,7 @@ map = ( i, j ) -> ( i : dense, j floordiv 4 : dense, - j mod 4 : block2_4 + j mod 4 : block[2, 4] ) }> diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir index eb99a027a8860..da735b4a3b58a 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir @@ -20,7 +20,7 @@ map = ( i, j ) -> ( i : dense, j floordiv 4 : dense, - j mod 4 : block2_4 + j mod 4 : block[2, 4] ) }> diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py index 88a5595d75aea..e9296b961e7fe 100644 --- a/mlir/test/python/dialects/sparse_tensor/dialect.py +++ b/mlir/test/python/dialects/sparse_tensor/dialect.py @@ -13,85 +13,85 @@ def run(f): # CHECK-LABEL: TEST: testEncodingAttr1D @run def testEncodingAttr1D(): - with Context() as ctx: - parsed = Attribute.parse( - "#sparse_tensor.encoding<{" - " map = (d0) -> (d0 : compressed)," - " posWidth = 16," - " crdWidth = 32" - "}>" - ) - # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }> - print(parsed) - - casted = st.EncodingAttr(parsed) - # CHECK: equal: True - print(f"equal: {casted == parsed}") - - # CHECK: lvl_types: [<LevelType.compressed: 8>] - print(f"lvl_types: {casted.lvl_types}") - # CHECK: dim_to_lvl: (d0) -> (d0) - print(f"dim_to_lvl: {casted.dim_to_lvl}") - # CHECK: lvl_to_dim: (d0) -> (d0) - print(f"lvl_to_dim: {casted.lvl_to_dim}") - # CHECK: pos_width: 16 - print(f"pos_width: {casted.pos_width}") - # CHECK: crd_width: 32 - print(f"crd_width: {casted.crd_width}") - - created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0) - # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> - print(created) - # CHECK: created_equal: False - print(f"created_equal: {created == casted}") - - # Verify that the factory creates an instance of the proper type. - # CHECK: is_proper_instance: True - print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}") - # CHECK: created_pos_width: 0 - print(f"created_pos_width: {created.pos_width}") + with Context() as ctx: + parsed = Attribute.parse( + "#sparse_tensor.encoding<{" + " map = (d0) -> (d0 : compressed)," + " posWidth = 16," + " crdWidth = 32" + "}>" + ) + # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }> + print(parsed) + + casted = st.EncodingAttr(parsed) + # CHECK: equal: True + print(f"equal: {casted == parsed}") + + # CHECK: lvl_types: [<LevelType.compressed: 131072>] + print(f"lvl_types: {casted.lvl_types}") + # CHECK: dim_to_lvl: (d0) -> (d0) + print(f"dim_to_lvl: {casted.dim_to_lvl}") + # CHECK: lvl_to_dim: (d0) -> (d0) + print(f"lvl_to_dim: {casted.lvl_to_dim}") + # CHECK: pos_width: 16 + print(f"pos_width: {casted.pos_width}") + # CHECK: crd_width: 32 + print(f"crd_width: {casted.crd_width}") + + created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0) + # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> + print(created) + # CHECK: created_equal: False + print(f"created_equal: {created == casted}") + + # Verify that the factory creates an instance of the proper type. + # CHECK: is_proper_instance: True + print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}") + # CHECK: created_pos_width: 0 + print(f"created_pos_width: {created.pos_width}") # CHECK-LABEL: TEST: testEncodingAttr2D @run def testEncodingAttr2D(): - with Context() as ctx: - parsed = Attribute.parse( - "#sparse_tensor.encoding<{" - " map = (d0, d1) -> (d1 : dense, d0 : compressed)," - " posWidth = 8," - " crdWidth = 32" - "}>" - ) - # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }> - print(parsed) - - casted = st.EncodingAttr(parsed) - # CHECK: equal: True - print(f"equal: {casted == parsed}") - - # CHECK: lvl_types: [<LevelType.dense: 4>, <LevelType.compressed: 8>] - print(f"lvl_types: {casted.lvl_types}") - # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0) - print(f"dim_to_lvl: {casted.dim_to_lvl}") - # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0) - print(f"lvl_to_dim: {casted.lvl_to_dim}") - # CHECK: pos_width: 8 - print(f"pos_width: {casted.pos_width}") - # CHECK: crd_width: 32 - print(f"crd_width: {casted.crd_width}") - - created = st.EncodingAttr.get( - casted.lvl_types, - casted.dim_to_lvl, - casted.lvl_to_dim, - 8, - 32, - ) - # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }> - print(created) - # CHECK: created_equal: True - print(f"created_equal: {created == casted}") + with Context() as ctx: + parsed = Attribute.parse( + "#sparse_tensor.encoding<{" + " map = (d0, d1) -> (d1 : dense, d0 : compressed)," + " posWidth = 8," + " crdWidth = 32" + "}>" + ) + # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }> + print(parsed) + + casted = st.EncodingAttr(parsed) + # CHECK: equal: True + print(f"equal: {casted == parsed}") + + # CHECK: lvl_types: [<LevelType.dense: 65536>, <LevelType.compressed: 131072>] + print(f"lvl_types: {casted.lvl_types}") + # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0) + print(f"dim_to_lvl: {casted.dim_to_lvl}") + # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0) + print(f"lvl_to_dim: {casted.lvl_to_dim}") + # CHECK: pos_width: 8 + print(f"pos_width: {casted.pos_width}") + # CHECK: crd_width: 32 + print(f"crd_width: {casted.crd_width}") + + created = st.EncodingAttr.get( + casted.lvl_types, + casted.dim_to_lvl, + casted.lvl_to_dim, + 8, + 32, + ) + # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }> + print(created) + # CHECK: created_equal: True + print(f"created_equal: {created == casted}") # CHECK-LABEL: TEST: testEncodingAttrOnTensorType >From a91ee4a2701822a50dc048563740a60cc00caf05 Mon Sep 17 00:00:00 2001 From: Yinying Li <yinyin...@google.com> Date: Tue, 30 Jan 2024 02:58:18 +0000 Subject: [PATCH 2/3] format --- .../include/mlir/Dialect/SparseTensor/IR/Enums.h | 8 ++------ .../SparseTensor/IR/Detail/DimLvlMapParser.cpp | 2 -- .../SparseTensor/IR/Detail/LvlTypeParser.cpp | 1 - .../SparseTensor/IR/SparseTensorDialect.cpp | 16 ++++++++++++++-- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h index 6ddc9326179fe..15802a5ad3563 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -233,7 +233,7 @@ constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) { } /// Returns string representation of the given dimension level type. -std::string toMLIRString(LevelType lt) { +constexpr const char *toMLIRString(LevelType lt) { switch (lt) { case LevelType::Undef: return "undef"; @@ -264,12 +264,8 @@ std::string toMLIRString(LevelType lt) { case LevelType::LooseCompressedNuNo: return "loose_compressed(nonunique, nonordered)"; default: - // If NOutOfM bit is set, print the [n, m] sizes. if (isNOutOfMLT(lt)) { - unsigned n = getN(lt); - unsigned m = getM(lt); - return std::string("block[") + std::to_string(n) + ", " + - std::to_string(m) + "]"; + return "block"; } } return ""; diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp index 95874d4857fc8..56b435c57d30a 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp @@ -299,8 +299,6 @@ ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) { FAILURE_IF_FAILED(type) lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type)); - llvm::errs() << "type = " << toMLIRString(static_cast<LevelType>(*type)) - << "\n"; return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp index 14ebe14b49f64..993ad9be8a012 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp @@ -63,7 +63,6 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const { } properties |= static_cast<uint64_t>(LevelFormat::NOutOfM); properties |= nToBits(blockSizes[0]) | mToBits(blockSizes[1]); - llvm::errs() << "properties1: " << properties << "\n"; } else if (base.compare("loose_compressed") == 0) { properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed); } else if (base.compare("singleton") == 0) { diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 6033ebf6897ce..d56d90a2d6130 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -613,16 +613,28 @@ void SparseTensorEncodingAttr::printDimensions( } } +std::string getNOutOfMString(LevelType lt) { + if (isNOutOfMLT(lt)) { + unsigned n = getN(lt); + unsigned m = getM(lt); + auto output = "[" + std::to_string(n) + ", " + std::to_string(m) + "]"; + return output; + } + return ""; +} + void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<LevelType> lvlTypes) const { for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { map.getResult(i).print(printer.getStream()); - printer << " : " << toMLIRString(lvlTypes[i]) << ", "; + printer << " : " << toMLIRString(lvlTypes[i]) + << getNOutOfMString(lvlTypes[i]) << ", "; } if (map.getNumResults() >= 1) { auto lastIndex = map.getNumResults() - 1; map.getResult(lastIndex).print(printer.getStream()); - printer << " : " << toMLIRString(lvlTypes[lastIndex]); + printer << " : " << toMLIRString(lvlTypes[lastIndex]) + << getNOutOfMString(lvlTypes[lastIndex]); } } >From 8f669440ab9e817cf628f81f274e45859cec7f68 Mon Sep 17 00:00:00 2001 From: Yinying Li <yinyin...@google.com> Date: Tue, 30 Jan 2024 03:13:30 +0000 Subject: [PATCH 3/3] python format --- .../mlir/Dialect/SparseTensor/IR/Enums.h | 1 - .../python/dialects/sparse_tensor/dialect.py | 148 +++++++++--------- 2 files changed, 74 insertions(+), 75 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h index 15802a5ad3563..99443957d01d5 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -310,7 +310,6 @@ constexpr bool isLooseCompressedLT(LevelType lt) { static_cast<uint64_t>(LevelType::LooseCompressed); } - /// Check if the `LevelType` needs positions array. constexpr bool isWithPosLT(LevelType lt) { return isCompressedLT(lt) || isLooseCompressedLT(lt); diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py index e9296b961e7fe..75c47a57f78af 100644 --- a/mlir/test/python/dialects/sparse_tensor/dialect.py +++ b/mlir/test/python/dialects/sparse_tensor/dialect.py @@ -13,85 +13,85 @@ def run(f): # CHECK-LABEL: TEST: testEncodingAttr1D @run def testEncodingAttr1D(): - with Context() as ctx: - parsed = Attribute.parse( - "#sparse_tensor.encoding<{" - " map = (d0) -> (d0 : compressed)," - " posWidth = 16," - " crdWidth = 32" - "}>" - ) - # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }> - print(parsed) - - casted = st.EncodingAttr(parsed) - # CHECK: equal: True - print(f"equal: {casted == parsed}") - - # CHECK: lvl_types: [<LevelType.compressed: 131072>] - print(f"lvl_types: {casted.lvl_types}") - # CHECK: dim_to_lvl: (d0) -> (d0) - print(f"dim_to_lvl: {casted.dim_to_lvl}") - # CHECK: lvl_to_dim: (d0) -> (d0) - print(f"lvl_to_dim: {casted.lvl_to_dim}") - # CHECK: pos_width: 16 - print(f"pos_width: {casted.pos_width}") - # CHECK: crd_width: 32 - print(f"crd_width: {casted.crd_width}") - - created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0) - # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> - print(created) - # CHECK: created_equal: False - print(f"created_equal: {created == casted}") - - # Verify that the factory creates an instance of the proper type. - # CHECK: is_proper_instance: True - print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}") - # CHECK: created_pos_width: 0 - print(f"created_pos_width: {created.pos_width}") + with Context() as ctx: + parsed = Attribute.parse( + "#sparse_tensor.encoding<{" + " map = (d0) -> (d0 : compressed)," + " posWidth = 16," + " crdWidth = 32" + "}>" + ) + # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }> + print(parsed) + + casted = st.EncodingAttr(parsed) + # CHECK: equal: True + print(f"equal: {casted == parsed}") + + # CHECK: lvl_types: [<LevelType.compressed: 131072>] + print(f"lvl_types: {casted.lvl_types}") + # CHECK: dim_to_lvl: (d0) -> (d0) + print(f"dim_to_lvl: {casted.dim_to_lvl}") + # CHECK: lvl_to_dim: (d0) -> (d0) + print(f"lvl_to_dim: {casted.lvl_to_dim}") + # CHECK: pos_width: 16 + print(f"pos_width: {casted.pos_width}") + # CHECK: crd_width: 32 + print(f"crd_width: {casted.crd_width}") + + created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0) + # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> + print(created) + # CHECK: created_equal: False + print(f"created_equal: {created == casted}") + + # Verify that the factory creates an instance of the proper type. + # CHECK: is_proper_instance: True + print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}") + # CHECK: created_pos_width: 0 + print(f"created_pos_width: {created.pos_width}") # CHECK-LABEL: TEST: testEncodingAttr2D @run def testEncodingAttr2D(): - with Context() as ctx: - parsed = Attribute.parse( - "#sparse_tensor.encoding<{" - " map = (d0, d1) -> (d1 : dense, d0 : compressed)," - " posWidth = 8," - " crdWidth = 32" - "}>" - ) - # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }> - print(parsed) - - casted = st.EncodingAttr(parsed) - # CHECK: equal: True - print(f"equal: {casted == parsed}") - - # CHECK: lvl_types: [<LevelType.dense: 65536>, <LevelType.compressed: 131072>] - print(f"lvl_types: {casted.lvl_types}") - # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0) - print(f"dim_to_lvl: {casted.dim_to_lvl}") - # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0) - print(f"lvl_to_dim: {casted.lvl_to_dim}") - # CHECK: pos_width: 8 - print(f"pos_width: {casted.pos_width}") - # CHECK: crd_width: 32 - print(f"crd_width: {casted.crd_width}") - - created = st.EncodingAttr.get( - casted.lvl_types, - casted.dim_to_lvl, - casted.lvl_to_dim, - 8, - 32, - ) - # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }> - print(created) - # CHECK: created_equal: True - print(f"created_equal: {created == casted}") + with Context() as ctx: + parsed = Attribute.parse( + "#sparse_tensor.encoding<{" + " map = (d0, d1) -> (d1 : dense, d0 : compressed)," + " posWidth = 8," + " crdWidth = 32" + "}>" + ) + # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }> + print(parsed) + + casted = st.EncodingAttr(parsed) + # CHECK: equal: True + print(f"equal: {casted == parsed}") + + # CHECK: lvl_types: [<LevelType.dense: 65536>, <LevelType.compressed: 131072>] + print(f"lvl_types: {casted.lvl_types}") + # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0) + print(f"dim_to_lvl: {casted.dim_to_lvl}") + # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0) + print(f"lvl_to_dim: {casted.lvl_to_dim}") + # CHECK: pos_width: 8 + print(f"pos_width: {casted.pos_width}") + # CHECK: crd_width: 32 + print(f"crd_width: {casted.crd_width}") + + created = st.EncodingAttr.get( + casted.lvl_types, + casted.dim_to_lvl, + casted.lvl_to_dim, + 8, + 32, + ) + # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }> + print(created) + # CHECK: created_equal: True + print(f"created_equal: {created == casted}") # CHECK-LABEL: TEST: testEncodingAttrOnTensorType _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits