[Mlir-commits] [mlir] [MLIR][Affine] Fix fusion in the presence of cyclic deps in source nests (PR #128397)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 22 20:17:12 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Uday Bondhugula (bondhugula)

<details>
<summary>Changes</summary>

Fixes: https://github.com/llvm/llvm-project/issues/61820

Fix affine fusion in the presence of cyclic deps in the source nest. In such cases, the nest being fused can't be executed multiple times. Add a utility to check for dependence cycles and use it in fusion. This fixes both sibling as well as producer consumer fusion where nests with cyclic dependences (typically reductions) were being in some cases incorrectly fused in.

The test case also exercises/required a fix to the check for the redundant computation being within the specified threshold.

---

Patch is 24.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128397.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h (+7) 
- (modified) mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp (+160-2) 
- (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+167-44) 
- (modified) mlir/test/Dialect/Affine/loop-fusion-4.mlir (+27) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h b/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
index ed3c21d952a01..4ae571e26f0b4 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
@@ -119,6 +119,13 @@ bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts);
 /// any dependence component is negative along any of `loops`.
 bool isTilingValid(ArrayRef<AffineForOp> loops);
 
+/// Returns true if the affine nest rooted at `root` has a cyclic dependence
+/// among its affine memory accesses. The dependence could be through any
+/// dependences carried by loops contained in `root` (inclusive of `root`) and
+/// those carried by loop bodies (blocks) contained. Dependences carried by
+/// loops outer to `root` aren't relevant.
+bool hasCyclicDependence(AffineForOp root);
+
 } // namespace affine
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index 0d4b0ea1668e0..6f014e5074f3f 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -16,7 +16,7 @@
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "llvm/Support/MathExtras.h"
 
@@ -28,10 +28,138 @@
 #include <optional>
 #include <type_traits>
 
+#define DEBUG_TYPE "affine-loop-analysis"
+
 using namespace mlir;
 using namespace mlir::affine;
 
-#define DEBUG_TYPE "affine-loop-analysis"
+namespace {
+
+/// A directed graph to model relationships between MLIR Operations.
+class DirectedOpGraph {
+public:
+  /// Add a node to
+  void addNode(Operation *op) {
+    assert(!hasNode(op) && "node already added");
+    nodes.emplace_back(op);
+    edges[op] = {};
+  }
+
+  /// Add an edge between `src` and `dest`.
+  void addEdge(Operation *src, Operation *dest) {
+    // This is a multi-graph.
+    assert(hasNode(src) && "src node does not exist in graph");
+    assert(hasNode(dest) && "dest node does not exist in graph");
+    edges[src].push_back(getNode(dest));
+  }
+
+  /// Returns true if there is a (directed) cycle in the graph.
+  bool hasCycle() { return dfsImpl(/*cycleCheck=*/true); }
+
+  void printEdges() {
+    for (auto &en : edges) {
+      llvm::dbgs() << *en.first << " (" << en.first << ")"
+                   << " has " << en.second.size() << " edges:\n";
+      for (auto *node : en.second) {
+        llvm::dbgs() << '\t' << *node->op << '\n';
+      }
+    }
+  }
+
+private:
+  /// A node of a directed graph between MLIR Operations to model various
+  /// relationships. This is meant to be used internally.
+  struct DGNode {
+    DGNode(Operation *op) : op(op){};
+    Operation *op;
+
+    // Start and finish visit numbers are standard in DFS to implement things
+    // strongly connected components. These numbers are modified during analyses
+    // on the graph and so seemingly const API methods will be non-const.
+
+    /// Start visit number.
+    int vn = -1;
+
+    /// Finish visit number.
+    int fn = -1;
+  };
+
+  /// Get internal node corresponding to `op`.
+  DGNode *getNode(Operation *op) {
+    auto *value =
+        llvm::find_if(nodes, [&](const DGNode &node) { return node.op == op; });
+    assert(value != nodes.end() && "node doesn't exist in graph");
+    return &*value;
+  }
+
+  /// Returns true if `key` is in the graph.
+  bool hasNode(Operation *key) const {
+    return llvm::find_if(nodes, [&](const DGNode &node) {
+             return node.op == key;
+           }) != nodes.end();
+  }
+
+  /// Perform a depth-first traversal of the graph setting visited and finished
+  /// numbers. If `cycleCheck` is set, detects cycles and returns true as soon
+  /// as the first cycle is detected, and false if there are no cycles. If
+  /// `cycleCheck` is not set, completes the DFS and the `return` value doesn't
+  /// have a meaning.
+  bool dfsImpl(bool cycleCheck = false) {
+    for (DGNode &node : nodes)
+      node.vn = 0;
+
+    unsigned time = 0;
+    for (DGNode &node : nodes) {
+      if (node.vn == 0) {
+        bool ret = dfsNode(node, cycleCheck, time);
+        // Check if a cycle was already found.
+        if (cycleCheck && ret)
+          return true;
+      } else if (cycleCheck && node.fn == -1) {
+        // We have encountered a node whose visit has started but it's not
+        // finished. So we have a cycle.
+        return true;
+      }
+    }
+    return false;
+  }
+
+  /// Perform depth-first traversal starting at `node`. Return true
+  /// as soon as a cycle is found if `cycleCheck` was set. Update `time`.
+  bool dfsNode(DGNode &node, bool cycleCheck, unsigned &time) const {
+    auto nodeEdges = edges.find(node.op);
+    assert(nodeEdges != edges.end() && "missing node in graph");
+    // Depth first search from a given vertex.
+    ++time;
+    node.vn = time;
+
+    for (auto &neighbour : nodeEdges->second) {
+      if (neighbour->vn == 0) {
+        bool ret = dfsNode(*neighbour, cycleCheck, time);
+        if (cycleCheck && ret)
+          return true;
+      } else if (cycleCheck && neighbour->fn == -1) {
+        // We have encountered a node whose visit has started but it's not
+        // finished. So we have a cycle.
+        return true;
+      }
+    }
+
+    ++time;
+    // Update finish time.
+    node.fn = time;
+
+    return false;
+  }
+
+  // The list of nodes. The storage is owned by this class.
+  SmallVector<DGNode> nodes;
+
+  // Edges as an adjacency list.
+  DenseMap<Operation *, SmallVector<DGNode *>> edges;
+};
+
+} // namespace
 
 /// Returns the trip count of the loop as an affine expression if the latter is
 /// expressible as an affine expression, and nullptr otherwise. The trip count
@@ -447,3 +575,33 @@ bool mlir::affine::isTilingValid(ArrayRef<AffineForOp> loops) {
 
   return true;
 }
+
+bool mlir::affine::hasCyclicDependence(AffineForOp root) {
+  // Collect all the memory accesses in the source nest grouped by their
+  // immediate parent block.
+  DirectedOpGraph graph;
+  SmallVector<MemRefAccess> accesses;
+  root->walk([&](Operation *op) {
+    if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
+      accesses.emplace_back(op);
+      graph.addNode(op);
+    }
+  });
+
+  // Construct the dependence graph for all the collected acccesses.
+  unsigned rootDepth = getNestingDepth(root);
+  for (const auto &accA : accesses) {
+    for (const auto &accB : accesses) {
+      if (accA.memref != accB.memref)
+        continue;
+      // Perform the dependence on all surrounding loops + the body.
+      unsigned numCommonLoops =
+          getNumCommonSurroundingLoops(*accA.opInst, *accB.opInst);
+      for (unsigned d = rootDepth + 1; d <= numCommonLoops + 1; ++d) {
+        if (!noDependence(checkMemrefAccessDependence(accA, accB, d)))
+          graph.addEdge(accA.opInst, accB.opInst);
+      }
+    }
+  }
+  return graph.hasCycle();
+}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 5add7df849286..b97f11a963828 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -274,6 +274,58 @@ getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
   return firstAncestor;
 }
 
+/// Returns the amount of additional (redundant) computation that will be done
+/// as a fraction of the total computation if `srcForOp` is fused into
+/// `dstForOp` at depth `depth`. The method returns the compute cost of the
+/// slice and the fused nest's compute cost in the trailing output arguments.
+static std::optional<double> getAdditionalComputeFraction(
+    AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
+    ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
+    int64_t &fusedLoopNestComputeCost) {
+  LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";);
+  // Compute cost of sliced and unsliced src loop nest.
+  // Walk src loop nest and collect stats.
+  LoopNestStats srcLoopNestStats;
+  if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) {
+    LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n");
+    return std::nullopt;
+  }
+
+  // Compute cost of dst loop nest.
+  LoopNestStats dstLoopNestStats;
+  if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) {
+    LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n");
+    return std::nullopt;
+  }
+
+  // Compute op instance count for the src loop nest without iteration slicing.
+  uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
+
+  // Compute op cost for the dst loop nest.
+  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
+
+  const ComputationSliceState &slice = depthSliceUnions[depth - 1];
+  // Skip slice union if it wasn't computed for this depth.
+  if (slice.isEmpty()) {
+    LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n");
+    return std::nullopt;
+  }
+
+  if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp,
+                            dstLoopNestStats, slice,
+                            &fusedLoopNestComputeCost)) {
+    LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
+    return std::nullopt;
+  }
+
+  double additionalComputeFraction =
+      fusedLoopNestComputeCost /
+          (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
+      1;
+
+  return additionalComputeFraction;
+}
+
 // Creates and returns a private (single-user) memref for fused loop rooted at
 // 'forOp', with (potentially reduced) memref size based on the memref region
 // written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
@@ -384,20 +436,19 @@ static Value createPrivateMemRef(AffineForOp forOp,
 }
 
 // Checks the profitability of fusing a backwards slice of the loop nest
-// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
-// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
-// the memref being produced and consumed, which is an input to the cost model.
-// For producer-consumer fusion, 'srcStoreOpInst' will be the same as
-// 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
-// fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
-// same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
-// unique store op in the src node, which will be used to check that the write
-// region is the same after input-reuse fusion. Computation slices are provided
-// in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
-// fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
-// profitable to fuse the candidate loop nests. Returns false otherwise.
-// `dstLoopDepth` is set to the most profitable depth at which to materialize
-// the source loop nest slice.
+// `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument
+// 'srcStoreOpInst' is used to calculate the storage reduction on the memref
+// being produced and consumed, which is an input to the cost model. For
+// producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst',
+// as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst'
+// will be the src loop nest LoadOp which reads from the same memref as dst loop
+// nest load ops, and 'srcStoreOpInst' will be the unique store op in the src
+// node, which will be used to check that the write region is the same after
+// input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for
+// each legal fusion depth. The maximal depth at which fusion is legal is
+// provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse
+// the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to
+// the most profitable depth at which to materialize the source loop nest slice.
 // The profitability model executes the following steps:
 // *) Computes the backward computation slice at 'srcOpInst'. This
 //    computation slice of the loop nest surrounding 'srcOpInst' is
@@ -422,15 +473,16 @@ static Value createPrivateMemRef(AffineForOp forOp,
 //    is lower.
 // TODO: Extend profitability analysis to support scenarios with multiple
 // stores.
-static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
+static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
                                AffineForOp dstForOp,
                                ArrayRef<ComputationSliceState> depthSliceUnions,
                                unsigned maxLegalFusionDepth,
                                unsigned *dstLoopDepth,
                                double computeToleranceThreshold) {
   LLVM_DEBUG({
-    llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
-    llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
+    llvm::dbgs()
+        << "Checking whether fusion is profitable between source nest:\n";
+    llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n";
     llvm::dbgs() << dstForOp << "\n";
   });
 
@@ -440,12 +492,10 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   }
 
   // Compute cost of sliced and unsliced src loop nest.
-  SmallVector<AffineForOp, 4> srcLoopIVs;
-  getAffineForIVs(*srcOpInst, &srcLoopIVs);
 
   // Walk src loop nest and collect stats.
   LoopNestStats srcLoopNestStats;
-  if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
+  if (!getLoopNestStats(srcForOp, &srcLoopNestStats))
     return false;
 
   // Compute cost of dst loop nest.
@@ -467,7 +517,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   std::optional<unsigned> bestDstLoopDepth;
 
   // Compute op instance count for the src loop nest without iteration slicing.
-  uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
+  uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
 
   // Compute src loop nest write region size.
   MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
@@ -494,18 +544,21 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     if (slice.isEmpty())
       continue;
 
+    // Compute cost of the slice separately, i.e, the compute cost of the slice
+    // if all outer trip counts are one.
+    int64_t sliceCost;
+
     int64_t fusedLoopNestComputeCost;
-    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
-                              dstLoopNestStats, slice,
-                              &fusedLoopNestComputeCost)) {
-      LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
+
+    auto mayAdditionalComputeFraction =
+        getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions,
+                                     sliceCost, fusedLoopNestComputeCost);
+    if (!mayAdditionalComputeFraction) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Can't determine additional compute fraction.\n");
       continue;
     }
-
-    double additionalComputeFraction =
-        fusedLoopNestComputeCost /
-            (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
-        1;
+    double additionalComputeFraction = *mayAdditionalComputeFraction;
 
     // Determine what the slice write MemRefRegion would be, if the src loop
     // nest slice 'slice' were to be inserted into the dst loop nest at loop
@@ -530,14 +583,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     }
     int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
 
-    // If we are fusing for reuse, check that write regions remain the same.
-    // TODO: Write region check should check sizes and offsets in
-    // each dimension, so that we are sure they are covering the same memref
-    // region. Also, move this out to a isMemRefRegionSuperSet helper function.
-    if (srcOpInst != srcStoreOpInst &&
-        sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
-      continue;
-
     double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
                               static_cast<double>(sliceWriteRegionSizeBytes);
 
@@ -560,7 +605,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     // (as per computeToleranceThreshold), we will simply pick the one that
     // reduces the intermediary size the most.
     if ((storageReduction > maxStorageReduction) &&
-        (additionalComputeFraction < computeToleranceThreshold)) {
+        (additionalComputeFraction <= computeToleranceThreshold)) {
       maxStorageReduction = storageReduction;
       bestDstLoopDepth = i;
       minFusedLoopNestComputeCost = fusedLoopNestComputeCost;
@@ -595,7 +640,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
                    << minFusedLoopNestComputeCost << "\n");
 
   auto dstMemSize = getMemoryFootprintBytes(dstForOp);
-  auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
+  auto srcMemSize = getMemoryFootprintBytes(srcForOp);
 
   std::optional<double> storageReduction;
 
@@ -840,6 +885,8 @@ struct GreedyFusion {
         LLVM_DEBUG(llvm::dbgs()
                    << "Trying to fuse producer loop nest " << srcId
                    << " with consumer loop nest " << dstId << "\n");
+        LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: "
+                                << computeToleranceThreshold << '\n');
         LLVM_DEBUG(llvm::dbgs()
                    << "Producer loop nest:\n"
                    << *srcNode->op << "\n and consumer loop nest:\n"
@@ -926,6 +973,46 @@ struct GreedyFusion {
           continue;
         }
 
+        LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
+                                << maxLegalFusionDepth << '\n');
+
+        double computeToleranceThresholdToUse = computeToleranceThreshold;
+
+        // Cyclic dependences in the source nest may be violated when performing
+        // slicing-based fusion. They aren't actually violated in cases where no
+        // redundant execution of the source happens (1:1 pointwise dep on the
+        // producer-consumer memref access for example). Check this and allow
+        // fusion accordingly.
+        if (hasCyclicDependence(srcAffineForOp)) {
+          LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n");
+          // Maximal fusion does not check for compute tolerance threshold; so
+          // perform the maximal fusion only when the redundanation computation
+          // is zero.
+          if (maximalFusion) {
+            auto srcForOp = cast<AffineForOp>(srcNode->op);
+            auto dstForOp = cast<AffineForOp>(dstNode->op);
+            int64_t sliceCost;
+            int64_t fusedLoopNestComputeCost;
+            auto fraction = getAdditionalComputeFraction(
+                srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions,
+                sliceCost, fusedLoopNestComputeCost);
+            if (!fraction || fraction > 0) {
+              LLVM_DEBUG(
+                  llvm::dbgs()
+                  << "Can't perform maximal fusion with a cyclic dependence "
+                     "and non-zero additional compute.\n");
+              return;
+            }
+          } else {
+            // Set redundant computation tolerance to zero regardless of what
+            // the user specified. Without this, fusion would be invalid.
+            LLVM_DEBUG(llvm::dbgs()
+                       << "Setting compute tolerance to zero since "
+                          "source has a cylic dependence.\n");
+            computeToleranceThresholdToUse = 0;
+          }
+        }
+
         // Check if fusion would be profitable. We skip profitability analysis
         // for maximal fusion since we already know the maximal legal depth to
         // fuse.
@@ -948,10 +1035,10 @@ struct GreedyFusion {
           if (producerStores.size() > 1)
             LLVM_DEB...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/128397


More information about the Mlir-commits mailing list