[llvm] ae3f54c - [EarlyCSE] Handle masked loads and stores

Krzysztof Parzyszek via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 21 16:47:29 PDT 2020


Author: Krzysztof Parzyszek
Date: 2020-09-21T18:47:10-05:00
New Revision: ae3f54c1e909743a89d48a8a05e18d2c8fd652ba

URL: https://github.com/llvm/llvm-project/commit/ae3f54c1e909743a89d48a8a05e18d2c8fd652ba
DIFF: https://github.com/llvm/llvm-project/commit/ae3f54c1e909743a89d48a8a05e18d2c8fd652ba.diff

LOG: [EarlyCSE] Handle masked loads and stores

Extend the handling of memory intrinsics to also include non-
target-specific intrinsics, in particular masked loads and stores.

Invent "isHandledNonTargetIntrinsic" to distinguish between intrin-
sics that should be handled natively from intrinsics that can be
passed to TTI.

Add code that handles masked loads and stores and update the
testcase to reflect the results.

Differential Revision: https://reviews.llvm.org/D87340

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/EarlyCSE.cpp
    llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll
    llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index acdddcea4ae3..5eb2e12ff3e2 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -689,8 +689,33 @@ class EarlyCSE {
     ParseMemoryInst(Instruction *Inst, const TargetTransformInfo &TTI)
       : Inst(Inst) {
       if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) {
+        IntrID = II->getIntrinsicID();
         if (TTI.getTgtMemIntrinsic(II, Info))
-          IntrID = II->getIntrinsicID();
+          return;
+        if (isHandledNonTargetIntrinsic(IntrID)) {
+          switch (IntrID) {
+          case Intrinsic::masked_load:
+            Info.PtrVal = Inst->getOperand(0);
+            Info.MatchingId = Intrinsic::masked_load;
+            Info.ReadMem = true;
+            Info.WriteMem = false;
+            Info.IsVolatile = false;
+            break;
+          case Intrinsic::masked_store:
+            Info.PtrVal = Inst->getOperand(1);
+            // Use the ID of masked load as the "matching id". This will
+            // prevent matching non-masked loads/stores with masked ones
+            // (which could be done), but at the moment, the code here
+            // does not support matching intrinsics with non-intrinsics,
+            // so keep the MatchingIds specific to masked instructions
+            // for now (TODO).
+            Info.MatchingId = Intrinsic::masked_load;
+            Info.ReadMem = false;
+            Info.WriteMem = true;
+            Info.IsVolatile = false;
+            break;
+          }
+        }
       }
     }
 
@@ -747,11 +772,6 @@ class EarlyCSE {
       return false;
     }
 
-    bool isMatchingMemLoc(const ParseMemoryInst &Inst) const {
-      return (getPointerOperand() == Inst.getPointerOperand() &&
-              getMatchingId() == Inst.getMatchingId());
-    }
-
     bool isValid() const { return getPointerOperand() != nullptr; }
 
     // For regular (non-intrinsic) loads/stores, this is set to -1. For
@@ -788,6 +808,22 @@ class EarlyCSE {
     Instruction *Inst;
   };
 
+  // This function is to prevent accidentally passing a non-target
+  // intrinsic ID to TargetTransformInfo.
+  static bool isHandledNonTargetIntrinsic(Intrinsic::ID ID) {
+    switch (ID) {
+    case Intrinsic::masked_load:
+    case Intrinsic::masked_store:
+      return true;
+    }
+    return false;
+  }
+  static bool isHandledNonTargetIntrinsic(const Value *V) {
+    if (auto *II = dyn_cast<IntrinsicInst>(V))
+      return isHandledNonTargetIntrinsic(II->getIntrinsicID());
+    return false;
+  }
+
   bool processNode(DomTreeNode *Node);
 
   bool handleBranchCondition(Instruction *CondInst, const BranchInst *BI,
@@ -796,14 +832,30 @@ class EarlyCSE {
   Value *getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst,
                           unsigned CurrentGeneration);
 
+  bool overridingStores(const ParseMemoryInst &Earlier,
+                        const ParseMemoryInst &Later);
+
   Value *getOrCreateResult(Value *Inst, Type *ExpectedType) const {
     if (auto *LI = dyn_cast<LoadInst>(Inst))
       return LI;
     if (auto *SI = dyn_cast<StoreInst>(Inst))
       return SI->getValueOperand();
     assert(isa<IntrinsicInst>(Inst) && "Instruction not supported");
-    return TTI.getOrCreateResultFromMemIntrinsic(cast<IntrinsicInst>(Inst),
-                                                 ExpectedType);
+    auto *II = cast<IntrinsicInst>(Inst);
+    if (isHandledNonTargetIntrinsic(II->getIntrinsicID()))
+      return getOrCreateResultNonTargetMemIntrinsic(II, ExpectedType);
+    return TTI.getOrCreateResultFromMemIntrinsic(II, ExpectedType);
+  }
+
+  Value *getOrCreateResultNonTargetMemIntrinsic(IntrinsicInst *II,
+                                                Type *ExpectedType) const {
+    switch (II->getIntrinsicID()) {
+    case Intrinsic::masked_load:
+      return II;
+    case Intrinsic::masked_store:
+      return II->getOperand(0);
+    }
+    return nullptr;
   }
 
   /// Return true if the instruction is known to only operate on memory
@@ -813,6 +865,101 @@ class EarlyCSE {
   bool isSameMemGeneration(unsigned EarlierGeneration, unsigned LaterGeneration,
                            Instruction *EarlierInst, Instruction *LaterInst);
 
+  bool isNonTargetIntrinsicMatch(const IntrinsicInst *Earlier,
+                                 const IntrinsicInst *Later) {
+    auto IsSubmask = [](const Value *Mask0, const Value *Mask1) {
+      // Is Mask0 a submask of Mask1?
+      if (Mask0 == Mask1)
+        return true;
+      if (isa<UndefValue>(Mask0) || isa<UndefValue>(Mask1))
+        return false;
+      auto *Vec0 = dyn_cast<ConstantVector>(Mask0);
+      auto *Vec1 = dyn_cast<ConstantVector>(Mask1);
+      if (!Vec0 || !Vec1)
+        return false;
+      assert(Vec0->getType() == Vec1->getType() &&
+             "Masks should have the same type");
+      for (int i = 0, e = Vec0->getNumOperands(); i != e; ++i) {
+        Constant *Elem0 = Vec0->getOperand(i);
+        Constant *Elem1 = Vec1->getOperand(i);
+        auto *Int0 = dyn_cast<ConstantInt>(Elem0);
+        if (Int0 && Int0->isZero())
+          continue;
+        auto *Int1 = dyn_cast<ConstantInt>(Elem1);
+        if (Int1 && !Int1->isZero())
+          continue;
+        if (isa<UndefValue>(Elem0) || isa<UndefValue>(Elem1))
+          return false;
+        if (Elem0 == Elem1)
+          continue;
+        return false;
+      }
+      return true;
+    };
+    auto PtrOp = [](const IntrinsicInst *II) {
+      if (II->getIntrinsicID() == Intrinsic::masked_load)
+        return II->getOperand(0);
+      if (II->getIntrinsicID() == Intrinsic::masked_store)
+        return II->getOperand(1);
+      llvm_unreachable("Unexpected IntrinsicInst");
+    };
+    auto MaskOp = [](const IntrinsicInst *II) {
+      if (II->getIntrinsicID() == Intrinsic::masked_load)
+        return II->getOperand(2);
+      if (II->getIntrinsicID() == Intrinsic::masked_store)
+        return II->getOperand(3);
+      llvm_unreachable("Unexpected IntrinsicInst");
+    };
+    auto ThruOp = [](const IntrinsicInst *II) {
+      if (II->getIntrinsicID() == Intrinsic::masked_load)
+        return II->getOperand(3);
+      llvm_unreachable("Unexpected IntrinsicInst");
+    };
+
+    if (PtrOp(Earlier) != PtrOp(Later))
+      return false;
+
+    Intrinsic::ID IDE = Earlier->getIntrinsicID();
+    Intrinsic::ID IDL = Later->getIntrinsicID();
+    // We could really use specific intrinsic classes for masked loads
+    // and stores in IntrinsicInst.h.
+    if (IDE == Intrinsic::masked_load && IDL == Intrinsic::masked_load) {
+      // Trying to replace later masked load with the earlier one.
+      // Check that the pointers are the same, and
+      // - masks and pass-throughs are the same, or
+      // - replacee's pass-through is "undef" and replacer's mask is a
+      //   super-set of the replacee's mask.
+      if (MaskOp(Earlier) == MaskOp(Later) && ThruOp(Earlier) == ThruOp(Later))
+        return true;
+      if (!isa<UndefValue>(ThruOp(Later)))
+        return false;
+      return IsSubmask(MaskOp(Later), MaskOp(Earlier));
+    }
+    if (IDE == Intrinsic::masked_store && IDL == Intrinsic::masked_load) {
+      // Trying to replace a load of a stored value with the store's value.
+      // Check that the pointers are the same, and
+      // - load's mask is a subset of store's mask, and
+      // - load's pass-through is "undef".
+      if (!IsSubmask(MaskOp(Later), MaskOp(Earlier)))
+        return false;
+      return isa<UndefValue>(ThruOp(Later));
+    }
+    if (IDE == Intrinsic::masked_load && IDL == Intrinsic::masked_store) {
+      // Trying to remove a store of the loaded value.
+      // Check that the pointers are the same, and
+      // - store's mask is a subset of the load's mask.
+      return IsSubmask(MaskOp(Later), MaskOp(Earlier));
+    }
+    if (IDE == Intrinsic::masked_store && IDL == Intrinsic::masked_store) {
+      // Trying to remove a dead store (earlier).
+      // Check that the pointers are the same,
+      // - the to-be-removed store's mask is a subset of the other store's
+      //   mask.
+      return IsSubmask(MaskOp(Earlier), MaskOp(Later));
+    }
+    return false;
+  }
+
   void removeMSSA(Instruction &Inst) {
     if (!MSSA)
       return;
@@ -978,6 +1125,17 @@ Value *EarlyCSE::getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst,
   Instruction *Matching = MemInstMatching ? MemInst.get() : InVal.DefInst;
   Instruction *Other = MemInstMatching ? InVal.DefInst : MemInst.get();
 
+  // Deal with non-target memory intrinsics.
+  bool MatchingNTI = isHandledNonTargetIntrinsic(Matching);
+  bool OtherNTI = isHandledNonTargetIntrinsic(Other);
+  if (OtherNTI != MatchingNTI)
+    return nullptr;
+  if (OtherNTI && MatchingNTI) {
+    if (!isNonTargetIntrinsicMatch(cast<IntrinsicInst>(InVal.DefInst),
+                                   cast<IntrinsicInst>(MemInst.get())))
+      return nullptr;
+  }
+
   if (!isOperatingOnInvariantMemAt(MemInst.get(), InVal.Generation) &&
       !isSameMemGeneration(InVal.Generation, CurrentGeneration, InVal.DefInst,
                            MemInst.get()))
@@ -985,6 +1143,37 @@ Value *EarlyCSE::getMatchingValue(LoadValue &InVal, ParseMemoryInst &MemInst,
   return getOrCreateResult(Matching, Other->getType());
 }
 
+bool EarlyCSE::overridingStores(const ParseMemoryInst &Earlier,
+                                const ParseMemoryInst &Later) {
+  // Can we remove Earlier store because of Later store?
+
+  assert(Earlier.isUnordered() && !Earlier.isVolatile() &&
+         "Violated invariant");
+  if (Earlier.getPointerOperand() != Later.getPointerOperand())
+    return false;
+  if (Earlier.getMatchingId() != Later.getMatchingId())
+    return false;
+  // At the moment, we don't remove ordered stores, but do remove
+  // unordered atomic stores.  There's no special requirement (for
+  // unordered atomics) about removing atomic stores only in favor of
+  // other atomic stores since we were going to execute the non-atomic
+  // one anyway and the atomic one might never have become visible.
+  if (!Earlier.isUnordered() || !Later.isUnordered())
+    return false;
+
+  // Deal with non-target memory intrinsics.
+  bool ENTI = isHandledNonTargetIntrinsic(Earlier.get());
+  bool LNTI = isHandledNonTargetIntrinsic(Later.get());
+  if (ENTI && LNTI)
+    return isNonTargetIntrinsicMatch(cast<IntrinsicInst>(Earlier.get()),
+                                     cast<IntrinsicInst>(Later.get()));
+
+  // Because of the check above, at least one of them is false.
+  // For now disallow matching intrinsics with non-intrinsics,
+  // so assume that the stores match if neither is an intrinsic.
+  return ENTI == LNTI;
+}
+
 bool EarlyCSE::processNode(DomTreeNode *Node) {
   bool Changed = false;
   BasicBlock *BB = Node->getBlock();
@@ -1320,17 +1509,8 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
       if (MemInst.isValid() && MemInst.isStore()) {
         // We do a trivial form of DSE if there are two stores to the same
         // location with no intervening loads.  Delete the earlier store.
-        // At the moment, we don't remove ordered stores, but do remove
-        // unordered atomic stores.  There's no special requirement (for
-        // unordered atomics) about removing atomic stores only in favor of
-        // other atomic stores since we were going to execute the non-atomic
-        // one anyway and the atomic one might never have become visible.
         if (LastStore) {
-          ParseMemoryInst LastStoreMemInst(LastStore, TTI);
-          assert(LastStoreMemInst.isUnordered() &&
-                 !LastStoreMemInst.isVolatile() &&
-                 "Violated invariant");
-          if (LastStoreMemInst.isMatchingMemLoc(MemInst)) {
+          if (overridingStores(ParseMemoryInst(LastStore, TTI), MemInst)) {
             LLVM_DEBUG(dbgs() << "EarlyCSE DEAD STORE: " << *LastStore
                               << "  due to: " << Inst << '\n');
             if (!DebugCounter::shouldExecute(CSECounter)) {

diff  --git a/llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll b/llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll
index 31c250ac30bf..cf5641d85551 100644
--- a/llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll
+++ b/llvm/test/Transforms/EarlyCSE/masked-intrinsics-unequal-masks.ll
@@ -13,8 +13,7 @@
 define <4 x i32> @f3(<4 x i32>* %a0, <4 x i32> %a1) {
 ; CHECK-LABEL: @f3(
 ; CHECK-NEXT:    [[V0:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* [[A0:%.*]], i32 4, <4 x i1> <i1 true, i1 true, i1 false, i1 true>, <4 x i32> [[A1:%.*]])
-; CHECK-NEXT:    [[V1:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* [[A0]], i32 4, <4 x i1> <i1 true, i1 false, i1 false, i1 true>, <4 x i32> undef)
-; CHECK-NEXT:    [[V2:%.*]] = add <4 x i32> [[V0]], [[V1]]
+; CHECK-NEXT:    [[V2:%.*]] = add <4 x i32> [[V0]], [[V0]]
 ; CHECK-NEXT:    ret <4 x i32> [[V2]]
 ;
   %v0 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %a0, i32 4, <4 x i1> <i1 true, i1 true, i1 false, i1 true>, <4 x i32> %a1)
@@ -60,8 +59,7 @@ define <4 x i32> @f5(<4 x i32>* %a0, <4 x i32> %a1) {
 ; Expect the first store to be removed.
 define void @f6(<4 x i32> %a0, <4 x i32>* %a1) {
 ; CHECK-LABEL: @f6(
-; CHECK-NEXT:    call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[A0:%.*]], <4 x i32>* [[A1:%.*]], i32 4, <4 x i1> <i1 true, i1 false, i1 false, i1 true>)
-; CHECK-NEXT:    call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[A0]], <4 x i32>* [[A1]], i32 4, <4 x i1> <i1 true, i1 true, i1 false, i1 true>)
+; CHECK-NEXT:    call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[A0:%.*]], <4 x i32>* [[A1:%.*]], i32 4, <4 x i1> <i1 true, i1 true, i1 false, i1 true>)
 ; CHECK-NEXT:    ret void
 ;
   call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %a0, <4 x i32>* %a1, i32 4, <4 x i1> <i1 true, i1 false, i1 false, i1 true>)
@@ -90,7 +88,6 @@ define void @f7(<4 x i32> %a0, <4 x i32>* %a1) {
 define <4 x i32> @f8(<4 x i32>* %a0, <4 x i32> %a1) {
 ; CHECK-LABEL: @f8(
 ; CHECK-NEXT:    [[V0:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* [[A0:%.*]], i32 4, <4 x i1> <i1 true, i1 true, i1 false, i1 true>, <4 x i32> [[A1:%.*]])
-; CHECK-NEXT:    call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[V0]], <4 x i32>* [[A0]], i32 4, <4 x i1> <i1 true, i1 false, i1 false, i1 true>)
 ; CHECK-NEXT:    ret <4 x i32> [[V0]]
 ;
   %v0 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %a0, i32 4, <4 x i1> <i1 true, i1 true, i1 false, i1 true>, <4 x i32> %a1)
@@ -119,8 +116,7 @@ define <4 x i32> @f9(<4 x i32>* %a0, <4 x i32> %a1) {
 define <4 x i32> @fa(<4 x i32> %a0, <4 x i32>* %a1) {
 ; CHECK-LABEL: @fa(
 ; CHECK-NEXT:    call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> [[A0:%.*]], <4 x i32>* [[A1:%.*]], i32 4, <4 x i1> <i1 true, i1 true, i1 false, i1 true>)
-; CHECK-NEXT:    [[V0:%.*]] = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* [[A1]], i32 4, <4 x i1> <i1 true, i1 false, i1 false, i1 true>, <4 x i32> undef)
-; CHECK-NEXT:    ret <4 x i32> [[V0]]
+; CHECK-NEXT:    ret <4 x i32> [[A0]]
 ;
   call void @llvm.masked.store.v4i32.p0v4i32(<4 x i32> %a0, <4 x i32>* %a1, i32 4, <4 x i1> <i1 true, i1 true, i1 false, i1 true>)
   %v0 = call <4 x i32> @llvm.masked.load.v4i32.p0v4i32(<4 x i32>* %a1, i32 4, <4 x i1> <i1 true, i1 false, i1 false, i1 true>, <4 x i32> undef)

diff  --git a/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll b/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll
index 77183ab97a6b..392a487f627d 100644
--- a/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll
+++ b/llvm/test/Transforms/EarlyCSE/masked-intrinsics.ll
@@ -5,8 +5,7 @@ define <128 x i8> @f0(<128 x i8>* %a0, <128 x i8> %a1, <128 x i8> %a2) {
 ; CHECK-LABEL: @f0(
 ; CHECK-NEXT:    [[V0:%.*]] = icmp eq <128 x i8> [[A1:%.*]], [[A2:%.*]]
 ; CHECK-NEXT:    call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[A1]], <128 x i8>* [[A0:%.*]], i32 4, <128 x i1> [[V0]])
-; CHECK-NEXT:    [[V1:%.*]] = call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[A0]], i32 4, <128 x i1> [[V0]], <128 x i8> undef)
-; CHECK-NEXT:    ret <128 x i8> [[V1]]
+; CHECK-NEXT:    ret <128 x i8> [[A1]]
 ;
   %v0 = icmp eq <128 x i8> %a1, %a2
   call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> %a1, <128 x i8>* %a0, i32 4, <128 x i1> %v0)
@@ -18,7 +17,6 @@ define <128 x i8> @f1(<128 x i8>* %a0, <128 x i8> %a1, <128 x i8> %a2) {
 ; CHECK-LABEL: @f1(
 ; CHECK-NEXT:    [[V0:%.*]] = icmp eq <128 x i8> [[A1:%.*]], [[A2:%.*]]
 ; CHECK-NEXT:    [[V1:%.*]] = call <128 x i8> @llvm.masked.load.v128i8.p0v128i8(<128 x i8>* [[A0:%.*]], i32 4, <128 x i1> [[V0]], <128 x i8> undef)
-; CHECK-NEXT:    call void @llvm.masked.store.v128i8.p0v128i8(<128 x i8> [[V1]], <128 x i8>* [[A0]], i32 4, <128 x i1> [[V0]])
 ; CHECK-NEXT:    ret <128 x i8> [[V1]]
 ;
   %v0 = icmp eq <128 x i8> %a1, %a2


        


More information about the llvm-commits mailing list