[llvm] Dx lower to rawbufferload dxil ops draft rebase (PR #116845)

Zhengxing li via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 19 09:33:53 PST 2024


https://github.com/lizhengxing created https://github.com/llvm/llvm-project/pull/116845

It's the draft PR of llvm.dx.rawbufferload.

There are 3 commits. 
The DXILResourceMap change: [[DirectX] Add Resource uses to Resource Handle map in DXILResourceMap · lizhengxing/llvm-project at 392daa1](https://github.com/lizhengxing/llvm-project/commit/392daa10503f89a2f081ec2df2fea5ea5db980f6)
The Raw Buffer Load  change: [Lower llvm.dx.rawbufferload to dxil ops · lizhengxing/llvm-project at 731703f](https://github.com/lizhengxing/llvm-project/commit/731703f8130d47d5d2e3b89ff413ef7d4d38523e)
 
I separate the DXILResourceMap change in the Raw Buffer Load PR to the third commit. [Changes in DXILResourceMap for lowering llvm.dx.rawbufferload to dxil… · lizhengxing/llvm-project at 5abec9d](https://github.com/lizhengxing/llvm-project/commit/5abec9db0adc2203355478e524dcc651336c3f40)

>From 8e1bca0aa7269bcb210d2e5e1a95e7d8375ba162 Mon Sep 17 00:00:00 2001
From: Zhengxing Li <zhengxingli at microsoft.com>
Date: Wed, 16 Oct 2024 09:10:59 -0700
Subject: [PATCH 1/3] [DirectX] Add Resource uses to Resource Handle map in
 DXILResourceMap

When lowering some resource use intrisics to DXIL operations, it needs to know the information of the resource that the intrisics are using.

This PR adds Resource uses to Resource Handle map in DXILResourceMap. It helps the resource uses to find the resource information.

This PR is also useful to #106188
---
 llvm/include/llvm/Analysis/DXILResource.h     | 27 +++++++++
 llvm/lib/Analysis/DXILResource.cpp            | 55 +++++++++++++++++++
 llvm/lib/Target/DirectX/DXILOpLowering.cpp    |  9 +++
 .../Analysis/DXILResource/resource-map.ll     | 36 ++++++++++++
 .../DirectX/DXILResource/dxil-resource-map.ll | 48 ++++++++++++++++
 5 files changed, 175 insertions(+)
 create mode 100644 llvm/test/Analysis/DXILResource/resource-map.ll
 create mode 100644 llvm/test/CodeGen/DirectX/DXILResource/dxil-resource-map.ll

diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 6b577c02f05450..016e45e78c3984 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -264,6 +264,8 @@ class ResourceInfo {
 class DXILResourceMap {
   SmallVector<dxil::ResourceInfo> Resources;
   DenseMap<CallInst *, unsigned> CallMap;
+  // Mapping from Resource use to Resource Handle
+  DenseMap<CallInst *, CallInst *> ResUseToHandleMap;
   unsigned FirstUAV = 0;
   unsigned FirstCBuffer = 0;
   unsigned FirstSampler = 0;
@@ -335,6 +337,31 @@ class DXILResourceMap {
   }
 
   void print(raw_ostream &OS) const;
+
+  void updateResourceMap(CallInst *origCallInst, CallInst *newCallInst);
+
+  void updateResUseMap(CallInst *origResUse, CallInst *newResUse) {
+    assert((origResUse != nullptr) && (newResUse != nullptr) &&
+           (origResUse != newResUse) && "Wrong Inputs");
+
+    updateResUseMapCommon(origResUse, newResUse, /*keepOrigResUseInMap=*/false);
+  }
+
+  CallInst *findResHandleByUse(CallInst *resUse) {
+    auto Pos = ResUseToHandleMap.find(resUse);
+    assert((Pos != ResUseToHandleMap.end()) &&
+           "Can't find the resource handle");
+
+    return Pos->second;
+  }
+
+private:
+  void updateResUseMapCommon(CallInst *origResUse, CallInst *newResUse,
+                             bool keepOrigResUseInMap) {
+    ResUseToHandleMap.try_emplace(newResUse, findResHandleByUse(origResUse));
+    if (!keepOrigResUseInMap)
+      ResUseToHandleMap.erase(origResUse);
+  }
 };
 
 class DXILResourceAnalysis : public AnalysisInfoMixin<DXILResourceAnalysis> {
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 2802480481690d..601d2648ae0288 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -719,6 +719,12 @@ DXILResourceMap::DXILResourceMap(
     if (Resources.empty() || RI != Resources.back())
       Resources.push_back(RI);
     CallMap[CI] = Resources.size() - 1;
+
+    // Build ResUseToHandleMap
+    for (auto it = CI->users().begin(); it != CI->users().end(); ++it) {
+      CallInst *CI_Use = dyn_cast<CallInst>(*it);
+      ResUseToHandleMap[CI_Use] = CI;
+    }
   }
 
   unsigned Size = Resources.size();
@@ -744,6 +750,47 @@ DXILResourceMap::DXILResourceMap(
   }
 }
 
+// Parameter origCallInst: original Resource Handle
+// Parameter newCallInst:  new Resource Handle
+//
+// This function is needed when origCallInst's lowered to newCallInst.
+//
+// Because origCallInst and its uses will be replaced by newCallInst and new def
+// instructions after lowering. The [origCallInst, resource info] entry in
+// CallMap and [origCallInst's use, origCallInst] entries in ResUseToHandleMap
+// have to be updated per the changes in lowering.
+//
+// What this function does are:
+//   1. Add [newCallInst, resource info] entry in CallMap
+//   2. Remove [origCallInst, resource info] entry in CallMap
+//   3. Remap [origCallInst's use, origCallInst] entries to
+//      [origCallInst's use, newCallInst] entries in ResUseToHandleMap
+//
+// Remove those entries related to origCallInst in maps is necessary since
+// origCallInst's no longer existing after lowering. Moreover, keeping those
+// entries in maps will crash DXILResourceMap::print function
+//
+// FYI:
+// Make sure to invoke this function before origCallInst->replaceAllUsesWith()
+// and origCallInst->eraseFromParent() since this function needs to visit
+// origCallInst and its uses.
+//
+void DXILResourceMap::updateResourceMap(CallInst *origCallInst,
+                                        CallInst *newCallInst) {
+  assert((origCallInst != nullptr) && (newCallInst != nullptr) &&
+         (origCallInst != newCallInst));
+
+  CallMap.try_emplace(newCallInst, CallMap[origCallInst]);
+  CallMap.erase(origCallInst);
+
+  // Update ResUseToHandleMap since Resource Handle changed
+  for (auto it = origCallInst->users().begin();
+       it != origCallInst->users().end(); ++it) {
+    CallInst *CI_Use = dyn_cast<CallInst>(*it);
+    ResUseToHandleMap[CI_Use] = newCallInst;
+  }
+}
+
 void DXILResourceMap::print(raw_ostream &OS) const {
   for (unsigned I = 0, E = Resources.size(); I != E; ++I) {
     OS << "Binding " << I << ":\n";
@@ -756,6 +803,14 @@ void DXILResourceMap::print(raw_ostream &OS) const {
     CI->print(OS);
     OS << "\n";
   }
+
+  for (const auto &[ResUse, ResHandle] : ResUseToHandleMap) {
+    OS << "\n";
+    OS << "Resource " << CallMap.find(ResHandle)->second;
+    OS << " is used by ";
+    ResUse->print(OS);
+    OS << "\n";
+  }
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 9f124394363a38..2dae9a8209ec57 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -249,6 +249,8 @@ class OpLowerer {
 
       removeResourceGlobals(CI);
 
+      DRM.updateResourceMap(CI, *OpCall);
+
       CI->replaceAllUsesWith(Cast);
       CI->eraseFromParent();
       return Error::success();
@@ -295,6 +297,8 @@ class OpLowerer {
 
       removeResourceGlobals(CI);
 
+      DRM.updateResourceMap(CI, *OpBind);
+
       CI->replaceAllUsesWith(Cast);
       CI->eraseFromParent();
 
@@ -479,6 +483,9 @@ class OpLowerer {
           OpCode::BufferLoad, Args, CI->getName(), NewRetTy);
       if (Error E = OpCall.takeError())
         return E;
+
+      DRM.updateResUseMap(CI, *OpCall);
+
       if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
         return E;
 
@@ -547,6 +554,8 @@ class OpLowerer {
       if (Error E = OpCall.takeError())
         return E;
 
+      DRM.updateResUseMap(CI, *OpCall);
+
       CI->eraseFromParent();
       return Error::success();
     });
diff --git a/llvm/test/Analysis/DXILResource/resource-map.ll b/llvm/test/Analysis/DXILResource/resource-map.ll
new file mode 100644
index 00000000000000..65255d4c942e53
--- /dev/null
+++ b/llvm/test/Analysis/DXILResource/resource-map.ll
@@ -0,0 +1,36 @@
+; RUN: opt -S -disable-output -disable-output -passes="print<dxil-resource>" < %s 2>&1 | FileCheck %s
+
+define void @test_typedbuffer() {
+  ; RWBuffer<float4> Buf : register(u5, space3)
+  %uav1 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_1_0(
+                  i32 3, i32 5, i32 1, i32 0, i1 false)
+  ; CHECK: Binding [[UAV1:[0-9]+]]:
+  ; CHECK:   Symbol: ptr undef
+  ; CHECK:   Name: ""
+  ; CHECK:   Binding:
+  ; CHECK:     Record ID: 0
+  ; CHECK:     Space: 3
+  ; CHECK:     Lower Bound: 5
+  ; CHECK:     Size: 1
+  ; CHECK:   Class: UAV
+  ; CHECK:   Kind: TypedBuffer
+  ; CHECK:   Globally Coherent: 0
+  ; CHECK:   HasCounter: 0
+  ; CHECK:   IsROV: 0
+  ; CHECK:   Element Type: f32
+  ; CHECK:   Element Count: 4
+
+  ; CHECK:     Call bound to [[UAV1]]:  %uav1 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0t(i32 3, i32 5, i32 1, i32 0, i1 false)
+  ; CHECK-DAG: Resource [[UAV1]] is used by   %data0 = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i32 0)
+  ; CHECK-DAG: Resource [[UAV1]] is used by   call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i32 2, <4 x float> %data0)
+
+  %data0 = call <4 x float> @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i32 0)
+  call void @llvm.dx.typedBufferStore(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1,
+      i32 2, <4 x float> %data0)
+
+  ret void
+}
+
diff --git a/llvm/test/CodeGen/DirectX/DXILResource/dxil-resource-map.ll b/llvm/test/CodeGen/DirectX/DXILResource/dxil-resource-map.ll
new file mode 100644
index 00000000000000..ac5f3d16145974
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/DXILResource/dxil-resource-map.ll
@@ -0,0 +1,48 @@
+; RUN: opt -S -disable-output -disable-output -passes="print<dxil-resource>,dxil-op-lower,print<dxil-resource>" -mtriple=dxil-pc-shadermodel6.6-compute < %s 2>&1 | FileCheck %s -check-prefixes=CHECK,CHECK_SM66
+; RUN: opt -S -disable-output -disable-output -passes="print<dxil-resource>,dxil-op-lower,print<dxil-resource>" -mtriple=dxil-pc-shadermodel6.2-compute < %s 2>&1 | FileCheck %s -check-prefixes=CHECK,CHECK_SM62
+
+define void @test_typedbuffer() {
+  ; RWBuffer<float4> Buf : register(u5, space3)
+  %uav1 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_1_0(
+                  i32 3, i32 5, i32 1, i32 0, i1 false)
+  ; CHECK: Binding [[UAV1:[0-9]+]]:
+  ; CHECK:   Symbol: ptr undef
+  ; CHECK:   Name: ""
+  ; CHECK:   Binding:
+  ; CHECK:     Record ID: 0
+  ; CHECK:     Space: 3
+  ; CHECK:     Lower Bound: 5
+  ; CHECK:     Size: 1
+  ; CHECK:   Class: UAV
+  ; CHECK:   Kind: TypedBuffer
+  ; CHECK:   Globally Coherent: 0
+  ; CHECK:   HasCounter: 0
+  ; CHECK:   IsROV: 0
+  ; CHECK:   Element Type: f32
+  ; CHECK:   Element Count: 4
+
+  ; CHECK:     Call bound to [[UAV1]]:  %uav1 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0t(i32 3, i32 5, i32 1, i32 0, i1 false)
+  ; CHECK-DAG: Resource [[UAV1]] is used by   %data0 = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i32 0)
+  ; CHECK-DAG: Resource [[UAV1]] is used by   call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i32 2, <4 x float> %data0)
+
+  %data0 = call <4 x float> @llvm.dx.typedBufferLoad(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1, i32 0)
+  call void @llvm.dx.typedBufferStore(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %uav1,
+      i32 2, <4 x float> %data0)
+
+  ;
+  ;;; After dxil-op-lower, the DXILResourceMap info should be updated.
+  ;
+  ; CHECK_SM66:     Call bound to [[UAV1]]:  %uav11 = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217, %dx.types.ResBind { i32 5, i32 5, i32 3, i8 1 }, i32 0, i1 false)
+  ; CHECK_SM66-DAG: Resource [[UAV1]] is used by   %data02 = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle %uav1_annot, i32 0, i32 undef)
+  ; CHECK_SM66-DAG: Resource [[UAV1]] is used by   call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %uav1_annot, i32 2, i32 undef, float %9, float %10, float %11, float %12, i8 15)
+  ;
+  ; CHECK_SM62:     Call bound to [[UAV1]]:  %uav11 = call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 0, i1 false)
+  ; CHECK_SM62-DAG: Resource [[UAV1]] is used by   %data02 = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle %uav11, i32 0, i32 undef)
+  ; CHECK_SM62-DAG: Resource [[UAV1]] is used by   call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle %uav11, i32 2, i32 undef, float %9, float %10, float %11, float %12, i8 15)
+
+  ret void
+}
+

>From 27c361bd7cf21506e46b074c59a57bbd6c60ceb1 Mon Sep 17 00:00:00 2001
From: Zhengxing Li <zhengxingli at microsoft.com>
Date: Fri, 4 Oct 2024 17:14:27 -0700
Subject: [PATCH 2/3] Lower llvm.dx.rawbufferload to dxil ops

This PR lowers the @llvm.dx.rawBufferLoad intrinsic to @dx.op.rawBufferLoad
---
 llvm/include/llvm/IR/IntrinsicsDirectX.td  |   3 +
 llvm/lib/Target/DirectX/DXIL.td            |  11 ++
 llvm/lib/Target/DirectX/DXILOpLowering.cpp | 190 ++++++++++++++++++++
 llvm/test/CodeGen/DirectX/RawBufferLoad.ll | 192 +++++++++++++++++++++
 4 files changed, 396 insertions(+)
 create mode 100644 llvm/test/CodeGen/DirectX/RawBufferLoad.ll

diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 48a9595f844f05..d8be8d002a0917 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -27,6 +27,9 @@ def int_dx_handle_fromBinding
           [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
           [IntrNoMem]>;
 
+def int_dx_rawBufferLoad
+    : DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty, llvm_i32_ty]>;
+
 def int_dx_typedBufferLoad
     : DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty],
                             [IntrReadMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 1a8e110491cc87..a638371d6031ad 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -854,6 +854,17 @@ def AnnotateHandle : DXILOp<216, annotateHandle> {
   let stages = [Stages<DXIL1_6, [all_stages]>];
 }
 
+def RawBufferLoad : DXILOp<139, rawBufferLoad> {
+  let Doc = "reads from a ByteAddressBuffer or StructuredBuffer";
+  // Handle, Coord0, Coord1, mask, alignment
+  let arguments = [HandleTy, Int32Ty, Int32Ty, Int8Ty, Int32Ty];
+  let result = OverloadTy;
+  let overloads =
+      [Overloads<DXIL1_0,
+                 [ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
 def CreateHandleFromBinding : DXILOp<217, createHandleFromBinding> {
   let Doc = "create resource handle from binding";
   let arguments = [ResBindTy, Int32Ty, Int1Ty];
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 2dae9a8209ec57..6fa856274e591b 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -628,6 +628,193 @@ class OpLowerer {
     });
   }
 
+  Value *GenerateRawBufLd(Value *handle, Value *bufIdx, Value *offset, Type *Ty,
+                          IRBuilder<> &Builder, unsigned NumComponents,
+                          Constant *alignment) {
+    if (bufIdx == nullptr) {
+      // This is actually a byte address buffer load with a struct template
+      // type. The call takes only one coordinates for the offset.
+      bufIdx = offset;
+      offset = UndefValue::get(offset->getType());
+    }
+
+    // NumComponents 1: mask = 1  // Mask_X;
+    // NumComponents 2: mask = 3  // Mask_X | Mask_Y
+    // NumComponents 3: mask = 7  // Mask_X | Mask_Y | Mask_Z
+    // NumComponents 4: mask = 15 // Mask_X | Mask_Y | Mask_Z | Mask_W
+    assert((NumComponents) > 0 && (NumComponents < 5));
+    Constant *mask =
+        ConstantInt::get(Builder.getInt8Ty(), ((1 << NumComponents) - 1));
+
+    Value *Args[] = {handle, bufIdx, offset, mask, alignment};
+    Type *NewRetTy = OpBuilder.getResRetType(Ty->getScalarType());
+    Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
+        OpCode::RawBufferLoad, Args, "", NewRetTy); // TODO: Need name argument?
+    if (Error E = OpCall.takeError())
+      return nullptr;
+
+    return *OpCall;
+  }
+
+  void TranslateRawBufVecLd(Type *Ty, unsigned ElemCount, IRBuilder<> &Builder,
+                            Value *handle, Value *bufIdx, Value *baseOffset,
+                            const DataLayout &DL, std::vector<Value *> &bufLds,
+                            unsigned baseAlign, bool isScalarTy) {
+    Type *VecEltTy = Ty->getScalarType();
+
+    unsigned EltSize = DL.getTypeAllocSize(VecEltTy);
+    unsigned alignment = std::min(baseAlign, EltSize);
+    Constant *alignmentVal =
+        ConstantInt::get(M.getContext(), APInt(32, alignment));
+
+    if (baseOffset == nullptr) {
+      baseOffset = ConstantInt::get(Builder.getInt32Ty(), 0);
+    }
+
+    std::vector<Value *> elts(ElemCount);
+    unsigned rest = (ElemCount % 4);
+    for (unsigned i = 0; i < ElemCount - rest; i += 4) {
+      Value *bufLd = GenerateRawBufLd(handle, bufIdx, baseOffset, Ty, Builder,
+                                      4, alignmentVal);
+      bufLds.emplace_back(bufLd);
+
+      baseOffset = Builder.CreateAdd(
+          baseOffset, ConstantInt::get(Builder.getInt32Ty(), 4 * EltSize));
+    }
+
+    if (rest) {
+      Value *bufLd = GenerateRawBufLd(handle, bufIdx, baseOffset, Ty, Builder,
+                                      rest, alignmentVal);
+      bufLds.emplace_back(bufLd);
+    }
+  }
+
+  Error replaceMultiResRetsUses(CallInst *Intrin,
+                                std::vector<Value *> &bufLds) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+
+    // TODO: HasCheckBit????
+
+    Type *OldTy = Intrin->getType();
+
+    // For scalars, we just extract the first element.
+    if (!isa<FixedVectorType>(OldTy)) {
+      CallInst *Op = dyn_cast<CallInst>(bufLds[0]);
+      assert(Op != nullptr);
+      Value *EVI = IRB.CreateExtractValue(Op, 0);
+
+      Intrin->replaceAllUsesWith(EVI);
+      Intrin->eraseFromParent();
+
+      return Error::success();
+    }
+
+    const auto *VecTy = cast<FixedVectorType>(OldTy);
+    const unsigned N = VecTy->getNumElements();
+
+    std::vector<Value *> Extracts(N);
+
+    // The users of the operation should all be scalarized, so we attempt to
+    // replace the extractelements with extractvalues directly.
+    for (Use &U : make_early_inc_range(Intrin->uses())) {
+      if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
+        if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
+          size_t IndexVal = IndexOp->getZExtValue();
+          assert(IndexVal < N && "Index into buffer load out of range");
+          if (!Extracts[IndexVal]) {
+            CallInst *Op = dyn_cast<CallInst>(bufLds[IndexVal / 4]);
+            assert(Op != nullptr);
+            Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal % 4);
+          }
+          EEI->replaceAllUsesWith(Extracts[IndexVal]);
+          EEI->eraseFromParent();
+        } else {
+          // Need to handle DynamicAccesses here???
+        }
+      }
+    }
+
+    // If there's a dynamic access we need to round trip through stack memory so
+    // that we don't leave vectors around.
+    //
+    // TODO: dynamic access for rawbuffer??????
+    //
+
+    // If we still have uses, then we're not fully scalarized and need to
+    // recreate the vector. This should only happen for things like exported
+    // functions from libraries.
+    if (!Intrin->use_empty()) {
+      for (int I = 0, E = N; I != E; ++I)
+        if (!Extracts[I]) {
+          CallInst *Op = dyn_cast<CallInst>(bufLds[I / 4]);
+          assert(Op != nullptr);
+          Extracts[I] = IRB.CreateExtractValue(Op, I % 4);
+        }
+
+      Value *Vec = UndefValue::get(OldTy);
+      for (int I = 0, E = N; I != E; ++I)
+        Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
+
+      Intrin->replaceAllUsesWith(Vec);
+    }
+
+    // TODO:
+    // Remove the dx.op.rawbufferload without any uses now?
+
+    Intrin->eraseFromParent();
+
+    return Error::success();
+  }
+
+  [[nodiscard]] bool lowerRawBufferLoad(Function &F) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+
+    return replaceFunction(F, [&](CallInst *CI) -> Error {
+      IRB.SetInsertPoint(CI);
+#if 0
+      auto *It = DRM.find(dyn_cast<CallInst>(CI->getArgOperand(0)));
+      assert(It != DRM.end() && "Resource not in map?");
+      dxil::ResourceInfo &RI = *It;
+
+      assert((RI.getResourceKind() == dxil::ResourceKind::StructuredBuffer) ||
+             (RI.getResourceKind() == dxil::ResourceKind::RawBuffer));
+#else
+      ResourceKind RCKind = dxil::ResourceKind::StructuredBuffer;
+#endif
+
+      Type *Ty = CI->getType();
+      std::vector<Value *> bufLds;
+      // TODO: Need check Bool type load???
+
+      unsigned numComponents = 1;
+      if (Ty->isVectorTy()) {
+        numComponents = dyn_cast<FixedVectorType>(Ty)->getNumElements();
+      }
+
+      Value *Handle =
+          createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
+      Value *bufIdx = CI->getArgOperand(1);
+      Value *baseOffset = CI->getArgOperand(2);
+
+      bool isScalarTy = !Ty->isVectorTy();
+
+      if (RCKind == dxil::ResourceKind::StructuredBuffer) {
+        TranslateRawBufVecLd(Ty, numComponents, IRB, Handle, bufIdx, baseOffset,
+                             F.getDataLayout(), bufLds,
+                             /*baseAlign (in bytes)*/ 8, isScalarTy);
+      } else {
+        TranslateRawBufVecLd(Ty, numComponents, IRB, Handle, bufIdx, baseOffset,
+                             F.getDataLayout(), bufLds,
+                             /*baseAlign (in bytes)*/ 4, isScalarTy);
+      }
+
+      if (Error E = replaceMultiResRetsUses(CI, bufLds))
+        return E;
+
+      return Error::success();
+    });
+  }
+
   bool lowerIntrinsics() {
     bool Updated = false;
     bool HasErrors = false;
@@ -647,6 +834,9 @@ class OpLowerer {
       case Intrinsic::dx_handle_fromBinding:
         HasErrors |= lowerHandleFromBinding(F);
         break;
+      case Intrinsic::dx_rawBufferLoad:
+        HasErrors |= lowerRawBufferLoad(F);
+        break;
       case Intrinsic::dx_typedBufferLoad:
         HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/false);
         break;
diff --git a/llvm/test/CodeGen/DirectX/RawBufferLoad.ll b/llvm/test/CodeGen/DirectX/RawBufferLoad.ll
new file mode 100644
index 00000000000000..500fec544e36d5
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/RawBufferLoad.ll
@@ -0,0 +1,192 @@
+; RUN: opt -S -dxil-op-lower %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+declare void @scalar_user(float)
+declare void @vector_user(<4 x float>)
+declare void @check_user(i1)
+
+declare void @vector_user_v3f32x4(<12 x float>)
+
+;; StructureBuffer load
+
+define void @loadv4f32() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.RawBuffer", <4 x float>, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.RawBuffer_v4f32_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; The temporary casts should all have been cleaned up
+  ; CHECK-NOT: %dx.cast_handle
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle [[HANDLE]], i32 0, i32 0, i8 15, i32 4)
+  %data0 = call <4 x float> @llvm.dx.rawBufferLoad(
+      target("dx.RawBuffer", <4 x float>, 0, 0) %buffer, i32 0, i32 0)
+
+  ; The extract order depends on the users, so don't enforce that here.
+  ; CHECK-DAG: [[VAL0_0:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
+  %data0_0 = extractelement <4 x float> %data0, i32 0
+  ; CHECK-DAG: [[VAL0_2:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0]], 2
+  %data0_2 = extractelement <4 x float> %data0, i32 2
+
+  ; If all of the uses are extracts, we skip creating a vector
+  ; CHECK-NOT: insertelement
+  ; CHECK-DAG: call void @scalar_user(float [[VAL0_0]])
+  ; CHECK-DAG: call void @scalar_user(float [[VAL0_2]])
+  call void @scalar_user(float %data0_0)
+  call void @scalar_user(float %data0_2)
+
+  ; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle [[HANDLE]], i32 4, i32 0, i8 15, i32 4)
+  %data4 = call <4 x float> @llvm.dx.rawBufferLoad(
+      target("dx.RawBuffer", <4 x float>, 0, 0) %buffer, i32 4, i32 0)
+
+  ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0
+  ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1
+  ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2
+  ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3
+  ; CHECK: insertelement <4 x float> undef
+  ; CHECK: insertelement <4 x float>
+  ; CHECK: insertelement <4 x float>
+  ; CHECK: insertelement <4 x float>
+  call void @vector_user(<4 x float> %data4)
+
+  ; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle [[HANDLE]], i32 12, i32 0, i8 15, i32 4)
+  %data12 = call <4 x float> @llvm.dx.rawBufferLoad(
+      target("dx.RawBuffer", <4 x float>, 0, 0) %buffer, i32 12, i32 0)
+
+  ; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3
+  %data12_3 = extractelement <4 x float> %data12, i32 3
+
+  ; If there are a mix of users we need the vector, but extracts are direct
+  ; CHECK: call void @scalar_user(float [[DATA12_3]])
+  call void @scalar_user(float %data12_3)
+  call void @vector_user(<4 x float> %data12)
+
+  ret void
+}
+
+define void @loadv3f32x4() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.RawBuffer", <12 x float>, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.RawBuffer_v3f32x4_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; The temporary casts should all have been cleaned up
+  ; CHECK-NOT: %dx.cast_handle
+
+  ; CHECK: [[DATA0_3:%.*]] = call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle [[HANDLE]], i32 0, i32 0, i8 15, i32 4)
+  ; CHECK: [[DATA4_7:%.*]] = call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle [[HANDLE]], i32 0, i32 16, i8 15, i32 4)
+  ; CHECK: [[DATA8_11:%.*]] = call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle [[HANDLE]], i32 0, i32 32, i8 15, i32 4)
+  %data0 = call <12 x float> @llvm.dx.rawBufferLoad(
+      target("dx.RawBuffer", <12 x float>, 0, 0) %buffer, i32 0, i32 0)
+
+  ; The extract order depends on the users, so don't enforce that here.
+  ; CHECK-DAG: [[VAL0_2:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA0_3]], 2
+  %data0_2 = extractelement <12 x float> %data0, i32 2
+  ; CHECK-DAG: [[VAL0_7:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA4_7]], 3
+  %data0_7 = extractelement <12 x float> %data0, i32 7
+
+  ; If all of the uses are extracts, we skip creating a vector
+  ; CHECK-NOT: insertelement
+  ; CHECK-DAG: call void @scalar_user(float [[VAL0_2]])
+  ; CHECK-DAG: call void @scalar_user(float [[VAL0_7]])
+  call void @scalar_user(float %data0_2)
+  call void @scalar_user(float %data0_7)
+
+  ;; Vector Use
+  ;
+  ; CHECK: [[DATA3_0_3:%.*]] = call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buffer_annot, i32 3, i32 0, i8 15, i32 4)
+  ; CHECK: [[DATA3_4_7:%.*]] = call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buffer_annot, i32 3, i32 16, i8 15, i32 4)
+  ; CHECK: [[DATA3_8_11:%.*]] = call %dx.types.ResRet.f32 @dx.op.rawBufferLoad.f32(i32 139, %dx.types.Handle %buffer_annot, i32 3, i32 32, i8 15, i32 4)
+  ; CHECK: [[VAL3_0:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA3_0_3]], 0
+  ; CHECK: [[VAL3_1:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA3_0_3]], 1
+  ; CHECK: [[VAL3_2:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA3_0_3]], 2
+  ; CHECK: [[VAL3_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA3_0_3]], 3
+  ; CHECK: [[VAL3_4:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA3_4_7]], 0
+  ; CHECK: [[VAL3_5:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA3_4_7]], 1
+  ; CHECK: [[VAL3_6:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA3_4_7]], 2
+  ; CHECK: [[VAL3_7:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA3_4_7]], 3
+  ; CHECK: [[VAL3_8:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA3_8_11]], 0
+  ; CHECK: [[VAL3_9:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA3_8_11]], 1
+  ; CHECK: [[VAL3_10:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA3_8_11]], 2
+  ; CHECK: [[VAL3_11:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA3_8_11]], 3
+  ; CHECK: insertelement <12 x float> undef
+  ; CHECK: insertelement <12 x float>
+  ; CHECK: insertelement <12 x float>
+  ; CHECK: insertelement <12 x float>
+  ; CHECK: insertelement <12 x float>
+  ; CHECK: insertelement <12 x float>
+  ; CHECK: insertelement <12 x float>
+  ; CHECK: insertelement <12 x float>
+  ; CHECK: insertelement <12 x float>
+  ; CHECK: insertelement <12 x float>
+  ; CHECK: insertelement <12 x float>
+  ; CHECK: [[VAL3_VecRes:%.*]] = insertelement <12 x float>
+  ; CHECK: call void @vector_user_v3f32x4(<12 x float> [[VAL3_VecRes]])
+  %data3 = call <12 x float> @llvm.dx.rawBufferLoad(
+      target("dx.RawBuffer", <12 x float>, 0, 0) %buffer, i32 3, i32 0)
+  call void @vector_user_v3f32x4(<12 x float> %data3);
+
+  ret void
+}
+
+define void @loadv4i32x2() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.RawBuffer", <8 x i32>, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.RawBuffer_v4i32x2_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0_3:%.*]] = call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buffer_annot, i32 0, i32 0, i8 15, i32 4)
+  ; CHECK: [[DATA4_7:%.*]] = call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buffer_annot, i32 0, i32 16, i8 15, i32 4)
+  %data0 = call <8 x i32> @llvm.dx.rawBufferLoad(
+      target("dx.RawBuffer", <8 x i32>, 0, 0) %buffer, i32 0, i32 0)
+
+  ret void
+}
+
+define void @loadv4f16() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.RawBuffer", <4 x half>, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.RawBuffer_v4f16_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.rawBufferLoad.f16(i32 139, %dx.types.Handle %buffer_annot, i32 0, i32 0, i8 15, i32 2)
+  %data0 = call <4 x half> @llvm.dx.rawBufferLoad(
+      target("dx.RawBuffer", <4 x half>, 0, 0) %buffer, i32 0, i32 0)
+
+  ret void
+}
+
+define void @loadv2i16x3() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.RawBuffer", <6 x i16>, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.RawBuffer_v2i16x3_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0_3:%.*]] = call %dx.types.ResRet.i16 @dx.op.rawBufferLoad.i16(i32 139, %dx.types.Handle %buffer_annot, i32 0, i32 0, i8 15, i32 2)
+  ; CHECK: [[DATA4_5:%.*]] = call %dx.types.ResRet.i16 @dx.op.rawBufferLoad.i16(i32 139, %dx.types.Handle %buffer_annot, i32 0, i32 8, i8 3, i32 2)
+  %data0 = call <6 x i16> @llvm.dx.rawBufferLoad(
+      target("dx.RawBuffer", <6 x i16>, 0, 0) %buffer, i32 0, i32 0)
+
+  ret void
+}
+
+;; ByteAddressBuffer load
+define void @load_2() {
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.RawBuffer", i8, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.RawBuffer_v2i32_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0_1:%.*]] = call %dx.types.ResRet.i32 @dx.op.rawBufferLoad.i32(i32 139, %dx.types.Handle %buffer_annot, i32 0, i32 0, i8 3, i32 4)
+  %data0 = call <2 x i32> @llvm.dx.rawBufferLoad(
+      target("dx.RawBuffer", i8, 0, 0) %buffer, i32 0, i32 0)
+
+  ret void
+}

>From 3a8f3a1996ed2c87aae08632fcb4fc08893cfae1 Mon Sep 17 00:00:00 2001
From: Zhengxing Li <zhengxingli at microsoft.com>
Date: Thu, 24 Oct 2024 14:25:15 -0700
Subject: [PATCH 3/3] Changes in DXILResourceMap for lowering
 llvm.dx.rawbufferload to dxil ops

---
 llvm/include/llvm/Analysis/DXILResource.h  |  5 +++++
 llvm/lib/Analysis/DXILResource.cpp         | 14 ++++++++++++++
 llvm/lib/Target/DirectX/DXILOpLowering.cpp | 14 +++++++-------
 3 files changed, 26 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 016e45e78c3984..7f64688fba183a 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -340,6 +340,11 @@ class DXILResourceMap {
 
   void updateResourceMap(CallInst *origCallInst, CallInst *newCallInst);
 
+  // Update ResUseMap with multiple new resource uses
+  void updateResUseMap(CallInst *origResUse,
+                       std::vector<Value *> &multiNewResUse);
+
+  // Update ResUseMap with single new resource use
   void updateResUseMap(CallInst *origResUse, CallInst *newResUse) {
     assert((origResUse != nullptr) && (newResUse != nullptr) &&
            (origResUse != newResUse) && "Wrong Inputs");
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 601d2648ae0288..be4fad8a84a171 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -791,6 +791,20 @@ void DXILResourceMap::updateResourceMap(CallInst *origCallInst,
   }
 }
 
+  void DXILResourceMap::updateResUseMap(CallInst *origResUse,
+                                      std::vector<Value *> &multiNewResUse) {
+    assert((origResUse != nullptr) && "Wrong Inputs");
+
+    for (int i = 0; i < multiNewResUse.size(); ++i) {
+      CallInst *newResUse = dyn_cast<CallInst>(multiNewResUse[i]);
+      assert(newResUse != nullptr);
+
+      bool keepOrigResUseInMap =
+          i == (multiNewResUse.size() - 1) ? false : true;
+      updateResUseMapCommon(origResUse, newResUse, keepOrigResUseInMap);
+    }
+  }
+
 void DXILResourceMap::print(raw_ostream &OS) const {
   for (unsigned I = 0, E = Resources.size(); I != E; ++I) {
     OS << "Binding " << I << ":\n";
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 6fa856274e591b..d2b6f05fc936bd 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -704,6 +704,7 @@ class OpLowerer {
       Value *EVI = IRB.CreateExtractValue(Op, 0);
 
       Intrin->replaceAllUsesWith(EVI);
+      DRM.updateResUseMap(Intrin, Op);
       Intrin->eraseFromParent();
 
       return Error::success();
@@ -761,6 +762,7 @@ class OpLowerer {
     // TODO:
     // Remove the dx.op.rawbufferload without any uses now?
 
+    DRM.updateResUseMap(Intrin, bufLds);
     Intrin->eraseFromParent();
 
     return Error::success();
@@ -771,16 +773,14 @@ class OpLowerer {
 
     return replaceFunction(F, [&](CallInst *CI) -> Error {
       IRB.SetInsertPoint(CI);
-#if 0
-      auto *It = DRM.find(dyn_cast<CallInst>(CI->getArgOperand(0)));
+
+      auto *It = DRM.find(DRM.findResHandleByUse(CI));
       assert(It != DRM.end() && "Resource not in map?");
       dxil::ResourceInfo &RI = *It;
 
-      assert((RI.getResourceKind() == dxil::ResourceKind::StructuredBuffer) ||
-             (RI.getResourceKind() == dxil::ResourceKind::RawBuffer));
-#else
-      ResourceKind RCKind = dxil::ResourceKind::StructuredBuffer;
-#endif
+      ResourceKind RCKind =  RI.getResourceKind();
+      assert((RCKind == dxil::ResourceKind::StructuredBuffer) ||
+             (RCKind == dxil::ResourceKind::RawBuffer));
 
       Type *Ty = CI->getType();
       std::vector<Value *> bufLds;



More information about the llvm-commits mailing list