llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Farzon Lotfi (farzonl) <details> <summary>Changes</summary> fixes #<!-- -->166206 - Add swizzle support if row index is constant - Add test cases - Add new AST type - Add new LValue for Matrix Row Type - TODO: Make the new LValue a dynamic index version of ExtVectorElt --- Patch is 72.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170779.diff 33 Files Affected: - (modified) clang/include/clang/AST/ComputeDependence.h (+2) - (modified) clang/include/clang/AST/Expr.h (+67) - (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+1) - (modified) clang/include/clang/AST/Stmt.h (+1) - (modified) clang/include/clang/Basic/StmtNodes.td (+1) - (modified) clang/include/clang/Sema/Sema.h (+3) - (modified) clang/lib/AST/ComputeDependence.cpp (+4) - (modified) clang/lib/AST/Expr.cpp (+1) - (modified) clang/lib/AST/ExprClassification.cpp (+3) - (modified) clang/lib/AST/ExprConstant.cpp (+1) - (modified) clang/lib/AST/ItaniumMangle.cpp (+9) - (modified) clang/lib/AST/StmtPrinter.cpp (+8) - (modified) clang/lib/AST/StmtProfile.cpp (+5) - (modified) clang/lib/CodeGen/CGExpr.cpp (+93) - (modified) clang/lib/CodeGen/CGExprScalar.cpp (+35) - (modified) clang/lib/CodeGen/CGValue.h (+18-1) - (modified) clang/lib/CodeGen/CodeGenFunction.h (+1) - (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1) - (modified) clang/lib/Sema/SemaExpr.cpp (+60-1) - (modified) clang/lib/Sema/TreeTransform.h (+29) - (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+8) - (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+9) - (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+5) - (added) clang/test/AST/HLSL/matrix-single-subscript-getter.hlsl (+77) - (added) clang/test/AST/HLSL/matrix-single-subscript-setter.hlsl (+59) - (added) clang/test/AST/HLSL/matrix-single-subscript-swizzle.hlsl (+56) - (added) clang/test/AST/HLSL/pch_with_matrix_single_subscript.hlsl (+16) - (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptConstSwizzle.hlsl (+60) - (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptDynamicSwizzle.hlsl (+10) - (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptGetter.hlsl (+205) - (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptSetter.hlsl (+126) - (added) clang/test/SemaHLSL/matrix_single_subscript_errors.hlsl (+12) - (modified) clang/tools/libclang/CXCursor.cpp (+5) ``````````diff diff --git a/clang/include/clang/AST/ComputeDependence.h b/clang/include/clang/AST/ComputeDependence.h index c298f2620f211..895105640b931 100644 --- a/clang/include/clang/AST/ComputeDependence.h +++ b/clang/include/clang/AST/ComputeDependence.h @@ -28,6 +28,7 @@ class ParenExpr; class UnaryOperator; class UnaryExprOrTypeTraitExpr; class ArraySubscriptExpr; +class MatrixSingleSubscriptExpr; class MatrixSubscriptExpr; class CompoundLiteralExpr; class ImplicitCastExpr; @@ -117,6 +118,7 @@ ExprDependence computeDependence(ParenExpr *E); ExprDependence computeDependence(UnaryOperator *E, const ASTContext &Ctx); ExprDependence computeDependence(UnaryExprOrTypeTraitExpr *E); ExprDependence computeDependence(ArraySubscriptExpr *E); +ExprDependence computeDependence(MatrixSingleSubscriptExpr *E); ExprDependence computeDependence(MatrixSubscriptExpr *E); ExprDependence computeDependence(CompoundLiteralExpr *E); ExprDependence computeDependence(ImplicitCastExpr *E); diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h index 573cc72db35c6..16d9bbe8ff7c1 100644 --- a/clang/include/clang/AST/Expr.h +++ b/clang/include/clang/AST/Expr.h @@ -2790,6 +2790,73 @@ class ArraySubscriptExpr : public Expr { } }; +/// MatrixSingleSubscriptExpr - Matrix single subscript expression for the +/// MatrixType extension when you want to get\set a vector from a Matrix. +class MatrixSingleSubscriptExpr : public Expr { + enum { BASE, ROW_IDX, END_EXPR }; + Stmt *SubExprs[END_EXPR]; + +public: + /// matrix[row] + /// + /// \param Base The matrix expression. + /// \param RowIdx The row index expression. + /// \param T The type of the row (usually a vector type). + /// \param RBracketLoc Location of the closing ']'. + MatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx, QualType T, + SourceLocation RBracketLoc) + : Expr(MatrixSingleSubscriptExprClass, T, + Base->getValueKind(), // lvalue/rvalue follows the matrix base + OK_MatrixComponent) { // or OK_Ordinary/OK_VectorComponent if you + // prefer + SubExprs[BASE] = Base; + SubExprs[ROW_IDX] = RowIdx; + ArrayOrMatrixSubscriptExprBits.RBracketLoc = RBracketLoc; + setDependence(computeDependence(this)); + } + + /// Create an empty matrix single-subscript expression. + explicit MatrixSingleSubscriptExpr(EmptyShell Shell) + : Expr(MatrixSingleSubscriptExprClass, Shell) {} + + Expr *getBase() { return cast<Expr>(SubExprs[BASE]); } + const Expr *getBase() const { return cast<Expr>(SubExprs[BASE]); } + void setBase(Expr *E) { SubExprs[BASE] = E; } + + Expr *getRowIdx() { return cast<Expr>(SubExprs[ROW_IDX]); } + const Expr *getRowIdx() const { return cast<Expr>(SubExprs[ROW_IDX]); } + void setRowIdx(Expr *E) { SubExprs[ROW_IDX] = E; } + + SourceLocation getBeginLoc() const LLVM_READONLY { + return getBase()->getBeginLoc(); + } + + SourceLocation getEndLoc() const { return getRBracketLoc(); } + + SourceLocation getExprLoc() const LLVM_READONLY { + return getBase()->getExprLoc(); + } + + SourceLocation getRBracketLoc() const { + return ArrayOrMatrixSubscriptExprBits.RBracketLoc; + } + void setRBracketLoc(SourceLocation L) { + ArrayOrMatrixSubscriptExprBits.RBracketLoc = L; + } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == MatrixSingleSubscriptExprClass; + } + + // Iterators + child_range children() { + return child_range(&SubExprs[0], &SubExprs[0] + END_EXPR); + } + const_child_range children() const { + return const_child_range(&SubExprs[0], &SubExprs[0] + END_EXPR); + } +}; + /// MatrixSubscriptExpr - Matrix subscript expression for the MatrixType /// extension. /// MatrixSubscriptExpr can be either incomplete (only Base and RowIdx are set diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index 8f427427d71ed..92409b72e4f0c 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -2893,6 +2893,7 @@ DEF_TRAVERSE_STMT(CXXMemberCallExpr, {}) // over the children. DEF_TRAVERSE_STMT(AddrLabelExpr, {}) DEF_TRAVERSE_STMT(ArraySubscriptExpr, {}) +DEF_TRAVERSE_STMT(MatrixSingleSubscriptExpr, {}) DEF_TRAVERSE_STMT(MatrixSubscriptExpr, {}) DEF_TRAVERSE_STMT(ArraySectionExpr, {}) DEF_TRAVERSE_STMT(OMPArrayShapingExpr, {}) diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h index e1cca34d2212c..21d0a7dfe577c 100644 --- a/clang/include/clang/AST/Stmt.h +++ b/clang/include/clang/AST/Stmt.h @@ -530,6 +530,7 @@ class alignas(void *) Stmt { class ArrayOrMatrixSubscriptExprBitfields { friend class ArraySubscriptExpr; friend class MatrixSubscriptExpr; + friend class MatrixSingleSubscriptExpr; LLVM_PREFERRED_TYPE(ExprBitfields) unsigned : NumExprBits; diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td index bf3686bb372d5..ada74807e56e2 100644 --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -74,6 +74,7 @@ def UnaryOperator : StmtNode<Expr>; def OffsetOfExpr : StmtNode<Expr>; def UnaryExprOrTypeTraitExpr : StmtNode<Expr>; def ArraySubscriptExpr : StmtNode<Expr>; +def MatrixSingleSubscriptExpr : StmtNode<Expr>; def MatrixSubscriptExpr : StmtNode<Expr>; def ArraySectionExpr : StmtNode<Expr>; def OMPIteratorExpr : StmtNode<Expr>; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 4a601a0eaf1b9..d4d5c3d8bed17 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -7406,6 +7406,9 @@ class Sema final : public SemaBase { ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc, Expr *Idx, SourceLocation RLoc); + ExprResult CreateBuiltinMatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx, + SourceLocation RBLoc); + ExprResult CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx, Expr *ColumnIdx, SourceLocation RBLoc); diff --git a/clang/lib/AST/ComputeDependence.cpp b/clang/lib/AST/ComputeDependence.cpp index 638080ea781a9..8429f17d26be5 100644 --- a/clang/lib/AST/ComputeDependence.cpp +++ b/clang/lib/AST/ComputeDependence.cpp @@ -115,6 +115,10 @@ ExprDependence clang::computeDependence(ArraySubscriptExpr *E) { return E->getLHS()->getDependence() | E->getRHS()->getDependence(); } +ExprDependence clang::computeDependence(MatrixSingleSubscriptExpr *E) { + return E->getBase()->getDependence() | E->getRowIdx()->getDependence(); +} + ExprDependence clang::computeDependence(MatrixSubscriptExpr *E) { return E->getBase()->getDependence() | E->getRowIdx()->getDependence() | (E->getColumnIdx() ? E->getColumnIdx()->getDependence() diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp index ca7f3e16a9276..b400b2a083d9b 100644 --- a/clang/lib/AST/Expr.cpp +++ b/clang/lib/AST/Expr.cpp @@ -3789,6 +3789,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx, case ParenExprClass: case ArraySubscriptExprClass: + case MatrixSingleSubscriptExprClass: case MatrixSubscriptExprClass: case ArraySectionExprClass: case OMPArrayShapingExprClass: diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp index aeacd0dc765ef..9995d1b411c5b 100644 --- a/clang/lib/AST/ExprClassification.cpp +++ b/clang/lib/AST/ExprClassification.cpp @@ -259,6 +259,9 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) { } return Cl::CL_LValue; + case Expr::MatrixSingleSubscriptExprClass: + return ClassifyInternal(Ctx, cast<MatrixSingleSubscriptExpr>(E)->getBase()); + // Subscripting matrix types behaves like member accesses. case Expr::MatrixSubscriptExprClass: return ClassifyInternal(Ctx, cast<MatrixSubscriptExpr>(E)->getBase()); diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index 11c5e1c6e90f4..52481dc71b75d 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -20667,6 +20667,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) { case Expr::ImaginaryLiteralClass: case Expr::StringLiteralClass: case Expr::ArraySubscriptExprClass: + case Expr::MatrixSingleSubscriptExprClass: case Expr::MatrixSubscriptExprClass: case Expr::ArraySectionExprClass: case Expr::OMPArrayShapingExprClass: diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp index 5572e0a7ae59c..cb71987fba766 100644 --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -5482,6 +5482,15 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity, break; } + case Expr::MatrixSingleSubscriptExprClass: { + NotPrimaryExpr(); + const MatrixSingleSubscriptExpr *ME = cast<MatrixSingleSubscriptExpr>(E); + Out << "ix"; + mangleExpression(ME->getBase()); + mangleExpression(ME->getRowIdx()); + break; + } + case Expr::MatrixSubscriptExprClass: { NotPrimaryExpr(); const MatrixSubscriptExpr *ME = cast<MatrixSubscriptExpr>(E); diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp index ff8ca01ec5477..51b9c47f22ff0 100644 --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -1685,6 +1685,14 @@ void StmtPrinter::VisitArraySubscriptExpr(ArraySubscriptExpr *Node) { OS << "]"; } +void StmtPrinter::VisitMatrixSingleSubscriptExpr( + MatrixSingleSubscriptExpr *Node) { + PrintExpr(Node->getBase()); + OS << "["; + PrintExpr(Node->getRowIdx()); + OS << "]"; +} + void StmtPrinter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *Node) { PrintExpr(Node->getBase()); OS << "["; diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index 4a8c638c85331..c7b7c65715dfc 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -1508,6 +1508,11 @@ void StmtProfiler::VisitArraySubscriptExpr(const ArraySubscriptExpr *S) { VisitExpr(S); } +void StmtProfiler::VisitMatrixSingleSubscriptExpr( + const MatrixSingleSubscriptExpr *S) { + VisitExpr(S); +} + void StmtProfiler::VisitMatrixSubscriptExpr(const MatrixSubscriptExpr *S) { VisitExpr(S); } diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index e842158236cd4..ca06b5df94cb3 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -1796,6 +1796,8 @@ LValue CodeGenFunction::EmitLValueHelper(const Expr *E, return EmitUnaryOpLValue(cast<UnaryOperator>(E)); case Expr::ArraySubscriptExprClass: return EmitArraySubscriptExpr(cast<ArraySubscriptExpr>(E)); + case Expr::MatrixSingleSubscriptExprClass: + return EmitMatrixSingleSubscriptExpr(cast<MatrixSingleSubscriptExpr>(E)); case Expr::MatrixSubscriptExprClass: return EmitMatrixSubscriptExpr(cast<MatrixSubscriptExpr>(E)); case Expr::ArraySectionExprClass: @@ -2440,6 +2442,35 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) { Builder.CreateLoad(LV.getMatrixAddress(), LV.isVolatileQualified()); return RValue::get(Builder.CreateExtractElement(Load, Idx, "matrixext")); } + if (LV.isMatrixRow()) { + QualType MatTy = LV.getType(); + const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>(); + + unsigned NumRows = MT->getNumRows(); + unsigned NumCols = MT->getNumColumns(); + + llvm::Value *MatrixVec = EmitLoadOfScalar(LV, Loc); + + llvm::Value *Row = LV.getMatrixRowIdx(); + llvm::Value *Result = + llvm::UndefValue::get(ConvertType(LV.getType())); // <NumCols x T> + + llvm::MatrixBuilder MB(Builder); + + for (unsigned Col = 0; Col < NumCols; ++Col) { + llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col); + + llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows); + + llvm::Value *Elt = Builder.CreateExtractElement(MatrixVec, EltIndex); + + llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col); + + Result = Builder.CreateInsertElement(Result, Elt, Lane); + } + + return RValue::get(Result); + } assert(LV.isBitField() && "Unknown LValue type!"); return EmitLoadOfBitfieldLValue(LV, Loc); @@ -2662,6 +2693,36 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst, addInstToCurrentSourceAtom(I, Vec); return; } + if (Dst.isMatrixRow()) { + QualType MatTy = Dst.getType(); + const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>(); + + unsigned NumRows = MT->getNumRows(); + unsigned NumCols = MT->getNumColumns(); + + llvm::Value *MatrixVec = + Builder.CreateLoad(Dst.getAddress(), "matrix.load"); + + llvm::Value *Row = Dst.getMatrixRowIdx(); + llvm::Value *RowVal = Src.getScalarVal(); // <NumCols x T> + + llvm::MatrixBuilder MB(Builder); + + for (unsigned Col = 0; Col < NumCols; ++Col) { + llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col); + + llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows); + + llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col); + + llvm::Value *NewElt = Builder.CreateExtractElement(RowVal, Lane); + + MatrixVec = Builder.CreateInsertElement(MatrixVec, NewElt, EltIndex); + } + + Builder.CreateStore(MatrixVec, Dst.getAddress()); + return; + } assert(Dst.isBitField() && "Unknown LValue type"); return EmitStoreThroughBitfieldLValue(Src, Dst); @@ -4874,6 +4935,35 @@ llvm::Value *CodeGenFunction::EmitMatrixIndexExpr(const Expr *E) { return Builder.CreateIntCast(Idx, IntPtrTy, IsSigned); } +LValue CodeGenFunction::EmitMatrixSingleSubscriptExpr( + const MatrixSingleSubscriptExpr *E) { + LValue Base = EmitLValue(E->getBase()); + llvm::Value *RowIdx = EmitMatrixIndexExpr(E->getRowIdx()); + + if (auto *RowConst = llvm::dyn_cast<llvm::ConstantInt>(RowIdx)) { + + // Extract matrix shape from the AST type + const auto *MatTy = E->getBase()->getType()->castAs<ConstantMatrixType>(); + unsigned NumCols = MatTy->getNumColumns(); + llvm::SmallVector<llvm::Constant *, 8> Indices; + Indices.reserve(NumCols); + + unsigned Row = RowConst->getZExtValue(); + unsigned Start = Row * NumCols; + for (unsigned C = 0; C < NumCols; ++C) { + Indices.push_back(llvm::ConstantInt::get(Int32Ty, Start + C)); + } + llvm::Constant *Elts = llvm::ConstantVector::get(Indices); + return LValue::MakeExtVectorElt( + MaybeConvertMatrixAddress(Base.getAddress(), *this), Elts, + E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo()); + } + + return LValue::MakeMatrixRow( + MaybeConvertMatrixAddress(Base.getAddress(), *this), RowIdx, + E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo()); +} + LValue CodeGenFunction::EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E) { assert( !E->isIncomplete() && @@ -5146,6 +5236,9 @@ EmitExtVectorElementExpr(const ExtVectorElementExpr *E) { return LValue::MakeExtVectorElt(Base.getAddress(), CV, type, Base.getBaseInfo(), TBAAAccessInfo()); } + if (Base.isMatrixRow()) + return EmitUnsupportedLValue(E, "Matrix single index swizzle"); + assert(Base.isExtVectorElt() && "Can only subscript lvalue vec elts here!"); llvm::Constant *BaseElts = Base.getExtVectorElts(); diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp index 769bc37b0e131..70397e8cb99c2 100644 --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -599,6 +599,7 @@ class ScalarExprEmitter } Value *VisitArraySubscriptExpr(ArraySubscriptExpr *E); + Value *VisitMatrixSingleSubscriptExpr(MatrixSingleSubscriptExpr *E); Value *VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E); Value *VisitShuffleVectorExpr(ShuffleVectorExpr *E); Value *VisitConvertVectorExpr(ConvertVectorExpr *E); @@ -2109,6 +2110,40 @@ Value *ScalarExprEmitter::VisitArraySubscriptExpr(ArraySubscriptExpr *E) { return Builder.CreateExtractElement(Base, Idx, "vecext"); } +Value *ScalarExprEmitter::VisitMatrixSingleSubscriptExpr( + MatrixSingleSubscriptExpr *E) { + TestAndClearIgnoreResultAssign(); + + auto *MatrixTy = E->getBase()->getType()->castAs<ConstantMatrixType>(); + unsigned NumRows = MatrixTy->getNumRows(); + unsigned NumColumns = MatrixTy->getNumColumns(); + + // Row index + Value *RowIdx = CGF.EmitMatrixIndexExpr(E->getRowIdx()); + + llvm::MatrixBuilder MB(Builder); + + // The row index must be in [0, NumRows) + if (CGF.CGM.getCodeGenOpts().OptimizationLevel > 0) + MB.CreateIndexAssumption(RowIdx, NumRows); + + Value *FlatMatrix = Visit(E->getBase()); + llvm::Type *ElemTy = CGF.ConvertType(MatrixTy->getElementType()); + auto *ResultTy = llvm::FixedVectorType::get(ElemTy, NumColumns); + Value *RowVec = llvm::UndefValue::get(ResultTy); + + for (unsigned Col = 0; Col != NumColumns; ++Col) { + Value *ColVal = llvm::ConstantInt::get(RowIdx->getType(), Col); + Value *EltIdx = MB.CreateIndex(RowIdx, ColVal, NumRows, "matrix_row_idx"); + Value *Elt = + Builder.CreateExtractElement(FlatMatrix, EltIdx, "matrix_elem"); + Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col); + RowVec = Builder.CreateInsertElement(RowVec, Elt, Lane, "matrix_row_ins"); + } + + return RowVec; +} + Value *ScalarExprEmitter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) { TestAndClearIgnoreResultAssign(); diff --git a/clang/lib/CodeGen/CGValue.h b/clang/lib/CodeGen/CGValue.h index 6b381b59e71cd..c08ca70de10e1 100644 --- a/clang/lib/CodeGen/CGValue.h +++ b/clang/lib/CodeGen/CGValue.h @@ -187,7 +187,8 @@ class LValue { BitField, // This is a bitfield l-value, use getBitfield*. ExtVectorElt, // This is an extended vector subset, use getExtVectorComp GlobalReg, // This is a register l-value, use getGlobalReg() - MatrixElt // This is a matrix element, use getVector* + MatrixElt, // This is a matrix element, use getVector* + MatrixRow // This is a matrix vector subset, use getVector* } LVType; union { @@ -282,6 +283,7 @@ class LValue { bool isExtVectorElt() const { return LVType == ExtVectorElt; } bool isGlobalReg() const { return LVType == GlobalReg; } bool isMatrixElt() const { return LVType == MatrixElt; } + bool isMatrixRow() const { return LVType == MatrixRow; } bool isVolatileQualified() const { return Quals.hasVolatile(); } bool isRestrictQualified() const { return Quals.hasRestrict(); } @@ -398,6 +400,11 @@ class LValue { return VectorIdx; } + llvm::Value *getMatrixRowIdx() const { + assert(isMatrixRow()); + return VectorIdx; + } + // extended vector elements. Address getExtVectorAddress() const { assert(isExtVectorElt()); @@ -486,6 +493,16 @@ class LValue { return R; } + static LValue MakeMatrixRow(Address Addr, llvm::Value *RowIdx, + QualType MatrixTy, LValueBaseInfo BaseInfo, + TBAAAccessInfo TBAAInfo) { + LValue LV; + LV.LVType = MatrixRow; + LV.VectorIdx = RowIdx; // store the row index here + LV.Initialize(MatrixTy, MatrixTy.getQualifiers(), Addr, BaseInfo, TBAAInfo); + return LV; + } + static LValue MakeMatrixElt(Address matAddress, llvm::Value *Idx, QualType type, LValueBaseInfo BaseInfo, TBAAAccessInfo TBAAInfo) { diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 8c4c1c8c2dc95..3abe516debcb0 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -4412,6 +4412,7 @@ class CodeGenFunction : public CodeGenTypeCache { LValue EmitArraySubscriptExpr(const ArraySubscriptExpr *E, bool Accessed = false); llvm::Value *EmitMatrixIndexExpr(const Expr *E); + LValue EmitMatrixSingleSubscriptExpr(const MatrixSingleSubscriptExpr *E); LValue EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E); LValue EmitArraySectionExpr(const ArraySect... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/170779 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
