[Mlir-commits] [mlir] [MLIR][Affine] Add missing check on fusion compute tolerance on a path (PR #128454)
Uday Bondhugula
llvmlistbot at llvm.org
Mon Feb 24 23:05:54 PST 2025
https://github.com/bondhugula updated https://github.com/llvm/llvm-project/pull/128454
>From 6d31ee564ca81ea4b18ef0f51b82c025f553f8df Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Sat, 22 Feb 2025 16:05:37 +0530
Subject: [PATCH] [MLIR][Affine] Add missing check on fusion compute tolerance
on a path
When profitability analysis can't be performed, we should still be
respecting the compute tolerance specified. Refactor to pull the
additional computation factor computation and check.
Fixes: https://github.com/llvm/llvm-project/issues/54541
---
.../Dialect/Affine/Transforms/LoopFusion.cpp | 59 +++++++++-----
mlir/test/Dialect/Affine/loop-fusion-4.mlir | 78 +++++++++++++++++++
2 files changed, 117 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index b97f11a963828..c1b919d0dcc7f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -15,7 +15,6 @@
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Affine/Utils.h"
@@ -473,7 +472,8 @@ static Value createPrivateMemRef(AffineForOp forOp,
// is lower.
// TODO: Extend profitability analysis to support scenarios with multiple
// stores.
-static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
+static bool isFusionProfitable(AffineForOp srcForOp,
+ ArrayRef<Operation *> producerStores,
AffineForOp dstForOp,
ArrayRef<ComputationSliceState> depthSliceUnions,
unsigned maxLegalFusionDepth,
@@ -503,6 +503,34 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
return false;
+ // TODO: Suppport multiple producer stores in profitability
+ // analysis. We limit profitability analysis to only scenarios with
+ // a single producer store for now. Note that some multi-store
+ // producer scenarios will still go through profitability analysis
+ // if only one of the stores is involved the producer-consumer
+ // relationship of the candidate loops.
+ if (producerStores.size() > 1) {
+ LLVM_DEBUG(llvm::dbgs() << "Limited profitability analysis. Not "
+ "supported for multiple producer store case.\n");
+ int64_t sliceCost;
+ int64_t fusedLoopNestComputeCost;
+ // We will still fuse if fusion obeys the specified compute
+ // tolerance at the max legal depth.
+ auto fraction = getAdditionalComputeFraction(
+ srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
+ fusedLoopNestComputeCost);
+ if (!fraction || fraction > computeToleranceThreshold) {
+ LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
+ "compute tolerance. Not fusing.\n");
+ return false;
+ }
+ LLVM_DEBUG(llvm::dbgs()
+ << "Considering fusion profitable at max legal depth.\n");
+ return true;
+ }
+
+ Operation *srcStoreOp = producerStores.front();
+
// Search for min cost value for 'dstLoopDepth'. At each value of
// 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
// bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
@@ -520,8 +548,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
// Compute src loop nest write region size.
- MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
- if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
+ MemRefRegion srcWriteRegion(srcStoreOp->getLoc());
+ if (failed(srcWriteRegion.compute(srcStoreOp, /*loopDepth=*/0))) {
LLVM_DEBUG(llvm::dbgs()
<< "Unable to compute MemRefRegion for source operation\n");
return false;
@@ -563,9 +591,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
// Determine what the slice write MemRefRegion would be, if the src loop
// nest slice 'slice' were to be inserted into the dst loop nest at loop
// depth 'i'.
- MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
- if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
- &slice))) {
+ MemRefRegion sliceWriteRegion(srcStoreOp->getLoc());
+ if (failed(sliceWriteRegion.compute(srcStoreOp, /*loopDepth=*/0, &slice))) {
LLVM_DEBUG(llvm::dbgs()
<< "Failed to compute slice write region at loopDepth: " << i
<< "\n");
@@ -1025,21 +1052,13 @@ struct GreedyFusion {
cast<AffineWriteOpInterface>(op).getMemRef()))
producerStores.push_back(op);
- // TODO: Suppport multiple producer stores in profitability
- // analysis. We limit profitability analysis to only scenarios with
- // a single producer store for now. Note that some multi-store
- // producer scenarios will still go through profitability analysis
- // if only one of the stores is involved the producer-consumer
- // relationship of the candidate loops.
assert(!producerStores.empty() && "Expected producer store");
- if (producerStores.size() > 1)
- LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
- "supported for this case\n");
- else if (!isFusionProfitable(srcAffineForOp, producerStores[0],
- dstAffineForOp, depthSliceUnions,
- maxLegalFusionDepth, &bestDstLoopDepth,
- computeToleranceThresholdToUse))
+ if (!isFusionProfitable(srcAffineForOp, producerStores,
+ dstAffineForOp, depthSliceUnions,
+ maxLegalFusionDepth, &bestDstLoopDepth,
+ computeToleranceThresholdToUse)) {
continue;
+ }
}
assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index cf96a30a6e62a..b5b951ad5eb0e 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{compute-tolerance=0.0}))' -split-input-file | FileCheck %s --check-prefix=ZERO-TOLERANCE
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer maximal}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
// All fusion: producer-consumer and sibling.
@@ -544,3 +545,80 @@ func.func @sibling_reduction(%input : memref<10xf32>, %output : memref<10xf32>,
// SIBLING-MAXIMAL-NEXT: affine.store
return
}
+
+// -----
+
+// From https://github.com/llvm/llvm-project/issues/54541
+
+#map = affine_map<(d0) -> (d0 mod 65536)>
+// ZERO-TOLERANCE-LABEL: func @zero_tolerance
+func.func @zero_tolerance(%arg0: memref<65536xcomplex<f64>>, %arg1: memref<30x131072xi64>,
+%3 : memref<30xi64>,
+%4 : memref<30xi64>,
+%5 : memref<30xi64>,
+%6 : memref<30xi64>
+) {
+ %c65536 = arith.constant 65536 : index
+ %cst = arith.constant 0.000000e+00 : f64
+ %cst_0 = arith.constant 0x4320000000380004 : f64
+ %cst_1 = arith.constant 5.000000e-01 : f64
+ %0 = memref.alloc() {alignment = 128 : i64} : memref<30x131072xi64>
+ %1 = memref.alloc() {alignment = 128 : i64} : memref<131072xi1>
+ %2 = memref.alloc() {alignment = 128 : i64} : memref<131072xi128>
+ // This nest nest shouldn't be fused in when a zero tolerance is specified.
+ // ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 131072
+ affine.for %arg2 = 0 to 131072 {
+ %7 = affine.apply #map(%arg2)
+ %8 = affine.load %arg0[%7] : memref<65536xcomplex<f64>>
+ %9 = arith.cmpi ult, %arg2, %c65536 : index
+ %10 = complex.im %8 : complex<f64>
+ %11 = complex.re %8 : complex<f64>
+ %12 = arith.select %9, %11, %10 : f64
+ %13 = arith.cmpf olt, %12, %cst : f64
+ %14 = arith.negf %12 : f64
+ %15 = arith.select %13, %14, %12 : f64
+ %16 = arith.mulf %15, %cst_0 : f64
+ %17 = arith.addf %16, %cst_1 : f64
+ %18 = arith.fptosi %17 : f64 to i128
+ affine.store %18, %2[%arg2] : memref<131072xi128>
+ affine.store %13, %1[%arg2] : memref<131072xi1>
+ }
+ // The next two nests are fused.
+ // ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 30
+ // ZERO-TOLERANCE-NEXT: affine.for %{{.*}} = 0 to 131072
+ // ZERO-TOLERANCE: func.call @__external_reduce_barrett
+ // ZERO-TOLERANCE: affine.store
+ // ZERO-TOLERANCE: affine.load
+ // ZERO-TOLERANCE-NEXT: affine.store
+ affine.for %arg2 = 0 to 30 {
+ affine.for %arg3 = 0 to 131072 {
+ %7 = affine.load %6[%arg2] : memref<30xi64>
+ %8 = affine.load %3[%arg2] : memref<30xi64>
+ %9 = affine.load %5[%arg2] : memref<30xi64>
+ %10 = affine.load %4[%arg2] : memref<30xi64>
+ %11 = affine.load %2[%arg3] : memref<131072xi128>
+ %12 = affine.load %1[%arg3] : memref<131072xi1>
+ %13 = func.call @__external_reduce_barrett(%7, %8, %9, %10, %11) {outputModFac = 1 : i64} : (i64, i64, i64, i64, i128) -> i64
+ %14 = arith.subi %7, %13 : i64
+ %15 = arith.select %12, %14, %13 : i64
+ affine.store %15, %0[%arg2, %arg3] : memref<30x131072xi64>
+ }
+ }
+ func.call @__external_levelwise_forward_ntt(%0) : (memref<30x131072xi64>) -> ()
+ affine.for %arg2 = 0 to 30 {
+ affine.for %arg3 = 0 to 131072 {
+ %7 = affine.load %0[%arg2, %arg3] : memref<30x131072xi64>
+ affine.store %7, %arg1[%arg2, %arg3] : memref<30x131072xi64>
+ }
+ }
+ // Under maximal fusion, just one nest.
+ // PRODUCER-CONSUMER-MAXIMAL: affine.for %{{.*}} = 0 to 30
+ // PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 131072
+ // PRODUCER-CONSUMER-MAXIMAL-NOT: affine.for %{{.*}}
+ memref.dealloc %2 : memref<131072xi128>
+ memref.dealloc %1 : memref<131072xi1>
+ memref.dealloc %0 : memref<30x131072xi64>
+ return
+}
+func.func private @__external_levelwise_forward_ntt(memref<30x131072xi64>)
+func.func private @__external_reduce_barrett(i64, i64, i64, i64, i128) -> i64
More information about the Mlir-commits
mailing list