[llvm] [SelectOpt] Support ADD and SUB with zext operands. (PR #115489)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 28 01:15:25 PST 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/115489

>From 5dbb7ea0ee74a0aefa674356978b44a4a60b71ce Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 16 Oct 2024 19:26:59 +0100
Subject: [PATCH 1/2] [SelectOpt] Support add and sub with zext operands.

Extend the support for implicit selects in the form of OR with a
ZExt operand to support ADD and SUB binops as well. They similarly can
form implicit selects which can be profitable to convert back the
branches.
---
 llvm/lib/CodeGen/SelectOptimize.cpp           | 98 +++++++++++++++++++
 .../AArch64/AArch64TargetTransformInfo.cpp    | 22 +++--
 llvm/test/CodeGen/AArch64/selectopt-cast.ll   | 50 +++++++---
 3 files changed, 146 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectOptimize.cpp b/llvm/lib/CodeGen/SelectOptimize.cpp
index d480642171e8e5..3517e3eed7f818 100644
--- a/llvm/lib/CodeGen/SelectOptimize.cpp
+++ b/llvm/lib/CodeGen/SelectOptimize.cpp
@@ -138,8 +138,49 @@ class SelectOptimizeImpl {
     unsigned CondIdx;
 
   public:
+<<<<<<< HEAD
     SelectLike(Instruction *I, bool Inverted = false, unsigned CondIdx = 0)
         : I(I), Inverted(Inverted), CondIdx(CondIdx) {}
+=======
+    /// Match a select or select-like instruction, returning a SelectLike.
+    static SelectLike match(Instruction *I) {
+      // Select instruction are what we are usually looking for.
+      if (isa<SelectInst>(I))
+        return SelectLike(I);
+
+      // An Or(zext(i1 X), Y) can also be treated like a select, with condition
+      // C and values Y|1 and Y.
+      switch (I->getOpcode()) {
+      case Instruction::Add:
+      case Instruction::Or:
+      case Instruction::Sub: {
+        Value *X;
+        if ((PatternMatch::match(I->getOperand(0),
+                                 m_OneUse(m_ZExt(m_Value(X)))) ||
+             PatternMatch::match(I->getOperand(1),
+                                 m_OneUse(m_ZExt(m_Value(X))))) &&
+            X->getType()->isIntegerTy(1))
+          return SelectLike(I);
+        break;
+      }
+      }
+
+      return SelectLike(nullptr);
+    }
+
+    bool isValid() { return I; }
+    operator bool() { return isValid(); }
+
+    /// Invert the select by inverting the condition and switching the operands.
+    void setInverted() {
+      assert(!Inverted && "Trying to invert an inverted SelectLike");
+      assert(isa<Instruction>(getCondition()) &&
+             cast<Instruction>(getCondition())->getOpcode() ==
+                 Instruction::Xor);
+      Inverted = true;
+    }
+    bool isInverted() const { return Inverted; }
+>>>>>>> 7e5ca4eafa3c ([SelectOpt] Support add and sub with zext operands.)
 
     Instruction *getI() { return I; }
     const Instruction *getI() const { return I; }
@@ -195,6 +236,7 @@ class SelectOptimizeImpl {
           return It != InstCostMap.end() ? It->second.NonPredCost
                                          : Scaled64::getZero();
         }
+<<<<<<< HEAD
         return Scaled64::getZero();
       }
       // If getTrue(False)Value() return nullptr, it means we are dealing with
@@ -212,6 +254,48 @@ class SelectOptimizeImpl {
           TotalCost += It->second.NonPredCost;
       }
       return TotalCost;
+=======
+
+      // BinaryOp case - add the cost of an extra BinOp to the cost of the False
+      // case.
+      if (isa<BinaryOperator>(I)) {
+        if (auto OpI = dyn_cast<Instruction>(getFalseValue())) {
+          auto It = InstCostMap.find(I);
+          if (It != InstCostMap.end()) {
+            InstructionCost OrCost = TTI->getArithmeticInstrCost(
+                I->getOpcode(), OpI->getType(),
+                TargetTransformInfo::TCK_Latency,
+                {TargetTransformInfo::OK_AnyValue,
+                 TargetTransformInfo::OP_None},
+                {TTI::OK_UniformConstantValue, TTI::OP_PowerOf2});
+            return It->second.NonPredCost + Scaled64::get(*OrCost.getValue());
+          }
+        }
+      }
+
+      return Scaled64::getZero();
+    }
+
+    /// Return the NonPredCost cost of the false op, given the costs in
+    /// InstCostMap. This may need to be generated for select-like instructions.
+    Scaled64
+    getFalseOpCost(DenseMap<const Instruction *, CostInfo> &InstCostMap,
+                   const TargetTransformInfo *TTI) {
+      if (isa<SelectInst>(I))
+        if (auto *I = dyn_cast<Instruction>(getFalseValue())) {
+          auto It = InstCostMap.find(I);
+          return It != InstCostMap.end() ? It->second.NonPredCost
+                                         : Scaled64::getZero();
+        }
+
+      // Or case - return the cost of the false case
+      if (isa<BinaryOperator>(I))
+        if (auto I = dyn_cast<Instruction>(getFalseValue()))
+          if (auto It = InstCostMap.find(I); It != InstCostMap.end())
+            return It->second.NonPredCost;
+
+      return Scaled64::getZero();
+>>>>>>> 7e5ca4eafa3c ([SelectOpt] Support add and sub with zext operands.)
     }
   };
 
@@ -488,9 +572,23 @@ static Value *getTrueOrFalseValue(
     return V;
   }
 
+<<<<<<< HEAD
   auto *BO = cast<BinaryOperator>(SI.getI());
   assert(BO->getOpcode() == Instruction::Or &&
          "Only currently handling Or instructions.");
+=======
+  if (auto *BinOp = dyn_cast<BinaryOperator>(SI.getI())) {
+    assert((BinOp->getOpcode() == Instruction::Add ||
+            BinOp->getOpcode() == Instruction::Or ||
+            BinOp->getOpcode() == Instruction::Sub) &&
+           "Only currently handling Add, Or and Sub instructions.");
+    V = SI.getFalseValue();
+    if (isTrue) {
+      Constant *CI = ConstantInt::get(V->getType(), 1);
+      V = IB.CreateBinOp(BinOp->getOpcode(), V, CI);
+    }
+  }
+>>>>>>> 7e5ca4eafa3c ([SelectOpt] Support add and sub with zext operands.)
 
   auto *CBO = BO->clone();
   auto CondIdx = SI.getConditionOpIndex();
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 7a1e401bca18cb..f8914b5ca1e25d 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4796,14 +4796,20 @@ AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
 }
 
 bool AArch64TTIImpl::shouldTreatInstructionLikeSelect(const Instruction *I) {
-  // For the binary operators (e.g. or) we need to be more careful than
-  // selects, here we only transform them if they are already at a natural
-  // break point in the code - the end of a block with an unconditional
-  // terminator.
-  if (EnableOrLikeSelectOpt && I->getOpcode() == Instruction::Or &&
-      isa<BranchInst>(I->getNextNode()) &&
-      cast<BranchInst>(I->getNextNode())->isUnconditional())
-    return true;
+  if (EnableOrLikeSelectOpt) {
+    // For the binary operators (e.g. or) we need to be more careful than
+    // selects, here we only transform them if they are already at a natural
+    // break point in the code - the end of a block with an unconditional
+    // terminator.
+    if (I->getOpcode() == Instruction::Or &&
+        isa<BranchInst>(I->getNextNode()) &&
+        cast<BranchInst>(I->getNextNode())->isUnconditional())
+      return true;
+
+    if (I->getOpcode() == Instruction::Add ||
+        I->getOpcode() == Instruction::Sub)
+      return true;
+  }
   return BaseT::shouldTreatInstructionLikeSelect(I);
 }
 
diff --git a/llvm/test/CodeGen/AArch64/selectopt-cast.ll b/llvm/test/CodeGen/AArch64/selectopt-cast.ll
index c15e6e1d1b697c..6489c8d6c2d0ae 100644
--- a/llvm/test/CodeGen/AArch64/selectopt-cast.ll
+++ b/llvm/test/CodeGen/AArch64/selectopt-cast.ll
@@ -7,16 +7,22 @@ define void @test_add_zext(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.star
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[HIGH:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[LOOP]] ]
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[SELECT_END:%.*]] ]
+; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[HIGH:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[SELECT_END]] ]
 ; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
 ; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_I]], align 8
 ; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[SRC]], i64 [[J]]
 ; CHECK-NEXT:    [[L_J:%.*]] = load ptr, ptr [[GEP_J]], align 8
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[CMP3]] to i64
-; CHECK-NEXT:    [[J_NEXT]] = add nsw i64 [[J]], [[DEC]]
+; CHECK-NEXT:    [[CMP3_FROZEN:%.*]] = freeze i1 [[CMP3]]
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[J]], 1
+; CHECK-NEXT:    br i1 [[CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE:%.*]]
+; CHECK:       select.false:
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.end:
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[LOOP]] ], [ [[J]], [[SELECT_FALSE]] ]
 ; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
 ; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
 ; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
@@ -101,9 +107,6 @@ define void @test_add_zext_not(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[J_START:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[LOOP]] ]
 ; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
 ; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_I]], align 8
 ; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[SRC]], i64 [[J]]
@@ -111,7 +114,13 @@ define void @test_add_zext_not(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[NOT_CMP3:%.*]] = xor i1 [[CMP3]], true
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[NOT_CMP3]] to i64
-; CHECK-NEXT:    [[J_NEXT]] = add nsw i64 [[J]], [[DEC]]
+; CHECK-NEXT:    [[NOT_CMP3_FROZEN:%.*]] = freeze i1 [[NOT_CMP3]]
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[J]], 1
+; CHECK-NEXT:    br i1 [[NOT_CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE:%.*]]
+; CHECK:       select.false:
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.end:
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[LOOP]] ], [ [[J]], [[SELECT_FALSE]] ]
 ; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
 ; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
 ; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
@@ -356,16 +365,22 @@ define void @test_sub_zext(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.star
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[HIGH:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[LOOP]] ]
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[SELECT_END:%.*]] ]
+; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[HIGH:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[SELECT_END]] ]
 ; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
 ; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_I]], align 8
 ; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[SRC]], i64 [[J]]
 ; CHECK-NEXT:    [[L_J:%.*]] = load ptr, ptr [[GEP_J]], align 8
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[CMP3]] to i64
-; CHECK-NEXT:    [[J_NEXT]] = sub nsw i64 [[J]], [[DEC]]
+; CHECK-NEXT:    [[CMP3_FROZEN:%.*]] = freeze i1 [[CMP3]]
+; CHECK-NEXT:    [[TMP0:%.*]] = sub i64 [[J]], 1
+; CHECK-NEXT:    br i1 [[CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE:%.*]]
+; CHECK:       select.false:
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.end:
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[LOOP]] ], [ [[J]], [[SELECT_FALSE]] ]
 ; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
 ; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
 ; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
@@ -450,9 +465,6 @@ define void @test_sub_zext_not(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[J_START:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[LOOP]] ]
 ; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
 ; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_I]], align 8
 ; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[SRC]], i64 [[J]]
@@ -460,7 +472,13 @@ define void @test_sub_zext_not(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[NOT_CMP3:%.*]] = xor i1 [[CMP3]], true
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[NOT_CMP3]] to i64
-; CHECK-NEXT:    [[J_NEXT]] = sub nsw i64 [[J]], [[DEC]]
+; CHECK-NEXT:    [[NOT_CMP3_FROZEN:%.*]] = freeze i1 [[NOT_CMP3]]
+; CHECK-NEXT:    [[TMP0:%.*]] = sub i64 [[J]], 1
+; CHECK-NEXT:    br i1 [[NOT_CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE:%.*]]
+; CHECK:       select.false:
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.end:
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[LOOP]] ], [ [[J]], [[SELECT_FALSE]] ]
 ; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
 ; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
 ; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1

>From 76ce3b6b22b1e2d5bb0af6d4dd4befc1b98122ae Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 27 Nov 2024 17:22:23 +0000
Subject: [PATCH 2/2] !fixup rebase on top of latest main

---
 llvm/lib/CodeGen/SelectOptimize.cpp         | 123 ++++----------------
 llvm/test/CodeGen/AArch64/selectopt-cast.ll |  98 +++++++++-------
 2 files changed, 79 insertions(+), 142 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectOptimize.cpp b/llvm/lib/CodeGen/SelectOptimize.cpp
index 3517e3eed7f818..484705eabbc42e 100644
--- a/llvm/lib/CodeGen/SelectOptimize.cpp
+++ b/llvm/lib/CodeGen/SelectOptimize.cpp
@@ -138,49 +138,8 @@ class SelectOptimizeImpl {
     unsigned CondIdx;
 
   public:
-<<<<<<< HEAD
     SelectLike(Instruction *I, bool Inverted = false, unsigned CondIdx = 0)
         : I(I), Inverted(Inverted), CondIdx(CondIdx) {}
-=======
-    /// Match a select or select-like instruction, returning a SelectLike.
-    static SelectLike match(Instruction *I) {
-      // Select instruction are what we are usually looking for.
-      if (isa<SelectInst>(I))
-        return SelectLike(I);
-
-      // An Or(zext(i1 X), Y) can also be treated like a select, with condition
-      // C and values Y|1 and Y.
-      switch (I->getOpcode()) {
-      case Instruction::Add:
-      case Instruction::Or:
-      case Instruction::Sub: {
-        Value *X;
-        if ((PatternMatch::match(I->getOperand(0),
-                                 m_OneUse(m_ZExt(m_Value(X)))) ||
-             PatternMatch::match(I->getOperand(1),
-                                 m_OneUse(m_ZExt(m_Value(X))))) &&
-            X->getType()->isIntegerTy(1))
-          return SelectLike(I);
-        break;
-      }
-      }
-
-      return SelectLike(nullptr);
-    }
-
-    bool isValid() { return I; }
-    operator bool() { return isValid(); }
-
-    /// Invert the select by inverting the condition and switching the operands.
-    void setInverted() {
-      assert(!Inverted && "Trying to invert an inverted SelectLike");
-      assert(isa<Instruction>(getCondition()) &&
-             cast<Instruction>(getCondition())->getOpcode() ==
-                 Instruction::Xor);
-      Inverted = true;
-    }
-    bool isInverted() const { return Inverted; }
->>>>>>> 7e5ca4eafa3c ([SelectOpt] Support add and sub with zext operands.)
 
     Instruction *getI() { return I; }
     const Instruction *getI() const { return I; }
@@ -236,7 +195,6 @@ class SelectOptimizeImpl {
           return It != InstCostMap.end() ? It->second.NonPredCost
                                          : Scaled64::getZero();
         }
-<<<<<<< HEAD
         return Scaled64::getZero();
       }
       // If getTrue(False)Value() return nullptr, it means we are dealing with
@@ -254,48 +212,6 @@ class SelectOptimizeImpl {
           TotalCost += It->second.NonPredCost;
       }
       return TotalCost;
-=======
-
-      // BinaryOp case - add the cost of an extra BinOp to the cost of the False
-      // case.
-      if (isa<BinaryOperator>(I)) {
-        if (auto OpI = dyn_cast<Instruction>(getFalseValue())) {
-          auto It = InstCostMap.find(I);
-          if (It != InstCostMap.end()) {
-            InstructionCost OrCost = TTI->getArithmeticInstrCost(
-                I->getOpcode(), OpI->getType(),
-                TargetTransformInfo::TCK_Latency,
-                {TargetTransformInfo::OK_AnyValue,
-                 TargetTransformInfo::OP_None},
-                {TTI::OK_UniformConstantValue, TTI::OP_PowerOf2});
-            return It->second.NonPredCost + Scaled64::get(*OrCost.getValue());
-          }
-        }
-      }
-
-      return Scaled64::getZero();
-    }
-
-    /// Return the NonPredCost cost of the false op, given the costs in
-    /// InstCostMap. This may need to be generated for select-like instructions.
-    Scaled64
-    getFalseOpCost(DenseMap<const Instruction *, CostInfo> &InstCostMap,
-                   const TargetTransformInfo *TTI) {
-      if (isa<SelectInst>(I))
-        if (auto *I = dyn_cast<Instruction>(getFalseValue())) {
-          auto It = InstCostMap.find(I);
-          return It != InstCostMap.end() ? It->second.NonPredCost
-                                         : Scaled64::getZero();
-        }
-
-      // Or case - return the cost of the false case
-      if (isa<BinaryOperator>(I))
-        if (auto I = dyn_cast<Instruction>(getFalseValue()))
-          if (auto It = InstCostMap.find(I); It != InstCostMap.end())
-            return It->second.NonPredCost;
-
-      return Scaled64::getZero();
->>>>>>> 7e5ca4eafa3c ([SelectOpt] Support add and sub with zext operands.)
     }
   };
 
@@ -572,23 +488,11 @@ static Value *getTrueOrFalseValue(
     return V;
   }
 
-<<<<<<< HEAD
   auto *BO = cast<BinaryOperator>(SI.getI());
-  assert(BO->getOpcode() == Instruction::Or &&
-         "Only currently handling Or instructions.");
-=======
-  if (auto *BinOp = dyn_cast<BinaryOperator>(SI.getI())) {
-    assert((BinOp->getOpcode() == Instruction::Add ||
-            BinOp->getOpcode() == Instruction::Or ||
-            BinOp->getOpcode() == Instruction::Sub) &&
-           "Only currently handling Add, Or and Sub instructions.");
-    V = SI.getFalseValue();
-    if (isTrue) {
-      Constant *CI = ConstantInt::get(V->getType(), 1);
-      V = IB.CreateBinOp(BinOp->getOpcode(), V, CI);
-    }
-  }
->>>>>>> 7e5ca4eafa3c ([SelectOpt] Support add and sub with zext operands.)
+  assert((BO->getOpcode() == Instruction::Add ||
+          BO->getOpcode() == Instruction::Or ||
+          BO->getOpcode() == Instruction::Sub) &&
+         "Only currently handling Add, Or and Sub binary operators.");
 
   auto *CBO = BO->clone();
   auto CondIdx = SI.getConditionOpIndex();
@@ -884,8 +788,23 @@ void SelectOptimizeImpl::collectSelectGroups(BasicBlock &BB,
     // An Or(zext(i1 X), Y) can also be treated like a select, with condition X
     // and values Y|1 and Y.
     if (auto *BO = dyn_cast<BinaryOperator>(I)) {
-      if (BO->getType()->isIntegerTy(1) || BO->getOpcode() != Instruction::Or)
-        return SelectInfo.end();
+      switch (I->getOpcode()) {
+      case Instruction::Add:
+      case Instruction::Sub: {
+        Value *X;
+        if (!((PatternMatch::match(I->getOperand(0),
+                                   m_OneUse(m_ZExt(m_Value(X)))) ||
+               PatternMatch::match(I->getOperand(1),
+                                   m_OneUse(m_ZExt(m_Value(X))))) &&
+              X->getType()->isIntegerTy(1)))
+          return SelectInfo.end();
+        break;
+      }
+      case Instruction::Or:
+        if (BO->getType()->isIntegerTy(1) || BO->getOpcode() != Instruction::Or)
+          return SelectInfo.end();
+        break;
+      }
 
       for (unsigned Idx = 0; Idx < 2; Idx++) {
         auto *Op = BO->getOperand(Idx);
diff --git a/llvm/test/CodeGen/AArch64/selectopt-cast.ll b/llvm/test/CodeGen/AArch64/selectopt-cast.ll
index 6489c8d6c2d0ae..102b89df32b03b 100644
--- a/llvm/test/CodeGen/AArch64/selectopt-cast.ll
+++ b/llvm/test/CodeGen/AArch64/selectopt-cast.ll
@@ -17,12 +17,12 @@ define void @test_add_zext(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.star
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[CMP3]] to i64
 ; CHECK-NEXT:    [[CMP3_FROZEN:%.*]] = freeze i1 [[CMP3]]
-; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[J]], 1
-; CHECK-NEXT:    br i1 [[CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE:%.*]]
-; CHECK:       select.false:
+; CHECK-NEXT:    br i1 [[CMP3_FROZEN]], label [[SELECT_TRUE_SINK:%.*]], label [[SELECT_END]]
+; CHECK:       select.true.sink:
+; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i64 [[J]], 1
 ; CHECK-NEXT:    br label [[SELECT_END]]
 ; CHECK:       select.end:
-; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[LOOP]] ], [ [[J]], [[SELECT_FALSE]] ]
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[SELECT_TRUE_SINK]] ], [ [[J]], [[LOOP]] ]
 ; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
 ; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
 ; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
@@ -60,20 +60,26 @@ define void @test_add_zext_first_op(ptr %dst, ptr %src, i64 %j.start, i64 %p, i6
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[J_START:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[LOOP]] ]
-; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
-; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_I]], align 8
-; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[SRC]], i64 [[J]]
+; CHECK-NEXT:    [[IV1:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[SELECT_END:%.*]] ]
+; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[J_START:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
+; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_DST]], align 8
+; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[DST]], i64 [[J]]
 ; CHECK-NEXT:    [[L_J:%.*]] = load ptr, ptr [[GEP_J]], align 8
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[CMP3]] to i64
-; CHECK-NEXT:    [[J_NEXT]] = add nsw i64 [[DEC]], [[J]]
-; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
-; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
-; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
-; CHECK-NEXT:    [[EC:%.*]] = icmp eq i64 [[IV]], [[J_START]]
+; CHECK-NEXT:    [[CMP3_FROZEN:%.*]] = freeze i1 [[CMP3]]
+; CHECK-NEXT:    br i1 [[CMP3_FROZEN]], label [[SELECT_TRUE_SINK:%.*]], label [[SELECT_END]]
+; CHECK:       select.true.sink:
+; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i64 1, [[J]]
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.end:
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[SELECT_TRUE_SINK]] ], [ [[J]], [[LOOP]] ]
+; CHECK-NEXT:    [[GEP_DST1:%.*]] = getelementptr inbounds ptr, ptr [[DST1:%.*]], i64 [[IV1]]
+; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST1]], align 8
+; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV1]], 1
+; CHECK-NEXT:    [[EC:%.*]] = icmp eq i64 [[IV1]], [[J_START]]
 ; CHECK-NEXT:    br i1 [[EC]], label [[EXIT:%.*]], label [[LOOP]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    ret void
@@ -107,6 +113,9 @@ define void @test_add_zext_not(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[SELECT_END:%.*]] ]
+; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[J_START:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[SELECT_END]] ]
 ; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
 ; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_I]], align 8
 ; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[SRC]], i64 [[J]]
@@ -114,13 +123,13 @@ define void @test_add_zext_not(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[NOT_CMP3:%.*]] = xor i1 [[CMP3]], true
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[NOT_CMP3]] to i64
-; CHECK-NEXT:    [[NOT_CMP3_FROZEN:%.*]] = freeze i1 [[NOT_CMP3]]
-; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[J]], 1
-; CHECK-NEXT:    br i1 [[NOT_CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE:%.*]]
-; CHECK:       select.false:
+; CHECK-NEXT:    [[CMP3_FROZEN:%.*]] = freeze i1 [[CMP3]]
+; CHECK-NEXT:    br i1 [[CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE_SINK:%.*]]
+; CHECK:       select.false.sink:
+; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i64 [[J]], 1
 ; CHECK-NEXT:    br label [[SELECT_END]]
 ; CHECK:       select.end:
-; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[LOOP]] ], [ [[J]], [[SELECT_FALSE]] ]
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[J]], [[LOOP]] ], [ [[TMP0]], [[SELECT_FALSE_SINK]] ]
 ; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
 ; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
 ; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
@@ -375,12 +384,12 @@ define void @test_sub_zext(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.star
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[CMP3]] to i64
 ; CHECK-NEXT:    [[CMP3_FROZEN:%.*]] = freeze i1 [[CMP3]]
-; CHECK-NEXT:    [[TMP0:%.*]] = sub i64 [[J]], 1
-; CHECK-NEXT:    br i1 [[CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE:%.*]]
-; CHECK:       select.false:
+; CHECK-NEXT:    br i1 [[CMP3_FROZEN]], label [[SELECT_TRUE_SINK:%.*]], label [[SELECT_END]]
+; CHECK:       select.true.sink:
+; CHECK-NEXT:    [[TMP0:%.*]] = sub nsw i64 [[J]], 1
 ; CHECK-NEXT:    br label [[SELECT_END]]
 ; CHECK:       select.end:
-; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[LOOP]] ], [ [[J]], [[SELECT_FALSE]] ]
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[SELECT_TRUE_SINK]] ], [ [[J]], [[LOOP]] ]
 ; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
 ; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
 ; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
@@ -418,20 +427,26 @@ define void @test_sub_zext_first_op(ptr %dst, ptr %src, i64 %j.start, i64 %p, i6
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[J_START:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[LOOP]] ]
-; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
-; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_I]], align 8
-; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[SRC]], i64 [[J]]
+; CHECK-NEXT:    [[IV1:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[SELECT_END:%.*]] ]
+; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[J_START:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
+; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_DST]], align 8
+; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[DST]], i64 [[J]]
 ; CHECK-NEXT:    [[L_J:%.*]] = load ptr, ptr [[GEP_J]], align 8
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[CMP3]] to i64
-; CHECK-NEXT:    [[J_NEXT]] = sub nsw i64 [[DEC]], [[J]]
-; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
-; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
-; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
-; CHECK-NEXT:    [[EC:%.*]] = icmp eq i64 [[IV]], [[J_START]]
+; CHECK-NEXT:    [[CMP3_FROZEN:%.*]] = freeze i1 [[CMP3]]
+; CHECK-NEXT:    br i1 [[CMP3_FROZEN]], label [[SELECT_TRUE_SINK:%.*]], label [[SELECT_END]]
+; CHECK:       select.true.sink:
+; CHECK-NEXT:    [[TMP0:%.*]] = sub nsw i64 1, [[J]]
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.end:
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[SELECT_TRUE_SINK]] ], [ [[J]], [[LOOP]] ]
+; CHECK-NEXT:    [[GEP_DST1:%.*]] = getelementptr inbounds ptr, ptr [[DST1:%.*]], i64 [[IV1]]
+; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST1]], align 8
+; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV1]], 1
+; CHECK-NEXT:    [[EC:%.*]] = icmp eq i64 [[IV1]], [[J_START]]
 ; CHECK-NEXT:    br i1 [[EC]], label [[EXIT:%.*]], label [[LOOP]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    ret void
@@ -465,6 +480,9 @@ define void @test_sub_zext_not(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[SELECT_END:%.*]] ]
+; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[J_START:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[SELECT_END]] ]
 ; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
 ; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_I]], align 8
 ; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[SRC]], i64 [[J]]
@@ -472,13 +490,13 @@ define void @test_sub_zext_not(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[NOT_CMP3:%.*]] = xor i1 [[CMP3]], true
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[NOT_CMP3]] to i64
-; CHECK-NEXT:    [[NOT_CMP3_FROZEN:%.*]] = freeze i1 [[NOT_CMP3]]
-; CHECK-NEXT:    [[TMP0:%.*]] = sub i64 [[J]], 1
-; CHECK-NEXT:    br i1 [[NOT_CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE:%.*]]
-; CHECK:       select.false:
+; CHECK-NEXT:    [[CMP3_FROZEN:%.*]] = freeze i1 [[CMP3]]
+; CHECK-NEXT:    br i1 [[CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE_SINK:%.*]]
+; CHECK:       select.false.sink:
+; CHECK-NEXT:    [[TMP0:%.*]] = sub nsw i64 [[J]], 1
 ; CHECK-NEXT:    br label [[SELECT_END]]
 ; CHECK:       select.end:
-; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[LOOP]] ], [ [[J]], [[SELECT_FALSE]] ]
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[J]], [[LOOP]] ], [ [[TMP0]], [[SELECT_FALSE_SINK]] ]
 ; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
 ; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
 ; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1



More information about the llvm-commits mailing list