[llvm] 3170d54 - [InstCombine][X86] Covert masked load/stores with (sign extended) bool vector masks to generic intrinsics.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 12 07:24:33 PDT 2020


Author: Simon Pilgrim
Date: 2020-09-12T15:09:28+01:00
New Revision: 3170d54842655d6d936aae32b7d0bc92fce7f22e

URL: https://github.com/llvm/llvm-project/commit/3170d54842655d6d936aae32b7d0bc92fce7f22e
DIFF: https://github.com/llvm/llvm-project/commit/3170d54842655d6d936aae32b7d0bc92fce7f22e.diff

LOG: [InstCombine][X86] Covert masked load/stores with (sign extended) bool vector masks to generic intrinsics.

As detailed on PR11210, if the mask is known to come from a (sign extended) bool vector (e.g. comparisons) then we can represent with a generic masked load/store without losing anything.

We already do something similar for BLENDV -> SELECT conversion.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
    llvm/test/Transforms/InstCombine/X86/x86-masked-memops.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index d93f22d0365c..2390a9818369 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -32,6 +32,23 @@ static Constant *getNegativeIsTrueBoolVec(Constant *V) {
   return V;
 }
 
+/// Convert the x86 XMM integer vector mask to a vector of bools based on
+/// each element's most significant bit (the sign bit).
+static Value *getBoolVecFromMask(Value *Mask) {
+  // Fold Constant Mask.
+  if (auto *ConstantMask = dyn_cast<ConstantDataVector>(Mask))
+    return getNegativeIsTrueBoolVec(ConstantMask);
+
+  // Mask was extended from a boolean vector.
+  Value *ExtMask;
+  if (PatternMatch::match(
+          Mask, PatternMatch::m_SExt(PatternMatch::m_Value(ExtMask))) &&
+      ExtMask->getType()->isIntOrIntVectorTy(1))
+    return ExtMask;
+
+  return nullptr;
+}
+
 // TODO: If the x86 backend knew how to convert a bool vector mask back to an
 // XMM register mask efficiently, we could transform all x86 masked intrinsics
 // to LLVM masked intrinsics and remove the x86 masked intrinsic defs.
@@ -40,32 +57,26 @@ static Instruction *simplifyX86MaskedLoad(IntrinsicInst &II, InstCombiner &IC) {
   Value *Mask = II.getOperand(1);
   Constant *ZeroVec = Constant::getNullValue(II.getType());
 
-  // Special case a zero mask since that's not a ConstantDataVector.
-  // This masked load instruction creates a zero vector.
+  // Zero Mask - masked load instruction creates a zero vector.
   if (isa<ConstantAggregateZero>(Mask))
     return IC.replaceInstUsesWith(II, ZeroVec);
 
-  auto *ConstMask = dyn_cast<ConstantDataVector>(Mask);
-  if (!ConstMask)
-    return nullptr;
-
-  // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic
-  // to allow target-independent optimizations.
-
-  // First, cast the x86 intrinsic scalar pointer to a vector pointer to match
-  // the LLVM intrinsic definition for the pointer argument.
-  unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace();
-  PointerType *VecPtrTy = PointerType::get(II.getType(), AddrSpace);
-  Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec");
-
-  // Second, convert the x86 XMM integer vector mask to a vector of bools based
-  // on each element's most significant bit (the sign bit).
-  Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask);
+  // The mask is constant or extended from a bool vector. Convert this x86
+  // intrinsic to the LLVM intrinsic to allow target-independent optimizations.
+  if (Value *BoolMask = getBoolVecFromMask(Mask)) {
+    // First, cast the x86 intrinsic scalar pointer to a vector pointer to match
+    // the LLVM intrinsic definition for the pointer argument.
+    unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace();
+    PointerType *VecPtrTy = PointerType::get(II.getType(), AddrSpace);
+    Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec");
+
+    // The pass-through vector for an x86 masked load is a zero vector.
+    CallInst *NewMaskedLoad =
+        IC.Builder.CreateMaskedLoad(PtrCast, Align(1), BoolMask, ZeroVec);
+    return IC.replaceInstUsesWith(II, NewMaskedLoad);
+  }
 
-  // The pass-through vector for an x86 masked load is a zero vector.
-  CallInst *NewMaskedLoad =
-      IC.Builder.CreateMaskedLoad(PtrCast, Align(1), BoolMask, ZeroVec);
-  return IC.replaceInstUsesWith(II, NewMaskedLoad);
+  return nullptr;
 }
 
 // TODO: If the x86 backend knew how to convert a bool vector mask back to an
@@ -76,8 +87,7 @@ static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) {
   Value *Mask = II.getOperand(1);
   Value *Vec = II.getOperand(2);
 
-  // Special case a zero mask since that's not a ConstantDataVector:
-  // this masked store instruction does nothing.
+  // Zero Mask - this masked store instruction does nothing.
   if (isa<ConstantAggregateZero>(Mask)) {
     IC.eraseInstFromFunction(II);
     return true;
@@ -88,28 +98,21 @@ static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) {
   if (II.getIntrinsicID() == Intrinsic::x86_sse2_maskmov_dqu)
     return false;
 
-  auto *ConstMask = dyn_cast<ConstantDataVector>(Mask);
-  if (!ConstMask)
-    return false;
-
-  // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic
-  // to allow target-independent optimizations.
+  // The mask is constant or extended from a bool vector. Convert this x86
+  // intrinsic to the LLVM intrinsic to allow target-independent optimizations.
+  if (Value *BoolMask = getBoolVecFromMask(Mask)) {
+    unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace();
+    PointerType *VecPtrTy = PointerType::get(Vec->getType(), AddrSpace);
+    Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec");
 
-  // First, cast the x86 intrinsic scalar pointer to a vector pointer to match
-  // the LLVM intrinsic definition for the pointer argument.
-  unsigned AddrSpace = cast<PointerType>(Ptr->getType())->getAddressSpace();
-  PointerType *VecPtrTy = PointerType::get(Vec->getType(), AddrSpace);
-  Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec");
+    IC.Builder.CreateMaskedStore(Vec, PtrCast, Align(1), BoolMask);
 
-  // Second, convert the x86 XMM integer vector mask to a vector of bools based
-  // on each element's most significant bit (the sign bit).
-  Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask);
-
-  IC.Builder.CreateMaskedStore(Vec, PtrCast, Align(1), BoolMask);
+    // 'Replace uses' doesn't work for stores. Erase the original masked store.
+    IC.eraseInstFromFunction(II);
+    return true;
+  }
 
-  // 'Replace uses' doesn't work for stores. Erase the original masked store.
-  IC.eraseInstFromFunction(II);
-  return true;
+  return false;
 }
 
 static Value *simplifyX86immShift(const IntrinsicInst &II,

diff  --git a/llvm/test/Transforms/InstCombine/X86/x86-masked-memops.ll b/llvm/test/Transforms/InstCombine/X86/x86-masked-memops.ll
index 2975b1c27479..ff4c05164d00 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-masked-memops.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-masked-memops.ll
@@ -14,14 +14,14 @@ define <4 x float> @mload(i8* %f, <4 x i32> %mask) {
   ret <4 x float> %ld
 }
 
-; TODO: If the mask comes from a comparison, convert to an LLVM intrinsic. The backend should optimize further.
+; If the mask comes from a comparison, convert to an LLVM intrinsic. The backend should optimize further.
 
 define <4 x float> @mload_v4f32_cmp(i8* %f, <4 x i32> %src) {
 ; CHECK-LABEL: @mload_v4f32_cmp(
 ; CHECK-NEXT:    [[ICMP:%.*]] = icmp ne <4 x i32> [[SRC:%.*]], zeroinitializer
-; CHECK-NEXT:    [[MASK:%.*]] = sext <4 x i1> [[ICMP]] to <4 x i32>
-; CHECK-NEXT:    [[LD:%.*]] = tail call <4 x float> @llvm.x86.avx.maskload.ps(i8* [[F:%.*]], <4 x i32> [[MASK]])
-; CHECK-NEXT:    ret <4 x float> [[LD]]
+; CHECK-NEXT:    [[CASTVEC:%.*]] = bitcast i8* [[F:%.*]] to <4 x float>*
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>* [[CASTVEC]], i32 1, <4 x i1> [[ICMP]], <4 x float> zeroinitializer)
+; CHECK-NEXT:    ret <4 x float> [[TMP1]]
 ;
   %icmp = icmp ne <4 x i32> %src, zeroinitializer
   %mask = sext <4 x i1> %icmp to <4 x i32>
@@ -102,9 +102,9 @@ define <8 x float> @mload_v8f32_cmp(i8* %f, <8 x float> %src0, <8 x float> %src1
 ; CHECK-NEXT:    [[ICMP0:%.*]] = fcmp one <8 x float> [[SRC0:%.*]], zeroinitializer
 ; CHECK-NEXT:    [[ICMP1:%.*]] = fcmp one <8 x float> [[SRC1:%.*]], zeroinitializer
 ; CHECK-NEXT:    [[MASK1:%.*]] = and <8 x i1> [[ICMP0]], [[ICMP1]]
-; CHECK-NEXT:    [[MASK:%.*]] = sext <8 x i1> [[MASK1]] to <8 x i32>
-; CHECK-NEXT:    [[LD:%.*]] = tail call <8 x float> @llvm.x86.avx.maskload.ps.256(i8* [[F:%.*]], <8 x i32> [[MASK]])
-; CHECK-NEXT:    ret <8 x float> [[LD]]
+; CHECK-NEXT:    [[CASTVEC:%.*]] = bitcast i8* [[F:%.*]] to <8 x float>*
+; CHECK-NEXT:    [[TMP1:%.*]] = call <8 x float> @llvm.masked.load.v8f32.p0v8f32(<8 x float>* [[CASTVEC]], i32 1, <8 x i1> [[MASK1]], <8 x float> zeroinitializer)
+; CHECK-NEXT:    ret <8 x float> [[TMP1]]
 ;
   %icmp0 = fcmp one <8 x float> %src0, zeroinitializer
   %icmp1 = fcmp one <8 x float> %src1, zeroinitializer
@@ -193,13 +193,13 @@ define void @mstore(i8* %f, <4 x i32> %mask, <4 x float> %v) {
   ret void
 }
 
-; TODO: If the mask comes from a comparison, convert to an LLVM intrinsic. The backend should optimize further.
+; If the mask comes from a comparison, convert to an LLVM intrinsic. The backend should optimize further.
 
 define void @mstore_v4f32_cmp(i8* %f, <4 x i32> %src, <4 x float> %v) {
 ; CHECK-LABEL: @mstore_v4f32_cmp(
 ; CHECK-NEXT:    [[ICMP:%.*]] = icmp eq <4 x i32> [[SRC:%.*]], zeroinitializer
-; CHECK-NEXT:    [[MASK:%.*]] = sext <4 x i1> [[ICMP]] to <4 x i32>
-; CHECK-NEXT:    tail call void @llvm.x86.avx.maskstore.ps(i8* [[F:%.*]], <4 x i32> [[MASK]], <4 x float> [[V:%.*]])
+; CHECK-NEXT:    [[CASTVEC:%.*]] = bitcast i8* [[F:%.*]] to <4 x float>*
+; CHECK-NEXT:    call void @llvm.masked.store.v4f32.p0v4f32(<4 x float> [[V:%.*]], <4 x float>* [[CASTVEC]], i32 1, <4 x i1> [[ICMP]])
 ; CHECK-NEXT:    ret void
 ;
   %icmp = icmp eq <4 x i32> %src, zeroinitializer
@@ -348,8 +348,8 @@ define void @mstore_v4i64_cmp(i8* %f, <4 x i64> %src0, <4 x i64> %src1, <4 x i64
 ; CHECK-NEXT:    [[ICMP0:%.*]] = icmp eq <4 x i64> [[SRC0:%.*]], zeroinitializer
 ; CHECK-NEXT:    [[ICMP1:%.*]] = icmp ne <4 x i64> [[SRC1:%.*]], zeroinitializer
 ; CHECK-NEXT:    [[MASK1:%.*]] = and <4 x i1> [[ICMP0]], [[ICMP1]]
-; CHECK-NEXT:    [[MASK:%.*]] = sext <4 x i1> [[MASK1]] to <4 x i64>
-; CHECK-NEXT:    tail call void @llvm.x86.avx2.maskstore.q.256(i8* [[F:%.*]], <4 x i64> [[MASK]], <4 x i64> [[V:%.*]])
+; CHECK-NEXT:    [[CASTVEC:%.*]] = bitcast i8* [[F:%.*]] to <4 x i64>*
+; CHECK-NEXT:    call void @llvm.masked.store.v4i64.p0v4i64(<4 x i64> [[V:%.*]], <4 x i64>* [[CASTVEC]], i32 1, <4 x i1> [[MASK1]])
 ; CHECK-NEXT:    ret void
 ;
   %icmp0 = icmp eq <4 x i64> %src0, zeroinitializer


        


More information about the llvm-commits mailing list