[llvm-commits] [llvm] r147936 - in /llvm/trunk: lib/CodeGen/SelectionDAG/DAGCombiner.cpp lib/Target/X86/X86ISelDAGToDAG.cpp test/CodeGen/X86/fold-and-shift.ll

Chandler Carruth chandlerc at gmail.com
Wed Jan 11 00:41:09 PST 2012


Author: chandlerc
Date: Wed Jan 11 02:41:08 2012
New Revision: 147936

URL: http://llvm.org/viewvc/llvm-project?rev=147936&view=rev
Log:
Teach the X86 instruction selection to do some heroic transforms to
detect a pattern which can be implemented with a small 'shl' embedded in
the addressing mode scale. This happens in real code as follows:

  unsigned x = my_accelerator_table[input >> 11];

Here we have some lookup table that we look into using the high bits of
'input'. Each entity in the table is 4-bytes, which means this
implicitly gets turned into (once lowered out of a GEP):

  *(unsigned*)((char*)my_accelerator_table + ((input >> 11) << 2));

The shift right followed by a shift left is canonicalized to a smaller
shift right and masking off the low bits. That hides the shift right
which x86 has an addressing mode designed to support. We now detect
masks of this form, and produce the longer shift right followed by the
proper addressing mode. In addition to saving a (rather large)
instruction, this also reduces stalls in Intel chips on benchmarks I've
measured.

In order for all of this to work, one part of the DAG needs to be
canonicalized *still further* than it currently is. This involves
removing pointless 'trunc' nodes between a zextload and a zext. Without
that, we end up generating spurious masks and hiding the pattern.

Modified:
    llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp
    llvm/trunk/test/CodeGen/X86/fold-and-shift.ll

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp?rev=147936&r1=147935&r2=147936&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Wed Jan 11 02:41:08 2012
@@ -4254,6 +4254,29 @@
     return DAG.getNode(ISD::ZERO_EXTEND, N->getDebugLoc(), VT,
                        N0.getOperand(0));
 
+  // fold (zext (truncate x)) -> (zext x) or
+  //      (zext (truncate x)) -> (truncate x)
+  // This is valid when the truncated bits of x are already zero.
+  // FIXME: We should extend this to work for vectors too.
+  if (N0.getOpcode() == ISD::TRUNCATE && !VT.isVector()) {
+    SDValue Op = N0.getOperand(0);
+    APInt TruncatedBits
+      = APInt::getBitsSet(Op.getValueSizeInBits(),
+                          N0.getValueSizeInBits(),
+                          std::min(Op.getValueSizeInBits(),
+                                   VT.getSizeInBits()));
+    APInt KnownZero, KnownOne;
+    DAG.ComputeMaskedBits(Op, TruncatedBits, KnownZero, KnownOne);
+    if (TruncatedBits == KnownZero) {
+      if (VT.bitsGT(Op.getValueType()))
+        return DAG.getNode(ISD::ZERO_EXTEND, N->getDebugLoc(), VT, Op);
+      if (VT.bitsLT(Op.getValueType()))
+        return DAG.getNode(ISD::TRUNCATE, N->getDebugLoc(), VT, Op);
+
+      return Op;
+    }
+  }
+
   // fold (zext (truncate (load x))) -> (zext (smaller load x))
   // fold (zext (truncate (srl (load x), c))) -> (zext (small load (x+c/n)))
   if (N0.getOpcode() == ISD::TRUNCATE) {

Modified: llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp?rev=147936&r1=147935&r2=147936&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelDAGToDAG.cpp Wed Jan 11 02:41:08 2012
@@ -725,6 +725,140 @@
   return false;
 }
 
+// Implement some heroics to detect shifts of masked values where the mask can
+// be replaced by extending the shift and undoing that in the addressing mode
+// scale. Patterns such as (shl (srl x, c1), c2) are canonicalized into (and
+// (srl x, SHIFT), MASK) by DAGCombines that don't know the shl can be done in
+// the addressing mode. This results in code such as:
+//
+//   int f(short *y, int *lookup_table) {
+//     ...
+//     return *y + lookup_table[*y >> 11];
+//   }
+//
+// Turning into:
+//   movzwl (%rdi), %eax
+//   movl %eax, %ecx
+//   shrl $11, %ecx
+//   addl (%rsi,%rcx,4), %eax
+//
+// Instead of:
+//   movzwl (%rdi), %eax
+//   movl %eax, %ecx
+//   shrl $9, %ecx
+//   andl $124, %rcx
+//   addl (%rsi,%rcx), %eax
+//
+static bool FoldMaskAndShiftToScale(SelectionDAG &DAG, SDValue N,
+                                    X86ISelAddressMode &AM) {
+  // Scale must not be used already.
+  if (AM.IndexReg.getNode() != 0 || AM.Scale != 1) return true;
+
+  SDValue Shift = N;
+  SDValue And = N.getOperand(0);
+  if (N.getOpcode() != ISD::SRL)
+    std::swap(Shift, And);
+  if (Shift.getOpcode() != ISD::SRL || And.getOpcode() != ISD::AND ||
+      !Shift.hasOneUse() ||
+      !isa<ConstantSDNode>(Shift.getOperand(1)) ||
+      !isa<ConstantSDNode>(And.getOperand(1)))
+    return true;
+  SDValue X = (N == Shift ? And.getOperand(0) : Shift.getOperand(0));
+
+  // We only handle up to 64-bit values here as those are what matter for
+  // addressing mode optimizations.
+  if (X.getValueSizeInBits() > 64) return true;
+
+  uint64_t Mask = And.getConstantOperandVal(1);
+  unsigned ShiftAmt = Shift.getConstantOperandVal(1);
+  unsigned MaskLZ = CountLeadingZeros_64(Mask);
+  unsigned MaskTZ = CountTrailingZeros_64(Mask);
+
+  // The amount of shift we're trying to fit into the addressing mode is taken
+  // from the trailing zeros of the mask. If the mask is pre-shift, we subtract
+  // the shift amount.
+  int AMShiftAmt = MaskTZ - (N == Shift ? ShiftAmt : 0);
+
+  // There is nothing we can do here unless the mask is removing some bits.
+  // Also, the addressing mode can only represent shifts of 1, 2, or 3 bits.
+  if (AMShiftAmt <= 0 || AMShiftAmt > 3) return true;
+
+  // We also need to ensure that mask is a continuous run of bits.
+  if (CountTrailingOnes_64(Mask >> MaskTZ) + MaskTZ + MaskLZ != 64) return true;
+
+  // Scale the leading zero count down based on the actual size of the value.
+  // Also scale it down based on the size of the shift if it was applied
+  // before the mask.
+  MaskLZ -= (64 - X.getValueSizeInBits()) + (N == Shift ? 0 : ShiftAmt);
+
+  // The final check is to ensure that any masked out high bits of X are
+  // already known to be zero. Otherwise, the mask has a semantic impact
+  // other than masking out a couple of low bits. Unfortunately, because of
+  // the mask, zero extensions will be removed from operands in some cases.
+  // This code works extra hard to look through extensions because we can
+  // replace them with zero extensions cheaply if necessary.
+  bool ReplacingAnyExtend = false;
+  if (X.getOpcode() == ISD::ANY_EXTEND) {
+    unsigned ExtendBits =
+      X.getValueSizeInBits() - X.getOperand(0).getValueSizeInBits();
+    // Assume that we'll replace the any-extend with a zero-extend, and
+    // narrow the search to the extended value.
+    X = X.getOperand(0);
+    MaskLZ = ExtendBits > MaskLZ ? 0 : MaskLZ - ExtendBits;
+    ReplacingAnyExtend = true;
+  }
+  APInt MaskedHighBits = APInt::getHighBitsSet(X.getValueSizeInBits(),
+                                               MaskLZ);
+  APInt KnownZero, KnownOne;
+  DAG.ComputeMaskedBits(X, MaskedHighBits, KnownZero, KnownOne);
+  if (MaskedHighBits != KnownZero) return true;
+
+  // We've identified a pattern that can be transformed into a single shift
+  // and an addressing mode. Make it so.
+  EVT VT = N.getValueType();
+  if (ReplacingAnyExtend) {
+    assert(X.getValueType() != VT);
+    // We looked through an ANY_EXTEND node, insert a ZERO_EXTEND.
+    SDValue NewX = DAG.getNode(ISD::ZERO_EXTEND, X.getDebugLoc(), VT, X);
+    if (NewX.getNode()->getNodeId() == -1 ||
+        NewX.getNode()->getNodeId() > N.getNode()->getNodeId()) {
+      DAG.RepositionNode(N.getNode(), NewX.getNode());
+      NewX.getNode()->setNodeId(N.getNode()->getNodeId());
+    }
+    X = NewX;
+  }
+  DebugLoc DL = N.getDebugLoc();
+  SDValue NewSRLAmt = DAG.getConstant(ShiftAmt + AMShiftAmt, MVT::i8);
+  SDValue NewSRL = DAG.getNode(ISD::SRL, DL, VT, X, NewSRLAmt);
+  SDValue NewSHLAmt = DAG.getConstant(AMShiftAmt, MVT::i8);
+  SDValue NewSHL = DAG.getNode(ISD::SHL, DL, VT, NewSRL, NewSHLAmt);
+  if (NewSRLAmt.getNode()->getNodeId() == -1 ||
+      NewSRLAmt.getNode()->getNodeId() > N.getNode()->getNodeId()) {
+    DAG.RepositionNode(N.getNode(), NewSRLAmt.getNode());
+    NewSRLAmt.getNode()->setNodeId(N.getNode()->getNodeId());
+  }
+  if (NewSRL.getNode()->getNodeId() == -1 ||
+      NewSRL.getNode()->getNodeId() > N.getNode()->getNodeId()) {
+    DAG.RepositionNode(N.getNode(), NewSRL.getNode());
+    NewSRL.getNode()->setNodeId(N.getNode()->getNodeId());
+  }
+  if (NewSHLAmt.getNode()->getNodeId() == -1 ||
+      NewSHLAmt.getNode()->getNodeId() > N.getNode()->getNodeId()) {
+    DAG.RepositionNode(N.getNode(), NewSHLAmt.getNode());
+    NewSHLAmt.getNode()->setNodeId(N.getNode()->getNodeId());
+  }
+  if (NewSHL.getNode()->getNodeId() == -1 ||
+      NewSHL.getNode()->getNodeId() > N.getNode()->getNodeId()) {
+    DAG.RepositionNode(N.getNode(), NewSHL.getNode());
+    NewSHL.getNode()->setNodeId(N.getNode()->getNodeId());
+  }
+  DAG.ReplaceAllUsesWith(N, NewSHL);
+
+  AM.Scale = 1 << AMShiftAmt;
+  AM.IndexReg = NewSRL;
+  return false;
+}
+
 bool X86DAGToDAGISel::MatchAddressRecursively(SDValue N, X86ISelAddressMode &AM,
                                               unsigned Depth) {
   DebugLoc dl = N.getDebugLoc();
@@ -814,6 +948,13 @@
     break;
     }
 
+  case ISD::SRL:
+    // Try to fold the mask and shift into the scale, and return false if we
+    // succeed.
+    if (!FoldMaskAndShiftToScale(*CurDAG, N, AM))
+      return false;
+    break;
+
   case ISD::SMUL_LOHI:
   case ISD::UMUL_LOHI:
     // A mul_lohi where we need the low part can be folded as a plain multiply.
@@ -1047,6 +1188,11 @@
       }
     }
 
+    // Try to fold the mask and shift into the scale, and return false if we
+    // succeed.
+    if (!FoldMaskAndShiftToScale(*CurDAG, N, AM))
+      return false;
+
     // Handle "(X << C1) & C2" as "(X & (C2>>C1)) << C1" if safe and if this
     // allows us to fold the shift into this addressing mode.
     if (Shift.getOpcode() != ISD::SHL) break;

Modified: llvm/trunk/test/CodeGen/X86/fold-and-shift.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/fold-and-shift.ll?rev=147936&r1=147935&r2=147936&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/fold-and-shift.ll (original)
+++ llvm/trunk/test/CodeGen/X86/fold-and-shift.ll Wed Jan 11 02:41:08 2012
@@ -31,3 +31,47 @@
   %tmp9 = load i32* %tmp78
   ret i32 %tmp9
 }
+
+define i32 @t3(i16* %i.ptr, i32* %arr) {
+; This case is tricky. The lshr followed by a gep will produce a lshr followed
+; by an and to remove the low bits. This can be simplified by doing the lshr by
+; a greater constant and using the addressing mode to scale the result back up.
+; To make matters worse, because of the two-phase zext of %i and their reuse in
+; the function, the DAG can get confusing trying to re-use both of them and
+; prevent easy analysis of the mask in order to match this.
+; CHECK: t3:
+; CHECK-NOT: and
+; CHECK: shrl
+; CHECK: addl (%{{...}},%{{...}},4),
+; CHECK: ret
+
+entry:
+  %i = load i16* %i.ptr
+  %i.zext = zext i16 %i to i32
+  %index = lshr i32 %i.zext, 11
+  %val.ptr = getelementptr inbounds i32* %arr, i32 %index
+  %val = load i32* %val.ptr
+  %sum = add i32 %val, %i.zext
+  ret i32 %sum
+}
+
+define i32 @t4(i16* %i.ptr, i32* %arr) {
+; A version of @t3 that has more zero extends and more re-use of intermediate
+; values. This exercise slightly different bits of canonicalization.
+; CHECK: t4:
+; CHECK-NOT: and
+; CHECK: shrl
+; CHECK: addl (%{{...}},%{{...}},4),
+; CHECK: ret
+
+entry:
+  %i = load i16* %i.ptr
+  %i.zext = zext i16 %i to i32
+  %index = lshr i32 %i.zext, 11
+  %index.zext = zext i32 %index to i64
+  %val.ptr = getelementptr inbounds i32* %arr, i64 %index.zext
+  %val = load i32* %val.ptr
+  %sum.1 = add i32 %val, %i.zext
+  %sum.2 = add i32 %sum.1, %index
+  ret i32 %sum.2
+}





More information about the llvm-commits mailing list