llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Farzon Lotfi (farzonl)

<details>
<summary>Changes</summary>

fixes #<!-- -->166206

- Add swizzle support if row index is constant
- Add test cases
- Add new AST type
- Add new LValue for Matrix Row Type
- TODO: Make the new LValue a dynamic index version of ExtVectorElt

---

Patch is 72.24 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/170779.diff


33 Files Affected:

- (modified) clang/include/clang/AST/ComputeDependence.h (+2) 
- (modified) clang/include/clang/AST/Expr.h (+67) 
- (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+1) 
- (modified) clang/include/clang/AST/Stmt.h (+1) 
- (modified) clang/include/clang/Basic/StmtNodes.td (+1) 
- (modified) clang/include/clang/Sema/Sema.h (+3) 
- (modified) clang/lib/AST/ComputeDependence.cpp (+4) 
- (modified) clang/lib/AST/Expr.cpp (+1) 
- (modified) clang/lib/AST/ExprClassification.cpp (+3) 
- (modified) clang/lib/AST/ExprConstant.cpp (+1) 
- (modified) clang/lib/AST/ItaniumMangle.cpp (+9) 
- (modified) clang/lib/AST/StmtPrinter.cpp (+8) 
- (modified) clang/lib/AST/StmtProfile.cpp (+5) 
- (modified) clang/lib/CodeGen/CGExpr.cpp (+93) 
- (modified) clang/lib/CodeGen/CGExprScalar.cpp (+35) 
- (modified) clang/lib/CodeGen/CGValue.h (+18-1) 
- (modified) clang/lib/CodeGen/CodeGenFunction.h (+1) 
- (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1) 
- (modified) clang/lib/Sema/SemaExpr.cpp (+60-1) 
- (modified) clang/lib/Sema/TreeTransform.h (+29) 
- (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+8) 
- (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+9) 
- (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+5) 
- (added) clang/test/AST/HLSL/matrix-single-subscript-getter.hlsl (+77) 
- (added) clang/test/AST/HLSL/matrix-single-subscript-setter.hlsl (+59) 
- (added) clang/test/AST/HLSL/matrix-single-subscript-swizzle.hlsl (+56) 
- (added) clang/test/AST/HLSL/pch_with_matrix_single_subscript.hlsl (+16) 
- (added) 
clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptConstSwizzle.hlsl 
(+60) 
- (added) 
clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptDynamicSwizzle.hlsl 
(+10) 
- (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptGetter.hlsl 
(+205) 
- (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptSetter.hlsl 
(+126) 
- (added) clang/test/SemaHLSL/matrix_single_subscript_errors.hlsl (+12) 
- (modified) clang/tools/libclang/CXCursor.cpp (+5) 


``````````diff
diff --git a/clang/include/clang/AST/ComputeDependence.h 
b/clang/include/clang/AST/ComputeDependence.h
index c298f2620f211..895105640b931 100644
--- a/clang/include/clang/AST/ComputeDependence.h
+++ b/clang/include/clang/AST/ComputeDependence.h
@@ -28,6 +28,7 @@ class ParenExpr;
 class UnaryOperator;
 class UnaryExprOrTypeTraitExpr;
 class ArraySubscriptExpr;
+class MatrixSingleSubscriptExpr;
 class MatrixSubscriptExpr;
 class CompoundLiteralExpr;
 class ImplicitCastExpr;
@@ -117,6 +118,7 @@ ExprDependence computeDependence(ParenExpr *E);
 ExprDependence computeDependence(UnaryOperator *E, const ASTContext &Ctx);
 ExprDependence computeDependence(UnaryExprOrTypeTraitExpr *E);
 ExprDependence computeDependence(ArraySubscriptExpr *E);
+ExprDependence computeDependence(MatrixSingleSubscriptExpr *E);
 ExprDependence computeDependence(MatrixSubscriptExpr *E);
 ExprDependence computeDependence(CompoundLiteralExpr *E);
 ExprDependence computeDependence(ImplicitCastExpr *E);
diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 573cc72db35c6..16d9bbe8ff7c1 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -2790,6 +2790,73 @@ class ArraySubscriptExpr : public Expr {
   }
 };
 
+/// MatrixSingleSubscriptExpr - Matrix single subscript expression for the
+/// MatrixType extension when you want to get\set a vector from a Matrix.
+class MatrixSingleSubscriptExpr : public Expr {
+  enum { BASE, ROW_IDX, END_EXPR };
+  Stmt *SubExprs[END_EXPR];
+
+public:
+  /// matrix[row]
+  ///
+  /// \param Base        The matrix expression.
+  /// \param RowIdx      The row index expression.
+  /// \param T           The type of the row (usually a vector type).
+  /// \param RBracketLoc Location of the closing ']'.
+  MatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx, QualType T,
+                            SourceLocation RBracketLoc)
+      : Expr(MatrixSingleSubscriptExprClass, T,
+             Base->getValueKind(), // lvalue/rvalue follows the matrix base
+             OK_MatrixComponent) { // or OK_Ordinary/OK_VectorComponent if you
+                                   // prefer
+    SubExprs[BASE] = Base;
+    SubExprs[ROW_IDX] = RowIdx;
+    ArrayOrMatrixSubscriptExprBits.RBracketLoc = RBracketLoc;
+    setDependence(computeDependence(this));
+  }
+
+  /// Create an empty matrix single-subscript expression.
+  explicit MatrixSingleSubscriptExpr(EmptyShell Shell)
+      : Expr(MatrixSingleSubscriptExprClass, Shell) {}
+
+  Expr *getBase() { return cast<Expr>(SubExprs[BASE]); }
+  const Expr *getBase() const { return cast<Expr>(SubExprs[BASE]); }
+  void setBase(Expr *E) { SubExprs[BASE] = E; }
+
+  Expr *getRowIdx() { return cast<Expr>(SubExprs[ROW_IDX]); }
+  const Expr *getRowIdx() const { return cast<Expr>(SubExprs[ROW_IDX]); }
+  void setRowIdx(Expr *E) { SubExprs[ROW_IDX] = E; }
+
+  SourceLocation getBeginLoc() const LLVM_READONLY {
+    return getBase()->getBeginLoc();
+  }
+
+  SourceLocation getEndLoc() const { return getRBracketLoc(); }
+
+  SourceLocation getExprLoc() const LLVM_READONLY {
+    return getBase()->getExprLoc();
+  }
+
+  SourceLocation getRBracketLoc() const {
+    return ArrayOrMatrixSubscriptExprBits.RBracketLoc;
+  }
+  void setRBracketLoc(SourceLocation L) {
+    ArrayOrMatrixSubscriptExprBits.RBracketLoc = L;
+  }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == MatrixSingleSubscriptExprClass;
+  }
+
+  // Iterators
+  child_range children() {
+    return child_range(&SubExprs[0], &SubExprs[0] + END_EXPR);
+  }
+  const_child_range children() const {
+    return const_child_range(&SubExprs[0], &SubExprs[0] + END_EXPR);
+  }
+};
+
 /// MatrixSubscriptExpr - Matrix subscript expression for the MatrixType
 /// extension.
 /// MatrixSubscriptExpr can be either incomplete (only Base and RowIdx are set
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h 
b/clang/include/clang/AST/RecursiveASTVisitor.h
index 8f427427d71ed..92409b72e4f0c 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -2893,6 +2893,7 @@ DEF_TRAVERSE_STMT(CXXMemberCallExpr, {})
 // over the children.
 DEF_TRAVERSE_STMT(AddrLabelExpr, {})
 DEF_TRAVERSE_STMT(ArraySubscriptExpr, {})
+DEF_TRAVERSE_STMT(MatrixSingleSubscriptExpr, {})
 DEF_TRAVERSE_STMT(MatrixSubscriptExpr, {})
 DEF_TRAVERSE_STMT(ArraySectionExpr, {})
 DEF_TRAVERSE_STMT(OMPArrayShapingExpr, {})
diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h
index e1cca34d2212c..21d0a7dfe577c 100644
--- a/clang/include/clang/AST/Stmt.h
+++ b/clang/include/clang/AST/Stmt.h
@@ -530,6 +530,7 @@ class alignas(void *) Stmt {
   class ArrayOrMatrixSubscriptExprBitfields {
     friend class ArraySubscriptExpr;
     friend class MatrixSubscriptExpr;
+    friend class MatrixSingleSubscriptExpr;
 
     LLVM_PREFERRED_TYPE(ExprBitfields)
     unsigned : NumExprBits;
diff --git a/clang/include/clang/Basic/StmtNodes.td 
b/clang/include/clang/Basic/StmtNodes.td
index bf3686bb372d5..ada74807e56e2 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -74,6 +74,7 @@ def UnaryOperator : StmtNode<Expr>;
 def OffsetOfExpr : StmtNode<Expr>;
 def UnaryExprOrTypeTraitExpr : StmtNode<Expr>;
 def ArraySubscriptExpr : StmtNode<Expr>;
+def MatrixSingleSubscriptExpr : StmtNode<Expr>;
 def MatrixSubscriptExpr : StmtNode<Expr>;
 def ArraySectionExpr : StmtNode<Expr>;
 def OMPIteratorExpr : StmtNode<Expr>;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 4a601a0eaf1b9..d4d5c3d8bed17 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -7406,6 +7406,9 @@ class Sema final : public SemaBase {
   ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc,
                                              Expr *Idx, SourceLocation RLoc);
 
+  ExprResult CreateBuiltinMatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx,
+                                                    SourceLocation RBLoc);
+
   ExprResult CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx,
                                               Expr *ColumnIdx,
                                               SourceLocation RBLoc);
diff --git a/clang/lib/AST/ComputeDependence.cpp 
b/clang/lib/AST/ComputeDependence.cpp
index 638080ea781a9..8429f17d26be5 100644
--- a/clang/lib/AST/ComputeDependence.cpp
+++ b/clang/lib/AST/ComputeDependence.cpp
@@ -115,6 +115,10 @@ ExprDependence clang::computeDependence(ArraySubscriptExpr 
*E) {
   return E->getLHS()->getDependence() | E->getRHS()->getDependence();
 }
 
+ExprDependence clang::computeDependence(MatrixSingleSubscriptExpr *E) {
+  return E->getBase()->getDependence() | E->getRowIdx()->getDependence();
+}
+
 ExprDependence clang::computeDependence(MatrixSubscriptExpr *E) {
   return E->getBase()->getDependence() | E->getRowIdx()->getDependence() |
          (E->getColumnIdx() ? E->getColumnIdx()->getDependence()
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index ca7f3e16a9276..b400b2a083d9b 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -3789,6 +3789,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
 
   case ParenExprClass:
   case ArraySubscriptExprClass:
+  case MatrixSingleSubscriptExprClass:
   case MatrixSubscriptExprClass:
   case ArraySectionExprClass:
   case OMPArrayShapingExprClass:
diff --git a/clang/lib/AST/ExprClassification.cpp 
b/clang/lib/AST/ExprClassification.cpp
index aeacd0dc765ef..9995d1b411c5b 100644
--- a/clang/lib/AST/ExprClassification.cpp
+++ b/clang/lib/AST/ExprClassification.cpp
@@ -259,6 +259,9 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const 
Expr *E) {
     }
     return Cl::CL_LValue;
 
+  case Expr::MatrixSingleSubscriptExprClass:
+    return ClassifyInternal(Ctx, 
cast<MatrixSingleSubscriptExpr>(E)->getBase());
+
   // Subscripting matrix types behaves like member accesses.
   case Expr::MatrixSubscriptExprClass:
     return ClassifyInternal(Ctx, cast<MatrixSubscriptExpr>(E)->getBase());
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 11c5e1c6e90f4..52481dc71b75d 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -20667,6 +20667,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext 
&Ctx) {
   case Expr::ImaginaryLiteralClass:
   case Expr::StringLiteralClass:
   case Expr::ArraySubscriptExprClass:
+  case Expr::MatrixSingleSubscriptExprClass:
   case Expr::MatrixSubscriptExprClass:
   case Expr::ArraySectionExprClass:
   case Expr::OMPArrayShapingExprClass:
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 5572e0a7ae59c..cb71987fba766 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -5482,6 +5482,15 @@ void CXXNameMangler::mangleExpression(const Expr *E, 
unsigned Arity,
     break;
   }
 
+  case Expr::MatrixSingleSubscriptExprClass: {
+    NotPrimaryExpr();
+    const MatrixSingleSubscriptExpr *ME = cast<MatrixSingleSubscriptExpr>(E);
+    Out << "ix";
+    mangleExpression(ME->getBase());
+    mangleExpression(ME->getRowIdx());
+    break;
+  }
+
   case Expr::MatrixSubscriptExprClass: {
     NotPrimaryExpr();
     const MatrixSubscriptExpr *ME = cast<MatrixSubscriptExpr>(E);
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index ff8ca01ec5477..51b9c47f22ff0 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -1685,6 +1685,14 @@ void 
StmtPrinter::VisitArraySubscriptExpr(ArraySubscriptExpr *Node) {
   OS << "]";
 }
 
+void StmtPrinter::VisitMatrixSingleSubscriptExpr(
+    MatrixSingleSubscriptExpr *Node) {
+  PrintExpr(Node->getBase());
+  OS << "[";
+  PrintExpr(Node->getRowIdx());
+  OS << "]";
+}
+
 void StmtPrinter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *Node) {
   PrintExpr(Node->getBase());
   OS << "[";
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 4a8c638c85331..c7b7c65715dfc 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -1508,6 +1508,11 @@ void StmtProfiler::VisitArraySubscriptExpr(const 
ArraySubscriptExpr *S) {
   VisitExpr(S);
 }
 
+void StmtProfiler::VisitMatrixSingleSubscriptExpr(
+    const MatrixSingleSubscriptExpr *S) {
+  VisitExpr(S);
+}
+
 void StmtProfiler::VisitMatrixSubscriptExpr(const MatrixSubscriptExpr *S) {
   VisitExpr(S);
 }
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index e842158236cd4..ca06b5df94cb3 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1796,6 +1796,8 @@ LValue CodeGenFunction::EmitLValueHelper(const Expr *E,
     return EmitUnaryOpLValue(cast<UnaryOperator>(E));
   case Expr::ArraySubscriptExprClass:
     return EmitArraySubscriptExpr(cast<ArraySubscriptExpr>(E));
+  case Expr::MatrixSingleSubscriptExprClass:
+    return EmitMatrixSingleSubscriptExpr(cast<MatrixSingleSubscriptExpr>(E));
   case Expr::MatrixSubscriptExprClass:
     return EmitMatrixSubscriptExpr(cast<MatrixSubscriptExpr>(E));
   case Expr::ArraySectionExprClass:
@@ -2440,6 +2442,35 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, 
SourceLocation Loc) {
         Builder.CreateLoad(LV.getMatrixAddress(), LV.isVolatileQualified());
     return RValue::get(Builder.CreateExtractElement(Load, Idx, "matrixext"));
   }
+  if (LV.isMatrixRow()) {
+    QualType MatTy = LV.getType();
+    const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>();
+
+    unsigned NumRows = MT->getNumRows();
+    unsigned NumCols = MT->getNumColumns();
+
+    llvm::Value *MatrixVec = EmitLoadOfScalar(LV, Loc);
+
+    llvm::Value *Row = LV.getMatrixRowIdx();
+    llvm::Value *Result =
+        llvm::UndefValue::get(ConvertType(LV.getType())); // <NumCols x T>
+
+    llvm::MatrixBuilder MB(Builder);
+
+    for (unsigned Col = 0; Col < NumCols; ++Col) {
+      llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
+
+      llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
+
+      llvm::Value *Elt = Builder.CreateExtractElement(MatrixVec, EltIndex);
+
+      llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
+
+      Result = Builder.CreateInsertElement(Result, Elt, Lane);
+    }
+
+    return RValue::get(Result);
+  }
 
   assert(LV.isBitField() && "Unknown LValue type!");
   return EmitLoadOfBitfieldLValue(LV, Loc);
@@ -2662,6 +2693,36 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, 
LValue Dst,
       addInstToCurrentSourceAtom(I, Vec);
       return;
     }
+    if (Dst.isMatrixRow()) {
+      QualType MatTy = Dst.getType();
+      const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>();
+
+      unsigned NumRows = MT->getNumRows();
+      unsigned NumCols = MT->getNumColumns();
+
+      llvm::Value *MatrixVec =
+          Builder.CreateLoad(Dst.getAddress(), "matrix.load");
+
+      llvm::Value *Row = Dst.getMatrixRowIdx();
+      llvm::Value *RowVal = Src.getScalarVal(); // <NumCols x T>
+
+      llvm::MatrixBuilder MB(Builder);
+
+      for (unsigned Col = 0; Col < NumCols; ++Col) {
+        llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
+
+        llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
+
+        llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
+
+        llvm::Value *NewElt = Builder.CreateExtractElement(RowVal, Lane);
+
+        MatrixVec = Builder.CreateInsertElement(MatrixVec, NewElt, EltIndex);
+      }
+
+      Builder.CreateStore(MatrixVec, Dst.getAddress());
+      return;
+    }
 
     assert(Dst.isBitField() && "Unknown LValue type");
     return EmitStoreThroughBitfieldLValue(Src, Dst);
@@ -4874,6 +4935,35 @@ llvm::Value *CodeGenFunction::EmitMatrixIndexExpr(const 
Expr *E) {
   return Builder.CreateIntCast(Idx, IntPtrTy, IsSigned);
 }
 
+LValue CodeGenFunction::EmitMatrixSingleSubscriptExpr(
+    const MatrixSingleSubscriptExpr *E) {
+  LValue Base = EmitLValue(E->getBase());
+  llvm::Value *RowIdx = EmitMatrixIndexExpr(E->getRowIdx());
+
+  if (auto *RowConst = llvm::dyn_cast<llvm::ConstantInt>(RowIdx)) {
+
+    // Extract matrix shape from the AST type
+    const auto *MatTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
+    unsigned NumCols = MatTy->getNumColumns();
+    llvm::SmallVector<llvm::Constant *, 8> Indices;
+    Indices.reserve(NumCols);
+
+    unsigned Row = RowConst->getZExtValue();
+    unsigned Start = Row * NumCols;
+    for (unsigned C = 0; C < NumCols; ++C) {
+      Indices.push_back(llvm::ConstantInt::get(Int32Ty, Start + C));
+    }
+    llvm::Constant *Elts = llvm::ConstantVector::get(Indices);
+    return LValue::MakeExtVectorElt(
+        MaybeConvertMatrixAddress(Base.getAddress(), *this), Elts,
+        E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
+  }
+
+  return LValue::MakeMatrixRow(
+      MaybeConvertMatrixAddress(Base.getAddress(), *this), RowIdx,
+      E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
+}
+
 LValue CodeGenFunction::EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E) {
   assert(
       !E->isIncomplete() &&
@@ -5146,6 +5236,9 @@ EmitExtVectorElementExpr(const ExtVectorElementExpr *E) {
     return LValue::MakeExtVectorElt(Base.getAddress(), CV, type,
                                     Base.getBaseInfo(), TBAAAccessInfo());
   }
+  if (Base.isMatrixRow())
+    return EmitUnsupportedLValue(E, "Matrix single index swizzle");
+
   assert(Base.isExtVectorElt() && "Can only subscript lvalue vec elts here!");
 
   llvm::Constant *BaseElts = Base.getExtVectorElts();
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp 
b/clang/lib/CodeGen/CGExprScalar.cpp
index 769bc37b0e131..70397e8cb99c2 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -599,6 +599,7 @@ class ScalarExprEmitter
   }
 
   Value *VisitArraySubscriptExpr(ArraySubscriptExpr *E);
+  Value *VisitMatrixSingleSubscriptExpr(MatrixSingleSubscriptExpr *E);
   Value *VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E);
   Value *VisitShuffleVectorExpr(ShuffleVectorExpr *E);
   Value *VisitConvertVectorExpr(ConvertVectorExpr *E);
@@ -2109,6 +2110,40 @@ Value 
*ScalarExprEmitter::VisitArraySubscriptExpr(ArraySubscriptExpr *E) {
   return Builder.CreateExtractElement(Base, Idx, "vecext");
 }
 
+Value *ScalarExprEmitter::VisitMatrixSingleSubscriptExpr(
+    MatrixSingleSubscriptExpr *E) {
+  TestAndClearIgnoreResultAssign();
+
+  auto *MatrixTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
+  unsigned NumRows = MatrixTy->getNumRows();
+  unsigned NumColumns = MatrixTy->getNumColumns();
+
+  // Row index
+  Value *RowIdx = CGF.EmitMatrixIndexExpr(E->getRowIdx());
+
+  llvm::MatrixBuilder MB(Builder);
+
+  // The row index must be in [0, NumRows)
+  if (CGF.CGM.getCodeGenOpts().OptimizationLevel > 0)
+    MB.CreateIndexAssumption(RowIdx, NumRows);
+
+  Value *FlatMatrix = Visit(E->getBase());
+  llvm::Type *ElemTy = CGF.ConvertType(MatrixTy->getElementType());
+  auto *ResultTy = llvm::FixedVectorType::get(ElemTy, NumColumns);
+  Value *RowVec = llvm::UndefValue::get(ResultTy);
+
+  for (unsigned Col = 0; Col != NumColumns; ++Col) {
+    Value *ColVal = llvm::ConstantInt::get(RowIdx->getType(), Col);
+    Value *EltIdx = MB.CreateIndex(RowIdx, ColVal, NumRows, "matrix_row_idx");
+    Value *Elt =
+        Builder.CreateExtractElement(FlatMatrix, EltIdx, "matrix_elem");
+    Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
+    RowVec = Builder.CreateInsertElement(RowVec, Elt, Lane, "matrix_row_ins");
+  }
+
+  return RowVec;
+}
+
 Value *ScalarExprEmitter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) {
   TestAndClearIgnoreResultAssign();
 
diff --git a/clang/lib/CodeGen/CGValue.h b/clang/lib/CodeGen/CGValue.h
index 6b381b59e71cd..c08ca70de10e1 100644
--- a/clang/lib/CodeGen/CGValue.h
+++ b/clang/lib/CodeGen/CGValue.h
@@ -187,7 +187,8 @@ class LValue {
     BitField,     // This is a bitfield l-value, use getBitfield*.
     ExtVectorElt, // This is an extended vector subset, use getExtVectorComp
     GlobalReg,    // This is a register l-value, use getGlobalReg()
-    MatrixElt     // This is a matrix element, use getVector*
+    MatrixElt,    // This is a matrix element, use getVector*
+    MatrixRow     // This is a matrix vector subset, use getVector*
   } LVType;
 
   union {
@@ -282,6 +283,7 @@ class LValue {
   bool isExtVectorElt() const { return LVType == ExtVectorElt; }
   bool isGlobalReg() const { return LVType == GlobalReg; }
   bool isMatrixElt() const { return LVType == MatrixElt; }
+  bool isMatrixRow() const { return LVType == MatrixRow; }
 
   bool isVolatileQualified() const { return Quals.hasVolatile(); }
   bool isRestrictQualified() const { return Quals.hasRestrict(); }
@@ -398,6 +400,11 @@ class LValue {
     return VectorIdx;
   }
 
+  llvm::Value *getMatrixRowIdx() const {
+    assert(isMatrixRow());
+    return VectorIdx;
+  }
+
   // extended vector elements.
   Address getExtVectorAddress() const {
     assert(isExtVectorElt());
@@ -486,6 +493,16 @@ class LValue {
     return R;
   }
 
+  static LValue MakeMatrixRow(Address Addr, llvm::Value *RowIdx,
+                              QualType MatrixTy, LValueBaseInfo BaseInfo,
+                              TBAAAccessInfo TBAAInfo) {
+    LValue LV;
+    LV.LVType = MatrixRow;
+    LV.VectorIdx = RowIdx; // store the row index here
+    LV.Initialize(MatrixTy, MatrixTy.getQualifiers(), Addr, BaseInfo, 
TBAAInfo);
+    return LV;
+  }
+
   static LValue MakeMatrixElt(Address matAddress, llvm::Value *Idx,
                               QualType type, LValueBaseInfo BaseInfo,
                               TBAAAccessInfo TBAAInfo) {
diff --git a/clang/lib/CodeGen/CodeGenFunction.h 
b/clang/lib/CodeGen/CodeGenFunction.h
index 8c4c1c8c2dc95..3abe516debcb0 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4412,6 +4412,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   LValue EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
                                 bool Accessed = false);
   llvm::Value *EmitMatrixIndexExpr(const Expr *E);
+  LValue EmitMatrixSingleSubscriptExpr(const MatrixSingleSubscriptExpr *E);
   LValue EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E);
   LValue EmitArraySectionExpr(const ArraySect...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/170779
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to