[llvm] 7428739 - [X86] matchAddressRecursively - peek through ZEXT nodes to match foldMaskAndShiftToExtract

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 9 07:54:21 PDT 2023


Author: Simon Pilgrim
Date: 2023-07-09T15:41:38+01:00
New Revision: 7428739ea81e1508b8da92e1561574216857a897

URL: https://github.com/llvm/llvm-project/commit/7428739ea81e1508b8da92e1561574216857a897
DIFF: https://github.com/llvm/llvm-project/commit/7428739ea81e1508b8da92e1561574216857a897.diff

LOG: [X86] matchAddressRecursively - peek through ZEXT nodes to match foldMaskAndShiftToExtract

Handle (zero_extend (and (srl X, C1), C2)) patterns to allow foldMaskAndShiftToExtract to match h-register extractions from smaller types

Ideally matchAddressRecursively needs to be able to recurse through ZEXT/SEXT nodes generally but for now we should just handle specific cases when they occur

Addresses regressions in D146121

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
    llvm/test/CodeGen/X86/h-register-addressing-64.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index bedeb783bcf6d3..ca064ee284f30c 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -1932,14 +1932,16 @@ static bool foldMaskAndShiftToExtract(SelectionDAG &DAG, SDValue N,
       Mask != (0xffu << ScaleLog))
     return true;
 
+  MVT XVT = X.getSimpleValueType();
   MVT VT = N.getSimpleValueType();
   SDLoc DL(N);
   SDValue Eight = DAG.getConstant(8, DL, MVT::i8);
-  SDValue NewMask = DAG.getConstant(0xff, DL, VT);
-  SDValue Srl = DAG.getNode(ISD::SRL, DL, VT, X, Eight);
-  SDValue And = DAG.getNode(ISD::AND, DL, VT, Srl, NewMask);
+  SDValue NewMask = DAG.getConstant(0xff, DL, XVT);
+  SDValue Srl = DAG.getNode(ISD::SRL, DL, XVT, X, Eight);
+  SDValue And = DAG.getNode(ISD::AND, DL, XVT, Srl, NewMask);
   SDValue ShlCount = DAG.getConstant(ScaleLog, DL, MVT::i8);
-  SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, And, ShlCount);
+  SDValue Ext = DAG.getZExtOrTrunc(And, DL, VT);
+  SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, Ext, ShlCount);
 
   // Insert the new nodes into the topological ordering. We must do this in
   // a valid topological ordering as nothing is going to go back and re-sort
@@ -1951,10 +1953,12 @@ static bool foldMaskAndShiftToExtract(SelectionDAG &DAG, SDValue N,
   insertDAGNode(DAG, N, NewMask);
   insertDAGNode(DAG, N, And);
   insertDAGNode(DAG, N, ShlCount);
+  if (Ext != And)
+    insertDAGNode(DAG, N, Ext);
   insertDAGNode(DAG, N, Shl);
   DAG.ReplaceAllUsesWith(N, Shl);
   DAG.RemoveDeadNode(N.getNode());
-  AM.IndexReg = And;
+  AM.IndexReg = Ext;
   AM.Scale = (1 << ScaleLog);
   return false;
 }
@@ -2508,53 +2512,60 @@ bool X86DAGToDAGISel::matchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
         Src = Src.getOperand(0);
       }
 
-    if (Src.getOpcode() != ISD::SHL || !Src.hasOneUse())
-      break;
+    if (Src.getOpcode() == ISD::SHL && Src.hasOneUse()) {
+      // Give up if the shift is not a valid scale factor [1,2,3].
+      SDValue ShlSrc = Src.getOperand(0);
+      SDValue ShlAmt = Src.getOperand(1);
+      auto *ShAmtC = dyn_cast<ConstantSDNode>(ShlAmt);
+      if (!ShAmtC)
+        break;
+      unsigned ShAmtV = ShAmtC->getZExtValue();
+      if (ShAmtV > 3)
+        break;
 
-    // Give up if the shift is not a valid scale factor [1,2,3].
-    SDValue ShlSrc = Src.getOperand(0);
-    SDValue ShlAmt = Src.getOperand(1);
-    auto *ShAmtC = dyn_cast<ConstantSDNode>(ShlAmt);
-    if (!ShAmtC)
-      break;
-    unsigned ShAmtV = ShAmtC->getZExtValue();
-    if (ShAmtV > 3)
-      break;
+      // The narrow shift must only shift out zero bits (it must be 'nuw').
+      // That makes it safe to widen to the destination type.
+      APInt HighZeros =
+          APInt::getHighBitsSet(ShlSrc.getValueSizeInBits(), ShAmtV);
+      if (!CurDAG->MaskedValueIsZero(ShlSrc, HighZeros & Mask))
+        break;
 
-    // The narrow shift must only shift out zero bits (it must be 'nuw').
-    // That makes it safe to widen to the destination type.
-    APInt HighZeros =
-        APInt::getHighBitsSet(ShlSrc.getValueSizeInBits(), ShAmtV);
-    if (!CurDAG->MaskedValueIsZero(ShlSrc, HighZeros & Mask))
-      break;
+      // zext (shl nuw i8 %x, C1) to i32
+      // --> shl (zext i8 %x to i32), (zext C1)
+      // zext (and (shl nuw i8 %x, C1), C2) to i32
+      // --> shl (zext i8 (and %x, C2 >> C1) to i32), (zext C1)
+      MVT SrcVT = ShlSrc.getSimpleValueType();
+      MVT VT = N.getSimpleValueType();
+      SDLoc DL(N);
+
+      SDValue Res = ShlSrc;
+      if (!Mask.isAllOnes()) {
+        Res = CurDAG->getConstant(Mask.lshr(ShAmtV), DL, SrcVT);
+        insertDAGNode(*CurDAG, N, Res);
+        Res = CurDAG->getNode(ISD::AND, DL, SrcVT, ShlSrc, Res);
+        insertDAGNode(*CurDAG, N, Res);
+      }
+      SDValue Zext = CurDAG->getNode(ISD::ZERO_EXTEND, DL, VT, Res);
+      insertDAGNode(*CurDAG, N, Zext);
+      SDValue NewShl = CurDAG->getNode(ISD::SHL, DL, VT, Zext, ShlAmt);
+      insertDAGNode(*CurDAG, N, NewShl);
 
-    // zext (shl nuw i8 %x, C1) to i32
-    // --> shl (zext i8 %x to i32), (zext C1)
-    // zext (and (shl nuw i8 %x, C1), C2) to i32
-    // --> shl (zext i8 (and %x, C2 >> C1) to i32), (zext C1)
-    MVT SrcVT = ShlSrc.getSimpleValueType();
-    MVT VT = N.getSimpleValueType();
-    SDLoc DL(N);
-
-    SDValue Res = ShlSrc;
-    if (!Mask.isAllOnes()) {
-      Res = CurDAG->getConstant(Mask.lshr(ShAmtV), DL, SrcVT);
-      insertDAGNode(*CurDAG, N, Res);
-      Res = CurDAG->getNode(ISD::AND, DL, SrcVT, ShlSrc, Res);
-      insertDAGNode(*CurDAG, N, Res);
+      // Convert the shift to scale factor.
+      AM.Scale = 1 << ShAmtV;
+      AM.IndexReg = Zext;
+
+      CurDAG->ReplaceAllUsesWith(N, NewShl);
+      CurDAG->RemoveDeadNode(N.getNode());
+      return false;
     }
-    SDValue Zext = CurDAG->getNode(ISD::ZERO_EXTEND, DL, VT, Res);
-    insertDAGNode(*CurDAG, N, Zext);
-    SDValue NewShl = CurDAG->getNode(ISD::SHL, DL, VT, Zext, ShlAmt);
-    insertDAGNode(*CurDAG, N, NewShl);
 
-    // Convert the shift to scale factor.
-    AM.Scale = 1 << ShAmtV;
-    AM.IndexReg = Zext;
+    // Try to fold the mask and shift into an extract and scale.
+    if (Src.getOpcode() == ISD::SRL && !Mask.isAllOnes() &&
+        !foldMaskAndShiftToExtract(*CurDAG, N, Mask.getZExtValue(), Src,
+                                   Src.getOperand(0), AM))
+      return false;
 
-    CurDAG->ReplaceAllUsesWith(N, NewShl);
-    CurDAG->RemoveDeadNode(N.getNode());
-    return false;
+    break;
   }
   }
 

diff  --git a/llvm/test/CodeGen/X86/h-register-addressing-64.ll b/llvm/test/CodeGen/X86/h-register-addressing-64.ll
index e65c2c85291a15..b36a9484d8561a 100644
--- a/llvm/test/CodeGen/X86/h-register-addressing-64.ll
+++ b/llvm/test/CodeGen/X86/h-register-addressing-64.ll
@@ -104,9 +104,8 @@ define i8 @bar2(ptr nocapture inreg %p, i64 inreg %x) nounwind readonly {
 define double @ext8(ptr nocapture inreg %p, i32 inreg %x) nounwind readonly {
 ; CHECK-LABEL: ext8:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    # kill: def $esi killed $esi def $rsi
-; CHECK-NEXT:    shrl $5, %esi
-; CHECK-NEXT:    andl $2040, %esi # imm = 0x7F8
+; CHECK-NEXT:    movl %esi, %eax
+; CHECK-NEXT:    movzbl %ah, %eax
 ; CHECK-NEXT:    movsd {{.*#+}} xmm0 = mem[0],zero
 ; CHECK-NEXT:    retq
   %t0 = lshr i32 %x, 5
@@ -120,9 +119,8 @@ define double @ext8(ptr nocapture inreg %p, i32 inreg %x) nounwind readonly {
 define float @ext4(ptr nocapture inreg %p, i32 inreg %x) nounwind readonly {
 ; CHECK-LABEL: ext4:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    # kill: def $esi killed $esi def $rsi
-; CHECK-NEXT:    shrl $6, %esi
-; CHECK-NEXT:    andl $1020, %esi # imm = 0x3FC
+; CHECK-NEXT:    movl %esi, %eax
+; CHECK-NEXT:    movzbl %ah, %eax
 ; CHECK-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
 ; CHECK-NEXT:    retq
   %t0 = lshr i32 %x, 6
@@ -136,10 +134,9 @@ define float @ext4(ptr nocapture inreg %p, i32 inreg %x) nounwind readonly {
 define i8 @ext2(ptr nocapture inreg %p, i32 inreg %x) nounwind readonly {
 ; CHECK-LABEL: ext2:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    # kill: def $esi killed $esi def $rsi
-; CHECK-NEXT:    shrl $7, %esi
-; CHECK-NEXT:    andl $510, %esi # imm = 0x1FE
-; CHECK-NEXT:    movzbl (%rdi,%rsi), %eax
+; CHECK-NEXT:    movl %esi, %eax
+; CHECK-NEXT:    movzbl %ah, %eax
+; CHECK-NEXT:    movzbl (%rdi,%rax,2), %eax
 ; CHECK-NEXT:    retq
   %t0 = lshr i32 %x, 7
   %t1 = and i32 %t0, 510


        


More information about the llvm-commits mailing list