[llvm] ae978ba - [LoopFlatten] Recognise gep+gep (#72515)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 10 06:33:04 PST 2024
Author: John Brawn
Date: 2024-01-10T14:32:59Z
New Revision: ae978baaf6cc5566036b89ceaadcabb47361ba2f
URL: https://github.com/llvm/llvm-project/commit/ae978baaf6cc5566036b89ceaadcabb47361ba2f
DIFF: https://github.com/llvm/llvm-project/commit/ae978baaf6cc5566036b89ceaadcabb47361ba2f.diff
LOG: [LoopFlatten] Recognise gep+gep (#72515)
Now that InstCombine canonicalises add+gep to gep+gep, LoopFlatten needs
to recognise (gep (gep ptr (i*M)), j) as being something it can
optimise.
Added:
llvm/test/Transforms/LoopFlatten/loop-flatten-gep.ll
Modified:
llvm/include/llvm/IR/PatternMatch.h
llvm/lib/Transforms/Scalar/LoopFlatten.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 447ac0f2aa6139..90d99a6031c834 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1495,6 +1495,36 @@ struct ThreeOps_match {
}
};
+/// Matches instructions with Opcode and any number of operands
+template <unsigned Opcode, typename... OperandTypes> struct AnyOps_match {
+ std::tuple<OperandTypes...> Operands;
+
+ AnyOps_match(const OperandTypes &...Ops) : Operands(Ops...) {}
+
+ // Operand matching works by recursively calling match_operands, matching the
+ // operands left to right. The first version is called for each operand but
+ // the last, for which the second version is called. The second version of
+ // match_operands is also used to match each individual operand.
+ template <int Idx, int Last>
+ std::enable_if_t<Idx != Last, bool> match_operands(const Instruction *I) {
+ return match_operands<Idx, Idx>(I) && match_operands<Idx + 1, Last>(I);
+ }
+
+ template <int Idx, int Last>
+ std::enable_if_t<Idx == Last, bool> match_operands(const Instruction *I) {
+ return std::get<Idx>(Operands).match(I->getOperand(Idx));
+ }
+
+ template <typename OpTy> bool match(OpTy *V) {
+ if (V->getValueID() == Value::InstructionVal + Opcode) {
+ auto *I = cast<Instruction>(V);
+ return I->getNumOperands() == sizeof...(OperandTypes) &&
+ match_operands<0, sizeof...(OperandTypes) - 1>(I);
+ }
+ return false;
+ }
+};
+
/// Matches SelectInst.
template <typename Cond, typename LHS, typename RHS>
inline ThreeOps_match<Cond, LHS, RHS, Instruction::Select>
@@ -1611,6 +1641,12 @@ m_Store(const ValueOpTy &ValueOp, const PointerOpTy &PointerOp) {
PointerOp);
}
+/// Matches GetElementPtrInst.
+template <typename... OperandTypes>
+inline auto m_GEP(const OperandTypes &...Ops) {
+ return AnyOps_match<Instruction::GetElementPtr, OperandTypes...>(Ops...);
+}
+
//===----------------------------------------------------------------------===//
// Matchers for CastInst classes
//
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index eef94636578d83..533cefaf106133 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -207,6 +207,12 @@ struct FlattenInfo {
match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(OuterInductionPHI)),
m_Value(MatchedItCount)));
+ // Matches the pattern ptr+i*M+j, with the two additions being done via GEP.
+ bool IsGEP = match(U, m_GEP(m_GEP(m_Value(), m_Value(MatchedMul)),
+ m_Specific(InnerInductionPHI))) &&
+ match(MatchedMul, m_c_Mul(m_Specific(OuterInductionPHI),
+ m_Value(MatchedItCount)));
+
if (!MatchedItCount)
return false;
@@ -224,7 +230,7 @@ struct FlattenInfo {
// Look through extends if the IV has been widened. Don't look through
// extends if we already looked through a trunc.
- if (Widened && IsAdd &&
+ if (Widened && (IsAdd || IsGEP) &&
(isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
assert(MatchedItCount->getType() == InnerInductionPHI->getType() &&
"Unexpected type mismatch in types after widening");
@@ -236,7 +242,7 @@ struct FlattenInfo {
LLVM_DEBUG(dbgs() << "Looking for inner trip count: ";
InnerTripCount->dump());
- if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
+ if ((IsAdd || IsAddTrunc || IsGEP) && MatchedItCount == InnerTripCount) {
LLVM_DEBUG(dbgs() << "Found. This sse is optimisable\n");
ValidOuterPHIUses.insert(MatchedMul);
LinearIVUses.insert(U);
@@ -646,33 +652,40 @@ static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
if (OR != OverflowResult::MayOverflow)
return OR;
- for (Value *V : FI.LinearIVUses) {
- for (Value *U : V->users()) {
- if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
- for (Value *GEPUser : U->users()) {
- auto *GEPUserInst = cast<Instruction>(GEPUser);
- if (!isa<LoadInst>(GEPUserInst) &&
- !(isa<StoreInst>(GEPUserInst) &&
- GEP == GEPUserInst->getOperand(1)))
- continue;
- if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst,
- FI.InnerLoop))
- continue;
- // The IV is used as the operand of a GEP which dominates the loop
- // latch, and the IV is at least as wide as the address space of the
- // GEP. In this case, the GEP would wrap around the address space
- // before the IV increment wraps, which would be UB.
- if (GEP->isInBounds() &&
- V->getType()->getIntegerBitWidth() >=
- DL.getPointerTypeSizeInBits(GEP->getType())) {
- LLVM_DEBUG(
- dbgs() << "use of linear IV would be UB if overflow occurred: ";
- GEP->dump());
- return OverflowResult::NeverOverflows;
- }
- }
+ auto CheckGEP = [&](GetElementPtrInst *GEP, Value *GEPOperand) {
+ for (Value *GEPUser : GEP->users()) {
+ auto *GEPUserInst = cast<Instruction>(GEPUser);
+ if (!isa<LoadInst>(GEPUserInst) &&
+ !(isa<StoreInst>(GEPUserInst) && GEP == GEPUserInst->getOperand(1)))
+ continue;
+ if (!isGuaranteedToExecuteForEveryIteration(GEPUserInst, FI.InnerLoop))
+ continue;
+ // The IV is used as the operand of a GEP which dominates the loop
+ // latch, and the IV is at least as wide as the address space of the
+ // GEP. In this case, the GEP would wrap around the address space
+ // before the IV increment wraps, which would be UB.
+ if (GEP->isInBounds() &&
+ GEPOperand->getType()->getIntegerBitWidth() >=
+ DL.getPointerTypeSizeInBits(GEP->getType())) {
+ LLVM_DEBUG(
+ dbgs() << "use of linear IV would be UB if overflow occurred: ";
+ GEP->dump());
+ return true;
}
}
+ return false;
+ };
+
+ // Check if any IV user is, or is used by, a GEP that would cause UB if the
+ // multiply overflows.
+ for (Value *V : FI.LinearIVUses) {
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(V))
+ if (GEP->getNumIndices() == 1 && CheckGEP(GEP, GEP->getOperand(1)))
+ return OverflowResult::NeverOverflows;
+ for (Value *U : V->users())
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(U))
+ if (CheckGEP(GEP, V))
+ return OverflowResult::NeverOverflows;
}
return OverflowResult::MayOverflow;
@@ -778,6 +791,18 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(),
"flatten.trunciv");
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
+ // Replace the GEP with one that uses OuterValue as the offset.
+ auto *InnerGEP = cast<GetElementPtrInst>(GEP->getOperand(0));
+ Value *Base = InnerGEP->getOperand(0);
+ // When the base of the GEP doesn't dominate the outer induction phi then
+ // we need to insert the new GEP where the old GEP was.
+ if (!DT->dominates(Base, &*Builder.GetInsertPoint()))
+ Builder.SetInsertPoint(cast<Instruction>(V));
+ OuterValue = Builder.CreateGEP(GEP->getSourceElementType(), Base,
+ OuterValue, "flatten." + V->getName());
+ }
+
LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with: ";
OuterValue->dump());
V->replaceAllUsesWith(OuterValue);
diff --git a/llvm/test/Transforms/LoopFlatten/loop-flatten-gep.ll b/llvm/test/Transforms/LoopFlatten/loop-flatten-gep.ll
new file mode 100644
index 00000000000000..f4b8ea97237fe6
--- /dev/null
+++ b/llvm/test/Transforms/LoopFlatten/loop-flatten-gep.ll
@@ -0,0 +1,137 @@
+; RUN: opt < %s -S -passes='loop(loop-flatten),verify' -verify-loop-info -verify-dom-info -verify-scev | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64"
+
+; We should be able to flatten the loops and turn the two geps into one.
+; CHECK-LABEL: test1
+define void @test1(i32 %N, ptr %A) {
+entry:
+ %cmp3 = icmp ult i32 0, %N
+ br i1 %cmp3, label %for.outer.preheader, label %for.end
+
+; CHECK-LABEL: for.outer.preheader:
+; CHECK: %flatten.tripcount = mul i32 %N, %N
+for.outer.preheader:
+ br label %for.inner.preheader
+
+; CHECK-LABEL: for.inner.preheader:
+; CHECK: %flatten.arrayidx = getelementptr i32, ptr %A, i32 %i
+for.inner.preheader:
+ %i = phi i32 [ 0, %for.outer.preheader ], [ %inc2, %for.outer ]
+ br label %for.inner
+
+; CHECK-LABEL: for.inner:
+; CHECK: store i32 0, ptr %flatten.arrayidx, align 4
+; CHECK: br label %for.outer
+for.inner:
+ %j = phi i32 [ 0, %for.inner.preheader ], [ %inc1, %for.inner ]
+ %mul = mul i32 %i, %N
+ %gep = getelementptr inbounds i32, ptr %A, i32 %mul
+ %arrayidx = getelementptr inbounds i32, ptr %gep, i32 %j
+ store i32 0, ptr %arrayidx, align 4
+ %inc1 = add nuw i32 %j, 1
+ %cmp2 = icmp ult i32 %inc1, %N
+ br i1 %cmp2, label %for.inner, label %for.outer
+
+; CHECK-LABEL: for.outer:
+; CHECK: %cmp1 = icmp ult i32 %inc2, %flatten.tripcount
+for.outer:
+ %inc2 = add i32 %i, 1
+ %cmp1 = icmp ult i32 %inc2, %N
+ br i1 %cmp1, label %for.inner.preheader, label %for.end.loopexit
+
+for.end.loopexit:
+ br label %for.end
+
+for.end:
+ ret void
+}
+
+; We can flatten, but the flattened gep has to be inserted after the load it
+; depends on.
+; CHECK-LABEL: test2
+define void @test2(i32 %N, ptr %A) {
+entry:
+ %cmp3 = icmp ult i32 0, %N
+ br i1 %cmp3, label %for.outer.preheader, label %for.end
+
+; CHECK-LABEL: for.outer.preheader:
+; CHECK: %flatten.tripcount = mul i32 %N, %N
+for.outer.preheader:
+ br label %for.inner.preheader
+
+; CHECK-LABEL: for.inner.preheader:
+; CHECK-NOT: getelementptr i32, ptr %ptr, i32 %i
+for.inner.preheader:
+ %i = phi i32 [ 0, %for.outer.preheader ], [ %inc2, %for.outer ]
+ br label %for.inner
+
+; CHECK-LABEL: for.inner:
+; CHECK: %flatten.arrayidx = getelementptr i32, ptr %ptr, i32 %i
+; CHECK: store i32 0, ptr %flatten.arrayidx, align 4
+; CHECK: br label %for.outer
+for.inner:
+ %j = phi i32 [ 0, %for.inner.preheader ], [ %inc1, %for.inner ]
+ %ptr = load volatile ptr, ptr %A, align 4
+ %mul = mul i32 %i, %N
+ %gep = getelementptr inbounds i32, ptr %ptr, i32 %mul
+ %arrayidx = getelementptr inbounds i32, ptr %gep, i32 %j
+ store i32 0, ptr %arrayidx, align 4
+ %inc1 = add nuw i32 %j, 1
+ %cmp2 = icmp ult i32 %inc1, %N
+ br i1 %cmp2, label %for.inner, label %for.outer
+
+; CHECK-LABEL: for.outer:
+; CHECK: %cmp1 = icmp ult i32 %inc2, %flatten.tripcount
+for.outer:
+ %inc2 = add i32 %i, 1
+ %cmp1 = icmp ult i32 %inc2, %N
+ br i1 %cmp1, label %for.inner.preheader, label %for.end.loopexit
+
+for.end.loopexit:
+ br label %for.end
+
+for.end:
+ ret void
+}
+
+; We can't flatten if the gep offset is smaller than the pointer size.
+; CHECK-LABEL: test3
+define void @test3(i16 %N, ptr %A) {
+entry:
+ %cmp3 = icmp ult i16 0, %N
+ br i1 %cmp3, label %for.outer.preheader, label %for.end
+
+for.outer.preheader:
+ br label %for.inner.preheader
+
+; CHECK-LABEL: for.inner.preheader:
+; CHECK-NOT: getelementptr i32, ptr %A, i16 %i
+for.inner.preheader:
+ %i = phi i16 [ 0, %for.outer.preheader ], [ %inc2, %for.outer ]
+ br label %for.inner
+
+; CHECK-LABEL: for.inner:
+; CHECK-NOT: getelementptr i32, ptr %A, i16 %i
+; CHECK: br i1 %cmp2, label %for.inner, label %for.outer
+for.inner:
+ %j = phi i16 [ 0, %for.inner.preheader ], [ %inc1, %for.inner ]
+ %mul = mul i16 %i, %N
+ %gep = getelementptr inbounds i32, ptr %A, i16 %mul
+ %arrayidx = getelementptr inbounds i32, ptr %gep, i16 %j
+ store i32 0, ptr %arrayidx, align 4
+ %inc1 = add nuw i16 %j, 1
+ %cmp2 = icmp ult i16 %inc1, %N
+ br i1 %cmp2, label %for.inner, label %for.outer
+
+for.outer:
+ %inc2 = add i16 %i, 1
+ %cmp1 = icmp ult i16 %inc2, %N
+ br i1 %cmp1, label %for.inner.preheader, label %for.end.loopexit
+
+for.end.loopexit:
+ br label %for.end
+
+for.end:
+ ret void
+}
More information about the llvm-commits
mailing list