[llvm-branch-commits] [clang] [clang][bytecode][HLSL][Matrix] Support `ConstantMatrixType` and more HLSL casts in the new constant interpreter for basic matrix constexpr evaluation in HLSL (PR #183424)

Deric C. via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Feb 26 09:44:06 PST 2026


================
@@ -811,6 +811,190 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
   case CK_LValueBitCast:
     return this->emitInvalidCast(CastKind::ReinterpretLike, /*Fatal=*/true, CE);
 
+  case CK_HLSLArrayRValue: {
+    // Non-decaying array rvalue cast - creates an rvalue copy of an lvalue
+    // array, similar to LValueToRValue for composite types.
+    if (!Initializing) {
+      UnsignedOrNone LocalIndex = allocateLocal(CE);
+      if (!LocalIndex)
+        return false;
+      if (!this->emitGetPtrLocal(*LocalIndex, CE))
+        return false;
+    }
+    if (!this->visit(SubExpr))
+      return false;
+    return this->emitMemcpy(CE);
+  }
+
+  case CK_HLSLMatrixTruncation: {
+    assert(SubExpr->getType()->isConstantMatrixType());
+    if (OptPrimType ResultT = classify(CE)) {
+      assert(!DiscardResult);
+      // Result must be either a float or integer. Take the first element.
+      if (!this->visit(SubExpr))
+        return false;
+      return this->emitArrayElemPop(*ResultT, 0, CE);
+    }
+    // Otherwise, this truncates to a a constant matrix type.
+    assert(CE->getType()->isConstantMatrixType());
+
+    if (!Initializing) {
+      UnsignedOrNone LocalIndex = allocateTemporary(CE);
+      if (!LocalIndex)
+        return false;
+      if (!this->emitGetPtrLocal(*LocalIndex, CE))
+        return false;
+    }
+    unsigned ToSize =
+        CE->getType()->getAs<ConstantMatrixType>()->getNumElementsFlattened();
+    if (!this->visit(SubExpr))
+      return false;
+    return this->emitCopyArray(classifyMatrixElementType(SubExpr->getType()), 0,
+                               0, ToSize, CE);
+  }
+
+  case CK_HLSLAggregateSplatCast: {
+    // Aggregate splat cast: convert a scalar value to one of an aggregate type,
+    // inserting casts when necessary to convert the scalar to the aggregate's
+    // element type(s).
+    // TODO: Aggregate splat to struct and array types
+    assert(canClassify(SubExpr->getType()));
+
+    unsigned NumElems;
+    PrimType DestElemT;
+    QualType DestElemType;
+    if (const auto *VT = CE->getType()->getAs<VectorType>()) {
+      NumElems = VT->getNumElements();
+      DestElemType = VT->getElementType();
+    } else if (const auto *MT = CE->getType()->getAs<ConstantMatrixType>()) {
+      NumElems = MT->getNumElementsFlattened();
+      DestElemType = MT->getElementType();
+    } else {
+      return false;
+    }
+    DestElemT = classifyPrim(DestElemType);
+
+    if (!Initializing) {
+      UnsignedOrNone LocalIndex = allocateLocal(CE);
+      if (!LocalIndex)
+        return false;
+      if (!this->emitGetPtrLocal(*LocalIndex, CE))
+        return false;
+    }
+
+    PrimType SrcElemT = classifyPrim(SubExpr->getType());
+    unsigned SrcOffset =
+        allocateLocalPrimitive(SubExpr, DestElemT, /*IsConst=*/true);
+
+    if (!this->visit(SubExpr))
+      return false;
+    if (SrcElemT != DestElemT) {
+      if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE))
+        return false;
+    }
+    if (!this->emitSetLocal(DestElemT, SrcOffset, CE))
+      return false;
+
+    for (unsigned I = 0; I != NumElems; ++I) {
+      if (!this->emitGetLocal(DestElemT, SrcOffset, CE))
+        return false;
+      if (!this->emitInitElem(DestElemT, I, CE))
+        return false;
+    }
+    return true;
+  }
+
+  case CK_HLSLElementwiseCast: {
+    // Elementwise cast: flatten source elements of one aggregate type and store
+    // to a destination scalar or aggregate type of the same or fewer number of
+    // elements, while inserting casts as necessary.
+    // TODO: Elementwise cast to structs, nested arrays, and arrays of composite
+    // types
+    QualType SrcType = SubExpr->getType();
+    QualType DestType = CE->getType();
+
+    unsigned SrcNumElems;
+    PrimType SrcElemT;
+    if (const auto *VT = SrcType->getAs<VectorType>()) {
+      SrcNumElems = VT->getNumElements();
+      SrcElemT = classifyPrim(VT->getElementType());
+    } else if (const auto *MT = SrcType->getAs<ConstantMatrixType>()) {
+      SrcNumElems = MT->getNumElementsFlattened();
+      SrcElemT = classifyPrim(MT->getElementType());
+    } else if (const auto *AT = SrcType->getAsArrayTypeUnsafe()) {
+      if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
+        SrcNumElems = CAT->getZExtSize();
+        SrcElemT = classifyPrim(CAT->getElementType());
+      } else {
+        return false;
+      }
+    } else {
+      return false;
----------------
Icohedron wrote:

I could change the code to:
```c++
    const auto *SrcVT = SrcType->getAs<VectorType>();
    const auto *SrcMT = SrcType->getAs<ConstantMatrixType>();
    const auto *SrcAT = SrcType->getAsArrayTypeUnsafe();

    if (!SrcVT && !SrcMT && !SrcAT)
      return false;

    if (SrcVT) {
      SrcNumElems = SrcVT->getNumElements();
      SrcElemT = classifyPrim(SrcVT->getElementType());
    } else if (SrcMT) {
      SrcNumElems = SrcMT->getNumElementsFlattened();
      SrcElemT = classifyPrim(SrcMT->getElementType());
    } else if (SrcAT) {
      if (const auto *CAT = dyn_cast<ConstantArrayType>(SrcAT)) {
        SrcNumElems = CAT->getZExtSize();
        SrcElemT = classifyPrim(CAT->getElementType());
      } else {
        return false;
      }
    }
```
but I don't think this significantly improves code readability. The [LLVM coding standards](https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code) suggests using early exits to reduce nesting, but the level of nesting does not change in this case.

https://github.com/llvm/llvm-project/pull/183424


More information about the llvm-branch-commits mailing list