[Mlir-commits] [mlir] 0e760a0 - Add hook for dialect specializing processing blocks post inlining calls

Jacques Pienaar llvmlistbot at llvm.org
Wed Jun 16 12:53:37 PDT 2021


Author: Jacques Pienaar
Date: 2021-06-16T12:53:21-07:00
New Revision: 0e760a0870e61b0a150bdea24532ad054774ade4

URL: https://github.com/llvm/llvm-project/commit/0e760a0870e61b0a150bdea24532ad054774ade4
DIFF: https://github.com/llvm/llvm-project/commit/0e760a0870e61b0a150bdea24532ad054774ade4.diff

LOG: Add hook for dialect specializing processing blocks post inlining calls

This allows for dialects to do different post-processing depending on operations with the inliner (my use case requires different attribute propagation rules depending on call op). This hook runs before the regular processInlinedBlocks method.

Differential Revision: https://reviews.llvm.org/D104399

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/InliningUtils.h
    mlir/lib/Transforms/Utils/InliningUtils.cpp
    mlir/test/Transforms/inlining.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h
index a86a6b9cb08eb..8dcc1f5eb699d 100644
--- a/mlir/include/mlir/Transforms/InliningUtils.h
+++ b/mlir/include/mlir/Transforms/InliningUtils.h
@@ -140,6 +140,11 @@ class DialectInlinerInterface
                                                Location conversionLoc) const {
     return nullptr;
   }
+
+  /// Process a set of blocks that have been inlined for a call. This callback
+  /// is invoked before inlined terminator operations have been processed.
+  virtual void processInlinedCallBlocks(
+      Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {}
 };
 
 /// This interface provides the hooks into the inlining interface.
@@ -178,6 +183,8 @@ class InlinerInterface
   virtual void handleTerminator(Operation *op, Block *newDest) const;
   virtual void handleTerminator(Operation *op,
                                 ArrayRef<Value> valuesToRepl) const;
+  virtual void processInlinedCallBlocks(
+      Operation *call, iterator_range<Region::iterator> inlinedBlocks) const;
 };
 
 //===----------------------------------------------------------------------===//
@@ -209,8 +216,7 @@ LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
 /// providing the set of operands ('inlinedOperands') that should be used
 /// in-favor of the region arguments when inlining.
 LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
-                           Operation *inlinePoint,
-                           ValueRange inlinedOperands,
+                           Operation *inlinePoint, ValueRange inlinedOperands,
                            ValueRange resultsToReplace,
                            Optional<Location> inlineLoc = llvm::None,
                            bool shouldCloneInlinedRegion = true);

diff  --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 7d18de076e4bf..5b50d212fb075 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -106,6 +106,13 @@ void InlinerInterface::handleTerminator(Operation *op,
   handler->handleTerminator(op, valuesToRepl);
 }
 
+void InlinerInterface::processInlinedCallBlocks(
+    Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
+  auto *handler = getInterfaceFor(call);
+  assert(handler && "expected valid dialect handler");
+  handler->processInlinedCallBlocks(call, inlinedBlocks);
+}
+
 /// Utility to check that all of the operations within 'src' can be inlined.
 static bool isLegalToInline(InlinerInterface &interface, Region *src,
                             Region *insertRegion, bool shouldCloneInlinedRegion,
@@ -137,13 +144,12 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
 // Inline Methods
 //===----------------------------------------------------------------------===//
 
-LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
-                                 Operation *inlinePoint,
-                                 BlockAndValueMapping &mapper,
-                                 ValueRange resultsToReplace,
-                                 TypeRange regionResultTypes,
-                                 Optional<Location> inlineLoc,
-                                 bool shouldCloneInlinedRegion) {
+static LogicalResult
+inlineRegionImpl(InlinerInterface &interface, Region *src,
+                 Operation *inlinePoint, BlockAndValueMapping &mapper,
+                 ValueRange resultsToReplace, TypeRange regionResultTypes,
+                 Optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
+                 Operation *call) {
   assert(resultsToReplace.size() == regionResultTypes.size());
   // We expect the region to have at least one block.
   if (src->empty())
@@ -198,6 +204,8 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
     remapInlinedOperands(newBlocks, mapper);
 
   // Process the newly inlined blocks.
+  if (call)
+    interface.processInlinedCallBlocks(call, newBlocks);
   interface.processInlinedBlocks(newBlocks);
 
   // Handle the case where only a single block was inlined.
@@ -232,15 +240,11 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
   return success();
 }
 
-/// This function is an overload of the above 'inlineRegion' that allows for
-/// providing the set of operands ('inlinedOperands') that should be used
-/// in-favor of the region arguments when inlining.
-LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
-                                 Operation *inlinePoint,
-                                 ValueRange inlinedOperands,
-                                 ValueRange resultsToReplace,
-                                 Optional<Location> inlineLoc,
-                                 bool shouldCloneInlinedRegion) {
+static LogicalResult
+inlineRegionImpl(InlinerInterface &interface, Region *src,
+                 Operation *inlinePoint, ValueRange inlinedOperands,
+                 ValueRange resultsToReplace, Optional<Location> inlineLoc,
+                 bool shouldCloneInlinedRegion, Operation *call) {
   // We expect the region to have at least one block.
   if (src->empty())
     return failure();
@@ -261,9 +265,33 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
   }
 
   // Call into the main region inliner function.
-  return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace,
-                      resultsToReplace.getTypes(), inlineLoc,
-                      shouldCloneInlinedRegion);
+  return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace,
+                          resultsToReplace.getTypes(), inlineLoc,
+                          shouldCloneInlinedRegion, call);
+}
+
+LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
+                                 Operation *inlinePoint,
+                                 BlockAndValueMapping &mapper,
+                                 ValueRange resultsToReplace,
+                                 TypeRange regionResultTypes,
+                                 Optional<Location> inlineLoc,
+                                 bool shouldCloneInlinedRegion) {
+  return inlineRegionImpl(interface, src, inlinePoint, mapper, resultsToReplace,
+                          regionResultTypes, inlineLoc,
+                          shouldCloneInlinedRegion,
+                          /*call=*/nullptr);
+}
+
+LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
+                                 Operation *inlinePoint,
+                                 ValueRange inlinedOperands,
+                                 ValueRange resultsToReplace,
+                                 Optional<Location> inlineLoc,
+                                 bool shouldCloneInlinedRegion) {
+  return inlineRegionImpl(interface, src, inlinePoint, inlinedOperands,
+                          resultsToReplace, inlineLoc, shouldCloneInlinedRegion,
+                          /*call=*/nullptr);
 }
 
 /// Utility function used to generate a cast operation from the given interface,
@@ -371,9 +399,9 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
     return cleanupState();
 
   // Attempt to inline the call.
-  if (failed(inlineRegion(interface, src, call, mapper, callResults,
-                          callableResultTypes, call.getLoc(),
-                          shouldCloneInlinedRegion)))
+  if (failed(inlineRegionImpl(interface, src, call, mapper, callResults,
+                              callableResultTypes, call.getLoc(),
+                              shouldCloneInlinedRegion, call)))
     return cleanupState();
   return success();
 }

diff  --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir
index d568be0429a9c..e0368b25a2d27 100644
--- a/mlir/test/Transforms/inlining.mlir
+++ b/mlir/test/Transforms/inlining.mlir
@@ -140,9 +140,9 @@ func @convert_callee_fn_multiblock() -> i32 {
 
 // CHECK-LABEL: func @inline_convert_result_multiblock
 func @inline_convert_result_multiblock() -> i16 {
-// CHECK:   br ^bb1
+// CHECK:   br ^bb1 {inlined_conversion}
 // CHECK: ^bb1:
-// CHECK:   %[[C:.+]] = constant 0 : i32
+// CHECK:   %[[C:.+]] = constant {inlined_conversion} 0 : i32
 // CHECK:   br ^bb2(%[[C]] : i32)
 // CHECK: ^bb2(%[[BBARG:.+]]: i32):
 // CHECK:   %[[CAST_RESULT:.+]] = "test.cast"(%[[BBARG]]) : (i32) -> i16

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a21e32a12eff6..8ef6ec6000c6a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -171,6 +171,20 @@ struct TestInlinerInterface : public DialectInlinerInterface {
       return nullptr;
     return builder.create<TestCastOp>(conversionLoc, resultType, input);
   }
+
+  void processInlinedCallBlocks(
+      Operation *call,
+      iterator_range<Region::iterator> inlinedBlocks) const final {
+    if (!isa<ConversionCallOp>(call))
+      return;
+
+    // Set attributed on all ops in the inlined blocks.
+    for (Block &block : inlinedBlocks) {
+      block.walk([&](Operation *op) {
+        op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
+      });
+    }
+  }
 };
 
 struct TestReductionPatternInterface : public DialectReductionPatternInterface {


        


More information about the Mlir-commits mailing list