[Mlir-commits] [mlir] 0e760a0 - Add hook for dialect specializing processing blocks post inlining calls
Jacques Pienaar
llvmlistbot at llvm.org
Wed Jun 16 12:53:37 PDT 2021
Author: Jacques Pienaar
Date: 2021-06-16T12:53:21-07:00
New Revision: 0e760a0870e61b0a150bdea24532ad054774ade4
URL: https://github.com/llvm/llvm-project/commit/0e760a0870e61b0a150bdea24532ad054774ade4
DIFF: https://github.com/llvm/llvm-project/commit/0e760a0870e61b0a150bdea24532ad054774ade4.diff
LOG: Add hook for dialect specializing processing blocks post inlining calls
This allows for dialects to do different post-processing depending on operations with the inliner (my use case requires different attribute propagation rules depending on call op). This hook runs before the regular processInlinedBlocks method.
Differential Revision: https://reviews.llvm.org/D104399
Added:
Modified:
mlir/include/mlir/Transforms/InliningUtils.h
mlir/lib/Transforms/Utils/InliningUtils.cpp
mlir/test/Transforms/inlining.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h
index a86a6b9cb08eb..8dcc1f5eb699d 100644
--- a/mlir/include/mlir/Transforms/InliningUtils.h
+++ b/mlir/include/mlir/Transforms/InliningUtils.h
@@ -140,6 +140,11 @@ class DialectInlinerInterface
Location conversionLoc) const {
return nullptr;
}
+
+ /// Process a set of blocks that have been inlined for a call. This callback
+ /// is invoked before inlined terminator operations have been processed.
+ virtual void processInlinedCallBlocks(
+ Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {}
};
/// This interface provides the hooks into the inlining interface.
@@ -178,6 +183,8 @@ class InlinerInterface
virtual void handleTerminator(Operation *op, Block *newDest) const;
virtual void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const;
+ virtual void processInlinedCallBlocks(
+ Operation *call, iterator_range<Region::iterator> inlinedBlocks) const;
};
//===----------------------------------------------------------------------===//
@@ -209,8 +216,7 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
/// providing the set of operands ('inlinedOperands') that should be used
/// in-favor of the region arguments when inlining.
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
- Operation *inlinePoint,
- ValueRange inlinedOperands,
+ Operation *inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace,
Optional<Location> inlineLoc = llvm::None,
bool shouldCloneInlinedRegion = true);
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 7d18de076e4bf..5b50d212fb075 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -106,6 +106,13 @@ void InlinerInterface::handleTerminator(Operation *op,
handler->handleTerminator(op, valuesToRepl);
}
+void InlinerInterface::processInlinedCallBlocks(
+ Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
+ auto *handler = getInterfaceFor(call);
+ assert(handler && "expected valid dialect handler");
+ handler->processInlinedCallBlocks(call, inlinedBlocks);
+}
+
/// Utility to check that all of the operations within 'src' can be inlined.
static bool isLegalToInline(InlinerInterface &interface, Region *src,
Region *insertRegion, bool shouldCloneInlinedRegion,
@@ -137,13 +144,12 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
// Inline Methods
//===----------------------------------------------------------------------===//
-LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
- Operation *inlinePoint,
- BlockAndValueMapping &mapper,
- ValueRange resultsToReplace,
- TypeRange regionResultTypes,
- Optional<Location> inlineLoc,
- bool shouldCloneInlinedRegion) {
+static LogicalResult
+inlineRegionImpl(InlinerInterface &interface, Region *src,
+ Operation *inlinePoint, BlockAndValueMapping &mapper,
+ ValueRange resultsToReplace, TypeRange regionResultTypes,
+ Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
+ Operation *call) {
assert(resultsToReplace.size() == regionResultTypes.size());
// We expect the region to have at least one block.
if (src->empty())
@@ -198,6 +204,8 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
remapInlinedOperands(newBlocks, mapper);
// Process the newly inlined blocks.
+ if (call)
+ interface.processInlinedCallBlocks(call, newBlocks);
interface.processInlinedBlocks(newBlocks);
// Handle the case where only a single block was inlined.
@@ -232,15 +240,11 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
return success();
}
-/// This function is an overload of the above 'inlineRegion' that allows for
-/// providing the set of operands ('inlinedOperands') that should be used
-/// in-favor of the region arguments when inlining.
-LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
- Operation *inlinePoint,
- ValueRange inlinedOperands,
- ValueRange resultsToReplace,
- Optional<Location> inlineLoc,
- bool shouldCloneInlinedRegion) {
+static LogicalResult
+inlineRegionImpl(InlinerInterface &interface, Region *src,
+ Operation *inlinePoint, ValueRange inlinedOperands,
+ ValueRange resultsToReplace, Optional<Location> inlineLoc,
+ bool shouldCloneInlinedRegion, Operation *call) {
// We expect the region to have at least one block.
if (src->empty())
return failure();
@@ -261,9 +265,33 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
}
// Call into the main region inliner function.
- return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace,
- resultsToReplace.getTypes(), inlineLoc,
- shouldCloneInlinedRegion);
+ return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace,
+ resultsToReplace.getTypes(), inlineLoc,
+ shouldCloneInlinedRegion, call);
+}
+
+LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
+ Operation *inlinePoint,
+ BlockAndValueMapping &mapper,
+ ValueRange resultsToReplace,
+ TypeRange regionResultTypes,
+ Optional<Location> inlineLoc,
+ bool shouldCloneInlinedRegion) {
+ return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace,
+ regionResultTypes, inlineLoc,
+ shouldCloneInlinedRegion,
+ /*call=*/nullptr);
+}
+
+LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
+ Operation *inlinePoint,
+ ValueRange inlinedOperands,
+ ValueRange resultsToReplace,
+ Optional<Location> inlineLoc,
+ bool shouldCloneInlinedRegion) {
+ return inlineRegionImpl(interface, src, inlinePoint, inlinedOperands,
+ resultsToReplace, inlineLoc, shouldCloneInlinedRegion,
+ /*call=*/nullptr);
}
/// Utility function used to generate a cast operation from the given interface,
@@ -371,9 +399,9 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
return cleanupState();
// Attempt to inline the call.
- if (failed(inlineRegion(interface, src, call, mapper, callResults,
- callableResultTypes, call.getLoc(),
- shouldCloneInlinedRegion)))
+ if (failed(inlineRegionImpl(interface, src, call, mapper, callResults,
+ callableResultTypes, call.getLoc(),
+ shouldCloneInlinedRegion, call)))
return cleanupState();
return success();
}
diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir
index d568be0429a9c..e0368b25a2d27 100644
--- a/mlir/test/Transforms/inlining.mlir
+++ b/mlir/test/Transforms/inlining.mlir
@@ -140,9 +140,9 @@ func @convert_callee_fn_multiblock() -> i32 {
// CHECK-LABEL: func @inline_convert_result_multiblock
func @inline_convert_result_multiblock() -> i16 {
-// CHECK: br ^bb1
+// CHECK: br ^bb1 {inlined_conversion}
// CHECK: ^bb1:
-// CHECK: %[[C:.+]] = constant 0 : i32
+// CHECK: %[[C:.+]] = constant {inlined_conversion} 0 : i32
// CHECK: br ^bb2(%[[C]] : i32)
// CHECK: ^bb2(%[[BBARG:.+]]: i32):
// CHECK: %[[CAST_RESULT:.+]] = "test.cast"(%[[BBARG]]) : (i32) -> i16
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a21e32a12eff6..8ef6ec6000c6a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -171,6 +171,20 @@ struct TestInlinerInterface : public DialectInlinerInterface {
return nullptr;
return builder.create<TestCastOp>(conversionLoc, resultType, input);
}
+
+ void processInlinedCallBlocks(
+ Operation *call,
+ iterator_range<Region::iterator> inlinedBlocks) const final {
+ if (!isa<ConversionCallOp>(call))
+ return;
+
+ // Set attributed on all ops in the inlined blocks.
+ for (Block &block : inlinedBlocks) {
+ block.walk([&](Operation *op) {
+ op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
+ });
+ }
+ }
};
struct TestReductionPatternInterface : public DialectReductionPatternInterface {
More information about the Mlir-commits
mailing list