[Mlir-commits] [mlir] [MLIR][Affine] Add missing check on fusion compute tolerance on a path (PR #128454)

Uday Bondhugula llvmlistbot at llvm.org
Mon Feb 24 23:05:54 PST 2025


https://github.com/bondhugula updated https://github.com/llvm/llvm-project/pull/128454

>From 6d31ee564ca81ea4b18ef0f51b82c025f553f8df Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Sat, 22 Feb 2025 16:05:37 +0530
Subject: [PATCH] [MLIR][Affine] Add missing check on fusion compute tolerance
 on a path

When profitability analysis can't be performed, we should still be
respecting the compute tolerance specified. Refactor to pull the
additional computation factor computation and check.

Fixes: https://github.com/llvm/llvm-project/issues/54541
---
 .../Dialect/Affine/Transforms/LoopFusion.cpp  | 59 +++++++++-----
 mlir/test/Dialect/Affine/loop-fusion-4.mlir   | 78 +++++++++++++++++++
 2 files changed, 117 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index b97f11a963828..c1b919d0dcc7f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -15,7 +15,6 @@
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Affine/Utils.h"
@@ -473,7 +472,8 @@ static Value createPrivateMemRef(AffineForOp forOp,
 //    is lower.
 // TODO: Extend profitability analysis to support scenarios with multiple
 // stores.
-static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
+static bool isFusionProfitable(AffineForOp srcForOp,
+                               ArrayRef<Operation *> producerStores,
                                AffineForOp dstForOp,
                                ArrayRef<ComputationSliceState> depthSliceUnions,
                                unsigned maxLegalFusionDepth,
@@ -503,6 +503,34 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
   if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
     return false;
 
+  // TODO: Suppport multiple producer stores in profitability
+  // analysis. We limit profitability analysis to only scenarios with
+  // a single producer store for now. Note that some multi-store
+  // producer scenarios will still go through profitability analysis
+  // if only one of the stores is involved the producer-consumer
+  // relationship of the candidate loops.
+  if (producerStores.size() > 1) {
+    LLVM_DEBUG(llvm::dbgs() << "Limited profitability analysis. Not "
+                               "supported for multiple producer store case.\n");
+    int64_t sliceCost;
+    int64_t fusedLoopNestComputeCost;
+    // We will still fuse if fusion obeys the specified compute
+    // tolerance at the max legal depth.
+    auto fraction = getAdditionalComputeFraction(
+        srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
+        fusedLoopNestComputeCost);
+    if (!fraction || fraction > computeToleranceThreshold) {
+      LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
+                                 "compute tolerance. Not fusing.\n");
+      return false;
+    }
+    LLVM_DEBUG(llvm::dbgs()
+               << "Considering fusion profitable at max legal depth.\n");
+    return true;
+  }
+
+  Operation *srcStoreOp = producerStores.front();
+
   // Search for min cost value for 'dstLoopDepth'. At each value of
   // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
   // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
@@ -520,8 +548,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
   uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
 
   // Compute src loop nest write region size.
-  MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
-  if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
+  MemRefRegion srcWriteRegion(srcStoreOp->getLoc());
+  if (failed(srcWriteRegion.compute(srcStoreOp, /*loopDepth=*/0))) {
     LLVM_DEBUG(llvm::dbgs()
                << "Unable to compute MemRefRegion for source operation\n");
     return false;
@@ -563,9 +591,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
     // Determine what the slice write MemRefRegion would be, if the src loop
     // nest slice 'slice' were to be inserted into the dst loop nest at loop
     // depth 'i'.
-    MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
-    if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
-                                        &slice))) {
+    MemRefRegion sliceWriteRegion(srcStoreOp->getLoc());
+    if (failed(sliceWriteRegion.compute(srcStoreOp, /*loopDepth=*/0, &slice))) {
       LLVM_DEBUG(llvm::dbgs()
                  << "Failed to compute slice write region at loopDepth: " << i
                  << "\n");
@@ -1025,21 +1052,13 @@ struct GreedyFusion {
                     cast<AffineWriteOpInterface>(op).getMemRef()))
               producerStores.push_back(op);
 
-          // TODO: Suppport multiple producer stores in profitability
-          // analysis. We limit profitability analysis to only scenarios with
-          // a single producer store for now. Note that some multi-store
-          // producer scenarios will still go through profitability analysis
-          // if only one of the stores is involved the producer-consumer
-          // relationship of the candidate loops.
           assert(!producerStores.empty() && "Expected producer store");
-          if (producerStores.size() > 1)
-            LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
-                                       "supported for this case\n");
-          else if (!isFusionProfitable(srcAffineForOp, producerStores[0],
-                                       dstAffineForOp, depthSliceUnions,
-                                       maxLegalFusionDepth, &bestDstLoopDepth,
-                                       computeToleranceThresholdToUse))
+          if (!isFusionProfitable(srcAffineForOp, producerStores,
+                                  dstAffineForOp, depthSliceUnions,
+                                  maxLegalFusionDepth, &bestDstLoopDepth,
+                                  computeToleranceThresholdToUse)) {
             continue;
+          }
         }
 
         assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index cf96a30a6e62a..b5b951ad5eb0e 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{compute-tolerance=0.0}))' -split-input-file | FileCheck %s --check-prefix=ZERO-TOLERANCE
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer maximal}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
 // All fusion: producer-consumer and sibling.
@@ -544,3 +545,80 @@ func.func @sibling_reduction(%input : memref<10xf32>, %output : memref<10xf32>,
   // SIBLING-MAXIMAL-NEXT:   affine.store
   return
 }
+
+// -----
+
+// From  https://github.com/llvm/llvm-project/issues/54541
+
+#map = affine_map<(d0) -> (d0 mod 65536)>
+// ZERO-TOLERANCE-LABEL: func @zero_tolerance
+func.func @zero_tolerance(%arg0: memref<65536xcomplex<f64>>, %arg1: memref<30x131072xi64>,
+%3 : memref<30xi64>,
+%4 : memref<30xi64>,
+%5 : memref<30xi64>,
+%6 : memref<30xi64>
+) {
+  %c65536 = arith.constant 65536 : index
+  %cst = arith.constant 0.000000e+00 : f64
+  %cst_0 = arith.constant 0x4320000000380004 : f64
+  %cst_1 = arith.constant 5.000000e-01 : f64
+  %0 = memref.alloc() {alignment = 128 : i64} : memref<30x131072xi64>
+  %1 = memref.alloc() {alignment = 128 : i64} : memref<131072xi1>
+  %2 = memref.alloc() {alignment = 128 : i64} : memref<131072xi128>
+  // This nest nest shouldn't be fused in when a zero tolerance is specified.
+  // ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 131072
+  affine.for %arg2 = 0 to 131072 {
+    %7 = affine.apply #map(%arg2)
+    %8 = affine.load %arg0[%7] : memref<65536xcomplex<f64>>
+    %9 = arith.cmpi ult, %arg2, %c65536 : index
+    %10 = complex.im %8 : complex<f64>
+    %11 = complex.re %8 : complex<f64>
+    %12 = arith.select %9, %11, %10 : f64
+    %13 = arith.cmpf olt, %12, %cst : f64
+    %14 = arith.negf %12 : f64
+    %15 = arith.select %13, %14, %12 : f64
+    %16 = arith.mulf %15, %cst_0 : f64
+    %17 = arith.addf %16, %cst_1 : f64
+    %18 = arith.fptosi %17 : f64 to i128
+    affine.store %18, %2[%arg2] : memref<131072xi128>
+    affine.store %13, %1[%arg2] : memref<131072xi1>
+  }
+  // The next two nests are fused.
+  // ZERO-TOLERANCE:      affine.for %{{.*}} = 0 to 30
+  // ZERO-TOLERANCE-NEXT:   affine.for %{{.*}} = 0 to 131072
+  // ZERO-TOLERANCE:          func.call @__external_reduce_barrett
+  // ZERO-TOLERANCE:          affine.store
+  // ZERO-TOLERANCE:          affine.load
+  // ZERO-TOLERANCE-NEXT:     affine.store
+  affine.for %arg2 = 0 to 30 {
+    affine.for %arg3 = 0 to 131072 {
+      %7 = affine.load %6[%arg2] : memref<30xi64>
+      %8 = affine.load %3[%arg2] : memref<30xi64>
+      %9 = affine.load %5[%arg2] : memref<30xi64>
+      %10 = affine.load %4[%arg2] : memref<30xi64>
+      %11 = affine.load %2[%arg3] : memref<131072xi128>
+      %12 = affine.load %1[%arg3] : memref<131072xi1>
+      %13 = func.call @__external_reduce_barrett(%7, %8, %9, %10, %11) {outputModFac = 1 : i64} : (i64, i64, i64, i64, i128) -> i64
+      %14 = arith.subi %7, %13 : i64
+      %15 = arith.select %12, %14, %13 : i64
+      affine.store %15, %0[%arg2, %arg3] : memref<30x131072xi64>
+    }
+  }
+  func.call @__external_levelwise_forward_ntt(%0) : (memref<30x131072xi64>) -> ()
+  affine.for %arg2 = 0 to 30 {
+    affine.for %arg3 = 0 to 131072 {
+      %7 = affine.load %0[%arg2, %arg3] : memref<30x131072xi64>
+      affine.store %7, %arg1[%arg2, %arg3] : memref<30x131072xi64>
+    }
+  }
+  // Under maximal fusion, just one nest.
+  // PRODUCER-CONSUMER-MAXIMAL:      affine.for %{{.*}} = 0 to 30
+  // PRODUCER-CONSUMER-MAXIMAL-NEXT:   affine.for %{{.*}} = 0 to 131072
+  // PRODUCER-CONSUMER-MAXIMAL-NOT:  affine.for %{{.*}}
+  memref.dealloc %2 : memref<131072xi128>
+  memref.dealloc %1 : memref<131072xi1>
+  memref.dealloc %0 : memref<30x131072xi64>
+  return
+}
+func.func private @__external_levelwise_forward_ntt(memref<30x131072xi64>)
+func.func private @__external_reduce_barrett(i64, i64, i64, i64, i128) -> i64



More information about the Mlir-commits mailing list