[Mlir-commits] [mlir] [MLIR][Affine] Fix fusion private memref creation for multiple producer stores (PR #130365)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 7 15:06:14 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir

Author: Uday Bondhugula (bondhugula)

<details>
<summary>Changes</summary>

Fix private memref creation in affine fusion for the multiple producer store
case. This scenario was not supported but not properly checked.

Fixes: https://github.com/llvm/llvm-project/issues/120227


---
Full diff: https://github.com/llvm/llvm-project/pull/130365.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+26-5) 
- (modified) mlir/test/Dialect/Affine/loop-fusion-2.mlir (+10-6) 
- (modified) mlir/test/Dialect/Affine/loop-fusion-4.mlir (+44) 


``````````diff
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 21cd0e2e49c2a..bcba17bb21544 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -328,7 +328,9 @@ static std::optional<double> getAdditionalComputeFraction(
 // Creates and returns a private (single-user) memref for fused loop rooted at
 // 'forOp', with (potentially reduced) memref size based on the memref region
 // written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
-// specifies the block in which the slice was/will be inserted.
+// specifies the block in which the slice was/will be inserted. The method
+// expects that all stores ops to the memref have the same access function.
+// Returns nullptr if the creation failed.
 static Value createPrivateMemRef(AffineForOp forOp,
                                  ArrayRef<Operation *> storeOps,
                                  unsigned dstLoopDepth,
@@ -336,6 +338,24 @@ static Value createPrivateMemRef(AffineForOp forOp,
                                  Block *sliceInsertionBlock,
                                  uint64_t localBufSizeThreshold) {
   assert(!storeOps.empty() && "no source stores supplied");
+
+  // Check if all stores have the same access function; we only support this
+  // case.
+  // TODO: Use union of memref write regions to compute private memref footprint
+  // for store ops with different access functions.
+  if (storeOps.size() > 1 &&
+      !std::equal(std::next(storeOps.begin()), storeOps.end(), storeOps.begin(),
+                  [](Operation *a, Operation *b) {
+                    MemRefAccess aM(cast<AffineWriteOpInterface>(a));
+                    MemRefAccess bM(cast<AffineWriteOpInterface>(b));
+                    return aM == bM;
+                  })) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Private memref creation unsupported for multiple producer "
+                  "stores with different access functions.\n");
+    return nullptr;
+  }
+
   Operation *srcStoreOp = storeOps[0];
 
   // Create builder to insert alloc op just before 'forOp'.
@@ -432,6 +452,8 @@ static Value createPrivateMemRef(AffineForOp forOp,
   assert(succeeded(res) &&
          "replaceAllMemrefUsesWith should always succeed here");
   (void)res;
+  LLVM_DEBUG(llvm::dbgs() << "Created private memref of type: " << newMemRefType
+                          << '\n');
   return newMemRef;
 }
 
@@ -1123,13 +1145,12 @@ struct GreedyFusion {
           // loads and stores. Any reference to the original ones becomes
           // invalid after this point.
           for (auto &memrefToStoresPair : privateMemRefToStores) {
-            // TODO: Use union of memref write regions to compute
-            // private memref footprint.
-            SmallVector<Operation *, 4> &storesForMemref =
-                memrefToStoresPair.second;
+            ArrayRef<Operation *> storesForMemref = memrefToStoresPair.second;
             Value newMemRef = createPrivateMemRef(
                 dstAffineForOp, storesForMemref, bestDstLoopDepth,
                 fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
+            if (!newMemRef)
+              continue;
             // Create new node in dependence graph for 'newMemRef' alloc op.
             unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
             // Add edge from 'newMemRef' node to dstNode.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-2.mlir b/mlir/test/Dialect/Affine/loop-fusion-2.mlir
index 9648504b7f960..b26a539b2f7d5 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-2.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-2.mlir
@@ -39,16 +39,20 @@ func.func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4
 
   // We can fuse source loop nest '%i0' into dst loop nest '%i2', but the
   // depth at which we can insert the src loop nest slice into the dst loop
-  // lest must be decreased because of a loop carried dependence on loop '%i3'.
+  // nest must be decreased because of a loop carried dependence on loop '%i3'.
   // As a result, the source loop nest is inserted at dst loop nest depth 1,
   // just above the loop with the carried dependence. In addition, the source
   // loop nest iteration bounds on its loop '%i1' are reduced to 1, so the
-  // memref size can be reduced to 128x1xf32.
+  // memref size can be reduced to 64x1xf32.
 
-  // CHECK:       memref.alloc() : memref<64x1xf32>
+  // In this case, since we have multiple producer stores (for %out) with
+  // different access functions and we don't yet support private memref
+  // computation in such cases, a 64x1 private memref isn't created.
+
+  // CHECK:       memref.alloc() : memref<64x4xf32>
   // CHECK:       affine.for %{{.*}} = 0 to 4 {
   // CHECK-NEXT:    affine.for %{{.*}} = 0 to 64 {
-  // CHECK-NEXT:      affine.store %{{.*}}, %{{.*}}[%{{.*}}, 0] : memref<64x1xf32>
+  // CHECK-NEXT:      affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x4xf32>
   // CHECK-NEXT:    }
   // CHECK-NEXT:    affine.for %{{.*}} = 0 to 4 {
   // CHECK-NEXT:      affine.for %{{.*}} = 0 to 16 {
@@ -62,9 +66,9 @@ func.func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4
   // CHECK-NEXT:        }
   // CHECK-NEXT:        affine.for %{{.*}} = 0 to 16 {
   // CHECK-NEXT:          %{{.*}} = "op2"() : () -> f32
-  // CHECK:               affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32>
+  // CHECK:               affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, %{{.*}}] : memref<64x4xf32>
   // CHECK-NEXT:          arith.addf %{{.*}}, %{{.*}} : f32
-  // CHECK:               affine.store %{{.*}}, %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32>
+  // CHECK:               affine.store %{{.*}}, %{{.*}}[%{{.*}} * 16 + %{{.*}}, %{{.*}}] : memref<64x4xf32>
   // CHECK-NEXT:        }
   // CHECK-NEXT:      }
   // CHECK-NEXT:    }
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index b5b951ad5eb0e..4b9eca45492fb 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -622,3 +622,47 @@ func.func @zero_tolerance(%arg0: memref<65536xcomplex<f64>>, %arg1: memref<30x13
 }
 func.func private @__external_levelwise_forward_ntt(memref<30x131072xi64>)
 func.func private @__external_reduce_barrett(i64, i64, i64, i64, i128) -> i64
+
+// An unrolled loop nest. Fusion here should correctly fuse while preserving
+// dependences between store-load pairs of the same memref. A private memref
+// of size 1x1x1 can't be created.
+
+// PRODUCER-CONSUMER-MAXIMAL-LABEL: func @unrolled
+func.func @unrolled(%arg0: memref<2x4xf32>, %arg1: memref<1x2x4xf32>) {
+  %alloc = memref.alloc() : memref<1x2x4xf32>
+  affine.for %i = 0 to 1 {
+    %0 = affine.load %arg0[0, 0] : memref<2x4xf32>
+    %1 = affine.load %arg0[0, 1] : memref<2x4xf32>
+    %2 = affine.load %arg0[0, 2] : memref<2x4xf32>
+    %3 = affine.load %arg0[0, 3] : memref<2x4xf32>
+    %4 = affine.load %arg0[1, 0] : memref<2x4xf32>
+    %5 = affine.load %arg0[1, 1] : memref<2x4xf32>
+    %6 = affine.load %arg0[1, 2] : memref<2x4xf32>
+    %7 = affine.load %arg0[1, 3] : memref<2x4xf32>
+
+    affine.store %0, %alloc[0, 0, 0] : memref<1x2x4xf32>
+    affine.store %1, %alloc[0, 0, 1] : memref<1x2x4xf32>
+    affine.store %2, %alloc[0, 0, 2] : memref<1x2x4xf32>
+    affine.store %3, %alloc[0, 0, 3] : memref<1x2x4xf32>
+    affine.store %4, %alloc[0, 1, 0] : memref<1x2x4xf32>
+    affine.store %5, %alloc[0, 1, 1] : memref<1x2x4xf32>
+    affine.store %6, %alloc[0, 1, 2] : memref<1x2x4xf32>
+    affine.store %7, %alloc[0, 1, 3] : memref<1x2x4xf32>
+  }
+
+  affine.for %i = 0 to 2 {
+    affine.for %j = 0 to 4 {
+      %8 = affine.load %alloc[0, %i, %j] : memref<1x2x4xf32>
+      %9 = arith.negf %8 : f32
+      affine.store %9, %arg1[0, %i, %j] : memref<1x2x4xf32>
+    }
+  }
+  // PRODUCER-CONSUMER-MAXIMAL:      affine.for %{{.*}} = 0 to 2 {
+  // PRODUCER-CONSUMER-MAXIMAL-NEXT:   affine.for %{{.*}} = 0 to 4 {
+  // PRODUCER-CONSUMER-MAXIMAL-NEXT:     affine.load %{{.*}}[0, 0]
+  // PRODUCER-CONSUMER-MAXIMAL:          affine.load %{{.*}}[1, 3]
+  // PRODUCER-CONSUMER-MAXIMAL:          affine.store %{{.*}}[0, 0, 0]
+  // PRODUCER-CONSUMER-MAXIMAL:          affine.store %{{.*}}[0, 1, 3]
+  // PRODUCER-CONSUMER-MAXIMAL:          affine.load %{{.*}}[0, %{{.*}}, %{{.*}}]
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/130365


More information about the Mlir-commits mailing list