https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/181385
>From de9ef4c7b90cdcecde51251f8ed5052c4185dd63 Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Fri, 13 Feb 2026 16:58:41 +0000 Subject: [PATCH] [mlir][IR] Separate `DenseStringElementsAttr` from `DenseElementsAttr` --- mlir/include/mlir/IR/BuiltinAttributes.h | 37 ++---- mlir/include/mlir/IR/BuiltinAttributes.td | 110 +++++++++++++++--- mlir/include/mlir/IR/CommonAttrConstraints.td | 4 +- mlir/lib/AsmParser/AttributeParser.cpp | 36 ++++-- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 22 ++-- mlir/lib/IR/BuiltinAttributes.cpp | 70 ++++------- mlir/test/IR/parser.mlir | 4 - mlir/test/mlir-tblgen/openmp-clause-ops.td | 2 +- mlir/unittests/IR/AttributeTest.cpp | 37 +++--- 9 files changed, 192 insertions(+), 130 deletions(-) diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h index ee6a8f4e4d948..3ba943c7ccd41 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -152,10 +152,6 @@ class DenseElementsAttr : public Attribute { /// Overload of the above 'get' method that is specialized for boolean values. static DenseElementsAttr get(ShapedType type, ArrayRef<bool> values); - /// Overload of the above 'get' method that is specialized for StringRef - /// values. - static DenseElementsAttr get(ShapedType type, ArrayRef<StringRef> values); - /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. 'type' must be a vector or tensor with static @@ -223,7 +219,8 @@ class DenseElementsAttr : public Attribute { decltype(std::declval<AttrT>().template getValues<T>()); /// A utility iterator that allows walking over the internal Attribute values - /// of a DenseElementsAttr. + /// of a dense elements attribute (DenseElementsAttr or + /// DenseStringElementsAttr). class AttributeElementIterator : public llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, Attribute, @@ -232,11 +229,9 @@ class DenseElementsAttr : public Attribute { /// Accesses the Attribute value at this iterator position. Attribute operator*() const; - private: - friend DenseElementsAttr; - - /// Constructs a new iterator. - AttributeElementIterator(DenseElementsAttr attr, size_t index); + /// Constructs a new iterator. Accepts any attribute implementing + /// ElementsAttr (e.g. DenseElementsAttr, DenseStringElementsAttr). + AttributeElementIterator(Attribute attr, size_t index); }; /// Iterator for walking raw element values of the specified type 'T', which @@ -461,21 +456,6 @@ class DenseElementsAttr : public Attribute { ElementIterator<T>(rawData, splat, getNumElements())); } - /// Try to get the held element values as a range of StringRef. - template <typename T> - using StringRefValueTemplateCheckT = - std::enable_if_t<std::is_same<T, StringRef>::value>; - template <typename T, typename = StringRefValueTemplateCheckT<T>> - FailureOr<iterator_range_impl<ElementIterator<StringRef>>> - tryGetValues() const { - auto stringRefs = getRawStringData(); - const char *ptr = reinterpret_cast<const char *>(stringRefs.data()); - bool splat = isSplat(); - return iterator_range_impl<ElementIterator<StringRef>>( - getType(), ElementIterator<StringRef>(ptr, splat, 0), - ElementIterator<StringRef>(ptr, splat, getNumElements())); - } - /// Try to get the held element values as a range of Attributes. template <typename T> using AttributeValueTemplateCheckT = @@ -484,8 +464,8 @@ class DenseElementsAttr : public Attribute { FailureOr<iterator_range_impl<AttributeElementIterator>> tryGetValues() const { return iterator_range_impl<AttributeElementIterator>( - getType(), AttributeElementIterator(*this, 0), - AttributeElementIterator(*this, getNumElements())); + getType(), AttributeElementIterator(Attribute(*this), 0), + AttributeElementIterator(Attribute(*this), getNumElements())); } /// Try to get the held element values a range of T, where T is a derived @@ -578,9 +558,6 @@ class DenseElementsAttr : public Attribute { /// form the user might expect. ArrayRef<char> getRawData() const; - /// Return the raw StringRef data held by this attribute. - ArrayRef<StringRef> getRawStringData() const; - /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor /// with static shape. ShapedType getType() const; diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index dced379d1f979..064783ae5f87a 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -395,8 +395,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< //===----------------------------------------------------------------------===// def Builtin_DenseStringElementsAttr : Builtin_Attr< - "DenseStringElements", "dense_string_elements", [ElementsAttrInterface], - "DenseElementsAttr" + "DenseStringElements", "dense_string_elements", [ElementsAttrInterface] > { let summary = "An Attribute containing a dense multi-dimensional array of " "strings"; @@ -431,13 +430,97 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr< }]>, ]; let extraClassDeclaration = [{ - using DenseElementsAttr::empty; - using DenseElementsAttr::getNumElements; - using DenseElementsAttr::getElementType; - using DenseElementsAttr::getValues; - using DenseElementsAttr::isSplat; - using DenseElementsAttr::size; - using DenseElementsAttr::value_begin; + /// Iterator for walking StringRef element values. + class StringRefElementIterator + : public detail::DenseElementIndexedIteratorImpl<StringRefElementIterator, + const StringRef> { + public: + const StringRef &operator*() const { + return reinterpret_cast<const StringRef *>(this->getData())[this->getDataIndex()]; + } + StringRefElementIterator(const char *data, bool isSplat, size_t dataIndex) + : detail::DenseElementIndexedIteratorImpl<StringRefElementIterator, + const StringRef>( + data, isSplat, dataIndex) {} + }; + + /// Iterator for walking element values as Attribute (StringAttr). + class StringAttributeElementIterator + : public llvm::indexed_accessor_iterator<StringAttributeElementIterator, + const void *, Attribute, + Attribute, Attribute> { + public: + Attribute operator*() const; + StringAttributeElementIterator(const DenseStringElementsAttr *attr, + size_t index) + : llvm::indexed_accessor_iterator<StringAttributeElementIterator, + const void *, Attribute, + Attribute, Attribute>( + attr->getAsOpaquePointer(), index) {} + }; + + /// Return the type of this attribute (vector or tensor with static shape). + ShapedType getType() const; + + /// Helper methods for ElementsAttr interface. + bool empty() const { return getNumElements() == 0; } + int64_t getNumElements() const { return getType().getNumElements(); } + Type getElementType() const { return getType().getElementType(); } + bool isSplat() const { return getRawStringData().size() == 1; } + int64_t size() const { return getNumElements(); } + + /// Return the raw StringRef data held by this attribute. + ArrayRef<StringRef> getRawStringData() const; + + /// Try to get the held element values as a range of StringRef. + template <typename T> + using StringRefValueTemplateCheckT = + std::enable_if_t<std::is_same<T, StringRef>::value>; + template <typename T, typename = StringRefValueTemplateCheckT<T>> + FailureOr<detail::ElementsAttrRange<StringRefElementIterator>> + tryGetValues() const { + auto stringRefs = getRawStringData(); + const char *ptr = reinterpret_cast<const char *>(stringRefs.data()); + bool splat = isSplat(); + return detail::ElementsAttrRange<StringRefElementIterator>( + getType(), StringRefElementIterator(ptr, splat, 0), + StringRefElementIterator(ptr, splat, getNumElements())); + } + + /// Try to get the held element values as a range of Attributes. + template <typename T> + using AttributeValueTemplateCheckT = + std::enable_if_t<std::is_same<T, Attribute>::value>; + template <typename T, typename = AttributeValueTemplateCheckT<T>> + FailureOr<detail::ElementsAttrRange<StringAttributeElementIterator>> + tryGetValues() const { + return detail::ElementsAttrRange<StringAttributeElementIterator>( + getType(), StringAttributeElementIterator(this, 0), + StringAttributeElementIterator(this, getNumElements())); + } + + template <typename T> + auto getValues() const { + auto range = tryGetValues<T>(); + assert(succeeded(range) && "element type cannot be iterated"); + return std::move(*range); + } + template <typename T> + auto value_begin() const { return getValues<T>().begin(); } + template <typename T> + auto value_end() const { return getValues<T>().end(); } + /// Return the splat value. Asserts that the attribute is a splat. + template <typename T> + auto getSplatValue() const { + assert(isSplat() && "expected the attribute to be a splat"); + return *value_begin<T>(); + } + template <typename T> + auto try_value_begin() const { + auto range = tryGetValues<T>(); + using iterator = decltype(range->begin()); + return failed(range) ? FailureOr<iterator>(failure()) : range->begin(); + } /// The set of data types that can be iterated by this attribute. using ContiguousIterableTypesT = std::tuple<StringRef>; @@ -449,11 +532,6 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr< auto try_value_begin_impl(OverloadToken<T>) const { return try_value_begin<T>(); } - - protected: - friend DenseElementsAttr; - - public: }]; let genAccessors = 0; let genStorageClass = 0; @@ -931,9 +1009,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr< std::complex<int16_t>, std::complex<int32_t>, std::complex<int64_t>, // Float types. APFloat, float, double, - std::complex<APFloat>, std::complex<float>, std::complex<double>, - // String types. - StringRef + std::complex<APFloat>, std::complex<float>, std::complex<double> >; using ElementsAttr::Trait<SparseElementsAttr>::getValues; using ElementsAttr::Trait<SparseElementsAttr>::value_begin; diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td index ba6cf55a8fb9e..634881f5813f3 100644 --- a/mlir/include/mlir/IR/CommonAttrConstraints.td +++ b/mlir/include/mlir/IR/CommonAttrConstraints.td @@ -565,8 +565,8 @@ def StringElementsAttr : ElementsAttrBase< CPred<"::llvm::isa<::mlir::DenseStringElementsAttr>($_self)" >, "string elements attribute"> { - let storageType = [{ ::mlir::DenseElementsAttr }]; - let returnType = [{ ::mlir::DenseElementsAttr }]; + let storageType = [{ ::mlir::DenseStringElementsAttr }]; + let returnType = [{ ::mlir::DenseStringElementsAttr }]; let convertFromStorage = "$_self"; } diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index dc9744a42b730..15c2e0225f98b 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -472,8 +472,8 @@ class TensorLiteralParser { ParseResult parse(bool allowHex); /// Build a dense attribute instance with the parsed elements and the given - /// shaped type. - DenseElementsAttr getAttr(SMLoc loc, ShapedType type); + /// shaped type. Returns DenseElementsAttr or DenseStringElementsAttr. + Attribute getAttr(SMLoc loc, ShapedType type); ArrayRef<int64_t> getShape() const { return shape; } @@ -487,7 +487,7 @@ class TensorLiteralParser { std::vector<APFloat> &floatValues); /// Build a Dense String attribute for the given type. - DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy); + DenseStringElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy); /// Build a Dense attribute with hex data for the given type. DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type); @@ -539,7 +539,7 @@ ParseResult TensorLiteralParser::parse(bool allowHex) { /// Build a dense attribute instance with the parsed elements and the given /// shaped type. -DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { +Attribute TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) { Type eltType = type.getElementType(); // Check to see if we parse the literal from a hex string. @@ -679,8 +679,8 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy, } /// Build a Dense String attribute for the given type. -DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, - Type eltTy) { +DenseStringElementsAttr +TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type, Type eltTy) { if (hexStorage.has_value()) { auto stringValue = hexStorage->getStringValue(); return DenseStringElementsAttr::get(type, {stringValue}); @@ -1174,6 +1174,13 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { if (!type) return nullptr; + // SparseElementsAttr only supports int/float element types. + if (!type.getElementType().isIntOrIndexOrFloat()) { + emitError(loc) << "sparse elements attribute does not support string " + "element type"; + return nullptr; + } + // Construct the sparse elements attr using zero element indice/value // attributes. ShapedType indicesType = @@ -1219,9 +1226,10 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { // Otherwise, set the shape to the one parsed by the literal parser. indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); } - auto indices = indiceParser.getAttr(indicesLoc, indicesType); - if (!indices) + auto indicesAttr = indiceParser.getAttr(indicesLoc, indicesType); + if (!indicesAttr) return nullptr; + auto indices = llvm::cast<DenseIntElementsAttr>(indicesAttr); // If the values are a splat, set the shape explicitly based on the number of // indices. The number of indices is encoded in the first dimension of the @@ -1231,10 +1239,18 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) { valuesParser.getShape().empty() ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType) : RankedTensorType::get(valuesParser.getShape(), valuesEltType); - auto values = valuesParser.getAttr(valuesLoc, valuesType); - if (!values) + auto valuesAttr = valuesParser.getAttr(valuesLoc, valuesType); + if (!valuesAttr) return nullptr; + // SparseElementsAttr only supports DenseElementsAttr for values (not string). + auto values = llvm::dyn_cast<DenseElementsAttr>(valuesAttr); + if (!values) { + emitError(valuesLoc) + << "dense string elements not supported in sparse elements attribute"; + return nullptr; + } + // Build the sparse elements attribute by the indices and values. return getChecked<SparseElementsAttr>(loc, type, indices, values); } diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 44a3deaf57db5..7325179c047c5 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -728,8 +728,8 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, for (intptr_t i = 0; i < numElements; ++i) values.push_back(unwrap(strs[i])); - return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), - values)); + return wrap(DenseStringElementsAttr::get( + llvm::cast<ShapedType>(unwrap(shapedType)), values)); } MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, @@ -743,12 +743,18 @@ MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, //===----------------------------------------------------------------------===// bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { - return llvm::cast<DenseElementsAttr>(unwrap(attr)).isSplat(); + Attribute a = unwrap(attr); + if (auto strAttr = llvm::dyn_cast<DenseStringElementsAttr>(a)) + return strAttr.isSplat(); + return llvm::cast<DenseElementsAttr>(a).isSplat(); } MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { + mlir::Attribute a = unwrap(attr); + if (auto strAttr = llvm::dyn_cast<DenseStringElementsAttr>(a)) + return wrap(strAttr.getSplatValue<mlir::Attribute>()); return wrap( - llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<Attribute>()); + llvm::cast<DenseElementsAttr>(a).getSplatValue<mlir::Attribute>()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<bool>(); @@ -778,8 +784,8 @@ double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { return llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<double>(); } MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { - return wrap( - llvm::cast<DenseElementsAttr>(unwrap(attr)).getSplatValue<StringRef>()); + return wrap(llvm::cast<DenseStringElementsAttr>(unwrap(attr)) + .getSplatValue<llvm::StringRef>()); } //===----------------------------------------------------------------------===// @@ -824,8 +830,8 @@ double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { - return wrap( - llvm::cast<DenseElementsAttr>(unwrap(attr)).getValues<StringRef>()[pos]); + return wrap(llvm::cast<DenseStringElementsAttr>(unwrap(attr)) + .getValues<StringRef>()[pos]); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index bbbc9198a68ab..e288be3271fab 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -589,23 +589,14 @@ static bool hasSameNumElementsOrSplat(ShapedType type, const Values &values) { //===----------------------------------------------------------------------===// DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( - DenseElementsAttr attr, size_t index) + Attribute attr, size_t index) : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, Attribute, Attribute, Attribute>( attr.getAsOpaquePointer(), index) {} Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base)); - Type eltTy = owner.getElementType(); - - // Handle strings specially. - if (llvm::isa<DenseStringElementsAttr>(owner)) { - ArrayRef<StringRef> vals = owner.getRawStringData(); - return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); - } - - // All other types should implement DenseElementTypeInterface. - auto denseEltTy = llvm::cast<DenseElementType>(eltTy); + auto denseEltTy = llvm::cast<DenseElementType>(owner.getElementType()); ArrayRef<char> rawData = owner.getRawData(); // Storage is byte-aligned: align bit size up to next byte boundary. size_t bitSize = denseEltTy.getDenseElementBitSize(); @@ -864,28 +855,13 @@ template class DenseArrayAttrImpl<double>; /// Method for support type inquiry through isa, cast and dyn_cast. bool DenseElementsAttr::classof(Attribute attr) { - return llvm::isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(attr); + return llvm::isa<DenseIntOrFPElementsAttr>(attr); } DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<Attribute> values) { assert(hasSameNumElementsOrSplat(type, values)); - Type eltType = type.getElementType(); - - // Handle strings specially. - if (!llvm::isa<DenseElementType>(eltType)) { - SmallVector<StringRef, 8> stringValues; - stringValues.reserve(values.size()); - for (Attribute attr : values) { - assert(llvm::isa<StringAttr>(attr) && - "expected string value for non-DenseElementType element"); - stringValues.push_back(llvm::cast<StringAttr>(attr).getValue()); - } - return get(type, stringValues); - } - - // All other types go through DenseElementTypeInterface. - auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType); + auto denseEltType = llvm::dyn_cast<DenseElementType>(type.getElementType()); assert(denseEltType && "attempted to get DenseElementsAttr with unsupported element type"); SmallVector<char> data; @@ -906,12 +882,6 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, values.size())); } -DenseElementsAttr DenseElementsAttr::get(ShapedType type, - ArrayRef<StringRef> values) { - assert(!type.getElementType().isIntOrFloat()); - return DenseStringElementsAttr::get(type, values); -} - /// Constructs a dense integer elements attribute from an array of APInt /// values. Each APInt value is expected to have the same bitwidth as the /// element type of 'type'. @@ -1048,9 +1018,6 @@ bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, /// values are the same. bool DenseElementsAttr::isSplat() const { // Splat iff the data array has exactly one element. - if (isa<DenseStringElementsAttr>(*this)) - return getRawStringData().size() == 1; - // FP/Int case. size_t storageSize = llvm::divideCeil( getDenseElementBitWidth(getType().getElementType()), CHAR_BIT); return getRawData().size() == storageSize; @@ -1100,10 +1067,6 @@ ArrayRef<char> DenseElementsAttr::getRawData() const { return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; } -ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { - return static_cast<DenseStringElementsAttrStorage *>(impl)->data; -} - /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the /// same total number of elements as well as element type. @@ -1390,6 +1353,27 @@ bool DenseIntElementsAttr::classof(Attribute attr) { return false; } +//===----------------------------------------------------------------------===// +// DenseStringElementsAttr +//===----------------------------------------------------------------------===// + +ShapedType DenseStringElementsAttr::getType() const { + return static_cast<const DenseStringElementsAttrStorage *>(impl)->type; +} + +ArrayRef<StringRef> DenseStringElementsAttr::getRawStringData() const { + return static_cast<const DenseStringElementsAttrStorage *>(impl)->data; +} + +Attribute +DenseStringElementsAttr::StringAttributeElementIterator::operator*() const { + auto attr = llvm::cast<DenseStringElementsAttr>( + Attribute::getFromOpaquePointer(this->base)); + auto data = attr.getRawStringData(); + return StringAttr::get(attr.isSplat() ? data.front() : data[this->index], + attr.getElementType()); +} + //===----------------------------------------------------------------------===// // DenseResourceElementsAttr //===----------------------------------------------------------------------===// @@ -1557,10 +1541,6 @@ Attribute SparseElementsAttr::getZeroAttr() const { ArrayRef<Attribute>{zero, zero}); } - // Handle string type. - if (llvm::isa<DenseStringElementsAttr>(getValues())) - return StringAttr::get("", eltType); - // Otherwise, this is an integer. return IntegerAttr::get(eltType, 0); } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 3bb6e38b4d613..c4a415e626760 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -797,10 +797,6 @@ func.func @sparsetensorattr() -> () { // CHECK: "foof321"() {bar = sparse<> : tensor<f32>} : () -> () "foof321"(){bar = sparse<> : tensor<f32>} : () -> () -// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<>>} : () -> () - "foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<>>} : () -> () -// CHECK: "foostr"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}"a", "b", "c"]> : tensor<2x2x2x!unknown<>>} : () -> () - "foostr"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], ["a", "b", "c"]> : tensor<2x2x2x!unknown<>>} : () -> () return } diff --git a/mlir/test/mlir-tblgen/openmp-clause-ops.td b/mlir/test/mlir-tblgen/openmp-clause-ops.td index 3e5896a00182b..c502b21c3baf8 100644 --- a/mlir/test/mlir-tblgen/openmp-clause-ops.td +++ b/mlir/test/mlir-tblgen/openmp-clause-ops.td @@ -59,7 +59,7 @@ def OpenMP_MyFirstClause : OpenMP_Clause< // CHECK-NEXT: ::mlir::IntegerAttr complexOptIntAttr; // CHECK-NEXT: ::mlir::ElementsAttr elementsAttr; -// CHECK-NEXT: ::mlir::DenseElementsAttr stringElementsAttr; +// CHECK-NEXT: ::mlir::DenseStringElementsAttr stringElementsAttr; // CHECK-NEXT: } def OpenMP_MySecondClause : OpenMP_Clause< diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index 404aa8c0dcf3d..f72aebabce280 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -38,6 +38,21 @@ static void testSplat(Type eltType, const EltTy &splatElt) { EXPECT_TRUE(newValue == splatElt); } +template <> +void testSplat<StringRef>(Type eltType, const StringRef &splatElt) { + RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); + + DenseStringElementsAttr splat = DenseStringElementsAttr::get(shape, splatElt); + EXPECT_TRUE(splat.isSplat()); + + auto detectedSplat = + DenseStringElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt})); + EXPECT_EQ(detectedSplat, splat); + + for (auto newValue : detectedSplat.getValues<StringRef>()) + EXPECT_TRUE(newValue == splatElt); +} + namespace { TEST(DenseSplatTest, BoolSplat) { MLIRContext context; @@ -184,8 +199,16 @@ TEST(DenseSplatTest, StringAttrSplat) { context.allowUnregisteredDialects(); Type stringType = OpaqueType::get(StringAttr::get(&context, "test"), "string"); + RankedTensorType shape = RankedTensorType::get({2, 1}, stringType); Attribute stringAttr = StringAttr::get("test-string", stringType); - testSplat(stringType, stringAttr); + StringRef value = llvm::cast<StringAttr>(stringAttr).getValue(); + DenseStringElementsAttr splat = DenseStringElementsAttr::get(shape, value); + EXPECT_TRUE(splat.isSplat()); + auto detectedSplat = + DenseStringElementsAttr::get(shape, llvm::ArrayRef({value, value})); + EXPECT_EQ(detectedSplat, splat); + for (auto newValue : detectedSplat.getValues<StringRef>()) + EXPECT_TRUE(newValue == value); } TEST(DenseComplexTest, ComplexFloatSplat) { @@ -396,11 +419,9 @@ TEST(SparseElementsAttrTest, GetZero) { IntegerType intTy = IntegerType::get(&context, 32); FloatType floatTy = Float32Type::get(&context); - Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string"); ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); - ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); @@ -413,13 +434,8 @@ TEST(SparseElementsAttrTest, GetZero) { RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); - RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); - auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")}); - auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); - auto sparseString = - SparseElementsAttr::get(tensorString, indices, stringValue); // Only index (0, 0) contains an element, others are supposed to return // the zero/empty value. @@ -432,11 +448,6 @@ TEST(SparseElementsAttrTest, GetZero) { cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]); EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); EXPECT_TRUE(zeroFloatValue.getType() == floatTy); - - auto zeroStringValue = - cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]); - EXPECT_TRUE(zeroStringValue.empty()); - EXPECT_TRUE(zeroStringValue.getType() == stringTy); } //===----------------------------------------------------------------------===// _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
