[llvm] 2c88ac9 - [DirectX] Clean up extra vectors when lowering to buffer store (#116721)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 2 13:34:10 PST 2024


Author: Justin Bogner
Date: 2024-12-02T13:34:06-08:00
New Revision: 2c88ac9da9f842875592b232ba957da341e62ea5

URL: https://github.com/llvm/llvm-project/commit/2c88ac9da9f842875592b232ba957da341e62ea5
DIFF: https://github.com/llvm/llvm-project/commit/2c88ac9da9f842875592b232ba957da341e62ea5.diff

LOG: [DirectX] Clean up extra vectors when lowering to buffer store (#116721)

DXILOpLowering runs after scalarization but `@llvm.dx.typedbuffer.store`
takes a vector, so the argument is usually an artifact. Avoid creating a
vector just to extract elements from it immediately.

Added: 
    

Modified: 
    llvm/lib/Target/DirectX/DXILOpLowering.cpp
    llvm/test/CodeGen/DirectX/BufferStore.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index a0d46efd1763ea..d9e70da6ed653a 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -568,23 +568,47 @@ class OpLowerer {
         return make_error<StringError>(
             "typedBufferStore data must be a vector of 4 elements",
             inconvertibleErrorCode());
-      Value *Data0 =
-          IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 0));
-      Value *Data1 =
-          IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 1));
-      Value *Data2 =
-          IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 2));
-      Value *Data3 =
-          IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 3));
-
-      std::array<Value *, 8> Args{Handle, Index0, Index1, Data0,
-                                  Data1,  Data2,  Data3,  Mask};
+
+      // Since we're post-scalarizer, we likely have a vector that's constructed
+      // solely for the argument of the store. If so, just use the scalar values
+      // from before they're inserted into the temporary.
+      std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr};
+      auto *IEI = dyn_cast<InsertElementInst>(Data);
+      while (IEI) {
+        auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2));
+        if (!IndexOp)
+          break;
+        size_t IndexVal = IndexOp->getZExtValue();
+        assert(IndexVal < 4 && "Too many elements for buffer store");
+        DataElements[IndexVal] = IEI->getOperand(1);
+        IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
+      }
+
+      // If for some reason we weren't able to forward the arguments from the
+      // scalarizer artifact, then we need to actually extract elements from the
+      // vector.
+      for (int I = 0, E = 4; I != E; ++I)
+        if (DataElements[I] == nullptr)
+          DataElements[I] =
+              IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I));
+
+      std::array<Value *, 8> Args{
+          Handle,          Index0,          Index1,          DataElements[0],
+          DataElements[1], DataElements[2], DataElements[3], Mask};
       Expected<CallInst *> OpCall =
           OpBuilder.tryCreateOp(OpCode::BufferStore, Args, CI->getName());
       if (Error E = OpCall.takeError())
         return E;
 
       CI->eraseFromParent();
+      // Clean up any leftover `insertelement`s
+      IEI = dyn_cast<InsertElementInst>(Data);
+      while (IEI && IEI->use_empty()) {
+        InsertElementInst *Tmp = IEI;
+        IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
+        Tmp->eraseFromParent();
+      }
+
       return Error::success();
     });
   }

diff  --git a/llvm/test/CodeGen/DirectX/BufferStore.ll b/llvm/test/CodeGen/DirectX/BufferStore.ll
index 9ea7735be59c81..81cc5fd328e0a7 100644
--- a/llvm/test/CodeGen/DirectX/BufferStore.ll
+++ b/llvm/test/CodeGen/DirectX/BufferStore.ll
@@ -90,3 +90,27 @@ define void @storei16(<4 x i16> %data, i32 %index) {
 
   ret void
 }
+
+define void @store_scalarized_floats(float %data0, float %data1, float %data2, float %data3, i32 %index) {
+
+  ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 217,
+  ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]]
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; We shouldn't end up with any inserts/extracts.
+  ; CHECK-NOT: insertelement
+  ; CHECK-NOT: extractelement
+
+  ; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, float %data0, float %data1, float %data2, float %data3, i8 15)
+  %vec.upto0 = insertelement <4 x float> poison, float %data0, i64 0
+  %vec.upto1 = insertelement <4 x float> %vec.upto0, float %data1, i64 1
+  %vec.upto2 = insertelement <4 x float> %vec.upto1, float %data2, i64 2
+  %vec = insertelement <4 x float> %vec.upto2, float %data3, i64 3
+  call void @llvm.dx.typedBufferStore(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer,
+      i32 %index, <4 x float> %vec)
+
+  ret void
+}


        


More information about the llvm-commits mailing list