[llvm] [DAGCombiner] Add sra-xor-sra pattern fold (PR #166777)
guan jian via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 6 06:05:50 PST 2025
https://github.com/rez5427 created https://github.com/llvm/llvm-project/pull/166777
Add `fold (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3)`
alive2: https://alive2.llvm.org/ce/z/yxRQf9
>From f075b94d10f86cd5279c901d3695b44731bc5b42 Mon Sep 17 00:00:00 2001
From: rez5427 <guanjian at stu.cdut.edu.cn>
Date: Thu, 6 Nov 2025 22:04:01 +0800
Subject: [PATCH] [DAGCombiner] add sra xor sra pattern fold
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 20 ++++++++++++
llvm/test/CodeGen/RISCV/sra-xor-sra.ll | 32 +++++++++++++++++++
2 files changed, 52 insertions(+)
create mode 100644 llvm/test/CodeGen/RISCV/sra-xor-sra.ll
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d2ea6525e1116..7ab460bef019e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10968,6 +10968,26 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
}
}
+ // fold (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3)
+ if (N0.getOpcode() == ISD::XOR && N0.hasOneUse() &&
+ isAllOnesConstant(N0.getOperand(1))) {
+ SDValue Inner = N0.getOperand(0);
+ if (Inner.getOpcode() == ISD::SRA && N1C) {
+ if (ConstantSDNode *InnerShiftAmt = isConstOrConstSplat(Inner.getOperand(1))) {
+ APInt c1 = InnerShiftAmt->getAPIntValue();
+ APInt c2 = N1C->getAPIntValue();
+ zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+ APInt Sum = c1 + c2;
+ unsigned ShiftSum =
+ Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
+ SDValue NewShift = DAG.getNode(ISD::SRA, DL, VT, Inner.getOperand(0),
+ DAG.getConstant(ShiftSum, DL, N1.getValueType()));
+ return DAG.getNode(ISD::XOR, DL, VT, NewShift,
+ DAG.getAllOnesConstant(DL, VT));
+ }
+ }
+ }
+
// fold (sra (shl X, m), (sub result_size, n))
// -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
// result_size - n != m.
diff --git a/llvm/test/CodeGen/RISCV/sra-xor-sra.ll b/llvm/test/CodeGen/RISCV/sra-xor-sra.ll
new file mode 100644
index 0000000000000..b04f0a29d07f3
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/sra-xor-sra.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s | FileCheck %s
+
+; Test folding of: (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3)
+; Original motivating example: should merge sra+sra across xor
+define i16 @not_invert_signbit_splat_mask(i8 %x, i16 %y) {
+; CHECK-LABEL: not_invert_signbit_splat_mask:
+; CHECK: # %bb.0:
+; CHECK-NEXT: slli a0, a0, 56
+; CHECK-NEXT: srai a0, a0, 62
+; CHECK-NEXT: not a0, a0
+; CHECK-NEXT: and a0, a0, a1
+; CHECK-NEXT: ret
+ %a = ashr i8 %x, 6
+ %n = xor i8 %a, -1
+ %s = sext i8 %n to i16
+ %r = and i16 %s, %y
+ ret i16 %r
+}
+
+; Edge case
+define i16 @sra_xor_sra_overflow(i8 %x, i16 %y) {
+; CHECK-LABEL: sra_xor_sra_overflow:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 0
+; CHECK-NEXT: ret
+ %a = ashr i8 %x, 10
+ %n = xor i8 %a, -1
+ %s = sext i8 %n to i16
+ %r = and i16 %s, %y
+ ret i16 %r
+}
More information about the llvm-commits
mailing list