[llvm] r348862 - [InstCombine] try to convert x86 movmsk intrinsic to generic IR (PR39927)

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 11 08:38:03 PST 2018


Author: spatel
Date: Tue Dec 11 08:38:03 2018
New Revision: 348862

URL: http://llvm.org/viewvc/llvm-project?rev=348862&view=rev
Log:
[InstCombine] try to convert x86 movmsk intrinsic to generic IR (PR39927)

call iM movmsk(sext <N x i1> X) --> zext (bitcast <N x i1> X to iN) to iM

This has the potential to create less-than-8-bit scalar types as shown in 
some of the test diffs, but it looks like the backend knows how to deal 
with that in these patterns. This is the simple part of the fix suggested in:
https://bugs.llvm.org/show_bug.cgi?id=39927

Differential Revision: https://reviews.llvm.org/D55529

Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/trunk/test/Transforms/InstCombine/X86/x86-movmsk.ll

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp?rev=348862&r1=348861&r2=348862&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp Tue Dec 11 08:38:03 2018
@@ -736,7 +736,8 @@ static Value *simplifyX86round(Intrinsic
   return Builder.CreateInsertElement(Dst, Res, (uint64_t)0);
 }
 
-static Value *simplifyX86movmsk(const IntrinsicInst &II) {
+static Value *simplifyX86movmsk(const IntrinsicInst &II,
+                                InstCombiner::BuilderTy &Builder) {
   Value *Arg = II.getArgOperand(0);
   Type *ResTy = II.getType();
   Type *ArgTy = Arg->getType();
@@ -749,29 +750,46 @@ static Value *simplifyX86movmsk(const In
   if (!ArgTy->isVectorTy())
     return nullptr;
 
-  auto *C = dyn_cast<Constant>(Arg);
-  if (!C)
-    return nullptr;
+  if (auto *C = dyn_cast<Constant>(Arg)) {
+    // Extract signbits of the vector input and pack into integer result.
+    APInt Result(ResTy->getPrimitiveSizeInBits(), 0);
+    for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) {
+      auto *COp = C->getAggregateElement(I);
+      if (!COp)
+        return nullptr;
+      if (isa<UndefValue>(COp))
+        continue;
 
-  // Extract signbits of the vector input and pack into integer result.
-  APInt Result(ResTy->getPrimitiveSizeInBits(), 0);
-  for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) {
-    auto *COp = C->getAggregateElement(I);
-    if (!COp)
-      return nullptr;
-    if (isa<UndefValue>(COp))
-      continue;
+      auto *CInt = dyn_cast<ConstantInt>(COp);
+      auto *CFp = dyn_cast<ConstantFP>(COp);
+      if (!CInt && !CFp)
+        return nullptr;
 
-    auto *CInt = dyn_cast<ConstantInt>(COp);
-    auto *CFp = dyn_cast<ConstantFP>(COp);
-    if (!CInt && !CFp)
-      return nullptr;
+      if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative()))
+        Result.setBit(I);
+    }
+    return Constant::getIntegerValue(ResTy, Result);
+  }
 
-    if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative()))
-      Result.setBit(I);
+  // Look for a sign-extended boolean source vector as the argument to this
+  // movmsk. If the argument is bitcast, look through that, but make sure the
+  // source of that bitcast is still a vector with the same number of elements.
+  // TODO: We can also convert a bitcast with wider elements, but that requires
+  // duplicating the bool source sign bits to match the number of elements
+  // expected by the movmsk call.
+  Arg = peekThroughBitcast(Arg);
+  Value *X;
+  if (Arg->getType()->isVectorTy() &&
+      Arg->getType()->getVectorNumElements() == ArgTy->getVectorNumElements() &&
+      match(Arg, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
+    // call iM movmsk(sext <N x i1> X) --> zext (bitcast <N x i1> X to iN) to iM
+    unsigned NumElts = X->getType()->getVectorNumElements();
+    Type *ScalarTy = Type::getIntNTy(Arg->getContext(), NumElts);
+    Value *BC = Builder.CreateBitCast(X, ScalarTy);
+    return Builder.CreateZExtOrTrunc(BC, ResTy);
   }
 
-  return Constant::getIntegerValue(ResTy, Result);
+  return nullptr;
 }
 
 static Value *simplifyX86insertps(const IntrinsicInst &II,
@@ -2543,7 +2561,7 @@ Instruction *InstCombiner::visitCallInst
   case Intrinsic::x86_avx_movmsk_pd_256:
   case Intrinsic::x86_avx_movmsk_ps_256:
   case Intrinsic::x86_avx2_pmovmskb:
-    if (Value *V = simplifyX86movmsk(*II))
+    if (Value *V = simplifyX86movmsk(*II, Builder))
       return replaceInstUsesWith(*II, V);
     break;
 

Modified: llvm/trunk/test/Transforms/InstCombine/X86/x86-movmsk.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/X86/x86-movmsk.ll?rev=348862&r1=348861&r2=348862&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/X86/x86-movmsk.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/X86/x86-movmsk.ll Tue Dec 11 08:38:03 2018
@@ -315,10 +315,9 @@ define i32 @fold_x86_avx2_pmovmskb() {
 
 define i32 @sext_sse_movmsk_ps(<4 x i1> %x) {
 ; CHECK-LABEL: @sext_sse_movmsk_ps(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <4 x i1> [[X:%.*]] to <4 x i32>
-; CHECK-NEXT:    [[BC:%.*]] = bitcast <4 x i32> [[SEXT]] to <4 x float>
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]])
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i1> [[X:%.*]] to i4
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i4 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
   %sext = sext <4 x i1> %x to <4 x i32>
   %bc = bitcast <4 x i32> %sext to <4 x float>
@@ -328,10 +327,9 @@ define i32 @sext_sse_movmsk_ps(<4 x i1>
 
 define i32 @sext_sse2_movmsk_pd(<2 x i1> %x) {
 ; CHECK-LABEL: @sext_sse2_movmsk_pd(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <2 x i1> [[X:%.*]] to <2 x i64>
-; CHECK-NEXT:    [[BC:%.*]] = bitcast <2 x i64> [[SEXT]] to <2 x double>
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.sse2.movmsk.pd(<2 x double> [[BC]])
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i1> [[X:%.*]] to i2
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i2 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
   %sext = sext <2 x i1> %x to <2 x i64>
   %bc = bitcast <2 x i64> %sext to <2 x double>
@@ -341,9 +339,9 @@ define i32 @sext_sse2_movmsk_pd(<2 x i1>
 
 define i32 @sext_sse2_pmovmskb_128(<16 x i1> %x) {
 ; CHECK-LABEL: @sext_sse2_pmovmskb_128(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <16 x i1> [[X:%.*]] to <16 x i8>
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.sse2.pmovmskb.128(<16 x i8> [[SEXT]])
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
   %sext = sext <16 x i1> %x to <16 x i8>
   %r = call i32 @llvm.x86.sse2.pmovmskb.128(<16 x i8> %sext)
@@ -352,10 +350,9 @@ define i32 @sext_sse2_pmovmskb_128(<16 x
 
 define i32 @sext_avx_movmsk_ps_256(<8 x i1> %x) {
 ; CHECK-LABEL: @sext_avx_movmsk_ps_256(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <8 x i1> [[X:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[BC:%.*]] = bitcast <8 x i32> [[SEXT]] to <8 x float>
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.avx.movmsk.ps.256(<8 x float> [[BC]])
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i8 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
   %sext = sext <8 x i1> %x to <8 x i32>
   %bc = bitcast <8 x i32> %sext to <8 x float>
@@ -365,10 +362,9 @@ define i32 @sext_avx_movmsk_ps_256(<8 x
 
 define i32 @sext_avx_movmsk_pd_256(<4 x i1> %x) {
 ; CHECK-LABEL: @sext_avx_movmsk_pd_256(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <4 x i1> [[X:%.*]] to <4 x i64>
-; CHECK-NEXT:    [[BC:%.*]] = bitcast <4 x i64> [[SEXT]] to <4 x double>
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.avx.movmsk.pd.256(<4 x double> [[BC]])
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i1> [[X:%.*]] to i4
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i4 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
   %sext = sext <4 x i1> %x to <4 x i64>
   %bc = bitcast <4 x i64> %sext to <4 x double>
@@ -378,15 +374,60 @@ define i32 @sext_avx_movmsk_pd_256(<4 x
 
 define i32 @sext_avx2_pmovmskb(<32 x i1> %x) {
 ; CHECK-LABEL: @sext_avx2_pmovmskb(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <32 x i1> [[X:%.*]] to <32 x i8>
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.avx2.pmovmskb(<32 x i8> [[SEXT]])
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <32 x i1> [[X:%.*]] to i32
+; CHECK-NEXT:    ret i32 [[TMP1]]
 ;
   %sext = sext <32 x i1> %x to <32 x i8>
   %r = call i32 @llvm.x86.avx2.pmovmskb(<32 x i8> %sext)
   ret i32 %r
 }
 
+; Negative test - bitcast from scalar.
+
+define i32 @sext_sse_movmsk_ps_scalar_source(i1 %x) {
+; CHECK-LABEL: @sext_sse_movmsk_ps_scalar_source(
+; CHECK-NEXT:    [[SEXT:%.*]] = sext i1 [[X:%.*]] to i128
+; CHECK-NEXT:    [[BC:%.*]] = bitcast i128 [[SEXT]] to <4 x float>
+; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]])
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %sext = sext i1 %x to i128
+  %bc = bitcast i128 %sext to <4 x float>
+  %r = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> %bc)
+  ret i32 %r
+}
+
+; Negative test - bitcast from vector type with more elements.
+
+define i32 @sext_sse_movmsk_ps_too_many_elts(<8 x i1> %x) {
+; CHECK-LABEL: @sext_sse_movmsk_ps_too_many_elts(
+; CHECK-NEXT:    [[SEXT:%.*]] = sext <8 x i1> [[X:%.*]] to <8 x i16>
+; CHECK-NEXT:    [[BC:%.*]] = bitcast <8 x i16> [[SEXT]] to <4 x float>
+; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]])
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %sext = sext <8 x i1> %x to <8 x i16>
+  %bc = bitcast <8 x i16> %sext to <4 x float>
+  %r = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> %bc)
+  ret i32 %r
+}
+
+; TODO: We could handle this by doing a bitcasted sign-bit test after the sext?
+; But need to make sure the backend handles that correctly.
+
+define i32 @sext_sse_movmsk_ps_must_replicate_bits(<2 x i1> %x) {
+; CHECK-LABEL: @sext_sse_movmsk_ps_must_replicate_bits(
+; CHECK-NEXT:    [[SEXT:%.*]] = sext <2 x i1> [[X:%.*]] to <2 x i64>
+; CHECK-NEXT:    [[BC:%.*]] = bitcast <2 x i64> [[SEXT]] to <4 x float>
+; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]])
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %sext = sext <2 x i1> %x to <2 x i64>
+  %bc = bitcast <2 x i64> %sext to <4 x float>
+  %r = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> %bc)
+  ret i32 %r
+}
+
 declare i32 @llvm.x86.mmx.pmovmskb(x86_mmx)
 
 declare i32 @llvm.x86.sse.movmsk.ps(<4 x float>)




More information about the llvm-commits mailing list