[llvm] [SelectOpt] Support ADD and SUB with zext operands. (PR #115489)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 12 06:18:22 PST 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/115489

>From 7e5ca4eafa3caebaec20ce8d36dfec6c9c70d129 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 16 Oct 2024 19:26:59 +0100
Subject: [PATCH] [SelectOpt] Support add and sub with zext operands.

Extend the support for implicit selects in the form of OR with a
ZExt operand to support ADD and SUB binops as well. They similarly can
form implicit selects which can be profitable to convert back the
branches.
---
 llvm/lib/CodeGen/SelectOptimize.cpp           | 44 ++++++++++-----
 .../AArch64/AArch64TargetTransformInfo.cpp    | 22 +++++---
 llvm/test/CodeGen/AArch64/selectopt-cast.ll   | 56 +++++++++++++------
 3 files changed, 84 insertions(+), 38 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectOptimize.cpp b/llvm/lib/CodeGen/SelectOptimize.cpp
index 9587534d1170fc..ce3c90d43391d9 100644
--- a/llvm/lib/CodeGen/SelectOptimize.cpp
+++ b/llvm/lib/CodeGen/SelectOptimize.cpp
@@ -145,11 +145,20 @@ class SelectOptimizeImpl {
 
       // An Or(zext(i1 X), Y) can also be treated like a select, with condition
       // C and values Y|1 and Y.
-      Value *X;
-      if (PatternMatch::match(
-              I, m_c_Or(m_OneUse(m_ZExt(m_Value(X))), m_Value())) &&
-          X->getType()->isIntegerTy(1))
-        return SelectLike(I);
+      switch (I->getOpcode()) {
+      case Instruction::Add:
+      case Instruction::Or:
+      case Instruction::Sub: {
+        Value *X;
+        if ((PatternMatch::match(I->getOperand(0),
+                                 m_OneUse(m_ZExt(m_Value(X)))) ||
+             PatternMatch::match(I->getOperand(1),
+                                 m_OneUse(m_ZExt(m_Value(X))))) &&
+            X->getType()->isIntegerTy(1))
+          return SelectLike(I);
+        break;
+      }
+      }
 
       return SelectLike(nullptr);
     }
@@ -250,19 +259,22 @@ class SelectOptimizeImpl {
                                          : Scaled64::getZero();
         }
 
-      // Or case - add the cost of an extra Or to the cost of the False case.
-      if (isa<BinaryOperator>(I))
-        if (auto I = dyn_cast<Instruction>(getFalseValue())) {
+      // BinaryOp case - add the cost of an extra BinOp to the cost of the False
+      // case.
+      if (isa<BinaryOperator>(I)) {
+        if (auto OpI = dyn_cast<Instruction>(getFalseValue())) {
           auto It = InstCostMap.find(I);
           if (It != InstCostMap.end()) {
             InstructionCost OrCost = TTI->getArithmeticInstrCost(
-                Instruction::Or, I->getType(), TargetTransformInfo::TCK_Latency,
+                I->getOpcode(), OpI->getType(),
+                TargetTransformInfo::TCK_Latency,
                 {TargetTransformInfo::OK_AnyValue,
                  TargetTransformInfo::OP_None},
                 {TTI::OK_UniformConstantValue, TTI::OP_PowerOf2});
             return It->second.NonPredCost + Scaled64::get(*OrCost.getValue());
           }
         }
+      }
 
       return Scaled64::getZero();
     }
@@ -548,12 +560,16 @@ getTrueOrFalseValue(SelectOptimizeImpl::SelectLike SI, bool isTrue,
       V = (!isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue());
   }
 
-  if (isa<BinaryOperator>(SI.getI())) {
-    assert(SI.getI()->getOpcode() == Instruction::Or &&
-           "Only currently handling Or instructions.");
+  if (auto *BinOp = dyn_cast<BinaryOperator>(SI.getI())) {
+    assert((BinOp->getOpcode() == Instruction::Add ||
+            BinOp->getOpcode() == Instruction::Or ||
+            BinOp->getOpcode() == Instruction::Sub) &&
+           "Only currently handling Add, Or and Sub instructions.");
     V = SI.getFalseValue();
-    if (isTrue)
-      V = IB.CreateOr(V, ConstantInt::get(V->getType(), 1));
+    if (isTrue) {
+      Constant *CI = ConstantInt::get(V->getType(), 1);
+      V = IB.CreateBinOp(BinOp->getOpcode(), V, CI);
+    }
   }
 
   assert(V && "Failed to get select true/false value");
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 71f9bbbbc35041..75eea291896ef7 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -4678,14 +4678,20 @@ AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
 }
 
 bool AArch64TTIImpl::shouldTreatInstructionLikeSelect(const Instruction *I) {
-  // For the binary operators (e.g. or) we need to be more careful than
-  // selects, here we only transform them if they are already at a natural
-  // break point in the code - the end of a block with an unconditional
-  // terminator.
-  if (EnableOrLikeSelectOpt && I->getOpcode() == Instruction::Or &&
-      isa<BranchInst>(I->getNextNode()) &&
-      cast<BranchInst>(I->getNextNode())->isUnconditional())
-    return true;
+  if (EnableOrLikeSelectOpt) {
+    // For the binary operators (e.g. or) we need to be more careful than
+    // selects, here we only transform them if they are already at a natural
+    // break point in the code - the end of a block with an unconditional
+    // terminator.
+    if (I->getOpcode() == Instruction::Or &&
+        isa<BranchInst>(I->getNextNode()) &&
+        cast<BranchInst>(I->getNextNode())->isUnconditional())
+      return true;
+
+    if (I->getOpcode() == Instruction::Add ||
+        I->getOpcode() == Instruction::Sub)
+      return true;
+  }
   return BaseT::shouldTreatInstructionLikeSelect(I);
 }
 
diff --git a/llvm/test/CodeGen/AArch64/selectopt-cast.ll b/llvm/test/CodeGen/AArch64/selectopt-cast.ll
index b1804a02ec1c82..3dfb72b299c2ea 100644
--- a/llvm/test/CodeGen/AArch64/selectopt-cast.ll
+++ b/llvm/test/CodeGen/AArch64/selectopt-cast.ll
@@ -7,16 +7,22 @@ define void @test_add_zext(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.star
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[HIGH:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[LOOP]] ]
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[SELECT_END:%.*]] ]
+; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[HIGH:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[SELECT_END]] ]
 ; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
 ; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_I]], align 8
 ; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[SRC]], i64 [[J]]
 ; CHECK-NEXT:    [[L_J:%.*]] = load ptr, ptr [[GEP_J]], align 8
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[CMP3]] to i64
-; CHECK-NEXT:    [[J_NEXT]] = add nsw i64 [[J]], [[DEC]]
+; CHECK-NEXT:    [[CMP3_FROZEN:%.*]] = freeze i1 [[CMP3]]
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[J]], 1
+; CHECK-NEXT:    br i1 [[CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE:%.*]]
+; CHECK:       select.false:
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.end:
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[LOOP]] ], [ [[J]], [[SELECT_FALSE]] ]
 ; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
 ; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
 ; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
@@ -54,9 +60,9 @@ define void @test_add_zext_not(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[HIGH:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[LOOP]] ]
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[SELECT_END:%.*]] ]
+; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[HIGH:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[SELECT_END]] ]
 ; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
 ; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_I]], align 8
 ; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[SRC]], i64 [[J]]
@@ -64,7 +70,13 @@ define void @test_add_zext_not(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[NOT_CMP3:%.*]] = xor i1 [[CMP3]], true
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[NOT_CMP3]] to i64
-; CHECK-NEXT:    [[J_NEXT]] = add nsw i64 [[J]], [[DEC]]
+; CHECK-NEXT:    [[NOT_CMP3_FROZEN:%.*]] = freeze i1 [[NOT_CMP3]]
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[J]], 1
+; CHECK-NEXT:    br i1 [[NOT_CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE:%.*]]
+; CHECK:       select.false:
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.end:
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[LOOP]] ], [ [[J]], [[SELECT_FALSE]] ]
 ; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
 ; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
 ; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
@@ -308,16 +320,22 @@ define void @test_sub_zext(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.star
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[HIGH:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[LOOP]] ]
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[SELECT_END:%.*]] ]
+; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[HIGH:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[SELECT_END]] ]
 ; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
 ; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_I]], align 8
 ; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[SRC]], i64 [[J]]
 ; CHECK-NEXT:    [[L_J:%.*]] = load ptr, ptr [[GEP_J]], align 8
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[CMP3]] to i64
-; CHECK-NEXT:    [[J_NEXT]] = sub nsw i64 [[J]], [[DEC]]
+; CHECK-NEXT:    [[CMP3_FROZEN:%.*]] = freeze i1 [[CMP3]]
+; CHECK-NEXT:    [[TMP0:%.*]] = sub i64 [[J]], 1
+; CHECK-NEXT:    br i1 [[CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE:%.*]]
+; CHECK:       select.false:
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.end:
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[LOOP]] ], [ [[J]], [[SELECT_FALSE]] ]
 ; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
 ; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
 ; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1
@@ -355,9 +373,9 @@ define void @test_sub_zext_not(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
-; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[HIGH:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[LOOP]] ]
+; CHECK-NEXT:    [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[SELECT_END:%.*]] ]
+; CHECK-NEXT:    [[J:%.*]] = phi i64 [ [[HIGH:%.*]], [[ENTRY]] ], [ [[J_NEXT:%.*]], [[SELECT_END]] ]
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_START:%.*]], [[ENTRY]] ], [ [[J_NEXT]], [[SELECT_END]] ]
 ; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr inbounds ptr, ptr [[SRC:%.*]], i64 [[I]]
 ; CHECK-NEXT:    [[L_I:%.*]] = load ptr, ptr [[GEP_I]], align 8
 ; CHECK-NEXT:    [[GEP_J:%.*]] = getelementptr inbounds ptr, ptr [[SRC]], i64 [[J]]
@@ -365,7 +383,13 @@ define void @test_sub_zext_not(ptr %dst, ptr %src, i64 %j.start, i64 %p, i64 %i.
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ult ptr [[L_I]], [[L_J]]
 ; CHECK-NEXT:    [[NOT_CMP3:%.*]] = xor i1 [[CMP3]], true
 ; CHECK-NEXT:    [[DEC:%.*]] = zext i1 [[NOT_CMP3]] to i64
-; CHECK-NEXT:    [[J_NEXT]] = sub nsw i64 [[J]], [[DEC]]
+; CHECK-NEXT:    [[NOT_CMP3_FROZEN:%.*]] = freeze i1 [[NOT_CMP3]]
+; CHECK-NEXT:    [[TMP0:%.*]] = sub i64 [[J]], 1
+; CHECK-NEXT:    br i1 [[NOT_CMP3_FROZEN]], label [[SELECT_END]], label [[SELECT_FALSE:%.*]]
+; CHECK:       select.false:
+; CHECK-NEXT:    br label [[SELECT_END]]
+; CHECK:       select.end:
+; CHECK-NEXT:    [[J_NEXT]] = phi i64 [ [[TMP0]], [[LOOP]] ], [ [[J]], [[SELECT_FALSE]] ]
 ; CHECK-NEXT:    [[GEP_DST:%.*]] = getelementptr inbounds ptr, ptr [[DST:%.*]], i64 [[IV]]
 ; CHECK-NEXT:    store i64 [[J_NEXT]], ptr [[GEP_DST]], align 8
 ; CHECK-NEXT:    [[IV_NEXT]] = add i64 [[IV]], 1



More information about the llvm-commits mailing list