[llvm] [SelectOpt] Add support for AShr/LShr operands (PR #118495)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 3 07:00:22 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Igor Kirillov (igogo-x86)

<details>
<summary>Changes</summary>

For conditional increments with sign check conditions like X < 0 or X >= 0, the compiler may generate code like this:

  %cmp = icmp sgt i64 %1, -1
  %shift = ashr i64 %1, 63
  %j.next = add nsw i64 %j, %shift
  %sel = select i1 %cmp ...

, where %cmp is not in computation but in some other implicit or regular expressions. This patch allows SelectOptimize pass to recognise these cases.

---
Full diff: https://github.com/llvm/llvm-project/pull/118495.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectOptimize.cpp (+63-23) 
- (modified) llvm/test/CodeGen/AArch64/selectopt-cast.ll (+124) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectOptimize.cpp b/llvm/lib/CodeGen/SelectOptimize.cpp
index 484705eabbc42e..9725f4322c5932 100644
--- a/llvm/lib/CodeGen/SelectOptimize.cpp
+++ b/llvm/lib/CodeGen/SelectOptimize.cpp
@@ -218,7 +218,7 @@ class SelectOptimizeImpl {
 private:
   // Select groups consist of consecutive select-like instructions with the same
   // condition. Between select-likes could be any number of auxiliary
-  // instructions related to the condition like not, zext
+  // instructions related to the condition like not, zext, ashr/lshr
   struct SelectGroup {
     Value *Condition;
     SmallVector<SelectLike, 2> Selects;
@@ -496,7 +496,13 @@ static Value *getTrueOrFalseValue(
 
   auto *CBO = BO->clone();
   auto CondIdx = SI.getConditionOpIndex();
-  CBO->setOperand(CondIdx, ConstantInt::get(CBO->getType(), 1));
+  auto *AuxI = cast<Instruction>(CBO->getOperand(CondIdx));
+  if (isa<ZExtInst>(AuxI) || isa<LShrOperator>(AuxI)) {
+    CBO->setOperand(CondIdx, ConstantInt::get(CBO->getType(), 1));
+  } else {
+    assert(isa<AShrOperator>(AuxI) && "Unexpected opcode");
+    CBO->setOperand(CondIdx, ConstantInt::get(CBO->getType(), -1));
+  }
 
   unsigned OtherIdx = 1 - CondIdx;
   if (auto *IV = dyn_cast<Instruction>(CBO->getOperand(OtherIdx))) {
@@ -755,6 +761,9 @@ void SelectOptimizeImpl::collectSelectGroups(BasicBlock &BB,
   // zero or some constant value on True/False branch, such as:
   // * ZExt(1bit)
   // * Not(1bit)
+  // * A(L)Shr(Val), ValBitSize - 1, where there is a condition like `Val <= 0`
+  // earlier in the BB. For conditions that check the sign of the Val compiler
+  // may generate shifts instead of ZExt/SExt.
   struct SelectLikeInfo {
     Value *Cond;
     bool IsAuxiliary;
@@ -763,11 +772,19 @@ void SelectOptimizeImpl::collectSelectGroups(BasicBlock &BB,
   };
 
   DenseMap<Value *, SelectLikeInfo> SelectInfo;
+  // Keeps visited comparisons to help identify AShr/LShr variants of auxiliary
+  // instructions.
+  SmallPtrSet<CmpInst *, 2> SeenCmp;
 
   // Check if the instruction is SelectLike or might be part of SelectLike
   // expression, put information into SelectInfo and return the iterator to the
   // inserted position.
-  auto ProcessSelectInfo = [&SelectInfo](Instruction *I) {
+  auto ProcessSelectInfo = [&SelectInfo, &SeenCmp](Instruction *I) {
+    if (auto *Cmp = dyn_cast<CmpInst>(I)) {
+      SeenCmp.insert(Cmp);
+      return SelectInfo.end();
+    }
+
     Value *Cond;
     if (match(I, m_OneUse(m_ZExt(m_Value(Cond)))) &&
         Cond->getType()->isIntegerTy(1)) {
@@ -784,30 +801,53 @@ void SelectOptimizeImpl::collectSelectGroups(BasicBlock &BB,
       bool Inverted = match(Cond, m_Not(m_Value(Cond)));
       return SelectInfo.insert({I, {Cond, false, Inverted, 0}}).first;
     }
-
-    // An Or(zext(i1 X), Y) can also be treated like a select, with condition X
-    // and values Y|1 and Y.
-    if (auto *BO = dyn_cast<BinaryOperator>(I)) {
-      switch (I->getOpcode()) {
-      case Instruction::Add:
-      case Instruction::Sub: {
-        Value *X;
-        if (!((PatternMatch::match(I->getOperand(0),
-                                   m_OneUse(m_ZExt(m_Value(X)))) ||
-               PatternMatch::match(I->getOperand(1),
-                                   m_OneUse(m_ZExt(m_Value(X))))) &&
-              X->getType()->isIntegerTy(1)))
+    Value *Val;
+    ConstantInt *Shift;
+    if (match(I, m_Shr(m_Value(Val), m_ConstantInt(Shift))) &&
+        I->getType()->getIntegerBitWidth() == Shift->getZExtValue() + 1) {
+      for (auto *CmpI : SeenCmp) {
+        auto Pred = CmpI->getPredicate();
+        if (Val != CmpI->getOperand(0))
           return SelectInfo.end();
-        break;
-      }
-      case Instruction::Or:
-        if (BO->getType()->isIntegerTy(1) || BO->getOpcode() != Instruction::Or)
-          return SelectInfo.end();
-        break;
+        if ((Pred == CmpInst::ICMP_SGT &&
+             match(CmpI->getOperand(1), m_ConstantInt<-1>())) ||
+            (Pred == CmpInst::ICMP_SGE &&
+             match(CmpI->getOperand(1), m_Zero())) ||
+            (Pred == CmpInst::ICMP_SLT &&
+             match(CmpI->getOperand(1), m_Zero())) ||
+            (Pred == CmpInst::ICMP_SLE &&
+             match(CmpI->getOperand(1), m_ConstantInt<-1>()))) {
+          bool Inverted =
+              Pred == CmpInst::ICMP_SGT || Pred == CmpInst::ICMP_SGE;
+          return SelectInfo.insert({I, {CmpI, true, Inverted, 0}}).first;
+        }
       }
+    }
+
+    // An BinOp(Aux(X), Y) can also be treated like a select, with condition X
+    // and values Y|1 and Y.
+    // `Aux` can be either `ZExt(1bit)` or `XShr(Val), ValBitSize - 1`
+    // `BinOp` can be Add, Sub, Or
+    Value *X;
+    auto MatchZExtPattern = m_c_BinOp(m_Value(), m_OneUse(m_ZExt(m_Value(X))));
+    auto MatchShiftPattern =
+        m_c_BinOp(m_Value(), m_OneUse(m_Shr(m_Value(X), m_ConstantInt(Shift))));
+
+    // This check is unnecessary, but it prevents costly access to the
+    // SelectInfo map.
+    if ((match(I, MatchZExtPattern) && X->getType()->isIntegerTy(1)) ||
+        (match(I, MatchShiftPattern) &&
+         X->getType()->getIntegerBitWidth() == Shift->getZExtValue() + 1)) {
+      if (I->getOpcode() != Instruction::Add &&
+          I->getOpcode() != Instruction::Sub &&
+          I->getOpcode() != Instruction::Or)
+        return SelectInfo.end();
+
+      if (I->getOpcode() == Instruction::Or && I->getType()->isIntegerTy(1))
+        return SelectInfo.end();
 
       for (unsigned Idx = 0; Idx < 2; Idx++) {
-        auto *Op = BO->getOperand(Idx);
+        auto *Op = I->getOperand(Idx);
         auto It = SelectInfo.find(Op);
         if (It != SelectInfo.end() && It->second.IsAuxiliary) {
           Cond = It->second.Cond;
diff --git a/llvm/test/CodeGen/AArch64/selectopt-cast.ll b/llvm/test/CodeGen/AArch64/selectopt-cast.ll
index 102b89df32b03b..3b8b9780a02ee0 100644
--- a/llvm/test/CodeGen/AArch64/selectopt-cast.ll
+++ b/llvm/test/CodeGen/AArch64/selectopt-cast.ll
@@ -735,3 +735,127 @@ loop:
 exit:
   ret void
 }
+
+define void @test_add_lshr_add_regular_select(ptr %dst, ptr %src, i64 %i.start, i64 %j.start) {
+; CHECK-LABEL: @test_add_lshr_add_regular_select(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[LOOP:%.*]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 100000, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[SELECT_END:%.*]] ]
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[I_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[J_START:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
+; CHECK-NEXT:    [[TMP0:%.*]] = load ptr, ptr [[GEP_I]], align 8
+; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds i64, ptr [[TMP0]], i64 [[J]]
+; CHECK-NEXT:    [[TMP1:%.*]] = load i64, ptr [[GEP_J]], align 8
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i64 [[TMP1]], -1
+; CHECK-NEXT:    [[SHIFT:%.*]] = lshr i64 [[TMP1]], 63
+; CHECK-NEXT:    [[CMP_FROZEN:%.*]] = freeze i1 [[CMP]]
+; CHECK-NEXT:    br i1 [[CMP_FROZEN]], label [[SELECT_TRUE_SINK:%.*]], label [[SELECT_FALSE_SINK:%.*]]
+; CHECK:       select.true.sink:
+; CHECK-NEXT:    [[TMP2:%.*]] = add nsw i64 [[I]], 1
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.false.sink:
+; CHECK-NEXT:    [[TMP3:%.*]] = add nsw i64 [[J]], 1
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.end:
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[J]], [[SELECT_TRUE_SINK]] ], [ [[TMP3]], [[SELECT_FALSE_SINK]] ]
+; CHECK-NEXT:    [[I_NEXT]] = phi i64 [ [[TMP2]], [[SELECT_TRUE_SINK]] ], [ [[I]], [[SELECT_FALSE_SINK]] ]
+; CHECK-NEXT:    [[COND:%.*]] = phi i64 [ [[J]], [[SELECT_TRUE_SINK]] ], [ [[I]], [[SELECT_FALSE_SINK]] ]
+; CHECK-NEXT:    [[INC:%.*]] = zext i1 [[CMP]] to i64
+; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr i64, ptr [[DST:%.*]], i64 [[IV]]
+; CHECK-NEXT:    store i64 [[COND]], ptr [[GEP_DST]], align 8
+; CHECK-NEXT:    [[IV_NEXT]] = add nsw i64 [[IV]], -1
+; CHECK-NEXT:    [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 0
+; CHECK-NEXT:    br i1 [[EC]], label [[EXIT:%.*]], label [[LOOP]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop
+
+loop:
+  %iv = phi i64 [ 100000, %entry ], [ %iv.next, %loop ]
+  %i = phi i64 [ %i.start, %entry ], [ %i.next, %loop ]
+  %j = phi i64 [ %j.start, %entry ], [ %j.next, %loop ]
+  %gep.i = getelementptr inbounds ptr, ptr %src, i64 %i
+  %0 = load ptr, ptr %gep.i, align 8
+  %gep.j = getelementptr inbounds i64, ptr %0, i64 %j
+  %1 = load i64, ptr %gep.j, align 8
+  %cmp = icmp sgt i64 %1, -1
+  %shift = lshr i64 %1, 63
+  %j.next = add nsw i64 %j, %shift
+  %inc = zext i1 %cmp to i64
+  %i.next = add nsw i64 %i, %inc
+  %cond = select i1 %cmp, i64 %j, i64 %i
+  %gep.dst = getelementptr i64, ptr %dst, i64 %iv
+  store i64 %cond, ptr %gep.dst, align 8
+  %iv.next = add nsw i64 %iv, -1
+  %ec = icmp eq i64 %iv.next, 0
+  br i1 %ec, label %exit, label %loop
+
+exit:
+  ret void
+}
+
+define void @test_add_ashr_add_regular_select(ptr %dst, ptr %src, i64 %i.start, i64 %j.start) {
+; CHECK-LABEL: @test_add_ashr_add_regular_select(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[LOOP:%.*]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 100000, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[SELECT_END:%.*]] ]
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[I_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[J_START:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
+; CHECK-NEXT:    [[TMP0:%.*]] = load ptr, ptr [[GEP_I]], align 8
+; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds i64, ptr [[TMP0]], i64 [[J]]
+; CHECK-NEXT:    [[TMP1:%.*]] = load i64, ptr [[GEP_J]], align 8
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i64 [[TMP1]], -1
+; CHECK-NEXT:    [[SHIFT:%.*]] = ashr i64 [[TMP1]], 63
+; CHECK-NEXT:    [[CMP_FROZEN:%.*]] = freeze i1 [[CMP]]
+; CHECK-NEXT:    br i1 [[CMP_FROZEN]], label [[SELECT_TRUE_SINK:%.*]], label [[SELECT_FALSE_SINK:%.*]]
+; CHECK:       select.true.sink:
+; CHECK-NEXT:    [[TMP2:%.*]] = add nsw i64 [[I]], 1
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.false.sink:
+; CHECK-NEXT:    [[TMP3:%.*]] = add nsw i64 [[J]], -1
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.end:
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[J]], [[SELECT_TRUE_SINK]] ], [ [[TMP3]], [[SELECT_FALSE_SINK]] ]
+; CHECK-NEXT:    [[I_NEXT]] = phi i64 [ [[TMP2]], [[SELECT_TRUE_SINK]] ], [ [[I]], [[SELECT_FALSE_SINK]] ]
+; CHECK-NEXT:    [[COND:%.*]] = phi i64 [ [[J]], [[SELECT_TRUE_SINK]] ], [ [[I]], [[SELECT_FALSE_SINK]] ]
+; CHECK-NEXT:    [[INC:%.*]] = zext i1 [[CMP]] to i64
+; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr i64, ptr [[DST:%.*]], i64 [[IV]]
+; CHECK-NEXT:    store i64 [[COND]], ptr [[GEP_DST]], align 8
+; CHECK-NEXT:    [[IV_NEXT]] = add nsw i64 [[IV]], -1
+; CHECK-NEXT:    [[EC:%.*]] = icmp eq i64 [[IV_NEXT]], 0
+; CHECK-NEXT:    br i1 [[EC]], label [[EXIT:%.*]], label [[LOOP]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop
+
+loop:
+  %iv = phi i64 [ 100000, %entry ], [ %iv.next, %loop ]
+  %i = phi i64 [ %i.start, %entry ], [ %i.next, %loop ]
+  %j = phi i64 [ %j.start, %entry ], [ %j.next, %loop ]
+  %gep.i = getelementptr inbounds ptr, ptr %src, i64 %i
+  %0 = load ptr, ptr %gep.i, align 8
+  %gep.j = getelementptr inbounds i64, ptr %0, i64 %j
+  %1 = load i64, ptr %gep.j, align 8
+  %cmp = icmp sgt i64 %1, -1
+  %shift = ashr i64 %1, 63
+  %j.next = add nsw i64 %j, %shift
+  %inc = zext i1 %cmp to i64
+  %i.next = add nsw i64 %i, %inc
+  %cond = select i1 %cmp, i64 %j, i64 %i
+  %gep.dst = getelementptr i64, ptr %dst, i64 %iv
+  store i64 %cond, ptr %gep.dst, align 8
+  %iv.next = add nsw i64 %iv, -1
+  %ec = icmp eq i64 %iv.next, 0
+  br i1 %ec, label %exit, label %loop
+
+exit:
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/118495


More information about the llvm-commits mailing list