[llvm] b88eef9 - [DSE] Add predicated vector length store support for masked store elimination (#134175)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Apr 9 18:12:18 PDT 2025
Author: Michael Berg
Date: 2025-04-09T18:12:15-07:00
New Revision: b88eef95e72ed54580c3f65c49cc8b768a764938
URL: https://github.com/llvm/llvm-project/commit/b88eef95e72ed54580c3f65c49cc8b768a764938
DIFF: https://github.com/llvm/llvm-project/commit/b88eef95e72ed54580c3f65c49cc8b768a764938.diff
LOG: [DSE] Add predicated vector length store support for masked store elimination (#134175)
In isMaskedStoreOverwrite we process two stores that fully overwrite one
another, here we add support for predicated vector length stores so that
DSE will eliminate this variant of masked stores.
This is the follow up installment mentioned in:
https://reviews.llvm.org/D132700
Added:
llvm/test/Transforms/DeadStoreElimination/dead-vp.store.ll
Modified:
llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 935f21fd484f3..141af344f0e16 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -248,28 +248,43 @@ static OverwriteResult isMaskedStoreOverwrite(const Instruction *KillingI,
return OW_Unknown;
if (KillingII->getIntrinsicID() != DeadII->getIntrinsicID())
return OW_Unknown;
- if (KillingII->getIntrinsicID() == Intrinsic::masked_store) {
- // Type size.
- VectorType *KillingTy =
- cast<VectorType>(KillingII->getArgOperand(0)->getType());
- VectorType *DeadTy = cast<VectorType>(DeadII->getArgOperand(0)->getType());
- if (KillingTy->getScalarSizeInBits() != DeadTy->getScalarSizeInBits())
+
+ switch (KillingII->getIntrinsicID()) {
+ case Intrinsic::masked_store:
+ case Intrinsic::vp_store: {
+ const DataLayout &DL = KillingII->getDataLayout();
+ auto *KillingTy = KillingII->getArgOperand(0)->getType();
+ auto *DeadTy = DeadII->getArgOperand(0)->getType();
+ if (DL.getTypeSizeInBits(KillingTy) != DL.getTypeSizeInBits(DeadTy))
return OW_Unknown;
// Element count.
- if (KillingTy->getElementCount() != DeadTy->getElementCount())
+ if (cast<VectorType>(KillingTy)->getElementCount() !=
+ cast<VectorType>(DeadTy)->getElementCount())
return OW_Unknown;
// Pointers.
- Value *KillingPtr = KillingII->getArgOperand(1)->stripPointerCasts();
- Value *DeadPtr = DeadII->getArgOperand(1)->stripPointerCasts();
+ Value *KillingPtr = KillingII->getArgOperand(1);
+ Value *DeadPtr = DeadII->getArgOperand(1);
if (KillingPtr != DeadPtr && !AA.isMustAlias(KillingPtr, DeadPtr))
return OW_Unknown;
- // Masks.
- // TODO: check that KillingII's mask is a superset of the DeadII's mask.
- if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
- return OW_Unknown;
+ if (KillingII->getIntrinsicID() == Intrinsic::masked_store) {
+ // Masks.
+ // TODO: check that KillingII's mask is a superset of the DeadII's mask.
+ if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
+ return OW_Unknown;
+ } else if (KillingII->getIntrinsicID() == Intrinsic::vp_store) {
+ // Masks.
+ // TODO: check that KillingII's mask is a superset of the DeadII's mask.
+ if (KillingII->getArgOperand(2) != DeadII->getArgOperand(2))
+ return OW_Unknown;
+ // Lengths.
+ if (KillingII->getArgOperand(3) != DeadII->getArgOperand(3))
+ return OW_Unknown;
+ }
return OW_Complete;
}
- return OW_Unknown;
+ default:
+ return OW_Unknown;
+ }
}
/// Return 'OW_Complete' if a store to the 'KillingLoc' location completely
diff --git a/llvm/test/Transforms/DeadStoreElimination/dead-vp.store.ll b/llvm/test/Transforms/DeadStoreElimination/dead-vp.store.ll
new file mode 100644
index 0000000000000..7ba1354d8cd0b
--- /dev/null
+++ b/llvm/test/Transforms/DeadStoreElimination/dead-vp.store.ll
@@ -0,0 +1,93 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -passes=dse -S < %s | FileCheck %s
+
+; Test predicated vector length masked stores for elimination
+
+define void @test1(ptr %a, i32 %vl, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2) {
+;
+; CHECK-LABEL: @test1(
+; CHECK-NEXT: [[VP_OP:%.*]] = call <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32> [[V1:%.*]], <vscale x 8 x i32> [[V2:%.*]], <vscale x 8 x i1> splat (i1 true), i32 [[VL:%.*]])
+; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[VP_OP]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> splat (i1 true), i32 [[VL]])
+; CHECK-NEXT: ret void
+;
+ call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl)
+ %vp.op = call <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2, <vscale x 8 x i1> splat (i1 true), i32 %vl)
+ call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %vp.op, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl)
+ ret void
+}
+
+; False test for
diff erent vector lengths
+
+define void @test2(ptr %a, i32 %vl1, i32 %vl2, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2) {
+;
+; CHECK-LABEL: @test2(
+; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> splat (i1 true), i32 [[VL1:%.*]])
+; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> splat (i1 true), i32 [[VL2:%.*]])
+; CHECK-NEXT: ret void
+;
+ call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl1)
+ call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl2)
+ ret void
+}
+
+; False test for
diff erent types
+
+define void @test3(ptr %a, i32 %vl1, i32 %vl2, <vscale x 4 x i32> %v1, <vscale x 8 x i32> %v2) {
+;
+; CHECK-LABEL: @test3(
+; CHECK-NEXT: call void @llvm.vp.store.nxv4i32.p0(<vscale x 4 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 4 x i1> splat (i1 true), i32 [[VL1:%.*]])
+; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> splat (i1 true), i32 [[VL2:%.*]])
+; CHECK-NEXT: ret void
+;
+ call void @llvm.vp.store.nxv4i32.p0(<vscale x 4 x i32> %v1, ptr nonnull %a, <vscale x 4 x i1> splat (i1 true), i32 %vl1)
+ call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl2)
+ ret void
+}
+
+; False test for
diff erent element count
+
+define void @test4(ptr %a, i32 %vl, <vscale x 4 x i64> %v1, <vscale x 8 x i32> %v2) {
+;
+; CHECK-LABEL: @test4(
+; CHECK-NEXT: call void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 4 x i1> splat (i1 true), i32 [[VL:%.*]])
+; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> splat (i1 true), i32 [[VL]])
+; CHECK-NEXT: ret void
+;
+ call void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64> %v1, ptr nonnull %a, <vscale x 4 x i1> splat (i1 true), i32 %vl)
+ call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> splat (i1 true), i32 %vl)
+ ret void
+}
+
+; False test for
diff erent masks
+
+define void @test5(ptr %a, i32 %vl, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2, <vscale x 8 x i1> %m1, <vscale x 8 x i1> %m2) {
+;
+; CHECK-LABEL: @test5(
+; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> [[M1:%.*]], i32 [[VL:%.*]])
+; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[A]], <vscale x 8 x i1> [[M2:%.*]], i32 [[VL]])
+; CHECK-NEXT: ret void
+;
+ call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> %m1, i32 %vl)
+ call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %a, <vscale x 8 x i1> %m2, i32 %vl)
+ ret void
+}
+
+; False test for
diff erent pointers
+
+define void @test6(ptr %a, ptr %b, i32 %vl, <vscale x 8 x i32> %v1, <vscale x 8 x i32> %v2, <vscale x 8 x i1> %m1) {
+;
+; CHECK-LABEL: @test6(
+; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V1:%.*]], ptr nonnull [[A:%.*]], <vscale x 8 x i1> [[M1:%.*]], i32 [[VL:%.*]])
+; CHECK-NEXT: call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> [[V2:%.*]], ptr nonnull [[B:%.*]], <vscale x 8 x i1> [[M1]], i32 [[VL]])
+; CHECK-NEXT: ret void
+;
+ call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v1, ptr nonnull %a, <vscale x 8 x i1> %m1, i32 %vl)
+ call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %v2, ptr nonnull %b, <vscale x 8 x i1> %m1, i32 %vl)
+ ret void
+}
+
+declare <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32>, <vscale x 8 x i32>, <vscale x 8 x i1>, i32)
+declare void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32>, ptr nocapture, <vscale x 8 x i1>, i32)
+declare void @llvm.vp.store.nxv4i32.p0(<vscale x 4 x i32>, ptr nocapture, <vscale x 4 x i1>, i32)
+declare void @llvm.vp.store.nxv4i64.p0(<vscale x 4 x i64>, ptr nocapture, <vscale x 4 x i1>, i32)
+
More information about the llvm-commits
mailing list