[llvm] 18e1373 - [BPF] Undo transformation for LICM.cpp:hoistMinMax()
Eduard Zingerman via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 11 12:32:37 PDT 2023
Author: Eduard Zingerman
Date: 2023-07-11T22:30:34+03:00
New Revision: 18e13739b8c02e0b82e3e587c33e8731c8a46b0a
URL: https://github.com/llvm/llvm-project/commit/18e13739b8c02e0b82e3e587c33e8731c8a46b0a
DIFF: https://github.com/llvm/llvm-project/commit/18e13739b8c02e0b82e3e587c33e8731c8a46b0a.diff
LOG: [BPF] Undo transformation for LICM.cpp:hoistMinMax()
Extended BPFCheckAndAdjustIR pass with sinkMinMax() transformation
that undoes LICM hoistMinMax pass.
The undo transformation converts the following patterns:
x < min(a, b) -> x < a && x < b
x > min(a, b) -> x > a || x > b
x < max(a, b) -> x < a || x < b
x > max(a, b) -> x > a && x > b
Where 'a' or 'b' is a constant.
Also supports `sext min(...) ...` and `zext min(...) ...`.
~~~
This was previously commited as 09feee559a29 and reverted in
0bf9bfeacc8c because of the testbot memory leak report:
https://lab.llvm.org/buildbot/#/builders/5/builds/34931
The memory leak issue was caused by incorrect instruction removal
sequence in skinMinMaxBB():
I->dropAllReferences(); --------> I->eraseFromParent();
I->removeFromParent(); fixed to
Differential Revision: https://reviews.llvm.org/D147990
Added:
llvm/test/CodeGen/BPF/sink-min-max.ll
Modified:
llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp b/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp
index 6b74e56d6b3e9b..a3616ae7ebabef 100644
--- a/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp
+++ b/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp
@@ -18,8 +18,10 @@
#include "BPF.h"
#include "BPFCORE.h"
#include "BPFTargetMachine.h"
+#include "llvm/Analysis/LoopInfo.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
@@ -41,12 +43,14 @@ class BPFCheckAndAdjustIR final : public ModulePass {
public:
static char ID;
BPFCheckAndAdjustIR() : ModulePass(ID) {}
+ virtual void getAnalysisUsage(AnalysisUsage &AU) const override;
private:
void checkIR(Module &M);
bool adjustIR(Module &M);
bool removePassThroughBuiltin(Module &M);
bool removeCompareBuiltin(Module &M);
+ bool sinkMinMax(Module &M);
};
} // End anonymous namespace
@@ -161,9 +165,206 @@ bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) {
return Changed;
}
+struct MinMaxSinkInfo {
+ ICmpInst *ICmp;
+ Value *Other;
+ ICmpInst::Predicate Predicate;
+ CallInst *MinMax;
+ ZExtInst *ZExt;
+ SExtInst *SExt;
+
+ MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate)
+ : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr),
+ ZExt(nullptr), SExt(nullptr) {}
+};
+
+static bool sinkMinMaxInBB(BasicBlock &BB,
+ const std::function<bool(Instruction *)> &Filter) {
+ // Check if V is:
+ // (fn %a %b) or (ext (fn %a %b))
+ // Where:
+ // ext := sext | zext
+ // fn := smin | umin | smax | umax
+ auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) {
+ if (auto *ZExt = dyn_cast<ZExtInst>(V)) {
+ V = ZExt->getOperand(0);
+ Info.ZExt = ZExt;
+ } else if (auto *SExt = dyn_cast<SExtInst>(V)) {
+ V = SExt->getOperand(0);
+ Info.SExt = SExt;
+ }
+
+ auto *Call = dyn_cast<CallInst>(V);
+ if (!Call)
+ return false;
+
+ auto *Called = dyn_cast<Function>(Call->getCalledOperand());
+ if (!Called)
+ return false;
+
+ switch (Called->getIntrinsicID()) {
+ case Intrinsic::smin:
+ case Intrinsic::umin:
+ case Intrinsic::smax:
+ case Intrinsic::umax:
+ break;
+ default:
+ return false;
+ }
+
+ if (!Filter(Call))
+ return false;
+
+ Info.MinMax = Call;
+
+ return true;
+ };
+
+ auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V,
+ MinMaxSinkInfo &Info) {
+ if (Info.SExt) {
+ if (Info.SExt->getType() == V->getType())
+ return V;
+ return Builder.CreateSExt(V, Info.SExt->getType());
+ }
+ if (Info.ZExt) {
+ if (Info.ZExt->getType() == V->getType())
+ return V;
+ return Builder.CreateZExt(V, Info.ZExt->getType());
+ }
+ return V;
+ };
+
+ bool Changed = false;
+ SmallVector<MinMaxSinkInfo, 2> SinkList;
+
+ // Check BB for instructions like:
+ // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a)
+ //
+ // Where:
+ // fn := min | max | (sext (min ...)) | (sext (max ...))
+ //
+ // Put such instructions to SinkList.
+ for (Instruction &I : BB) {
+ ICmpInst *ICmp = dyn_cast<ICmpInst>(&I);
+ if (!ICmp)
+ continue;
+ if (!ICmp->isRelational())
+ continue;
+ MinMaxSinkInfo First(ICmp, ICmp->getOperand(1),
+ ICmpInst::getSwappedPredicate(ICmp->getPredicate()));
+ MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate());
+ bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First);
+ bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second);
+ if (!(FirstMinMax ^ SecondMinMax))
+ continue;
+ SinkList.push_back(FirstMinMax ? First : Second);
+ }
+
+ // Iterate SinkList and replace each (icmp ...) with corresponding
+ // `x < a && x < b` or similar expression.
+ for (auto &Info : SinkList) {
+ ICmpInst *ICmp = Info.ICmp;
+ CallInst *MinMax = Info.MinMax;
+ Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID();
+ ICmpInst::Predicate P = Info.Predicate;
+ if (ICmpInst::isSigned(P) && IID != Intrinsic::smin &&
+ IID != Intrinsic::smax)
+ continue;
+
+ IRBuilder<> Builder(ICmp);
+ Value *X = Info.Other;
+ Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info);
+ Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info);
+ bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin;
+ bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax;
+ bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P);
+ bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P);
+ assert(IsMin ^ IsMax);
+ assert(IsLess ^ IsGreater);
+
+ Value *Replacement;
+ Value *LHS = Builder.CreateICmp(P, X, A);
+ Value *RHS = Builder.CreateICmp(P, X, B);
+ if ((IsLess && IsMin) || (IsGreater && IsMax))
+ // x < min(a, b) -> x < a && x < b
+ // x > max(a, b) -> x > a && x > b
+ Replacement = Builder.CreateLogicalAnd(LHS, RHS);
+ else
+ // x > min(a, b) -> x > a || x > b
+ // x < max(a, b) -> x < a || x < b
+ Replacement = Builder.CreateLogicalOr(LHS, RHS);
+
+ ICmp->replaceAllUsesWith(Replacement);
+
+ Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax};
+ for (Instruction *I : ToRemove)
+ if (I && I->use_empty())
+ I->eraseFromParent();
+
+ Changed = true;
+ }
+
+ return Changed;
+}
+
+// Do the following transformation:
+//
+// x < min(a, b) -> x < a && x < b
+// x > min(a, b) -> x > a || x > b
+// x < max(a, b) -> x < a || x < b
+// x > max(a, b) -> x > a && x > b
+//
+// Such patterns are introduced by LICM.cpp:hoistMinMax()
+// transformation and might lead to BPF verification failures for
+// older kernels.
+//
+// To minimize "collateral" changes only do it for icmp + min/max
+// calls when icmp is inside a loop and min/max is outside of that
+// loop.
+//
+// Verification failure happens when:
+// - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
+// - verifier can recognize RHS as a constant scalar in some context;
+// - verifier can't recognize RHS1 as a constant scalar in the same
+// context;
+//
+// The "constant scalar" is not a compile time constant, but a register
+// that holds a scalar value known to verifier at some point in time
+// during abstract interpretation.
+//
+// See also:
+// https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
+bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) {
+ bool Changed = false;
+
+ for (Function &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
+ for (Loop *L : LI)
+ for (BasicBlock *BB : L->blocks()) {
+ // Filter out instructions coming from the same loop
+ Loop *BBLoop = LI.getLoopFor(BB);
+ auto OtherLoopFilter = [&](Instruction *I) {
+ return LI.getLoopFor(I->getParent()) != BBLoop;
+ };
+ Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter);
+ }
+ }
+
+ return Changed;
+}
+
+void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<LoopInfoWrapperPass>();
+}
+
bool BPFCheckAndAdjustIR::adjustIR(Module &M) {
bool Changed = removePassThroughBuiltin(M);
Changed = removeCompareBuiltin(M) || Changed;
+ Changed = sinkMinMax(M) || Changed;
return Changed;
}
diff --git a/llvm/test/CodeGen/BPF/sink-min-max.ll b/llvm/test/CodeGen/BPF/sink-min-max.ll
new file mode 100644
index 00000000000000..5ee080839985d2
--- /dev/null
+++ b/llvm/test/CodeGen/BPF/sink-min-max.ll
@@ -0,0 +1,258 @@
+; RUN: opt --bpf-check-and-opt-ir -S -mtriple=bpf-pc-linux %s | FileCheck %s
+
+; Test plan:
+; @test1: x < umin(i64 a, i64 b)
+; @test2: x < umax(i64 a, i64 b)
+; @test3: x >= umin(i64 a, i64 b)
+; @test4: x >= umax(i64 a, i64 b)
+; @test5: umin(i64 a, i64 b) >= x
+; @test6: x < smin(i64 a, i64 b)
+; @test7: x < umin(i32 a, i32 b)
+; @test8: x < zext i64 umin(i32 a, i32 b)
+; @test9: x < sext i64 umin(i32 a, i32 b)
+; @test10: check that umin belonging to the same loop is not touched
+; @test11: check that nested loops are processed
+
+define i32 @test1(i64 %a, i64 %b, i64 %x) {
+entry:
+ %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
+ br label %loop
+loop:
+ %cmp = icmp ult i64 %x, %min
+ br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK: @test1
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT: loop:
+; CHECK-NEXT: %0 = icmp ult i64 %x, %a
+; CHECK-NEXT: %1 = icmp ult i64 %x, %b
+; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false
+; CHECK-NEXT: br i1 %2, label %loop, label %ret
+
+define i32 @test2(i64 %a, i64 %b, i64 %x) {
+entry:
+ %max = tail call i64 @llvm.umax.i64(i64 %a, i64 %b)
+ br label %loop
+loop:
+ %cmp = icmp ult i64 %x, %max
+ br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK: @test2
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT: loop:
+; CHECK-NEXT: %0 = icmp ult i64 %x, %a
+; CHECK-NEXT: %1 = icmp ult i64 %x, %b
+; CHECK-NEXT: %2 = select i1 %0, i1 true, i1 %1
+; CHECK-NEXT: br i1 %2, label %loop, label %ret
+
+define i32 @test3(i64 %a, i64 %b, i64 %x) {
+entry:
+ %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
+ br label %loop
+loop:
+ %cmp = icmp uge i64 %x, %min
+ br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK: @test3
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT: loop:
+; CHECK-NEXT: %0 = icmp uge i64 %x, %a
+; CHECK-NEXT: %1 = icmp uge i64 %x, %b
+; CHECK-NEXT: %2 = select i1 %0, i1 true, i1 %1
+; CHECK-NEXT: br i1 %2, label %loop, label %ret
+
+define i32 @test4(i64 %a, i64 %b, i64 %x) {
+entry:
+ %max = tail call i64 @llvm.umax.i64(i64 %a, i64 %b)
+ br label %loop
+loop:
+ %cmp = icmp uge i64 %x, %max
+ br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK: @test4
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT: loop:
+; CHECK-NEXT: %0 = icmp uge i64 %x, %a
+; CHECK-NEXT: %1 = icmp uge i64 %x, %b
+; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false
+; CHECK-NEXT: br i1 %2, label %loop, label %ret
+
+define i32 @test5(i64 %a, i64 %b, i64 %x) {
+entry:
+ %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
+ br label %loop
+loop:
+ %cmp = icmp uge i64 %min, %x
+ br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK: @test5
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT: loop:
+; CHECK: %0 = icmp ule i64 %x, %a
+; CHECK-NEXT: %1 = icmp ule i64 %x, %b
+; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false
+; CHECK-NEXT: br i1 %2, label %loop, label %ret
+
+define i32 @test6(i64 %a, i64 %b, i64 %x) {
+entry:
+ %min = tail call i64 @llvm.smin.i64(i64 %a, i64 %b)
+ br label %loop
+loop:
+ %cmp = icmp slt i64 %x, %min
+ br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK: @test6
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT: loop:
+; CHECK: %0 = icmp slt i64 %x, %a
+; CHECK-NEXT: %1 = icmp slt i64 %x, %b
+; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false
+; CHECK-NEXT: br i1 %2, label %loop, label %ret
+
+define i32 @test7(i32 %a, i32 %b, i32 %x) {
+entry:
+ %min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b)
+ br label %loop
+loop:
+ %cmp = icmp ult i32 %x, %min
+ br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK: @test7
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT: loop:
+; CHECK: %0 = icmp ult i32 %x, %a
+; CHECK-NEXT: %1 = icmp ult i32 %x, %b
+; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false
+; CHECK-NEXT: br i1 %2, label %loop, label %ret
+
+define i32 @test8(i32 %a, i32 %b, i64 %x) {
+entry:
+ %min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b)
+ br label %loop
+loop:
+ %ext = zext i32 %min to i64
+ %cmp = icmp ult i64 %x, %ext
+ br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK: @test8
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT: loop:
+; CHECK-NEXT: %0 = zext i32 %a to i64
+; CHECK-NEXT: %1 = zext i32 %b to i64
+; CHECK-NEXT: %2 = icmp ult i64 %x, %0
+; CHECK-NEXT: %3 = icmp ult i64 %x, %1
+; CHECK-NEXT: %4 = select i1 %2, i1 %3, i1 false
+; CHECK-NEXT: br i1 %4, label %loop, label %ret
+
+define i32 @test9(i32 %a, i32 %b, i64 %x) {
+entry:
+ %min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b)
+ br label %loop
+loop:
+ %ext = sext i32 %min to i64
+ %cmp = icmp ult i64 %x, %ext
+ br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK: @test9
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT: loop:
+; CHECK-NEXT: %0 = sext i32 %a to i64
+; CHECK-NEXT: %1 = sext i32 %b to i64
+; CHECK-NEXT: %2 = icmp ult i64 %x, %0
+; CHECK-NEXT: %3 = icmp ult i64 %x, %1
+; CHECK-NEXT: %4 = select i1 %2, i1 %3, i1 false
+; CHECK-NEXT: br i1 %4, label %loop, label %ret
+
+; umin within the loop body is unchanged
+define i32 @test10(i64 %a, i64 %b, i64 %x) {
+entry:
+ br label %loop
+loop:
+ %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
+ %cmp = icmp ult i64 %x, %min
+ br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK: @test10
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT: loop:
+; CHECK-NEXT: %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
+; CHECK-NEXT: %cmp = icmp ult i64 %x, %min
+; CHECK-NEXT: br i1 %cmp, label %loop, label %ret
+
+; umin from outer loop body is processed
+define i32 @test11(i64 %a, i64 %b, i64 %x) {
+entry:
+ br label %loop
+
+loop:
+ %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
+ br label %nested.loop
+nested.loop:
+ %cmp = icmp ult i64 %x, %min
+ br i1 %cmp, label %nested.loop, label %loop
+
+ret: ret i32 0
+}
+
+; CHECK: @test11
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT: loop:
+; CHECK-NEXT: br label %nested.loop
+; CHECK-EMPTY:
+; CHECK-NEXT: nested.loop:
+; CHECK-NEXT: %0 = icmp ult i64 %x, %a
+; CHECK-NEXT: %1 = icmp ult i64 %x, %b
+; CHECK-NEXT: %2 = select i1 %0, i1 %1, i1 false
+; CHECK-NEXT: br i1 %2, label %nested.loop, label %loop
+
+declare i64 @llvm.umin.i64(i64, i64)
+declare i64 @llvm.smin.i64(i64, i64)
+declare i64 @llvm.umax.i64(i64, i64)
+declare i64 @llvm.smax.i64(i64, i64)
+
+declare i32 @llvm.umin.i32(i32, i32)
+declare i32 @llvm.smin.i32(i32, i32)
+declare i32 @llvm.umax.i32(i32, i32)
+declare i32 @llvm.smax.i32(i32, i32)
More information about the llvm-commits
mailing list