[Mlir-commits] [mlir] 0a7a1ac - [mlir] Support FuncOpSignatureConversion for more FunctionLike ops.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 21 17:35:17 PST 2021


Author: mikeurbach
Date: 2021-01-21T18:35:09-07:00
New Revision: 0a7a1ac73d095eacd4499e889ce35191a9d1c648

URL: https://github.com/llvm/llvm-project/commit/0a7a1ac73d095eacd4499e889ce35191a9d1c648
DIFF: https://github.com/llvm/llvm-project/commit/0a7a1ac73d095eacd4499e889ce35191a9d1c648.diff

LOG: [mlir] Support FuncOpSignatureConversion for more FunctionLike ops.

This extracts the implementation of getType, setType, and getBody from
FunctionSupport.h into the mlir::impl namespace and defines them
generically in FunctionSupport.cpp. This allows them to be used
elsewhere for any FunctionLike ops that use FunctionType for their
type signature.

Using the new helpers, FuncOpSignatureConversion is generalized to
work with all such FunctionLike ops. Convenience helpers are added to
configure the pattern for a given concrete FunctionLike op type.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D95021

Added: 
    

Modified: 
    mlir/include/mlir/IR/FunctionSupport.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/IR/FunctionSupport.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index 481dc5a6986f..be8a68979203 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -80,6 +80,13 @@ void eraseFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
 void eraseFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
                           unsigned originalNumResults, Type newType);
 
+/// Get and set a FunctionLike operation's type signature.
+FunctionType getFunctionType(Operation *op);
+void setFunctionType(Operation *op, FunctionType newType);
+
+/// Get a FunctionLike operation's body.
+Region &getFunctionBody(Operation *op);
+
 } // namespace impl
 
 namespace OpTrait {
@@ -134,7 +141,9 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
   /// Returns true if this function is external, i.e. it has no body.
   bool isExternal() { return empty(); }
 
-  Region &getBody() { return this->getOperation()->getRegion(0); }
+  Region &getBody() {
+    return ::mlir::impl::getFunctionBody(this->getOperation());
+  }
 
   /// Delete all blocks from this function.
   void eraseBody() {
@@ -198,7 +207,7 @@ class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
   /// hide this one if the concrete class does not use FunctionType for the
   /// function type under the hood.
   FunctionType getType() {
-    return getTypeAttr().getValue().template cast<FunctionType>();
+    return ::mlir::impl::getFunctionType(this->getOperation());
   }
 
   /// Return the type of this function without the specified arguments and
@@ -542,15 +551,7 @@ Block *FunctionLike<ConcreteType>::addBlock() {
 
 template <typename ConcreteType>
 void FunctionLike<ConcreteType>::setType(FunctionType newType) {
-  SmallVector<char, 16> nameBuf;
-  auto oldType = getType();
-  auto *concreteOp = static_cast<ConcreteType *>(this);
-
-  for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++)
-    concreteOp->removeAttr(getArgAttrName(i, nameBuf));
-  for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; i++)
-    concreteOp->removeAttr(getResultAttrName(i, nameBuf));
-  (*concreteOp)->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+  ::mlir::impl::setFunctionType(this->getOperation(), newType);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index ca28c175fbdd..ae2e2d73cf58 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -421,6 +421,21 @@ struct OpConversionPattern : public ConversionPattern {
   using ConversionPattern::matchAndRewrite;
 };
 
+/// Add a pattern to the given pattern list to convert the signature of a
+/// FunctionLike op with the given type converter. This only supports
+/// FunctionLike ops which use FunctionType to represent their type.
+void populateFunctionLikeTypeConversionPattern(
+    StringRef functionLikeOpName, OwningRewritePatternList &patterns,
+    MLIRContext *ctx, TypeConverter &converter);
+
+template <typename FuncOpT>
+void populateFunctionLikeTypeConversionPattern(
+    OwningRewritePatternList &patterns, MLIRContext *ctx,
+    TypeConverter &converter) {
+  populateFunctionLikeTypeConversionPattern(FuncOpT::getOperationName(),
+                                            patterns, ctx, converter);
+}
+
 /// Add a pattern to the given pattern list to convert the signature of a FuncOp
 /// with the given type converter.
 void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns,

diff  --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp
index 772a95ddd9de..347ea155d5e7 100644
--- a/mlir/lib/IR/FunctionSupport.cpp
+++ b/mlir/lib/IR/FunctionSupport.cpp
@@ -99,3 +99,35 @@ void mlir::impl::eraseFunctionResults(Operation *op,
       op->removeAttr(nameAttr);
   }
 }
+
+//===----------------------------------------------------------------------===//
+// Function type signature.
+//===----------------------------------------------------------------------===//
+
+FunctionType mlir::impl::getFunctionType(Operation *op) {
+  assert(op->hasTrait<OpTrait::FunctionLike>());
+  return op->getAttrOfType<TypeAttr>(mlir::impl::getTypeAttrName())
+      .getValue()
+      .cast<FunctionType>();
+}
+
+void mlir::impl::setFunctionType(Operation *op, FunctionType newType) {
+  assert(op->hasTrait<OpTrait::FunctionLike>());
+  SmallVector<char, 16> nameBuf;
+  FunctionType oldType = getFunctionType(op);
+
+  for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e; i++)
+    op->removeAttr(getArgAttrName(i, nameBuf));
+  for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e; i++)
+    op->removeAttr(getResultAttrName(i, nameBuf));
+  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
+}
+
+//===----------------------------------------------------------------------===//
+// Function body.
+//===----------------------------------------------------------------------===//
+
+Region &mlir::impl::getFunctionBody(Operation *op) {
+  assert(op->hasTrait<OpTrait::FunctionLike>());
+  return op->getRegion(0);
+}

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a97c461a8e9c..ae62a63b1228 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/FunctionSupport.h"
 #include "mlir/Rewrite/PatternApplicator.h"
 #include "mlir/Transforms/Utils.h"
 #include "llvm/ADT/SetVector.h"
@@ -2515,41 +2516,52 @@ auto TypeConverter::convertBlockSignature(Block *block)
 }
 
 /// Create a default conversion pattern that rewrites the type signature of a
-/// FuncOp.
+/// FunctionLike op. This only supports FunctionLike ops which use FunctionType
+/// to represent their type.
 namespace {
-struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
-  FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
-      : OpConversionPattern(converter, ctx) {}
+struct FunctionLikeSignatureConversion : public ConversionPattern {
+  FunctionLikeSignatureConversion(StringRef functionLikeOpName,
+                                  MLIRContext *ctx, TypeConverter &converter)
+      : ConversionPattern(functionLikeOpName, /*benefit=*/1, converter, ctx) {}
 
-  /// Hook for derived classes to implement combined matching and rewriting.
+  /// Hook to implement combined matching and rewriting for FunctionLike ops.
   LogicalResult
-  matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    FunctionType type = funcOp.getType();
+    FunctionType type = mlir::impl::getFunctionType(op);
 
     // Convert the original function types.
     TypeConverter::SignatureConversion result(type.getNumInputs());
     SmallVector<Type, 1> newResults;
     if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
         failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
-        failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter,
-                                           &result)))
+        failed(rewriter.convertRegionTypes(&mlir::impl::getFunctionBody(op),
+                                           *typeConverter, &result)))
       return failure();
 
     // Update the function signature in-place.
-    rewriter.updateRootInPlace(funcOp, [&] {
-      funcOp.setType(FunctionType::get(funcOp.getContext(),
-                                       result.getConvertedTypes(), newResults));
-    });
+    auto newType = FunctionType::get(rewriter.getContext(),
+                                     result.getConvertedTypes(), newResults);
+
+    rewriter.updateRootInPlace(
+        op, [&] { mlir::impl::setFunctionType(op, newType); });
+
     return success();
   }
 };
 } // end anonymous namespace
 
+void mlir::populateFunctionLikeTypeConversionPattern(
+    StringRef functionLikeOpName, OwningRewritePatternList &patterns,
+    MLIRContext *ctx, TypeConverter &converter) {
+  patterns.insert<FunctionLikeSignatureConversion>(functionLikeOpName, ctx,
+                                                   converter);
+}
+
 void mlir::populateFuncOpTypeConversionPattern(
     OwningRewritePatternList &patterns, MLIRContext *ctx,
     TypeConverter &converter) {
-  patterns.insert<FuncOpSignatureConversion>(ctx, converter);
+  populateFunctionLikeTypeConversionPattern<FuncOp>(patterns, ctx, converter);
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list