[llvm] [ARM][TargetLowering] Combine Level should not be a factor in shouldFoldConstantShiftPairToMask (PR #156949)

via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 4 11:56:27 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-arm

@llvm/pr-subscribers-backend-mips

Author: AZero13 (AZero13)

<details>
<summary>Changes</summary>

This should be based on the type and instructions, and only thumb uses combine level anyway.

---
Full diff: https://github.com/llvm/llvm-project/pull/156949.diff


10 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+1-2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+2-2) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+1-1) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+1-2) 
- (modified) llvm/lib/Target/ARM/ARMISelLowering.cpp (+3-2) 
- (modified) llvm/lib/Target/ARM/ARMISelLowering.h (+1-2) 
- (modified) llvm/lib/Target/Mips/MipsISelLowering.cpp (+1-1) 
- (modified) llvm/lib/Target/Mips/MipsISelLowering.h (+1-2) 
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+2-2) 
- (modified) llvm/lib/Target/X86/X86ISelLowering.h (+1-2) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 438b6ff55c85f..119209fd0a6fb 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -854,8 +854,7 @@ class LLVM_ABI TargetLoweringBase {
   /// This is usually true on most targets. But some targets, like Thumb1,
   /// have immediate shift instructions, but no immediate "and" instruction;
   /// this makes the fold unprofitable.
-  virtual bool shouldFoldConstantShiftPairToMask(const SDNode *N,
-                                                 CombineLevel Level) const {
+  virtual bool shouldFoldConstantShiftPairToMask(const SDNode *N) const {
     return true;
   }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index bed3c42473e27..ec5ce31fc8777 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10521,7 +10521,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
     // folding this will increase the total number of instructions.
     if (N0.getOpcode() == ISD::SRL &&
         (N0.getOperand(1) == N1 || N0.hasOneUse()) &&
-        TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
+        TLI.shouldFoldConstantShiftPairToMask(N)) {
       if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
                                     /*AllowUndefs*/ false,
                                     /*AllowTypeMismatch*/ true)) {
@@ -11100,7 +11100,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
     // fold (srl (shl x, c1), c2) -> (and (shl x, (sub c1, c2), MASK) or
     //                               (and (srl x, (sub c2, c1), MASK)
     if ((N0.getOperand(1) == N1 || N0->hasOneUse()) &&
-        TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
+        TLI.shouldFoldConstantShiftPairToMask(N)) {
       auto MatchShiftAmount = [OpSizeInBits](ConstantSDNode *LHS,
                                              ConstantSDNode *RHS) {
         const APInt &LHSC = LHS->getAPIntValue();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b7011e0ea1669..77bc780295a88 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -18280,7 +18280,7 @@ bool AArch64TargetLowering::isDesirableToCommuteXorWithShift(
 }
 
 bool AArch64TargetLowering::shouldFoldConstantShiftPairToMask(
-    const SDNode *N, CombineLevel Level) const {
+    const SDNode *N) const {
   assert(((N->getOpcode() == ISD::SHL &&
            N->getOperand(0).getOpcode() == ISD::SRL) ||
           (N->getOpcode() == ISD::SRL &&
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 46738365080f9..6ad7c6f2a5975 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -297,8 +297,7 @@ class AArch64TargetLowering : public TargetLowering {
   bool isDesirableToCommuteXorWithShift(const SDNode *N) const override;
 
   /// Return true if it is profitable to fold a pair of shifts into a mask.
-  bool shouldFoldConstantShiftPairToMask(const SDNode *N,
-                                         CombineLevel Level) const override;
+  bool shouldFoldConstantShiftPairToMask(const SDNode *N) const override;
 
   bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
                                             unsigned SelectOpcode, SDValue X,
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index b5c01eafcf108..cb39a40f88d31 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -13878,7 +13878,7 @@ bool ARMTargetLowering::isDesirableToCommuteXorWithShift(
 }
 
 bool ARMTargetLowering::shouldFoldConstantShiftPairToMask(
-    const SDNode *N, CombineLevel Level) const {
+    const SDNode *N) const {
   assert(((N->getOpcode() == ISD::SHL &&
            N->getOperand(0).getOpcode() == ISD::SRL) ||
           (N->getOpcode() == ISD::SRL &&
@@ -13888,7 +13888,8 @@ bool ARMTargetLowering::shouldFoldConstantShiftPairToMask(
   if (!Subtarget->isThumb1Only())
     return true;
 
-  if (Level == BeforeLegalizeTypes)
+  EVT VT = N->getValueType(0);
+  if (VT.getScalarSizeInBits() > 32)
     return true;
 
   return false;
diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h
index 196ecb1b9f678..fbf1d9b48bb4a 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.h
+++ b/llvm/lib/Target/ARM/ARMISelLowering.h
@@ -770,8 +770,7 @@ class VectorType;
 
     bool isDesirableToCommuteXorWithShift(const SDNode *N) const override;
 
-    bool shouldFoldConstantShiftPairToMask(const SDNode *N,
-                                           CombineLevel Level) const override;
+    bool shouldFoldConstantShiftPairToMask(const SDNode *N) const override;
 
     bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
                                               unsigned SelectOpcode, SDValue X,
diff --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp
index 1491300e37d3e..956625487be99 100644
--- a/llvm/lib/Target/Mips/MipsISelLowering.cpp
+++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp
@@ -1306,7 +1306,7 @@ bool MipsTargetLowering::hasBitTest(SDValue X, SDValue Y) const {
 }
 
 bool MipsTargetLowering::shouldFoldConstantShiftPairToMask(
-    const SDNode *N, CombineLevel Level) const {
+    const SDNode *N) const {
   assert(((N->getOpcode() == ISD::SHL &&
            N->getOperand(0).getOpcode() == ISD::SRL) ||
           (N->getOpcode() == ISD::SRL &&
diff --git a/llvm/lib/Target/Mips/MipsISelLowering.h b/llvm/lib/Target/Mips/MipsISelLowering.h
index c65c76ccffc75..25a0bf9b797d5 100644
--- a/llvm/lib/Target/Mips/MipsISelLowering.h
+++ b/llvm/lib/Target/Mips/MipsISelLowering.h
@@ -290,8 +290,7 @@ class TargetRegisterClass;
     bool isCheapToSpeculateCttz(Type *Ty) const override;
     bool isCheapToSpeculateCtlz(Type *Ty) const override;
     bool hasBitTest(SDValue X, SDValue Y) const override;
-    bool shouldFoldConstantShiftPairToMask(const SDNode *N,
-                                           CombineLevel Level) const override;
+    bool shouldFoldConstantShiftPairToMask(const SDNode *N) const override;
 
     /// Return the register type for a given MVT, ensuring vectors are treated
     /// as a series of gpr sized integers.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 08ae0d52d795e..d69fd099df875 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -3638,7 +3638,7 @@ bool X86TargetLowering::preferScalarizeSplat(SDNode *N) const {
 }
 
 bool X86TargetLowering::shouldFoldConstantShiftPairToMask(
-    const SDNode *N, CombineLevel Level) const {
+    const SDNode *N) const {
   assert(((N->getOpcode() == ISD::SHL &&
            N->getOperand(0).getOpcode() == ISD::SRL) ||
           (N->getOpcode() == ISD::SRL &&
@@ -3653,7 +3653,7 @@ bool X86TargetLowering::shouldFoldConstantShiftPairToMask(
     // the fold for non-splats yet.
     return N->getOperand(1) == N->getOperand(0).getOperand(1);
   }
-  return TargetLoweringBase::shouldFoldConstantShiftPairToMask(N, Level);
+  return TargetLoweringBase::shouldFoldConstantShiftPairToMask(N);
 }
 
 bool X86TargetLowering::shouldFoldMaskToVariableShiftPair(SDValue Y) const {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index d888f9f593ee7..2fcf33938a596 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -1240,8 +1240,7 @@ namespace llvm {
     getJumpConditionMergingParams(Instruction::BinaryOps Opc, const Value *Lhs,
                                   const Value *Rhs) const override;
 
-    bool shouldFoldConstantShiftPairToMask(const SDNode *N,
-                                           CombineLevel Level) const override;
+    bool shouldFoldConstantShiftPairToMask(const SDNode *N) const override;
 
     bool shouldFoldMaskToVariableShiftPair(SDValue Y) const override;
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/156949


More information about the llvm-commits mailing list