[Mlir-commits] [mlir] [SCF] Add interface methods to `ParallelCombiningOp` for promotion (PR #159840)

Alan Li llvmlistbot at llvm.org
Mon Dec 8 19:03:35 PST 2025


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/159840

>From 020e2c1d6b4a615d291100dff44c06031d1c8f11 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 18 Sep 2025 13:29:16 -0400
Subject: [PATCH] [MLIR][SCF] Add `promoteInParallelLoop` to
 `ParallelCombiningOpInterface`

This is a followup of #157736 where we introduced `ParallelCombiningOpInterface`.

This patch extends `ParallelCombiningOpInterface` with two new methods:
* `canPromoteInParallelLoop`: returns whether an op can be promoted
* `promoteInParallelLoop`: promotes the op and returns the replacement value

The `scf::promote` function is refactored to use these interface methods
instead of hardcoding `tensor::ParallelInsertSliceOp` handling. This makes
the promotion logic extensible to other parallel combining ops.
---
 mlir/include/mlir/Dialect/SCF/IR/SCF.h        |  2 +-
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  3 +-
 .../Interfaces/ParallelCombiningOpInterface.h |  2 +
 .../ParallelCombiningOpInterface.td           | 21 ++++++++
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 54 ++++++++++++-------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 19 +++++++
 mlir/test/Dialect/SCF/canonicalize.mlir       | 25 +++++++++
 7 files changed, 104 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..ac6f034dba728 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -58,7 +58,7 @@ ForallOp getForallOpThreadIndexOwner(Value val);
 bool insideMutuallyExclusiveBranches(Operation *a, Operation *b);
 
 /// Promotes the loop body of a scf::ForallOp to its containing block.
-void promote(RewriterBase &rewriter, scf::ForallOp forallOp);
+LogicalResult promote(RewriterBase &rewriter, scf::ForallOp forallOp);
 
 /// An owning vector of values, handy to return from functions.
 using ValueVector = SmallVector<Value>;
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d2b6007c628..b03380bf65b8b 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1483,7 +1483,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
        AttrSizedOperandSegments,
        OffsetSizeAndStrideOpInterface,
        DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
-          ["getUpdatedDestinations", "getIteratingParent"]>,
+          ["getUpdatedDestinations", "getIteratingParent",
+           "promoteInParallelLoop", "canPromoteInParallelLoop"]>,
        // TODO: Cannot use an interface here atm, verify this manually for now.
        // HasParent<"InParallelOpInterface">
   ]> {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
index 82ab427699f64..85cc18c47a527 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
@@ -15,6 +15,8 @@
 #define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
 
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
 
 namespace mlir {
 namespace detail {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
index ace26f723ef53..1a333d82d8468 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -106,6 +106,27 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
       /*methodName=*/"getIteratingParent",
       /*args=*/(ins)
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Promotes this parallel combining op out of its enclosing parallel loop
+        and returns the value that should replace the destination updated by
+        this op.
+      }],
+      /*retTy=*/"::mlir::FailureOr<::mlir::Value>",
+      /*methodName=*/"promoteInParallelLoop",
+      /*args=*/(ins "::mlir::RewriterBase &":$rewriter)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if this op can be promoted out of its enclosing parallel
+        loop.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"canPromoteInParallelLoop",
+      /*args=*/(ins "::mlir::RewriterBase &":$rewriter),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return false; }]
+    >,
   ];
 }
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c75528a76c999..b069ae90d4e68 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -777,8 +777,7 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
       return failure();
   }
 
-  promote(rewriter, *this);
-  return success();
+  return promote(rewriter, *this);
 }
 
 Block::BlockArgListType ForallOp::getRegionIterArgs() {
@@ -790,10 +789,28 @@ MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
 }
 
 /// Promotes the loop body of a scf::ForallOp to its containing block.
-void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
+LogicalResult mlir::scf::promote(RewriterBase &rewriter,
+                                 scf::ForallOp forallOp) {
   OpBuilder::InsertionGuard g(rewriter);
   scf::InParallelOp terminator = forallOp.getTerminator();
 
+  // Make sure we can promote all parallel combining ops in terminator:
+  unsigned numParallelCombiningOps = 0;
+  for (auto &yieldingOp : terminator.getYieldingOps()) {
+    auto parallelCombiningOp =
+        dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+    if (!parallelCombiningOp)
+      continue;
+    ++numParallelCombiningOps;
+    if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter))
+      return rewriter.notifyMatchFailure(
+          forallOp, "parallel combining op cannot be promoted");
+  }
+  if (numParallelCombiningOps != forallOp.getResults().size())
+    return rewriter.notifyMatchFailure(
+        forallOp,
+        "number of parallel combining ops does not match number of results");
+
   // Replace block arguments with lower bounds (replacements for IVs) and
   // outputs.
   SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
@@ -809,30 +826,26 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
   SmallVector<Value> results;
   results.reserve(forallOp.getResults().size());
   for (auto &yieldingOp : terminator.getYieldingOps()) {
-    auto parallelInsertSliceOp =
-        dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
-    if (!parallelInsertSliceOp)
+    auto parallelCombiningOp =
+        dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+    if (!parallelCombiningOp)
       continue;
 
-    Value dst = parallelInsertSliceOp.getDest();
-    Value src = parallelInsertSliceOp.getSource();
-    if (llvm::isa<TensorType>(src.getType())) {
-      results.push_back(tensor::InsertSliceOp::create(
-          rewriter, forallOp.getLoc(), dst.getType(), src, dst,
-          parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
-          parallelInsertSliceOp.getStrides(),
-          parallelInsertSliceOp.getStaticOffsets(),
-          parallelInsertSliceOp.getStaticSizes(),
-          parallelInsertSliceOp.getStaticStrides()));
-    } else {
-      llvm_unreachable("unsupported terminator");
-    }
+    assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter));
+
+    FailureOr<Value> promotedValue =
+        parallelCombiningOp.promoteInParallelLoop(rewriter);
+    if (failed(promotedValue))
+      return failure();
+
+    results.push_back(*promotedValue);
   }
   rewriter.replaceAllUsesWith(forallOp.getResults(), results);
 
   // Erase the old terminator and the loop.
   rewriter.eraseOp(terminator);
   rewriter.eraseOp(forallOp);
+  return success();
 }
 
 LoopNest mlir::scf::buildLoopNest(
@@ -1890,7 +1903,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
 
     // All of the loop dimensions perform a single iteration. Inline loop body.
     if (newMixedLowerBounds.empty()) {
-      promote(rewriter, op);
+      if (failed(promote(rewriter, op)))
+        return failure();
       return success();
     }
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 204e9bb73e12c..bf3790cd023af 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3962,6 +3962,25 @@ Operation *ParallelInsertSliceOp::getIteratingParent() {
   return nullptr;
 }
 
+FailureOr<Value>
+ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) {
+  Value dst = getDest();
+  Value src = getSource();
+  if (!isa<TensorType>(src.getType()))
+    return failure();
+
+  Value inserted = tensor::InsertSliceOp::create(
+      rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(),
+      getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides());
+
+  return inserted;
+}
+
+bool ParallelInsertSliceOp::canPromoteInParallelLoop(RewriterBase &) {
+  return isa<TensorType>(getSource().getType()) &&
+         isa<TensorType>(getDest().getType());
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..e99948fdc6926 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2171,3 +2171,28 @@ func.func @scf_for_all_step_size_0()  {
   }
   return
 }
+
+// -----
+
+// Test single-iteration forall with multiple parallel_insert_slice ops.
+func.func @inline_forall_loop_multiple_results(
+    %arg0: tensor<8x8xf32>, %arg1: tensor<4x4xf32>,
+    %s0: tensor<2x3xf32>, %s1: tensor<2x2xf32>) -> (tensor<8x8xf32>, tensor<4x4xf32>) {
+  %0:2 = scf.forall (%i) in (1) shared_outs (%out0 = %arg0, %out1 = %arg1)
+      -> (tensor<8x8xf32>, tensor<4x4xf32>) {
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %s0 into %out0[0, 0] [2, 3] [1, 1]
+        : tensor<2x3xf32> into tensor<8x8xf32>
+      tensor.parallel_insert_slice %s1 into %out1[0, 0] [2, 2] [1, 1]
+        : tensor<2x2xf32> into tensor<4x4xf32>
+    }
+  }
+  return %0#0, %0#1 : tensor<8x8xf32>, tensor<4x4xf32>
+}
+// CHECK-LABEL: @inline_forall_loop_multiple_results
+// CHECK-SAME:    %[[ARG0:.*]]: tensor<8x8xf32>, %[[ARG1:.*]]: tensor<4x4xf32>,
+// CHECK-SAME:    %[[S0:.*]]: tensor<2x3xf32>, %[[S1:.*]]: tensor<2x2xf32>
+// CHECK-NOT:     scf.forall
+// CHECK-DAG:     %[[R0:.*]] = tensor.insert_slice %[[S0]] into %[[ARG0]]
+// CHECK-DAG:     %[[R1:.*]] = tensor.insert_slice %[[S1]] into %[[ARG1]]
+// CHECK:         return %[[R0]], %[[R1]]



More information about the Mlir-commits mailing list