[llvm] [DSE] Add predicated vector length store support for masked store elimination (PR #134175)

Michael Berg via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 7 13:55:10 PDT 2025


https://github.com/mcberg2021 updated https://github.com/llvm/llvm-project/pull/134175

>From 526a862362b0823f09f51787b96751ad74e27524 Mon Sep 17 00:00:00 2001
From: Michael Berg <michael.berg at sifive.com>
Date: Mon, 31 Mar 2025 14:40:55 -0700
Subject: [PATCH] [DSE] Add predicated vector length store support for masked
 store elimination

In isMaskedStoreOverwrite we process two stores that fully overwrite one
another, here we add support for predicated vector length stores so that
DSE will eliminate this variant of masked stores.
---
 .../Scalar/DeadStoreElimination.cpp           | 43 ++++++---
 .../DeadStoreElimination/dead-vp.store.ll     | 93 +++++++++++++++++++
 2 files changed, 122 insertions(+), 14 deletions(-)
 create mode 100644 llvm/test/Transforms/DeadStoreElimination/dead-vp.store.ll

diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 935f21fd484f3..141af344f0e16 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -248,28 +248,43 @@ static OverwriteResult isMaskedStoreOverwrite(const Instruction *KillingI,
     return OW_Unknown;
   if (KillingII->getIntrinsicID() != DeadII->getIntrinsicID())
     return OW_Unknown;
-  if (KillingII->getIntrinsicID() == Intrinsic::masked_store) {
-    // Type size.
-    VectorType *KillingTy =
-        cast<VectorType>(KillingII->getArgOperand(0)->getType());
-    VectorType *DeadTy = cast<VectorType>(DeadII->getArgOperand(0)->getType());
-    if (KillingTy->getScalarSizeInBits() != DeadTy->getScalarSizeInBits())
+
+  switch (KillingII->getIntrinsicID()) {
+  case Intrinsic::masked_store:
+  case Intrinsic::vp_store: {
+    const DataLayout &DL = KillingII->getDataLayout();
+    auto *KillingTy = KillingII->getArgOperand(0)->getType();
+    auto *DeadTy = DeadII->getArgOperand(0)->getType();
+    if (DL.getTypeSizeInBits(KillingTy) != DL.getTypeSizeInBits(DeadTy))
       return OW_Unknown;
     // Element count.
-    if (KillingTy->getElementCount() != DeadTy->getElementCount())
+    if (cast<VectorType>(KillingTy)->getElementCount() !=
+        cast<VectorType>(DeadTy)->getElementCount())
       return OW_Unknown;
     // Pointers.
-    Value *KillingPtr = KillingII->getArgOperand(1)->stripPointerCasts();
-    Value *DeadPtr = DeadII->getArgOperand(1)->stripPointerCasts();
+    Value *KillingPtr = KillingII->getArgOperand(1);
+    Value *DeadPtr = DeadII->getArgOperand(1);
     if (KillingPtr != DeadPtr && !AA.isMustAlias(KillingPtr, DeadPtr))
       return OW_Unknown;
-    // Masks.
-    // TODO: check that KillingII's mask is a superset of the DeadII's mask.
-    if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
-      return OW_Unknown;
+    if (KillingII->getIntrinsicID() == Intrinsic::masked_store) {
+      // Masks.
+      // TODO: check that KillingII's mask is a superset of the DeadII's mask.
+      if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
+        return OW_Unknown;
+    } else if (KillingII->getIntrinsicID() == Intrinsic::vp_store) {
+      // Masks.
+      // TODO: check that KillingII's mask is a superset of the DeadII's mask.
+      if (KillingII->getArgOperand(2) != DeadII->getArgOperand(2))
+        return OW_Unknown;
+      // Lengths.
+      if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
+        return OW_Unknown;
+    }
     return OW_Complete;
   }
-  return OW_Unknown;
+  default:
+    return OW_Unknown;
+  }
 }
 
 /// Return 'OW_Complete' if a store to the 'KillingLoc' location completely
diff --git a/llvm/test/Transforms/DeadStoreElimination/dead-vp.store.ll b/llvm/test/Transforms/DeadStoreElimination/dead-vp.store.ll
new file mode 100644
index 0000000000000..7ba1354d8cd0b
--- /dev/null
+++ b/llvm/test/Transforms/DeadStoreElimination/dead-vp.store.ll
@@ -0,0 +1,93 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=dse -S < %s | FileCheck %s
+
+; Test predicated vector length masked stores for elimination
+
+define void @test1(ptr %a, i32 %vl, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2) {
+;
+; CHECK-LABEL: @test1(
+; CHECK-NEXT:    [[VP_OP:%.*]] = call <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32> [[V1:%.*]], <vscale x 8 x i32> [[V2:%.*]], <vscale x 8 x i1> splat (i1 true), i32 [[VL:%.*]])
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[VP_OP]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> splat (i1 true), i32 [[VL]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl)
+  %vp.op = call <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2, <vscale x 8 x i1> splat (i1 true), i32 %vl)
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %vp.op, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl)
+  ret void
+}
+
+; False test for different vector lengths
+
+define void @test2(ptr %a, i32 %vl1, i32 %vl2, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2) {
+;
+; CHECK-LABEL: @test2(
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> splat (i1 true), i32 [[VL1:%.*]])
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> splat (i1 true), i32 [[VL2:%.*]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl1)
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl2)
+  ret void
+}
+
+; False test for different types
+
+define void @test3(ptr %a, i32 %vl1, i32 %vl2, <vscale x 4 x i32> %v1, <vscale x 8 x i32> %v2) {
+;
+; CHECK-LABEL: @test3(
+; CHECK-NEXT:    call void @llvm.vp.store.nxv4i32.p0(<vscale x 4 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 4 x i1> splat (i1 true), i32 [[VL1:%.*]])
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> splat (i1 true), i32 [[VL2:%.*]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.store.nxv4i32.p0(<vscale x 4 x i32> %v1, ptr nonnull %a, <vscale x 4 x i1> splat (i1 true), i32 %vl1)
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl2)
+  ret void
+}
+
+; False test for different element count
+
+define void @test4(ptr %a, i32 %vl, <vscale x 4 x i64> %v1, <vscale x 8 x i32> %v2) {
+;
+; CHECK-LABEL: @test4(
+; CHECK-NEXT:    call void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 4 x i1> splat (i1 true), i32 [[VL:%.*]])
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> splat (i1 true), i32 [[VL]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64> %v1, ptr nonnull %a, <vscale x 4 x i1> splat (i1 true), i32 %vl)
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl)
+  ret void
+}
+
+; False test for different masks
+
+define void @test5(ptr %a, i32 %vl, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2, <vscale x 8 x i1> %m1, <vscale x 8 x i1> %m2) {
+;
+; CHECK-LABEL: @test5(
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> [[M1:%.*]], i32 [[VL:%.*]])
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> [[M2:%.*]], i32 [[VL]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> %m1, i32 %vl)
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> %m2, i32 %vl)
+  ret void
+}
+
+; False test for different pointers
+
+define void @test6(ptr %a, ptr %b, i32 %vl, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2, <vscale x 8 x i1> %m1) {
+;
+; CHECK-LABEL: @test6(
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> [[M1:%.*]], i32 [[VL:%.*]])
+; CHECK-NEXT:    call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[B:%.*]], <vscale x 8 x i1> [[M1]], i32 [[VL]])
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> %m1, i32 %vl)
+  call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %b, <vscale x 8 x i1> %m1, i32 %vl)
+  ret void
+}
+
+declare <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32>, <vscale x 8 x i32>, <vscale x 8 x i1>, i32)
+declare void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32>, ptr nocapture, <vscale x 8 x i1>, i32)
+declare void @llvm.vp.store.nxv4i32.p0(<vscale x 4 x i32>, ptr nocapture, <vscale x 4 x i1>, i32)
+declare void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64>, ptr nocapture, <vscale x 4 x i1>, i32)
+



More information about the llvm-commits mailing list