[llvm] Combine (X ^ Y) and (X == Y) where appropriate (PR #130922)

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 12 00:47:55 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Ryan Buchner (bababuck)

<details>
<summary>Changes</summary>

Fixes #<!-- -->130510.

In RISCV, modify the folding of (X ^ Y == 0) -> (X == Y) to account for cases where the (X ^ Y) will be re-used.

If a constant is being used for the XOR before a branch, ensure that it is small enough to fit within a 12-bit immediate field. Otherwise, the equality check is more efficient than the check against 0, see the following:
```
# %bb.0:
        lui     a1, 5
        addiw   a1, a1, 1365
        xor     a0, a0, a1
        beqz    a0, .LBB0_2
# %bb.1: 
        ret
.LBB0_2: 
```

```
# %bb.0:
        lui     a1, 5
        addiw   a1, a1, 1365
        beq    a0, a1, .LBB0_2
# %bb.1: 
        xor     a0, a0, a1
        ret
.LBB0_2: 
```

Similarly, if the XOR is between 1 and a size one integer, we should still fold away the XOR since that comparison can be optimized as a comparison against 0.
```
# %bb.0:
        slt a0, a0, a1
        xor  a0, a0, 1
        beqz    a0, .LBB0_2
# %bb.1: 
        ret
.LBB0_2: 
```

```
# %bb.0:
        slt a0, a0, a1
        bnez    a0, .LBB0_2
# %bb.1: 
        xor  a0, a0, 1
        ret
.LBB0_2: 
```

One question about my code is that I used a hard-coded value for the width of a RISCV ALU immediate. Do you know of a way that I can gather this from the `context`, I was unable to devise one.

---
Full diff: https://github.com/llvm/llvm-project/pull/130922.diff


3 Files Affected:

- (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+2-1) 
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+40-1) 
- (modified) llvm/test/CodeGen/RISCV/select-constant-xor.ll (+40) 


``````````diff
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index d5fbd4c380746..2acb7cb321d07 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -8578,7 +8578,8 @@ static bool optimizeBranch(BranchInst *Branch, const TargetLowering &TLI,
     }
     if (Cmp->isEquality() &&
         (match(UI, m_Add(m_Specific(X), m_SpecificInt(-CmpC))) ||
-         match(UI, m_Sub(m_Specific(X), m_SpecificInt(CmpC))))) {
+         match(UI, m_Sub(m_Specific(X), m_SpecificInt(CmpC))) ||
+         match(UI, m_Xor(m_Specific(X), m_SpecificInt(CmpC))))) {
       IRBuilder<> Builder(Branch);
       if (UI->getParent() != Branch->getParent())
         UI->moveBefore(Branch->getIterator());
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 27a4bbce1f5fc..3abc835376f7b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -17194,8 +17194,47 @@ static bool combine_CC(SDValue &LHS, SDValue &RHS, SDValue &CC, const SDLoc &DL,
     return true;
   }
 
+  // If XOR is reused and has an immediate that will fit in XORI,
+  // do not fold
+  auto Is12BitConstant = [](const SDValue &Op) -> bool {
+    if (Op.getOpcode() == ISD::Constant) {
+      const int64_t RiscvAluImmBits = 12;
+      const int64_t RiscvAluImmUpperBound = (1 << RiscvAluImmBits) - 1;
+      const int64_t RiscvAluImmLowerBound = -(1 << RiscvAluImmBits);
+      const int64_t XorCnst =
+          llvm::dyn_cast<llvm::ConstantSDNode>(Op)->getSExtValue();
+      return (XorCnst >= RiscvAluImmLowerBound) &&
+             (XorCnst <= RiscvAluImmUpperBound);
+    }
+    return false;
+  };
+  // Fold (X(i1) ^ 1) == 0 -> X != 0
+  auto SingleBitOp = [&DAG](const SDValue &VarOp,
+                            const SDValue &ConstOp) -> bool {
+    if (ConstOp.getOpcode() == ISD::Constant) {
+      const int64_t XorCnst =
+          llvm::dyn_cast<llvm::ConstantSDNode>(ConstOp)->getSExtValue();
+      const APInt Mask = APInt::getBitsSetFrom(VarOp.getValueSizeInBits(), 1);
+      return (XorCnst == 1) && DAG.MaskedValueIsZero(VarOp, Mask);
+    }
+    return false;
+  };
+  auto OnlyUsedBySelectOrBR = [](const SDValue &Op) -> bool {
+    for (const SDUse &Use : Op->uses()) {
+      const SDNode *UseNode = Use.getUser();
+      const unsigned Opcode = UseNode->getOpcode();
+      if (Opcode != RISCVISD::SELECT_CC && Opcode != RISCVISD::BR_CC) {
+        return false;
+      }
+    }
+    return true;
+  };
+
   // Fold ((xor X, Y), 0, eq/ne) -> (X, Y, eq/ne)
-  if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS)) {
+  if (LHS.getOpcode() == ISD::XOR && isNullConstant(RHS) &&
+      (!Is12BitConstant(LHS.getOperand(1)) ||
+       SingleBitOp(LHS.getOperand(0), LHS.getOperand(1))) &&
+      OnlyUsedBySelectOrBR(LHS)) {
     RHS = LHS.getOperand(1);
     LHS = LHS.getOperand(0);
     return true;
diff --git a/llvm/test/CodeGen/RISCV/select-constant-xor.ll b/llvm/test/CodeGen/RISCV/select-constant-xor.ll
index 2e26ae78e2dd8..f24e03ecd7d67 100644
--- a/llvm/test/CodeGen/RISCV/select-constant-xor.ll
+++ b/llvm/test/CodeGen/RISCV/select-constant-xor.ll
@@ -239,3 +239,43 @@ define i32 @oneusecmp(i32 %a, i32 %b, i32 %d) {
   %x = add i32 %s, %s2
   ret i32 %x
 }
+
+define i32 @xor_branch_ret(i32 %x) {
+; RV32-LABEL: xor_branch_ret:
+; RV32:       # %bb.0: # %entry
+; RV32-NEXT:    xori a0, a0, -1365
+; RV32-NEXT:    beqz a0, .LBB11_2
+; RV32-NEXT:  # %bb.1: # %if.then
+; RV32-NEXT:    ret
+; RV32-NEXT:  .LBB11_2: # %if.end
+; RV32-NEXT:    addi sp, sp, -16
+; RV32-NEXT:    .cfi_def_cfa_offset 16
+; RV32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT:    .cfi_offset ra, -4
+; RV32-NEXT:    call abort
+;
+; RV64-LABEL: xor_branch_ret:
+; RV64:       # %bb.0: # %entry
+; RV64-NEXT:    xori a0, a0, -1365
+; RV64-NEXT:    sext.w a1, a0
+; RV64-NEXT:    beqz a1, .LBB11_2
+; RV64-NEXT:  # %bb.1: # %if.then
+; RV64-NEXT:    ret
+; RV64-NEXT:  .LBB11_2: # %if.end
+; RV64-NEXT:    addi sp, sp, -16
+; RV64-NEXT:    .cfi_def_cfa_offset 16
+; RV64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64-NEXT:    .cfi_offset ra, -8
+; RV64-NEXT:    call abort
+entry:
+  %cmp.not = icmp eq i32 %x, -1365
+  br i1 %cmp.not, label %if.end, label %if.then
+if.then:
+  %xor = xor i32 %x, -1365
+  ret i32 %xor
+if.end:
+    tail call void @abort() #2
+  unreachable
+}
+
+declare void @abort()

``````````

</details>


https://github.com/llvm/llvm-project/pull/130922


More information about the llvm-commits mailing list