[llvm] [DirectX] Match DXC when storing `RWBuffer<float>` (PR #129911)

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 5 10:31:13 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: Justin Bogner (bogner)

<details>
<summary>Changes</summary>

Update the lowering of `llvm.dx.resource.store.typedbuffer` to match DXC and repeat the first element in cases where we are storing fewer than 4 elements.

Fixes #<!-- -->128110

---
Full diff: https://github.com/llvm/llvm-project/pull/129911.diff


4 Files Affected:

- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1-1) 
- (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+9-9) 
- (modified) llvm/test/CodeGen/DirectX/BufferStore-errors.ll (+1-17) 
- (modified) llvm/test/CodeGen/DirectX/BufferStore.ll (+48-1) 


``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 87de68cb3ad4f..ead7286f4311c 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -34,7 +34,7 @@ def int_dx_resource_load_typedbuffer
     : DefaultAttrsIntrinsic<[llvm_any_ty, llvm_i1_ty],
                             [llvm_any_ty, llvm_i32_ty], [IntrReadMem]>;
 def int_dx_resource_store_typedbuffer
-    : DefaultAttrsIntrinsic<[], [llvm_any_ty, llvm_i32_ty, llvm_anyvector_ty],
+    : DefaultAttrsIntrinsic<[], [llvm_any_ty, llvm_i32_ty, llvm_any_ty],
                             [IntrWriteMem]>;
 def int_dx_resource_load_rawbuffer
     : DefaultAttrsIntrinsic<[llvm_any_ty, llvm_i1_ty],
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index bc41347faf06c..07467654aa88d 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -649,16 +649,13 @@ class OpLowerer {
 
       uint64_t NumElements =
           DL.getTypeSizeInBits(DataTy) / DL.getTypeSizeInBits(ScalarTy);
-      Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements));
+      Value *Mask =
+          ConstantInt::get(Int8Ty, IsRaw ? ~(~0U << NumElements) : 15U);
 
       // TODO: check that we only have vector or scalar...
-      if (!IsRaw && NumElements != 4)
-        return make_error<StringError>(
-            "typedBufferStore data must be a vector of 4 elements",
-            inconvertibleErrorCode());
-      else if (NumElements > 4)
+      if (NumElements > 4)
         return make_error<StringError>(
-            "rawBufferStore data must have at most 4 elements",
+            "Buffer store data must have at most 4 elements",
             inconvertibleErrorCode());
 
       std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr};
@@ -687,10 +684,13 @@ class OpLowerer {
         if (DataElements[I] == nullptr)
           DataElements[I] =
               IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I));
-      // For any elements beyond the length of the vector, fill up with undef.
+
+      // For any elements beyond the length of the vector, we should fill it up
+      // with undef - however, for typed buffers we repeat the first element to
+      // match DXC.
       for (int I = NumElements, E = 4; I < E; ++I)
         if (DataElements[I] == nullptr)
-          DataElements[I] = UndefValue::get(ScalarTy);
+          DataElements[I] = IsRaw ? UndefValue::get(ScalarTy) : DataElements[0];
 
       dxil::OpCode Op = OpCode::BufferStore;
       SmallVector<Value *, 9> Args{
diff --git a/llvm/test/CodeGen/DirectX/BufferStore-errors.ll b/llvm/test/CodeGen/DirectX/BufferStore-errors.ll
index 6e529973bd604..8d041b1ebfeaa 100644
--- a/llvm/test/CodeGen/DirectX/BufferStore-errors.ll
+++ b/llvm/test/CodeGen/DirectX/BufferStore-errors.ll
@@ -5,7 +5,7 @@ target triple = "dxil-pc-shadermodel6.6-compute"
 
 ; CHECK: error:
 ; CHECK-SAME: in function storetoomany
-; CHECK-SAME: typedBufferStore data must be a vector of 4 elements
+; CHECK-SAME: Buffer store data must have at most 4 elements
 define void @storetoomany(<5 x float> %data, i32 %index) "hlsl.export" {
   %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
       @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4f32_1_0_0(
@@ -18,20 +18,4 @@ define void @storetoomany(<5 x float> %data, i32 %index) "hlsl.export" {
   ret void
 }
 
-; CHECK: error:
-; CHECK-SAME: in function storetoofew
-; CHECK-SAME: typedBufferStore data must be a vector of 4 elements
-define void @storetoofew(<3 x i32> %data, i32 %index) "hlsl.export" {
-  %buffer = call target("dx.TypedBuffer", <4 x i32>, 1, 0, 0)
-      @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4i32_1_0_0(
-          i32 0, i32 0, i32 1, i32 0, i1 false)
-
-  call void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v4i32_1_0_0t.v3i32(
-      target("dx.TypedBuffer", <4 x i32>, 1, 0, 0) %buffer,
-      i32 %index, <3 x i32> %data)
-
-  ret void
-}
-
 declare void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v4f32_1_0_0t.v5f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0), i32, <5 x float>)
-declare void @llvm.dx.resource.store.typedbuffer.tdx.TypedBuffer_v4i32_1_0_0t.v3i32(target("dx.TypedBuffer", <4 x i32>, 1, 0, 0), i32, <3 x i32>)
diff --git a/llvm/test/CodeGen/DirectX/BufferStore.ll b/llvm/test/CodeGen/DirectX/BufferStore.ll
index 6892228b0d8ae..363a3c723bfd5 100644
--- a/llvm/test/CodeGen/DirectX/BufferStore.ll
+++ b/llvm/test/CodeGen/DirectX/BufferStore.ll
@@ -2,7 +2,8 @@
 
 target triple = "dxil-pc-shadermodel6.6-compute"
 
-define void @storefloat(<4 x float> %data, i32 %index) {
+; CHECK-LABEL: define void @storefloats
+define void @storefloats(<4 x float> %data, i32 %index) {
 
   ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
   ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
@@ -25,6 +26,49 @@ define void @storefloat(<4 x float> %data, i32 %index) {
   ret void
 }
 
+; CHECK-LABEL: define void @storeonefloat
+define void @storeonefloat(float %data, i32 %index) {
+
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", float, 1, 0, 0)
+      @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_f32_1_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; The temporary casts should all have been cleaned up
+  ; CHECK-NOT: %dx.resource.casthandle
+
+  ; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, float %data, float %data, float %data, float %data, i8 15){{$}}
+  call void @llvm.dx.resource.store.typedbuffer(
+      target("dx.TypedBuffer", float, 1, 0, 0) %buffer,
+      i32 %index, float %data)
+
+  ret void
+}
+
+; CHECK-LABEL: define void @storetwofloat
+define void @storetwofloat(<2 x float> %data, i32 %index) {
+
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <2 x float>, 1, 0, 0)
+      @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v2f32_1_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; The temporary casts should all have been cleaned up
+  ; CHECK-NOT: %dx.resource.casthandle
+
+  ; CHECK: [[DATA0_0:%.*]] = extractelement <2 x float> %data, i32 0
+  ; CHECK: [[DATA0_1:%.*]] = extractelement <2 x float> %data, i32 1
+  ; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, float [[DATA0_0]], float [[DATA0_1]], float [[DATA0_0]], float [[DATA0_0]], i8 15){{$}}
+  call void @llvm.dx.resource.store.typedbuffer(
+      target("dx.TypedBuffer", <2 x float>, 1, 0, 0) %buffer,
+      i32 %index, <2 x float> %data)
+
+  ret void
+}
+
+; CHECK-LABEL: define void @storeint
 define void @storeint(<4 x i32> %data, i32 %index) {
 
   ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
@@ -45,6 +89,7 @@ define void @storeint(<4 x i32> %data, i32 %index) {
   ret void
 }
 
+; CHECK-LABEL: define void @storehalf
 define void @storehalf(<4 x half> %data, i32 %index) {
 
   ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
@@ -68,6 +113,7 @@ define void @storehalf(<4 x half> %data, i32 %index) {
   ret void
 }
 
+; CHECK-LABEL: define void @storei16
 define void @storei16(<4 x i16> %data, i32 %index) {
 
   ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
@@ -91,6 +137,7 @@ define void @storei16(<4 x i16> %data, i32 %index) {
   ret void
 }
 
+; CHECK-LABEL: define void @store_scalarized_floats
 define void @store_scalarized_floats(float %data0, float %data1, float %data2, float %data3, i32 %index) {
 
   ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,

``````````

</details>


https://github.com/llvm/llvm-project/pull/129911


More information about the llvm-commits mailing list