[llvm] ae978ba - [LoopFlatten] Recognise gep+gep (#72515)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 10 06:33:04 PST 2024


Author: John Brawn
Date: 2024-01-10T14:32:59Z
New Revision: ae978baaf6cc5566036b89ceaadcabb47361ba2f

URL: https://github.com/llvm/llvm-project/commit/ae978baaf6cc5566036b89ceaadcabb47361ba2f
DIFF: https://github.com/llvm/llvm-project/commit/ae978baaf6cc5566036b89ceaadcabb47361ba2f.diff

LOG: [LoopFlatten] Recognise gep+gep (#72515)

Now that InstCombine canonicalises add+gep to gep+gep, LoopFlatten needs
to recognise (gep (gep ptr (i*M)), j) as being something it can
optimise.

Added: 
    llvm/test/Transforms/LoopFlatten/loop-flatten-gep.ll

Modified: 
    llvm/include/llvm/IR/PatternMatch.h
    llvm/lib/Transforms/Scalar/LoopFlatten.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 447ac0f2aa6139..90d99a6031c834 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1495,6 +1495,36 @@ struct ThreeOps_match {
   }
 };
 
+/// Matches instructions with Opcode and any number of operands
+template <unsigned Opcode, typename... OperandTypes> struct AnyOps_match {
+  std::tuple<OperandTypes...> Operands;
+
+  AnyOps_match(const OperandTypes &...Ops) : Operands(Ops...) {}
+
+  // Operand matching works by recursively calling match_operands, matching the
+  // operands left to right. The first version is called for each operand but
+  // the last, for which the second version is called. The second version of
+  // match_operands is also used to match each individual operand.
+  template <int Idx, int Last>
+  std::enable_if_t<Idx != Last, bool> match_operands(const Instruction *I) {
+    return match_operands<Idx, Idx>(I) && match_operands<Idx + 1, Last>(I);
+  }
+
+  template <int Idx, int Last>
+  std::enable_if_t<Idx == Last, bool> match_operands(const Instruction *I) {
+    return std::get<Idx>(Operands).match(I->getOperand(Idx));
+  }
+
+  template <typename OpTy> bool match(OpTy *V) {
+    if (V->getValueID() == Value::InstructionVal + Opcode) {
+      auto *I = cast<Instruction>(V);
+      return I->getNumOperands() == sizeof...(OperandTypes) &&
+             match_operands<0, sizeof...(OperandTypes) - 1>(I);
+    }
+    return false;
+  }
+};
+
 /// Matches SelectInst.
 template <typename Cond, typename LHS, typename RHS>
 inline ThreeOps_match<Cond, LHS, RHS, Instruction::Select>
@@ -1611,6 +1641,12 @@ m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp) {
                                                                   PointerOp);
 }
 
+/// Matches GetElementPtrInst.
+template <typename... OperandTypes>
+inline auto m_GEP(const OperandTypes &...Ops) {
+  return AnyOps_match<Instruction::GetElementPtr, OperandTypes...>(Ops...);
+}
+
 //===----------------------------------------------------------------------===//
 // Matchers for CastInst classes
 //

diff  --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index eef94636578d83..533cefaf106133 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -207,6 +207,12 @@ struct FlattenInfo {
         match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)),
                                   m_Value(MatchedItCount)));
 
+    // Matches the pattern ptr+i*M+j, with the two additions being done via GEP.
+    bool IsGEP = match(U, m_GEP(m_GEP(m_Value(), m_Value(MatchedMul)),
+                                m_Specific(InnerInductionPHI))) &&
+                 match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),
+                                           m_Value(MatchedItCount)));
+
     if (!MatchedItCount)
       return false;
 
@@ -224,7 +230,7 @@ struct FlattenInfo {
 
     // Look through extends if the IV has been widened. Don't look through
     // extends if we already looked through a trunc.
-    if (Widened && IsAdd &&
+    if (Widened && (IsAdd || IsGEP) &&
         (isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
       assert(MatchedItCount->getType() == InnerInductionPHI->getType() &&
              "Unexpected type mismatch in types after widening");
@@ -236,7 +242,7 @@ struct FlattenInfo {
     LLVM_DEBUG(dbgs() << "Looking for inner trip count: ";
                InnerTripCount->dump());
 
-    if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
+    if ((IsAdd || IsAddTrunc || IsGEP) && MatchedItCount == InnerTripCount) {
       LLVM_DEBUG(dbgs() << "Found. This sse is optimisable\n");
       ValidOuterPHIUses.insert(MatchedMul);
       LinearIVUses.insert(U);
@@ -646,33 +652,40 @@ static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
   if (OR != OverflowResult::MayOverflow)
     return OR;
 
-  for (Value *V : FI.LinearIVUses) {
-    for (Value *U : V->users()) {
-      if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
-        for (Value *GEPUser : U->users()) {
-          auto *GEPUserInst = cast<Instruction>(GEPUser);
-          if (!isa<LoadInst>(GEPUserInst) &&
-              !(isa<StoreInst>(GEPUserInst) &&
-                GEP == GEPUserInst->getOperand(1)))
-            continue;
-          if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst,
-                                                      FI.InnerLoop))
-            continue;
-          // The IV is used as the operand of a GEP which dominates the loop
-          // latch, and the IV is at least as wide as the address space of the
-          // GEP. In this case, the GEP would wrap around the address space
-          // before the IV increment wraps, which would be UB.
-          if (GEP->isInBounds() &&
-              V->getType()->getIntegerBitWidth() >=
-                  DL.getPointerTypeSizeInBits(GEP->getType())) {
-            LLVM_DEBUG(
-                dbgs() << "use of linear IV would be UB if overflow occurred: ";
-                GEP->dump());
-            return OverflowResult::NeverOverflows;
-          }
-        }
+  auto CheckGEP = [&](GetElementPtrInst *GEP, Value *GEPOperand) {
+    for (Value *GEPUser : GEP->users()) {
+      auto *GEPUserInst = cast<Instruction>(GEPUser);
+      if (!isa<LoadInst>(GEPUserInst) &&
+          !(isa<StoreInst>(GEPUserInst) && GEP == GEPUserInst->getOperand(1)))
+        continue;
+      if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst, FI.InnerLoop))
+        continue;
+      // The IV is used as the operand of a GEP which dominates the loop
+      // latch, and the IV is at least as wide as the address space of the
+      // GEP. In this case, the GEP would wrap around the address space
+      // before the IV increment wraps, which would be UB.
+      if (GEP->isInBounds() &&
+          GEPOperand->getType()->getIntegerBitWidth() >=
+              DL.getPointerTypeSizeInBits(GEP->getType())) {
+        LLVM_DEBUG(
+            dbgs() << "use of linear IV would be UB if overflow occurred: ";
+            GEP->dump());
+        return true;
       }
     }
+    return false;
+  };
+
+  // Check if any IV user is, or is used by, a GEP that would cause UB if the
+  // multiply overflows.
+  for (Value *V : FI.LinearIVUses) {
+    if (auto *GEP = dyn_cast<GetElementPtrInst>(V))
+      if (GEP->getNumIndices() == 1 && CheckGEP(GEP, GEP->getOperand(1)))
+        return OverflowResult::NeverOverflows;
+    for (Value *U : V->users())
+      if (auto *GEP = dyn_cast<GetElementPtrInst>(U))
+        if (CheckGEP(GEP, V))
+          return OverflowResult::NeverOverflows;
   }
 
   return OverflowResult::MayOverflow;
@@ -778,6 +791,18 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
       OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(),
                                        "flatten.trunciv");
 
+    if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
+      // Replace the GEP with one that uses OuterValue as the offset.
+      auto *InnerGEP = cast<GetElementPtrInst>(GEP->getOperand(0));
+      Value *Base = InnerGEP->getOperand(0);
+      // When the base of the GEP doesn't dominate the outer induction phi then
+      // we need to insert the new GEP where the old GEP was.
+      if (!DT->dominates(Base, &*Builder.GetInsertPoint()))
+        Builder.SetInsertPoint(cast<Instruction>(V));
+      OuterValue = Builder.CreateGEP(GEP->getSourceElementType(), Base,
+                                     OuterValue, "flatten." + V->getName());
+    }
+
     LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with:      ";
                OuterValue->dump());
     V->replaceAllUsesWith(OuterValue);

diff  --git a/llvm/test/Transforms/LoopFlatten/loop-flatten-gep.ll b/llvm/test/Transforms/LoopFlatten/loop-flatten-gep.ll
new file mode 100644
index 00000000000000..f4b8ea97237fe6
--- /dev/null
+++ b/llvm/test/Transforms/LoopFlatten/loop-flatten-gep.ll
@@ -0,0 +1,137 @@
+; RUN: opt < %s -S -passes='loop(loop-flatten),verify' -verify-loop-info -verify-dom-info -verify-scev | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64"
+
+; We should be able to flatten the loops and turn the two geps into one.
+; CHECK-LABEL: test1
+define void @test1(i32 %N, ptr %A) {
+entry:
+  %cmp3 = icmp ult i32 0, %N
+  br i1 %cmp3, label %for.outer.preheader, label %for.end
+
+; CHECK-LABEL: for.outer.preheader:
+; CHECK: %flatten.tripcount = mul i32 %N, %N
+for.outer.preheader:
+  br label %for.inner.preheader
+
+; CHECK-LABEL: for.inner.preheader:
+; CHECK: %flatten.arrayidx = getelementptr i32, ptr %A, i32 %i
+for.inner.preheader:
+  %i = phi i32 [ 0, %for.outer.preheader ], [ %inc2, %for.outer ]
+  br label %for.inner
+
+; CHECK-LABEL: for.inner:
+; CHECK: store i32 0, ptr %flatten.arrayidx, align 4
+; CHECK: br label %for.outer
+for.inner:
+  %j = phi i32 [ 0, %for.inner.preheader ], [ %inc1, %for.inner ]
+  %mul = mul i32 %i, %N
+  %gep = getelementptr inbounds i32, ptr %A, i32 %mul
+  %arrayidx = getelementptr inbounds i32, ptr %gep, i32 %j
+  store i32 0, ptr %arrayidx, align 4
+  %inc1 = add nuw i32 %j, 1
+  %cmp2 = icmp ult i32 %inc1, %N
+  br i1 %cmp2, label %for.inner, label %for.outer
+
+; CHECK-LABEL: for.outer:
+; CHECK: %cmp1 = icmp ult i32 %inc2, %flatten.tripcount
+for.outer:
+  %inc2 = add i32 %i, 1
+  %cmp1 = icmp ult i32 %inc2, %N
+  br i1 %cmp1, label %for.inner.preheader, label %for.end.loopexit
+
+for.end.loopexit:
+  br label %for.end
+
+for.end:
+  ret void
+}
+
+; We can flatten, but the flattened gep has to be inserted after the load it
+; depends on.
+; CHECK-LABEL: test2
+define void @test2(i32 %N, ptr %A) {
+entry:
+  %cmp3 = icmp ult i32 0, %N
+  br i1 %cmp3, label %for.outer.preheader, label %for.end
+
+; CHECK-LABEL: for.outer.preheader:
+; CHECK: %flatten.tripcount = mul i32 %N, %N
+for.outer.preheader:
+  br label %for.inner.preheader
+
+; CHECK-LABEL: for.inner.preheader:
+; CHECK-NOT: getelementptr i32, ptr %ptr, i32 %i
+for.inner.preheader:
+  %i = phi i32 [ 0, %for.outer.preheader ], [ %inc2, %for.outer ]
+  br label %for.inner
+
+; CHECK-LABEL: for.inner:
+; CHECK: %flatten.arrayidx = getelementptr i32, ptr %ptr, i32 %i
+; CHECK: store i32 0, ptr %flatten.arrayidx, align 4
+; CHECK: br label %for.outer
+for.inner:
+  %j = phi i32 [ 0, %for.inner.preheader ], [ %inc1, %for.inner ]
+  %ptr = load volatile ptr, ptr %A, align 4
+  %mul = mul i32 %i, %N
+  %gep = getelementptr inbounds i32, ptr %ptr, i32 %mul
+  %arrayidx = getelementptr inbounds i32, ptr %gep, i32 %j
+  store i32 0, ptr %arrayidx, align 4
+  %inc1 = add nuw i32 %j, 1
+  %cmp2 = icmp ult i32 %inc1, %N
+  br i1 %cmp2, label %for.inner, label %for.outer
+
+; CHECK-LABEL: for.outer:
+; CHECK: %cmp1 = icmp ult i32 %inc2, %flatten.tripcount
+for.outer:
+  %inc2 = add i32 %i, 1
+  %cmp1 = icmp ult i32 %inc2, %N
+  br i1 %cmp1, label %for.inner.preheader, label %for.end.loopexit
+
+for.end.loopexit:
+  br label %for.end
+
+for.end:
+  ret void
+}
+
+; We can't flatten if the gep offset is smaller than the pointer size.
+; CHECK-LABEL: test3
+define void @test3(i16 %N, ptr %A) {
+entry:
+  %cmp3 = icmp ult i16 0, %N
+  br i1 %cmp3, label %for.outer.preheader, label %for.end
+
+for.outer.preheader:
+  br label %for.inner.preheader
+
+; CHECK-LABEL: for.inner.preheader:
+; CHECK-NOT: getelementptr i32, ptr %A, i16 %i
+for.inner.preheader:
+  %i = phi i16 [ 0, %for.outer.preheader ], [ %inc2, %for.outer ]
+  br label %for.inner
+
+; CHECK-LABEL: for.inner:
+; CHECK-NOT: getelementptr i32, ptr %A, i16 %i
+; CHECK: br i1 %cmp2, label %for.inner, label %for.outer
+for.inner:
+  %j = phi i16 [ 0, %for.inner.preheader ], [ %inc1, %for.inner ]
+  %mul = mul i16 %i, %N
+  %gep = getelementptr inbounds i32, ptr %A, i16 %mul
+  %arrayidx = getelementptr inbounds i32, ptr %gep, i16 %j
+  store i32 0, ptr %arrayidx, align 4
+  %inc1 = add nuw i16 %j, 1
+  %cmp2 = icmp ult i16 %inc1, %N
+  br i1 %cmp2, label %for.inner, label %for.outer
+
+for.outer:
+  %inc2 = add i16 %i, 1
+  %cmp1 = icmp ult i16 %inc2, %N
+  br i1 %cmp1, label %for.inner.preheader, label %for.end.loopexit
+
+for.end.loopexit:
+  br label %for.end
+
+for.end:
+  ret void
+}


        


More information about the llvm-commits mailing list