[Mlir-commits] [mlir] 02c9050 - [mlir] Tighten access of RewritePattern methods.

Christian Sigg llvmlistbot at llvm.org
Tue Dec 8 07:45:03 PST 2020


Author: Christian Sigg
Date: 2020-12-08T16:44:51+01:00
New Revision: 02c9050155dff70497b3423ae95ed7d2ab7675a8

URL: https://github.com/llvm/llvm-project/commit/02c9050155dff70497b3423ae95ed7d2ab7675a8
DIFF: https://github.com/llvm/llvm-project/commit/02c9050155dff70497b3423ae95ed7d2ab7675a8.diff

LOG: [mlir] Tighten access of RewritePattern methods.

In RewritePattern, only expose `matchAndRewrite` as a public function. `match` can be protected (but needs to be protected because we want to call it from an override of `matchAndRewrite`). `rewrite` can be private.

For classes deriving from RewritePattern, all 3 functions can be private.

Side note: I didn't understand the need for the `using RewritePattern::matchAndRewrite` in derived classes, and started poking around. They are gone now, and I think the result is (only very slightly) cleaner.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D92670

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Transforms/DialectConversion.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index bf41f29749de..5b605c165be6 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -571,11 +571,9 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                              &typeConverter.getContext(), typeConverter,
                              benefit) {}
 
-  /// Wrappers around the RewritePattern methods that pass the derived op type.
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), operands, rewriter);
-  }
+private:
+  /// Wrappers around the ConversionPattern methods that pass the derived op
+  /// type.
   LogicalResult match(Operation *op) const final {
     return match(cast<SourceOp>(op));
   }
@@ -584,6 +582,10 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                   ConversionPatternRewriter &rewriter) const final {
     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
   }
+  void rewrite(Operation *op, ArrayRef<Value> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    rewrite(cast<SourceOp>(op), operands, rewriter);
+  }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
   /// overridden by the derived pattern class.
@@ -603,10 +605,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
     }
     return failure();
   }
-
-private:
-  using ConvertToLLVMPattern::match;
-  using ConvertToLLVMPattern::matchAndRewrite;
 };
 
 namespace LLVM {
@@ -636,6 +634,7 @@ class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
   using Super = OneToOneConvertToLLVMPattern<SourceOp, TargetOp>;
 
+private:
   /// Converts the type of the result to an LLVM type, pass operands as is,
   /// preserve attributes.
   LogicalResult
@@ -655,6 +654,7 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
   using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;
 
+private:
   LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 0bbb2216ee7b..1739cfa4a80c 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -156,17 +156,6 @@ class RewritePattern : public Pattern {
 public:
   virtual ~RewritePattern() {}
 
-  /// Rewrite the IR rooted at the specified operation with the result of
-  /// this pattern, generating any new operations with the specified
-  /// builder.  If an unexpected error is encountered (an internal
-  /// compiler error), it is emitted through the normal MLIR diagnostic
-  /// hooks and the IR is left in a valid state.
-  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
-
-  /// Attempt to match against code rooted at the specified operation,
-  /// which is the same operation code as getRootKind().
-  virtual LogicalResult match(Operation *op) const;
-
   /// Attempt to match against code rooted at the specified operation,
   /// which is the same operation code as getRootKind(). If successful, this
   /// function will automatically perform the rewrite.
@@ -183,6 +172,18 @@ class RewritePattern : public Pattern {
   /// Inherit the base constructors from `Pattern`.
   using Pattern::Pattern;
 
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind().
+  virtual LogicalResult match(Operation *op) const;
+
+private:
+  /// Rewrite the IR rooted at the specified operation with the result of
+  /// this pattern, generating any new operations with the specified
+  /// builder.  If an unexpected error is encountered (an internal
+  /// compiler error), it is emitted through the normal MLIR diagnostic
+  /// hooks and the IR is left in a valid state.
+  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
+
   /// An anchor for the virtual table.
   virtual void anchor();
 };
@@ -190,12 +191,15 @@ class RewritePattern : public Pattern {
 /// OpRewritePattern is a wrapper around RewritePattern that allows for
 /// matching and rewriting against an instance of a derived operation class as
 /// opposed to a raw Operation.
-template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
+template <typename SourceOp>
+class OpRewritePattern : public RewritePattern {
+public:
   /// Patterns must specify the root operation name they match against, and can
   /// also specify the benefit of the pattern matching.
   OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
       : RewritePattern(SourceOp::getOperationName(), benefit, context) {}
 
+private:
   /// Wrappers around the RewritePattern methods that pass the derived op type.
   void rewrite(Operation *op, PatternRewriter &rewriter) const final {
     rewrite(cast<SourceOp>(op), rewriter);

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e02cf8fe4c0a..ecbb653f7ed9 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -313,6 +313,30 @@ class TypeConverter {
 /// patterns of this type can only be used with the 'apply*' methods below.
 class ConversionPattern : public RewritePattern {
 public:
+  /// Return the type converter held by this pattern, or nullptr if the pattern
+  /// does not require type conversion.
+  TypeConverter *getTypeConverter() const { return typeConverter; }
+
+protected:
+  /// See `RewritePattern::RewritePattern` for information on the other
+  /// available constructors.
+  using RewritePattern::RewritePattern;
+  /// Construct a conversion pattern that matches an operation with the given
+  /// root name. This constructor allows for providing a type converter to use
+  /// within the pattern.
+  ConversionPattern(StringRef rootName, PatternBenefit benefit,
+                    TypeConverter &typeConverter, MLIRContext *ctx)
+      : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {}
+  /// Construct a conversion pattern that matches any operation type. This
+  /// constructor allows for providing a type converter to use within the
+  /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
+  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
+  /// always be supplied here.
+  ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter,
+                    MatchAnyOpTypeTag tag)
+      : RewritePattern(benefit, tag), typeConverter(&typeConverter) {}
+
+private:
   /// Hook for derived classes to implement rewriting. `op` is the (first)
   /// operation matched by the pattern, `operands` is a list of the rewritten
   /// operand values that are passed to `op`, `rewriter` can be used to emit the
@@ -323,6 +347,10 @@ class ConversionPattern : public RewritePattern {
     llvm_unreachable("unimplemented rewrite");
   }
 
+  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
+    llvm_unreachable("never called");
+  }
+
   /// Hook for derived classes to implement combined matching and rewriting.
   virtual LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -337,42 +365,17 @@ class ConversionPattern : public RewritePattern {
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const final;
 
-  /// Return the type converter held by this pattern, or nullptr if the pattern
-  /// does not require type conversion.
-  TypeConverter *getTypeConverter() const { return typeConverter; }
-
-protected:
-  /// See `RewritePattern::RewritePattern` for information on the other
-  /// available constructors.
-  using RewritePattern::RewritePattern;
-  /// Construct a conversion pattern that matches an operation with the given
-  /// root name. This constructor allows for providing a type converter to use
-  /// within the pattern.
-  ConversionPattern(StringRef rootName, PatternBenefit benefit,
-                    TypeConverter &typeConverter, MLIRContext *ctx)
-      : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {}
-  /// Construct a conversion pattern that matches any operation type. This
-  /// constructor allows for providing a type converter to use within the
-  /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
-  /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
-  /// always be supplied here.
-  ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter,
-                    MatchAnyOpTypeTag tag)
-      : RewritePattern(benefit, tag), typeConverter(&typeConverter) {}
-
 protected:
   /// An optional type converter for use by this pattern.
   TypeConverter *typeConverter = nullptr;
-
-private:
-  using RewritePattern::rewrite;
 };
 
 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
 /// matching and rewriting against an instance of a derived operation class as
 /// opposed to a raw Operation.
 template <typename SourceOp>
-struct OpConversionPattern : public ConversionPattern {
+class OpConversionPattern : public ConversionPattern {
+public:
   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
   OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context,
@@ -380,6 +383,7 @@ struct OpConversionPattern : public ConversionPattern {
       : ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter,
                           context) {}
 
+private:
   /// Wrappers around the ConversionPattern methods that pass the derived op
   /// type.
   void rewrite(Operation *op, ArrayRef<Value> operands,
@@ -409,9 +413,6 @@ struct OpConversionPattern : public ConversionPattern {
     rewrite(op, operands, rewriter);
     return success();
   }
-
-private:
-  using ConversionPattern::matchAndRewrite;
 };
 
 /// Add a pattern to the given pattern list to convert the signature of a FuncOp


        


More information about the Mlir-commits mailing list