[Mlir-commits] [mlir] 11bbee9 - Adding to execute_region_op some missing support (#164159)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 23 04:46:02 PDT 2025
Author: ddubov100
Date: 2025-10-23T11:45:58Z
New Revision: 11bbee9d9fbaa98978ed7704e799d6b56fb47295
URL: https://github.com/llvm/llvm-project/commit/11bbee9d9fbaa98978ed7704e799d6b56fb47295
DIFF: https://github.com/llvm/llvm-project/commit/11bbee9d9fbaa98978ed7704e799d6b56fb47295.diff
LOG: Adding to execute_region_op some missing support (#164159)
Adding canonicalization pattern in case execute_region op has yieldOps
which operands are from outside the execute_region, then it simplifies
the op to return just internal values. The pattern is applied only in
case all yieldOps within execute_region_op have same operands
---------
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
Added:
Modified:
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9bd13f3236cfc..744a5951330a3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -27,6 +27,7 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/DebugLog.h"
@@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
}
};
+// Pattern to eliminate ExecuteRegionOp results which forward external
+// values from the region. In case there are multiple yield operations,
+// all of them must have the same operands in order for the pattern to be
+// applicable.
+struct ExecuteRegionForwardingEliminator
+ : public OpRewritePattern<ExecuteRegionOp> {
+ using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExecuteRegionOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getNumResults() == 0)
+ return failure();
+
+ SmallVector<Operation *> yieldOps;
+ for (Block &block : op.getRegion()) {
+ if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
+ yieldOps.push_back(yield.getOperation());
+ }
+
+ if (yieldOps.empty())
+ return failure();
+
+ // Check if all yield operations have the same operands.
+ auto yieldOpsOperands = yieldOps[0]->getOperands();
+ for (auto *yieldOp : yieldOps) {
+ if (yieldOp->getOperands() != yieldOpsOperands)
+ return failure();
+ }
+
+ SmallVector<Value> externalValues;
+ SmallVector<Value> internalValues;
+ SmallVector<Value> opResultsToReplaceWithExternalValues;
+ SmallVector<Value> opResultsToKeep;
+ for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
+ if (isValueFromInsideRegion(yieldedValue, op)) {
+ internalValues.push_back(yieldedValue);
+ opResultsToKeep.push_back(op.getResult(index));
+ } else {
+ externalValues.push_back(yieldedValue);
+ opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
+ }
+ }
+ // No yielded external values - nothing to do.
+ if (externalValues.empty())
+ return failure();
+
+ // There are yielded external values - create a new execute_region returning
+ // just the internal values.
+ SmallVector<Type> resultTypes;
+ for (Value value : internalValues)
+ resultTypes.push_back(value.getType());
+ auto newOp =
+ ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
+ newOp->setAttrs(op->getAttrs());
+
+ // Move old op's region to the new operation.
+ rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
+ newOp.getRegion().end());
+
+ // Replace all yield operations with a new yield operation with updated
+ // results. scf.execute_region must have at least one yield operation.
+ for (auto *yieldOp : yieldOps) {
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
+ ValueRange(internalValues));
+ }
+
+ // Replace the old operation with the external values directly.
+ rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
+ externalValues);
+ // Replace the old operation's remaining results with the new operation's
+ // results.
+ rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+private:
+ bool isValueFromInsideRegion(Value value,
+ ExecuteRegionOp executeRegionOp) const {
+ // Check if the value is defined within the execute_region
+ if (Operation *defOp = value.getDefiningOp())
+ return &executeRegionOp.getRegion() == defOp->getParentRegion();
+
+ // If it's a block argument, check if it's from within the region
+ if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
+ return &executeRegionOp.getRegion() == blockArg.getParentRegion();
+
+ return false; // Value is from outside the region
+ }
+};
+
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
+ results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
+ ExecuteRegionForwardingEliminator>(context);
}
void ExecuteRegionOp::getSuccessorRegions(
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 2bec63672e783..084c3fc065de3 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1604,6 +1604,148 @@ func.func @func_execute_region_inline_multi_yield() {
// -----
+// Test case with single scf.yield op inside execute_region and its operand is defined outside the execute_region op.
+// Make scf.execute_region not to return anything.
+
+// CHECK: scf.execute_region no_inline {
+// CHECK: func.call @foo() : () -> ()
+// CHECK: scf.yield
+// CHECK: }
+
+module {
+func.func private @foo()->()
+func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %1 = scf.execute_region -> memref<1x60xui8> no_inline {
+ func.call @foo():()->()
+ scf.yield %alloc: memref<1x60xui8>
+ }
+ return %1 : memref<1x60xui8>
+}
+}
+
+// -----
+
+// Test case with scf.yield op inside execute_region with multiple operands.
+// One of operands is defined outside the execute_region op.
+// Remove just this operand from the op results.
+
+// CHECK: %[[VAL_1:.*]] = scf.execute_region -> memref<1x120xui8> no_inline {
+// CHECK: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+// CHECK: func.call @foo() : () -> ()
+// CHECK: scf.yield %[[VAL_2]] : memref<1x120xui8>
+// CHECK: }
+module {
+func.func private @foo()->()
+func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ func.call @foo():()->()
+ scf.yield %alloc, %alloc_1: memref<1x60xui8>, memref<1x120xui8>
+ }
+ return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
+}
+}
+
+// -----
+
+// Test case with multiple scf.yield ops inside execute_region with same operands and those operands are defined outside the execute_region op..
+// Make scf.execute_region not to return anything.
+// scf.yield must remain, cause scf.execute_region can't be empty.
+
+// CHECK: scf.execute_region no_inline {
+// CHECK: %[[VAL_3:.*]] = "test.cmp"() : () -> i1
+// CHECK: cf.cond_br %[[VAL_3]], ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: scf.yield
+// CHECK: ^bb2:
+// CHECK: scf.yield
+// CHECK: }
+
+module {
+ func.func private @foo()->()
+ func.func private @execute_region_multiple_yields_same_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+ %c = "test.cmp"() : () -> i1
+ cf.cond_br %c, ^bb2, ^bb3
+ ^bb2:
+ func.call @foo():()->()
+ scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
+ ^bb3:
+ func.call @foo():()->()
+ scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
+ }
+ return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
+ }
+}
+
+// -----
+
+// Test case with multiple scf.yield ops with at least one
diff erent operand, then no change.
+
+// CHECK: %[[VAL_3:.*]]:2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+// CHECK: ^bb1:
+// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
+// CHECK: ^bb2:
+// CHECK: scf.yield %{{.*}}, %{{.*}} : memref<1x60xui8>, memref<1x120xui8>
+// CHECK: }
+
+module {
+ func.func private @foo()->()
+ func.func private @execute_region_multiple_yields_
diff erent_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+ %c = "test.cmp"() : () -> i1
+ cf.cond_br %c, ^bb2, ^bb3
+ ^bb2:
+ func.call @foo():()->()
+ scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
+ ^bb3:
+ func.call @foo():()->()
+ scf.yield %alloc, %alloc_2 : memref<1x60xui8>, memref<1x120xui8>
+ }
+ return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
+ }
+}
+
+// -----
+
+// Test case with multiple scf.yield ops each has
diff erent operand.
+// In this case scf.execute_region isn't changed.
+
+// CHECK: %[[VAL_2:.*]] = scf.execute_region -> memref<1x60xui8> no_inline {
+// CHECK: ^bb1:
+// CHECK: scf.yield %{{.*}} : memref<1x60xui8>
+// CHECK: ^bb2:
+// CHECK: scf.yield %{{.*}} : memref<1x60xui8>
+// CHECK: }
+
+module {
+func.func private @foo()->()
+func.func private @execute_region_multiple_yields_
diff erent_operands() -> (memref<1x60xui8>) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %1 = scf.execute_region -> (memref<1x60xui8>) no_inline {
+ %c = "test.cmp"() : () -> i1
+ cf.cond_br %c, ^bb2, ^bb3
+ ^bb2:
+ func.call @foo():()->()
+ scf.yield %alloc : memref<1x60xui8>
+ ^bb3:
+ func.call @foo():()->()
+ scf.yield %alloc_1 : memref<1x60xui8>
+ }
+ return %1 : memref<1x60xui8>
+}
+}
+
+// -----
+
// CHECK-LABEL: func @canonicalize_parallel_insert_slice_indices(
// CHECK-SAME: %[[arg0:.*]]: tensor<1x5xf32>, %[[arg1:.*]]: tensor<?x?xf32>
func.func @canonicalize_parallel_insert_slice_indices(
More information about the Mlir-commits
mailing list