[llvm] 3c8bf29 - [DAG] Move "xor (X logical_shift ShiftC), XorC --> (not X) logical_shift ShiftC" fold into SimplifyDemandedBits

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 15 05:11:17 PDT 2022


Author: Simon Pilgrim
Date: 2022-07-15T13:10:15+01:00
New Revision: 3c8bf29696c3439a6279807dc56c4afcc21b8112

URL: https://github.com/llvm/llvm-project/commit/3c8bf29696c3439a6279807dc56c4afcc21b8112
DIFF: https://github.com/llvm/llvm-project/commit/3c8bf29696c3439a6279807dc56c4afcc21b8112.diff

LOG: [DAG] Move "xor (X logical_shift ShiftC), XorC --> (not X) logical_shift ShiftC" fold into SimplifyDemandedBits

SimplifyDemandedBits is called slightly later which allows the not(sext(x)) -> sext(not(x)) fold to occur via foldLogicOfShifts

As mentioned on D127115, we should be able to further generalise this based off the demanded bits.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/test/CodeGen/X86/xor.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2654c00929d8..cbabdc99cf0a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -8488,28 +8488,6 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
     return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
   }
 
-  if ((N0Opcode == ISD::SRL || N0Opcode == ISD::SHL) && N0.hasOneUse()) {
-    ConstantSDNode *XorC = isConstOrConstSplat(N1);
-    ConstantSDNode *ShiftC = isConstOrConstSplat(N0.getOperand(1));
-    unsigned BitWidth = VT.getScalarSizeInBits();
-    if (XorC && ShiftC) {
-      // Don't crash on an oversized shift. We can not guarantee that a bogus
-      // shift has been simplified to undef.
-      uint64_t ShiftAmt = ShiftC->getLimitedValue();
-      if (ShiftAmt < BitWidth) {
-        APInt Ones = APInt::getAllOnes(BitWidth);
-        Ones = N0Opcode == ISD::SHL ? Ones.shl(ShiftAmt) : Ones.lshr(ShiftAmt);
-        if (XorC->getAPIntValue() == Ones) {
-          // If the xor constant is a shifted -1, do a 'not' before the shift:
-          // xor (X << ShiftC), XorC --> (not X) << ShiftC
-          // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
-          SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
-          return DAG.getNode(N0Opcode, DL, VT, Not, N0.getOperand(1));
-        }
-      }
-    }
-  }
-
   // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
   if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
     SDValue A = N0Opcode == ISD::ADD ? N0 : N1;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index a31a5fa67996..b15383a787b5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -1500,7 +1500,7 @@ bool TargetLowering::SimplifyDemandedBits(
     if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero))
       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1));
 
-    ConstantSDNode* C = isConstOrConstSplat(Op1, DemandedElts);
+    ConstantSDNode *C = isConstOrConstSplat(Op1, DemandedElts);
     if (C) {
       // If one side is a constant, and all of the set bits in the constant are
       // also known set on the other side, turn this into an AND, as we know
@@ -1521,6 +1521,30 @@ bool TargetLowering::SimplifyDemandedBits(
         SDValue New = TLO.DAG.getNOT(dl, Op0, VT);
         return TLO.CombineTo(Op, New);
       }
+
+      unsigned Op0Opcode = Op0.getOpcode();
+      if ((Op0Opcode == ISD::SRL || Op0Opcode == ISD::SHL) && Op0.hasOneUse()) {
+        if (ConstantSDNode *ShiftC =
+                isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) {
+          // Don't crash on an oversized shift. We can not guarantee that a
+          // bogus shift has been simplified to undef.
+          if (ShiftC->getAPIntValue().ult(BitWidth)) {
+            uint64_t ShiftAmt = ShiftC->getZExtValue();
+            APInt Ones = APInt::getAllOnes(BitWidth);
+            Ones = Op0Opcode == ISD::SHL ? Ones.shl(ShiftAmt)
+                                         : Ones.lshr(ShiftAmt);
+            if (C->getAPIntValue() == Ones) {
+              // If the xor constant is a shifted -1, do a 'not' before the
+              // shift:
+              // xor (X << ShiftC), XorC --> (not X) << ShiftC
+              // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
+              SDValue Not = TLO.DAG.getNOT(dl, Op0.getOperand(0), VT);
+              return TLO.CombineTo(Op, TLO.DAG.getNode(Op0Opcode, dl, VT, Not,
+                                                       Op0.getOperand(1)));
+            }
+          }
+        }
+      }
     }
 
     // If we can't turn this into a 'not', try to shrink the constant.

diff  --git a/llvm/test/CodeGen/X86/xor.ll b/llvm/test/CodeGen/X86/xor.ll
index df10837c7a54..eccae2885edb 100644
--- a/llvm/test/CodeGen/X86/xor.ll
+++ b/llvm/test/CodeGen/X86/xor.ll
@@ -464,16 +464,16 @@ define ptr @test12(ptr %op, i64 %osbot, i64 %intval) {
 ;
 ; X64-LIN-LABEL: test12:
 ; X64-LIN:       # %bb.0:
+; X64-LIN-NEXT:    notl %edx
 ; X64-LIN-NEXT:    movslq %edx, %rax
-; X64-LIN-NEXT:    notq %rax
 ; X64-LIN-NEXT:    shlq $4, %rax
 ; X64-LIN-NEXT:    addq %rdi, %rax
 ; X64-LIN-NEXT:    retq
 ;
 ; X64-WIN-LABEL: test12:
 ; X64-WIN:       # %bb.0:
+; X64-WIN-NEXT:    notl %r8d
 ; X64-WIN-NEXT:    movslq %r8d, %rax
-; X64-WIN-NEXT:    notq %rax
 ; X64-WIN-NEXT:    shlq $4, %rax
 ; X64-WIN-NEXT:    addq %rcx, %rax
 ; X64-WIN-NEXT:    retq


        


More information about the llvm-commits mailing list