[Mlir-commits] [mlir] Fix scf.forall inlining: add shared outputs (PR #132197)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 20 05:24:05 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Prakhar Dixit (Prakhar-Dixit)

<details>
<summary>Changes</summary>

Fixes #<!-- -->108164

This patch fixes a crash in the scf-forall-to-for conversion pass by ensuring that the replacement vector used during inlining contains both the induction variables and the shared outputs. Previously, only the induction variables were passed, causing a mismatch with the expected number of block arguments in the forall op’s body. The fix concatenates the shared outputs (retrieved via getOutputs()) with the induction variables and then replaces the forall op with its shared outputs, preserving the intended semantics without introducing regressions.

```
Minimal Example IR:
module {
  func.func @<!-- -->parallel_insert_slice(%arg0: tensor<100xf32>) -> tensor<100xf32> {
    %c100 = arith.constant 100 : index
    %res = scf.forall (%i) in (%c100) shared_outs(%s = %arg0) -> (tensor<100xf32>) {
      %t = "test.foo"() : () -> tensor<100xf32>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %t into %s[%i] [100] [1] : tensor<100xf32> into tensor<100xf32>
      }
    }
    return %res : tensor<100xf32>
  }
}
```

---
Full diff: https://github.com/llvm/llvm-project/pull/132197.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp (+5-2) 
- (modified) mlir/test/Dialect/SCF/forall-to-for.mlir (+23) 


``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
index a2f03f1e1056e..a1df366cef132 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -40,12 +40,15 @@ mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
   SmallVector<Value> ivs = llvm::map_to_vector(
       loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
 
+  SmallVector<Value> replacementVals = ivs;
+  for (Value shared : forallOp.getOutputs())
+    replacementVals.push_back(shared);
   Block *innermostBlock = loopNest.loops.back().getBody();
   rewriter.eraseOp(forallOp.getBody()->getTerminator());
   rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
                              innermostBlock->getTerminator()->getIterator(),
-                             ivs);
-  rewriter.eraseOp(forallOp);
+                             replacementVals);
+  rewriter.replaceOp(forallOp, forallOp.getOutputs());
 
   if (results) {
     llvm::move(loopNest.loops, std::back_inserter(*results));
diff --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir
index e7d183fb9d2b5..17598a154fefd 100644
--- a/mlir/test/Dialect/SCF/forall-to-for.mlir
+++ b/mlir/test/Dialect/SCF/forall-to-for.mlir
@@ -55,3 +55,26 @@ func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
   }
   return
 }
+
+// -----
+
+  func.func @parallel_insert_slice(%arg0: tensor<100xf32>) -> tensor<100xf32> {
+    %c100 = arith.constant 100 : index
+    %res = scf.forall (%i) in (%c100) shared_outs(%s = %arg0) -> (tensor<100xf32>) {
+      %t = "test.foo"() : () -> tensor<100xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %t into %s[%i] [100] [1] : tensor<100xf32> into tensor<100xf32>
+      }
+    }
+    return %res : tensor<100xf32>
+  }
+// CHECK-LABEL:   func.func @parallel_insert_slice(
+// CHECK-SAME:      %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<100xf32>) -> tensor<100xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:           scf.for %[[VAL_4:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_5:.*]] = "test.foo"() : () -> tensor<100xf32>
+// CHECK:           }
+// CHECK:           return %[[VAL_0]] : tensor<100xf32>
+// CHECK:         }
\ No newline at end of file

``````````

</details>


https://github.com/llvm/llvm-project/pull/132197


More information about the Mlir-commits mailing list