[llvm-branch-commits] [llvm] [IR] Account for byte width in m_PtrAdd (PR #106540)

Sergei Barannikov via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Aug 29 04:51:23 PDT 2024


https://github.com/s-barannikov created https://github.com/llvm/llvm-project/pull/106540

The method has few uses yet, so just pass DL argument to it. The change
follows m_PtrToIntSameSize, and I don't see a better way of delivering
the byte width to the method.

>From 7cca42199cf0bae1b5648ff6b1dcc42677001548 Mon Sep 17 00:00:00 2001
From: Sergei Barannikov <barannikov88 at gmail.com>
Date: Thu, 29 Aug 2024 00:54:20 +0300
Subject: [PATCH] [IR] Account for byte width in m_PtrAdd

The method has few uses yet, so just pass DL argument to it. The change
follows m_PtrToIntSameSize, and I don't see a better way of delivering
the byte width to the method.
---
 llvm/include/llvm/IR/PatternMatch.h           | 13 ++++++----
 llvm/lib/Analysis/InstructionSimplify.cpp     |  2 +-
 .../InstCombineSimplifyDemanded.cpp           |  7 +++---
 .../InstCombine/InstructionCombining.cpp      |  2 +-
 llvm/unittests/IR/PatternMatch.cpp            | 25 ++++++++++++++-----
 5 files changed, 33 insertions(+), 16 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 8c6b7895470b8d..29fe7eed73cff5 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1845,15 +1845,17 @@ struct m_SplatOrPoisonMask {
 };
 
 template <typename PointerOpTy, typename OffsetOpTy> struct PtrAdd_match {
+  const DataLayout &DL;
   PointerOpTy PointerOp;
   OffsetOpTy OffsetOp;
 
-  PtrAdd_match(const PointerOpTy &PointerOp, const OffsetOpTy &OffsetOp)
-      : PointerOp(PointerOp), OffsetOp(OffsetOp) {}
+  PtrAdd_match(const DataLayout &DL, const PointerOpTy &PointerOp,
+               const OffsetOpTy &OffsetOp)
+      : DL(DL), PointerOp(PointerOp), OffsetOp(OffsetOp) {}
 
   template <typename OpTy> bool match(OpTy *V) {
     auto *GEP = dyn_cast<GEPOperator>(V);
-    return GEP && GEP->getSourceElementType()->isIntegerTy(8) &&
+    return GEP && GEP->getSourceElementType()->isIntegerTy(DL.getByteWidth()) &&
            PointerOp.match(GEP->getPointerOperand()) &&
            OffsetOp.match(GEP->idx_begin()->get());
   }
@@ -1895,8 +1897,9 @@ inline auto m_GEP(const OperandTypes &...Ops) {
 /// Matches GEP with i8 source element type
 template <typename PointerOpTy, typename OffsetOpTy>
 inline PtrAdd_match<PointerOpTy, OffsetOpTy>
-m_PtrAdd(const PointerOpTy &PointerOp, const OffsetOpTy &OffsetOp) {
-  return PtrAdd_match<PointerOpTy, OffsetOpTy>(PointerOp, OffsetOp);
+m_PtrAdd(const DataLayout &DL, const PointerOpTy &PointerOp,
+         const OffsetOpTy &OffsetOp) {
+  return PtrAdd_match<PointerOpTy, OffsetOpTy>(DL, PointerOp, OffsetOp);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 32a9f1ab34fb3f..172dd0c8d86982 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -5373,7 +5373,7 @@ static Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty,
   // ptrtoint (ptradd (Ptr, X - ptrtoint(Ptr))) -> X
   Value *Ptr, *X;
   if (CastOpc == Instruction::PtrToInt &&
-      match(Op, m_PtrAdd(m_Value(Ptr),
+      match(Op, m_PtrAdd(Q.DL, m_Value(Ptr),
                          m_Sub(m_Value(X), m_PtrToInt(m_Deferred(Ptr))))) &&
       X->getType() == Ty && Ty == Q.DL.getIndexType(Ptr->getType()))
     return X;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 9c4d206692fac6..c6484d125e63ef 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -987,9 +987,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I,
         Value *InnerPtr;
         uint64_t GEPIndex;
         uint64_t PtrMaskImmediate;
-        if (match(I, m_Intrinsic<Intrinsic::ptrmask>(
-                         m_PtrAdd(m_Value(InnerPtr), m_ConstantInt(GEPIndex)),
-                         m_ConstantInt(PtrMaskImmediate)))) {
+        if (match(I,
+                  m_Intrinsic<Intrinsic::ptrmask>(
+                      m_PtrAdd(DL, m_Value(InnerPtr), m_ConstantInt(GEPIndex)),
+                      m_ConstantInt(PtrMaskImmediate)))) {
 
           LHSKnown = computeKnownBits(InnerPtr, Depth + 1, I);
           if (!LHSKnown.isZero()) {
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 8a96d1d0fb4c90..37eddcf6c6dc94 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2350,7 +2350,7 @@ static Instruction *canonicalizeGEPOfConstGEPI8(GetElementPtrInst &GEP,
   auto &DL = IC.getDataLayout();
   Value *Base;
   const APInt *C1;
-  if (!match(Src, m_PtrAdd(m_Value(Base), m_APInt(C1))))
+  if (!match(Src, m_PtrAdd(DL, m_Value(Base), m_APInt(C1))))
     return nullptr;
   Value *VarIndex;
   const APInt *C2;
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 379f97fb63139f..d7cb004626a3b6 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -2419,26 +2419,39 @@ TEST_F(PatternMatchTest, ConstExpr) {
   EXPECT_TRUE(match(V, m_ConstantExpr()));
 }
 
-TEST_F(PatternMatchTest, PtrAdd) {
+// PatternMatchTest parametrized by byte width.
+class PatternMatchByteParamTest : public PatternMatchTest,
+                             public ::testing::WithParamInterface<unsigned> {
+public:
+  PatternMatchByteParamTest() {
+    M->setDataLayout("b:" + std::to_string(GetParam()));
+  }
+};
+
+INSTANTIATE_TEST_SUITE_P(ByteWidths, PatternMatchByteParamTest,
+                         ::testing::Values(8, 16, 32));
+
+TEST_P(PatternMatchByteParamTest, PtrAdd) {
+  const DataLayout &DL = M->getDataLayout();
   Type *PtrTy = PointerType::getUnqual(Ctx);
   Type *IdxTy = Type::getInt64Ty(Ctx);
   Constant *Null = Constant::getNullValue(PtrTy);
   Constant *Offset = ConstantInt::get(IdxTy, 42);
   Value *PtrAdd = IRB.CreatePtrAdd(Null, Offset);
   Value *OtherGEP = IRB.CreateGEP(IdxTy, Null, Offset);
-  Value *PtrAddConst =
-      ConstantExpr::getGetElementPtr(Type::getInt8Ty(Ctx), Null, Offset);
+  Value *PtrAddConst = ConstantExpr::getGetElementPtr(
+      Type::getIntNTy(Ctx, DL.getByteWidth()), Null, Offset);
 
   Value *A, *B;
-  EXPECT_TRUE(match(PtrAdd, m_PtrAdd(m_Value(A), m_Value(B))));
+  EXPECT_TRUE(match(PtrAdd, m_PtrAdd(DL, m_Value(A), m_Value(B))));
   EXPECT_EQ(A, Null);
   EXPECT_EQ(B, Offset);
 
-  EXPECT_TRUE(match(PtrAddConst, m_PtrAdd(m_Value(A), m_Value(B))));
+  EXPECT_TRUE(match(PtrAddConst, m_PtrAdd(DL, m_Value(A), m_Value(B))));
   EXPECT_EQ(A, Null);
   EXPECT_EQ(B, Offset);
 
-  EXPECT_FALSE(match(OtherGEP, m_PtrAdd(m_Value(A), m_Value(B))));
+  EXPECT_FALSE(match(OtherGEP, m_PtrAdd(DL, m_Value(A), m_Value(B))));
 }
 
 } // anonymous namespace.



More information about the llvm-branch-commits mailing list