[Mlir-commits] [mlir] [MLIR][Affine] Add missing check on fusion compute tolerance on a path (PR #128454)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 23 19:45:16 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-affine
Author: Uday Bondhugula (bondhugula)
<details>
<summary>Changes</summary>
When profitability analysis can't be performed, we should still be respecting the compute tolerance specified. Refactor to pull the additional computation factor computation and check.
Fixes: https://github.com/llvm/llvm-project/issues/54541
---
Full diff: https://github.com/llvm/llvm-project/pull/128454.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+109-46)
- (modified) mlir/test/Dialect/Affine/loop-fusion-4.mlir (+75)
``````````diff
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 5add7df849286..7945e156dcd31 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -15,7 +15,6 @@
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Affine/Utils.h"
@@ -274,6 +273,58 @@ getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
return firstAncestor;
}
+/// Returns the amount of additional (redundant) computation that will be done
+/// as a fraction of the total computation if `srcForOp` is fused into
+/// `dstForOp` at depth `depth`. The method returns the compute cost of the
+/// slice and the fused nest's compute cost in the trailing output arguments.
+static std::optional<double> getAdditionalComputeFraction(
+ AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
+ ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
+ int64_t &fusedLoopNestComputeCost) {
+ LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";);
+ // Compute cost of sliced and unsliced src loop nest.
+ // Walk src loop nest and collect stats.
+ LoopNestStats srcLoopNestStats;
+ if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) {
+ LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n");
+ return std::nullopt;
+ }
+
+ // Compute cost of dst loop nest.
+ LoopNestStats dstLoopNestStats;
+ if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) {
+ LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n");
+ return std::nullopt;
+ }
+
+ // Compute op instance count for the src loop nest without iteration slicing.
+ uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
+
+ // Compute op cost for the dst loop nest.
+ uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
+
+ const ComputationSliceState &slice = depthSliceUnions[depth - 1];
+ // Skip slice union if it wasn't computed for this depth.
+ if (slice.isEmpty()) {
+ LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n");
+ return std::nullopt;
+ }
+
+ if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp,
+ dstLoopNestStats, slice,
+ &fusedLoopNestComputeCost)) {
+ LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
+ return std::nullopt;
+ }
+
+ double additionalComputeFraction =
+ fusedLoopNestComputeCost /
+ (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
+ 1;
+
+ return additionalComputeFraction;
+}
+
// Creates and returns a private (single-user) memref for fused loop rooted at
// 'forOp', with (potentially reduced) memref size based on the memref region
// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
@@ -384,20 +435,19 @@ static Value createPrivateMemRef(AffineForOp forOp,
}
// Checks the profitability of fusing a backwards slice of the loop nest
-// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
-// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
-// the memref being produced and consumed, which is an input to the cost model.
-// For producer-consumer fusion, 'srcStoreOpInst' will be the same as
-// 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
-// fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
-// same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
-// unique store op in the src node, which will be used to check that the write
-// region is the same after input-reuse fusion. Computation slices are provided
-// in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
-// fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
-// profitable to fuse the candidate loop nests. Returns false otherwise.
-// `dstLoopDepth` is set to the most profitable depth at which to materialize
-// the source loop nest slice.
+// `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument
+// 'srcStoreOpInst' is used to calculate the storage reduction on the memref
+// being produced and consumed, which is an input to the cost model. For
+// producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst',
+// as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst'
+// will be the src loop nest LoadOp which reads from the same memref as dst loop
+// nest load ops, and 'srcStoreOpInst' will be the unique store op in the src
+// node, which will be used to check that the write region is the same after
+// input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for
+// each legal fusion depth. The maximal depth at which fusion is legal is
+// provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse
+// the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to
+// the most profitable depth at which to materialize the source loop nest slice.
// The profitability model executes the following steps:
// *) Computes the backward computation slice at 'srcOpInst'. This
// computation slice of the loop nest surrounding 'srcOpInst' is
@@ -422,15 +472,16 @@ static Value createPrivateMemRef(AffineForOp forOp,
// is lower.
// TODO: Extend profitability analysis to support scenarios with multiple
// stores.
-static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
+static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
AffineForOp dstForOp,
ArrayRef<ComputationSliceState> depthSliceUnions,
unsigned maxLegalFusionDepth,
unsigned *dstLoopDepth,
double computeToleranceThreshold) {
LLVM_DEBUG({
- llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
- llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
+ llvm::dbgs()
+ << "Checking whether fusion is profitable between source nest:\n";
+ llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n";
llvm::dbgs() << dstForOp << "\n";
});
@@ -440,12 +491,10 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
}
// Compute cost of sliced and unsliced src loop nest.
- SmallVector<AffineForOp, 4> srcLoopIVs;
- getAffineForIVs(*srcOpInst, &srcLoopIVs);
// Walk src loop nest and collect stats.
LoopNestStats srcLoopNestStats;
- if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
+ if (!getLoopNestStats(srcForOp, &srcLoopNestStats))
return false;
// Compute cost of dst loop nest.
@@ -467,7 +516,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
std::optional<unsigned> bestDstLoopDepth;
// Compute op instance count for the src loop nest without iteration slicing.
- uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
+ uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
// Compute src loop nest write region size.
MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
@@ -494,18 +543,21 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
if (slice.isEmpty())
continue;
+ // Compute cost of the slice separately, i.e, the compute cost of the slice
+ // if all outer trip counts are one.
+ int64_t sliceCost;
+
int64_t fusedLoopNestComputeCost;
- if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
- dstLoopNestStats, slice,
- &fusedLoopNestComputeCost)) {
- LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
+
+ auto mayAdditionalComputeFraction =
+ getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions,
+ sliceCost, fusedLoopNestComputeCost);
+ if (!mayAdditionalComputeFraction) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Can't determine additional compute fraction.\n");
continue;
}
-
- double additionalComputeFraction =
- fusedLoopNestComputeCost /
- (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
- 1;
+ double additionalComputeFraction = *mayAdditionalComputeFraction;
// Determine what the slice write MemRefRegion would be, if the src loop
// nest slice 'slice' were to be inserted into the dst loop nest at loop
@@ -530,14 +582,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
}
int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
- // If we are fusing for reuse, check that write regions remain the same.
- // TODO: Write region check should check sizes and offsets in
- // each dimension, so that we are sure they are covering the same memref
- // region. Also, move this out to a isMemRefRegionSuperSet helper function.
- if (srcOpInst != srcStoreOpInst &&
- sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
- continue;
-
double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
static_cast<double>(sliceWriteRegionSizeBytes);
@@ -595,7 +639,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
<< minFusedLoopNestComputeCost << "\n");
auto dstMemSize = getMemoryFootprintBytes(dstForOp);
- auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
+ auto srcMemSize = getMemoryFootprintBytes(srcForOp);
std::optional<double> storageReduction;
@@ -840,6 +884,8 @@ struct GreedyFusion {
LLVM_DEBUG(llvm::dbgs()
<< "Trying to fuse producer loop nest " << srcId
<< " with consumer loop nest " << dstId << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: "
+ << computeToleranceThreshold << '\n');
LLVM_DEBUG(llvm::dbgs()
<< "Producer loop nest:\n"
<< *srcNode->op << "\n and consumer loop nest:\n"
@@ -926,6 +972,9 @@ struct GreedyFusion {
continue;
}
+ LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
+ << maxLegalFusionDepth << '\n');
+
// Check if fusion would be profitable. We skip profitability analysis
// for maximal fusion since we already know the maximal legal depth to
// fuse.
@@ -945,14 +994,28 @@ struct GreedyFusion {
// if only one of the stores is involved the producer-consumer
// relationship of the candidate loops.
assert(!producerStores.empty() && "Expected producer store");
- if (producerStores.size() > 1)
+ if (producerStores.size() > 1) {
LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
"supported for this case\n");
- else if (!isFusionProfitable(producerStores[0], producerStores[0],
- dstAffineForOp, depthSliceUnions,
- maxLegalFusionDepth, &bestDstLoopDepth,
- computeToleranceThreshold))
+ // We will still fuse if fusion obeys the specified compute
+ // tolerance at the max legal depth.
+ int64_t sliceCost;
+ int64_t fusedLoopNestComputeCost;
+ auto fraction = getAdditionalComputeFraction(
+ srcAffineForOp, dstAffineForOp, maxLegalFusionDepth,
+ depthSliceUnions, sliceCost, fusedLoopNestComputeCost);
+ if (!fraction || fraction > computeToleranceThreshold) {
+ LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
+ "compute tolerance. Not fusing.\n");
+ continue;
+ }
+ }
+ if (!isFusionProfitable(srcAffineForOp, producerStores[0],
+ dstAffineForOp, depthSliceUnions,
+ maxLegalFusionDepth, &bestDstLoopDepth,
+ computeToleranceThreshold)) {
continue;
+ }
}
assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
@@ -1169,7 +1232,7 @@ struct GreedyFusion {
// load op is treated as the src "store" op for fusion profitability
// purposes. The footprint of the load in the slice relative to the
// unfused source's determines reuse.
- if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp,
+ if (!isFusionProfitable(sibAffineForOp, sibLoadOpInst, dstAffineForOp,
depthSliceUnions, maxLegalFusionDepth,
&bestDstLoopDepth, computeToleranceThreshold))
continue;
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 42d5ce632188e..1fca35836fcc2 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{compute-tolerance=0.0}))' -split-input-file | FileCheck %s --check-prefix=ZERO-TOLERANCE
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer maximal}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
// All fusion: producer-consumer and sibling.
@@ -495,3 +496,77 @@ func.func @test_add_slice_bounds() {
}
return
}
+
+// -----
+
+// From https://github.com/llvm/llvm-project/issues/54541
+
+#map = affine_map<(d0) -> (d0 mod 65536)>
+// ZERO-TOLERANCE-LABEL: func @zero_tolerance
+func.func @zero_tolerance(%arg0: memref<65536xcomplex<f64>>, %arg1: memref<30x131072xi64>,
+%3 : memref<30xi64>,
+%4 : memref<30xi64>,
+%5 : memref<30xi64>,
+%6 : memref<30xi64>
+) {
+ %c65536 = arith.constant 65536 : index
+ %cst = arith.constant 0.000000e+00 : f64
+ %cst_0 = arith.constant 0x4320000000380004 : f64
+ %cst_1 = arith.constant 5.000000e-01 : f64
+ %0 = memref.alloc() {alignment = 128 : i64} : memref<30x131072xi64>
+ %1 = memref.alloc() {alignment = 128 : i64} : memref<131072xi1>
+ %2 = memref.alloc() {alignment = 128 : i64} : memref<131072xi128>
+ // The two nests shouldn't be fused when a zero tolerance is specified.
+ // ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 131072
+ affine.for %arg2 = 0 to 131072 {
+ %7 = affine.apply #map(%arg2)
+ %8 = affine.load %arg0[%7] : memref<65536xcomplex<f64>>
+ %9 = arith.cmpi ult, %arg2, %c65536 : index
+ %10 = complex.im %8 : complex<f64>
+ %11 = complex.re %8 : complex<f64>
+ %12 = arith.select %9, %11, %10 : f64
+ %13 = arith.cmpf olt, %12, %cst : f64
+ %14 = arith.negf %12 : f64
+ %15 = arith.select %13, %14, %12 : f64
+ %16 = arith.mulf %15, %cst_0 : f64
+ %17 = arith.addf %16, %cst_1 : f64
+ %18 = arith.fptosi %17 : f64 to i128
+ affine.store %18, %2[%arg2] : memref<131072xi128>
+ affine.store %13, %1[%arg2] : memref<131072xi1>
+ }
+ // ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 30
+ // ZERO-TOLERANCE-NEXT: affine.for %{{.*}} = 0 to 131072
+ affine.for %arg2 = 0 to 30 {
+ affine.for %arg3 = 0 to 131072 {
+ %7 = affine.load %6[%arg2] : memref<30xi64>
+ %8 = affine.load %3[%arg2] : memref<30xi64>
+ %9 = affine.load %5[%arg2] : memref<30xi64>
+ %10 = affine.load %4[%arg2] : memref<30xi64>
+ %11 = affine.load %2[%arg3] : memref<131072xi128>
+ %12 = affine.load %1[%arg3] : memref<131072xi1>
+ %13 = func.call @__external_reduce_barrett(%7, %8, %9, %10, %11) {outputModFac = 1 : i64} : (i64, i64, i64, i64, i128) -> i64
+ %14 = arith.subi %7, %13 : i64
+ %15 = arith.select %12, %14, %13 : i64
+ affine.store %15, %0[%arg2, %arg3] : memref<30x131072xi64>
+ }
+ }
+ func.call @__external_levelwise_forward_ntt(%0) : (memref<30x131072xi64>) -> ()
+ // ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 30
+ // ZERO-TOLERANCE-NEXT: affine.for %{{.*}} = 0 to 131072
+ affine.for %arg2 = 0 to 30 {
+ affine.for %arg3 = 0 to 131072 {
+ %7 = affine.load %0[%arg2, %arg3] : memref<30x131072xi64>
+ affine.store %7, %arg1[%arg2, %arg3] : memref<30x131072xi64>
+ }
+ }
+ // Under maximal fusion, just one nest.
+ // PRODUCER-CONSUMER-MAXIMAL: affine.for %{{.*}} = 0 to 30
+ // PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 131072
+ // PRODUCER-CONSUMER-MAXIMAL-NOT: affine.for %{{.*}}
+ memref.dealloc %2 : memref<131072xi128>
+ memref.dealloc %1 : memref<131072xi1>
+ memref.dealloc %0 : memref<30x131072xi64>
+ return
+}
+func.func private @__external_levelwise_forward_ntt(memref<30x131072xi64>)
+func.func private @__external_reduce_barrett(i64, i64, i64, i64, i128) -> i64
``````````
</details>
https://github.com/llvm/llvm-project/pull/128454
More information about the Mlir-commits
mailing list