[llvm] r275379 - Simplify llvm.masked.load w/ undef masks

David Majnemer via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 13 23:58:37 PDT 2016


Author: majnemer
Date: Thu Jul 14 01:58:37 2016
New Revision: 275379

URL: http://llvm.org/viewvc/llvm-project?rev=275379&view=rev
Log:
Simplify llvm.masked.load w/ undef masks

We can always pick the passthru value if the mask is undef: we are
permitted to treat the mask as-if it were filled with zeros.

Modified:
    llvm/trunk/lib/Analysis/ConstantFolding.cpp
    llvm/trunk/lib/Analysis/InstructionSimplify.cpp
    llvm/trunk/test/Transforms/InstSimplify/call.ll

Modified: llvm/trunk/lib/Analysis/ConstantFolding.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ConstantFolding.cpp?rev=275379&r1=275378&r2=275379&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ConstantFolding.cpp (original)
+++ llvm/trunk/lib/Analysis/ConstantFolding.cpp Thu Jul 14 01:58:37 2016
@@ -1854,32 +1854,39 @@ Constant *ConstantFoldVectorCall(StringR
     auto *SrcPtr = Operands[0];
     auto *Mask = Operands[2];
     auto *Passthru = Operands[3];
+
     Constant *VecData = ConstantFoldLoadFromConstPtr(SrcPtr, VTy, DL);
-    if (!VecData)
-      return nullptr;
 
     SmallVector<Constant *, 32> NewElements;
     for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) {
-      auto *MaskElt =
-          dyn_cast_or_null<ConstantInt>(Mask->getAggregateElement(I));
+      auto *MaskElt = Mask->getAggregateElement(I);
       if (!MaskElt)
         break;
-      if (MaskElt->isZero()) {
-        auto *PassthruElt = Passthru->getAggregateElement(I);
+      auto *PassthruElt = Passthru->getAggregateElement(I);
+      auto *VecElt = VecData ? VecData->getAggregateElement(I) : nullptr;
+      if (isa<UndefValue>(MaskElt)) {
+        if (PassthruElt)
+          NewElements.push_back(PassthruElt);
+        else if (VecElt)
+          NewElements.push_back(VecElt);
+        else
+          return nullptr;
+      }
+      if (MaskElt->isNullValue()) {
         if (!PassthruElt)
-          break;
+          return nullptr;
         NewElements.push_back(PassthruElt);
-      } else {
-        assert(MaskElt->isOne());
-        auto *VecElt = VecData->getAggregateElement(I);
+      } else if (MaskElt->isOneValue()) {
         if (!VecElt)
-          break;
+          return nullptr;
         NewElements.push_back(VecElt);
+      } else {
+        return nullptr;
       }
     }
-    if (NewElements.size() == VTy->getNumElements())
-      return ConstantVector::get(NewElements);
-    return nullptr;
+    if (NewElements.size() != VTy->getNumElements())
+      return nullptr;
+    return ConstantVector::get(NewElements);
   }
 
   for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) {

Modified: llvm/trunk/lib/Analysis/InstructionSimplify.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/InstructionSimplify.cpp?rev=275379&r1=275378&r2=275379&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/InstructionSimplify.cpp (original)
+++ llvm/trunk/lib/Analysis/InstructionSimplify.cpp Thu Jul 14 01:58:37 2016
@@ -3944,6 +3944,22 @@ static Value *SimplifyRelativeLoad(Const
   return ConstantExpr::getBitCast(LoadedLHSPtr, Int8PtrTy);
 }
 
+static bool maskIsAllZeroOrUndef(Value *Mask) {
+  auto *ConstMask = dyn_cast<Constant>(Mask);
+  if (!ConstMask)
+    return false;
+  if (ConstMask->isNullValue() || isa<UndefValue>(ConstMask))
+    return true;
+  for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E;
+       ++I) {
+    if (auto *MaskElt = ConstMask->getAggregateElement(I))
+      if (MaskElt->isNullValue() || isa<UndefValue>(MaskElt))
+        continue;
+    return false;
+  }
+  return true;
+}
+
 template <typename IterTy>
 static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd,
                                 const Query &Q, unsigned MaxRecurse) {
@@ -3993,11 +4009,11 @@ static Value *SimplifyIntrinsic(Function
 
   // Simplify calls to llvm.masked.load.*
   if (IID == Intrinsic::masked_load) {
-    IterTy MaskArg = ArgBegin + 2;
-    // If the mask is all zeros, the "passthru" argument is the result.
-    if (auto *ConstMask = dyn_cast<Constant>(*MaskArg))
-      if (ConstMask->isNullValue())
-        return ArgBegin[3];
+    Value *MaskArg = ArgBegin[2];
+    Value *PassthruArg = ArgBegin[3];
+    // If the mask is all zeros or undef, the "passthru" argument is the result.
+    if (maskIsAllZeroOrUndef(MaskArg))
+      return PassthruArg;
   }
 
   // Perform idempotent optimizations

Modified: llvm/trunk/test/Transforms/InstSimplify/call.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstSimplify/call.ll?rev=275379&r1=275378&r2=275379&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstSimplify/call.ll (original)
+++ llvm/trunk/test/Transforms/InstSimplify/call.ll Thu Jul 14 01:58:37 2016
@@ -213,6 +213,13 @@ define <8 x i32> @partial_masked_load()
   ret <8 x i32> %masked.load
 }
 
+define <8 x i32> @masked_load_undef_mask(<8 x i32>* %V) {
+; CHECK-LABEL: @masked_load_undef_mask(
+; CHECK:         ret <8 x i32> <i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0>
+  %masked.load = call <8 x i32> @llvm.masked.load.v8i32.p0v8i32(<8 x i32>* %V, i32 4, <8 x i1> undef, <8 x i32> <i32 1, i32 0, i32 1, i32 0, i32 1, i32 0, i32 1, i32 0>)
+  ret <8 x i32> %masked.load
+}
+
 declare noalias i8* @malloc(i64)
 
 declare <8 x i32> @llvm.masked.load.v8i32.p0v8i32(<8 x i32>*, i32, <8 x i1>, <8 x i32>)




More information about the llvm-commits mailing list