[Mlir-commits] [mlir] [SCF] Add interface methods to `ParallelCombiningOp` for promotion (PR #159840)

Alan Li llvmlistbot at llvm.org
Fri Sep 19 12:45:21 PDT 2025


https://github.com/lialan created https://github.com/llvm/llvm-project/pull/159840

`ParallelCombiningOp` adds expandability to the parallel insertion of a `scf.forall.in_parallel` op.

This patch adds interface methods for the optimizer to promote ops.
* `canPromoteInParallelLoop` make decisions whether we can fold/promote in trivial iteration cases. 
* `promoteInParallelLoop` does the actual work.

>From a9073c056efe40e7a36917aec3025c97b668ed35 Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 18 Sep 2025 13:29:16 -0400
Subject: [PATCH 1/2] Adding a missing op for ParallelCombiningOpInterface

---
 mlir/include/mlir/Dialect/SCF/IR/SCF.h        |  2 +-
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  3 +-
 .../Interfaces/ParallelCombiningOpInterface.h |  3 ++
 .../ParallelCombiningOpInterface.td           | 20 +++++++
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 52 ++++++++++++-------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 15 ++++++
 6 files changed, 73 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index ba648181daecb..830b49321c2e4 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -58,7 +58,7 @@ ForallOp getForallOpThreadIndexOwner(Value val);
 bool insideMutuallyExclusiveBranches(Operation *a, Operation *b);
 
 /// Promotes the loop body of a scf::ForallOp to its containing block.
-void promote(RewriterBase &rewriter, scf::ForallOp forallOp);
+LogicalResult promote(RewriterBase &rewriter, scf::ForallOp forallOp);
 
 /// An owning vector of values, handy to return from functions.
 using ValueVector = SmallVector<Value>;
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 2453cf5b5b5a4..be04c3a4aebbe 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1474,7 +1474,8 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
        AttrSizedOperandSegments,
        OffsetSizeAndStrideOpInterface,
        DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
-          ["getUpdatedDestinations", "getIteratingParent"]>,
+          ["getUpdatedDestinations", "getIteratingParent",
+           "promoteInParallelLoop"]>,
        // TODO: Cannot use an interface here atm, verify this manually for now.
        // HasParent<"InParallelOpInterface">
   ]> {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
index 82ab427699f64..ff4e5a87d05c7 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
@@ -15,6 +15,9 @@
 #define MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE_H_
 
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
 
 namespace mlir {
 namespace detail {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
index ace26f723ef53..632371b2777fd 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -106,6 +106,26 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
       /*methodName=*/"getIteratingParent",
       /*args=*/(ins)
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Promotes this parallel combining op out of its enclosing parallel loop
+        and returns the values that should replace the destinations updated by
+        this op.
+      }],
+      /*retTy=*/"::mlir::FailureOr<::llvm::SmallVector<::mlir::Value>>",
+      /*methodName=*/"promoteInParallelLoop",
+      /*args=*/(ins "::mlir::RewriterBase &":$rewriter)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns true if this op can be promoted out of its enclosing parallel
+        loop.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"canPromoteInParallelLoop",
+      /*args=*/(ins "::mlir::RewriterBase &":$rewriter),
+      /*methodBody=*/[{ return true; }]
+    >,
   ];
 }
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c35989ecba6cd..4115ca00f64b5 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -651,8 +651,7 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
       return failure();
   }
 
-  promote(rewriter, *this);
-  return success();
+  return promote(rewriter, *this);
 }
 
 Block::BlockArgListType ForallOp::getRegionIterArgs() {
@@ -664,10 +663,23 @@ MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
 }
 
 /// Promotes the loop body of a scf::ForallOp to its containing block.
-void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
+LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
   OpBuilder::InsertionGuard g(rewriter);
   scf::InParallelOp terminator = forallOp.getTerminator();
 
+  // Make sure we can promote all parallel combining ops in terminator:
+  for (auto &yieldingOp : terminator.getYieldingOps()) {
+    auto parallelCombiningOp =
+        dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+    if (!parallelCombiningOp)
+      return rewriter.notifyMatchFailure(
+          forallOp, "terminator has non-parallel-combining op");
+    if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter))
+      return rewriter.notifyMatchFailure(
+          forallOp, "parallel combining op cannot be promoted");
+  }
+
+
   // Replace block arguments with lower bounds (replacements for IVs) and
   // outputs.
   SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
@@ -683,30 +695,29 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
   SmallVector<Value> results;
   results.reserve(forallOp.getResults().size());
   for (auto &yieldingOp : terminator.getYieldingOps()) {
-    auto parallelInsertSliceOp =
-        dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
-    if (!parallelInsertSliceOp)
+    auto parallelCombiningOp =
+        dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
+    if (!parallelCombiningOp)
       continue;
 
-    Value dst = parallelInsertSliceOp.getDest();
-    Value src = parallelInsertSliceOp.getSource();
-    if (llvm::isa<TensorType>(src.getType())) {
-      results.push_back(tensor::InsertSliceOp::create(
-          rewriter, forallOp.getLoc(), dst.getType(), src, dst,
-          parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(),
-          parallelInsertSliceOp.getStrides(),
-          parallelInsertSliceOp.getStaticOffsets(),
-          parallelInsertSliceOp.getStaticSizes(),
-          parallelInsertSliceOp.getStaticStrides()));
-    } else {
-      llvm_unreachable("unsupported terminator");
-    }
+    assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter));
+
+    FailureOr<SmallVector<Value>> promotedValues =
+        parallelCombiningOp.promoteInParallelLoop(rewriter);
+    if (failed(promotedValues))
+      return failure();
+
+    results.append(promotedValues->begin(), promotedValues->end());
   }
+  if (results.size() != forallOp.getResults().size())
+    return rewriter.notifyMatchFailure(
+        forallOp, "failed to materialize replacements for all results");
   rewriter.replaceAllUsesWith(forallOp.getResults(), results);
 
   // Erase the old terminator and the loop.
   rewriter.eraseOp(terminator);
   rewriter.eraseOp(forallOp);
+  return success();
 }
 
 LoopNest mlir::scf::buildLoopNest(
@@ -1789,7 +1800,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
 
     // All of the loop dimensions perform a single iteration. Inline loop body.
     if (newMixedLowerBounds.empty()) {
-      promote(rewriter, op);
+      if (failed(promote(rewriter, op)))
+        return failure();
       return success();
     }
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fa97b49a41d97..2932000b85b3b 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3947,6 +3947,21 @@ Operation *ParallelInsertSliceOp::getIteratingParent() {
   return nullptr;
 }
 
+FailureOr<SmallVector<Value>>
+ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) {
+  Value dst = getDest();
+  Value src = getSource();
+  if (!isa<TensorType>(src.getType()))
+    return rewriter.notifyMatchFailure(getOperation(),
+                                       "expected tensor source");
+
+  Value inserted = tensor::InsertSliceOp::create(
+      rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(),
+      getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides());
+
+  return SmallVector<Value>{inserted};
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//

>From 551f35e3d6c4890b2c5f16e48bd2c1d1d9645bed Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Thu, 18 Sep 2025 14:00:55 -0400
Subject: [PATCH 2/2] Update. Related test: test/Dialect/SCF/transform-ops.mlir

---
 mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td     |  2 +-
 .../mlir/Interfaces/ParallelCombiningOpInterface.h   |  1 -
 .../mlir/Interfaces/ParallelCombiningOpInterface.td  |  7 ++++---
 mlir/lib/Dialect/SCF/IR/SCF.cpp                      | 10 ++++------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp             | 12 ++++++++----
 5 files changed, 17 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index be04c3a4aebbe..4fb4cc8410230 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1475,7 +1475,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
        OffsetSizeAndStrideOpInterface,
        DeclareOpInterfaceMethods<ParallelCombiningOpInterface,
           ["getUpdatedDestinations", "getIteratingParent",
-           "promoteInParallelLoop"]>,
+           "promoteInParallelLoop", "canPromoteInParallelLoop"]>,
        // TODO: Cannot use an interface here atm, verify this manually for now.
        // HasParent<"InParallelOpInterface">
   ]> {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
index ff4e5a87d05c7..85cc18c47a527 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
@@ -17,7 +17,6 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LogicalResult.h"
-#include "llvm/ADT/SmallVector.h"
 
 namespace mlir {
 namespace detail {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
index 632371b2777fd..1a333d82d8468 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -109,10 +109,10 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
     InterfaceMethod<
       /*desc=*/[{
         Promotes this parallel combining op out of its enclosing parallel loop
-        and returns the values that should replace the destinations updated by
+        and returns the value that should replace the destination updated by
         this op.
       }],
-      /*retTy=*/"::mlir::FailureOr<::llvm::SmallVector<::mlir::Value>>",
+      /*retTy=*/"::mlir::FailureOr<::mlir::Value>",
       /*methodName=*/"promoteInParallelLoop",
       /*args=*/(ins "::mlir::RewriterBase &":$rewriter)
     >,
@@ -124,7 +124,8 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
       /*retTy=*/"bool",
       /*methodName=*/"canPromoteInParallelLoop",
       /*args=*/(ins "::mlir::RewriterBase &":$rewriter),
-      /*methodBody=*/[{ return true; }]
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{ return false; }]
     >,
   ];
 }
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 4115ca00f64b5..04737738d8593 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -672,14 +672,12 @@ LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp)
     auto parallelCombiningOp =
         dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
     if (!parallelCombiningOp)
-      return rewriter.notifyMatchFailure(
-          forallOp, "terminator has non-parallel-combining op");
+      continue;
     if (!parallelCombiningOp.canPromoteInParallelLoop(rewriter))
       return rewriter.notifyMatchFailure(
           forallOp, "parallel combining op cannot be promoted");
   }
 
-
   // Replace block arguments with lower bounds (replacements for IVs) and
   // outputs.
   SmallVector<Value> bbArgReplacements = forallOp.getLowerBound(rewriter);
@@ -702,12 +700,12 @@ LogicalResult mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp)
 
     assert(parallelCombiningOp.canPromoteInParallelLoop(rewriter));
 
-    FailureOr<SmallVector<Value>> promotedValues =
+    FailureOr<Value> promotedValue =
         parallelCombiningOp.promoteInParallelLoop(rewriter);
-    if (failed(promotedValues))
+    if (failed(promotedValue))
       return failure();
 
-    results.append(promotedValues->begin(), promotedValues->end());
+    results.push_back(*promotedValue);
   }
   if (results.size() != forallOp.getResults().size())
     return rewriter.notifyMatchFailure(
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 2932000b85b3b..f05c58a40fde0 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3947,19 +3947,23 @@ Operation *ParallelInsertSliceOp::getIteratingParent() {
   return nullptr;
 }
 
-FailureOr<SmallVector<Value>>
+FailureOr<Value>
 ParallelInsertSliceOp::promoteInParallelLoop(RewriterBase &rewriter) {
   Value dst = getDest();
   Value src = getSource();
   if (!isa<TensorType>(src.getType()))
-    return rewriter.notifyMatchFailure(getOperation(),
-                                       "expected tensor source");
+    return failure();
 
   Value inserted = tensor::InsertSliceOp::create(
       rewriter, getLoc(), dst.getType(), src, dst, getOffsets(), getSizes(),
       getStrides(), getStaticOffsets(), getStaticSizes(), getStaticStrides());
 
-  return SmallVector<Value>{inserted};
+  return inserted;
+}
+
+bool ParallelInsertSliceOp::canPromoteInParallelLoop(RewriterBase &) {
+  return isa<TensorType>(getSource().getType()) &&
+         isa<TensorType>(getDest().getType());
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list