[llvm] r193096 - X86 vector element shift-by-immediate instructions take i8 immediates. Make

Lang Hames lhames at gmail.com
Mon Oct 21 10:51:24 PDT 2013


Author: lhames
Date: Mon Oct 21 12:51:24 2013
New Revision: 193096

URL: http://llvm.org/viewvc/llvm-project?rev=193096&view=rev
Log:
X86 vector element shift-by-immediate instructions take i8 immediates. Make
the instruction defenitions and ISEL reflect this.

Prior to this patch these instructions took an i32i8imm, and the high bits were
dropped during encoding. This led to incorrect behavior for shifts by
immediates higher than 255. This patch fixes that issue by detecting large
immediate shifts and returning constant zero (for logical shifts) or capping
the shift amount at an encodable value (for arithmetic shifts).

Fixes <rdar://problem/14968098>


Modified:
    llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
    llvm/trunk/lib/Target/X86/X86InstrAVX512.td
    llvm/trunk/lib/Target/X86/X86InstrSSE.td
    llvm/trunk/test/CodeGen/X86/avx2-vector-shifts.ll
    llvm/trunk/test/CodeGen/X86/sse2-vector-shifts.ll

Modified: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86ISelLowering.cpp?rev=193096&r1=193095&r2=193096&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp Mon Oct 21 12:51:24 2013
@@ -10952,6 +10952,26 @@ static SDValue LowerVACOPY(SDValue Op, c
                        MachinePointerInfo(DstSV), MachinePointerInfo(SrcSV));
 }
 
+// getTargetVShiftByConstNode - Handle vector element shifts where the shift
+// amount is a constant. Takes immediate version of shift as input.
+static SDValue getTargetVShiftByConstNode(unsigned Opc, SDLoc dl, EVT VT,
+                                          SDValue SrcOp, uint64_t ShiftAmt,
+                                          SelectionDAG &DAG) {
+
+  // Check for ShiftAmt >= element width
+  if (ShiftAmt >= VT.getVectorElementType().getSizeInBits()) {
+    if (Opc == X86ISD::VSRAI)
+      ShiftAmt = VT.getVectorElementType().getSizeInBits() - 1;
+    else
+      return DAG.getConstant(0, VT);
+  }
+
+  assert((Opc == X86ISD::VSHLI || Opc == X86ISD::VSRLI || Opc == X86ISD::VSRAI)
+         && "Unknown target vector shift-by-constant node");
+
+  return DAG.getNode(Opc, dl, VT, SrcOp, DAG.getConstant(ShiftAmt, MVT::i8));
+}
+
 // getTargetVShiftNode - Handle vector element shifts where the shift amount
 // may or may not be a constant. Takes immediate version of shift as input.
 static SDValue getTargetVShiftNode(unsigned Opc, SDLoc dl, EVT VT,
@@ -10959,18 +10979,10 @@ static SDValue getTargetVShiftNode(unsig
                                    SelectionDAG &DAG) {
   assert(ShAmt.getValueType() == MVT::i32 && "ShAmt is not i32");
 
-  if (isa<ConstantSDNode>(ShAmt)) {
-    // Constant may be a TargetConstant. Use a regular constant.
-    uint32_t ShiftAmt = cast<ConstantSDNode>(ShAmt)->getZExtValue();
-    switch (Opc) {
-      default: llvm_unreachable("Unknown target vector shift node");
-      case X86ISD::VSHLI:
-      case X86ISD::VSRLI:
-      case X86ISD::VSRAI:
-        return DAG.getNode(Opc, dl, VT, SrcOp,
-                           DAG.getConstant(ShiftAmt, MVT::i32));
-    }
-  }
+  // Catch shift-by-constant.
+  if (ConstantSDNode *CShAmt = dyn_cast<ConstantSDNode>(ShAmt))
+    return getTargetVShiftByConstNode(Opc, dl, VT, SrcOp,
+                                      CShAmt->getZExtValue(), DAG);
 
   // Change opcode to non-immediate version
   switch (Opc) {
@@ -12416,10 +12428,8 @@ static SDValue LowerMUL(SDValue Op, cons
   //  AhiBlo = psllqi(AhiBlo, 32);
   //  return AloBlo + AloBhi + AhiBlo;
 
-  SDValue ShAmt = DAG.getConstant(32, MVT::i32);
-
-  SDValue Ahi = DAG.getNode(X86ISD::VSRLI, dl, VT, A, ShAmt);
-  SDValue Bhi = DAG.getNode(X86ISD::VSRLI, dl, VT, B, ShAmt);
+  SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG);
+  SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG);
 
   // Bit cast to 32-bit vectors for MULUDQ
   EVT MulVT = (VT == MVT::v2i64) ? MVT::v4i32 :
@@ -12433,8 +12443,8 @@ static SDValue LowerMUL(SDValue Op, cons
   SDValue AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, Bhi);
   SDValue AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, B);
 
-  AloBhi = DAG.getNode(X86ISD::VSHLI, dl, VT, AloBhi, ShAmt);
-  AhiBlo = DAG.getNode(X86ISD::VSHLI, dl, VT, AhiBlo, ShAmt);
+  AloBhi = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AloBhi, 32, DAG);
+  AhiBlo = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AhiBlo, 32, DAG);
 
   SDValue Res = DAG.getNode(ISD::ADD, dl, VT, AloBlo, AloBhi);
   return DAG.getNode(ISD::ADD, dl, VT, Res, AhiBlo);
@@ -12462,7 +12472,7 @@ static SDValue LowerSDIV(SDValue Op, Sel
 
   if ((SplatValue != 0) &&
       (SplatValue.isPowerOf2() || (-SplatValue).isPowerOf2())) {
-    unsigned lg2 = SplatValue.countTrailingZeros();
+    unsigned Lg2 = SplatValue.countTrailingZeros();
     // Splat the sign bit.
     SmallVector<SDValue, 16> Sz(NumElts,
                                 DAG.getConstant(EltTy.getSizeInBits() - 1,
@@ -12472,13 +12482,13 @@ static SDValue LowerSDIV(SDValue Op, Sel
                                           NumElts));
     // Add (N0 < 0) ? abs2 - 1 : 0;
     SmallVector<SDValue, 16> Amt(NumElts,
-                                 DAG.getConstant(EltTy.getSizeInBits() - lg2,
+                                 DAG.getConstant(EltTy.getSizeInBits() - Lg2,
                                                  EltTy));
     SDValue SRL = DAG.getNode(ISD::SRL, dl, VT, SGN,
                               DAG.getNode(ISD::BUILD_VECTOR, dl, VT, &Amt[0],
                                           NumElts));
     SDValue ADD = DAG.getNode(ISD::ADD, dl, VT, N0, SRL);
-    SmallVector<SDValue, 16> Lg2Amt(NumElts, DAG.getConstant(lg2, EltTy));
+    SmallVector<SDValue, 16> Lg2Amt(NumElts, DAG.getConstant(Lg2, EltTy));
     SDValue SRA = DAG.getNode(ISD::SRA, dl, VT, ADD,
                               DAG.getNode(ISD::BUILD_VECTOR, dl, VT, &Lg2Amt[0],
                                           NumElts));
@@ -12514,21 +12524,22 @@ static SDValue LowerScalarImmediateShift
           (Subtarget->hasAVX512() &&
            (VT == MVT::v8i64 || VT == MVT::v16i32))) {
         if (Op.getOpcode() == ISD::SHL)
-          return DAG.getNode(X86ISD::VSHLI, dl, VT, R,
-                             DAG.getConstant(ShiftAmt, MVT::i32));
+          return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
+                                            DAG);
         if (Op.getOpcode() == ISD::SRL)
-          return DAG.getNode(X86ISD::VSRLI, dl, VT, R,
-                             DAG.getConstant(ShiftAmt, MVT::i32));
+          return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
+                                            DAG);
         if (Op.getOpcode() == ISD::SRA && VT != MVT::v2i64 && VT != MVT::v4i64)
-          return DAG.getNode(X86ISD::VSRAI, dl, VT, R,
-                             DAG.getConstant(ShiftAmt, MVT::i32));
+          return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
+                                            DAG);
       }
 
       if (VT == MVT::v16i8) {
         if (Op.getOpcode() == ISD::SHL) {
           // Make a large shift.
-          SDValue SHL = DAG.getNode(X86ISD::VSHLI, dl, MVT::v8i16, R,
-                                    DAG.getConstant(ShiftAmt, MVT::i32));
+          SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl,
+                                                   MVT::v8i16, R, ShiftAmt,
+                                                   DAG); 
           SHL = DAG.getNode(ISD::BITCAST, dl, VT, SHL);
           // Zero out the rightmost bits.
           SmallVector<SDValue, 16> V(16,
@@ -12539,8 +12550,9 @@ static SDValue LowerScalarImmediateShift
         }
         if (Op.getOpcode() == ISD::SRL) {
           // Make a large shift.
-          SDValue SRL = DAG.getNode(X86ISD::VSRLI, dl, MVT::v8i16, R,
-                                    DAG.getConstant(ShiftAmt, MVT::i32));
+          SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl,
+                                                   MVT::v8i16, R, ShiftAmt,
+                                                   DAG);
           SRL = DAG.getNode(ISD::BITCAST, dl, VT, SRL);
           // Zero out the leftmost bits.
           SmallVector<SDValue, 16> V(16,
@@ -12571,8 +12583,9 @@ static SDValue LowerScalarImmediateShift
       if (Subtarget->hasInt256() && VT == MVT::v32i8) {
         if (Op.getOpcode() == ISD::SHL) {
           // Make a large shift.
-          SDValue SHL = DAG.getNode(X86ISD::VSHLI, dl, MVT::v16i16, R,
-                                    DAG.getConstant(ShiftAmt, MVT::i32));
+          SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl,
+                                                   MVT::v16i16, R, ShiftAmt,
+                                                   DAG);
           SHL = DAG.getNode(ISD::BITCAST, dl, VT, SHL);
           // Zero out the rightmost bits.
           SmallVector<SDValue, 32> V(32,
@@ -12583,8 +12596,9 @@ static SDValue LowerScalarImmediateShift
         }
         if (Op.getOpcode() == ISD::SRL) {
           // Make a large shift.
-          SDValue SRL = DAG.getNode(X86ISD::VSRLI, dl, MVT::v16i16, R,
-                                    DAG.getConstant(ShiftAmt, MVT::i32));
+          SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl,
+                                                   MVT::v16i16, R, ShiftAmt,
+                                                   DAG);
           SRL = DAG.getNode(ISD::BITCAST, dl, VT, SRL);
           // Zero out the leftmost bits.
           SmallVector<SDValue, 32> V(32,
@@ -12649,14 +12663,14 @@ static SDValue LowerScalarImmediateShift
     default:
       llvm_unreachable("Unknown shift opcode!");
     case ISD::SHL:
-      return DAG.getNode(X86ISD::VSHLI, dl, VT, R,
-                         DAG.getConstant(ShiftAmt, MVT::i32));
+      return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
+                                        DAG);
     case ISD::SRL:
-      return DAG.getNode(X86ISD::VSRLI, dl, VT, R,
-                         DAG.getConstant(ShiftAmt, MVT::i32));
+      return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
+                                        DAG);
     case ISD::SRA:
-      return DAG.getNode(X86ISD::VSRAI, dl, VT, R,
-                         DAG.getConstant(ShiftAmt, MVT::i32));
+      return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
+                                        DAG);
     }
   }
 
@@ -12869,8 +12883,7 @@ static SDValue LowerShift(SDValue Op, co
 
     // r = VSELECT(r, psllw(r & (char16)15, 4), a);
     SDValue M = DAG.getNode(ISD::AND, dl, VT, R, CM1);
-    M = getTargetVShiftNode(X86ISD::VSHLI, dl, MVT::v8i16, M,
-                            DAG.getConstant(4, MVT::i32), DAG);
+    M = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, MVT::v8i16, M, 4, DAG);
     M = DAG.getNode(ISD::BITCAST, dl, VT, M);
     R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel, M, R);
 
@@ -12881,8 +12894,7 @@ static SDValue LowerShift(SDValue Op, co
 
     // r = VSELECT(r, psllw(r & (char16)63, 2), a);
     M = DAG.getNode(ISD::AND, dl, VT, R, CM2);
-    M = getTargetVShiftNode(X86ISD::VSHLI, dl, MVT::v8i16, M,
-                            DAG.getConstant(2, MVT::i32), DAG);
+    M = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, MVT::v8i16, M, 2, DAG);
     M = DAG.getNode(ISD::BITCAST, dl, VT, M);
     R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel, M, R);
 
@@ -13025,7 +13037,6 @@ SDValue X86TargetLowering::LowerSIGN_EXT
 
   unsigned BitsDiff = VT.getScalarType().getSizeInBits() -
                       ExtraVT.getScalarType().getSizeInBits();
-  SDValue ShAmt = DAG.getConstant(BitsDiff, MVT::i32);
 
   switch (VT.getSimpleVT().SimpleTy) {
     default: return SDValue();
@@ -13075,8 +13086,10 @@ SDValue X86TargetLowering::LowerSIGN_EXT
       }
 
       // If the above didn't work, then just use Shift-Left + Shift-Right.
-      Tmp1 = getTargetVShiftNode(X86ISD::VSHLI, dl, VT, Op0, ShAmt, DAG);
-      return getTargetVShiftNode(X86ISD::VSRAI, dl, VT, Tmp1, ShAmt, DAG);
+      Tmp1 = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Op0, BitsDiff,
+                                        DAG);
+      return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, Tmp1, BitsDiff,
+                                        DAG);
     }
   }
 }

Modified: llvm/trunk/lib/Target/X86/X86InstrAVX512.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86InstrAVX512.td?rev=193096&r1=193095&r2=193096&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86InstrAVX512.td (original)
+++ llvm/trunk/lib/Target/X86/X86InstrAVX512.td Mon Oct 21 12:51:24 2013
@@ -1845,22 +1845,22 @@ multiclass avx512_shift_rmi<bits<8> opc,
                          ValueType vt, X86MemOperand x86memop, PatFrag mem_frag,
                          RegisterClass KRC> {
   def ri : AVX512BIi8<opc, ImmFormR, (outs RC:$dst),
-       (ins RC:$src1, i32i8imm:$src2),
+       (ins RC:$src1, i8imm:$src2),
            !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
-       [(set RC:$dst, (vt (OpNode RC:$src1, (i32 imm:$src2))))],
+       [(set RC:$dst, (vt (OpNode RC:$src1, (i8 imm:$src2))))],
         SSE_INTSHIFT_ITINS_P.rr>, EVEX_4V;
   def rik : AVX512BIi8<opc, ImmFormR, (outs RC:$dst),
-       (ins KRC:$mask, RC:$src1, i32i8imm:$src2),
+       (ins KRC:$mask, RC:$src1, i8imm:$src2),
            !strconcat(OpcodeStr,
                 "\t{$src2, $src1, $dst {${mask}}|$dst {${mask}}, $src1, $src2}"),
        [], SSE_INTSHIFT_ITINS_P.rr>, EVEX_4V, EVEX_K;
   def mi: AVX512BIi8<opc, ImmFormM, (outs RC:$dst),
-       (ins x86memop:$src1, i32i8imm:$src2),
+       (ins x86memop:$src1, i8imm:$src2),
            !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
        [(set RC:$dst, (OpNode (mem_frag addr:$src1),
-                     (i32 imm:$src2)))], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V;
+                     (i8 imm:$src2)))], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V;
   def mik: AVX512BIi8<opc, ImmFormM, (outs RC:$dst),
-       (ins KRC:$mask, x86memop:$src1, i32i8imm:$src2),
+       (ins KRC:$mask, x86memop:$src1, i8imm:$src2),
            !strconcat(OpcodeStr,
                 "\t{$src2, $src1, $dst {${mask}}|$dst {${mask}}, $src1, $src2}"),
        [], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V, EVEX_K;

Modified: llvm/trunk/lib/Target/X86/X86InstrSSE.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86InstrSSE.td?rev=193096&r1=193095&r2=193096&view=diff
==============================================================================
--- llvm/trunk/lib/Target/X86/X86InstrSSE.td (original)
+++ llvm/trunk/lib/Target/X86/X86InstrSSE.td Mon Oct 21 12:51:24 2013
@@ -3744,11 +3744,11 @@ multiclass PDI_binop_rmi<bits<8> opc, bi
                        (bc_frag (memopv2i64 addr:$src2)))))], itins.rm>,
       Sched<[WriteVecShiftLd, ReadAfterLd]>;
   def ri : PDIi8<opc2, ImmForm, (outs RC:$dst),
-       (ins RC:$src1, i32i8imm:$src2),
+       (ins RC:$src1, i8imm:$src2),
        !if(Is2Addr,
            !strconcat(OpcodeStr, "\t{$src2, $dst|$dst, $src2}"),
            !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}")),
-       [(set RC:$dst, (DstVT (OpNode2 RC:$src1, (i32 imm:$src2))))], itins.ri>,
+       [(set RC:$dst, (DstVT (OpNode2 RC:$src1, (i8 imm:$src2))))], itins.ri>,
        Sched<[WriteVecShift]>;
 }
 
@@ -5064,12 +5064,12 @@ multiclass SS3I_unop_rm_int_y<bits<8> op
 // Helper fragments to match sext vXi1 to vXiY.
 def v16i1sextv16i8 : PatLeaf<(v16i8 (X86pcmpgt (bc_v16i8 (v4i32 immAllZerosV)),
                                                VR128:$src))>;
-def v8i1sextv8i16  : PatLeaf<(v8i16 (X86vsrai VR128:$src, (i32 15)))>;
-def v4i1sextv4i32  : PatLeaf<(v4i32 (X86vsrai VR128:$src, (i32 31)))>;
+def v8i1sextv8i16  : PatLeaf<(v8i16 (X86vsrai VR128:$src, (i8 15)))>;
+def v4i1sextv4i32  : PatLeaf<(v4i32 (X86vsrai VR128:$src, (i8 31)))>;
 def v32i1sextv32i8 : PatLeaf<(v32i8 (X86pcmpgt (bc_v32i8 (v8i32 immAllZerosV)),
                                                VR256:$src))>;
-def v16i1sextv16i16: PatLeaf<(v16i16 (X86vsrai VR256:$src, (i32 15)))>;
-def v8i1sextv8i32  : PatLeaf<(v8i32 (X86vsrai VR256:$src, (i32 31)))>;
+def v16i1sextv16i16: PatLeaf<(v16i16 (X86vsrai VR256:$src, (i8 15)))>;
+def v8i1sextv8i32  : PatLeaf<(v8i32 (X86vsrai VR256:$src, (i8 31)))>;
 
 let Predicates = [HasAVX] in {
   defm VPABSB  : SS3I_unop_rm_int<0x1C, "vpabsb",

Modified: llvm/trunk/test/CodeGen/X86/avx2-vector-shifts.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/avx2-vector-shifts.ll?rev=193096&r1=193095&r2=193096&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/avx2-vector-shifts.ll (original)
+++ llvm/trunk/test/CodeGen/X86/avx2-vector-shifts.ll Mon Oct 21 12:51:24 2013
@@ -121,7 +121,7 @@ entry:
 }
 
 ; CHECK-LABEL: test_sraw_3:
-; CHECK: vpsraw  $16, %ymm0, %ymm0
+; CHECK: vpsraw  $15, %ymm0, %ymm0
 ; CHECK: ret
 
 define <8 x i32> @test_srad_1(<8 x i32> %InVec) {
@@ -151,7 +151,7 @@ entry:
 }
 
 ; CHECK-LABEL: test_srad_3:
-; CHECK: vpsrad  $32, %ymm0, %ymm0
+; CHECK: vpsrad  $31, %ymm0, %ymm0
 ; CHECK: ret
 
 ; SSE Logical Shift Right

Modified: llvm/trunk/test/CodeGen/X86/sse2-vector-shifts.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/X86/sse2-vector-shifts.ll?rev=193096&r1=193095&r2=193096&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/X86/sse2-vector-shifts.ll (original)
+++ llvm/trunk/test/CodeGen/X86/sse2-vector-shifts.ll Mon Oct 21 12:51:24 2013
@@ -121,7 +121,7 @@ entry:
 }
 
 ; CHECK-LABEL: test_sraw_3:
-; CHECK: psraw   $16, %xmm0
+; CHECK: psraw   $15, %xmm0
 ; CHECK-NEXT: ret
 
 define <4 x i32> @test_srad_1(<4 x i32> %InVec) {
@@ -151,7 +151,7 @@ entry:
 }
 
 ; CHECK-LABEL: test_srad_3:
-; CHECK: psrad   $32, %xmm0
+; CHECK: psrad   $31, %xmm0
 ; CHECK-NEXT: ret
 
 ; SSE Logical Shift Right





More information about the llvm-commits mailing list