[llvm] [DAGCombiner] Add sra-xor-sra pattern fold (PR #166777)

guan jian via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 6 21:16:40 PST 2025


https://github.com/rez5427 updated https://github.com/llvm/llvm-project/pull/166777

>From 1d0ea8b887ee3a8b4dd18fb07e072e0ab549a8b2 Mon Sep 17 00:00:00 2001
From: rez5427 <guanjian at stu.cdut.edu.cn>
Date: Thu, 6 Nov 2025 22:04:01 +0800
Subject: [PATCH 1/4] [DAGCombiner] add sra xor sra pattern fold

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 22 +++++++++++++
 llvm/test/CodeGen/RISCV/sra-xor-sra.ll        | 32 +++++++++++++++++++
 2 files changed, 54 insertions(+)
 create mode 100644 llvm/test/CodeGen/RISCV/sra-xor-sra.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d2ea6525e1116..b7e195d44b031 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10968,6 +10968,28 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
     }
   }
 
+  // fold (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3)
+  if (N0.getOpcode() == ISD::XOR && N0.hasOneUse() &&
+      isAllOnesConstant(N0.getOperand(1))) {
+    SDValue Inner = N0.getOperand(0);
+    if (Inner.getOpcode() == ISD::SRA && N1C) {
+      if (ConstantSDNode *InnerShiftAmt =
+              isConstOrConstSplat(Inner.getOperand(1))) {
+        APInt c1 = InnerShiftAmt->getAPIntValue();
+        APInt c2 = N1C->getAPIntValue();
+        zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+        APInt Sum = c1 + c2;
+        unsigned ShiftSum =
+            Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
+        SDValue NewShift =
+            DAG.getNode(ISD::SRA, DL, VT, Inner.getOperand(0),
+                        DAG.getConstant(ShiftSum, DL, N1.getValueType()));
+        return DAG.getNode(ISD::XOR, DL, VT, NewShift,
+                           DAG.getAllOnesConstant(DL, VT));
+      }
+    }
+  }
+
   // fold (sra (shl X, m), (sub result_size, n))
   // -> (sign_extend (trunc (shl X, (sub (sub result_size, n), m)))) for
   // result_size - n != m.
diff --git a/llvm/test/CodeGen/RISCV/sra-xor-sra.ll b/llvm/test/CodeGen/RISCV/sra-xor-sra.ll
new file mode 100644
index 0000000000000..b04f0a29d07f3
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/sra-xor-sra.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s | FileCheck %s
+
+; Test folding of: (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3)
+; Original motivating example: should merge sra+sra across xor
+define i16 @not_invert_signbit_splat_mask(i8 %x, i16 %y) {
+; CHECK-LABEL: not_invert_signbit_splat_mask:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a0, a0, 56
+; CHECK-NEXT:    srai a0, a0, 62
+; CHECK-NEXT:    not a0, a0
+; CHECK-NEXT:    and a0, a0, a1
+; CHECK-NEXT:    ret
+  %a = ashr i8 %x, 6
+  %n = xor i8 %a, -1
+  %s = sext i8 %n to i16
+  %r = and i16 %s, %y
+  ret i16 %r
+}
+
+; Edge case
+define i16 @sra_xor_sra_overflow(i8 %x, i16 %y) {
+; CHECK-LABEL: sra_xor_sra_overflow:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    li a0, 0
+; CHECK-NEXT:    ret
+  %a = ashr i8 %x, 10
+  %n = xor i8 %a, -1
+  %s = sext i8 %n to i16
+  %r = and i16 %s, %y
+  ret i16 %r
+}

>From 9f2e1e381b6b0e75fe17fb924c33f1f6ec13b7a8 Mon Sep 17 00:00:00 2001
From: rez5427 <guanjian at stu.cdut.edu.cn>
Date: Thu, 6 Nov 2025 22:42:06 +0800
Subject: [PATCH 2/4] Use sd_match

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 36 +++++++++----------
 1 file changed, 16 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b7e195d44b031..e7ab0a94c34fb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10968,26 +10968,22 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
     }
   }
 
-  // fold (sra (xor (sra x, c1), -1), c2) -> (sra (xor x, -1), c3)
-  if (N0.getOpcode() == ISD::XOR && N0.hasOneUse() &&
-      isAllOnesConstant(N0.getOperand(1))) {
-    SDValue Inner = N0.getOperand(0);
-    if (Inner.getOpcode() == ISD::SRA && N1C) {
-      if (ConstantSDNode *InnerShiftAmt =
-              isConstOrConstSplat(Inner.getOperand(1))) {
-        APInt c1 = InnerShiftAmt->getAPIntValue();
-        APInt c2 = N1C->getAPIntValue();
-        zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
-        APInt Sum = c1 + c2;
-        unsigned ShiftSum =
-            Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
-        SDValue NewShift =
-            DAG.getNode(ISD::SRA, DL, VT, Inner.getOperand(0),
-                        DAG.getConstant(ShiftSum, DL, N1.getValueType()));
-        return DAG.getNode(ISD::XOR, DL, VT, NewShift,
-                           DAG.getAllOnesConstant(DL, VT));
-      }
-    }
+  // fold (sra (xor (sra x, c1), -1), c2) -> (xor (sra x, c1+c2), -1)
+  // This allows merging two arithmetic shifts even when there's a NOT in
+  // between.
+  SDValue X;
+  APInt C1, C2;
+  if (sd_match(N0, m_OneUse(m_Xor(m_OneUse(m_Sra(m_Value(X), m_ConstInt(C1))),
+                                  m_AllOnes()))) &&
+      sd_match(N1, m_ConstInt(C2))) {
+    zeroExtendToMatch(C1, C2, 1 /* Overflow Bit */);
+    APInt Sum = C1 + C2;
+    unsigned ShiftSum =
+        Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
+    SDValue NewShift = DAG.getNode(
+        ISD::SRA, DL, VT, X, DAG.getConstant(ShiftSum, DL, N1.getValueType()));
+    return DAG.getNode(ISD::XOR, DL, VT, NewShift,
+                       DAG.getAllOnesConstant(DL, VT));
   }
 
   // fold (sra (shl X, m), (sub result_size, n))

>From 66e06e70c2f204378ef97c2f5b6e45f3428297fc Mon Sep 17 00:00:00 2001
From: rez5427 <guanjian at stu.cdut.edu.cn>
Date: Fri, 7 Nov 2025 08:31:51 +0800
Subject: [PATCH 3/4] Use m_Not

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e7ab0a94c34fb..abf790280c408 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10968,14 +10968,15 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
     }
   }
 
-  // fold (sra (xor (sra x, c1), -1), c2) -> (xor (sra x, c1+c2), -1)
+  // fold (sra (xor (sra x, c1), -1), c2) -> (xor (sra x, c3), -1)
   // This allows merging two arithmetic shifts even when there's a NOT in
   // between.
   SDValue X;
-  APInt C1, C2;
-  if (sd_match(N0, m_OneUse(m_Xor(m_OneUse(m_Sra(m_Value(X), m_ConstInt(C1))),
-                                  m_AllOnes()))) &&
-      sd_match(N1, m_ConstInt(C2))) {
+  APInt C1;
+  if (sd_match(N0,
+               m_OneUse(m_Not(m_OneUse(m_Sra(m_Value(X), m_ConstInt(C1)))))) &&
+      N1C) {
+    APInt C2 = N1C->getAPIntValue();
     zeroExtendToMatch(C1, C2, 1 /* Overflow Bit */);
     APInt Sum = C1 + C2;
     unsigned ShiftSum =

>From 17eeb2c25095213cc02277283e4fa943c5c14f96 Mon Sep 17 00:00:00 2001
From: rez5427 <guanjian at stu.cdut.edu.cn>
Date: Fri, 7 Nov 2025 13:16:28 +0800
Subject: [PATCH 4/4] Address code review feedback

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index abf790280c408..34d7e72f72435 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10973,18 +10973,16 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
   // between.
   SDValue X;
   APInt C1;
-  if (sd_match(N0,
-               m_OneUse(m_Not(m_OneUse(m_Sra(m_Value(X), m_ConstInt(C1)))))) &&
-      N1C) {
+  if (N1C && sd_match(N0, m_OneUse(m_Not(
+                              m_OneUse(m_Sra(m_Value(X), m_ConstInt(C1))))))) {
     APInt C2 = N1C->getAPIntValue();
-    zeroExtendToMatch(C1, C2, 1 /* Overflow Bit */);
+    zeroExtendToMatch(C1, C2, /*OverflowBit=*/1);
     APInt Sum = C1 + C2;
     unsigned ShiftSum =
         Sum.uge(OpSizeInBits) ? (OpSizeInBits - 1) : Sum.getZExtValue();
     SDValue NewShift = DAG.getNode(
         ISD::SRA, DL, VT, X, DAG.getConstant(ShiftSum, DL, N1.getValueType()));
-    return DAG.getNode(ISD::XOR, DL, VT, NewShift,
-                       DAG.getAllOnesConstant(DL, VT));
+    return DAG.getNOT(DL, NewShift, VT);
   }
 
   // fold (sra (shl X, m), (sub result_size, n))



More information about the llvm-commits mailing list