[Mlir-commits] [mlir] aeec945 - [mlir][inliner] Add doClone and canHandleMultipleBlocks callbacks to Inliner Config (#131226)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Apr 5 13:56:58 PDT 2025
Author: junfengd-nv
Date: 2025-04-05T22:56:55+02:00
New Revision: aeec94500a5dbd576e5d2d16895fe00fa0b1e154
URL: https://github.com/llvm/llvm-project/commit/aeec94500a5dbd576e5d2d16895fe00fa0b1e154
DIFF: https://github.com/llvm/llvm-project/commit/aeec94500a5dbd576e5d2d16895fe00fa0b1e154.diff
LOG: [mlir][inliner] Add doClone and canHandleMultipleBlocks callbacks to Inliner Config (#131226)
Current inliner disables inlining when the caller is in a region with
single block trait, while the callee function contains multiple blocks.
the SingleBlock trait is used in operations such as do/while loop, for
example fir.do_loop, fir.iterate_while and fir.if. Typically, calls within
loops are good candidates for inlining. However, functions with multiple
blocks are also common. for example, any function with "if () then
return" will result in multiple blocks in MLIR.
This change gives the flexibility of a customized inliner to handle such
cases.
doClone: clones instructions and other information from the callee
function into the caller function. .
canHandleMultipleBlocks: checks if functions with multiple blocks can be
inlined into a region with the SingleBlock trait.
The default behavior of the inliner remains unchanged.
---------
Co-authored-by: jeanPerier <jean.perier.polytechnique at gmail.com>
Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
Added:
mlir/test/Transforms/test-inlining-callback.mlir
mlir/test/lib/Transforms/TestInliningCallback.cpp
Modified:
mlir/include/mlir/Transforms/Inliner.h
mlir/include/mlir/Transforms/InliningUtils.h
mlir/lib/Transforms/Utils/Inliner.cpp
mlir/lib/Transforms/Utils/InliningUtils.cpp
mlir/test/lib/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/TestInlining.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/Inliner.h b/mlir/include/mlir/Transforms/Inliner.h
index ec77319d6ac88..506b4455af646 100644
--- a/mlir/include/mlir/Transforms/Inliner.h
+++ b/mlir/include/mlir/Transforms/Inliner.h
@@ -27,6 +27,11 @@ class InlinerConfig {
public:
using DefaultPipelineTy = std::function<void(OpPassManager &)>;
using OpPipelinesTy = llvm::StringMap<OpPassManager>;
+ using CloneCallbackSigTy = void(OpBuilder &builder, Region *src,
+ Block *inlineBlock, Block *postInsertBlock,
+ IRMapping &mapper,
+ bool shouldCloneInlinedRegion);
+ using CloneCallbackTy = std::function<CloneCallbackSigTy>;
InlinerConfig() = default;
InlinerConfig(DefaultPipelineTy defaultPipeline,
@@ -39,6 +44,9 @@ class InlinerConfig {
}
const OpPipelinesTy &getOpPipelines() const { return opPipelines; }
unsigned getMaxInliningIterations() const { return maxInliningIterations; }
+ const CloneCallbackTy &getCloneCallback() const { return cloneCallback; }
+ bool getCanHandleMultipleBlocks() const { return canHandleMultipleBlocks; }
+
void setDefaultPipeline(DefaultPipelineTy pipeline) {
defaultPipeline = std::move(pipeline);
}
@@ -46,6 +54,12 @@ class InlinerConfig {
opPipelines = std::move(pipelines);
}
void setMaxInliningIterations(unsigned max) { maxInliningIterations = max; }
+ void setCloneCallback(CloneCallbackTy callback) {
+ cloneCallback = std::move(callback);
+ }
+ void setCanHandleMultipleBlocks(bool value = true) {
+ canHandleMultipleBlocks = value;
+ }
private:
/// An optional function that constructs an optimization pipeline for
@@ -60,6 +74,28 @@ class InlinerConfig {
/// For SCC-based inlining algorithms, specifies maximum number of iterations
/// when inlining within an SCC.
unsigned maxInliningIterations{0};
+ /// Callback for cloning operations during inlining
+ CloneCallbackTy cloneCallback = [](OpBuilder &builder, Region *src,
+ Block *inlineBlock, Block *postInsertBlock,
+ IRMapping &mapper,
+ bool shouldCloneInlinedRegion) {
+ // Check to see if the region is being cloned, or moved inline. In
+ // either case, move the new blocks after the 'insertBlock' to improve
+ // IR readability.
+ Region *insertRegion = inlineBlock->getParent();
+ if (shouldCloneInlinedRegion)
+ src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
+ else
+ insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
+ src->getBlocks(), src->begin(),
+ src->end());
+ };
+ /// Determine if the inliner can inline a function containing multiple
+ /// blocks into a region that requires a single block. By default, it is
+ /// not allowed. If it is true, cloneCallback should perform the extra
+ /// transformation. see the example in
+ /// mlir/test/lib/Transforms/TestInliningCallback.cpp
+ bool canHandleMultipleBlocks{false};
};
/// This is an implementation of the inliner
diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h
index becfe9b047ef4..552030983d724 100644
--- a/mlir/include/mlir/Transforms/InliningUtils.h
+++ b/mlir/include/mlir/Transforms/InliningUtils.h
@@ -18,6 +18,7 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/ValueRange.h"
+#include "mlir/Transforms/Inliner.h"
#include <optional>
namespace mlir {
@@ -253,33 +254,39 @@ class InlinerInterface
/// provided, will be used to update the inlined operations' location
/// information. 'shouldCloneInlinedRegion' corresponds to whether the source
/// region should be cloned into the 'inlinePoint' or spliced directly.
-LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
- Operation *inlinePoint, IRMapping &mapper,
- ValueRange resultsToReplace,
- TypeRange regionResultTypes,
- std::optional<Location> inlineLoc = std::nullopt,
- bool shouldCloneInlinedRegion = true);
-LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
- Block *inlineBlock, Block::iterator inlinePoint,
- IRMapping &mapper, ValueRange resultsToReplace,
- TypeRange regionResultTypes,
- std::optional<Location> inlineLoc = std::nullopt,
- bool shouldCloneInlinedRegion = true);
+LogicalResult
+inlineRegion(InlinerInterface &interface,
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
+ Region *src, Operation *inlinePoint, IRMapping &mapper,
+ ValueRange resultsToReplace, TypeRange regionResultTypes,
+ std::optional<Location> inlineLoc = std::nullopt,
+ bool shouldCloneInlinedRegion = true);
+LogicalResult
+inlineRegion(InlinerInterface &interface,
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
+ Region *src, Block *inlineBlock, Block::iterator inlinePoint,
+ IRMapping &mapper, ValueRange resultsToReplace,
+ TypeRange regionResultTypes,
+ std::optional<Location> inlineLoc = std::nullopt,
+ bool shouldCloneInlinedRegion = true);
/// This function is an overload of the above 'inlineRegion' that allows for
/// providing the set of operands ('inlinedOperands') that should be used
/// in-favor of the region arguments when inlining.
-LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
- Operation *inlinePoint, ValueRange inlinedOperands,
- ValueRange resultsToReplace,
- std::optional<Location> inlineLoc = std::nullopt,
- bool shouldCloneInlinedRegion = true);
-LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
- Block *inlineBlock, Block::iterator inlinePoint,
- ValueRange inlinedOperands,
- ValueRange resultsToReplace,
- std::optional<Location> inlineLoc = std::nullopt,
- bool shouldCloneInlinedRegion = true);
+LogicalResult
+inlineRegion(InlinerInterface &interface,
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
+ Region *src, Operation *inlinePoint, ValueRange inlinedOperands,
+ ValueRange resultsToReplace,
+ std::optional<Location> inlineLoc = std::nullopt,
+ bool shouldCloneInlinedRegion = true);
+LogicalResult
+inlineRegion(InlinerInterface &interface,
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
+ Region *src, Block *inlineBlock, Block::iterator inlinePoint,
+ ValueRange inlinedOperands, ValueRange resultsToReplace,
+ std::optional<Location> inlineLoc = std::nullopt,
+ bool shouldCloneInlinedRegion = true);
/// This function inlines a given region, 'src', of a callable operation,
/// 'callable', into the location defined by the given call operation. This
@@ -287,9 +294,11 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
/// corresponds to whether the source region should be cloned into the 'call' or
/// spliced directly.
-LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call,
- CallableOpInterface callable, Region *src,
- bool shouldCloneInlinedRegion = true);
+LogicalResult
+inlineCall(InlinerInterface &interface,
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
+ CallOpInterface call, CallableOpInterface callable, Region *src,
+ bool shouldCloneInlinedRegion = true);
} // namespace mlir
diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index f511504594cfa..54b5c788a3526 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -652,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
LogicalResult inlineResult =
- inlineCall(inlinerIface, call,
+ inlineCall(inlinerIface, inliner.config.getCloneCallback(), call,
cast<CallableOpInterface>(targetRegion->getParentOp()),
targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
if (failed(inlineResult)) {
@@ -730,19 +730,22 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
// Don't allow inlining if the callee has multiple blocks (unstructured
// control flow) but we cannot be sure that the caller region supports that.
- bool calleeHasMultipleBlocks =
- llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
- // If both parent ops have the same type, it is safe to inline. Otherwise,
- // decide based on whether the op has the SingleBlock trait or not.
- // Note: This check does currently not account for SizedRegion/MaxSizedRegion.
- auto callerRegionSupportsMultipleBlocks = [&]() {
- return callableRegion->getParentOp()->getName() ==
- resolvedCall.call->getParentOp()->getName() ||
- !resolvedCall.call->getParentOp()
- ->mightHaveTrait<OpTrait::SingleBlock>();
- };
- if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
- return false;
+ if (!inliner.config.getCanHandleMultipleBlocks()) {
+ bool calleeHasMultipleBlocks =
+ llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
+ // If both parent ops have the same type, it is safe to inline. Otherwise,
+ // decide based on whether the op has the SingleBlock trait or not.
+ // Note: This check does currently not account for
+ // SizedRegion/MaxSizedRegion.
+ auto callerRegionSupportsMultipleBlocks = [&]() {
+ return callableRegion->getParentOp()->getName() ==
+ resolvedCall.call->getParentOp()->getName() ||
+ !resolvedCall.call->getParentOp()
+ ->mightHaveTrait<OpTrait::SingleBlock>();
+ };
+ if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
+ return false;
+ }
if (!inliner.isProfitableToInline(resolvedCall))
return false;
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index e113389b26ae7..3dd95d2845715 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/InliningUtils.h"
+#include "mlir/Transforms/Inliner.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
@@ -266,10 +267,11 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
}
static LogicalResult
-inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
- Block::iterator inlinePoint, IRMapping &mapper,
- ValueRange resultsToReplace, TypeRange regionResultTypes,
- std::optional<Location> inlineLoc,
+inlineRegionImpl(InlinerInterface &interface,
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
+ Region *src, Block *inlineBlock, Block::iterator inlinePoint,
+ IRMapping &mapper, ValueRange resultsToReplace,
+ TypeRange regionResultTypes, std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
assert(resultsToReplace.size() == regionResultTypes.size());
// We expect the region to have at least one block.
@@ -296,16 +298,10 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
if (call && callable)
handleArgumentImpl(interface, builder, call, callable, mapper);
- // Check to see if the region is being cloned, or moved inline. In either
- // case, move the new blocks after the 'insertBlock' to improve IR
- // readability.
+ // Clone the callee's source into the caller.
Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
- if (shouldCloneInlinedRegion)
- src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
- else
- insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
- src->getBlocks(), src->begin(),
- src->end());
+ cloneCallback(builder, src, inlineBlock, postInsertBlock, mapper,
+ shouldCloneInlinedRegion);
// Get the range of newly inserted blocks.
auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
@@ -374,9 +370,11 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
}
static LogicalResult
-inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
- Block::iterator inlinePoint, ValueRange inlinedOperands,
- ValueRange resultsToReplace, std::optional<Location> inlineLoc,
+inlineRegionImpl(InlinerInterface &interface,
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
+ Region *src, Block *inlineBlock, Block::iterator inlinePoint,
+ ValueRange inlinedOperands, ValueRange resultsToReplace,
+ std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
// We expect the region to have at least one block.
if (src->empty())
@@ -398,53 +396,54 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
}
// Call into the main region inliner function.
- return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
- resultsToReplace, resultsToReplace.getTypes(),
- inlineLoc, shouldCloneInlinedRegion, call);
+ return inlineRegionImpl(interface, cloneCallback, src, inlineBlock,
+ inlinePoint, mapper, resultsToReplace,
+ resultsToReplace.getTypes(), inlineLoc,
+ shouldCloneInlinedRegion, call);
}
-LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
- Operation *inlinePoint, IRMapping &mapper,
- ValueRange resultsToReplace,
- TypeRange regionResultTypes,
- std::optional<Location> inlineLoc,
- bool shouldCloneInlinedRegion) {
- return inlineRegion(interface, src, inlinePoint->getBlock(),
+LogicalResult mlir::inlineRegion(
+ InlinerInterface &interface,
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
+ Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace,
+ TypeRange regionResultTypes, std::optional<Location> inlineLoc,
+ bool shouldCloneInlinedRegion) {
+ return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(),
++inlinePoint->getIterator(), mapper, resultsToReplace,
regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
}
-LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
- Block *inlineBlock,
- Block::iterator inlinePoint, IRMapping &mapper,
- ValueRange resultsToReplace,
- TypeRange regionResultTypes,
- std::optional<Location> inlineLoc,
- bool shouldCloneInlinedRegion) {
- return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
- resultsToReplace, regionResultTypes, inlineLoc,
- shouldCloneInlinedRegion);
+
+LogicalResult mlir::inlineRegion(
+ InlinerInterface &interface,
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
+ Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper,
+ ValueRange resultsToReplace, TypeRange regionResultTypes,
+ std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
+ return inlineRegionImpl(
+ interface, cloneCallback, src, inlineBlock, inlinePoint, mapper,
+ resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
}
-LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
- Operation *inlinePoint,
- ValueRange inlinedOperands,
- ValueRange resultsToReplace,
- std::optional<Location> inlineLoc,
- bool shouldCloneInlinedRegion) {
- return inlineRegion(interface, src, inlinePoint->getBlock(),
+LogicalResult mlir::inlineRegion(
+ InlinerInterface &interface,
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
+ Operation *inlinePoint, ValueRange inlinedOperands,
+ ValueRange resultsToReplace, std::optional<Location> inlineLoc,
+ bool shouldCloneInlinedRegion) {
+ return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(),
++inlinePoint->getIterator(), inlinedOperands,
resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
}
-LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
- Block *inlineBlock,
- Block::iterator inlinePoint,
- ValueRange inlinedOperands,
- ValueRange resultsToReplace,
- std::optional<Location> inlineLoc,
- bool shouldCloneInlinedRegion) {
- return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
- inlinedOperands, resultsToReplace, inlineLoc,
- shouldCloneInlinedRegion);
+
+LogicalResult mlir::inlineRegion(
+ InlinerInterface &interface,
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
+ Block *inlineBlock, Block::iterator inlinePoint, ValueRange inlinedOperands,
+ ValueRange resultsToReplace, std::optional<Location> inlineLoc,
+ bool shouldCloneInlinedRegion) {
+ return inlineRegionImpl(interface, cloneCallback, src, inlineBlock,
+ inlinePoint, inlinedOperands, resultsToReplace,
+ inlineLoc, shouldCloneInlinedRegion);
}
/// Utility function used to generate a cast operation from the given interface,
@@ -475,10 +474,11 @@ static Value materializeConversion(const DialectInlinerInterface *interface,
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
/// corresponds to whether the source region should be cloned into the 'call' or
/// spliced directly.
-LogicalResult mlir::inlineCall(InlinerInterface &interface,
- CallOpInterface call,
- CallableOpInterface callable, Region *src,
- bool shouldCloneInlinedRegion) {
+LogicalResult
+mlir::inlineCall(InlinerInterface &interface,
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
+ CallOpInterface call, CallableOpInterface callable,
+ Region *src, bool shouldCloneInlinedRegion) {
// We expect the region to have at least one block.
if (src->empty())
return failure();
@@ -552,7 +552,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
return cleanupState();
// Attempt to inline the call.
- if (failed(inlineRegionImpl(interface, src, call->getBlock(),
+ if (failed(inlineRegionImpl(interface, cloneCallback, src, call->getBlock(),
++call->getIterator(), mapper, callResults,
callableResultTypes, call.getLoc(),
shouldCloneInlinedRegion, call)))
diff --git a/mlir/test/Transforms/test-inlining-callback.mlir b/mlir/test/Transforms/test-inlining-callback.mlir
new file mode 100644
index 0000000000000..c012c31e7e490
--- /dev/null
+++ b/mlir/test/Transforms/test-inlining-callback.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -test-inline-callback | FileCheck %s
+
+// Test inlining with multiple blocks and scf.execute_region transformation
+// CHECK-LABEL: func @test_inline_multiple_blocks
+func.func @test_inline_multiple_blocks(%arg0: i32) -> i32 {
+ // CHECK: %[[RES:.*]] = scf.execute_region -> i32
+ // CHECK-NEXT: %[[ADD1:.*]] = arith.addi %arg0, %arg0
+ // CHECK-NEXT: cf.br ^bb1(%[[ADD1]] : i32)
+ // CHECK: ^bb1(%[[ARG:.*]]: i32):
+ // CHECK-NEXT: %[[ADD2:.*]] = arith.addi %[[ARG]], %[[ARG]]
+ // CHECK-NEXT: scf.yield %[[ADD2]]
+ // CHECK: return %[[RES]]
+ %fn = "test.functional_region_op"() ({
+ ^bb0(%a : i32):
+ %b = arith.addi %a, %a : i32
+ cf.br ^bb1(%b: i32)
+ ^bb1(%c: i32):
+ %d = arith.addi %c, %c : i32
+ "test.return"(%d) : (i32) -> ()
+ }) : () -> ((i32) -> i32)
+
+ %0 = call_indirect %fn(%arg0) : (i32) -> i32
+ return %0 : i32
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index c053fd4b20473..76041cd6cd791 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -29,6 +29,7 @@ add_mlir_library(MLIRTestTransforms
TestConstantFold.cpp
TestControlFlowSink.cpp
TestInlining.cpp
+ TestInliningCallback.cpp
TestMakeIsolatedFromAbove.cpp
TestTransformsOps.cpp
${MLIRTestTransformsPDLSrc}
diff --git a/mlir/test/lib/Transforms/TestInlining.cpp b/mlir/test/lib/Transforms/TestInlining.cpp
index 223cc78dd1e21..ae904a92a5d68 100644
--- a/mlir/test/lib/Transforms/TestInlining.cpp
+++ b/mlir/test/lib/Transforms/TestInlining.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Inliner.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/StringSet.h"
@@ -25,8 +26,9 @@ using namespace mlir;
using namespace test;
namespace {
-struct Inliner : public PassWrapper<Inliner, OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Inliner)
+struct InlinerTest
+ : public PassWrapper<InlinerTest, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InlinerTest)
StringRef getArgument() const final { return "test-inline"; }
StringRef getDescription() const final {
@@ -34,6 +36,8 @@ struct Inliner : public PassWrapper<Inliner, OperationPass<func::FuncOp>> {
}
void runOnOperation() override {
+ InlinerConfig config;
+
auto function = getOperation();
// Collect each of the direct function calls within the module.
@@ -54,8 +58,8 @@ struct Inliner : public PassWrapper<Inliner, OperationPass<func::FuncOp>> {
// Inline the functional region operation, but only clone the internal
// region if there is more than one use.
if (failed(inlineRegion(
- interface, &callee.getBody(), caller, caller.getArgOperands(),
- caller.getResults(), caller.getLoc(),
+ interface, config.getCloneCallback(), &callee.getBody(), caller,
+ caller.getArgOperands(), caller.getResults(), caller.getLoc(),
/*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse())))
continue;
@@ -71,6 +75,6 @@ struct Inliner : public PassWrapper<Inliner, OperationPass<func::FuncOp>> {
namespace mlir {
namespace test {
-void registerInliner() { PassRegistration<Inliner>(); }
+void registerInliner() { PassRegistration<InlinerTest>(); }
} // namespace test
} // namespace mlir
diff --git a/mlir/test/lib/Transforms/TestInliningCallback.cpp b/mlir/test/lib/Transforms/TestInliningCallback.cpp
new file mode 100644
index 0000000000000..012d62b7b1b42
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestInliningCallback.cpp
@@ -0,0 +1,151 @@
+//===- TestInliningCallback.cpp - Pass to inline calls in the test dialect
+//--------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This file implements a pass to test inlining callbacks including
+// canHandleMultipleBlocks and doClone.
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "TestOps.h"
+#include "mlir/Analysis/CallGraph.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Inliner.h"
+#include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/StringSet.h"
+
+using namespace mlir;
+using namespace test;
+
+namespace {
+struct InlinerCallback
+ : public PassWrapper<InlinerCallback, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InlinerCallback)
+
+ StringRef getArgument() const final { return "test-inline-callback"; }
+ StringRef getDescription() const final {
+ return "Test inlining region calls with call back functions";
+ }
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<scf::SCFDialect>();
+ }
+
+ static LogicalResult runPipelineHelper(Pass &pass, OpPassManager &pipeline,
+ Operation *op) {
+ return mlir::cast<InlinerCallback>(pass).runPipeline(pipeline, op);
+ }
+
+ // Customize the implementation of Inliner::doClone
+ // Wrap the callee into scf.execute_region operation
+ static void testDoClone(OpBuilder &builder, Region *src, Block *inlineBlock,
+ Block *postInsertBlock, IRMapping &mapper,
+ bool shouldCloneInlinedRegion) {
+ // Create a new scf.execute_region operation
+ mlir::Operation &call = inlineBlock->back();
+ builder.setInsertionPointAfter(&call);
+
+ auto executeRegionOp = builder.create<mlir::scf::ExecuteRegionOp>(
+ call.getLoc(), call.getResultTypes());
+ mlir::Region ®ion = executeRegionOp.getRegion();
+
+ // Move the inlined blocks into the region
+ src->cloneInto(®ion, mapper);
+
+ // Split block before scf operation.
+ Block *continueBlock =
+ inlineBlock->splitBlock(executeRegionOp.getOperation());
+
+ // Replace all test.return with scf.yield
+ for (mlir::Block &block : region) {
+
+ for (mlir::Operation &op : llvm::make_early_inc_range(block)) {
+ if (test::TestReturnOp returnOp =
+ llvm::dyn_cast<test::TestReturnOp>(&op)) {
+ mlir::OpBuilder returnBuilder(returnOp);
+ returnBuilder.create<mlir::scf::YieldOp>(returnOp.getLoc(),
+ returnOp.getOperands());
+ returnOp.erase();
+ }
+ }
+ }
+
+ // Add test.return after scf.execute_region
+ builder.setInsertionPointAfter(executeRegionOp);
+ builder.create<test::TestReturnOp>(executeRegionOp.getLoc(),
+ executeRegionOp.getResults());
+ }
+
+ void runOnOperation() override {
+ InlinerConfig config;
+ CallGraph &cg = getAnalysis<CallGraph>();
+
+ func::FuncOp function = getOperation();
+
+ // By default, assume that any inlining is profitable.
+ auto profitabilityCb = [&](const mlir::Inliner::ResolvedCall &call) {
+ return true;
+ };
+
+ // Set the clone callback in the config
+ config.setCloneCallback([](OpBuilder &builder, Region *src,
+ Block *inlineBlock, Block *postInsertBlock,
+ IRMapping &mapper,
+ bool shouldCloneInlinedRegion) {
+ return testDoClone(builder, src, inlineBlock, postInsertBlock, mapper,
+ shouldCloneInlinedRegion);
+ });
+
+ // Set canHandleMultipleBlocks to true in the config
+ config.setCanHandleMultipleBlocks();
+
+ // Get an instance of the inliner.
+ Inliner inliner(function, cg, *this, getAnalysisManager(),
+ runPipelineHelper, config, profitabilityCb);
+
+ // Collect each of the direct function calls within the module.
+ SmallVector<func::CallIndirectOp> callers;
+ function.walk(
+ [&](func::CallIndirectOp caller) { callers.push_back(caller); });
+
+ // Build the inliner interface.
+ InlinerInterface interface(&getContext());
+
+ // Try to inline each of the call operations.
+ for (auto caller : callers) {
+ auto callee = dyn_cast_or_null<FunctionalRegionOp>(
+ caller.getCallee().getDefiningOp());
+ if (!callee)
+ continue;
+
+ // Inline the functional region operation, but only clone the internal
+ // region if there is more than one use.
+ if (failed(inlineRegion(
+ interface, config.getCloneCallback(), &callee.getBody(), caller,
+ caller.getArgOperands(), caller.getResults(), caller.getLoc(),
+ /*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse())))
+ continue;
+
+ // If the inlining was successful then erase the call and callee if
+ // possible.
+ caller.erase();
+ if (callee.use_empty())
+ callee.erase();
+ }
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerInlinerCallback() { PassRegistration<InlinerCallback>(); }
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index d06ff8070e7cf..ca4706e96787f 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -73,6 +73,7 @@ void registerCommutativityUtils();
void registerConvertCallOpPass();
void registerConvertFuncOpPass();
void registerInliner();
+void registerInlinerCallback();
void registerMemRefBoundCheck();
void registerPatternsTestPass();
void registerSimpleParametricTilingPass();
@@ -215,6 +216,7 @@ void registerTestPasses() {
mlir::test::registerConvertCallOpPass();
mlir::test::registerConvertFuncOpPass();
mlir::test::registerInliner();
+ mlir::test::registerInlinerCallback();
mlir::test::registerMemRefBoundCheck();
mlir::test::registerPatternsTestPass();
mlir::test::registerSimpleParametricTilingPass();
More information about the Mlir-commits
mailing list