[llvm] r356590 - Simplify operands of masked stores and scatters based on demanded elements

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 20 11:44:58 PDT 2019


Author: reames
Date: Wed Mar 20 11:44:58 2019
New Revision: 356590

URL: http://llvm.org/viewvc/llvm-project?rev=356590&view=rev
Log:
Simplify operands of masked stores and scatters based on demanded elements

If we know we're not storing a lane, we don't need to compute the lane. This could be improved by using the undef element result to further prune the mask, but I want to separate that into its own change since it's relatively likely to expose other problems.

Differential Revision: https://reviews.llvm.org/D57247


Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp
    llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/trunk/test/Transforms/InstCombine/masked_intrinsics.ll

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp?rev=356590&r1=356589&r2=356590&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCalls.cpp Wed Mar 20 11:44:58 2019
@@ -25,6 +25,7 @@
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
@@ -1175,6 +1176,20 @@ static bool maskIsAllOneOrUndef(Value *M
   return true;
 }
 
+/// Given a mask vector <Y x i1>, return an APInt (of bitwidth Y) for each lane
+/// which may be active.  TODO: This is a lot like known bits, but for
+/// vectors.  Is there something we can common this with?
+static APInt possiblyDemandedEltsInMask(Value *Mask) {
+
+  const unsigned VWidth = cast<VectorType>(Mask->getType())->getNumElements();
+  APInt DemandedElts = APInt::getAllOnesValue(VWidth);
+  if (auto *CV = dyn_cast<ConstantVector>(Mask))
+    for (unsigned i = 0; i < VWidth; i++)
+      if (CV->getAggregateElement(i)->isNullValue())
+        DemandedElts.clearBit(i);
+  return DemandedElts;
+}
+
 // TODO, Obvious Missing Transforms:
 // * Dereferenceable address -> speculative load/select
 // * Narrow width by halfs excluding zero/undef lanes
@@ -1196,14 +1211,14 @@ static Value *simplifyMaskedLoad(const I
 // * SimplifyDemandedVectorElts
 // * Single constant active lane -> store
 // * Narrow width by halfs excluding zero/undef lanes
-static Instruction *simplifyMaskedStore(IntrinsicInst &II, InstCombiner &IC) {
+Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) {
   auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
   if (!ConstMask)
     return nullptr;
 
   // If the mask is all zeros, this instruction does nothing.
   if (ConstMask->isNullValue())
-    return IC.eraseInstFromFunction(II);
+    return eraseInstFromFunction(II);
 
   // If the mask is all ones, this is a plain vector store of the 1st argument.
   if (ConstMask->isAllOnesValue()) {
@@ -1212,6 +1227,15 @@ static Instruction *simplifyMaskedStore(
     return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment);
   }
 
+  // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts
+  APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
+  APInt UndefElts(DemandedElts.getBitWidth(), 0);
+  if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0),
+                                            DemandedElts, UndefElts)) {
+    II.setOperand(0, V);
+    return &II;
+  }
+
   return nullptr;
 }
 
@@ -1268,11 +1292,28 @@ static Instruction *simplifyInvariantGro
 // * Single constant active lane -> store
 // * Adjacent vector addresses -> masked.store
 // * Narrow store width by halfs excluding zero/undef lanes
-static Instruction *simplifyMaskedScatter(IntrinsicInst &II, InstCombiner &IC) {
-  // If the mask is all zeros, a scatter does nothing.
+Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) {
   auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
-  if (ConstMask && ConstMask->isNullValue())
-    return IC.eraseInstFromFunction(II);
+  if (!ConstMask)
+    return nullptr;
+
+  // If the mask is all zeros, a scatter does nothing.
+  if (ConstMask->isNullValue())
+    return eraseInstFromFunction(II);
+
+  // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts
+  APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
+  APInt UndefElts(DemandedElts.getBitWidth(), 0);
+  if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0),
+                                            DemandedElts, UndefElts)) {
+    II.setOperand(0, V);
+    return &II;
+  }
+  if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1),
+                                            DemandedElts, UndefElts)) {
+    II.setOperand(1, V);
+    return &II;
+  }
 
   return nullptr;
 }
@@ -1972,11 +2013,11 @@ Instruction *InstCombiner::visitCallInst
       return replaceInstUsesWith(CI, SimplifiedMaskedOp);
     break;
   case Intrinsic::masked_store:
-    return simplifyMaskedStore(*II, *this);
+    return simplifyMaskedStore(*II);
   case Intrinsic::masked_gather:
     return simplifyMaskedGather(*II, *this);
   case Intrinsic::masked_scatter:
-    return simplifyMaskedScatter(*II, *this);
+    return simplifyMaskedScatter(*II);
   case Intrinsic::launder_invariant_group:
   case Intrinsic::strip_invariant_group:
     if (auto *SkippedBarrier = simplifyInvariantGroupIntrinsic(*II, *this))

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h?rev=356590&r1=356589&r2=356590&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h Wed Mar 20 11:44:58 2019
@@ -474,6 +474,9 @@ private:
   Instruction *transformCallThroughTrampoline(CallBase &Call,
                                               IntrinsicInst &Tramp);
 
+  Instruction *simplifyMaskedStore(IntrinsicInst &II);
+  Instruction *simplifyMaskedScatter(IntrinsicInst &II);
+  
   /// Transform (zext icmp) to bitwise / integer operations in order to
   /// eliminate it.
   ///

Modified: llvm/trunk/test/Transforms/InstCombine/masked_intrinsics.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/masked_intrinsics.ll?rev=356590&r1=356589&r2=356590&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/masked_intrinsics.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/masked_intrinsics.ll Wed Mar 20 11:44:58 2019
@@ -80,8 +80,7 @@ define void @store_onemask(<2 x double>*
 
 define void @store_demandedelts(<2 x double>* %ptr, double %val)  {
 ; CHECK-LABEL: @store_demandedelts(
-; CHECK-NEXT:    [[VALVEC1:%.*]] = insertelement <2 x double> undef, double [[VAL:%.*]], i32 0
-; CHECK-NEXT:    [[VALVEC2:%.*]] = shufflevector <2 x double> [[VALVEC1]], <2 x double> undef, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[VALVEC2:%.*]] = insertelement <2 x double> undef, double [[VAL:%.*]], i32 0
 ; CHECK-NEXT:    call void @llvm.masked.store.v2f64.p0v2f64(<2 x double> [[VALVEC2]], <2 x double>* [[PTR:%.*]], i32 4, <2 x i1> <i1 true, i1 false>)
 ; CHECK-NEXT:    ret void
 ;
@@ -137,9 +136,8 @@ define void @scatter_zeromask(<2 x doubl
 
 define void @scatter_demandedelts(double* %ptr, double %val)  {
 ; CHECK-LABEL: @scatter_demandedelts(
-; CHECK-NEXT:    [[PTRS:%.*]] = getelementptr double, double* [[PTR:%.*]], <2 x i64> <i64 0, i64 1>
-; CHECK-NEXT:    [[VALVEC1:%.*]] = insertelement <2 x double> undef, double [[VAL:%.*]], i32 0
-; CHECK-NEXT:    [[VALVEC2:%.*]] = shufflevector <2 x double> [[VALVEC1]], <2 x double> undef, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[PTRS:%.*]] = getelementptr double, double* [[PTR:%.*]], <2 x i64> <i64 0, i64 undef>
+; CHECK-NEXT:    [[VALVEC2:%.*]] = insertelement <2 x double> undef, double [[VAL:%.*]], i32 0
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> [[VALVEC2]], <2 x double*> [[PTRS]], i32 8, <2 x i1> <i1 true, i1 false>)
 ; CHECK-NEXT:    ret void
 ;




More information about the llvm-commits mailing list