[llvm] [DirectX] Clean up extra vectors when lowering to buffer store (PR #116721)

Justin Bogner via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 20 09:41:18 PST 2024


https://github.com/bogner updated https://github.com/llvm/llvm-project/pull/116721

>From ec3c4f440a089e2f157344c998a9d3de2e691ad1 Mon Sep 17 00:00:00 2001
From: Justin Bogner <mail at justinbogner.com>
Date: Fri, 15 Nov 2024 10:56:14 -0800
Subject: [PATCH 1/2] [DirectX] Clean up extra vectors when lowering to buffer
 store

DXILOpLowering runs after scalarization but `@llvm.dx.typedbuffer.store`
takes a vector, so the argument is usually an artifact. Avoid creating a
vector just to extract elements from it immediately.
---
 llvm/lib/Target/DirectX/DXILOpLowering.cpp | 45 ++++++++++++++++------
 llvm/test/CodeGen/DirectX/BufferStore.ll   | 24 ++++++++++++
 2 files changed, 58 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 9f124394363a38..00daa87a53aa13 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -531,23 +531,46 @@ class OpLowerer {
         return make_error<StringError>(
             "typedBufferStore data must be a vector of 4 elements",
             inconvertibleErrorCode());
-      Value *Data0 =
-          IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 0));
-      Value *Data1 =
-          IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 1));
-      Value *Data2 =
-          IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 2));
-      Value *Data3 =
-          IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 3));
-
-      std::array<Value *, 8> Args{Handle, Index0, Index1, Data0,
-                                  Data1,  Data2,  Data3,  Mask};
+
+      // Since we're post-scalarizer, we likely have a vector that's constructed
+      // solely for the argument of the store.
+      std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr};
+      auto *IEI = dyn_cast<InsertElementInst>(Data);
+      while (IEI) {
+        auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2));
+        if (!IndexOp)
+          break;
+        size_t IndexVal = IndexOp->getZExtValue();
+        assert(IndexVal < 4 && "Too many elements for buffer store");
+        DataElements[IndexVal] = IEI->getOperand(1);
+        IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
+      }
+
+      // If for some reason we weren't able to forward the arguments from the
+      // scalarizer artifact, then we need to actually extract elements from the
+      // vector.
+      for (int I = 0, E = 4; I != E; ++I)
+        if (DataElements[I] == nullptr)
+          DataElements[I] =
+              IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I));
+
+      std::array<Value *, 8> Args{
+          Handle,          Index0,          Index1,          DataElements[0],
+          DataElements[1], DataElements[2], DataElements[3], Mask};
       Expected<CallInst *> OpCall =
           OpBuilder.tryCreateOp(OpCode::BufferStore, Args, CI->getName());
       if (Error E = OpCall.takeError())
         return E;
 
       CI->eraseFromParent();
+      // Clean up any leftover `insertelement`s
+      IEI = dyn_cast<InsertElementInst>(Data);
+      while (IEI && IEI->use_empty()) {
+        InsertElementInst *Tmp = IEI;
+        IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
+        Tmp->eraseFromParent();
+      }
+
       return Error::success();
     });
   }
diff --git a/llvm/test/CodeGen/DirectX/BufferStore.ll b/llvm/test/CodeGen/DirectX/BufferStore.ll
index 9ea7735be59c81..81cc5fd328e0a7 100644
--- a/llvm/test/CodeGen/DirectX/BufferStore.ll
+++ b/llvm/test/CodeGen/DirectX/BufferStore.ll
@@ -90,3 +90,27 @@ define void @storei16(<4 x i16> %data, i32 %index) {
 
   ret void
 }
+
+define void @store_scalarized_floats(float %data0, float %data1, float %data2, float %data3, i32 %index) {
+
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; We shouldn't end up with any inserts/extracts.
+  ; CHECK-NOT: insertelement
+  ; CHECK-NOT: extractelement
+
+  ; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, float %data0, float %data1, float %data2, float %data3, i8 15)
+  %vec.upto0 = insertelement <4 x float> poison, float %data0, i64 0
+  %vec.upto1 = insertelement <4 x float> %vec.upto0, float %data1, i64 1
+  %vec.upto2 = insertelement <4 x float> %vec.upto1, float %data2, i64 2
+  %vec = insertelement <4 x float> %vec.upto2, float %data3, i64 3
+  call void @llvm.dx.typedBufferStore(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer,
+      i32 %index, <4 x float> %vec)
+
+  ret void
+}

>From 3ad9863845385ac874191af42506673428aebe84 Mon Sep 17 00:00:00 2001
From: Justin Bogner <mail at justinbogner.com>
Date: Wed, 20 Nov 2024 09:40:47 -0800
Subject: [PATCH 2/2] Improve comment

---
 llvm/lib/Target/DirectX/DXILOpLowering.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 00daa87a53aa13..0df36b326e5dcb 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -533,7 +533,8 @@ class OpLowerer {
             inconvertibleErrorCode());
 
       // Since we're post-scalarizer, we likely have a vector that's constructed
-      // solely for the argument of the store.
+      // solely for the argument of the store. If so, just use the scalar values
+      // from before they're inserted into the temporary.
       std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr};
       auto *IEI = dyn_cast<InsertElementInst>(Data);
       while (IEI) {



More information about the llvm-commits mailing list