[Mlir-commits] [mlir] [MLIR][Affine] Add missing check on fusion compute tolerance on a path (PR #128454)

Uday Bondhugula llvmlistbot at llvm.org
Sun Feb 23 19:44:42 PST 2025


https://github.com/bondhugula created https://github.com/llvm/llvm-project/pull/128454

When profitability analysis can't be performed, we should still be respecting the compute tolerance specified. Refactor to pull the additional computation factor computation and check.

Fixes: https://github.com/llvm/llvm-project/issues/54541

>From a7349f867c40512230bbb96e28630a14875a9610 Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Sat, 22 Feb 2025 16:05:37 +0530
Subject: [PATCH] [MLIR][Affine] Add missing check on fusion compute tolerance
 on a path

When profitability analysis can't be performed, we should still be
respecting the compute tolerance specified. Refactor to pull the
additional computation factor computation and check.

Fixes: https://github.com/llvm/llvm-project/issues/54541
---
 .../Dialect/Affine/Transforms/LoopFusion.cpp  | 155 ++++++++++++------
 mlir/test/Dialect/Affine/loop-fusion-4.mlir   |  75 +++++++++
 2 files changed, 184 insertions(+), 46 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 5add7df849286..7945e156dcd31 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -15,7 +15,6 @@
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Affine/Utils.h"
@@ -274,6 +273,58 @@ getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
   return firstAncestor;
 }
 
+/// Returns the amount of additional (redundant) computation that will be done
+/// as a fraction of the total computation if `srcForOp` is fused into
+/// `dstForOp` at depth `depth`. The method returns the compute cost of the
+/// slice and the fused nest's compute cost in the trailing output arguments.
+static std::optional<double> getAdditionalComputeFraction(
+    AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
+    ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
+    int64_t &fusedLoopNestComputeCost) {
+  LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";);
+  // Compute cost of sliced and unsliced src loop nest.
+  // Walk src loop nest and collect stats.
+  LoopNestStats srcLoopNestStats;
+  if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) {
+    LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n");
+    return std::nullopt;
+  }
+
+  // Compute cost of dst loop nest.
+  LoopNestStats dstLoopNestStats;
+  if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) {
+    LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n");
+    return std::nullopt;
+  }
+
+  // Compute op instance count for the src loop nest without iteration slicing.
+  uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
+
+  // Compute op cost for the dst loop nest.
+  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
+
+  const ComputationSliceState &slice = depthSliceUnions[depth - 1];
+  // Skip slice union if it wasn't computed for this depth.
+  if (slice.isEmpty()) {
+    LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n");
+    return std::nullopt;
+  }
+
+  if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp,
+                            dstLoopNestStats, slice,
+                            &fusedLoopNestComputeCost)) {
+    LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
+    return std::nullopt;
+  }
+
+  double additionalComputeFraction =
+      fusedLoopNestComputeCost /
+          (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
+      1;
+
+  return additionalComputeFraction;
+}
+
 // Creates and returns a private (single-user) memref for fused loop rooted at
 // 'forOp', with (potentially reduced) memref size based on the memref region
 // written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
@@ -384,20 +435,19 @@ static Value createPrivateMemRef(AffineForOp forOp,
 }
 
 // Checks the profitability of fusing a backwards slice of the loop nest
-// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
-// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
-// the memref being produced and consumed, which is an input to the cost model.
-// For producer-consumer fusion, 'srcStoreOpInst' will be the same as
-// 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
-// fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
-// same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
-// unique store op in the src node, which will be used to check that the write
-// region is the same after input-reuse fusion. Computation slices are provided
-// in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
-// fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
-// profitable to fuse the candidate loop nests. Returns false otherwise.
-// `dstLoopDepth` is set to the most profitable depth at which to materialize
-// the source loop nest slice.
+// `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument
+// 'srcStoreOpInst' is used to calculate the storage reduction on the memref
+// being produced and consumed, which is an input to the cost model. For
+// producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst',
+// as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst'
+// will be the src loop nest LoadOp which reads from the same memref as dst loop
+// nest load ops, and 'srcStoreOpInst' will be the unique store op in the src
+// node, which will be used to check that the write region is the same after
+// input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for
+// each legal fusion depth. The maximal depth at which fusion is legal is
+// provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse
+// the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to
+// the most profitable depth at which to materialize the source loop nest slice.
 // The profitability model executes the following steps:
 // *) Computes the backward computation slice at 'srcOpInst'. This
 //    computation slice of the loop nest surrounding 'srcOpInst' is
@@ -422,15 +472,16 @@ static Value createPrivateMemRef(AffineForOp forOp,
 //    is lower.
 // TODO: Extend profitability analysis to support scenarios with multiple
 // stores.
-static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
+static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
                                AffineForOp dstForOp,
                                ArrayRef<ComputationSliceState> depthSliceUnions,
                                unsigned maxLegalFusionDepth,
                                unsigned *dstLoopDepth,
                                double computeToleranceThreshold) {
   LLVM_DEBUG({
-    llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
-    llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
+    llvm::dbgs()
+        << "Checking whether fusion is profitable between source nest:\n";
+    llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n";
     llvm::dbgs() << dstForOp << "\n";
   });
 
@@ -440,12 +491,10 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   }
 
   // Compute cost of sliced and unsliced src loop nest.
-  SmallVector<AffineForOp, 4> srcLoopIVs;
-  getAffineForIVs(*srcOpInst, &srcLoopIVs);
 
   // Walk src loop nest and collect stats.
   LoopNestStats srcLoopNestStats;
-  if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
+  if (!getLoopNestStats(srcForOp, &srcLoopNestStats))
     return false;
 
   // Compute cost of dst loop nest.
@@ -467,7 +516,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   std::optional<unsigned> bestDstLoopDepth;
 
   // Compute op instance count for the src loop nest without iteration slicing.
-  uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
+  uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
 
   // Compute src loop nest write region size.
   MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
@@ -494,18 +543,21 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     if (slice.isEmpty())
       continue;
 
+    // Compute cost of the slice separately, i.e, the compute cost of the slice
+    // if all outer trip counts are one.
+    int64_t sliceCost;
+
     int64_t fusedLoopNestComputeCost;
-    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
-                              dstLoopNestStats, slice,
-                              &fusedLoopNestComputeCost)) {
-      LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
+
+    auto mayAdditionalComputeFraction =
+        getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions,
+                                     sliceCost, fusedLoopNestComputeCost);
+    if (!mayAdditionalComputeFraction) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Can't determine additional compute fraction.\n");
       continue;
     }
-
-    double additionalComputeFraction =
-        fusedLoopNestComputeCost /
-            (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
-        1;
+    double additionalComputeFraction = *mayAdditionalComputeFraction;
 
     // Determine what the slice write MemRefRegion would be, if the src loop
     // nest slice 'slice' were to be inserted into the dst loop nest at loop
@@ -530,14 +582,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     }
     int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
 
-    // If we are fusing for reuse, check that write regions remain the same.
-    // TODO: Write region check should check sizes and offsets in
-    // each dimension, so that we are sure they are covering the same memref
-    // region. Also, move this out to a isMemRefRegionSuperSet helper function.
-    if (srcOpInst != srcStoreOpInst &&
-        sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
-      continue;
-
     double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
                               static_cast<double>(sliceWriteRegionSizeBytes);
 
@@ -595,7 +639,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
                    << minFusedLoopNestComputeCost << "\n");
 
   auto dstMemSize = getMemoryFootprintBytes(dstForOp);
-  auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
+  auto srcMemSize = getMemoryFootprintBytes(srcForOp);
 
   std::optional<double> storageReduction;
 
@@ -840,6 +884,8 @@ struct GreedyFusion {
         LLVM_DEBUG(llvm::dbgs()
                    << "Trying to fuse producer loop nest " << srcId
                    << " with consumer loop nest " << dstId << "\n");
+        LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: "
+                                << computeToleranceThreshold << '\n');
         LLVM_DEBUG(llvm::dbgs()
                    << "Producer loop nest:\n"
                    << *srcNode->op << "\n and consumer loop nest:\n"
@@ -926,6 +972,9 @@ struct GreedyFusion {
           continue;
         }
 
+        LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
+                                << maxLegalFusionDepth << '\n');
+
         // Check if fusion would be profitable. We skip profitability analysis
         // for maximal fusion since we already know the maximal legal depth to
         // fuse.
@@ -945,14 +994,28 @@ struct GreedyFusion {
           // if only one of the stores is involved the producer-consumer
           // relationship of the candidate loops.
           assert(!producerStores.empty() && "Expected producer store");
-          if (producerStores.size() > 1)
+          if (producerStores.size() > 1) {
             LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
                                        "supported for this case\n");
-          else if (!isFusionProfitable(producerStores[0], producerStores[0],
-                                       dstAffineForOp, depthSliceUnions,
-                                       maxLegalFusionDepth, &bestDstLoopDepth,
-                                       computeToleranceThreshold))
+            // We will still fuse if fusion obeys the specified compute
+            // tolerance at the max legal depth.
+            int64_t sliceCost;
+            int64_t fusedLoopNestComputeCost;
+            auto fraction = getAdditionalComputeFraction(
+                srcAffineForOp, dstAffineForOp, maxLegalFusionDepth,
+                depthSliceUnions, sliceCost, fusedLoopNestComputeCost);
+            if (!fraction || fraction > computeToleranceThreshold) {
+              LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
+                                         "compute tolerance. Not fusing.\n");
+              continue;
+            }
+          }
+          if (!isFusionProfitable(srcAffineForOp, producerStores[0],
+                                  dstAffineForOp, depthSliceUnions,
+                                  maxLegalFusionDepth, &bestDstLoopDepth,
+                                  computeToleranceThreshold)) {
             continue;
+          }
         }
 
         assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
@@ -1169,7 +1232,7 @@ struct GreedyFusion {
         // load op is treated as the src "store" op for fusion profitability
         // purposes. The footprint of the load in the slice relative to the
         // unfused source's determines reuse.
-        if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp,
+        if (!isFusionProfitable(sibAffineForOp, sibLoadOpInst, dstAffineForOp,
                                 depthSliceUnions, maxLegalFusionDepth,
                                 &bestDstLoopDepth, computeToleranceThreshold))
           continue;
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 42d5ce632188e..1fca35836fcc2 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{compute-tolerance=0.0}))' -split-input-file | FileCheck %s --check-prefix=ZERO-TOLERANCE
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer maximal}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
 // All fusion: producer-consumer and sibling.
@@ -495,3 +496,77 @@ func.func @test_add_slice_bounds() {
   }
   return
 }
+
+// -----
+
+// From  https://github.com/llvm/llvm-project/issues/54541
+
+#map = affine_map<(d0) -> (d0 mod 65536)>
+// ZERO-TOLERANCE-LABEL: func @zero_tolerance
+func.func @zero_tolerance(%arg0: memref<65536xcomplex<f64>>, %arg1: memref<30x131072xi64>,
+%3 : memref<30xi64>,
+%4 : memref<30xi64>,
+%5 : memref<30xi64>,
+%6 : memref<30xi64>
+) {
+  %c65536 = arith.constant 65536 : index
+  %cst = arith.constant 0.000000e+00 : f64
+  %cst_0 = arith.constant 0x4320000000380004 : f64
+  %cst_1 = arith.constant 5.000000e-01 : f64
+  %0 = memref.alloc() {alignment = 128 : i64} : memref<30x131072xi64>
+  %1 = memref.alloc() {alignment = 128 : i64} : memref<131072xi1>
+  %2 = memref.alloc() {alignment = 128 : i64} : memref<131072xi128>
+  // The two nests shouldn't be fused when a zero tolerance is specified.
+  // ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 131072
+  affine.for %arg2 = 0 to 131072 {
+    %7 = affine.apply #map(%arg2)
+    %8 = affine.load %arg0[%7] : memref<65536xcomplex<f64>>
+    %9 = arith.cmpi ult, %arg2, %c65536 : index
+    %10 = complex.im %8 : complex<f64>
+    %11 = complex.re %8 : complex<f64>
+    %12 = arith.select %9, %11, %10 : f64
+    %13 = arith.cmpf olt, %12, %cst : f64
+    %14 = arith.negf %12 : f64
+    %15 = arith.select %13, %14, %12 : f64
+    %16 = arith.mulf %15, %cst_0 : f64
+    %17 = arith.addf %16, %cst_1 : f64
+    %18 = arith.fptosi %17 : f64 to i128
+    affine.store %18, %2[%arg2] : memref<131072xi128>
+    affine.store %13, %1[%arg2] : memref<131072xi1>
+  }
+  // ZERO-TOLERANCE:      affine.for %{{.*}} = 0 to 30
+  // ZERO-TOLERANCE-NEXT:   affine.for %{{.*}} = 0 to 131072
+  affine.for %arg2 = 0 to 30 {
+    affine.for %arg3 = 0 to 131072 {
+      %7 = affine.load %6[%arg2] : memref<30xi64>
+      %8 = affine.load %3[%arg2] : memref<30xi64>
+      %9 = affine.load %5[%arg2] : memref<30xi64>
+      %10 = affine.load %4[%arg2] : memref<30xi64>
+      %11 = affine.load %2[%arg3] : memref<131072xi128>
+      %12 = affine.load %1[%arg3] : memref<131072xi1>
+      %13 = func.call @__external_reduce_barrett(%7, %8, %9, %10, %11) {outputModFac = 1 : i64} : (i64, i64, i64, i64, i128) -> i64
+      %14 = arith.subi %7, %13 : i64
+      %15 = arith.select %12, %14, %13 : i64
+      affine.store %15, %0[%arg2, %arg3] : memref<30x131072xi64>
+    }
+  }
+  func.call @__external_levelwise_forward_ntt(%0) : (memref<30x131072xi64>) -> ()
+  // ZERO-TOLERANCE:      affine.for %{{.*}} = 0 to 30
+  // ZERO-TOLERANCE-NEXT:   affine.for %{{.*}} = 0 to 131072
+  affine.for %arg2 = 0 to 30 {
+    affine.for %arg3 = 0 to 131072 {
+      %7 = affine.load %0[%arg2, %arg3] : memref<30x131072xi64>
+      affine.store %7, %arg1[%arg2, %arg3] : memref<30x131072xi64>
+    }
+  }
+  // Under maximal fusion, just one nest.
+  // PRODUCER-CONSUMER-MAXIMAL:      affine.for %{{.*}} = 0 to 30
+  // PRODUCER-CONSUMER-MAXIMAL-NEXT:   affine.for %{{.*}} = 0 to 131072
+  // PRODUCER-CONSUMER-MAXIMAL-NOT:  affine.for %{{.*}}
+  memref.dealloc %2 : memref<131072xi128>
+  memref.dealloc %1 : memref<131072xi1>
+  memref.dealloc %0 : memref<30x131072xi64>
+  return
+}
+func.func private @__external_levelwise_forward_ntt(memref<30x131072xi64>)
+func.func private @__external_reduce_barrett(i64, i64, i64, i64, i128) -> i64



More information about the Mlir-commits mailing list