[Mlir-commits] [mlir] Fix scf.forall inlining: add shared outputs (PR #132197)

Prakhar Dixit llvmlistbot at llvm.org
Thu Mar 20 05:23:29 PDT 2025


https://github.com/Prakhar-Dixit created https://github.com/llvm/llvm-project/pull/132197

Fixes #108164

This patch fixes a crash in the scf-forall-to-for conversion pass by ensuring that the replacement vector used during inlining contains both the induction variables and the shared outputs. Previously, only the induction variables were passed, causing a mismatch with the expected number of block arguments in the forall op’s body. The fix concatenates the shared outputs (retrieved via getOutputs()) with the induction variables and then replaces the forall op with its shared outputs, preserving the intended semantics without introducing regressions.

```
Minimal Example IR:
module {
  func.func @parallel_insert_slice(%arg0: tensor<100xf32>) -> tensor<100xf32> {
    %c100 = arith.constant 100 : index
    %res = scf.forall (%i) in (%c100) shared_outs(%s = %arg0) -> (tensor<100xf32>) {
      %t = "test.foo"() : () -> tensor<100xf32>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %t into %s[%i] [100] [1] : tensor<100xf32> into tensor<100xf32>
      }
    }
    return %res : tensor<100xf32>
  }
}
```

>From c1ab3700149bac18146aa4792434a75a235ae021 Mon Sep 17 00:00:00 2001
From: Prakhar Dixit <dixitprakhar11 at gmail.com>
Date: Thu, 20 Mar 2025 17:48:48 +0530
Subject: [PATCH] Fix scf.forall inlining: add shared outputs

---
 .../Dialect/SCF/Transforms/ForallToFor.cpp    |  7 ++++--
 mlir/test/Dialect/SCF/forall-to-for.mlir      | 23 +++++++++++++++++++
 2 files changed, 28 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
index a2f03f1e1056e..a1df366cef132 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -40,12 +40,15 @@ mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
   SmallVector<Value> ivs = llvm::map_to_vector(
       loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
 
+  SmallVector<Value> replacementVals = ivs;
+  for (Value shared : forallOp.getOutputs())
+    replacementVals.push_back(shared);
   Block *innermostBlock = loopNest.loops.back().getBody();
   rewriter.eraseOp(forallOp.getBody()->getTerminator());
   rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
                              innermostBlock->getTerminator()->getIterator(),
-                             ivs);
-  rewriter.eraseOp(forallOp);
+                             replacementVals);
+  rewriter.replaceOp(forallOp, forallOp.getOutputs());
 
   if (results) {
     llvm::move(loopNest.loops, std::back_inserter(*results));
diff --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir
index e7d183fb9d2b5..17598a154fefd 100644
--- a/mlir/test/Dialect/SCF/forall-to-for.mlir
+++ b/mlir/test/Dialect/SCF/forall-to-for.mlir
@@ -55,3 +55,26 @@ func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
   }
   return
 }
+
+// -----
+
+  func.func @parallel_insert_slice(%arg0: tensor<100xf32>) -> tensor<100xf32> {
+    %c100 = arith.constant 100 : index
+    %res = scf.forall (%i) in (%c100) shared_outs(%s = %arg0) -> (tensor<100xf32>) {
+      %t = "test.foo"() : () -> tensor<100xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %t into %s[%i] [100] [1] : tensor<100xf32> into tensor<100xf32>
+      }
+    }
+    return %res : tensor<100xf32>
+  }
+// CHECK-LABEL:   func.func @parallel_insert_slice(
+// CHECK-SAME:      %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<100xf32>) -> tensor<100xf32> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 100 : index
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:           scf.for %[[VAL_4:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] {
+// CHECK:             %[[VAL_5:.*]] = "test.foo"() : () -> tensor<100xf32>
+// CHECK:           }
+// CHECK:           return %[[VAL_0]] : tensor<100xf32>
+// CHECK:         }
\ No newline at end of file



More information about the Mlir-commits mailing list