[llvm] b88eef9 - [DSE] Add predicated vector length store support for masked store elimination (#134175)

via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 9 18:12:18 PDT 2025


Author: Michael Berg
Date: 2025-04-09T18:12:15-07:00
New Revision: b88eef95e72ed54580c3f65c49cc8b768a764938

URL: https://github.com/llvm/llvm-project/commit/b88eef95e72ed54580c3f65c49cc8b768a764938
DIFF: https://github.com/llvm/llvm-project/commit/b88eef95e72ed54580c3f65c49cc8b768a764938.diff

LOG: [DSE] Add predicated vector length store support for masked store elimination (#134175)

In isMaskedStoreOverwrite we process two stores that fully overwrite one
another, here we add support for predicated vector length stores so that
DSE will eliminate this variant of masked stores.

This is the follow up installment mentioned in:
https://reviews.llvm.org/D132700

Added: 
    llvm/test/Transforms/DeadStoreElimination/dead-vp.store.ll

Modified: 
    llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 935f21fd484f3..141af344f0e16 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -248,28 +248,43 @@ static OverwriteResult isMaskedStoreOverwrite(const Instruction *KillingI,
     return OW_Unknown;
   if (KillingII->getIntrinsicID() != DeadII->getIntrinsicID())
     return OW_Unknown;
-  if (KillingII->getIntrinsicID() == Intrinsic::masked_store) {
-    // Type size.
-    VectorType *KillingTy =
-        cast<VectorType>(KillingII->getArgOperand(0)->getType());
-    VectorType *DeadTy = cast<VectorType>(DeadII->getArgOperand(0)->getType());
-    if (KillingTy->getScalarSizeInBits() != DeadTy->getScalarSizeInBits())
+
+  switch (KillingII->getIntrinsicID()) {
+  case Intrinsic::masked_store:
+  case Intrinsic::vp_store: {
+    const DataLayout &DL = KillingII->getDataLayout();
+    auto *KillingTy = KillingII->getArgOperand(0)->getType();
+    auto *DeadTy = DeadII->getArgOperand(0)->getType();
+    if (DL.getTypeSizeInBits(KillingTy) != DL.getTypeSizeInBits(DeadTy))
       return OW_Unknown;
     // Element count.
-    if (KillingTy->getElementCount() != DeadTy->getElementCount())
+    if (cast<VectorType>(KillingTy)->getElementCount() !=
+        cast<VectorType>(DeadTy)->getElementCount())
       return OW_Unknown;
     // Pointers.
-    Value *KillingPtr = KillingII->getArgOperand(1)->stripPointerCasts();
-    Value *DeadPtr = DeadII->getArgOperand(1)->stripPointerCasts();
+    Value *KillingPtr = KillingII->getArgOperand(1);
+    Value *DeadPtr = DeadII->getArgOperand(1);
     if (KillingPtr != DeadPtr && !AA.isMustAlias(KillingPtr, DeadPtr))
       return OW_Unknown;
-    // Masks.
-    // TODO: check that KillingII's mask is a superset of the DeadII's mask.
-    if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
-      return OW_Unknown;
+    if (KillingII->getIntrinsicID() == Intrinsic::masked_store) {
+      // Masks.
+      // TODO: check that KillingII's mask is a superset of the DeadII's mask.
+      if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
+        return OW_Unknown;
+    } else if (KillingII->getIntrinsicID() == Intrinsic::vp_store) {
+      // Masks.
+      // TODO: check that KillingII's mask is a superset of the DeadII's mask.
+      if (KillingII->getArgOperand(2) != DeadII->getArgOperand(2))
+        return OW_Unknown;
+      // Lengths.
+      if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
+        return OW_Unknown;
+    }
     return OW_Complete;
   }
-  return OW_Unknown;
+  default:
+    return OW_Unknown;
+  }
 }
 
 /// Return 'OW_Complete' if a store to the 'KillingLoc' location completely

diff  --git a/llvm/test/Transforms/DeadStoreElimination/dead-vp.store.ll b/llvm/test/Transforms/DeadStoreElimination/dead-vp.store.ll
new file mode 100644
index 0000000000000..7ba1354d8cd0b
--- /dev/null
+++ b/llvm/test/Transforms/DeadStoreElimination/dead-vp.store.ll
@@ -0,0 +1,93 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=dse -S < %s | FileCheck %s
+
+; Test predicated vector length masked stores for elimination
+
+define void @test1(ptr %a, i32 %vl, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2) {
+;
+; CHECK-LABEL: @test1(
+; CHECK-NEXT:    [[VP_OP:%.*]] = call <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32> [[V1:%.*]], <vscale x 8 x i32> [[V2:%.*]], <vscale x 8 x i1> splat (i1 true), i32 [[VL:%.*]])
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[VP_OP]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> splat (i1 true), i32 [[VL]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl)
+  %vp.op = call <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2, <vscale x 8 x i1> splat (i1 true), i32 %vl)
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %vp.op, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl)
+  ret void
+}
+
+; False test for 
diff erent vector lengths
+
+define void @test2(ptr %a, i32 %vl1, i32 %vl2, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2) {
+;
+; CHECK-LABEL: @test2(
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> splat (i1 true), i32 [[VL1:%.*]])
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> splat (i1 true), i32 [[VL2:%.*]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl1)
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl2)
+  ret void
+}
+
+; False test for 
diff erent types
+
+define void @test3(ptr %a, i32 %vl1, i32 %vl2, <vscale x 4 x i32> %v1, <vscale x 8 x i32> %v2) {
+;
+; CHECK-LABEL: @test3(
+; CHECK-NEXT:    call void @llvm.vp.store.nxv4i32.p0(<vscale x 4 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 4 x i1> splat (i1 true), i32 [[VL1:%.*]])
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> splat (i1 true), i32 [[VL2:%.*]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.store.nxv4i32.p0(<vscale x 4 x i32> %v1, ptr nonnull %a, <vscale x 4 x i1> splat (i1 true), i32 %vl1)
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl2)
+  ret void
+}
+
+; False test for 
diff erent element count
+
+define void @test4(ptr %a, i32 %vl, <vscale x 4 x i64> %v1, <vscale x 8 x i32> %v2) {
+;
+; CHECK-LABEL: @test4(
+; CHECK-NEXT:    call void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 4 x i1> splat (i1 true), i32 [[VL:%.*]])
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> splat (i1 true), i32 [[VL]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64> %v1, ptr nonnull %a, <vscale x 4 x i1> splat (i1 true), i32 %vl)
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl)
+  ret void
+}
+
+; False test for 
diff erent masks
+
+define void @test5(ptr %a, i32 %vl, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2, <vscale x 8 x i1> %m1, <vscale x 8 x i1> %m2) {
+;
+; CHECK-LABEL: @test5(
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> [[M1:%.*]], i32 [[VL:%.*]])
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> [[M2:%.*]], i32 [[VL]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> %m1, i32 %vl)
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> %m2, i32 %vl)
+  ret void
+}
+
+; False test for 
diff erent pointers
+
+define void @test6(ptr %a, ptr %b, i32 %vl, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2, <vscale x 8 x i1> %m1) {
+;
+; CHECK-LABEL: @test6(
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> [[M1:%.*]], i32 [[VL:%.*]])
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[B:%.*]], <vscale x 8 x i1> [[M1]], i32 [[VL]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> %m1, i32 %vl)
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %b, <vscale x 8 x i1> %m1, i32 %vl)
+  ret void
+}
+
+declare <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32>, <vscale x 8 x i32>, <vscale x 8 x i1>, i32)
+declare void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32>, ptr nocapture, <vscale x 8 x i1>, i32)
+declare void @llvm.vp.store.nxv4i32.p0(<vscale x 4 x i32>, ptr nocapture, <vscale x 4 x i1>, i32)
+declare void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64>, ptr nocapture, <vscale x 4 x i1>, i32)
+


        


More information about the llvm-commits mailing list