https://github.com/Icohedron created 
https://github.com/llvm/llvm-project/pull/184840

Forgot to change the target branch before merging. This PR is a cherry-pick of 
the squashed-and-merged PR commit b16aa4b7ec665911c74300cd7442659b70973d13 from 
183424

This PR fixes #182963
This PR is an extension of #178762 which has already been merged.

This PR adds support for `ConstantMatrixType` and the HLSL casts 
`CK_HLSLArrayRValue`, `CK_HLSLMatrixTruncation`,
`CK_HLSLAggregateSplatCast`, and `CK_HLSLElementwiseCast` to the bytecode 
constexpr evaluator.

The implementations of CK_HLSLAggregateSplatCast and CK_HLSLElementwiseCast are 
incomplete, as they still need to support struct and array types to enable use 
of the experimental new constant interpreter on other existing HLSL constexpr 
tests. The completion of the implementations of these casts will be tracked in 
a separate issue (#183426) and implemented in a separate PR.

Assisted-by: claude-opus-4.6

>From b6a48d2a7df7bc284df42df3464dc577cb9d0f09 Mon Sep 17 00:00:00 2001
From: "Deric C." <[email protected]>
Date: Thu, 5 Mar 2026 09:45:48 -0800
Subject: [PATCH] [clang][bytecode][HLSL][Matrix] Support `ConstantMatrixType`
 and more HLSL casts in the new constant interpreter for basic matrix
 constexpr evaluation in HLSL (#183424)

Fixes #182963
This PR is an extension of #178762 and is to be merged immediately after
it.

This PR adds support for `ConstantMatrixType` and the HLSL casts
`CK_HLSLArrayRValue`, `CK_HLSLMatrixTruncation`,
`CK_HLSLAggregateSplatCast`, and `CK_HLSLElementwiseCast` to the
bytecode constexpr evaluator.

The implementations of CK_HLSLAggregateSplatCast and
CK_HLSLElementwiseCast are incomplete, as they still need to support
struct and array types to enable use of the experimental new constant
interpreter on other existing HLSL constexpr tests. The completion of
the implementations of these casts will be tracked in a separate issue
(#183426) and implemented in a separate PR.

Assisted-by: claude-opus-4.6
---
 clang/lib/AST/ByteCode/Compiler.cpp           | 226 ++++++++++++++++++
 clang/lib/AST/ByteCode/Compiler.h             |   5 +
 clang/lib/AST/ByteCode/Pointer.cpp            |  18 ++
 clang/lib/AST/ByteCode/Program.cpp            |  10 +
 .../BuiltinMatrix/MatrixConstantExpr.hlsl     |   4 +
 5 files changed, 263 insertions(+)

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp 
b/clang/lib/AST/ByteCode/Compiler.cpp
index e33fa4f86c052..93ad8eb26f29e 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -811,6 +811,199 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) 
{
   case CK_LValueBitCast:
     return this->emitInvalidCast(CastKind::ReinterpretLike, /*Fatal=*/true, 
CE);
 
+  case CK_HLSLArrayRValue: {
+    // Non-decaying array rvalue cast - creates an rvalue copy of an lvalue
+    // array, similar to LValueToRValue for composite types.
+    if (!Initializing) {
+      UnsignedOrNone LocalIndex = allocateLocal(CE);
+      if (!LocalIndex)
+        return false;
+      if (!this->emitGetPtrLocal(*LocalIndex, CE))
+        return false;
+    }
+    if (!this->visit(SubExpr))
+      return false;
+    return this->emitMemcpy(CE);
+  }
+
+  case CK_HLSLMatrixTruncation: {
+    assert(SubExpr->getType()->isConstantMatrixType());
+    if (OptPrimType ResultT = classify(CE)) {
+      assert(!DiscardResult);
+      // Result must be either a float or integer. Take the first element.
+      if (!this->visit(SubExpr))
+        return false;
+      return this->emitArrayElemPop(*ResultT, 0, CE);
+    }
+    // Otherwise, this truncates to a a constant matrix type.
+    assert(CE->getType()->isConstantMatrixType());
+
+    if (!Initializing) {
+      UnsignedOrNone LocalIndex = allocateTemporary(CE);
+      if (!LocalIndex)
+        return false;
+      if (!this->emitGetPtrLocal(*LocalIndex, CE))
+        return false;
+    }
+    unsigned ToSize =
+        CE->getType()->getAs<ConstantMatrixType>()->getNumElementsFlattened();
+    if (!this->visit(SubExpr))
+      return false;
+    return this->emitCopyArray(classifyMatrixElementType(SubExpr->getType()), 
0,
+                               0, ToSize, CE);
+  }
+
+  case CK_HLSLAggregateSplatCast: {
+    // Aggregate splat cast: convert a scalar value to one of an aggregate 
type,
+    // inserting casts when necessary to convert the scalar to the aggregate's
+    // element type(s).
+    // TODO: Aggregate splat to struct and array types
+    assert(canClassify(SubExpr->getType()));
+
+    unsigned NumElems;
+    PrimType DestElemT;
+    QualType DestElemType;
+    if (const auto *VT = CE->getType()->getAs<VectorType>()) {
+      NumElems = VT->getNumElements();
+      DestElemType = VT->getElementType();
+    } else if (const auto *MT = CE->getType()->getAs<ConstantMatrixType>()) {
+      NumElems = MT->getNumElementsFlattened();
+      DestElemType = MT->getElementType();
+    } else {
+      return false;
+    }
+    DestElemT = classifyPrim(DestElemType);
+
+    if (!Initializing) {
+      UnsignedOrNone LocalIndex = allocateLocal(CE);
+      if (!LocalIndex)
+        return false;
+      if (!this->emitGetPtrLocal(*LocalIndex, CE))
+        return false;
+    }
+
+    PrimType SrcElemT = classifyPrim(SubExpr->getType());
+    unsigned SrcOffset =
+        allocateLocalPrimitive(SubExpr, DestElemT, /*IsConst=*/true);
+
+    if (!this->visit(SubExpr))
+      return false;
+    if (SrcElemT != DestElemT) {
+      if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE))
+        return false;
+    }
+    if (!this->emitSetLocal(DestElemT, SrcOffset, CE))
+      return false;
+
+    for (unsigned I = 0; I != NumElems; ++I) {
+      if (!this->emitGetLocal(DestElemT, SrcOffset, CE))
+        return false;
+      if (!this->emitInitElem(DestElemT, I, CE))
+        return false;
+    }
+    return true;
+  }
+
+  case CK_HLSLElementwiseCast: {
+    // Elementwise cast: flatten source elements of one aggregate type and 
store
+    // to a destination scalar or aggregate type of the same or fewer number of
+    // elements, while inserting casts as necessary.
+    // TODO: Elementwise cast to structs, nested arrays, and arrays of 
composite
+    // types
+    QualType SrcType = SubExpr->getType();
+    QualType DestType = CE->getType();
+
+    // Allowed SrcTypes
+    const auto *SrcVT = SrcType->getAs<VectorType>();
+    const auto *SrcMT = SrcType->getAs<ConstantMatrixType>();
+    const auto *SrcAT = SrcType->getAsArrayTypeUnsafe();
+    const auto *SrcCAT = SrcAT ? dyn_cast<ConstantArrayType>(SrcAT) : nullptr;
+
+    // Allowed DestTypes
+    const auto *DestVT = DestType->getAs<VectorType>();
+    const auto *DestMT = DestType->getAs<ConstantMatrixType>();
+    const auto *DestAT = DestType->getAsArrayTypeUnsafe();
+    const auto *DestCAT =
+        DestAT ? dyn_cast<ConstantArrayType>(DestAT) : nullptr;
+    const OptPrimType DestPT = classify(DestType);
+
+    if (!SrcVT && !SrcMT && !SrcCAT)
+      return false;
+    if (!DestVT && !DestMT && !DestCAT && !DestPT)
+      return false;
+
+    unsigned SrcNumElems;
+    PrimType SrcElemT;
+    if (SrcVT) {
+      SrcNumElems = SrcVT->getNumElements();
+      SrcElemT = classifyPrim(SrcVT->getElementType());
+    } else if (SrcMT) {
+      SrcNumElems = SrcMT->getNumElementsFlattened();
+      SrcElemT = classifyPrim(SrcMT->getElementType());
+    } else if (SrcCAT) {
+      SrcNumElems = SrcCAT->getZExtSize();
+      SrcElemT = classifyPrim(SrcCAT->getElementType());
+    }
+
+    if (DestPT) {
+      // Scalar destination: extract element 0 and cast.
+      if (!this->visit(SubExpr))
+        return false;
+      if (!this->emitArrayElemPop(SrcElemT, 0, CE))
+        return false;
+      if (SrcElemT != *DestPT) {
+        if (!this->emitPrimCast(SrcElemT, *DestPT, DestType, CE))
+          return false;
+      }
+      return true;
+    }
+
+    unsigned DestNumElems;
+    PrimType DestElemT;
+    QualType DestElemType;
+    if (DestVT) {
+      DestNumElems = DestVT->getNumElements();
+      DestElemType = DestVT->getElementType();
+    } else if (DestMT) {
+      DestNumElems = DestMT->getNumElementsFlattened();
+      DestElemType = DestMT->getElementType();
+    } else if (DestCAT) {
+      DestNumElems = DestCAT->getZExtSize();
+      DestElemType = DestCAT->getElementType();
+    }
+    DestElemT = classifyPrim(DestElemType);
+
+    if (!Initializing) {
+      UnsignedOrNone LocalIndex = allocateTemporary(CE);
+      if (!LocalIndex)
+        return false;
+      if (!this->emitGetPtrLocal(*LocalIndex, CE))
+        return false;
+    }
+
+    unsigned SrcOffset =
+        allocateLocalPrimitive(SubExpr, PT_Ptr, /*IsConst=*/true);
+    if (!this->visit(SubExpr))
+      return false;
+    if (!this->emitSetLocal(PT_Ptr, SrcOffset, CE))
+      return false;
+
+    unsigned NumElems = std::min(SrcNumElems, DestNumElems);
+    for (unsigned I = 0; I != NumElems; ++I) {
+      if (!this->emitGetLocal(PT_Ptr, SrcOffset, CE))
+        return false;
+      if (!this->emitArrayElemPop(SrcElemT, I, CE))
+        return false;
+      if (SrcElemT != DestElemT) {
+        if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE))
+          return false;
+      }
+      if (!this->emitInitElem(DestElemT, I, CE))
+        return false;
+    }
+    return true;
+  }
+
   default:
     return this->emitInvalid(CE);
   }
@@ -1813,6 +2006,20 @@ bool Compiler<Emitter>::VisitImplicitValueInitExpr(
     return true;
   }
 
+  if (const auto *MT = E->getType()->getAs<ConstantMatrixType>()) {
+    unsigned NumElems = MT->getNumElementsFlattened();
+    QualType ElemQT = MT->getElementType();
+    PrimType ElemT = classifyPrim(ElemQT);
+
+    for (unsigned I = 0; I != NumElems; ++I) {
+      if (!this->visitZeroInitializer(ElemT, ElemQT, E))
+        return false;
+      if (!this->emitInitElem(ElemT, I, E))
+        return false;
+    }
+    return true;
+  }
+
   return false;
 }
 
@@ -2129,6 +2336,25 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const 
Expr *> Inits,
     return true;
   }
 
+  if (const auto *MT = QT->getAs<ConstantMatrixType>()) {
+    unsigned NumElems = MT->getNumElementsFlattened();
+    assert(Inits.size() == NumElems);
+
+    QualType ElemQT = MT->getElementType();
+    PrimType ElemT = classifyPrim(ElemQT);
+
+    // Matrix initializer list elements are in row-major order, which matches
+    // the matrix APValue convention and therefore no index remapping is
+    // required.
+    for (unsigned I = 0; I != NumElems; ++I) {
+      if (!this->visit(Inits[I]))
+        return false;
+      if (!this->emitInitElem(ElemT, I, E))
+        return false;
+    }
+    return true;
+  }
+
   return false;
 }
 
diff --git a/clang/lib/AST/ByteCode/Compiler.h 
b/clang/lib/AST/ByteCode/Compiler.h
index 8f1f9dad4469e..717928dc1fbbd 100644
--- a/clang/lib/AST/ByteCode/Compiler.h
+++ b/clang/lib/AST/ByteCode/Compiler.h
@@ -406,6 +406,11 @@ class Compiler : public 
ConstStmtVisitor<Compiler<Emitter>, bool>,
     return *this->classify(T->getAs<VectorType>()->getElementType());
   }
 
+  PrimType classifyMatrixElementType(QualType T) const {
+    assert(T->isMatrixType());
+    return *this->classify(T->getAs<MatrixType>()->getElementType());
+  }
+
   bool emitComplexReal(const Expr *SubExpr);
   bool emitComplexBoolCast(const Expr *E);
   bool emitComplexComparison(const Expr *LHS, const Expr *RHS,
diff --git a/clang/lib/AST/ByteCode/Pointer.cpp 
b/clang/lib/AST/ByteCode/Pointer.cpp
index e237013f4199c..f4352e7edf5f8 100644
--- a/clang/lib/AST/ByteCode/Pointer.cpp
+++ b/clang/lib/AST/ByteCode/Pointer.cpp
@@ -934,6 +934,24 @@ std::optional<APValue> Pointer::toRValue(const Context 
&Ctx,
       return true;
     }
 
+    // Constant Matrix types.
+    if (const auto *MT = Ty->getAs<ConstantMatrixType>()) {
+      assert(Ptr.getFieldDesc()->isPrimitiveArray());
+      QualType ElemTy = MT->getElementType();
+      PrimType ElemT = *Ctx.classify(ElemTy);
+      unsigned NumElems = MT->getNumElementsFlattened();
+
+      SmallVector<APValue> Values;
+      Values.reserve(NumElems);
+      for (unsigned I = 0; I != NumElems; ++I) {
+        TYPE_SWITCH(ElemT,
+                    { Values.push_back(Ptr.elem<T>(I).toAPValue(ASTCtx)); });
+      }
+
+      R = APValue(Values.data(), MT->getNumRows(), MT->getNumColumns());
+      return true;
+    }
+
     llvm_unreachable("invalid value to return");
   };
 
diff --git a/clang/lib/AST/ByteCode/Program.cpp 
b/clang/lib/AST/ByteCode/Program.cpp
index bf511ff215db9..efef5db177e56 100644
--- a/clang/lib/AST/ByteCode/Program.cpp
+++ b/clang/lib/AST/ByteCode/Program.cpp
@@ -483,5 +483,15 @@ Descriptor *Program::createDescriptor(const DeclTy &D, 
const Type *Ty,
                               IsTemporary, IsMutable);
   }
 
+  // Same with constant matrix types.
+  if (const auto *MT = Ty->getAs<ConstantMatrixType>()) {
+    OptPrimType ElemTy = Ctx.classify(MT->getElementType());
+    if (!ElemTy)
+      return nullptr;
+
+    return allocateDescriptor(D, *ElemTy, MDSize, 
MT->getNumElementsFlattened(),
+                              IsConst, IsTemporary, IsMutable);
+  }
+
   return nullptr;
 }
diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl 
b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
index 64220980d9edc..2c55e1a0ee4b3 100644
--- a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
@@ -1,5 +1,7 @@
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library 
-finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major 
-verify %s
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library 
-finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -verify 
%s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library 
-finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major 
-fexperimental-new-constant-interpreter -verify %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library 
-finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major 
-fexperimental-new-constant-interpreter -verify %s
 
 // expected-no-diagnostics
 
@@ -43,6 +45,8 @@ export void fn() {
     _Static_assert(FA4[1] == 2.5, "Woo!");
     _Static_assert(FA4[2] == 3.5, "Woo!");
     _Static_assert(FA4[3] == 4.5, "Woo!");
+    constexpr float F = (float)FA4;
+    _Static_assert(F == 1.5, "Woo!");
   }
 
   // Array cast to matrix to vector

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to