[llvm] r350222 - [X86] Support SHLD/SHRD masked shift-counts (PR34641)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 2 09:05:37 PST 2019


Author: rksimon
Date: Wed Jan  2 09:05:37 2019
New Revision: 350222

URL: http://llvm.org/viewvc/llvm-project?rev=350222&view=rev
Log:
[X86] Support SHLD/SHRD masked shift-counts (PR34641)

Peek through shift modulo masks while matching double shift patterns.

I was hoping to delay this until I could remove the X86 code with generic funnel shift matching (PR40081) but this will do for now.

Differential Revision: https://reviews.llvm.org/D56199

Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/test/CodeGen/X86/shift-double.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=350222&r1=350221&r2=350222&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Wed Jan  2 09:05:37 2019
@@ -36514,6 +36514,7 @@ static SDValue combineOr(SDNode *N, Sele
 
   // fold (or (x << c) | (y >> (64 - c))) ==> (shld64 x, y, c)
   bool OptForSize = DAG.getMachineFunction().getFunction().optForSize();
+  unsigned Bits = VT.getScalarSizeInBits();
 
   // SHLD/SHRD instructions have lower register pressure, but on some
   // platforms they have higher latency than the equivalent
@@ -36536,6 +36537,23 @@ static SDValue combineOr(SDNode *N, Sele
   SDValue ShAmt1 = N1.getOperand(1);
   if (ShAmt1.getValueType() != MVT::i8)
     return SDValue();
+
+  // Peek through any modulo shift masks.
+  SDValue ShMsk0;
+  if (ShAmt0.getOpcode() == ISD::AND &&
+      isa<ConstantSDNode>(ShAmt0.getOperand(1)) &&
+      ShAmt0.getConstantOperandVal(1) == (Bits - 1)) {
+    ShMsk0 = ShAmt0;
+    ShAmt0 = ShAmt0.getOperand(0);
+  }
+  SDValue ShMsk1;
+  if (ShAmt1.getOpcode() == ISD::AND &&
+      isa<ConstantSDNode>(ShAmt1.getOperand(1)) &&
+      ShAmt1.getConstantOperandVal(1) == (Bits - 1)) {
+    ShMsk1 = ShAmt1;
+    ShAmt1 = ShAmt1.getOperand(0);
+  }
+
   if (ShAmt0.getOpcode() == ISD::TRUNCATE)
     ShAmt0 = ShAmt0.getOperand(0);
   if (ShAmt1.getOpcode() == ISD::TRUNCATE)
@@ -36550,24 +36568,26 @@ static SDValue combineOr(SDNode *N, Sele
     Opc = X86ISD::SHRD;
     std::swap(Op0, Op1);
     std::swap(ShAmt0, ShAmt1);
+    std::swap(ShMsk0, ShMsk1);
   }
 
   // OR( SHL( X, C ), SRL( Y, 32 - C ) ) -> SHLD( X, Y, C )
   // OR( SRL( X, C ), SHL( Y, 32 - C ) ) -> SHRD( X, Y, C )
   // OR( SHL( X, C ), SRL( SRL( Y, 1 ), XOR( C, 31 ) ) ) -> SHLD( X, Y, C )
   // OR( SRL( X, C ), SHL( SHL( Y, 1 ), XOR( C, 31 ) ) ) -> SHRD( X, Y, C )
-  unsigned Bits = VT.getScalarSizeInBits();
+  // OR( SHL( X, AND( C, 31 ) ), SRL( Y, AND( 0 - C, 31 ) ) ) -> SHLD( X, Y, C )
+  // OR( SRL( X, AND( C, 31 ) ), SHL( Y, AND( 0 - C, 31 ) ) ) -> SHRD( X, Y, C )
   if (ShAmt1.getOpcode() == ISD::SUB) {
     SDValue Sum = ShAmt1.getOperand(0);
     if (auto *SumC = dyn_cast<ConstantSDNode>(Sum)) {
       SDValue ShAmt1Op1 = ShAmt1.getOperand(1);
       if (ShAmt1Op1.getOpcode() == ISD::TRUNCATE)
         ShAmt1Op1 = ShAmt1Op1.getOperand(0);
-      if (SumC->getSExtValue() == Bits && ShAmt1Op1 == ShAmt0)
-        return DAG.getNode(Opc, DL, VT,
-                           Op0, Op1,
-                           DAG.getNode(ISD::TRUNCATE, DL,
-                                       MVT::i8, ShAmt0));
+      if ((SumC->getAPIntValue() == Bits ||
+           (SumC->getAPIntValue() == 0 && ShMsk1)) &&
+          ShAmt1Op1 == ShAmt0)
+        return DAG.getNode(Opc, DL, VT, Op0, Op1,
+                           DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0));
     }
   } else if (auto *ShAmt1C = dyn_cast<ConstantSDNode>(ShAmt1)) {
     auto *ShAmt0C = dyn_cast<ConstantSDNode>(ShAmt0);
@@ -36583,7 +36603,8 @@ static SDValue combineOr(SDNode *N, Sele
       SDValue ShAmt1Op0 = ShAmt1.getOperand(0);
       if (ShAmt1Op0.getOpcode() == ISD::TRUNCATE)
         ShAmt1Op0 = ShAmt1Op0.getOperand(0);
-      if (MaskC->getSExtValue() == (Bits - 1) && ShAmt1Op0 == ShAmt0) {
+      if (MaskC->getSExtValue() == (Bits - 1) &&
+          (ShAmt1Op0 == ShAmt0 || ShAmt1Op0 == ShMsk0)) {
         if (Op1.getOpcode() == InnerShift &&
             isa<ConstantSDNode>(Op1.getOperand(1)) &&
             Op1.getConstantOperandVal(1) == 1) {
@@ -36594,7 +36615,7 @@ static SDValue combineOr(SDNode *N, Sele
         if (InnerShift == ISD::SHL && Op1.getOpcode() == ISD::ADD &&
             Op1.getOperand(0) == Op1.getOperand(1)) {
           return DAG.getNode(Opc, DL, VT, Op0, Op1.getOperand(0),
-                     DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0));
+                             DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0));
         }
       }
     }

Modified: llvm/trunk/test/CodeGen/X86/shift-double.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/shift-double.ll?rev=350222&r1=350221&r2=350222&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/shift-double.ll (original)
+++ llvm/trunk/test/CodeGen/X86/shift-double.ll Wed Jan  2 09:05:37 2019
@@ -460,24 +460,18 @@ define i32 @test17(i32 %hi, i32 %lo, i32
 define i32 @shld_safe_i32(i32, i32, i32) {
 ; X86-LABEL: shld_safe_i32:
 ; X86:       # %bb.0:
-; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    movb {{[0-9]+}}(%esp), %cl
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %edx
-; X86-NEXT:    shll %cl, %edx
-; X86-NEXT:    negb %cl
-; X86-NEXT:    shrl %cl, %eax
-; X86-NEXT:    orl %edx, %eax
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    shldl %cl, %edx, %eax
 ; X86-NEXT:    retl
 ;
 ; X64-LABEL: shld_safe_i32:
 ; X64:       # %bb.0:
 ; X64-NEXT:    movl %edx, %ecx
-; X64-NEXT:    movl %esi, %eax
-; X64-NEXT:    shll %cl, %edi
-; X64-NEXT:    negb %cl
+; X64-NEXT:    movl %edi, %eax
 ; X64-NEXT:    # kill: def $cl killed $cl killed $ecx
-; X64-NEXT:    shrl %cl, %eax
-; X64-NEXT:    orl %edi, %eax
+; X64-NEXT:    shldl %cl, %esi, %eax
 ; X64-NEXT:    retq
   %4 = and i32 %2, 31
   %5 = shl i32 %0, %4
@@ -491,24 +485,18 @@ define i32 @shld_safe_i32(i32, i32, i32)
 define i32 @shrd_safe_i32(i32, i32, i32) {
 ; X86-LABEL: shrd_safe_i32:
 ; X86:       # %bb.0:
-; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-NEXT:    movb {{[0-9]+}}(%esp), %cl
 ; X86-NEXT:    movl {{[0-9]+}}(%esp), %edx
-; X86-NEXT:    shrl %cl, %edx
-; X86-NEXT:    negb %cl
-; X86-NEXT:    shll %cl, %eax
-; X86-NEXT:    orl %edx, %eax
+; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
+; X86-NEXT:    shrdl %cl, %edx, %eax
 ; X86-NEXT:    retl
 ;
 ; X64-LABEL: shrd_safe_i32:
 ; X64:       # %bb.0:
 ; X64-NEXT:    movl %edx, %ecx
-; X64-NEXT:    movl %esi, %eax
-; X64-NEXT:    shrl %cl, %edi
-; X64-NEXT:    negb %cl
+; X64-NEXT:    movl %edi, %eax
 ; X64-NEXT:    # kill: def $cl killed $cl killed $ecx
-; X64-NEXT:    shll %cl, %eax
-; X64-NEXT:    orl %edi, %eax
+; X64-NEXT:    shrdl %cl, %esi, %eax
 ; X64-NEXT:    retq
   %4 = and i32 %2, 31
   %5 = lshr i32 %0, %4




More information about the llvm-commits mailing list