[llvm] 18e1373 - [BPF] Undo transformation for LICM.cpp:hoistMinMax()

Eduard Zingerman via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 11 12:32:37 PDT 2023


Author: Eduard Zingerman
Date: 2023-07-11T22:30:34+03:00
New Revision: 18e13739b8c02e0b82e3e587c33e8731c8a46b0a

URL: https://github.com/llvm/llvm-project/commit/18e13739b8c02e0b82e3e587c33e8731c8a46b0a
DIFF: https://github.com/llvm/llvm-project/commit/18e13739b8c02e0b82e3e587c33e8731c8a46b0a.diff

LOG: [BPF] Undo transformation for LICM.cpp:hoistMinMax()

Extended BPFCheckAndAdjustIR pass with sinkMinMax() transformation
that undoes LICM hoistMinMax pass.

The undo transformation converts the following patterns:

    x < min(a, b) -> x < a && x < b
    x > min(a, b) -> x > a || x > b
    x < max(a, b) -> x < a || x < b
    x > max(a, b) -> x > a && x > b

Where 'a' or 'b' is a constant.
Also supports `sext min(...) ...` and `zext min(...) ...`.

~~~

This was previously commited as 09feee559a29 and reverted in
0bf9bfeacc8c because of the testbot memory leak report:
  https://lab.llvm.org/buildbot/#/builders/5/builds/34931

The memory leak issue was caused by incorrect instruction removal
sequence in skinMinMaxBB():

    I->dropAllReferences();  -------->  I->eraseFromParent();
    I->removeFromParent();   fixed to

Differential Revision: https://reviews.llvm.org/D147990

Added: 
    llvm/test/CodeGen/BPF/sink-min-max.ll

Modified: 
    llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp b/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp
index 6b74e56d6b3e9b..a3616ae7ebabef 100644
--- a/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp
+++ b/llvm/lib/Target/BPF/BPFCheckAndAdjustIR.cpp
@@ -18,8 +18,10 @@
 #include "BPF.h"
 #include "BPFCORE.h"
 #include "BPFTargetMachine.h"
+#include "llvm/Analysis/LoopInfo.h"
 #include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
@@ -41,12 +43,14 @@ class BPFCheckAndAdjustIR final : public ModulePass {
 public:
   static char ID;
   BPFCheckAndAdjustIR() : ModulePass(ID) {}
+  virtual void getAnalysisUsage(AnalysisUsage &AU) const override;
 
 private:
   void checkIR(Module &M);
   bool adjustIR(Module &M);
   bool removePassThroughBuiltin(Module &M);
   bool removeCompareBuiltin(Module &M);
+  bool sinkMinMax(Module &M);
 };
 } // End anonymous namespace
 
@@ -161,9 +165,206 @@ bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) {
   return Changed;
 }
 
+struct MinMaxSinkInfo {
+  ICmpInst *ICmp;
+  Value *Other;
+  ICmpInst::Predicate Predicate;
+  CallInst *MinMax;
+  ZExtInst *ZExt;
+  SExtInst *SExt;
+
+  MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate)
+      : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr),
+        ZExt(nullptr), SExt(nullptr) {}
+};
+
+static bool sinkMinMaxInBB(BasicBlock &BB,
+                           const std::function<bool(Instruction *)> &Filter) {
+  // Check if V is:
+  //   (fn %a %b) or (ext (fn %a %b))
+  // Where:
+  //   ext := sext | zext
+  //   fn  := smin | umin | smax | umax
+  auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) {
+    if (auto *ZExt = dyn_cast<ZExtInst>(V)) {
+      V = ZExt->getOperand(0);
+      Info.ZExt = ZExt;
+    } else if (auto *SExt = dyn_cast<SExtInst>(V)) {
+      V = SExt->getOperand(0);
+      Info.SExt = SExt;
+    }
+
+    auto *Call = dyn_cast<CallInst>(V);
+    if (!Call)
+      return false;
+
+    auto *Called = dyn_cast<Function>(Call->getCalledOperand());
+    if (!Called)
+      return false;
+
+    switch (Called->getIntrinsicID()) {
+    case Intrinsic::smin:
+    case Intrinsic::umin:
+    case Intrinsic::smax:
+    case Intrinsic::umax:
+      break;
+    default:
+      return false;
+    }
+
+    if (!Filter(Call))
+      return false;
+
+    Info.MinMax = Call;
+
+    return true;
+  };
+
+  auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V,
+                             MinMaxSinkInfo &Info) {
+    if (Info.SExt) {
+      if (Info.SExt->getType() == V->getType())
+        return V;
+      return Builder.CreateSExt(V, Info.SExt->getType());
+    }
+    if (Info.ZExt) {
+      if (Info.ZExt->getType() == V->getType())
+        return V;
+      return Builder.CreateZExt(V, Info.ZExt->getType());
+    }
+    return V;
+  };
+
+  bool Changed = false;
+  SmallVector<MinMaxSinkInfo, 2> SinkList;
+
+  // Check BB for instructions like:
+  //   insn := (icmp %a (fn ...)) | (icmp (fn ...)  %a)
+  //
+  // Where:
+  //   fn := min | max | (sext (min ...)) | (sext (max ...))
+  //
+  // Put such instructions to SinkList.
+  for (Instruction &I : BB) {
+    ICmpInst *ICmp = dyn_cast<ICmpInst>(&I);
+    if (!ICmp)
+      continue;
+    if (!ICmp->isRelational())
+      continue;
+    MinMaxSinkInfo First(ICmp, ICmp->getOperand(1),
+                         ICmpInst::getSwappedPredicate(ICmp->getPredicate()));
+    MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate());
+    bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First);
+    bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second);
+    if (!(FirstMinMax ^ SecondMinMax))
+      continue;
+    SinkList.push_back(FirstMinMax ? First : Second);
+  }
+
+  // Iterate SinkList and replace each (icmp ...) with corresponding
+  // `x < a && x < b` or similar expression.
+  for (auto &Info : SinkList) {
+    ICmpInst *ICmp = Info.ICmp;
+    CallInst *MinMax = Info.MinMax;
+    Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID();
+    ICmpInst::Predicate P = Info.Predicate;
+    if (ICmpInst::isSigned(P) && IID != Intrinsic::smin &&
+        IID != Intrinsic::smax)
+      continue;
+
+    IRBuilder<> Builder(ICmp);
+    Value *X = Info.Other;
+    Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info);
+    Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info);
+    bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin;
+    bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax;
+    bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P);
+    bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P);
+    assert(IsMin ^ IsMax);
+    assert(IsLess ^ IsGreater);
+
+    Value *Replacement;
+    Value *LHS = Builder.CreateICmp(P, X, A);
+    Value *RHS = Builder.CreateICmp(P, X, B);
+    if ((IsLess && IsMin) || (IsGreater && IsMax))
+      // x < min(a, b) -> x < a && x < b
+      // x > max(a, b) -> x > a && x > b
+      Replacement = Builder.CreateLogicalAnd(LHS, RHS);
+    else
+      // x > min(a, b) -> x > a || x > b
+      // x < max(a, b) -> x < a || x < b
+      Replacement = Builder.CreateLogicalOr(LHS, RHS);
+
+    ICmp->replaceAllUsesWith(Replacement);
+
+    Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax};
+    for (Instruction *I : ToRemove)
+      if (I && I->use_empty())
+        I->eraseFromParent();
+
+    Changed = true;
+  }
+
+  return Changed;
+}
+
+// Do the following transformation:
+//
+//   x < min(a, b) -> x < a && x < b
+//   x > min(a, b) -> x > a || x > b
+//   x < max(a, b) -> x < a || x < b
+//   x > max(a, b) -> x > a && x > b
+//
+// Such patterns are introduced by LICM.cpp:hoistMinMax()
+// transformation and might lead to BPF verification failures for
+// older kernels.
+//
+// To minimize "collateral" changes only do it for icmp + min/max
+// calls when icmp is inside a loop and min/max is outside of that
+// loop.
+//
+// Verification failure happens when:
+// - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
+// - verifier can recognize RHS as a constant scalar in some context;
+// - verifier can't recognize RHS1 as a constant scalar in the same
+//   context;
+//
+// The "constant scalar" is not a compile time constant, but a register
+// that holds a scalar value known to verifier at some point in time
+// during abstract interpretation.
+//
+// See also:
+//   https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
+bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) {
+  bool Changed = false;
+
+  for (Function &F : M) {
+    if (F.isDeclaration())
+      continue;
+
+    LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo();
+    for (Loop *L : LI)
+      for (BasicBlock *BB : L->blocks()) {
+        // Filter out instructions coming from the same loop
+        Loop *BBLoop = LI.getLoopFor(BB);
+        auto OtherLoopFilter = [&](Instruction *I) {
+          return LI.getLoopFor(I->getParent()) != BBLoop;
+        };
+        Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter);
+      }
+  }
+
+  return Changed;
+}
+
+void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addRequired<LoopInfoWrapperPass>();
+}
+
 bool BPFCheckAndAdjustIR::adjustIR(Module &M) {
   bool Changed = removePassThroughBuiltin(M);
   Changed = removeCompareBuiltin(M) || Changed;
+  Changed = sinkMinMax(M) || Changed;
   return Changed;
 }
 

diff  --git a/llvm/test/CodeGen/BPF/sink-min-max.ll b/llvm/test/CodeGen/BPF/sink-min-max.ll
new file mode 100644
index 00000000000000..5ee080839985d2
--- /dev/null
+++ b/llvm/test/CodeGen/BPF/sink-min-max.ll
@@ -0,0 +1,258 @@
+; RUN: opt --bpf-check-and-opt-ir -S -mtriple=bpf-pc-linux %s | FileCheck %s
+
+; Test plan:
+; @test1: x <  umin(i64 a, i64 b)
+; @test2: x <  umax(i64 a, i64 b)
+; @test3: x >= umin(i64 a, i64 b)
+; @test4: x >= umax(i64 a, i64 b)
+; @test5: umin(i64 a, i64 b) >= x
+; @test6: x <  smin(i64 a, i64 b)
+; @test7: x <  umin(i32 a, i32 b)
+; @test8: x <  zext i64 umin(i32 a, i32 b)
+; @test9: x <  sext i64 umin(i32 a, i32 b)
+; @test10: check that umin belonging to the same loop is not touched
+; @test11: check that nested loops are processed
+
+define i32 @test1(i64 %a, i64 %b, i64 %x) {
+entry:
+  %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
+  br label %loop
+loop:
+  %cmp = icmp ult i64 %x, %min
+  br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK:       @test1
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  loop:
+; CHECK-NEXT:    %0 = icmp ult i64 %x, %a
+; CHECK-NEXT:    %1 = icmp ult i64 %x, %b
+; CHECK-NEXT:    %2 = select i1 %0, i1 %1, i1 false
+; CHECK-NEXT:    br i1 %2, label %loop, label %ret
+
+define i32 @test2(i64 %a, i64 %b, i64 %x) {
+entry:
+  %max = tail call i64 @llvm.umax.i64(i64 %a, i64 %b)
+  br label %loop
+loop:
+  %cmp = icmp ult i64 %x, %max
+  br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK:       @test2
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  loop:
+; CHECK-NEXT:    %0 = icmp ult i64 %x, %a
+; CHECK-NEXT:    %1 = icmp ult i64 %x, %b
+; CHECK-NEXT:    %2 = select i1 %0, i1 true, i1 %1
+; CHECK-NEXT:    br i1 %2, label %loop, label %ret
+
+define i32 @test3(i64 %a, i64 %b, i64 %x) {
+entry:
+  %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
+  br label %loop
+loop:
+  %cmp = icmp uge i64 %x, %min
+  br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK:       @test3
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  loop:
+; CHECK-NEXT:    %0 = icmp uge i64 %x, %a
+; CHECK-NEXT:    %1 = icmp uge i64 %x, %b
+; CHECK-NEXT:    %2 = select i1 %0, i1 true, i1 %1
+; CHECK-NEXT:    br i1 %2, label %loop, label %ret
+
+define i32 @test4(i64 %a, i64 %b, i64 %x) {
+entry:
+  %max = tail call i64 @llvm.umax.i64(i64 %a, i64 %b)
+  br label %loop
+loop:
+  %cmp = icmp uge i64 %x, %max
+  br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK:       @test4
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  loop:
+; CHECK-NEXT:    %0 = icmp uge i64 %x, %a
+; CHECK-NEXT:    %1 = icmp uge i64 %x, %b
+; CHECK-NEXT:    %2 = select i1 %0, i1 %1, i1 false
+; CHECK-NEXT:    br i1 %2, label %loop, label %ret
+
+define i32 @test5(i64 %a, i64 %b, i64 %x) {
+entry:
+  %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
+  br label %loop
+loop:
+  %cmp = icmp uge i64 %min, %x
+  br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK:       @test5
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  loop:
+; CHECK:         %0 = icmp ule i64 %x, %a
+; CHECK-NEXT:    %1 = icmp ule i64 %x, %b
+; CHECK-NEXT:    %2 = select i1 %0, i1 %1, i1 false
+; CHECK-NEXT:    br i1 %2, label %loop, label %ret
+
+define i32 @test6(i64 %a, i64 %b, i64 %x) {
+entry:
+  %min = tail call i64 @llvm.smin.i64(i64 %a, i64 %b)
+  br label %loop
+loop:
+  %cmp = icmp slt i64 %x, %min
+  br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK:       @test6
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  loop:
+; CHECK:         %0 = icmp slt i64 %x, %a
+; CHECK-NEXT:    %1 = icmp slt i64 %x, %b
+; CHECK-NEXT:    %2 = select i1 %0, i1 %1, i1 false
+; CHECK-NEXT:    br i1 %2, label %loop, label %ret
+
+define i32 @test7(i32 %a, i32 %b, i32 %x) {
+entry:
+  %min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b)
+  br label %loop
+loop:
+  %cmp = icmp ult i32 %x, %min
+  br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK:       @test7
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  loop:
+; CHECK:         %0 = icmp ult i32 %x, %a
+; CHECK-NEXT:    %1 = icmp ult i32 %x, %b
+; CHECK-NEXT:    %2 = select i1 %0, i1 %1, i1 false
+; CHECK-NEXT:    br i1 %2, label %loop, label %ret
+
+define i32 @test8(i32 %a, i32 %b, i64 %x) {
+entry:
+  %min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b)
+  br label %loop
+loop:
+  %ext = zext i32 %min to i64
+  %cmp = icmp ult i64 %x, %ext
+  br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK:       @test8
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  loop:
+; CHECK-NEXT:    %0 = zext i32 %a to i64
+; CHECK-NEXT:    %1 = zext i32 %b to i64
+; CHECK-NEXT:    %2 = icmp ult i64 %x, %0
+; CHECK-NEXT:    %3 = icmp ult i64 %x, %1
+; CHECK-NEXT:    %4 = select i1 %2, i1 %3, i1 false
+; CHECK-NEXT:    br i1 %4, label %loop, label %ret
+
+define i32 @test9(i32 %a, i32 %b, i64 %x) {
+entry:
+  %min = tail call i32 @llvm.umin.i32(i32 %a, i32 %b)
+  br label %loop
+loop:
+  %ext = sext i32 %min to i64
+  %cmp = icmp ult i64 %x, %ext
+  br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK:       @test9
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  loop:
+; CHECK-NEXT:    %0 = sext i32 %a to i64
+; CHECK-NEXT:    %1 = sext i32 %b to i64
+; CHECK-NEXT:    %2 = icmp ult i64 %x, %0
+; CHECK-NEXT:    %3 = icmp ult i64 %x, %1
+; CHECK-NEXT:    %4 = select i1 %2, i1 %3, i1 false
+; CHECK-NEXT:    br i1 %4, label %loop, label %ret
+
+; umin within the loop body is unchanged
+define i32 @test10(i64 %a, i64 %b, i64 %x) {
+entry:
+  br label %loop
+loop:
+  %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
+  %cmp = icmp ult i64 %x, %min
+  br i1 %cmp, label %loop, label %ret
+ret: ret i32 0
+}
+
+; CHECK:       @test10
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  loop:
+; CHECK-NEXT:    %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
+; CHECK-NEXT:    %cmp = icmp ult i64 %x, %min
+; CHECK-NEXT:    br i1 %cmp, label %loop, label %ret
+
+; umin from outer loop body is processed
+define i32 @test11(i64 %a, i64 %b, i64 %x) {
+entry:
+  br label %loop
+
+loop:
+  %min = tail call i64 @llvm.umin.i64(i64 %a, i64 %b)
+  br label %nested.loop
+nested.loop:
+  %cmp = icmp ult i64 %x, %min
+  br i1 %cmp, label %nested.loop, label %loop
+
+ret: ret i32 0
+}
+
+; CHECK:       @test11
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label %loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  loop:
+; CHECK-NEXT:    br label %nested.loop
+; CHECK-EMPTY:
+; CHECK-NEXT:  nested.loop:
+; CHECK-NEXT:    %0 = icmp ult i64 %x, %a
+; CHECK-NEXT:    %1 = icmp ult i64 %x, %b
+; CHECK-NEXT:    %2 = select i1 %0, i1 %1, i1 false
+; CHECK-NEXT:    br i1 %2, label %nested.loop, label %loop
+
+declare i64 @llvm.umin.i64(i64, i64)
+declare i64 @llvm.smin.i64(i64, i64)
+declare i64 @llvm.umax.i64(i64, i64)
+declare i64 @llvm.smax.i64(i64, i64)
+
+declare i32 @llvm.umin.i32(i32, i32)
+declare i32 @llvm.smin.i32(i32, i32)
+declare i32 @llvm.umax.i32(i32, i32)
+declare i32 @llvm.smax.i32(i32, i32)


        


More information about the llvm-commits mailing list