[llvm] DSE: lift limitation on sizes being non-scalable (PR #110670)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 28 09:11:00 PST 2024


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/110670

>From 7f0bf5f2a3fe88e8be7b1fe90c21ea5398deb7e2 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Tue, 1 Oct 2024 14:16:20 +0100
Subject: [PATCH] DSE: lift limitation on sizes being non-scalable

As AliasAnalysis now has support for scalable sizes, lift the limitation
on analyzing scalable sizes in DeadStoreElimination.
---
 .../Scalar/DeadStoreElimination.cpp           | 70 +++++++++++++++----
 .../stores-of-existing-values.ll              | 12 +---
 2 files changed, 60 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 5555b5e29cc74f..379550cc10c112 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -271,6 +271,54 @@ static OverwriteResult isMaskedStoreOverwrite(const Instruction *KillingI,
   return OW_Unknown;
 }
 
+// Given a fixed/scalable LocationSize for DeadSize, we compute the
+// upper-range(DeadSize), by factoring in VScale.
+uint64_t getDeadSizeFactoringVScale(const LocationSize &DeadSz,
+                                    const Function &F) {
+  APInt DeadSize = APInt(64, DeadSz.getValue().getKnownMinValue());
+  ConstantRange CR = getVScaleRange(&F, 64);
+  if (DeadSz.isScalable()) {
+    bool Overflow;
+    APInt UpperRange = CR.getUnsignedMax().umul_ov(DeadSize, Overflow);
+    if (!Overflow)
+      DeadSize = UpperRange;
+  }
+  return DeadSize.getZExtValue();
+}
+
+// Given fixed/scalable LocationSizes for KillingSize and DeadSize, we compute
+// the lower-range(KillingSize) and upper-range(DeadSize), by factoring in
+// VScale.
+std::pair<uint64_t, uint64_t>
+getSizesFactoringVScale(const LocationSize &KillingSz,
+                        const LocationSize &DeadSz, const Function &F) {
+  APInt KillingSize = APInt(64, KillingSz.getValue().getKnownMinValue());
+  APInt DeadSize = APInt(64, DeadSz.getValue().getKnownMinValue());
+
+  ConstantRange CR = getVScaleRange(&F, 64);
+  bool OverflowL, OverflowU;
+  if (KillingSz.isScalable() && DeadSz.isScalable()) {
+    // We have a special-case when both are scalable, so we ensure that we don't
+    // set one of the values, if UpperRange overflows but LowerRange doesn't, or
+    // vice-versa.
+    APInt LowerRange = CR.getUnsignedMin().umul_ov(KillingSize, OverflowL);
+    APInt UpperRange = CR.getUnsignedMax().umul_ov(DeadSize, OverflowU);
+    if (!OverflowL && !OverflowU) {
+      KillingSize = LowerRange;
+      DeadSize = UpperRange;
+    }
+  } else if (KillingSz.isScalable()) {
+    APInt LowerRange = CR.getUnsignedMin().umul_ov(KillingSize, OverflowL);
+    if (!OverflowL)
+      KillingSize = LowerRange;
+  } else if (DeadSz.isScalable()) {
+    APInt UpperRange = CR.getUnsignedMax().umul_ov(DeadSize, OverflowU);
+    if (!OverflowU)
+      DeadSize = UpperRange;
+  }
+  return {KillingSize.getZExtValue(), DeadSize.getZExtValue()};
+}
+
 /// Return 'OW_Complete' if a store to the 'KillingLoc' location completely
 /// overwrites a store to the 'DeadLoc' location, 'OW_End' if the end of the
 /// 'DeadLoc' location is completely overwritten by 'KillingLoc', 'OW_Begin'
@@ -285,9 +333,11 @@ static OverwriteResult isPartialOverwrite(const MemoryLocation &KillingLoc,
                                           const MemoryLocation &DeadLoc,
                                           int64_t KillingOff, int64_t DeadOff,
                                           Instruction *DeadI,
-                                          InstOverlapIntervalsTy &IOL) {
-  const uint64_t KillingSize = KillingLoc.Size.getValue();
-  const uint64_t DeadSize = DeadLoc.Size.getValue();
+                                          InstOverlapIntervalsTy &IOL,
+                                          const Function &F) {
+  auto [KillingSize, DeadSize] =
+      getSizesFactoringVScale(KillingLoc.Size, DeadLoc.Size, F);
+
   // We may now overlap, although the overlap is not complete. There might also
   // be other incomplete overlaps, and together, they might cover the complete
   // dead store.
@@ -1063,15 +1113,9 @@ struct DSEState {
       return isMaskedStoreOverwrite(KillingI, DeadI, BatchAA);
     }
 
-    const TypeSize KillingSize = KillingLocSize.getValue();
-    const TypeSize DeadSize = DeadLoc.Size.getValue();
-    // Bail on doing Size comparison which depends on AA for now
-    // TODO: Remove AnyScalable once Alias Analysis deal with scalable vectors
-    const bool AnyScalable =
-        DeadSize.isScalable() || KillingLocSize.isScalable();
+    auto [KillingSize, DeadSize] =
+        getSizesFactoringVScale(KillingLocSize, DeadLoc.Size, F);
 
-    if (AnyScalable)
-      return OW_Unknown;
     // Query the alias information
     AliasResult AAR = BatchAA.alias(KillingLoc, DeadLoc);
 
@@ -2171,7 +2215,7 @@ struct DSEState {
 
       const Value *Ptr = Loc.Ptr->stripPointerCasts();
       int64_t DeadStart = 0;
-      uint64_t DeadSize = Loc.Size.getValue();
+      uint64_t DeadSize = getDeadSizeFactoringVScale(Loc.Size, F);
       GetPointerBaseWithConstantOffset(Ptr, DeadStart, DL);
       OverlapIntervalsTy &IntervalMap = OI.second;
       Changed |= tryToShortenEnd(DeadI, IntervalMap, DeadStart, DeadSize);
@@ -2422,7 +2466,7 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
         auto &IOL = IOLs[DeadLocWrapper.DefInst->getParent()];
         OR = isPartialOverwrite(KillingLocWrapper.MemLoc, DeadLocWrapper.MemLoc,
                                 KillingOffset, DeadOffset,
-                                DeadLocWrapper.DefInst, IOL);
+                                DeadLocWrapper.DefInst, IOL, F);
       }
       if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) {
         auto *DeadSI = dyn_cast<StoreInst>(DeadLocWrapper.DefInst);
diff --git a/llvm/test/Transforms/DeadStoreElimination/stores-of-existing-values.ll b/llvm/test/Transforms/DeadStoreElimination/stores-of-existing-values.ll
index 7193bb501c89de..e4061b0ff294af 100644
--- a/llvm/test/Transforms/DeadStoreElimination/stores-of-existing-values.ll
+++ b/llvm/test/Transforms/DeadStoreElimination/stores-of-existing-values.ll
@@ -658,9 +658,7 @@ exit:
 
 define void @scalable_scalable_redundant_store(ptr %ptr) {
 ; CHECK-LABEL: @scalable_scalable_redundant_store(
-; CHECK-NEXT:    [[GEP_PTR_2:%.*]] = getelementptr i64, ptr [[PTR:%.*]], i64 2
-; CHECK-NEXT:    store <vscale x 2 x i64> zeroinitializer, ptr [[GEP_PTR_2]], align 16
-; CHECK-NEXT:    store <vscale x 4 x i64> zeroinitializer, ptr [[PTR]], align 32
+; CHECK-NEXT:    store <vscale x 4 x i64> zeroinitializer, ptr [[PTR:%.*]], align 32
 ; CHECK-NEXT:    ret void
 ;
   %gep.ptr.2 = getelementptr i64, ptr %ptr, i64 2
@@ -697,9 +695,7 @@ define void @scalable_scalable_nonconst_offset_neg(ptr %ptr, i64 %i) {
 
 define void @scalable_fixed_redundant_store(ptr %ptr) vscale_range(1, 2) {
 ; CHECK-LABEL: @scalable_fixed_redundant_store(
-; CHECK-NEXT:    [[GEP_PTR_2:%.*]] = getelementptr i64, ptr [[PTR:%.*]], i64 2
-; CHECK-NEXT:    store <2 x i64> zeroinitializer, ptr [[GEP_PTR_2]], align 16
-; CHECK-NEXT:    store <vscale x 4 x i64> zeroinitializer, ptr [[PTR]], align 32
+; CHECK-NEXT:    store <vscale x 4 x i64> zeroinitializer, ptr [[PTR:%.*]], align 32
 ; CHECK-NEXT:    ret void
 ;
   %gep.ptr.2 = getelementptr i64, ptr %ptr, i64 2
@@ -723,9 +719,7 @@ define void @scalable_fixed_neg(ptr %ptr) vscale_range(1, 2) {
 
 define void @fixed_scalable_redundant_store(ptr %ptr) vscale_range(1, 2) {
 ; CHECK-LABEL: @fixed_scalable_redundant_store(
-; CHECK-NEXT:    [[GEP_PTR_2:%.*]] = getelementptr i64, ptr [[PTR:%.*]], i64 2
-; CHECK-NEXT:    store <vscale x 2 x i64> zeroinitializer, ptr [[GEP_PTR_2]], align 16
-; CHECK-NEXT:    store <8 x i64> zeroinitializer, ptr [[PTR]], align 64
+; CHECK-NEXT:    store <8 x i64> zeroinitializer, ptr [[PTR:%.*]], align 64
 ; CHECK-NEXT:    ret void
 ;
   %gep.ptr.2 = getelementptr i64, ptr %ptr, i64 2



More information about the llvm-commits mailing list