[flang-commits] [flang] [mlir] [mlir][IR] Move `match` and `rewrite` functions into separate class (PR #129861)

Matthias Springer via flang-commits flang-commits at lists.llvm.org
Wed Mar 5 02:54:19 PST 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/129861

>From 18105b5255a6864f8bf2a61846383b0bbbb158c2 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 5 Mar 2025 10:16:25 +0100
Subject: [PATCH 1/2] add new class

---
 .../mlir/Conversion/LLVMCommon/Pattern.h      |  40 +----
 mlir/include/mlir/IR/PatternMatch.h           |  86 ++++++-----
 .../mlir/Transforms/DialectConversion.h       | 141 +++++++++---------
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           |  18 ++-
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  |  10 +-
 .../Transforms/EmulateUnsupportedFloats.cpp   |   5 +-
 .../Transforms/IntRangeOptimizations.cpp      |   6 +-
 .../Transforms/VectorTransferOpTransforms.cpp |   6 +-
 mlir/lib/IR/PatternMatch.cpp                  |   9 --
 mlir/unittests/IR/PatternMatchTest.cpp        |   5 +
 10 files changed, 147 insertions(+), 179 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 86ea87b55af1c..8f82176f3b75f 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite(
 /// during the entire pattern lifetime.
 class ConvertToLLVMPattern : public ConversionPattern {
 public:
+  using SplitMatchAndRewrite =
+      detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
+
   ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
                        const LLVMTypeConverter &typeConverter,
                        PatternBenefit benefit = 1);
@@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
 template <typename SourceOp>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
+  using OperationT = SourceOp;
   using OpAdaptor = typename SourceOp::Adaptor;
   using OneToNOpAdaptor =
       typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+  using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
+      ConvertOpToLLVMPattern<SourceOp>>;
 
   explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
                                   PatternBenefit benefit = 1)
@@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                              benefit) {}
 
   /// Wrappers around the RewritePattern methods that pass the derived op type.
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
-  }
-  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
-  }
-  LogicalResult match(Operation *op) const final {
-    return match(cast<SourceOp>(op));
-  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                            rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. One of these must be
   /// overridden by the derived pattern class.
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override rewrite or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    SmallVector<Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, adaptor, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   }
 
 private:
-  using ConvertToLLVMPattern::match;
   using ConvertToLLVMPattern::matchAndRewrite;
 };
 
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ab0405043a54..9055dc6ed7fc1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -234,41 +234,50 @@ class Pattern {
 // RewritePattern
 //===----------------------------------------------------------------------===//
 
-/// RewritePattern is the common base class for all DAG to DAG replacements.
-/// There are two possible usages of this class:
-///   * Multi-step RewritePattern with "match" and "rewrite"
-///     - By overloading the "match" and "rewrite" functions, the user can
-///       separate the concerns of matching and rewriting.
-///   * Single-step RewritePattern with "matchAndRewrite"
-///     - By overloading the "matchAndRewrite" function, the user can perform
-///       the rewrite in the same call as the match.
-///
-class RewritePattern : public Pattern {
-public:
-  virtual ~RewritePattern() = default;
+namespace detail {
+/// Helper class that derives from a RewritePattern class and provides separate
+/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
+template <typename PatternT>
+class SplitMatchAndRewriteImpl : public PatternT {
+  using PatternT::PatternT;
 
   /// Rewrite the IR rooted at the specified operation with the result of
   /// this pattern, generating any new operations with the specified
-  /// builder.  If an unexpected error is encountered (an internal
-  /// compiler error), it is emitted through the normal MLIR diagnostic
-  /// hooks and the IR is left in a valid state.
-  virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
+  /// rewriter.
+  virtual void rewrite(typename PatternT::OperationT op,
+                       PatternRewriter &rewriter) const = 0;
 
   /// Attempt to match against code rooted at the specified operation,
   /// which is the same operation code as getRootKind().
-  virtual LogicalResult match(Operation *op) const;
+  virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
 
-  /// Attempt to match against code rooted at the specified operation,
-  /// which is the same operation code as getRootKind(). If successful, this
-  /// function will automatically perform the rewrite.
-  virtual LogicalResult matchAndRewrite(Operation *op,
-                                        PatternRewriter &rewriter) const {
+  LogicalResult matchAndRewrite(typename PatternT::OperationT op,
+                                PatternRewriter &rewriter) const final {
     if (succeeded(match(op))) {
       rewrite(op, rewriter);
       return success();
     }
     return failure();
   }
+};
+} // namespace detail
+
+/// RewritePattern is the common base class for all DAG to DAG replacements.
+/// By overloading the "matchAndRewrite" function, the user can perform the
+/// rewrite in the same call as the match.
+///
+class RewritePattern : public Pattern {
+public:
+  using OperationT = Operation *;
+  using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
+
+  virtual ~RewritePattern() = default;
+
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind(). If successful, this
+  /// function will automatically perform the rewrite.
+  virtual LogicalResult matchAndRewrite(Operation *op,
+                                        PatternRewriter &rewriter) const = 0;
 
   /// This method provides a convenient interface for creating and initializing
   /// derived rewrite patterns of the given type `T`.
@@ -317,36 +326,19 @@ namespace detail {
 /// class or Interface.
 template <typename SourceOp>
 struct OpOrInterfaceRewritePatternBase : public RewritePattern {
+  using OperationT = SourceOp;
   using RewritePattern::RewritePattern;
 
-  /// Wrappers around the RewritePattern methods that pass the derived op type.
-  void rewrite(Operation *op, PatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), rewriter);
-  }
-  LogicalResult match(Operation *op) const final {
-    return match(cast<SourceOp>(op));
-  }
+  /// Wrapper around the RewritePattern method that passes the derived op type.
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const final {
     return matchAndRewrite(cast<SourceOp>(op), rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
-  /// overridden by the derived pattern class.
-  virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
-    llvm_unreachable("must override rewrite or matchAndRewrite");
-  }
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
+  /// Method that operates on the SourceOp type. Must be overridden by the
+  /// derived pattern class.
   virtual LogicalResult matchAndRewrite(SourceOp op,
-                                        PatternRewriter &rewriter) const {
-    if (succeeded(match(op))) {
-      rewrite(op, rewriter);
-      return success();
-    }
-    return failure();
-  }
+                                        PatternRewriter &rewriter) const = 0;
 };
 } // namespace detail
 
@@ -356,6 +348,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
 template <typename SourceOp>
 struct OpRewritePattern
     : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+  using SplitMatchAndRewrite =
+      detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
+
   /// Patterns must specify the root operation name they match against, and can
   /// also specify the benefit of the pattern matching and a list of generated
   /// ops.
@@ -371,6 +366,9 @@ struct OpRewritePattern
 template <typename SourceOp>
 struct OpInterfaceRewritePattern
     : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+  using SplitMatchAndRewrite =
+      detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
+
   OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
       : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
             Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9a6975dcf8dfa..d705480cb137e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -528,24 +528,72 @@ class TypeConverter {
 // Conversion Patterns
 //===----------------------------------------------------------------------===//
 
+namespace detail {
+/// Helper class that derives from a ConversionRewritePattern class and
+/// provides separate `match` and `rewrite` entry points instead of a combined
+/// `matchAndRewrite`.
+template <typename PatternT>
+class ConversionSplitMatchAndRewriteImpl : public PatternT {
+  using PatternT::PatternT;
+
+  /// Rewrite the IR rooted at the specified operation with the result of
+  /// this pattern, generating any new operations with the specified
+  /// rewriter.
+  virtual void rewrite(typename PatternT::OperationT op,
+                       typename PatternT::OpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    // One of the two `rewrite` functions must be implemented.
+    llvm_unreachable("rewrite is not implemented");
+  }
+
+  virtual void rewrite(typename PatternT::OperationT op,
+                       typename PatternT::OneToNOpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    if constexpr (std::is_same<typename PatternT::OpAdaptor,
+                               ArrayRef<Value>>::value) {
+      rewrite(op, PatternT::getOneToOneAdaptorOperands(adaptor), rewriter);
+    } else {
+      SmallVector<Value> oneToOneOperands =
+          PatternT::getOneToOneAdaptorOperands(adaptor.getOperands());
+      rewrite(op, typename PatternT::OpAdaptor(oneToOneOperands, adaptor),
+              rewriter);
+    }
+  }
+
+  /// Attempt to match against code rooted at the specified operation,
+  /// which is the same operation code as getRootKind().
+  virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
+
+  LogicalResult
+  matchAndRewrite(typename PatternT::OperationT op,
+                  typename PatternT::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    llvm_unreachable("1:1 matchAndRewrite entry point is never used");
+  }
+
+  LogicalResult
+  matchAndRewrite(typename PatternT::OperationT op,
+                  typename PatternT::OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (succeeded(match(op))) {
+      rewrite(op, adaptor, rewriter);
+      return success();
+    }
+    return failure();
+  }
+};
+} // namespace detail
+
 /// Base class for the conversion patterns. This pattern class enables type
 /// conversions, and other uses specific to the conversion framework. As such,
 /// patterns of this type can only be used with the 'apply*' methods below.
 class ConversionPattern : public RewritePattern {
 public:
-  /// Hook for derived classes to implement rewriting. `op` is the (first)
-  /// operation matched by the pattern, `operands` is a list of the rewritten
-  /// operand values that are passed to `op`, `rewriter` can be used to emit the
-  /// new operations. This function should not fail. If some specific cases of
-  /// the operation are not supported, these cases should not be matched.
-  virtual void rewrite(Operation *op, ArrayRef<Value> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("unimplemented rewrite");
-  }
-  virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
-  }
+  using OperationT = Operation *;
+  using OpAdaptor = ArrayRef<Value>;
+  using OneToNOpAdaptor = ArrayRef<ValueRange>;
+  using SplitMatchAndRewrite =
+      detail::ConversionSplitMatchAndRewriteImpl<ConversionPattern>;
 
   /// Hook for derived classes to implement combined matching and rewriting.
   /// This overload supports only 1:1 replacements. The 1:N overload is called
@@ -554,10 +602,7 @@ class ConversionPattern : public RewritePattern {
   virtual LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, operands, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
 
   /// Hook for derived classes to implement combined matching and rewriting.
@@ -606,9 +651,6 @@ class ConversionPattern : public RewritePattern {
 protected:
   /// An optional type converter for use by this pattern.
   const TypeConverter *typeConverter = nullptr;
-
-private:
-  using RewritePattern::rewrite;
 };
 
 /// OpConversionPattern is a wrapper around ConversionPattern that allows for
@@ -617,9 +659,12 @@ class ConversionPattern : public RewritePattern {
 template <typename SourceOp>
 class OpConversionPattern : public ConversionPattern {
 public:
+  using OperationT = SourceOp;
   using OpAdaptor = typename SourceOp::Adaptor;
   using OneToNOpAdaptor =
       typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
+  using SplitMatchAndRewrite =
+      detail::ConversionSplitMatchAndRewriteImpl<OpConversionPattern<SourceOp>>;
 
   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -630,19 +675,6 @@ class OpConversionPattern : public ConversionPattern {
 
   /// Wrappers around the ConversionPattern methods that pass the derived op
   /// type.
-  LogicalResult match(Operation *op) const final {
-    return match(cast<SourceOp>(op));
-  }
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
-  }
-  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = cast<SourceOp>(op);
-    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
-  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -657,28 +689,12 @@ class OpConversionPattern : public ConversionPattern {
                            rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. One of these must be
   /// overridden by the derived pattern class.
-  virtual LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override matchAndRewrite or a rewrite method");
-  }
-  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
-                       ConversionPatternRewriter &rewriter) const {
-    SmallVector<Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, adaptor, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -708,14 +724,6 @@ class OpInterfaceConversionPattern : public ConversionPattern {
 
   /// Wrappers around the ConversionPattern methods that pass the derived op
   /// type.
-  void rewrite(Operation *op, ArrayRef<Value> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), operands, rewriter);
-  }
-  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
-               ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), operands, rewriter);
-  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -727,23 +735,12 @@ class OpInterfaceConversionPattern : public ConversionPattern {
     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
   }
 
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. One of these must be
   /// overridden by the derived pattern class.
-  virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override matchAndRewrite or a rewrite method");
-  }
-  virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
-                       ConversionPatternRewriter &rewriter) const {
-    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
-  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
-    if (failed(match(op)))
-      return failure();
-    rewrite(op, operands, rewriter);
-    return success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index cba71740f9380..734c4839f9a10 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -41,23 +41,25 @@ struct ArithToAMDGPUConversionPass final
   void runOnOperation() override;
 };
 
-struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
-  using OpRewritePattern::OpRewritePattern;
+struct ExtFOnFloat8RewritePattern final
+    : OpRewritePattern<arith::ExtFOp>::SplitMatchAndRewrite {
+  using SplitMatchAndRewrite::SplitMatchAndRewrite;
 
   Chipset chipset;
   ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
-      : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+      : SplitMatchAndRewrite::SplitMatchAndRewrite(ctx), chipset(chipset) {}
 
   LogicalResult match(arith::ExtFOp op) const override;
   void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
 };
 
-struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
+struct TruncFToFloat8RewritePattern final
+    : OpRewritePattern<arith::TruncFOp>::SplitMatchAndRewrite {
   bool saturateFP8 = false;
   TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
                                Chipset chipset)
-      : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
-        chipset(chipset) {}
+      : SplitMatchAndRewrite::SplitMatchAndRewrite(ctx),
+        saturateFP8(saturateFP8), chipset(chipset) {}
   Chipset chipset;
 
   LogicalResult match(arith::TruncFOp op) const override;
@@ -65,9 +67,9 @@ struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
 };
 
 struct TruncfToFloat16RewritePattern final
-    : public OpRewritePattern<arith::TruncFOp> {
+    : public OpRewritePattern<arith::TruncFOp>::SplitMatchAndRewrite {
 
-  using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+  using SplitMatchAndRewrite::SplitMatchAndRewrite;
 
   LogicalResult match(arith::TruncFOp op) const override;
   void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 3646416def810..80310ce56a51b 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -363,11 +363,6 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
   using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
   using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
   using Base = LoadStoreOpLowering<Derived>;
-
-  LogicalResult match(Derived op) const override {
-    MemRefType type = op.getMemRefType();
-    return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
-  }
 };
 
 /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
@@ -662,8 +657,9 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
   }
 };
 
-struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
-  using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
+struct MemRefCastOpLowering
+    : public ConvertOpToLLVMPattern<memref::CastOp>::SplitMatchAndRewrite {
+  using SplitMatchAndRewrite::SplitMatchAndRewrite;
 
   LogicalResult match(memref::CastOp memRefCastOp) const override {
     Type srcType = memRefCastOp.getOperand().getType();
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 836ebb65e7d17..f105534626082 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -40,9 +40,10 @@ struct EmulateUnsupportedFloatsPass
   void runOnOperation() override;
 };
 
-struct EmulateFloatPattern final : ConversionPattern {
+struct EmulateFloatPattern final : ConversionPattern::SplitMatchAndRewrite {
   EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
-      : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
+      : ConversionPattern::SplitMatchAndRewrite(
+            converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
 
   LogicalResult match(Operation *op) const override;
   void rewrite(Operation *op, ArrayRef<Value> operands,
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 5982f5f55549e..6da28ddeede3c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -115,9 +115,11 @@ class DataFlowListener : public RewriterBase::Listener {
 /// and replace their uses with that constant. Return success() if all results
 /// where thus replaced and the operation is erased. Also replace any block
 /// arguments with their constant values.
-struct MaterializeKnownConstantValues : public RewritePattern {
+struct MaterializeKnownConstantValues
+    : public RewritePattern::SplitMatchAndRewrite {
   MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
-      : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
+      : RewritePattern::SplitMatchAndRewrite(Pattern::MatchAnyOpTypeTag(),
+                                             /*benefit=*/1, context),
         solver(s) {}
 
   LogicalResult match(Operation *op) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index f13e54901f690..2413a4126f3f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -772,14 +772,14 @@ class FlattenContiguousRowMajorTransferWritePattern
 /// `vector.extract` and `vector.extract_element`.
 template <class VectorExtractOp>
 class RewriteScalarExtractOfTransferReadBase
-    : public OpRewritePattern<VectorExtractOp> {
-  using Base = OpRewritePattern<VectorExtractOp>;
+    : public OpRewritePattern<VectorExtractOp>::SplitMatchAndRewrite {
+  using Base = typename OpRewritePattern<VectorExtractOp>::SplitMatchAndRewrite;
 
 public:
   RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
                                          PatternBenefit benefit,
                                          bool allowMultipleUses)
-      : Base::OpRewritePattern(context, benefit),
+      : Base::SplitMatchAndRewrite(context, benefit),
         allowMultipleUses(allowMultipleUses) {}
 
   LogicalResult match(VectorExtractOp extractOp) const override {
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 286f47ce69136..3e3c06bebf142 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -87,15 +87,6 @@ Pattern::Pattern(const void *rootValue, RootKind rootKind,
 // RewritePattern
 //===----------------------------------------------------------------------===//
 
-void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
-  llvm_unreachable("need to implement either matchAndRewrite or one of the "
-                   "rewrite functions!");
-}
-
-LogicalResult RewritePattern::match(Operation *op) const {
-  llvm_unreachable("need to implement either match or matchAndRewrite!");
-}
-
 /// Out-of-line vtable anchor.
 void RewritePattern::anchor() {}
 
diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp
index 75d5228c82d99..1c67bfc284d32 100644
--- a/mlir/unittests/IR/PatternMatchTest.cpp
+++ b/mlir/unittests/IR/PatternMatchTest.cpp
@@ -19,6 +19,11 @@ struct AnOpRewritePattern : OpRewritePattern<test::OpA> {
   AnOpRewritePattern(MLIRContext *context)
       : OpRewritePattern(context, /*benefit=*/1,
                          /*generatedNames=*/{test::OpB::getOperationName()}) {}
+
+  LogicalResult matchAndRewrite(test::OpA op,
+                                PatternRewriter &rewriter) const override {
+    return failure();
+  }
 };
 TEST(OpRewritePatternTest, GetGeneratedNames) {
   MLIRContext context;

>From 68f068543919e0bd61d0da6fabc456168a6f7de4 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 5 Mar 2025 11:54:00 +0100
Subject: [PATCH 2/2] fix flang

---
 .../flang/Optimizer/CodeGen/FIROpPatterns.h   | 36 ++-----------------
 1 file changed, 2 insertions(+), 34 deletions(-)

diff --git a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
index 35749dae5d7e9..53d16323beddf 100644
--- a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
+++ b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
@@ -187,7 +187,6 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
 
   const fir::FIRToLLVMPassOptions &options;
 
-  using ConvertToLLVMPattern::match;
   using ConvertToLLVMPattern::matchAndRewrite;
 };
 
@@ -206,20 +205,6 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
                                 options, benefit) {}
 
   /// Wrappers around the RewritePattern methods that pass the derived op type.
-  void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
-               mlir::ConversionPatternRewriter &rewriter) const final {
-    rewrite(mlir::cast<SourceOp>(op),
-            OpAdaptor(operands, mlir::cast<SourceOp>(op)), rewriter);
-  }
-  void rewrite(mlir::Operation *op, mlir::ArrayRef<mlir::ValueRange> operands,
-               mlir::ConversionPatternRewriter &rewriter) const final {
-    auto sourceOp = llvm::cast<SourceOp>(op);
-    rewrite(llvm::cast<SourceOp>(op), OneToNOpAdaptor(operands, sourceOp),
-            rewriter);
-  }
-  llvm::LogicalResult match(mlir::Operation *op) const final {
-    return match(mlir::cast<SourceOp>(op));
-  }
   llvm::LogicalResult
   matchAndRewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
                   mlir::ConversionPatternRewriter &rewriter) const final {
@@ -235,28 +220,12 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
     return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
                            rewriter);
   }
-  /// Rewrite and Match methods that operate on the SourceOp type. These must be
+  /// Methods that operate on the SourceOp type. These must be
   /// overridden by the derived pattern class.
-  virtual llvm::LogicalResult match(SourceOp op) const {
-    llvm_unreachable("must override match or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OpAdaptor adaptor,
-                       mlir::ConversionPatternRewriter &rewriter) const {
-    llvm_unreachable("must override rewrite or matchAndRewrite");
-  }
-  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
-                       mlir::ConversionPatternRewriter &rewriter) const {
-    llvm::SmallVector<mlir::Value> oneToOneOperands =
-        getOneToOneAdaptorOperands(adaptor.getOperands());
-    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
-  }
   virtual llvm::LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const {
-    if (mlir::failed(match(op)))
-      return mlir::failure();
-    rewrite(op, adaptor, rewriter);
-    return mlir::success();
+    llvm_unreachable("matchAndRewrite is not implemented");
   }
   virtual llvm::LogicalResult
   matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -268,7 +237,6 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
 
 private:
   using ConvertFIRToLLVMPattern::matchAndRewrite;
-  using ConvertToLLVMPattern::match;
 };
 
 /// FIR conversion pattern template



More information about the flang-commits mailing list