[llvm] [RISCV] Fold (X & -(1 << C1) & 0xffffffff) == C2 << C1 to sraiw X, C1 == C2. (PR #157617)

Piotr Fusik via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 9 00:05:15 PDT 2025


================
@@ -16761,29 +16761,36 @@ static SDValue performSETCCCombine(SDNode *N,
           combineVectorSizedSetCCEquality(VT, N0, N1, Cond, dl, DAG, Subtarget))
     return V;
 
-  if (DCI.isAfterLegalizeDAG() && isNullConstant(N1) &&
+  if (DCI.isAfterLegalizeDAG() && isa<ConstantSDNode>(N1) &&
       N0.getOpcode() == ISD::AND && N0.hasOneUse() &&
       isa<ConstantSDNode>(N0.getOperand(1))) {
-    const APInt &AndRHSC =
-        cast<ConstantSDNode>(N0.getOperand(1))->getAPIntValue();
+    const APInt &AndRHSC = N0.getConstantOperandAPInt(1);
     // (X & -(1 << C)) == 0 -> (X >> C) == 0 if the AND constant can't use ANDI.
-    if (!isInt<12>(AndRHSC.getSExtValue()) && AndRHSC.isNegatedPowerOf2()) {
+    if (isNullConstant(N1) && !isInt<12>(AndRHSC.getSExtValue()) &&
+        AndRHSC.isNegatedPowerOf2()) {
       unsigned ShiftBits = AndRHSC.countr_zero();
       SDValue Shift = DAG.getNode(ISD::SRL, dl, OpVT, N0.getOperand(0),
                                   DAG.getConstant(ShiftBits, dl, OpVT));
       return DAG.getSetCC(dl, VT, Shift, N1, Cond);
     }
 
-    // Similar to above but handling the lower 32 bits by using srliw.
-    // FIXME: Handle the case where N1 is non-zero.
+    // Similar to above but handling the lower 32 bits by using sraiw. Allow
+    // comparing with constants other than 0 if the constant can be folded into
+    // addi or xori after shifting.
     if (OpVT == MVT::i64 && AndRHSC.getZExtValue() <= 0xffffffff &&
         isPowerOf2_32(-uint32_t(AndRHSC.getZExtValue()))) {
       unsigned ShiftBits = llvm::countr_zero(AndRHSC.getZExtValue());
-      SDValue And = DAG.getNode(ISD::AND, dl, OpVT, N0.getOperand(0),
-                                DAG.getConstant(0xffffffff, dl, OpVT));
-      SDValue Shift = DAG.getNode(ISD::SRL, dl, OpVT, And,
-                                  DAG.getConstant(ShiftBits, dl, OpVT));
-      return DAG.getSetCC(dl, VT, Shift, N1, Cond);
+      int64_t NewC = SignExtend64<32>(cast<ConstantSDNode>(N1)->getSExtValue());
+      NewC >>= ShiftBits;
+      if (isInt<12>(NewC) || NewC == 2048) {
----------------
pfusik wrote:

```suggestion
      if (NewC >= -2048 && NewC <= 2048) {
```
easier to read for me.

https://github.com/llvm/llvm-project/pull/157617


More information about the llvm-commits mailing list