[llvm] r346479 - [ARM] Small reorganisation in ARMParallelDSP

Sam Parker via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 9 01:18:00 PST 2018


Author: sam_parker
Date: Fri Nov  9 01:18:00 2018
New Revision: 346479

URL: http://llvm.org/viewvc/llvm-project?rev=346479&view=rev
Log:
[ARM] Small reorganisation in ARMParallelDSP

A few code movement things:

- AreSymmetrical is now a method of BinOpChain.
- Created a lambda in CreateParallelMACPairs to reduce loop nesting.
- A Reduction object now gets pasted in a couple of places instead,
  including CreateParallelMACPairs so it doesn't need to return a
  value.
I've also added RecordSequentialLoads, which is run before the
transformation begins, and caches the interesting loads. This can then
be queried later instead of cross checking many load values.

Differential Revision: https://reviews.llvm.org/D54254

Modified:
    llvm/trunk/lib/Target/ARM/ARMParallelDSP.cpp

Modified: llvm/trunk/lib/Target/ARM/ARMParallelDSP.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/ARM/ARMParallelDSP.cpp?rev=346479&r1=346478&r2=346479&view=diff
==============================================================================
--- llvm/trunk/lib/Target/ARM/ARMParallelDSP.cpp (original)
+++ llvm/trunk/lib/Target/ARM/ARMParallelDSP.cpp Fri Nov  9 01:18:00 2018
@@ -99,6 +99,8 @@ namespace {
         for (auto *V : RHS)
           AllValues.push_back(V);
       }
+
+    bool AreSymmetrical(BinOpChain *Other);
   };
 
   struct Reduction {
@@ -106,9 +108,9 @@ namespace {
                                       // pattern matching.
     Instruction     *AccIntAdd;       // The accumulating integer add statement,
                                       // i.e, the reduction statement.
-
     OpChainList     MACCandidates;    // The MAC candidates associated with
                                       // this reduction statement.
+    PMACPairList    PMACPairs;
     Reduction (PHINode *P, Instruction *Acc) : Phi(P), AccIntAdd(Acc) { };
   };
 
@@ -121,10 +123,13 @@ namespace {
     Loop              *L;
     const DataLayout  *DL;
     Module            *M;
+    std::map<LoadInst*, LoadInst*> LoadPairs;
+    std::map<LoadInst*, SmallVector<LoadInst*, 4>> SequentialLoads;
 
-    bool InsertParallelMACs(Reduction &Reduction, PMACPairList &PMACPairs);
+    bool RecordSequentialLoads(BasicBlock *Header);
+    bool InsertParallelMACs(Reduction &Reduction);
     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
-    PMACPairList CreateParallelMACPairs(OpChainList &Candidates);
+    void CreateParallelMACPairs(Reduction &R);
     Instruction *CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
                                  Instruction *Acc, bool Exchange,
                                  Instruction *InsertAfter);
@@ -202,6 +207,12 @@ namespace {
 
       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
       LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
+
+      if (!RecordSequentialLoads(Header)) {
+        LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
+        return false;
+      }
+
       Changes = MatchSMLAD(F);
       return Changes;
     }
@@ -254,58 +265,14 @@ static bool IsNarrowSequence(Value *V, V
   return false;
 }
 
-// Element-by-element comparison of Value lists returning true if they are
-// instructions with the same opcode or constants with the same value.
-static bool AreSymmetrical(const ValueList &VL0,
-                           const ValueList &VL1) {
-  if (VL0.size() != VL1.size()) {
-    LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
-                      << VL0.size() << " != " << VL1.size() << "\n");
-    return false;
-  }
-
-  const unsigned Pairs = VL0.size();
-  LLVM_DEBUG(dbgs() << "Number of operand pairs: " << Pairs << "\n");
-
-  for (unsigned i = 0; i < Pairs; ++i) {
-    const Value *V0 = VL0[i];
-    const Value *V1 = VL1[i];
-    const auto *Inst0 = dyn_cast<Instruction>(V0);
-    const auto *Inst1 = dyn_cast<Instruction>(V1);
-
-    LLVM_DEBUG(dbgs() << "Pair " << i << ":\n";
-               dbgs() << "mul1: "; V0->dump();
-               dbgs() << "mul2: "; V1->dump());
-
-    if (!Inst0 || !Inst1)
-      return false;
-
-    if (Inst0->isSameOperationAs(Inst1)) {
-      LLVM_DEBUG(dbgs() << "OK: same operation found!\n");
-      continue;
-    }
-
-    const APInt *C0, *C1;
-    if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
-      return false;
-  }
-
-  LLVM_DEBUG(dbgs() << "OK: found symmetrical operand lists.\n");
-  return true;
-}
-
 template<typename MemInst>
 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
-                                  MemInstList &VecMem, const DataLayout &DL,
-                                  ScalarEvolution &SE) {
+                                  const DataLayout &DL, ScalarEvolution &SE) {
   if (!MemOp0->isSimple() || !MemOp1->isSimple()) {
     LLVM_DEBUG(dbgs() << "No, not touching volatile access\n");
     return false;
   }
   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE)) {
-    VecMem.clear();
-    VecMem.push_back(MemOp0);
-    VecMem.push_back(MemOp1);
     LLVM_DEBUG(dbgs() << "OK: accesses are consecutive.\n");
     return true;
   }
@@ -328,16 +295,106 @@ bool ARMParallelDSP::AreSequentialLoads(
     return false;
   }
 
-  return AreSequentialAccesses<LoadInst>(Ld0, Ld1, VecMem, *DL, *SE);
+  if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
+    return false;
+
+  VecMem.clear();
+  VecMem.push_back(Ld0);
+  VecMem.push_back(Ld1);
+  return true;
+}
+
+/// Iterate through the block and record base, offset pairs of loads as well as
+/// maximal sequences of sequential loads.
+bool ARMParallelDSP::RecordSequentialLoads(BasicBlock *Header) {
+  SmallVector<LoadInst*, 8> Loads;
+  for (auto &I : *Header) {
+    auto *Ld = dyn_cast<LoadInst>(&I);
+    if (!Ld)
+      continue;
+    Loads.push_back(Ld);
+  }
+
+  std::map<LoadInst*, LoadInst*> BaseLoads;
+
+  for (auto *Ld0 : Loads) {
+    for (auto *Ld1 : Loads) {
+      if (Ld0 == Ld1)
+        continue;
+
+      if (AreSequentialAccesses<LoadInst>(Ld0, Ld1, *DL, *SE)) {
+        LoadPairs[Ld0] = Ld1;
+        if (BaseLoads.count(Ld0)) {
+          LoadInst *Base = BaseLoads[Ld0];
+          BaseLoads[Ld1] = Base;
+          SequentialLoads[Base].push_back(Ld1);
+        } else {
+          BaseLoads[Ld1] = Ld0;
+          SequentialLoads[Ld0].push_back(Ld1);
+        }
+      }
+    }
+  }
+  return LoadPairs.size() > 1;
 }
 
-PMACPairList
-ARMParallelDSP::CreateParallelMACPairs(OpChainList &Candidates) {
+void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) {
+  OpChainList &Candidates = R.MACCandidates;
+  PMACPairList &PMACPairs = R.PMACPairs;
   const unsigned Elems = Candidates.size();
-  PMACPairList PMACPairs;
 
   if (Elems < 2)
-    return PMACPairs;
+    return;
+
+  auto CanPair = [&](BinOpChain *PMul0, BinOpChain *PMul1) {
+    if (!PMul0->AreSymmetrical(PMul1))
+      return false;
+
+    // The first elements of each vector should be loads with sexts. If we
+    // find that its two pairs of consecutive loads, then these can be
+    // transformed into two wider loads and the users can be replaced with
+    // DSP intrinsics.
+    for (unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
+      auto *Ld0 = dyn_cast<LoadInst>(PMul0->LHS[x]);
+      auto *Ld1 = dyn_cast<LoadInst>(PMul1->LHS[x]);
+      auto *Ld2 = dyn_cast<LoadInst>(PMul0->RHS[x]);
+      auto *Ld3 = dyn_cast<LoadInst>(PMul1->RHS[x]);
+
+      if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
+        return false;
+
+      LLVM_DEBUG(dbgs() << "Looking at operands " << x << ":\n"
+                 << "\t Ld0: " << *Ld0 << "\n"
+                 << "\t Ld1: " << *Ld1 << "\n"
+                 << "and operands " << x + 2 << ":\n"
+                 << "\t Ld2: " << *Ld2 << "\n"
+                 << "\t Ld3: " << *Ld3 << "\n");
+
+      if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
+        if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
+          LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
+          PMACPairs.push_back(std::make_pair(PMul0, PMul1));
+          return true;
+        } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
+          LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
+          LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
+          PMul1->Exchange = true;
+          PMACPairs.push_back(std::make_pair(PMul0, PMul1));
+          return true;
+        }
+      } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
+                 AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
+        LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
+        LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
+        LLVM_DEBUG(dbgs() << "    and swapping muls\n");
+        PMul0->Exchange = true;
+        // Only the second operand can be exchanged, so swap the muls.
+        PMACPairs.push_back(std::make_pair(PMul1, PMul0));
+        return true;
+      }
+    }
+    return false;
+  };
 
   SmallPtrSet<const Instruction*, 4> Paired;
   for (unsigned i = 0; i < Elems; ++i) {
@@ -364,77 +421,21 @@ ARMParallelDSP::CreateParallelMACPairs(O
                  dbgs() << "- "; Mul0->dump();
                  dbgs() << "- "; Mul1->dump());
 
-      const ValueList &Mul0_LHS = PMul0->LHS;
-      const ValueList &Mul0_RHS = PMul0->RHS;
-      const ValueList &Mul1_LHS = PMul1->LHS;
-      const ValueList &Mul1_RHS = PMul1->RHS;
-
-      if (!AreSymmetrical(Mul0_LHS, Mul1_LHS) ||
-          !AreSymmetrical(Mul0_RHS, Mul1_RHS))
-        continue;
-
       LLVM_DEBUG(dbgs() << "OK: mul operands list match:\n");
-      // The first elements of each vector should be loads with sexts. If we
-      // find that its two pairs of consecutive loads, then these can be
-      // transformed into two wider loads and the users can be replaced with
-      // DSP intrinsics.
-      bool Found = false;
-      for (unsigned x = 0; x < Mul0_LHS.size(); x += 2) {
-        auto *Ld0 = dyn_cast<LoadInst>(Mul0_LHS[x]);
-        auto *Ld1 = dyn_cast<LoadInst>(Mul1_LHS[x]);
-        auto *Ld2 = dyn_cast<LoadInst>(Mul0_RHS[x]);
-        auto *Ld3 = dyn_cast<LoadInst>(Mul1_RHS[x]);
-
-        if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
-          continue;
-
-        LLVM_DEBUG(dbgs() << "Looking at operands " << x << ":\n"
-                   << "\t Ld0: " << *Ld0 << "\n"
-                   << "\t Ld1: " << *Ld1 << "\n"
-                   << "and operands " << x + 2 << ":\n"
-                   << "\t Ld2: " << *Ld2 << "\n"
-                   << "\t Ld3: " << *Ld3 << "\n");
-
-        if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
-          if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
-            LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
-            PMACPairs.push_back(std::make_pair(PMul0, PMul1));
-            Found = true;
-          } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
-            LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
-            LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");
-            PMul1->Exchange = true;
-            PMACPairs.push_back(std::make_pair(PMul0, PMul1));
-            Found = true;
-          }
-        } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd)) {
-          if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
-            LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
-            LLVM_DEBUG(dbgs() << "    exchanging Ld0 and Ld1\n");
-            LLVM_DEBUG(dbgs() << "    and swapping muls\n");
-            PMul0->Exchange = true;
-            // Only the second operand can be exchanged, so swap the muls.
-            PMACPairs.push_back(std::make_pair(PMul1, PMul0));
-            Found = true;
-          }
-        }
-      }
-      if (Found) {
+      if (CanPair(PMul0, PMul1)) {
         Paired.insert(Mul0);
         Paired.insert(Mul1);
         break;
       }
     }
   }
-  return PMACPairs;
 }
 
-bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction,
-                                        PMACPairList &PMACPairs) {
+bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction) {
   Instruction *Acc = Reduction.Phi;
   Instruction *InsertAfter = Reduction.AccIntAdd;
 
-  for (auto &Pair : PMACPairs) {
+  for (auto &Pair : Reduction.PMACPairs) {
     BinOpChain *PMul0 = Pair.first;
     BinOpChain *PMul1 = Pair.second;
     LLVM_DEBUG(dbgs() << "Found parallel MACs!!\n";
@@ -685,8 +686,8 @@ bool ARMParallelDSP::MatchSMLAD(Function
   for (auto &R : Reductions) {
     if (AreAliased(AA, Reads, Writes, R.MACCandidates))
       return false;
-    PMACPairList PMACPairs = CreateParallelMACPairs(R.MACCandidates);
-    Changed |= InsertParallelMACs(R, PMACPairs);
+    CreateParallelMACPairs(R);
+    Changed |= InsertParallelMACs(R);
   }
 
   LLVM_DEBUG(if (Changed) dbgs() << "Header block:\n"; Header->dump(););
@@ -733,6 +734,52 @@ Instruction *ARMParallelDSP::CreateSMLAD
   return Call;
 }
 
+// Compare the value lists in Other to this chain.
+bool BinOpChain::AreSymmetrical(BinOpChain *Other) {
+  // Element-by-element comparison of Value lists returning true if they are
+  // instructions with the same opcode or constants with the same value.
+  auto CompareValueList = [](const ValueList &VL0,
+                             const ValueList &VL1) {
+    if (VL0.size() != VL1.size()) {
+      LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
+                        << VL0.size() << " != " << VL1.size() << "\n");
+      return false;
+    }
+
+    const unsigned Pairs = VL0.size();
+    LLVM_DEBUG(dbgs() << "Number of operand pairs: " << Pairs << "\n");
+
+    for (unsigned i = 0; i < Pairs; ++i) {
+      const Value *V0 = VL0[i];
+      const Value *V1 = VL1[i];
+      const auto *Inst0 = dyn_cast<Instruction>(V0);
+      const auto *Inst1 = dyn_cast<Instruction>(V1);
+
+      LLVM_DEBUG(dbgs() << "Pair " << i << ":\n";
+                dbgs() << "mul1: "; V0->dump();
+                dbgs() << "mul2: "; V1->dump());
+
+      if (!Inst0 || !Inst1)
+        return false;
+
+      if (Inst0->isSameOperationAs(Inst1)) {
+        LLVM_DEBUG(dbgs() << "OK: same operation found!\n");
+        continue;
+      }
+
+      const APInt *C0, *C1;
+      if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
+        return false;
+    }
+
+    LLVM_DEBUG(dbgs() << "OK: found symmetrical operand lists.\n");
+    return true;
+  };
+
+  return CompareValueList(LHS, Other->LHS) &&
+         CompareValueList(RHS, Other->RHS);
+}
+
 Pass *llvm::createARMParallelDSPPass() {
   return new ARMParallelDSP();
 }




More information about the llvm-commits mailing list