[llvm] 47208f8 - [X86] matchAddressRecursively - support zext(and(shl(x,c1)),c2) -> shl(zext(and(x, c2 >> c1),c1)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 12 01:49:25 PST 2023


Author: Simon Pilgrim
Date: 2023-03-12T09:49:11Z
New Revision: 47208f8d343791d47a487f0139743d607b8bb965

URL: https://github.com/llvm/llvm-project/commit/47208f8d343791d47a487f0139743d607b8bb965
DIFF: https://github.com/llvm/llvm-project/commit/47208f8d343791d47a487f0139743d607b8bb965.diff

LOG: [X86] matchAddressRecursively - support zext(and(shl(x,c1)),c2) -> shl(zext(and(x, c2 >> c1),c1)

This came about while investigating ways to handle D145468 in a more generic manner, which involves trying harder to fold and(zext(x),c) -> zext(and(x,c))

Alive2: https://alive2.llvm.org/ce/z/7fXtDt (generic fold)

Differential Revision: https://reviews.llvm.org/D145855

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
    llvm/test/CodeGen/X86/lea-dagdag.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index c13245ff7bea7..5e90a94819b6b 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -2489,34 +2489,60 @@ bool X86DAGToDAGISel::matchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
     // match the shift as a scale factor.
     if (AM.IndexReg.getNode() != nullptr || AM.Scale != 1)
       break;
-    if (N.getOperand(0).getOpcode() != ISD::SHL || !N.getOperand(0).hasOneUse())
+
+    // Peek through mask: zext(and(shl(x,c1),c2))
+    SDValue Src = N.getOperand(0);
+    APInt Mask = APInt::getAllOnes(Src.getScalarValueSizeInBits());
+    if (Src.getOpcode() == ISD::AND && Src.hasOneUse())
+      if (auto *MaskC = dyn_cast<ConstantSDNode>(Src.getOperand(1))) {
+        Mask = MaskC->getAPIntValue();
+        Src = Src.getOperand(0);
+      }
+
+    if (Src.getOpcode() != ISD::SHL || !Src.hasOneUse())
       break;
 
     // Give up if the shift is not a valid scale factor [1,2,3].
-    SDValue Shl = N.getOperand(0);
-    auto *ShAmtC = dyn_cast<ConstantSDNode>(Shl.getOperand(1));
-    if (!ShAmtC || ShAmtC->getZExtValue() > 3)
+    SDValue ShlSrc = Src.getOperand(0);
+    SDValue ShlAmt = Src.getOperand(1);
+    auto *ShAmtC = dyn_cast<ConstantSDNode>(ShlAmt);
+    if (!ShAmtC)
+      break;
+    unsigned ShAmtV = ShAmtC->getZExtValue();
+    if (ShAmtV > 3)
       break;
 
     // The narrow shift must only shift out zero bits (it must be 'nuw').
     // That makes it safe to widen to the destination type.
-    APInt HighZeros = APInt::getHighBitsSet(Shl.getValueSizeInBits(),
-                                            ShAmtC->getZExtValue());
-    if (!CurDAG->MaskedValueIsZero(Shl.getOperand(0), HighZeros))
+    APInt HighZeros =
+        APInt::getHighBitsSet(ShlSrc.getValueSizeInBits(), ShAmtV);
+    if (!CurDAG->MaskedValueIsZero(ShlSrc, HighZeros & Mask))
       break;
 
-    // zext (shl nuw i8 %x, C) to i32 --> shl (zext i8 %x to i32), (zext C)
+    // zext (shl nuw i8 %x, C1) to i32
+    // --> shl (zext i8 %x to i32), (zext C1)
+    // zext (and (shl nuw i8 %x, C1), C2) to i32
+    // --> shl (zext i8 (and %x, C2 >> C1) to i32), (zext C1)
+    MVT SrcVT = ShlSrc.getSimpleValueType();
     MVT VT = N.getSimpleValueType();
     SDLoc DL(N);
-    SDValue Zext = CurDAG->getNode(ISD::ZERO_EXTEND, DL, VT, Shl.getOperand(0));
-    SDValue NewShl = CurDAG->getNode(ISD::SHL, DL, VT, Zext, Shl.getOperand(1));
+
+    SDValue Res = ShlSrc;
+    if (!Mask.isAllOnes()) {
+      Res = CurDAG->getConstant(Mask.lshr(ShAmtV), DL, SrcVT);
+      insertDAGNode(*CurDAG, N, Res);
+      Res = CurDAG->getNode(ISD::AND, DL, SrcVT, ShlSrc, Res);
+      insertDAGNode(*CurDAG, N, Res);
+    }
+    SDValue Zext = CurDAG->getNode(ISD::ZERO_EXTEND, DL, VT, Res);
+    insertDAGNode(*CurDAG, N, Zext);
+    SDValue NewShl = CurDAG->getNode(ISD::SHL, DL, VT, Zext, ShlAmt);
+    insertDAGNode(*CurDAG, N, NewShl);
 
     // Convert the shift to scale factor.
-    AM.Scale = 1 << ShAmtC->getZExtValue();
+    AM.Scale = 1 << ShAmtV;
     AM.IndexReg = Zext;
 
-    insertDAGNode(*CurDAG, N, Zext);
-    insertDAGNode(*CurDAG, N, NewShl);
     CurDAG->ReplaceAllUsesWith(N, NewShl);
     CurDAG->RemoveDeadNode(N.getNode());
     return false;

diff  --git a/llvm/test/CodeGen/X86/lea-dagdag.ll b/llvm/test/CodeGen/X86/lea-dagdag.ll
index c5bf10d391add..2705bd00f5d2c 100644
--- a/llvm/test/CodeGen/X86/lea-dagdag.ll
+++ b/llvm/test/CodeGen/X86/lea-dagdag.ll
@@ -153,10 +153,9 @@ define i64 @and_i32_shl_zext_add_i64(i64 %t0, i32 %t1) {
 define i64 @shl_and_i8_zext_add_i64(i64 %t0, i8 %t1) {
 ; CHECK-LABEL: shl_and_i8_zext_add_i64:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    shlb $2, %sil
-; CHECK-NEXT:    andb $60, %sil
+; CHECK-NEXT:    andb $15, %sil
 ; CHECK-NEXT:    movzbl %sil, %eax
-; CHECK-NEXT:    addq %rdi, %rax
+; CHECK-NEXT:    leaq (%rdi,%rax,4), %rax
 ; CHECK-NEXT:    retq
   %s = shl i8 %t1, 2
   %m = and i8 %s, 60
@@ -169,9 +168,8 @@ define i64 @shl_and_i16_zext_add_i64(i64 %t0, i16 %t1) {
 ; CHECK-LABEL: shl_and_i16_zext_add_i64:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    # kill: def $esi killed $esi def $rsi
-; CHECK-NEXT:    leal (%rsi,%rsi), %eax
-; CHECK-NEXT:    andl $16, %eax
-; CHECK-NEXT:    addq %rdi, %rax
+; CHECK-NEXT:    andl $8, %esi
+; CHECK-NEXT:    leaq (%rdi,%rsi,2), %rax
 ; CHECK-NEXT:    retq
   %s = shl i16 %t1, 1
   %m = and i16 %s, 17
@@ -184,9 +182,8 @@ define i64 @shl_and_i32_zext_add_i64(i64 %t0, i32 %t1) {
 ; CHECK-LABEL: shl_and_i32_zext_add_i64:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    # kill: def $esi killed $esi def $rsi
-; CHECK-NEXT:    leal (,%rsi,8), %eax
-; CHECK-NEXT:    andl $5992, %eax # imm = 0x1768
-; CHECK-NEXT:    addq %rdi, %rax
+; CHECK-NEXT:    andl $749, %esi # imm = 0x2ED
+; CHECK-NEXT:    leaq (%rdi,%rsi,8), %rax
 ; CHECK-NEXT:    retq
   %s = shl i32 %t1, 3
   %m = and i32 %s, 5999


        


More information about the llvm-commits mailing list