[Mlir-commits] [mlir] 501fda0 - [mlir][Inliner] Add a new hook for checking if it is legal to inline a callable into a call

River Riddle llvmlistbot at llvm.org
Wed Oct 28 21:55:20 PDT 2020


Author: River Riddle
Date: 2020-10-28T21:49:28-07:00
New Revision: 501fda0167341f2db0da5198f70defb017a36178

URL: https://github.com/llvm/llvm-project/commit/501fda0167341f2db0da5198f70defb017a36178
DIFF: https://github.com/llvm/llvm-project/commit/501fda0167341f2db0da5198f70defb017a36178.diff

LOG: [mlir][Inliner] Add a new hook for checking if it is legal to inline a callable into a call

In certain situations it isn't legal to inline a call operation, but this isn't something that is possible(at least not easily) to prevent with the current hooks. This revision adds a new hook so that dialects with call operations that shouldn't be inlined can prevent it.

Differential Revision: https://reviews.llvm.org/D90359

Added: 
    

Modified: 
    mlir/docs/Tutorials/Toy/Ch-4.md
    mlir/examples/toy/Ch4/mlir/Dialect.cpp
    mlir/examples/toy/Ch5/mlir/Dialect.cpp
    mlir/examples/toy/Ch6/mlir/Dialect.cpp
    mlir/examples/toy/Ch7/mlir/Dialect.cpp
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/Transforms/InliningUtils.h
    mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Transforms/Utils/InliningUtils.cpp
    mlir/test/Transforms/inlining.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md
index 11e6ddf50120..058041309ef7 100644
--- a/mlir/docs/Tutorials/Toy/Ch-4.md
+++ b/mlir/docs/Tutorials/Toy/Ch-4.md
@@ -61,6 +61,13 @@ In this case, the interface is `DialectInlinerInterface`.
 struct ToyInlinerInterface : public DialectInlinerInterface {
   using DialectInlinerInterface::DialectInlinerInterface;
 
+  /// This hook checks to see if the given callable operation is legal to inline
+  /// into the given call. For Toy this hook can simply return true, as the Toy
+  /// Call operation is always inlinable.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// This hook checks to see if the given operation is legal to inline into the
   /// given region. For Toy this hook can simply return true, as all Toy
   /// operations are inlinable.

diff  --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index ca568a55d8ea..462de2bb074e 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -34,6 +34,11 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  /// All call operations within toy can be inlined.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// All operations within toy can be inlined.
   bool isLegalToInline(Operation *, Region *,
                        BlockAndValueMapping &) const final {

diff  --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index d1a518ee8ed9..87bd185e0a47 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -34,6 +34,11 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  /// All call operations within toy can be inlined.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// All operations within toy can be inlined.
   bool isLegalToInline(Operation *, Region *,
                        BlockAndValueMapping &) const final {

diff  --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index d1a518ee8ed9..87bd185e0a47 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -34,6 +34,11 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  /// All call operations within toy can be inlined.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// All operations within toy can be inlined.
   bool isLegalToInline(Operation *, Region *,
                        BlockAndValueMapping &) const final {

diff  --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 046637f17eee..14d764ec71fa 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -35,6 +35,11 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  /// All call operations within toy can be inlined.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// All operations within toy can be inlined.
   bool isLegalToInline(Operation *, Region *,
                        BlockAndValueMapping &) const final {

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 6de7677dbf05..d3dce868ca64 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -323,11 +323,20 @@ class Operation final
   template <typename AttrClass> AttrClass getAttrOfType(Identifier name) {
     return getAttr(name).dyn_cast_or_null<AttrClass>();
   }
-
   template <typename AttrClass> AttrClass getAttrOfType(StringRef name) {
     return getAttr(name).dyn_cast_or_null<AttrClass>();
   }
 
+  /// Return true if the operation has an attribute with the provided name,
+  /// false otherwise.
+  bool hasAttr(Identifier name) { return static_cast<bool>(getAttr(name)); }
+  bool hasAttr(StringRef name) { return static_cast<bool>(getAttr(name)); }
+  template <typename AttrClass, typename NameT>
+  bool hasAttrOfType(NameT &&name) {
+    return static_cast<bool>(
+        getAttrOfType<AttrClass>(std::forward<NameT>(name)));
+  }
+
   /// If the an attribute exists with the specified name, change it to the new
   /// value.  Otherwise, add a new attribute with the specified name/value.
   void setAttr(Identifier name, Attribute value) { attrs.set(name, value); }

diff  --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h
index a526d0f1f33a..9c4fdf255a36 100644
--- a/mlir/include/mlir/Transforms/InliningUtils.h
+++ b/mlir/include/mlir/Transforms/InliningUtils.h
@@ -47,6 +47,14 @@ class DialectInlinerInterface
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  /// Returns true if the given operation 'callable', that implements the
+  /// 'CallableOpInterface', can be inlined into the position given call
+  /// operation 'call', that is registered to the current dialect and implements
+  /// the `CallOpInterface`.
+  virtual bool isLegalToInline(Operation *call, Operation *callable) const {
+    return false;
+  }
+
   /// Returns true if the given region 'src' can be inlined into the region
   /// 'dest' that is attached to an operation registered to the current dialect.
   /// 'valueMapping' contains any remapped values from within the 'src' region.
@@ -146,6 +154,7 @@ class InlinerInterface
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  virtual bool isLegalToInline(Operation *call, Operation *callable) const;
   virtual bool isLegalToInline(Region *dest, Region *src,
                                BlockAndValueMapping &valueMapping) const;
   virtual bool isLegalToInline(Operation *op, Region *dest,

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
index ac6d6150a826..874087681fd1 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
@@ -56,6 +56,11 @@ namespace {
 struct SPIRVInlinerInterface : public DialectInlinerInterface {
   using DialectInlinerInterface::DialectInlinerInterface;
 
+  /// All call operations within SPIRV can be inlined.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// Returns true if the given region 'src' can be inlined into the region
   /// 'dest' that is attached to an operation registered to the current dialect.
   bool isLegalToInline(Region *dest, Region *src,

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 48c3155cd105..9c8753d6d61e 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -46,6 +46,11 @@ struct StdInlinerInterface : public DialectInlinerInterface {
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  /// All call operations within standard ops can be inlined.
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    return true;
+  }
+
   /// All operations within standard ops can be inlined.
   bool isLegalToInline(Operation *, Region *,
                        BlockAndValueMapping &) const final {

diff  --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 4b7ae8024115..4e0251b2362e 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -57,6 +57,12 @@ static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks,
 // InlinerInterface
 //===----------------------------------------------------------------------===//
 
+bool InlinerInterface::isLegalToInline(Operation *call,
+                                       Operation *callable) const {
+  auto *handler = getInterfaceFor(call);
+  return handler ? handler->isLegalToInline(call, callable) : false;
+}
+
 bool InlinerInterface::isLegalToInline(
     Region *dest, Region *src, BlockAndValueMapping &valueMapping) const {
   // Regions can always be inlined into functions.
@@ -352,6 +358,10 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
     castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult);
   }
 
+  // Check that it is legal to inline the callable into the call.
+  if (!interface.isLegalToInline(call, callable))
+    return cleanupState();
+
   // Attempt to inline the call.
   if (failed(inlineRegion(interface, src, call, mapper, callResults,
                           callableResultTypes, call.getLoc(),

diff  --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir
index 9c9ed7031042..54bf6c67d927 100644
--- a/mlir/test/Transforms/inlining.mlir
+++ b/mlir/test/Transforms/inlining.mlir
@@ -183,3 +183,9 @@ func @inline_simplify() -> i32 {
   %res = call_indirect %fn() : () -> i32
   return %res : i32
 }
+
+// CHECK-LABEL: func @no_inline_invalid_call
+func @no_inline_invalid_call() -> i32 {
+  %res = "test.conversion_call_op"() { callee=@convert_callee_fn_multiblock, noinline } : () -> (i32)
+  return %res : i32
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index d2013d8c6941..8171367299af 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -77,6 +77,10 @@ struct TestInlinerInterface : public DialectInlinerInterface {
   // Analysis Hooks
   //===--------------------------------------------------------------------===//
 
+  bool isLegalToInline(Operation *call, Operation *callable) const final {
+    // Don't allow inlining calls that are marked `noinline`.
+    return !call->hasAttr("noinline");
+  }
   bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
     // Inlining into test dialect regions is legal.
     return true;


        


More information about the Mlir-commits mailing list