[Mlir-commits] [mlir] 825b23a - NFC. Refactor/update some affine fusion pass code for readability

Uday Bondhugula llvmlistbot at llvm.org
Sun Dec 11 11:48:43 PST 2022


Author: Uday Bondhugula
Date: 2022-12-12T01:18:17+05:30
New Revision: 825b23a71194958a0c2bb91814b5b371e6b4206c

URL: https://github.com/llvm/llvm-project/commit/825b23a71194958a0c2bb91814b5b371e6b4206c
DIFF: https://github.com/llvm/llvm-project/commit/825b23a71194958a0c2bb91814b5b371e6b4206c.diff

LOG: NFC. Refactor/update some affine fusion pass code for readability

NFC. Refactor some affine fusion pass code for readability. Some of its
methods are too long. This is the first among some NFC changes before new
features/related updates are posted. Add missing code comments at a couple of
places.

Differential Revision: https://reviews.llvm.org/D139255

Added: 
    

Modified: 
    mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
    mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 63b7ad70f9c8..df6b30ee91f9 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -12,7 +12,6 @@
 
 #include "mlir/Dialect/Affine/Passes.h"
 
-#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
@@ -224,7 +223,7 @@ struct MemRefDependenceGraph {
   // Returns the graph node for 'forOp'.
   Node *getForOpNode(AffineForOp forOp) {
     for (auto &idAndNode : nodes)
-      if (idAndNode.second.op == forOp.getOperation())
+      if (idAndNode.second.op == forOp)
         return &idAndNode.second;
     return nullptr;
   }
@@ -711,27 +710,35 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
                                 producerConsumerMemrefs);
 }
 
+/// A memref escapes the function if either:
+///   1. it is a function argument, or
+///   2. it is used by a non-affine op (e.g., std load/store, std
+///   call, etc.)
+/// FIXME: Support alias creating ops like memref view ops.
+static bool isEscapingMemref(Value memref) {
+  // Check if 'memref' escapes because it's a block argument.
+  if (memref.isa<BlockArgument>())
+    return true;
+
+  // Check if 'memref' escapes through a non-affine op (e.g., std load/store,
+  // call op, etc.). This already covers aliases created from this.
+  for (Operation *user : memref.getUsers())
+    if (!isa<AffineMapAccessInterface>(*user))
+      return true;
+  return false;
+}
+
 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
-/// that escape the function. A memref escapes the function if either:
-///   1. It's a function argument, or
-///   2. It's used by a non-affine op (e.g., std load/store, std call, etc.)
+/// that escape the function.
 void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
                            DenseSet<Value> &escapingMemRefs) {
   auto *node = mdg->getNode(id);
-  for (auto *storeOpInst : node->stores) {
-    auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+  for (Operation *storeOp : node->stores) {
+    auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
     if (escapingMemRefs.count(memref))
       continue;
-    // Check if 'memref' escapes because it's a block argument.
-    if (memref.isa<BlockArgument>()) {
+    if (isEscapingMemref(memref))
       escapingMemRefs.insert(memref);
-      continue;
-    }
-    // Check if 'memref' escapes through a non-affine op (e.g., std load/store,
-    // call op, etc.).
-    for (Operation *user : memref.getUsers())
-      if (!isa<AffineMapAccessInterface>(*user))
-        escapingMemRefs.insert(memref);
   }
 }
 
@@ -743,6 +750,8 @@ void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
 // dependence graph at a 
diff erent depth.
 bool MemRefDependenceGraph::init(func::FuncOp f) {
   LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
+  // Map from a memref to the set of ids of the nodes that have ops accessing
+  // the memref.
   DenseMap<Value, SetVector<unsigned>> memrefAccesses;
 
   // TODO: support multi-block functions.
@@ -832,8 +841,8 @@ bool MemRefDependenceGraph::init(func::FuncOp f) {
         getLoopIVs(*user, &loops);
         if (loops.empty())
           continue;
-        assert(forToNodeMap.count(loops[0].getOperation()) > 0);
-        unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()];
+        assert(forToNodeMap.count(loops[0]) > 0);
+        unsigned userLoopNestId = forToNodeMap[loops[0]];
         addEdge(node.id, userLoopNestId, value);
       }
     }
@@ -866,7 +875,7 @@ bool MemRefDependenceGraph::init(func::FuncOp f) {
 static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
   assert(isa<AffineForOp>(node->op));
   AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
-  node->op = newRootForOp.getOperation();
+  node->op = newRootForOp;
 }
 
 //  TODO: improve/complete this when we have target data.
@@ -893,7 +902,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
                                  unsigned dstLoopDepth,
                                  Optional<unsigned> fastMemorySpace,
                                  uint64_t localBufSizeThreshold) {
-  auto *forInst = forOp.getOperation();
+  Operation *forInst = forOp.getOperation();
 
   // Create builder to insert alloc op just before 'forOp'.
   OpBuilder b(forInst);
@@ -1418,6 +1427,10 @@ struct GreedyFusion {
     eraseUnusedMemRefAllocations();
   }
 
+  /// Visit each node in the graph, and for each node, attempt to fuse it with
+  /// producer-consumer candidates. No fusion is performed when producers with a
+  /// user count greater than `maxSrcUserCount` for any of the memrefs involved
+  /// are encountered.
   void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
     LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
     init();
@@ -1628,8 +1641,8 @@ struct GreedyFusion {
                      << dstAffineForOp << "\n");
 
           // Move 'dstAffineForOp' before 'insertPointInst' if needed.
-          if (fusedLoopInsPoint != dstAffineForOp.getOperation())
-            dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint);
+          if (fusedLoopInsPoint != dstAffineForOp)
+            dstAffineForOp->moveBefore(fusedLoopInsPoint);
 
           // Update edges between 'srcNode' and 'dstNode'.
           mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
@@ -1642,8 +1655,7 @@ struct GreedyFusion {
             dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
               Value storeMemRef = storeOp.getMemRef();
               if (privateMemrefs.count(storeMemRef) > 0)
-                privateMemRefToStores[storeMemRef].push_back(
-                    storeOp.getOperation());
+                privateMemRefToStores[storeMemRef].push_back(storeOp);
             });
 
             // Replace original memrefs with private memrefs. Note that all the
@@ -1672,7 +1684,7 @@ struct GreedyFusion {
 
           // Collect dst loop stats after memref privatization transformation.
           LoopNestStateCollector dstLoopCollector;
-          dstLoopCollector.collect(dstAffineForOp.getOperation());
+          dstLoopCollector.collect(dstAffineForOp);
 
           // Clear and add back loads and stores.
           mdg->clearNodeLoadAndStores(dstNode->id);
@@ -1798,7 +1810,7 @@ struct GreedyFusion {
 
       auto dstForInst = cast<AffineForOp>(dstNode->op);
       // Update operation position of fused loop nest (if needed).
-      if (insertPointInst != dstForInst.getOperation()) {
+      if (insertPointInst != dstForInst) {
         dstForInst->moveBefore(insertPointInst);
       }
       // Update data dependence graph state post fusion.
@@ -1939,7 +1951,7 @@ struct GreedyFusion {
     // Collect dst loop stats after memref privatization transformation.
     auto dstForInst = cast<AffineForOp>(dstNode->op);
     LoopNestStateCollector dstLoopCollector;
-    dstLoopCollector.collect(dstForInst.getOperation());
+    dstLoopCollector.collect(dstForInst);
     // Clear and add back loads and stores
     mdg->clearNodeLoadAndStores(dstNode->id);
     mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,

diff  --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 7378418dc0af..889c9078c530 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -13,20 +13,13 @@
 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
-#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Operation.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -113,7 +106,7 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
         }
         return WalkResult::advance();
       }
-      for (auto value : op->getResults()) {
+      for (Value value : op->getResults()) {
         for (Operation *user : value.getUsers()) {
           SmallVector<AffineForOp, 4> loops;
           // Check if any loop in loop nest surrounding 'user' is 'opB'.
@@ -137,15 +130,12 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
 // dependences. Returns nullptr if no such insertion point is found.
 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
                                                  AffineForOp dstForOp) {
-  bool isSrcForOpBeforeDstForOp =
-      srcForOp->isBeforeInBlock(dstForOp.getOperation());
+  bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
 
-  auto *firstDepOpA =
-      getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
-  auto *lastDepOpB =
-      getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
+  Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB);
+  Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB);
   // Block:
   //      ...
   //  |-- opA
@@ -170,7 +160,7 @@ static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
   }
   // No dependences from 'opA' to operation in range ('opA', 'opB'), return
   // 'opB' insertion point.
-  return forOpB.getOperation();
+  return forOpB;
 }
 
 // Gathers all load and store ops in loop nest rooted at 'forOp' into
@@ -281,8 +271,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   }
 
   // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
-  bool isSrcForOpBeforeDstForOp =
-      srcForOp->isBeforeInBlock(dstForOp.getOperation());
+  bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
   // 'forOpA' executes before 'forOpB' in 'block'.
   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
@@ -315,8 +304,8 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   }
 
   // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
-  unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
-      *srcForOp.getOperation(), *dstForOp.getOperation());
+  unsigned numCommonLoops =
+      mlir::getNumCommonSurroundingLoops(*srcForOp, *dstForOp);
 
   // Filter out ops in 'opsA' to compute the slice union based on the
   // assumptions made by the fusion strategy.
@@ -539,8 +528,8 @@ static int64_t getComputeCostHelper(
   int64_t opCount = stats.opCountMap[forOp] - 1;
   if (stats.loopMap.count(forOp) > 0) {
     for (auto childForOp : stats.loopMap[forOp]) {
-      opCount += getComputeCostHelper(childForOp.getOperation(), stats,
-                                      tripCountOverrideMap, computeCostMap);
+      opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap,
+                                      computeCostMap);
     }
   }
   // Add in additional op instances from slice (if specified in map).
@@ -567,7 +556,7 @@ static int64_t getComputeCostHelper(
 /// instance count (i.e. total number of operations in the loop body * loop
 /// trip count) for the entire loop nest.
 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
-  return getComputeCostHelper(forOp.getOperation(), stats,
+  return getComputeCostHelper(forOp, stats,
                               /*tripCountOverrideMap=*/nullptr,
                               /*computeCostMap=*/nullptr);
 }
@@ -611,8 +600,8 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
       computeCostMap[insertPointParent] = -storeCount;
     // Subtract out any load users of 'storeMemrefs' nested below
     // 'insertPointParent'.
-    for (auto value : storeMemrefs) {
-      for (auto *user : value.getUsers()) {
+    for (Value memref : storeMemrefs) {
+      for (auto *user : memref.getUsers()) {
         if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
           SmallVector<AffineForOp, 4> loops;
           // Check if any loop in loop nest surrounding 'user' is
@@ -633,13 +622,13 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
 
   // Compute op instance count for the src loop nest with iteration slicing.
   int64_t sliceComputeCost = getComputeCostHelper(
-      srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
+      srcForOp, srcStats, &sliceTripCountMap, &computeCostMap);
 
   // Compute cost of fusion for this depth.
   computeCostMap[insertPointParent] = sliceComputeCost;
 
   *computeCost =
-      getComputeCostHelper(dstForOp.getOperation(), dstStats,
+      getComputeCostHelper(dstForOp, dstStats,
                            /*tripCountOverrideMap=*/nullptr, &computeCostMap);
   return true;
 }


        


More information about the Mlir-commits mailing list