llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-llvm-analysis Author: S. VenkataKeerthy (svkeerthy) <details> <summary>Changes</summary> Refactored IR2Vec vocabulary and introduced IR (semantics) agnostic VocabStorage - Vocabulary "has-a" VocabStorage - Vocabulary deals with LLVM IR specific entities. This would help in efficient reuse of parts of the logic for MIR. - Storage uses a section-based approach instead of a flat vector, improving organization and access patterns. --- Patch is 43.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158376.diff 6 Files Affected: - (modified) llvm/include/llvm/Analysis/IR2Vec.h (+114-31) - (modified) llvm/lib/Analysis/IR2Vec.cpp (+174-56) - (modified) llvm/lib/Analysis/InlineAdvisor.cpp (+1-1) - (modified) llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp (+3-3) - (modified) llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp (+6-3) - (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+273-28) ``````````diff diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 4a6db5d895a62..7d51a7320d194 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -45,6 +45,7 @@ #include "llvm/Support/JSON.h" #include <array> #include <map> +#include <optional> namespace llvm { @@ -144,6 +145,73 @@ struct Embedding { using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>; using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>; +/// Generic storage class for section-based vocabularies. +/// VocabStorage provides a generic foundation for storing and accessing +/// embeddings organized into sections. +class VocabStorage { +private: + /// Section-based storage + std::vector<std::vector<Embedding>> Sections; + + size_t TotalSize = 0; + unsigned Dimension = 0; + +public: + /// Default constructor creates empty storage (invalid state) + VocabStorage() : Sections(), TotalSize(0), Dimension(0) {} + + /// Create a VocabStorage with pre-organized section data + VocabStorage(std::vector<std::vector<Embedding>> &&SectionData); + + VocabStorage(VocabStorage &&) = default; + VocabStorage &operator=(VocabStorage &&Other); + + VocabStorage(const VocabStorage &) = delete; + VocabStorage &operator=(const VocabStorage &) = delete; + + /// Get total number of entries across all sections + size_t size() const { return TotalSize; } + + /// Get number of sections + unsigned getNumSections() const { + return static_cast<unsigned>(Sections.size()); + } + + /// Section-based access: Storage[sectionId][localIndex] + const std::vector<Embedding> &operator[](unsigned SectionId) const { + assert(SectionId < Sections.size() && "Invalid section ID"); + return Sections[SectionId]; + } + + /// Get vocabulary dimension + unsigned getDimension() const { return Dimension; } + + /// Check if vocabulary is valid (has data) + bool isValid() const { return TotalSize > 0; } + + /// Iterator support for section-based access + class const_iterator { + const VocabStorage *Storage; + unsigned SectionId; + size_t LocalIndex; + + public: + const_iterator(const VocabStorage *Storage, unsigned SectionId, + size_t LocalIndex) + : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {} + + LLVM_ABI const Embedding &operator*() const; + LLVM_ABI const_iterator &operator++(); + LLVM_ABI bool operator==(const const_iterator &Other) const; + LLVM_ABI bool operator!=(const const_iterator &Other) const; + }; + + const_iterator begin() const { return const_iterator(this, 0, 0); } + const_iterator end() const { + return const_iterator(this, getNumSections(), 0); + } +}; + /// Class for storing and accessing the IR2Vec vocabulary. /// The Vocabulary class manages seed embeddings for LLVM IR entities. The /// seed embeddings are the initial learned representations of the entities @@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>; class Vocabulary { friend class llvm::IR2VecVocabAnalysis; - // Vocabulary Slot Layout: + // Vocabulary Layout: // +----------------+------------------------------------------------------+ // | Entity Type | Index Range | // +----------------+------------------------------------------------------+ @@ -175,8 +243,16 @@ class Vocabulary { // Note: "Similar" LLVM Types are grouped/canonicalized together. // Operands include Comparison predicates (ICmp/FCmp). // This can be extended to include other specializations in future. - using VocabVector = std::vector<ir2vec::Embedding>; - VocabVector Vocab; + enum class Section : unsigned { + Opcodes = 0, + CanonicalTypes = 1, + Operands = 2, + Predicates = 3, + MaxSections + }; + + // Use section-based storage for better organization and efficiency + VocabStorage Storage; static constexpr unsigned NumICmpPredicates = static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) - @@ -228,9 +304,18 @@ class Vocabulary { NumICmpPredicates + NumFCmpPredicates; Vocabulary() = default; - LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {} + LLVM_ABI Vocabulary(VocabStorage &&Storage) : Storage(std::move(Storage)) {} + + Vocabulary(const Vocabulary &) = delete; + Vocabulary &operator=(const Vocabulary &) = delete; + + Vocabulary(Vocabulary &&) = default; + Vocabulary &operator=(Vocabulary &&Other); + + LLVM_ABI bool isValid() const { + return Storage.size() == NumCanonicalEntries; + } - LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; }; LLVM_ABI unsigned getDimension() const; /// Total number of entries (opcodes + canonicalized types + operand kinds + /// predicates) @@ -251,12 +336,11 @@ class Vocabulary { /// Function to get vocabulary key for a given predicate LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P); - /// Functions to return the slot index or position of a given Opcode, TypeID, - /// or OperandKind in the vocabulary. - LLVM_ABI static unsigned getSlotIndex(unsigned Opcode); - LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID); - LLVM_ABI static unsigned getSlotIndex(const Value &Op); - LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P); + /// Functions to return flat index + LLVM_ABI static unsigned getIndex(unsigned Opcode); + LLVM_ABI static unsigned getIndex(Type::TypeID TypeID); + LLVM_ABI static unsigned getIndex(const Value &Op); + LLVM_ABI static unsigned getIndex(CmpInst::Predicate P); /// Accessors to get the embedding for a given entity. LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const; @@ -265,26 +349,21 @@ class Vocabulary { LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const; /// Const Iterator type aliases - using const_iterator = VocabVector::const_iterator; + using const_iterator = VocabStorage::const_iterator; + const_iterator begin() const { assert(isValid() && "IR2Vec Vocabulary is invalid"); - return Vocab.begin(); + return Storage.begin(); } - const_iterator cbegin() const { - assert(isValid() && "IR2Vec Vocabulary is invalid"); - return Vocab.cbegin(); - } + const_iterator cbegin() const { return begin(); } const_iterator end() const { assert(isValid() && "IR2Vec Vocabulary is invalid"); - return Vocab.end(); + return Storage.end(); } - const_iterator cend() const { - assert(isValid() && "IR2Vec Vocabulary is invalid"); - return Vocab.cend(); - } + const_iterator cend() const { return end(); } /// Returns the string key for a given index position in the vocabulary. /// This is useful for debugging or printing the vocabulary. Do not use this @@ -292,7 +371,7 @@ class Vocabulary { LLVM_ABI static StringRef getStringKey(unsigned Pos); /// Create a dummy vocabulary for testing purposes. - LLVM_ABI static VocabVector createDummyVocabForTest(unsigned Dim = 1); + LLVM_ABI static VocabStorage createDummyVocabForTest(unsigned Dim = 1); LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, ModuleAnalysisManager::Invalidator &Inv) const; @@ -301,12 +380,16 @@ class Vocabulary { constexpr static unsigned NumCanonicalEntries = MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds; - // Base offsets for slot layout to simplify index computation + // Base offsets for flat index computation constexpr static unsigned OperandBaseOffset = MaxOpcodes + MaxCanonicalTypeIDs; constexpr static unsigned PredicateBaseOffset = OperandBaseOffset + MaxOperandKinds; + /// Functions for predicate index calculations + static unsigned getPredicateLocalIndex(CmpInst::Predicate P); + static CmpInst::Predicate getPredicateFromLocalIndex(unsigned LocalIndex); + /// String mappings for CanonicalTypeID values static constexpr StringLiteral CanonicalTypeNames[] = { "FloatTy", "VoidTy", "LabelTy", "MetadataTy", @@ -452,22 +535,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder { /// mapping between an entity of the IR (like opcode, type, argument, etc.) and /// its corresponding embedding. class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> { - using VocabVector = std::vector<ir2vec::Embedding>; using VocabMap = std::map<std::string, ir2vec::Embedding>; - VocabMap OpcVocab, TypeVocab, ArgVocab; - VocabVector Vocab; + std::optional<ir2vec::VocabStorage> Vocab; - Error readVocabulary(); + Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab, + VocabMap &ArgVocab); Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, unsigned &Dim); - void generateNumMappedVocab(); + void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab, + VocabMap &ArgVocab); void emitError(Error Err, LLVMContext &Ctx); public: LLVM_ABI static AnalysisKey Key; IR2VecVocabAnalysis() = default; - LLVM_ABI explicit IR2VecVocabAnalysis(const VocabVector &Vocab); - LLVM_ABI explicit IR2VecVocabAnalysis(VocabVector &&Vocab); + LLVM_ABI explicit IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab) + : Vocab(std::move(Vocab)) {} using Result = ir2vec::Vocabulary; LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM); }; diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index f51f0898cb37e..eeba109eb7dbd 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Module.h" @@ -261,55 +262,121 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const { BBVecMap[&BB] = BBVector; } +// ==----------------------------------------------------------------------===// +// VocabStorage +//===----------------------------------------------------------------------===// + +VocabStorage::VocabStorage(std::vector<std::vector<Embedding>> &&SectionData) + : Sections(std::move(SectionData)) { + TotalSize = 0; + Dimension = 0; + assert(!Sections.empty() && "Vocabulary has no sections"); + assert(!Sections[0].empty() && "First section of vocabulary is empty"); + + // Compute total size across all sections + for (const auto &Section : Sections) + TotalSize += Section.size(); + + // Get dimension from the first embedding in the first section - all + // embeddings must have the same dimension + Dimension = static_cast<unsigned>(Sections[0][0].size()); +} + +VocabStorage &VocabStorage::operator=(VocabStorage &&Other) { + if (this != &Other) { + Sections = std::move(Other.Sections); + TotalSize = Other.TotalSize; + Dimension = Other.Dimension; + Other.TotalSize = 0; + Other.Dimension = 0; + } + return *this; +} + +const Embedding &VocabStorage::const_iterator::operator*() const { + assert(SectionId < Storage->Sections.size() && "Invalid section ID"); + assert(LocalIndex < Storage->Sections[SectionId].size() && + "Local index out of range"); + return Storage->Sections[SectionId][LocalIndex]; +} + +VocabStorage::const_iterator &VocabStorage::const_iterator::operator++() { + ++LocalIndex; + // Check if we need to move to the next section + while (SectionId < Storage->getNumSections() && + LocalIndex >= Storage->Sections[SectionId].size()) { + LocalIndex = 0; + ++SectionId; + } + return *this; +} + +bool VocabStorage::const_iterator::operator==( + const const_iterator &Other) const { + return Storage == Other.Storage && SectionId == Other.SectionId && + LocalIndex == Other.LocalIndex; +} + +bool VocabStorage::const_iterator::operator!=( + const const_iterator &Other) const { + return !(*this == Other); +} + // ==----------------------------------------------------------------------===// // Vocabulary //===----------------------------------------------------------------------===// +Vocabulary &Vocabulary::operator=(Vocabulary &&Other) { + if (this != &Other) + Storage = std::move(Other.Storage); + return *this; +} + unsigned Vocabulary::getDimension() const { assert(isValid() && "IR2Vec Vocabulary is invalid"); - return Vocab[0].size(); + return Storage.getDimension(); } -unsigned Vocabulary::getSlotIndex(unsigned Opcode) { +unsigned Vocabulary::getIndex(unsigned Opcode) { assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); return Opcode - 1; // Convert to zero-based index } -unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) { +unsigned Vocabulary::getIndex(Type::TypeID TypeID) { assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID"); return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID)); } -unsigned Vocabulary::getSlotIndex(const Value &Op) { +unsigned Vocabulary::getIndex(const Value &Op) { unsigned Index = static_cast<unsigned>(getOperandKind(&Op)); assert(Index < MaxOperandKinds && "Invalid OperandKind"); return OperandBaseOffset + Index; } -unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) { - unsigned PU = static_cast<unsigned>(P); - unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE); - unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE); - - unsigned PredIdx = - (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC); - return PredicateBaseOffset + PredIdx; +unsigned Vocabulary::getIndex(CmpInst::Predicate P) { + return PredicateBaseOffset + getPredicateLocalIndex(P); } const Embedding &Vocabulary::operator[](unsigned Opcode) const { - return Vocab[getSlotIndex(Opcode)]; + assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); + return Storage[static_cast<unsigned>(Section::Opcodes)][Opcode - 1]; } const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const { - return Vocab[getSlotIndex(TypeID)]; + assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID"); + unsigned LocalIndex = static_cast<unsigned>(getCanonicalTypeID(TypeID)); + return Storage[static_cast<unsigned>(Section::CanonicalTypes)][LocalIndex]; } const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const { - return Vocab[getSlotIndex(Arg)]; + unsigned LocalIndex = static_cast<unsigned>(getOperandKind(&Arg)); + assert(LocalIndex < MaxOperandKinds && "Invalid OperandKind"); + return Storage[static_cast<unsigned>(Section::Operands)][LocalIndex]; } const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const { - return Vocab[getSlotIndex(P)]; + unsigned LocalIndex = getPredicateLocalIndex(P); + return Storage[static_cast<unsigned>(Section::Predicates)][LocalIndex]; } StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { @@ -359,12 +426,26 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) { assert(Index < MaxPredicateKinds && "Invalid predicate index"); - unsigned PredEnumVal = - (Index < NumFCmpPredicates) - ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index) - : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + - (Index - NumFCmpPredicates)); - return static_cast<CmpInst::Predicate>(PredEnumVal); + return getPredicateFromLocalIndex(Index); +} + +unsigned Vocabulary::getPredicateLocalIndex(CmpInst::Predicate P) { + if (P >= CmpInst::FIRST_FCMP_PREDICATE && P <= CmpInst::LAST_FCMP_PREDICATE) + return P - CmpInst::FIRST_FCMP_PREDICATE; + else + return P - CmpInst::FIRST_ICMP_PREDICATE + + (CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1); +} + +CmpInst::Predicate Vocabulary::getPredicateFromLocalIndex(unsigned LocalIndex) { + unsigned fcmpRange = + CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1; + if (LocalIndex < fcmpRange) + return static_cast<CmpInst::Predicate>(CmpInst::FIRST_FCMP_PREDICATE + + LocalIndex); + else + return static_cast<CmpInst::Predicate>(CmpInst::FIRST_ICMP_PREDICATE + + LocalIndex - fcmpRange); } StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) { @@ -401,17 +482,51 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA, return !(PAC.preservedWhenStateless()); } -Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) { - VocabVector DummyVocab; - DummyVocab.reserve(NumCanonicalEntries); +VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) { float DummyVal = 0.1f; - // Create a dummy vocabulary with entries for all opcodes, types, operands - // and predicates - for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) { - DummyVocab.push_back(Embedding(Dim, DummyVal)); + + // Create sections for opcodes, types, operands, and predicates + // Order must match Vocabulary::Section enum + std::vector<std::vector<Embedding>> Sections; + Sections.reserve(4); + + // Opcodes section + std::vector<Embedding> OpcodeSec; + OpcodeSec.reserve(MaxOpcodes); + for (unsigned I = 0; I < MaxOpcodes; ++I) { + OpcodeSec.emplace_back(Dim, DummyVal); + DummyVal += 0.1f; + } + Sections.push_back(std::move(OpcodeSec)); + + // Types section + std::vector<Embedding> TypeSec; + TypeSec.reserve(MaxCanonicalTypeIDs); + for (unsigned I = 0; I < MaxCanonicalTypeIDs; ++I) { + TypeSec.emplace_back(Dim, DummyVal); + DummyVal += 0.1f; + } + Sections.push_back(std::move(TypeSec)); + + // Operands section + std::vector<Embedding> OperandSec; + OperandSec.reserve(MaxOperandKinds); + for (unsigned I = 0; I < MaxOperandKinds; ++I) { + OperandSec.emplace_back(Dim, DummyVal); DummyVal += 0.1f; } - return DummyVocab; + Sections.push_back(std::move(OperandSec)); + + // Predicates section + std::vector<Embedding> PredicateSec; + PredicateSec.reserve(MaxPredicateKinds); + for (unsigned I = 0; I < MaxPredicateKinds; ++I) { + PredicateSec.emplace_back(Dim, DummyVal); + DummyVal += 0.1f; + } + Sections.push_back(std::move(PredicateSec)); + + return VocabStorage(std::move(Sections)); } // ==----------------------------------------------------------------------===// @@ -457,7 +572,9 @@ Error IR2VecVocabAnalysis::parseVocabSection( // FIXME: Make this optional. We can avoid file reads // by auto-generating a default vocabulary during the build time. -Error IR2VecVocabAnalysis::readVocabulary() { +Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab, + VocabMap &TypeVocab, + VocabMap &ArgVocab) { auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true); if (!BufOrError) return createFileError(VocabFile, BufOrError.getError()); @@ -488,7 +605,9 @@ Error IR2VecVocabAnalysis::readVocabulary() { return Error::success(); } -void IR2VecVocabAnalysis::generateNumMappedVocab() { +void IR2VecVocabAnalysis::generateVocabStorage(VocabMap &OpcVocab, + VocabMap &TypeVocab, + VocabMap &ArgVocab) { // Helper for handling missing entities in the vocabulary. // Currently, we use a zero vector. In the future, we will throw an error to @@ -506,7 +625,6 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { // Handle Opcodes std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes, Embedding(Dim)); - NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes); for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) { StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1); auto It = OpcVocab.find(VocabKey.str()); @@ -515,13 +633,10 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { else handleMissingEntity(VocabKey.str()); } - Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(), - NumericOpcodeEmbeddings.end()); ... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/158376 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits