[llvm-branch-commits] [clang] [clang][bytecode][HLSL][Matrix] Support `ConstantMatrixType` and more HLSL casts in the new constant interpreter for basic matrix constexpr evaluation in HLSL (PR #183424)

Deric C. via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Feb 26 09:59:27 PST 2026


https://github.com/Icohedron updated https://github.com/llvm/llvm-project/pull/183424

>From c33ea3969b9f527d18a37932751c625d9b3f71d1 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Wed, 25 Feb 2026 16:13:16 -0800
Subject: [PATCH 1/8] Support ConstantMatrixTypes and HLSL casts to bytecode
 constexpr evaluator

This commit adds support for ConstantMatrixType and the HLSL casts
CK_HLSLArrayRValue, CK_HLSLMatrixTruncation, CK_HLSLAggregateSplatCast,
and CK_HLSLElementwiseCast to the bytecode constexpr evaluator.

The implementations of CK_HLSLAggregateSplatCast and
CK_HLSLElementwiseCast are incomplete, as they still need to support
struct and array types.

Assisted-by: claude-opus-4.6
---
 clang/lib/AST/ByteCode/Compiler.cpp           | 207 ++++++++++++++++++
 clang/lib/AST/ByteCode/Compiler.h             |   5 +
 clang/lib/AST/ByteCode/Pointer.cpp            |  18 ++
 clang/lib/AST/ByteCode/Program.cpp            |  11 +
 .../BuiltinMatrix/MatrixConstantExpr.hlsl     |   2 +
 5 files changed, 243 insertions(+)

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 15e65a4d96581..87982e67dcb51 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -811,6 +811,180 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
   case CK_LValueBitCast:
     return this->emitInvalidCast(CastKind::ReinterpretLike, /*Fatal=*/true, CE);
 
+  case CK_HLSLArrayRValue: {
+    // Non-decaying array rvalue cast - creates an rvalue copy of an lvalue
+    // array, similar to LValueToRValue for composite types.
+    if (!Initializing) {
+      UnsignedOrNone LocalIndex = allocateLocal(CE);
+      if (!LocalIndex)
+        return false;
+      if (!this->emitGetPtrLocal(*LocalIndex, CE))
+        return false;
+    }
+    if (!this->visit(SubExpr))
+      return false;
+    return this->emitMemcpy(CE);
+  }
+
+  case CK_HLSLMatrixTruncation: {
+    assert(SubExpr->getType()->isConstantMatrixType());
+    if (OptPrimType ResultT = classify(CE)) {
+      assert(!DiscardResult);
+      // Result must be either a float or integer. Take the first element.
+      if (!this->visit(SubExpr))
+        return false;
+      return this->emitArrayElemPop(*ResultT, 0, CE);
+    }
+    // Otherwise, this truncates to a a constant matrix type.
+    assert(CE->getType()->isConstantMatrixType());
+
+    if (!Initializing) {
+      UnsignedOrNone LocalIndex = allocateTemporary(CE);
+      if (!LocalIndex)
+        return false;
+      if (!this->emitGetPtrLocal(*LocalIndex, CE))
+        return false;
+    }
+    unsigned ToSize =
+        CE->getType()->getAs<ConstantMatrixType>()->getNumElementsFlattened();
+    if (!this->visit(SubExpr))
+      return false;
+    return this->emitCopyArray(classifyMatrixElementType(SubExpr->getType()), 0,
+                               0, ToSize, CE);
+  }
+
+  case CK_HLSLAggregateSplatCast: {
+    // Aggregate splat cast: convert a scalar value to one of an aggregate type.
+    // TODO: Aggregate splat to struct and array types
+    assert(canClassify(SubExpr->getType()));
+
+    unsigned NumElts;
+    PrimType DestElemT;
+    QualType DestElemType;
+    if (const auto *VT = CE->getType()->getAs<VectorType>()) {
+      NumElts = VT->getNumElements();
+      DestElemType = VT->getElementType();
+    } else if (const auto *MT =
+                   CE->getType()->getAs<ConstantMatrixType>()) {
+      NumElts = MT->getNumElementsFlattened();
+      DestElemType = MT->getElementType();
+    } else {
+      return false;
+    }
+    DestElemT = classifyPrim(DestElemType);
+
+    if (!Initializing) {
+      UnsignedOrNone LocalIndex = allocateLocal(CE);
+      if (!LocalIndex)
+        return false;
+      if (!this->emitGetPtrLocal(*LocalIndex, CE))
+        return false;
+    }
+
+    PrimType SrcElemT = classifyPrim(SubExpr->getType());
+    unsigned SrcOffset =
+        allocateLocalPrimitive(SubExpr, DestElemT, /*IsConst=*/true);
+
+    if (!this->visit(SubExpr))
+      return false;
+    if (classifyPrim(SubExpr) == PT_Ptr && !this->emitLoadPop(SrcElemT, CE))
+      return false;
+    if (SrcElemT != DestElemT) {
+      if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE))
+        return false;
+    }
+    if (!this->emitSetLocal(DestElemT, SrcOffset, CE))
+      return false;
+
+    for (unsigned I = 0; I != NumElts; ++I) {
+      if (!this->emitGetLocal(DestElemT, SrcOffset, CE))
+        return false;
+      if (!this->emitInitElem(DestElemT, I, CE))
+        return false;
+    }
+    return true;
+  }
+
+  case CK_HLSLElementwiseCast: {
+    // Elementwise cast: flatten source elements of one aggregate type and store
+    // to a destination aggregate type of the same or fewer number of elements.
+    // TODO: Elementwise cast to structs, nested arrays, and arrays of composite
+    // types
+    QualType SrcType = SubExpr->getType();
+    QualType DestType = CE->getType();
+
+    unsigned SrcNumElts;
+    PrimType SrcElemT;
+    if (const auto *VT = SrcType->getAs<VectorType>()) {
+      SrcNumElts = VT->getNumElements();
+      SrcElemT = classifyPrim(VT->getElementType());
+    } else if (const auto *MT = SrcType->getAs<ConstantMatrixType>()) {
+      SrcNumElts = MT->getNumElementsFlattened();
+      SrcElemT = classifyPrim(MT->getElementType());
+    } else if (const auto *AT = SrcType->getAsArrayTypeUnsafe()) {
+      if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
+        SrcNumElts = CAT->getZExtSize();
+        SrcElemT = classifyPrim(CAT->getElementType());
+      } else {
+        return false;
+      }
+    } else {
+      return false;
+    }
+
+    unsigned DestNumElts;
+    PrimType DestElemT;
+    QualType DestElemType;
+    if (const auto *VT = DestType->getAs<VectorType>()) {
+      DestNumElts = VT->getNumElements();
+      DestElemType = VT->getElementType();
+    } else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) {
+      DestNumElts = MT->getNumElementsFlattened();
+      DestElemType = MT->getElementType();
+    } else if (const auto *AT = DestType->getAsArrayTypeUnsafe()) {
+      if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
+        DestNumElts = CAT->getZExtSize();
+        DestElemType = CAT->getElementType();
+      } else {
+        return false;
+      }
+    } else {
+      return false;
+    }
+    DestElemT = classifyPrim(DestElemType);
+
+    if (!Initializing) {
+      UnsignedOrNone LocalIndex =
+          classify(DestType) ? allocateLocal(CE) : allocateTemporary(CE);
+      if (!LocalIndex)
+        return false;
+      if (!this->emitGetPtrLocal(*LocalIndex, CE))
+        return false;
+    }
+
+    unsigned SrcOffset =
+        allocateLocalPrimitive(SubExpr, PT_Ptr, /*IsConst=*/true);
+    if (!this->visit(SubExpr))
+      return false;
+    if (!this->emitSetLocal(PT_Ptr, SrcOffset, CE))
+      return false;
+
+    unsigned NumElts = std::min(SrcNumElts, DestNumElts);
+    for (unsigned I = 0; I != NumElts; ++I) {
+      if (!this->emitGetLocal(PT_Ptr, SrcOffset, CE))
+        return false;
+      if (!this->emitArrayElemPop(SrcElemT, I, CE))
+        return false;
+      if (SrcElemT != DestElemT) {
+        if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE))
+          return false;
+      }
+      if (!this->emitInitElem(DestElemT, I, CE))
+        return false;
+    }
+    return true;
+  }
+
   default:
     return this->emitInvalid(CE);
   }
@@ -1813,6 +1987,20 @@ bool Compiler<Emitter>::VisitImplicitValueInitExpr(
     return true;
   }
 
+  if (const auto *MT = E->getType()->getAs<ConstantMatrixType>()) {
+    unsigned NumElts = MT->getNumElementsFlattened();
+    QualType ElemQT = MT->getElementType();
+    PrimType ElemT = classifyPrim(ElemQT);
+
+    for (unsigned I = 0; I < NumElts; ++I) {
+      if (!this->visitZeroInitializer(ElemT, ElemQT, E))
+        return false;
+      if (!this->emitInitElem(ElemT, I, E))
+        return false;
+    }
+    return true;
+  }
+
   return false;
 }
 
@@ -2129,6 +2317,25 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits,
     return true;
   }
 
+  if (const auto *MT = QT->getAs<ConstantMatrixType>()) {
+    unsigned NumElts = MT->getNumElementsFlattened();
+    assert(Inits.size() == NumElts);
+
+    QualType ElemQT = MT->getElementType();
+    PrimType ElemT = classifyPrim(ElemQT);
+
+    // InitListExpr elements are in column-major order.
+    // Store in row-major order to match APValue convention.
+    for (unsigned I = 0; I < NumElts; ++I) {
+      if (!this->visit(Inits[I]))
+        return false;
+      if (!this->emitInitElem(ElemT,
+                              MT->mapColumnMajorToRowMajorFlattenedIndex(I), E))
+        return false;
+    }
+    return true;
+  }
+
   return false;
 }
 
diff --git a/clang/lib/AST/ByteCode/Compiler.h b/clang/lib/AST/ByteCode/Compiler.h
index 1bd15c3d79563..74ded47e88792 100644
--- a/clang/lib/AST/ByteCode/Compiler.h
+++ b/clang/lib/AST/ByteCode/Compiler.h
@@ -406,6 +406,11 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>,
     return *this->classify(T->getAs<VectorType>()->getElementType());
   }
 
+  PrimType classifyMatrixElementType(QualType T) const {
+    assert(T->isMatrixType());
+    return *this->classify(T->getAs<MatrixType>()->getElementType());
+  }
+
   bool emitComplexReal(const Expr *SubExpr);
   bool emitComplexBoolCast(const Expr *E);
   bool emitComplexComparison(const Expr *LHS, const Expr *RHS,
diff --git a/clang/lib/AST/ByteCode/Pointer.cpp b/clang/lib/AST/ByteCode/Pointer.cpp
index e237013f4199c..a569a4221cf2f 100644
--- a/clang/lib/AST/ByteCode/Pointer.cpp
+++ b/clang/lib/AST/ByteCode/Pointer.cpp
@@ -934,6 +934,24 @@ std::optional<APValue> Pointer::toRValue(const Context &Ctx,
       return true;
     }
 
+    // Constant Matrix types.
+    if (const auto *MT = Ty->getAs<ConstantMatrixType>()) {
+      assert(Ptr.getFieldDesc()->isPrimitiveArray());
+      QualType ElemTy = MT->getElementType();
+      PrimType ElemT = *Ctx.classify(ElemTy);
+      unsigned NumElts = MT->getNumElementsFlattened();
+
+      SmallVector<APValue> Values;
+      Values.reserve(NumElts);
+      for (unsigned I = 0; I != NumElts; ++I) {
+        TYPE_SWITCH(ElemT,
+                    { Values.push_back(Ptr.elem<T>(I).toAPValue(ASTCtx)); });
+      }
+
+      R = APValue(Values.data(), MT->getNumRows(), MT->getNumColumns());
+      return true;
+    }
+
     llvm_unreachable("invalid value to return");
   };
 
diff --git a/clang/lib/AST/ByteCode/Program.cpp b/clang/lib/AST/ByteCode/Program.cpp
index 76fec63a8920d..8876ded409415 100644
--- a/clang/lib/AST/ByteCode/Program.cpp
+++ b/clang/lib/AST/ByteCode/Program.cpp
@@ -474,5 +474,16 @@ Descriptor *Program::createDescriptor(const DeclTy &D, const Type *Ty,
                               IsTemporary, IsMutable);
   }
 
+  // Same with constant matrix types.
+  if (const auto *MT = Ty->getAs<ConstantMatrixType>()) {
+    OptPrimType ElemTy = Ctx.classify(MT->getElementType());
+    if (!ElemTy)
+      return nullptr;
+
+    return allocateDescriptor(D, *ElemTy, MDSize,
+                              MT->getNumElementsFlattened(), IsConst,
+                              IsTemporary, IsMutable);
+  }
+
   return nullptr;
 }
diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
index 64220980d9edc..608af75bff4bb 100644
--- a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
@@ -1,5 +1,7 @@
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major -verify %s
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -verify %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major -fexperimental-new-constant-interpreter -verify %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -fexperimental-new-constant-interpreter -verify %s
 
 // expected-no-diagnostics
 

>From cd61ecb70cf65b5df44154c418179db137399645 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Wed, 25 Feb 2026 16:59:02 -0800
Subject: [PATCH 2/8] Apply clang-format

---
 clang/lib/AST/ByteCode/Compiler.cpp | 3 +--
 clang/lib/AST/ByteCode/Program.cpp  | 5 ++---
 2 files changed, 3 insertions(+), 5 deletions(-)

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 87982e67dcb51..6f5ed3b7dd107 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -864,8 +864,7 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
     if (const auto *VT = CE->getType()->getAs<VectorType>()) {
       NumElts = VT->getNumElements();
       DestElemType = VT->getElementType();
-    } else if (const auto *MT =
-                   CE->getType()->getAs<ConstantMatrixType>()) {
+    } else if (const auto *MT = CE->getType()->getAs<ConstantMatrixType>()) {
       NumElts = MT->getNumElementsFlattened();
       DestElemType = MT->getElementType();
     } else {
diff --git a/clang/lib/AST/ByteCode/Program.cpp b/clang/lib/AST/ByteCode/Program.cpp
index 8876ded409415..a2cba3da27675 100644
--- a/clang/lib/AST/ByteCode/Program.cpp
+++ b/clang/lib/AST/ByteCode/Program.cpp
@@ -480,9 +480,8 @@ Descriptor *Program::createDescriptor(const DeclTy &D, const Type *Ty,
     if (!ElemTy)
       return nullptr;
 
-    return allocateDescriptor(D, *ElemTy, MDSize,
-                              MT->getNumElementsFlattened(), IsConst,
-                              IsTemporary, IsMutable);
+    return allocateDescriptor(D, *ElemTy, MDSize, MT->getNumElementsFlattened(),
+                              IsConst, IsTemporary, IsMutable);
   }
 
   return nullptr;

>From 7bc9164b5d3a2977f55136844f7a39c9c952f1cd Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Wed, 25 Feb 2026 22:02:56 -0800
Subject: [PATCH 3/8] Rename NumElts to NumElems

---
 clang/lib/AST/ByteCode/Compiler.cpp | 38 ++++++++++++++---------------
 clang/lib/AST/ByteCode/Pointer.cpp  |  6 ++---
 2 files changed, 22 insertions(+), 22 deletions(-)

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 6f5ed3b7dd107..32ec16570a5b6 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -858,14 +858,14 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
     // TODO: Aggregate splat to struct and array types
     assert(canClassify(SubExpr->getType()));
 
-    unsigned NumElts;
+    unsigned NumElems;
     PrimType DestElemT;
     QualType DestElemType;
     if (const auto *VT = CE->getType()->getAs<VectorType>()) {
-      NumElts = VT->getNumElements();
+      NumElems = VT->getNumElements();
       DestElemType = VT->getElementType();
     } else if (const auto *MT = CE->getType()->getAs<ConstantMatrixType>()) {
-      NumElts = MT->getNumElementsFlattened();
+      NumElems = MT->getNumElementsFlattened();
       DestElemType = MT->getElementType();
     } else {
       return false;
@@ -895,7 +895,7 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
     if (!this->emitSetLocal(DestElemT, SrcOffset, CE))
       return false;
 
-    for (unsigned I = 0; I != NumElts; ++I) {
+    for (unsigned I = 0; I != NumElems; ++I) {
       if (!this->emitGetLocal(DestElemT, SrcOffset, CE))
         return false;
       if (!this->emitInitElem(DestElemT, I, CE))
@@ -912,17 +912,17 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
     QualType SrcType = SubExpr->getType();
     QualType DestType = CE->getType();
 
-    unsigned SrcNumElts;
+    unsigned SrcNumElems;
     PrimType SrcElemT;
     if (const auto *VT = SrcType->getAs<VectorType>()) {
-      SrcNumElts = VT->getNumElements();
+      SrcNumElems = VT->getNumElements();
       SrcElemT = classifyPrim(VT->getElementType());
     } else if (const auto *MT = SrcType->getAs<ConstantMatrixType>()) {
-      SrcNumElts = MT->getNumElementsFlattened();
+      SrcNumElems = MT->getNumElementsFlattened();
       SrcElemT = classifyPrim(MT->getElementType());
     } else if (const auto *AT = SrcType->getAsArrayTypeUnsafe()) {
       if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
-        SrcNumElts = CAT->getZExtSize();
+        SrcNumElems = CAT->getZExtSize();
         SrcElemT = classifyPrim(CAT->getElementType());
       } else {
         return false;
@@ -931,18 +931,18 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
       return false;
     }
 
-    unsigned DestNumElts;
+    unsigned DestNumElems;
     PrimType DestElemT;
     QualType DestElemType;
     if (const auto *VT = DestType->getAs<VectorType>()) {
-      DestNumElts = VT->getNumElements();
+      DestNumElems = VT->getNumElements();
       DestElemType = VT->getElementType();
     } else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) {
-      DestNumElts = MT->getNumElementsFlattened();
+      DestNumElems = MT->getNumElementsFlattened();
       DestElemType = MT->getElementType();
     } else if (const auto *AT = DestType->getAsArrayTypeUnsafe()) {
       if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
-        DestNumElts = CAT->getZExtSize();
+        DestNumElems = CAT->getZExtSize();
         DestElemType = CAT->getElementType();
       } else {
         return false;
@@ -968,8 +968,8 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
     if (!this->emitSetLocal(PT_Ptr, SrcOffset, CE))
       return false;
 
-    unsigned NumElts = std::min(SrcNumElts, DestNumElts);
-    for (unsigned I = 0; I != NumElts; ++I) {
+    unsigned NumElems = std::min(SrcNumElems, DestNumElems);
+    for (unsigned I = 0; I != NumElems; ++I) {
       if (!this->emitGetLocal(PT_Ptr, SrcOffset, CE))
         return false;
       if (!this->emitArrayElemPop(SrcElemT, I, CE))
@@ -1987,11 +1987,11 @@ bool Compiler<Emitter>::VisitImplicitValueInitExpr(
   }
 
   if (const auto *MT = E->getType()->getAs<ConstantMatrixType>()) {
-    unsigned NumElts = MT->getNumElementsFlattened();
+    unsigned NumElems = MT->getNumElementsFlattened();
     QualType ElemQT = MT->getElementType();
     PrimType ElemT = classifyPrim(ElemQT);
 
-    for (unsigned I = 0; I < NumElts; ++I) {
+    for (unsigned I = 0; I < NumElems; ++I) {
       if (!this->visitZeroInitializer(ElemT, ElemQT, E))
         return false;
       if (!this->emitInitElem(ElemT, I, E))
@@ -2317,15 +2317,15 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits,
   }
 
   if (const auto *MT = QT->getAs<ConstantMatrixType>()) {
-    unsigned NumElts = MT->getNumElementsFlattened();
-    assert(Inits.size() == NumElts);
+    unsigned NumElems = MT->getNumElementsFlattened();
+    assert(Inits.size() == NumElems);
 
     QualType ElemQT = MT->getElementType();
     PrimType ElemT = classifyPrim(ElemQT);
 
     // InitListExpr elements are in column-major order.
     // Store in row-major order to match APValue convention.
-    for (unsigned I = 0; I < NumElts; ++I) {
+    for (unsigned I = 0; I < NumElems; ++I) {
       if (!this->visit(Inits[I]))
         return false;
       if (!this->emitInitElem(ElemT,
diff --git a/clang/lib/AST/ByteCode/Pointer.cpp b/clang/lib/AST/ByteCode/Pointer.cpp
index a569a4221cf2f..f4352e7edf5f8 100644
--- a/clang/lib/AST/ByteCode/Pointer.cpp
+++ b/clang/lib/AST/ByteCode/Pointer.cpp
@@ -939,11 +939,11 @@ std::optional<APValue> Pointer::toRValue(const Context &Ctx,
       assert(Ptr.getFieldDesc()->isPrimitiveArray());
       QualType ElemTy = MT->getElementType();
       PrimType ElemT = *Ctx.classify(ElemTy);
-      unsigned NumElts = MT->getNumElementsFlattened();
+      unsigned NumElems = MT->getNumElementsFlattened();
 
       SmallVector<APValue> Values;
-      Values.reserve(NumElts);
-      for (unsigned I = 0; I != NumElts; ++I) {
+      Values.reserve(NumElems);
+      for (unsigned I = 0; I != NumElems; ++I) {
         TYPE_SWITCH(ElemT,
                     { Values.push_back(Ptr.elem<T>(I).toAPValue(ASTCtx)); });
       }

>From 4bd726d2c2dd898b8403103cb4e0ab962f8b9de4 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Wed, 25 Feb 2026 22:10:49 -0800
Subject: [PATCH 4/8] Remove unnecessary SubExpr ptr check and load

This code was copied over from the CK_VectorSplat case but is not needed
for CK_HLSLAggregateSplatCast. The `classifyPrim(SubExpr) == PT_Ptr &&
!this->emitLoadPop(SrcElemT, CE)` is not necessary because an lvalue to
rvalue conversion is always inserted with an aggregate splat cast, so
the SubExpr will never be a ptr.

See https://github.com/llvm/llvm-project/blob/143664fcd3df825befdb9586151d53aefef3d7d0/clang/lib/Sema/SemaCast.cpp#L2939-L2940
for confirmation that this is the case.
---
 clang/lib/AST/ByteCode/Compiler.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 32ec16570a5b6..9cdaa3cc31288 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -886,8 +886,6 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
 
     if (!this->visit(SubExpr))
       return false;
-    if (classifyPrim(SubExpr) == PT_Ptr && !this->emitLoadPop(SrcElemT, CE))
-      return false;
     if (SrcElemT != DestElemT) {
       if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE))
         return false;

>From 952f2cb56d629a6f07ddce29a0b92ac5bcd3dc04 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Wed, 25 Feb 2026 23:16:09 -0800
Subject: [PATCH 5/8] Support scalar DestType for HLSLElementwiseCast

Also edit the comments of HLSLElementwiseCast and
HLSLAggregateSplatCasts to be more clear.

Assisted-by: claude-opus-4.6
---
 clang/lib/AST/ByteCode/Compiler.cpp           | 22 +++++++++++++++----
 .../BuiltinMatrix/MatrixConstantExpr.hlsl     |  2 ++
 2 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 9cdaa3cc31288..08c39538107df 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -854,7 +854,9 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
   }
 
   case CK_HLSLAggregateSplatCast: {
-    // Aggregate splat cast: convert a scalar value to one of an aggregate type.
+    // Aggregate splat cast: convert a scalar value to one of an aggregate type,
+    // inserting casts when necessary to convert the scalar to the aggregate's
+    // element type(s).
     // TODO: Aggregate splat to struct and array types
     assert(canClassify(SubExpr->getType()));
 
@@ -904,7 +906,8 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
 
   case CK_HLSLElementwiseCast: {
     // Elementwise cast: flatten source elements of one aggregate type and store
-    // to a destination aggregate type of the same or fewer number of elements.
+    // to a destination scalar or aggregate type of the same or fewer number of
+    // elements, while inserting casts as necessary.
     // TODO: Elementwise cast to structs, nested arrays, and arrays of composite
     // types
     QualType SrcType = SubExpr->getType();
@@ -945,14 +948,25 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
       } else {
         return false;
       }
+    } else if (classify(DestType)) {
+      // Scalar destination: extract element 0 and cast.
+      PrimType DestT = classifyPrim(DestType);
+      if (!this->visit(SubExpr))
+        return false;
+      if (!this->emitArrayElemPop(SrcElemT, 0, CE))
+        return false;
+      if (SrcElemT != DestT) {
+        if (!this->emitPrimCast(SrcElemT, DestT, DestType, CE))
+          return false;
+      }
+      return true;
     } else {
       return false;
     }
     DestElemT = classifyPrim(DestElemType);
 
     if (!Initializing) {
-      UnsignedOrNone LocalIndex =
-          classify(DestType) ? allocateLocal(CE) : allocateTemporary(CE);
+      UnsignedOrNone LocalIndex = allocateTemporary(CE);
       if (!LocalIndex)
         return false;
       if (!this->emitGetPtrLocal(*LocalIndex, CE))
diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
index 608af75bff4bb..2c55e1a0ee4b3 100644
--- a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
@@ -45,6 +45,8 @@ export void fn() {
     _Static_assert(FA4[1] == 2.5, "Woo!");
     _Static_assert(FA4[2] == 3.5, "Woo!");
     _Static_assert(FA4[3] == 4.5, "Woo!");
+    constexpr float F = (float)FA4;
+    _Static_assert(F == 1.5, "Woo!");
   }
 
   // Array cast to matrix to vector

>From 974eec6611da919c1489f9d498e9be5913723660 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Wed, 25 Feb 2026 23:27:32 -0800
Subject: [PATCH 6/8] Change < to != in loop guards

---
 clang/lib/AST/ByteCode/Compiler.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 08c39538107df..914da43a3d714 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -2003,7 +2003,7 @@ bool Compiler<Emitter>::VisitImplicitValueInitExpr(
     QualType ElemQT = MT->getElementType();
     PrimType ElemT = classifyPrim(ElemQT);
 
-    for (unsigned I = 0; I < NumElems; ++I) {
+    for (unsigned I = 0; I != NumElems; ++I) {
       if (!this->visitZeroInitializer(ElemT, ElemQT, E))
         return false;
       if (!this->emitInitElem(ElemT, I, E))
@@ -2337,7 +2337,7 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits,
 
     // InitListExpr elements are in column-major order.
     // Store in row-major order to match APValue convention.
-    for (unsigned I = 0; I < NumElems; ++I) {
+    for (unsigned I = 0; I != NumElems; ++I) {
       if (!this->visit(Inits[I]))
         return false;
       if (!this->emitInitElem(ElemT,

>From eb43998a7e2e081ced9c23fcf45307d26c74e052 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Thu, 26 Feb 2026 09:04:28 -0800
Subject: [PATCH 7/8] Replace PrimType DestT with OptPrimType DestT

---
 clang/lib/AST/ByteCode/Compiler.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 914da43a3d714..2f7026b961bbd 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -948,15 +948,14 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
       } else {
         return false;
       }
-    } else if (classify(DestType)) {
+    } else if (OptPrimType DestT = classify(DestType)) {
       // Scalar destination: extract element 0 and cast.
-      PrimType DestT = classifyPrim(DestType);
       if (!this->visit(SubExpr))
         return false;
       if (!this->emitArrayElemPop(SrcElemT, 0, CE))
         return false;
-      if (SrcElemT != DestT) {
-        if (!this->emitPrimCast(SrcElemT, DestT, DestType, CE))
+      if (SrcElemT != *DestT) {
+        if (!this->emitPrimCast(SrcElemT, *DestT, DestType, CE))
           return false;
       }
       return true;

>From 6cb36787d65fe34087213c7a79c9d9203df9060c Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Thu, 26 Feb 2026 09:59:06 -0800
Subject: [PATCH 8/8] Early return false from HLSLElementwiseCast if Src or
 Dest types are not allowed

---
 clang/lib/AST/ByteCode/Compiler.cpp | 85 ++++++++++++++++-------------
 1 file changed, 47 insertions(+), 38 deletions(-)

diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 2f7026b961bbd..990539446ce22 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -913,54 +913,63 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
     QualType SrcType = SubExpr->getType();
     QualType DestType = CE->getType();
 
-    unsigned SrcNumElems;
-    PrimType SrcElemT;
-    if (const auto *VT = SrcType->getAs<VectorType>()) {
-      SrcNumElems = VT->getNumElements();
-      SrcElemT = classifyPrim(VT->getElementType());
-    } else if (const auto *MT = SrcType->getAs<ConstantMatrixType>()) {
-      SrcNumElems = MT->getNumElementsFlattened();
-      SrcElemT = classifyPrim(MT->getElementType());
-    } else if (const auto *AT = SrcType->getAsArrayTypeUnsafe()) {
-      if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
-        SrcNumElems = CAT->getZExtSize();
-        SrcElemT = classifyPrim(CAT->getElementType());
-      } else {
-        return false;
-      }
-    } else {
+    // Allowed SrcTypes
+    const auto *SrcVT = SrcType->getAs<VectorType>();
+    const auto *SrcMT = SrcType->getAs<ConstantMatrixType>();
+    const auto *SrcAT = SrcType->getAsArrayTypeUnsafe();
+    const auto *SrcCAT = SrcAT ? dyn_cast<ConstantArrayType>(SrcAT) : nullptr;
+
+    // Allowed DestTypes
+    const auto *DestVT = DestType->getAs<VectorType>();
+    const auto *DestMT = DestType->getAs<ConstantMatrixType>();
+    const auto *DestAT = DestType->getAsArrayTypeUnsafe();
+    const auto *DestCAT =
+        DestAT ? dyn_cast<ConstantArrayType>(DestAT) : nullptr;
+    const OptPrimType DestPT = classify(DestType);
+
+    if (!SrcVT && !SrcMT && !SrcCAT)
+      return false;
+    if (!DestVT && !DestMT && !DestCAT && !DestPT)
       return false;
-    }
 
-    unsigned DestNumElems;
-    PrimType DestElemT;
-    QualType DestElemType;
-    if (const auto *VT = DestType->getAs<VectorType>()) {
-      DestNumElems = VT->getNumElements();
-      DestElemType = VT->getElementType();
-    } else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) {
-      DestNumElems = MT->getNumElementsFlattened();
-      DestElemType = MT->getElementType();
-    } else if (const auto *AT = DestType->getAsArrayTypeUnsafe()) {
-      if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
-        DestNumElems = CAT->getZExtSize();
-        DestElemType = CAT->getElementType();
-      } else {
-        return false;
-      }
-    } else if (OptPrimType DestT = classify(DestType)) {
+    unsigned SrcNumElems;
+    PrimType SrcElemT;
+    if (SrcVT) {
+      SrcNumElems = SrcVT->getNumElements();
+      SrcElemT = classifyPrim(SrcVT->getElementType());
+    } else if (SrcMT) {
+      SrcNumElems = SrcMT->getNumElementsFlattened();
+      SrcElemT = classifyPrim(SrcMT->getElementType());
+    } else if (SrcCAT) {
+      SrcNumElems = SrcCAT->getZExtSize();
+      SrcElemT = classifyPrim(SrcCAT->getElementType());
+    }
+
+    if (DestPT) {
       // Scalar destination: extract element 0 and cast.
       if (!this->visit(SubExpr))
         return false;
       if (!this->emitArrayElemPop(SrcElemT, 0, CE))
         return false;
-      if (SrcElemT != *DestT) {
-        if (!this->emitPrimCast(SrcElemT, *DestT, DestType, CE))
+      if (SrcElemT != *DestPT) {
+        if (!this->emitPrimCast(SrcElemT, *DestPT, DestType, CE))
           return false;
       }
       return true;
-    } else {
-      return false;
+    }
+
+    unsigned DestNumElems;
+    PrimType DestElemT;
+    QualType DestElemType;
+    if (DestVT) {
+      DestNumElems = DestVT->getNumElements();
+      DestElemType = DestVT->getElementType();
+    } else if (DestMT) {
+      DestNumElems = DestMT->getNumElementsFlattened();
+      DestElemType = DestMT->getElementType();
+    } else if (DestCAT) {
+      DestNumElems = DestCAT->getZExtSize();
+      DestElemType = DestCAT->getElementType();
     }
     DestElemT = classifyPrim(DestElemType);
 



More information about the llvm-branch-commits mailing list