[Mlir-commits] [mlir] [MLIR][Affine] Fix double reduction when cyclic loop can't be removed after fusion (PR #189236)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 29 06:05:05 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-affine

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

When a source loop has a cyclic dependence (e.g., a reduction that reads and writes the same memref) and the fusion cannot remove the source after fusing it into the consumer, skip the fusion. Without this check, the cyclic computation would execute twice: once in the original source loop and once in the fused copy, producing incorrect results.

This was the root cause of issue #<!-- -->174580, where affine-loop-fusion produced wrong results for a tosa.rsqrt → tosa.reduce_sum → tosa.sigmoid pipeline. The reinterpret_cast on the accumulator memref caused it to be treated as an "escaping" memref, forcing an isMaximal check. The slice covered only a subset of the source's iteration space (non-maximal), so removeSrcNode was false. The source cyclic reduction was then cloned into the consumer but not removed, causing the reduction to run twice.

Fixes #<!-- -->174580

Assisted-by: Claude Code

---
Full diff: https://github.com/llvm/llvm-project/pull/189236.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+13-1) 
- (modified) mlir/test/Dialect/Affine/loop-fusion-3.mlir (+86) 


``````````diff
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 1ec5fbfef50c3..c7d88e64c7d0a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -1015,7 +1015,8 @@ struct GreedyFusion {
         // redundant execution of the source happens (1:1 pointwise dep on the
         // producer-consumer memref access for example). Check this and allow
         // fusion accordingly.
-        if (hasCyclicDependence(srcAffineForOp)) {
+        bool srcHasCyclicDep = hasCyclicDependence(srcAffineForOp);
+        if (srcHasCyclicDep) {
           LDBG() << "Source nest has a cyclic dependence.";
           // Maximal fusion does not check for compute tolerance threshold; so
           // perform the maximal fusion only when the redundanation computation
@@ -1075,6 +1076,17 @@ struct GreedyFusion {
             srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
             *mdg);
 
+        // If the source loop has a cyclic dependence (e.g., a reduction that
+        // reads and writes the same memref) and cannot be removed after fusion,
+        // skip this fusion. Fusing a cyclic source without removing it would
+        // result in its cyclic computation executing twice: once in the
+        // original source and once in the fused copy.
+        if (srcHasCyclicDep && !removeSrcNode) {
+          LDBG() << "Can't fuse: source has cyclic dependence and "
+                 << "can't be removed after fusion";
+          continue;
+        }
+
         DenseSet<Value> privateMemrefs;
         for (Value memref : producerConsumerMemrefs) {
           if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId,
diff --git a/mlir/test/Dialect/Affine/loop-fusion-3.mlir b/mlir/test/Dialect/Affine/loop-fusion-3.mlir
index 70d6c82105543..4d329d54fcaa8 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-3.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-3.mlir
@@ -1294,5 +1294,91 @@ func.func @unknown_memref_def_op() {
 }
 func.func private @bar() -> memref<10xf32>
 
+// -----
+
+// CHECK-LABEL: func @no_double_reduction_cyclic_src_non_removable
+// Test that a cyclic source loop (reduction) is not fused as a separate copy
+// into a consumer when the fusion cannot remove the source. Without this check,
+// the reduction would run twice, producing incorrect results.
+// The reinterpret_cast makes %acc an "escaping" memref, which forces an
+// isMaximal check. The slice is non-maximal (consumer only covers a subset of
+// the producer's iteration space), making the source non-removable.
+// The fix ensures fusion is skipped in that case, allowing the loops to be
+// correctly combined later when the source can be fully removed.
+//
+// CHECK:         affine.for
+// CHECK-NOT:     affine.for
+// CHECK:           affine.for
+// CHECK-NOT:         affine.for
+// CHECK:               affine.for
+// CHECK-NOT:             affine.for
+// CHECK:                   affine.for
+// CHECK-NOT:                 affine.for
+// CHECK:                   }
+// CHECK-NOT:             affine.for
+// CHECK:               }
+// CHECK-NOT:         affine.for
+// CHECK:             }
+// CHECK-NOT:       affine.for
+// CHECK:           }
+// CHECK-NOT:     affine.for
+// CHECK:         }
+// CHECK-NOT: affine.for
+func.func private @printMemrefF32(memref<*xf32>)
+func.func @no_double_reduction_cyclic_src_non_removable() {
+  %cst = arith.constant 5.000000e-01 : f32
+  %cst_0 = arith.constant 1.000000e+00 : f32
+  %cst_1 = arith.constant 0.000000e+00 : f32
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<3x1x7x5xf32>
+  // Init loop (producer)
+  affine.for %arg0 = 0 to 3 {
+    affine.for %arg1 = 0 to 1 {
+      affine.for %arg2 = 0 to 7 {
+        affine.for %arg3 = 0 to 5 {
+          affine.store %cst, %alloc[%arg0, %arg1, %arg2, %arg3] : memref<3x1x7x5xf32>
+        }
+      }
+    }
+  }
+  // Accumulator via reinterpret_cast (makes it "escaping", triggers isMaximal check)
+  %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<7x5xf32>
+  %reinterpret_cast = memref.reinterpret_cast %alloc_2 to offset: [0], sizes: [1, 7, 5], strides: [35, 5, 1] : memref<7x5xf32> to memref<1x7x5xf32>
+  // Zero-init accumulator
+  affine.for %arg0 = 0 to 1 {
+    affine.for %arg1 = 0 to 7 {
+      affine.for %arg2 = 0 to 5 {
+        affine.store %cst_1, %reinterpret_cast[%arg0, %arg1, %arg2] : memref<1x7x5xf32>
+      }
+    }
+  }
+  // Cyclic reduction loop: reads and writes %reinterpret_cast
+  affine.for %arg0 = 0 to 3 {
+    affine.for %arg1 = 0 to 1 {
+      affine.for %arg2 = 0 to 7 {
+        affine.for %arg3 = 0 to 5 {
+          %0 = affine.load %alloc[%arg0, %arg1, %arg2, %arg3] : memref<3x1x7x5xf32>
+          %1 = affine.load %reinterpret_cast[%arg1, %arg2, %arg3] : memref<1x7x5xf32>
+          %2 = arith.addf %0, %1 : f32
+          affine.store %2, %reinterpret_cast[%arg1, %arg2, %arg3] : memref<1x7x5xf32>
+        }
+      }
+    }
+  }
+  // Consumer loop (sigmoid)
+  %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<1x7x5xf32>
+  affine.for %arg0 = 0 to 1 {
+    affine.for %arg1 = 0 to 7 {
+      affine.for %arg2 = 0 to 5 {
+        %0 = affine.load %reinterpret_cast[%arg0, %arg1, %arg2] : memref<1x7x5xf32>
+        %1 = arith.negf %0 : f32
+        %2 = math.exp %1 : f32
+        %3 = arith.addf %2, %cst_0 : f32
+        %4 = arith.divf %cst_0, %3 : f32
+        affine.store %4, %alloc_3[%arg0, %arg1, %arg2] : memref<1x7x5xf32>
+      }
+    }
+  }
+  return
+}
 
 // Add further tests in mlir/test/Transforms/loop-fusion-4.mlir

``````````

</details>


https://github.com/llvm/llvm-project/pull/189236


More information about the Mlir-commits mailing list