[llvm-branch-commits] [DirectX] Lower `@llvm.dx.typedBufferStore` to DXIL ops (PR #104253)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Aug 14 14:29:53 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: Justin Bogner (bogner)

<details>
<summary>Changes</summary>

The `@<!-- -->llvm.dx.typedBufferStore` intrinsic is lowered to `@<!-- -->dx.op.bufferStore`.


---
Full diff: https://github.com/llvm/llvm-project/pull/104253.diff


5 Files Affected:

- (modified) llvm/docs/DirectX/DXILResources.rst (+3-3) 
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+3-2) 
- (modified) llvm/lib/Target/DirectX/DXIL.td (+12) 
- (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+40) 
- (added) llvm/test/CodeGen/DirectX/BufferStore.ll (+92) 


``````````diff
diff --git a/llvm/docs/DirectX/DXILResources.rst b/llvm/docs/DirectX/DXILResources.rst
index aef88bc43b224d..cd6c527a546610 100644
--- a/llvm/docs/DirectX/DXILResources.rst
+++ b/llvm/docs/DirectX/DXILResources.rst
@@ -365,11 +365,11 @@ Examples:
 
 .. code-block:: llvm
 
-   call void @llvm.dx.bufferStore.tdx.Buffer_f32_1_0t(
+   call void @llvm.dx.typedBufferStore.tdx.Buffer_v4f32_1_0_0t(
        target("dx.TypedBuffer", f32, 1, 0) %buf, i32 %index, <4 x f32> %data)
-   call void @llvm.dx.bufferStore.tdx.Buffer_f16_1_0t(
+   call void @llvm.dx.typedBufferStore.tdx.Buffer_v4f16_1_0_0t(
        target("dx.TypedBuffer", f16, 1, 0) %buf, i32 %index, <4 x f16> %data)
-   call void @llvm.dx.bufferStore.tdx.Buffer_f64_1_0t(
+   call void @llvm.dx.typedBufferStore.tdx.Buffer_v2f64_1_0_0t(
        target("dx.TypedBuffer", f64, 1, 0) %buf, i32 %index, <2 x f64> %data)
 
 .. list-table:: ``@llvm.dx.rawBufferPtr``
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index d817b610fa71a0..67351ad8f9b91f 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -31,8 +31,9 @@ def int_dx_handle_fromBinding
           [IntrNoMem]>;
 
 def int_dx_typedBufferLoad
-    : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
-                            [llvm_any_ty, llvm_i32_ty]>;
+    : DefaultAttrsIntrinsic<[llvm_anyvector_ty], [llvm_any_ty, llvm_i32_ty]>;
+def int_dx_typedBufferStore
+    : DefaultAttrsIntrinsic<[], [llvm_any_ty, llvm_i32_ty, llvm_anyvector_ty]>;
 
 // Cast between target extension handle types and dxil-style opaque handles
 def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index b114148f84e843..50eff20e810d79 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -697,6 +697,18 @@ def BufferLoad : DXILOp<68, bufferLoad> {
   let stages = [Stages<DXIL1_0, [all_stages]>];
 }
 
+def BufferStore : DXILOp<69, bufferStore> {
+  let Doc = "writes to an RWTypedBuffer";
+  // Handle, Coord0, Coord1, Val0, Val1, Val2, Val3, Mask
+  let arguments = [
+    HandleTy, Int32Ty, Int32Ty, OverloadTy, OverloadTy, OverloadTy, OverloadTy,
+    Int8Ty
+  ];
+  let result = VoidTy;
+  let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, Int16Ty, Int32Ty]>];
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
 def ThreadId :  DXILOp<93, threadId> {
   let Doc = "Reads the thread ID";
   let LLVMIntrinsic = int_dx_thread_id;
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 46dfc905b5875c..9d5377edd3a780 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -289,6 +289,43 @@ class OpLowerer {
     });
   }
 
+  void lowerTypedBufferStore(Function &F) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+    Type *Int8Ty = IRB.getInt8Ty();
+    Type *Int32Ty = IRB.getInt32Ty();
+
+    replaceFunction(F, [&](CallInst *CI) -> Error {
+      IRB.SetInsertPoint(CI);
+
+      Value *Handle =
+          createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
+      Value *Index0 = CI->getArgOperand(1);
+      Value *Index1 = UndefValue::get(Int32Ty);
+      // For typed stores, the mask must always cover all four elements.
+      Constant *Mask = ConstantInt::get(Int8Ty, 0xF);
+
+      Value *Data = CI->getArgOperand(2);
+      Value *Data0 =
+          IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 0));
+      Value *Data1 =
+          IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 1));
+      Value *Data2 =
+          IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 2));
+      Value *Data3 =
+          IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 3));
+
+      std::array<Value *, 8> Args{Handle, Index0, Index1, Data0,
+                                  Data1,  Data2,  Data3,  Mask};
+      Expected<CallInst *> OpCall =
+          OpBuilder.tryCreateOp(OpCode::BufferStore, Args);
+      if (Error E = OpCall.takeError())
+        return E;
+
+      CI->eraseFromParent();
+      return Error::success();
+    });
+  }
+
   bool lowerIntrinsics() {
     bool Updated = false;
 
@@ -310,6 +347,9 @@ class OpLowerer {
       case Intrinsic::dx_typedBufferLoad:
         lowerTypedBufferLoad(F);
         break;
+      case Intrinsic::dx_typedBufferStore:
+        lowerTypedBufferStore(F);
+        break;
       }
       Updated = true;
     }
diff --git a/llvm/test/CodeGen/DirectX/BufferStore.ll b/llvm/test/CodeGen/DirectX/BufferStore.ll
new file mode 100644
index 00000000000000..102084816a6f24
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/BufferStore.ll
@@ -0,0 +1,92 @@
+; RUN: opt -S -dxil-op-lower %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+define void @storefloat(<4 x float> %data, i32 %index) {
+
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
+                  i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; The temporary casts should all have been cleaned up
+  ; CHECK-NOT: %dx.cast_handle
+
+  ; CHECK: [[DATA0_0:%.*]] = extractelement <4 x float> %data, i32 0
+  ; CHECK: [[DATA0_1:%.*]] = extractelement <4 x float> %data, i32 1
+  ; CHECK: [[DATA0_2:%.*]] = extractelement <4 x float> %data, i32 2
+  ; CHECK: [[DATA0_3:%.*]] = extractelement <4 x float> %data, i32 3
+  ; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, float [[DATA0_0]], float [[DATA0_1]], float [[DATA0_2]], float [[DATA0_3]], i8 15)
+  call void @llvm.dx.typedBufferStore(
+       target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer,
+       i32 %index, <4 x float> %data)
+
+  ret void
+}
+
+define void @storeint(<4 x i32> %data, i32 %index) {
+
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x i32>, 1, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_1_0_0(
+                  i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK: [[DATA0_0:%.*]] = extractelement <4 x i32> %data, i32 0
+  ; CHECK: [[DATA0_1:%.*]] = extractelement <4 x i32> %data, i32 1
+  ; CHECK: [[DATA0_2:%.*]] = extractelement <4 x i32> %data, i32 2
+  ; CHECK: [[DATA0_3:%.*]] = extractelement <4 x i32> %data, i32 3
+  ; CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, i32 [[DATA0_0]], i32 [[DATA0_1]], i32 [[DATA0_2]], i32 [[DATA0_3]], i8 15)
+  call void @llvm.dx.typedBufferStore(
+       target("dx.TypedBuffer", <4 x i32>, 1, 0, 0) %buffer,
+       i32 %index, <4 x i32> %data)
+
+  ret void
+}
+
+define void @storehalf(<4 x half> %data, i32 %index) {
+
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x half>, 1, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_1_0_0(
+                  i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; The temporary casts should all have been cleaned up
+  ; CHECK-NOT: %dx.cast_handle
+
+  ; CHECK: [[DATA0_0:%.*]] = extractelement <4 x half> %data, i32 0
+  ; CHECK: [[DATA0_1:%.*]] = extractelement <4 x half> %data, i32 1
+  ; CHECK: [[DATA0_2:%.*]] = extractelement <4 x half> %data, i32 2
+  ; CHECK: [[DATA0_3:%.*]] = extractelement <4 x half> %data, i32 3
+  ; CHECK: call void @dx.op.bufferStore.f16(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, half [[DATA0_0]], half [[DATA0_1]], half [[DATA0_2]], half [[DATA0_3]], i8 15)
+  call void @llvm.dx.typedBufferStore(
+       target("dx.TypedBuffer", <4 x half>, 1, 0, 0) %buffer,
+       i32 %index, <4 x half> %data)
+
+  ret void
+}
+
+define void @storei16(<4 x i16> %data, i32 %index) {
+
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x i16>, 1, 0, 0)
+              @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_1_0_0(
+                  i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; The temporary casts should all have been cleaned up
+  ; CHECK-NOT: %dx.cast_handle
+
+  ; CHECK: [[DATA0_0:%.*]] = extractelement <4 x i16> %data, i32 0
+  ; CHECK: [[DATA0_1:%.*]] = extractelement <4 x i16> %data, i32 1
+  ; CHECK: [[DATA0_2:%.*]] = extractelement <4 x i16> %data, i32 2
+  ; CHECK: [[DATA0_3:%.*]] = extractelement <4 x i16> %data, i32 3
+  ; CHECK: call void @dx.op.bufferStore.i16(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, i16 [[DATA0_0]], i16 [[DATA0_1]], i16 [[DATA0_2]], i16 [[DATA0_3]], i8 15)
+  call void @llvm.dx.typedBufferStore(
+       target("dx.TypedBuffer", <4 x i16>, 1, 0, 0) %buffer,
+       i32 %index, <4 x i16> %data)
+
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/104253


More information about the llvm-branch-commits mailing list