[Mlir-commits] [mlir] f809eb4 - [mlir] Argument and result attribute handling during inlining.

Tobias Gysi llvmlistbot at llvm.org
Wed Mar 22 01:08:40 PDT 2023


Author: Tobias Gysi
Date: 2023-03-22T09:02:15+01:00
New Revision: f809eb4db2d14a5a529f9f440b849b7489292976

URL: https://github.com/llvm/llvm-project/commit/f809eb4db2d14a5a529f9f440b849b7489292976
DIFF: https://github.com/llvm/llvm-project/commit/f809eb4db2d14a5a529f9f440b849b7489292976.diff

LOG: [mlir] Argument and result attribute handling during inlining.

The revision adds the handleArgument and handleResult handlers that
allow users of the inlining interface to implement argument and result
conversions that take argument and result attributes into account. The
motivating use cases for this revision are taken from the LLVM dialect
inliner, which has to copy arguments that are marked as byval and that
also has to consider zeroext / signext when converting integers.

All type conversions are currently handled by the
materializeCallConversion hook. It runs before isLegalToInline and
supports only the introduction of a single cast operation since it may
have to rollback. The new handlers run shortly before and after
inlining and cannot fail. As a result, they can introduce more complex
ir such as copying a struct argument. At the moment, the new hooks
cannot be used to perform type conversions since all type conversions
have to be done using the materializeCallConversion. A follow up
revision will either relax this constraint or drop
materializeCallConversion in favor of the new and more flexible
handlers.

The revision also extends the CallableOpInterface to provide access
to the argument and result attributes if available.

Reviewed By: rriddle, Dinistro

Differential Revision: https://reviews.llvm.org/D145582

Added: 
    

Modified: 
    mlir/docs/Interfaces.md
    mlir/docs/Tutorials/Toy/Ch-4.md
    mlir/examples/toy/Ch4/mlir/Dialect.cpp
    mlir/examples/toy/Ch5/mlir/Dialect.cpp
    mlir/examples/toy/Ch6/mlir/Dialect.cpp
    mlir/examples/toy/Ch7/mlir/Dialect.cpp
    mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
    mlir/include/mlir/Dialect/Func/IR/FuncOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/include/mlir/Interfaces/CallInterfaces.td
    mlir/include/mlir/Transforms/InliningUtils.h
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Transforms/Utils/InliningUtils.cpp
    mlir/test/Transforms/inlining.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index 6bb507013863..b51adec4fc4f 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -731,6 +731,8 @@ interface section goes as follows:
 *   `CallableOpInterface` - Used to represent the target callee of call.
     -   `Region * getCallableRegion()`
     -   `ArrayRef<Type> getCallableResults()`
+    -   `ArrayAttr getCallableArgAttrs()`
+    -   `ArrayAttr getCallableResAttrs()`
 
 ##### RegionKindInterfaces
 

diff  --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md
index 77a52163774f..f462274fa592 100644
--- a/mlir/docs/Tutorials/Toy/Ch-4.md
+++ b/mlir/docs/Tutorials/Toy/Ch-4.md
@@ -169,6 +169,18 @@ Region *FuncOp::getCallableRegion() { return &getBody(); }
 /// executed.
 ArrayRef<Type> FuncOp::getCallableResults() { return getType().getResults(); }
 
+/// Returns the argument attributes for all callable region arguments or
+/// null if there are none.
+ArrayAttr FuncOp::getCallableArgAttrs() {
+  return getArgAttrs().value_or(nullptr);
+}
+
+/// Returns the result attributes for all callable region results or
+/// null if there are none.
+ArrayAttr FuncOp::getCallableResAttrs() {
+  return getResAttrs().value_or(nullptr);
+}
+
 // ....
 
 /// Return the callee of the generic call operation, this is required by the

diff  --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
index 17a42d69c8f4..f5258eb5cff1 100644
--- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp
@@ -307,6 +307,18 @@ llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
   return getFunctionType().getResults();
 }
 
+/// Returns the argument attributes for all callable region arguments or
+/// null if there are none.
+ArrayAttr FuncOp::getCallableArgAttrs() {
+  return getArgAttrs().value_or(nullptr);
+}
+
+/// Returns the result attributes for all callable region results or
+/// null if there are none.
+ArrayAttr FuncOp::getCallableResAttrs() {
+  return getResAttrs().value_or(nullptr);
+}
+
 //===----------------------------------------------------------------------===//
 // GenericCallOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
index 77ceb636e17f..a959969c0449 100644
--- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp
@@ -307,6 +307,18 @@ llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
   return getFunctionType().getResults();
 }
 
+/// Returns the argument attributes for all callable region arguments or
+/// null if there are none.
+ArrayAttr FuncOp::getCallableArgAttrs() {
+  return getArgAttrs().value_or(nullptr);
+}
+
+/// Returns the result attributes for all callable region results or
+/// null if there are none.
+ArrayAttr FuncOp::getCallableResAttrs() {
+  return getResAttrs().value_or(nullptr);
+}
+
 //===----------------------------------------------------------------------===//
 // GenericCallOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
index 77ceb636e17f..a959969c0449 100644
--- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp
@@ -307,6 +307,18 @@ llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
   return getFunctionType().getResults();
 }
 
+/// Returns the argument attributes for all callable region arguments or
+/// null if there are none.
+ArrayAttr FuncOp::getCallableArgAttrs() {
+  return getArgAttrs().value_or(nullptr);
+}
+
+/// Returns the result attributes for all callable region results or
+/// null if there are none.
+ArrayAttr FuncOp::getCallableResAttrs() {
+  return getResAttrs().value_or(nullptr);
+}
+
 //===----------------------------------------------------------------------===//
 // GenericCallOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
index 188b94fc2dfe..d332411b63bb 100644
--- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp
+++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp
@@ -336,6 +336,18 @@ llvm::ArrayRef<mlir::Type> FuncOp::getCallableResults() {
   return getFunctionType().getResults();
 }
 
+/// Returns the argument attributes for all callable region arguments or
+/// null if there are none.
+ArrayAttr FuncOp::getCallableArgAttrs() {
+  return getArgAttrs().value_or(nullptr);
+}
+
+/// Returns the result attributes for all callable region results or
+/// null if there are none.
+ArrayAttr FuncOp::getCallableResAttrs() {
+  return getResAttrs().value_or(nullptr);
+}
+
 //===----------------------------------------------------------------------===//
 // GenericCallOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index 2cf5ee810b7a..30147b8b6a30 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -168,6 +168,18 @@ def Async_FuncOp : Async_Op<"func",
     ArrayRef<Type> getCallableResults() { return getFunctionType()
                                                     .getResults(); }
 
+    /// Returns the argument attributes for all callable region arguments or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableArgAttrs() {
+      return getArgAttrs().value_or(nullptr);
+    }
+
+    /// Returns the result attributes for all callable region results or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableResAttrs() {
+      return getResAttrs().value_or(nullptr);
+    }
+
     //===------------------------------------------------------------------===//
     // FunctionOpInterface Methods
     //===------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
index 45ec8a9e0b7e..1a06d6533b2d 100644
--- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
+++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td
@@ -299,6 +299,18 @@ def FuncOp : Func_Op<"func", [
     /// executed.
     ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
 
+    /// Returns the argument attributes for all callable region arguments or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableArgAttrs() {
+      return getArgAttrs().value_or(nullptr);
+    }
+
+    /// Returns the result attributes for all callable region results or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableResAttrs() {
+      return getResAttrs().value_or(nullptr);
+    }
+
     //===------------------------------------------------------------------===//
     // FunctionOpInterface Methods
     //===------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index c2bb2f34a463..1bbc32f3d291 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1583,6 +1583,10 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     /// Returns the result types of this function.
     ArrayRef<Type> getResultTypes() { return getFunctionType().getReturnTypes(); }
 
+    //===------------------------------------------------------------------===//
+    // CallableOpInterface
+    //===------------------------------------------------------------------===//
+
     /// Returns the callable region, which is the function body. If the function
     /// is external, returns null.
     Region *getCallableRegion();
@@ -1596,6 +1600,17 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
       return getFunctionType().getReturnTypes();
     }
 
+    /// Returns the argument attributes for all callable region arguments or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableArgAttrs() {
+      return getArgAttrs().value_or(nullptr);
+    }
+
+    /// Returns the result attributes for all callable region results or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableResAttrs() {
+      return getResAttrs().value_or(nullptr);
+    }
   }];
 
   let hasCustomAssemblyFormat = 1;

diff  --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
index db6c7733130c..7984b9744513 100644
--- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
+++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td
@@ -73,6 +73,18 @@ def MLProgram_FuncOp : MLProgram_Op<"func", [
     /// executed.
     ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
 
+    /// Returns the argument attributes for all callable region arguments or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableArgAttrs() {
+      return getArgAttrs().value_or(nullptr);
+    }
+
+    /// Returns the result attributes for all callable region results or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableResAttrs() {
+      return getResAttrs().value_or(nullptr);
+    }
+
     //===------------------------------------------------------------------===//
     // FunctionOpInterface Methods
     //===------------------------------------------------------------------===//
@@ -422,6 +434,18 @@ def MLProgram_SubgraphOp : MLProgram_Op<"subgraph", [
     /// executed.
     ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
 
+    /// Returns the argument attributes for all callable region arguments or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableArgAttrs() {
+      return getArgAttrs().value_or(nullptr);
+    }
+
+    /// Returns the result attributes for all callable region results or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableResAttrs() {
+      return getResAttrs().value_or(nullptr);
+    }
+
     //===------------------------------------------------------------------===//
     // FunctionOpInterface Methods
     //===------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index ae84b07acab2..47918b46dddc 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -1149,6 +1149,18 @@ def Shape_FuncOp : Shape_Op<"func",
       return getFunctionType().getResults();
     }
 
+    /// Returns the argument attributes for all callable region arguments or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableArgAttrs() {
+      return getArgAttrs().value_or(nullptr);
+    }
+
+    /// Returns the result attributes for all callable region results or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableResAttrs() {
+      return getResAttrs().value_or(nullptr);
+    }
+
     //===------------------------------------------------------------------===//
     // FunctionOpInterface Methods
     //===------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 3ffc3f71433c..46dea7454635 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -394,6 +394,12 @@ def NamedSequenceOp : TransformDialectOp<"named_sequence",
     ::llvm::ArrayRef<::mlir::Type> getCallableResults() {
       return getFunctionType().getResults();
     }
+    ::mlir::ArrayAttr getCallableArgAttrs() {
+      return getArgAttrs().value_or(nullptr);
+    }
+    ::mlir::ArrayAttr getCallableResAttrs() {
+      return getResAttrs().value_or(nullptr);
+    }
   }];
 }
 

diff  --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td
index 96540675f833..cd37222cbc27 100644
--- a/mlir/include/mlir/Interfaces/CallInterfaces.td
+++ b/mlir/include/mlir/Interfaces/CallInterfaces.td
@@ -84,6 +84,18 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
       }],
       "::llvm::ArrayRef<::mlir::Type>", "getCallableResults"
     >,
+    InterfaceMethod<[{
+        Returns the argument attributes for all callable region arguments or
+        null if there are none.
+      }],
+      "::mlir::ArrayAttr", "getCallableArgAttrs"
+    >,
+    InterfaceMethod<[{
+        Returns the result attributes for all callable region results or null
+        if there are none.
+      }],
+      "::mlir::ArrayAttr", "getCallableResAttrs"
+    >
   ];
 }
 

diff  --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h
index 241983ef8c3d..63aba6a08e39 100644
--- a/mlir/include/mlir/Transforms/InliningUtils.h
+++ b/mlir/include/mlir/Transforms/InliningUtils.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_TRANSFORMS_INLININGUTILS_H
 #define MLIR_TRANSFORMS_INLININGUTILS_H
 
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Region.h"
@@ -141,6 +142,40 @@ class DialectInlinerInterface
     return nullptr;
   }
 
+  /// Hook to transform the call arguments before using them to replace the
+  /// callee arguments. It returns the transformation result or `argument`
+  /// itself if the hook did not change anything. The type of the returned value
+  /// has to match `targetType`, and the `argumentAttrs` dictionary is non-null
+  /// even if no attribute is present. The hook is called after converting the
+  /// callsite argument types using the materializeCallConversion callback, and
+  /// right before inlining the callee region. Any operations created using the
+  /// provided `builder` are inserted right before the inlined callee region.
+  /// Example use cases are the insertion of copies for by value arguments, or
+  /// integer conversions that require signedness information.
+  virtual Value handleArgument(OpBuilder &builder, Operation *call,
+                               Operation *callable, Value argument,
+                               Type targetType,
+                               DictionaryAttr argumentAttrs) const {
+    return argument;
+  }
+
+  /// Hook to transform the callee results before using them to replace the call
+  /// results. It returns the transformation result or the `result` itself if
+  /// the hook did not change anything. The type of the returned values has to
+  /// match `targetType`, and the `resultAttrs` dictionary is non-null even if
+  /// no attribute is present. The hook is called right before handling
+  /// terminators, and obtains the callee result before converting its type
+  /// using the `materializeCallConversion` callback. Any operations created
+  /// using the provided `builder` are inserted right after the inlined callee
+  /// region. Example use cases are the insertion of copies for by value results
+  /// or integer conversions that require signedness information.
+  /// NOTE: This hook is invoked after inlining the `callable` region.
+  virtual Value handleResult(OpBuilder &builder, Operation *call,
+                             Operation *callable, Value result, Type targetType,
+                             DictionaryAttr resultAttrs) const {
+    return result;
+  }
+
   /// Process a set of blocks that have been inlined for a call. This callback
   /// is invoked before inlined terminator operations have been processed.
   virtual void processInlinedCallBlocks(
@@ -183,6 +218,15 @@ class InlinerInterface
   virtual void handleTerminator(Operation *op, Block *newDest) const;
   virtual void handleTerminator(Operation *op,
                                 ArrayRef<Value> valuesToRepl) const;
+
+  virtual Value handleArgument(OpBuilder &builder, Operation *call,
+                               Operation *callable, Value argument,
+                               Type targetType,
+                               DictionaryAttr argumentAttrs) const;
+  virtual Value handleResult(OpBuilder &builder, Operation *call,
+                             Operation *callable, Value result, Type targetType,
+                             DictionaryAttr resultAttrs) const;
+
   virtual void processInlinedCallBlocks(
       Operation *call, iterator_range<Region::iterator> inlinedBlocks) const;
 };

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index f6865b410709..bb3ad91ce620 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2469,6 +2469,16 @@ ArrayRef<Type> spirv::FuncOp::getCallableResults() {
   return getFunctionType().getResults();
 }
 
+// CallableOpInterface
+::mlir::ArrayAttr spirv::FuncOp::getCallableArgAttrs() {
+  return getArgAttrs().value_or(nullptr);
+}
+
+// CallableOpInterface
+::mlir::ArrayAttr spirv::FuncOp::getCallableResAttrs() {
+  return getResAttrs().value_or(nullptr);
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.FunctionCall
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index f9dc69caea47..8856fd59abf9 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -103,6 +103,26 @@ void InlinerInterface::handleTerminator(Operation *op,
   handler->handleTerminator(op, valuesToRepl);
 }
 
+Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call,
+                                       Operation *callable, Value argument,
+                                       Type targetType,
+                                       DictionaryAttr argumentAttrs) const {
+  auto *handler = getInterfaceFor(callable);
+  assert(handler && "expected valid dialect handler");
+  return handler->handleArgument(builder, call, callable, argument, targetType,
+                                 argumentAttrs);
+}
+
+Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call,
+                                     Operation *callable, Value result,
+                                     Type targetType,
+                                     DictionaryAttr resultAttrs) const {
+  auto *handler = getInterfaceFor(callable);
+  assert(handler && "expected valid dialect handler");
+  return handler->handleResult(builder, call, callable, result, targetType,
+                               resultAttrs);
+}
+
 void InlinerInterface::processInlinedCallBlocks(
     Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
   auto *handler = getInterfaceFor(call);
@@ -141,6 +161,71 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
 // Inline Methods
 //===----------------------------------------------------------------------===//
 
+static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
+                               CallOpInterface call,
+                               CallableOpInterface callable,
+                               IRMapping &mapper) {
+  // Unpack the argument attributes if there are any.
+  SmallVector<DictionaryAttr> argAttrs(
+      callable.getCallableRegion()->getNumArguments(),
+      builder.getDictionaryAttr({}));
+  if (ArrayAttr arrayAttr = callable.getCallableArgAttrs()) {
+    assert(arrayAttr.size() == argAttrs.size());
+    for (auto [idx, attr] : llvm::enumerate(arrayAttr))
+      argAttrs[idx] = cast<DictionaryAttr>(attr);
+  }
+
+  // Run the argument attribute handler for the given argument and attribute.
+  for (auto [blockArg, argAttr] :
+       llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
+    Value newArgument = interface.handleArgument(builder, call, callable,
+                                                 mapper.lookup(blockArg),
+                                                 blockArg.getType(), argAttr);
+    assert(newArgument.getType() == blockArg.getType() &&
+           "expected the handled argument type to match the target type");
+
+    // Update the mapping to point the new argument returned by the handler.
+    mapper.map(blockArg, newArgument);
+  }
+}
+
+static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
+                             CallOpInterface call, CallableOpInterface callable,
+                             ValueRange results) {
+  // Unpack the result attributes if there are any.
+  SmallVector<DictionaryAttr> resAttrs(results.size(),
+                                       builder.getDictionaryAttr({}));
+  if (ArrayAttr arrayAttr = callable.getCallableResAttrs()) {
+    assert(arrayAttr.size() == resAttrs.size());
+    for (auto [idx, attr] : llvm::enumerate(arrayAttr))
+      resAttrs[idx] = cast<DictionaryAttr>(attr);
+  }
+
+  // Run the result attribute handler for the given result and attribute.
+  SmallVector<DictionaryAttr> resultAttributes;
+  for (auto [result, resAttr] : llvm::zip(results, resAttrs)) {
+    // Store the original result users before running the handler.
+    DenseSet<Operation *> resultUsers;
+    for (Operation *user : result.getUsers())
+      resultUsers.insert(user);
+
+    // TODO: Use the type of the call result to replace once the hook can be
+    // used for type conversions. At the moment, all type conversions have to be
+    // done using materializeCallConversion.
+    Type targetType = result.getType();
+
+    Value newResult = interface.handleResult(builder, call, callable, result,
+                                             targetType, resAttr);
+    assert(newResult.getType() == targetType &&
+           "expected the handled result type to match the target type");
+
+    // Replace the result uses except for the ones introduce by the handler.
+    result.replaceUsesWithIf(newResult, [&](OpOperand &operand) {
+      return resultUsers.count(operand.getOwner());
+    });
+  }
+}
+
 static LogicalResult
 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
                  Block::iterator inlinePoint, IRMapping &mapper,
@@ -166,6 +251,12 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
                        mapper))
     return failure();
 
+  // Run the argument attribute handler before inlining the callable region.
+  OpBuilder builder(inlineBlock, inlinePoint);
+  auto callable = dyn_cast<CallableOpInterface>(src->getParentOp());
+  if (call && callable)
+    handleArgumentImpl(interface, builder, call, callable, mapper);
+
   // Check to see if the region is being cloned, or moved inline. In either
   // case, move the new blocks after the 'insertBlock' to improve IR
   // readability.
@@ -199,8 +290,14 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
 
   // Handle the case where only a single block was inlined.
   if (std::next(newBlocks.begin()) == newBlocks.end()) {
+    // Run the result attribute handler on the terminator operands.
+    Operation *firstBlockTerminator = firstNewBlock->getTerminator();
+    builder.setInsertionPoint(firstBlockTerminator);
+    if (call && callable)
+      handleResultImpl(interface, builder, call, callable,
+                       firstBlockTerminator->getOperands());
+
     // Have the interface handle the terminator of this block.
-    auto *firstBlockTerminator = firstNewBlock->getTerminator();
     interface.handleTerminator(firstBlockTerminator,
                                llvm::to_vector<6>(resultsToReplace));
     firstBlockTerminator->erase();
@@ -218,6 +315,12 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
                                        resultToRepl.value().getLoc()));
     }
 
+    // Run the result attribute handler on the post insertion block arguments.
+    builder.setInsertionPointToStart(postInsertBlock);
+    if (call && callable)
+      handleResultImpl(interface, builder, call, callable,
+                       postInsertBlock->getArguments());
+
     /// Handle the terminators for each of the new blocks.
     for (auto &newBlock : newBlocks)
       interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);

diff  --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir
index b102c210f056..f7eaa478cdbb 100644
--- a/mlir/test/Transforms/inlining.mlir
+++ b/mlir/test/Transforms/inlining.mlir
@@ -226,3 +226,40 @@ func.func @func_with_block_args_location_callee2(%arg0 : i32) {
   call @func_with_block_args_location(%arg0) : (i32) -> ()
   return
 }
+
+// Check that we can handle argument and result attributes.
+test.conversion_func_op @handle_attr_callee_fn_multi_arg(%arg0 : i16, %arg1 : i16 {"test.handle_argument"}) -> (i16 {"test.handle_result"}, i16) {
+  %0 = arith.addi %arg0, %arg1 : i16
+  %1 = arith.subi %arg0, %arg1 : i16
+  "test.return"(%0, %1) : (i16, i16) -> ()
+}
+test.conversion_func_op @handle_attr_callee_fn(%arg0 : i32 {"test.handle_argument"}) -> (i32 {"test.handle_result"}) {
+  "test.return"(%arg0) : (i32) -> ()
+}
+
+// CHECK-LABEL: func @inline_handle_attr_call
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+func.func @inline_handle_attr_call(%arg0 : i16, %arg1 : i16) -> (i16, i16) {
+
+  // CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[ARG1]]) : (i16) -> i16
+  // CHECK: %[[SUM:.*]] = arith.addi %[[ARG0]], %[[CHANGE_INPUT]]
+  // CHECK: %[[DIFF:.*]] = arith.subi %[[ARG0]], %[[CHANGE_INPUT]]
+  // CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[SUM]]) : (i16) -> i16
+  // CHECK-NEXT: return %[[CHANGE_RESULT]], %[[DIFF]]
+  %res0, %res1 = "test.conversion_call_op"(%arg0, %arg1) { callee=@handle_attr_callee_fn_multi_arg } : (i16, i16) -> (i16, i16)
+  return %res0, %res1 : i16, i16
+}
+
+// CHECK-LABEL: func @inline_convert_and_handle_attr_call
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+func.func @inline_convert_and_handle_attr_call(%arg0 : i16) -> (i16) {
+
+  // CHECK: %[[CAST_INPUT:.*]] = "test.cast"(%[[ARG0]]) : (i16) -> i32
+  // CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[CAST_INPUT]]) : (i32) -> i32
+  // CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[CHANGE_INPUT]]) : (i32) -> i32
+  // CHECK: %[[CAST_RESULT:.*]] = "test.cast"(%[[CHANGE_RESULT]]) : (i32) -> i16
+  // CHECK: return %[[CAST_RESULT]]
+  %res = "test.conversion_call_op"(%arg0) { callee=@handle_attr_callee_fn } : (i16) -> (i16)
+  return %res : i16
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 97c77b0eb489..36e2b9882be4 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/IR/FunctionImplementation.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
@@ -354,6 +355,24 @@ struct TestInlinerInterface : public DialectInlinerInterface {
     return builder.create<TestCastOp>(conversionLoc, resultType, input);
   }
 
+  Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
+                       Value argument, Type targetType,
+                       DictionaryAttr argumentAttrs) const final {
+    if (!argumentAttrs.contains("test.handle_argument"))
+      return argument;
+    return builder.create<TestTypeChangerOp>(call->getLoc(), targetType,
+                                             argument);
+  }
+
+  Value handleResult(OpBuilder &builder, Operation *call, Operation *callable,
+                     Value result, Type targetType,
+                     DictionaryAttr resultAttrs) const final {
+    if (!resultAttrs.contains("test.handle_result"))
+      return result;
+    return builder.create<TestTypeChangerOp>(call->getLoc(), targetType,
+                                             result);
+  }
+
   void processInlinedCallBlocks(
       Operation *call,
       iterator_range<Region::iterator> inlinedBlocks) const final {
@@ -650,6 +669,29 @@ LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ConversionFuncOp
+//===----------------------------------------------------------------------===//
+
+ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
+                                    OperationState &result) {
+  auto buildFuncType =
+      [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
+         function_interface_impl::VariadicFlag,
+         std::string &) { return builder.getFunctionType(argTypes, results); };
+
+  return function_interface_impl::parseFunctionOp(
+      parser, result, /*allowVariadic=*/false,
+      getFunctionTypeAttrName(result.name), buildFuncType,
+      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
+}
+
+void ConversionFuncOp::print(OpAsmPrinter &p) {
+  function_interface_impl::printFunctionOp(
+      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
+      getArgAttrsAttrName(), getResAttrsAttrName());
+}
+
 //===----------------------------------------------------------------------===//
 // TestFoldToCallOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 3f642b8a87ea..e747d4bddfd7 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -14,6 +14,7 @@ include "TestInterfaces.td"
 include "mlir/Dialect/DLTI/DLTIBase.td"
 include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/IR/EnumAttr.td"
+include "mlir/IR/FunctionInterfaces.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/PatternBase.td"
@@ -482,6 +483,66 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
   }];
 }
 
+def ConversionFuncOp : TEST_Op<"conversion_func_op", [CallableOpInterface,
+                                                      FunctionOpInterface]> {
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       TypeAttrOf<FunctionType>:$function_type,
+                       OptionalAttr<DictArrayAttr>:$arg_attrs,
+                       OptionalAttr<DictArrayAttr>:$res_attrs,
+                       OptionalAttr<StrAttr>:$sym_visibility);
+  let regions = (region AnyRegion:$body);
+
+  let extraClassDeclaration = [{
+    //===------------------------------------------------------------------===//
+    // CallableOpInterface
+    //===------------------------------------------------------------------===//
+
+    /// Returns the region on the current operation that is callable. This may
+    /// return null in the case of an external callable object, e.g. an external
+    /// function.
+    ::mlir::Region *getCallableRegion() {
+      return isExternal() ? nullptr : &getBody();
+    }
+
+    /// Returns the results types that the callable region produces when
+    /// executed.
+    ::mlir::ArrayRef<::mlir::Type> getCallableResults() {
+      return getFunctionType().getResults();
+    }
+
+    /// Returns the argument attributes for all callable region arguments or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableArgAttrs() {
+      return getArgAttrs().value_or(nullptr);
+    }
+
+    /// Returns the result attributes for all callable region results or
+    /// null if there are none.
+    ::mlir::ArrayAttr getCallableResAttrs() {
+      return getResAttrs().value_or(nullptr);
+    }
+
+    //===------------------------------------------------------------------===//
+    // FunctionOpInterface Methods
+    //===------------------------------------------------------------------===//
+
+    /// Returns the argument types of this async function.
+    ::mlir::ArrayRef<::mlir::Type> getArgumentTypes() {
+      return getFunctionType().getInputs();
+    }
+
+    /// Returns the result types of this async function.
+    ::mlir::ArrayRef<::mlir::Type> getResultTypes() {
+      return getFunctionType().getResults();
+    }
+
+    /// Returns the number of results of this async function
+    unsigned getNumResults() {return getResultTypes().size();}
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+}
+
 def FunctionalRegionOp : TEST_Op<"functional_region_op",
     [CallableOpInterface]> {
   let regions = (region AnyRegion:$body);
@@ -492,6 +553,12 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
     ::llvm::ArrayRef<::mlir::Type> getCallableResults() {
       return getType().cast<::mlir::FunctionType>().getResults();
     }
+    ::mlir::ArrayAttr getCallableArgAttrs() {
+      return nullptr;
+    }
+    ::mlir::ArrayAttr getCallableResAttrs() {
+      return nullptr;
+    }
   }];
 }
 


        


More information about the Mlir-commits mailing list