[llvm] [DirectX] Implement `memcpy` in DXIL CBuffer Access pass (PR #144436)

Justin Bogner via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 30 09:44:16 PDT 2025


================
@@ -54,114 +58,248 @@ struct CBufferRowIntrin {
     }
   }
 };
-} // namespace
 
-static size_t getOffsetForCBufferGEP(GEPOperator *GEP, GlobalVariable *Global,
-                                     const DataLayout &DL) {
-  // Since we should always have a constant offset, we should only ever have a
-  // single GEP of indirection from the Global.
-  assert(GEP->getPointerOperand() == Global &&
-         "Indirect access to resource handle");
+// Helper for creating CBuffer handles and loading data from them
+struct CBufferResource {
+  GlobalVariable *GVHandle;
+  GlobalVariable *Member;
+  size_t MemberOffset;
 
-  APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
-  bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);
-  (void)Success;
-  assert(Success && "Offsets into cbuffer globals must be constant");
+  LoadInst *Handle;
 
-  if (auto *ATy = dyn_cast<ArrayType>(Global->getValueType()))
-    ConstantOffset = hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);
+  CBufferResource(GlobalVariable *GVHandle, GlobalVariable *Member,
+                  size_t MemberOffset)
+      : GVHandle(GVHandle), Member(Member), MemberOffset(MemberOffset) {}
 
-  return ConstantOffset.getZExtValue();
-}
+  const DataLayout &getDataLayout() { return GVHandle->getDataLayout(); }
+  Type *getValueType() { return Member->getValueType(); }
+  iterator_range<ConstantDataSequential::user_iterator> users() {
+    return Member->users();
+  }
 
-/// Replace access via cbuffer global with a load from the cbuffer handle
-/// itself.
-static void replaceAccess(LoadInst *LI, GlobalVariable *Global,
-                          GlobalVariable *HandleGV, size_t BaseOffset,
-                          SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
-  const DataLayout &DL = HandleGV->getDataLayout();
+  /// Get the byte offset of a Pointer-typed Value * `Val` relative to Member.
+  /// `Val` can either be Member itself, or a GEP of a constant offset from
+  /// Member
+  size_t getOffsetForCBufferGEP(Value *Val) {
+    assert(isa<PointerType>(Val->getType()) &&
+           "Expected a pointer-typed value");
+
+    if (Val == Member)
+      return 0;
+
+    if (auto *GEP = dyn_cast<GEPOperator>(Val)) {
+      // Since we should always have a constant offset, we should only ever have
+      // a single GEP of indirection from the Global.
+      assert(GEP->getPointerOperand() == Member &&
+             "Indirect access to resource handle");
+
+      const DataLayout &DL = getDataLayout();
+      APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
+      bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);
+      (void)Success;
+      assert(Success && "Offsets into cbuffer globals must be constant");
+
+      if (auto *ATy = dyn_cast<ArrayType>(Member->getValueType()))
+        ConstantOffset =
+            hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);
+
+      return ConstantOffset.getZExtValue();
+    }
 
-  size_t Offset = BaseOffset;
-  if (auto *GEP = dyn_cast<GEPOperator>(LI->getPointerOperand()))
-    Offset += getOffsetForCBufferGEP(GEP, Global, DL);
-  else if (LI->getPointerOperand() != Global)
-    llvm_unreachable("Load instruction doesn't reference cbuffer global");
+    llvm_unreachable("Expected Val to be a GlobalVariable or GEP");
+  }
 
-  IRBuilder<> Builder(LI);
-  auto *Handle = Builder.CreateLoad(HandleGV->getValueType(), HandleGV,
-                                    HandleGV->getName());
-
-  Type *Ty = LI->getType();
-  CBufferRowIntrin Intrin(DL, Ty->getScalarType());
-  // The cbuffer consists of some number of 16-byte rows.
-  unsigned int CurrentRow = Offset / hlsl::CBufferRowSizeInBytes;
-  unsigned int CurrentIndex =
-      (Offset % hlsl::CBufferRowSizeInBytes) / Intrin.EltSize;
-
-  auto *CBufLoad = Builder.CreateIntrinsic(
-      Intrin.RetTy, Intrin.IID,
-      {Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,
-      LI->getName());
-  auto *Elt =
-      Builder.CreateExtractValue(CBufLoad, {CurrentIndex++}, LI->getName());
-
-  Value *Result = nullptr;
-  unsigned int Remaining =
-      ((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;
-  if (Remaining == 0) {
-    // We only have a single element, so we're done.
-    Result = Elt;
-
-    // However, if we loaded a <1 x T>, then we need to adjust the type here.
-    if (auto *VT = dyn_cast<FixedVectorType>(LI->getType())) {
-      assert(VT->getNumElements() == 1 && "Can't have multiple elements here");
-      Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,
-                                           Builder.getInt32(0));
-    }
-  } else {
-    // Walk each element and extract it, wrapping to new rows as needed.
-    SmallVector<Value *> Extracts{Elt};
-    while (Remaining--) {
-      CurrentIndex %= Intrin.NumElts;
-
-      if (CurrentIndex == 0)
-        CBufLoad = Builder.CreateIntrinsic(
-            Intrin.RetTy, Intrin.IID,
-            {Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},
-            nullptr, LI->getName());
-
-      Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
-                                                    LI->getName()));
+  /// Create a handle for this cbuffer resource using the IRBuilder `Builder`
+  /// and sets the handle as the current one to use for subsequent calls to
+  /// `loadValue`
+  void createAndSetCurrentHandle(IRBuilder<> &Builder) {
+    Handle = Builder.CreateLoad(GVHandle->getValueType(), GVHandle,
+                                GVHandle->getName());
+  }
+
+  /// Load a value of type `Ty` at offset `Offset` using the handle from the
+  /// last call to `createAndSetCurrentHandle`
+  Value *loadValue(IRBuilder<> &Builder, Type *Ty, size_t Offset,
+                   const Twine &Name = "") {
+    assert(Handle &&
+           "Expected a handle for this cbuffer global resource to be created "
+           "before loading a value from it");
+    const DataLayout &DL = getDataLayout();
+
+    size_t TargetOffset = MemberOffset + Offset;
+    CBufferRowIntrin Intrin(DL, Ty->getScalarType());
+    // The cbuffer consists of some number of 16-byte rows.
+    unsigned int CurrentRow = TargetOffset / hlsl::CBufferRowSizeInBytes;
+    unsigned int CurrentIndex =
+        (TargetOffset % hlsl::CBufferRowSizeInBytes) / Intrin.EltSize;
+
+    auto *CBufLoad = Builder.CreateIntrinsic(
+        Intrin.RetTy, Intrin.IID,
+        {Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,
+        Name + ".load");
+    auto *Elt = Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
+                                           Name + ".extract");
+
+    Value *Result = nullptr;
+    unsigned int Remaining =
+        ((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;
+    if (Remaining == 0) {
+      // We only have a single element, so we're done.
+      Result = Elt;
+
+      // However, if we loaded a <1 x T>, then we need to adjust the type here.
+      if (auto *VT = dyn_cast<FixedVectorType>(Ty)) {
+        assert(VT->getNumElements() == 1 &&
+               "Can't have multiple elements here");
+        Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,
+                                             Builder.getInt32(0), Name);
+      }
+    } else {
+      // Walk each element and extract it, wrapping to new rows as needed.
+      SmallVector<Value *> Extracts{Elt};
+      while (Remaining--) {
+        CurrentIndex %= Intrin.NumElts;
+
+        if (CurrentIndex == 0)
+          CBufLoad = Builder.CreateIntrinsic(
+              Intrin.RetTy, Intrin.IID,
+              {Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},
+              nullptr, Name + ".load");
+
+        Extracts.push_back(Builder.CreateExtractValue(
+            CBufLoad, {CurrentIndex++}, Name + ".extract"));
+      }
+
+      // Finally, we build up the original loaded value.
+      Result = PoisonValue::get(Ty);
+      for (int I = 0, E = Extracts.size(); I < E; ++I)
+        Result = Builder.CreateInsertElement(Result, Extracts[I],
+                                             Builder.getInt32(I),
+                                             Name + formatv(".upto{}", I));
     }
 
-    // Finally, we build up the original loaded value.
-    Result = PoisonValue::get(Ty);
-    for (int I = 0, E = Extracts.size(); I < E; ++I)
-      Result =
-          Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I));
+    return Result;
   }
+};
+
+} // namespace
 
+/// Replace load via cbuffer global with a load from the cbuffer handle itself.
+static void replaceLoad(LoadInst *LI, CBufferResource &CBR,
+                        SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
+  size_t Offset = CBR.getOffsetForCBufferGEP(LI->getPointerOperand());
+  IRBuilder<> Builder(LI);
+  CBR.createAndSetCurrentHandle(Builder);
+  Value *Result = CBR.loadValue(Builder, LI->getType(), Offset, LI->getName());
   LI->replaceAllUsesWith(Result);
   DeadInsts.push_back(LI);
 }
 
-static void replaceAccessesWithHandle(GlobalVariable *Global,
-                                      GlobalVariable *HandleGV,
-                                      size_t BaseOffset) {
+/// Replace memcpy from a cbuffer global with a memcpy from the cbuffer handle
+/// itself. Assumes the cbuffer global is an array, and the length of bytes to
+/// copy is divisible by array element allocation size.
+/// The memcpy source must also be a direct cbuffer global reference, not a GEP.
+static void replaceMemCpy(MemCpyInst *MCI, CBufferResource &CBR) {
+
+  ArrayType *ArrTy = dyn_cast<ArrayType>(CBR.getValueType());
+  assert(ArrTy && "MemCpy lowering is only supported for array types");
+
+  // This assumption vastly simplifies the implementation
+  if (MCI->getSource() != CBR.Member)
+    reportFatalUsageError(
+        "Expected MemCpy source to be a cbuffer global variable");
+
+  const std::string Name = ("memcpy." + MCI->getDest()->getName() + "." +
+                            MCI->getSource()->getName())
+                               .str();
+
+  ConstantInt *Length = dyn_cast<ConstantInt>(MCI->getLength());
+  uint64_t ByteLength = Length->getZExtValue();
+
+  // If length to copy is zero, no memcpy is needed
+  if (ByteLength == 0) {
+    MCI->eraseFromParent();
+    return;
+  }
+
+  const DataLayout &DL = CBR.getDataLayout();
+
+  Type *ElemTy = ArrTy->getElementType();
+  size_t ElemSize = DL.getTypeAllocSize(ElemTy);
+  assert(ByteLength % ElemSize == 0 &&
+         "Length of bytes to MemCpy must be divisible by allocation size of "
+         "source/destination array elements");
+  size_t ElemsToCpy = ByteLength / ElemSize;
+
+  IRBuilder<> Builder(MCI);
+  CBR.createAndSetCurrentHandle(Builder);
+
+  // This function recursively copies N array elements from the CBuffer Resource
+  // to the MemCpy Destination. Recursion is used to unravel multidimensional
+  // arrays into a sequence of scalar/vector extracts and stores.
+  auto CopyElemsImpl = [&Builder, &MCI, &Name, &CBR,
+                        &DL](const auto &Self, ArrayType *ArrTy,
+                             size_t ArrOffset, size_t N) -> void {
----------------
bogner wrote:

It would probably help readability a bit to just make this a `static void` function at the module scope rather than jumping through hoops to make a recursive lambda.

https://github.com/llvm/llvm-project/pull/144436


More information about the llvm-commits mailing list