[Mlir-commits] [mlir] 1cff4cb - [mlir][Transform] NFC - Various API cleanups and use RewriterBase in lieu of PatternRewriter

Nicolas Vasilache llvmlistbot at llvm.org
Tue Mar 14 04:23:19 PDT 2023


Author: Nicolas Vasilache
Date: 2023-03-14T04:23:12-07:00
New Revision: 1cff4cbda3051df77c4b2e2bf9432cac45fbd395

URL: https://github.com/llvm/llvm-project/commit/1cff4cbda3051df77c4b2e2bf9432cac45fbd395
DIFF: https://github.com/llvm/llvm-project/commit/1cff4cbda3051df77c4b2e2bf9432cac45fbd395.diff

LOG: [mlir][Transform] NFC - Various API cleanups and use RewriterBase in lieu of PatternRewriter

Depends on: D145685

Differential Revision: https://reviews.llvm.org/D145977

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
    mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
    mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
    mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
    mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
    mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
    mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/GPU/transform-gpu-failing.mlir
    mlir/test/Dialect/GPU/transform-gpu.mlir
    mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Removed: 
    mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
index 91352b28fc7b7..463736722e229 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h
@@ -33,14 +33,15 @@ class DialectRegistry;
 namespace transform {
 namespace gpu {
 
-/// Searches `scf.forall` ops nested under `target` and maps each such
-/// op to GPU threads. Mapping is one-to-one and the induction variables of
-/// `scf.forall` are rewritten to gpu.thread_id according to the
-/// thread_dim_apping attribute. Sibling `scf.forall` are supported in
-/// which case, the union of the number of threads is computed and may result in
-/// predication. Dynamic, `scf.forall` trip counts are currently not
-/// supported. Dynamic block dim sizes are currently not supported.
-DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl(
+/// Search `scf.forall` ops nested under `target` and map each such op to GPU
+/// threads. Mapping is one-to-one and the induction variables of `scf.forall`
+/// are rewritten to gpu.thread_id according to the thread_dim_mapping
+/// attribute.
+/// Sibling `scf.forall` are supported in which case, the union of the number of
+/// threads is computed and may result in predication.
+/// Dynamic, `scf.forall` trip counts are currently not supported.
+/// Dynamic block dim sizes are currently not supported.
+DiagnosedSilenceableFailure mapNestedForallToThreadsImpl(
     RewriterBase &rewriter, Operation *target,
     const SmallVectorImpl<int64_t> &blockDim,
     function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
@@ -48,19 +49,19 @@ DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl(
     bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
     const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes);
 
-/// Maps the top level `scf.forall` op to GPU Thread Blocks. Mapping is
-/// one-to-one and the induction variables of `scf.forall` are rewritten
-/// to gpu.block_id according to the thread_dim_apping attribute. Dynamic,
-/// `scf.forall` trip counts are currently not supported. Dynamic block
-/// dim sizes are currently not supported.
-DiagnosedSilenceableFailure mapForeachToBlocksImpl(
+/// Map the top level `scf.forall` op to GPU Thread Blocks.
+/// Mapping is one-to-one and the induction variables of `scf.forall` are
+/// rewritten to gpu.block_id according to the thread_dim_apping attribute.
+/// Dynamic, `scf.forall` trip counts are currently not supported.
+/// Dynamic block dim sizes are currently not supported.
+DiagnosedSilenceableFailure mapForallToBlocksImpl(
     RewriterBase &rewriter, scf::ForallOp forallOp,
     function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
         blockIdGenerator,
     SmallVectorImpl<int64_t> &gridDims, TransformOpInterface transformOp,
     const ArrayRef<DeviceMappingAttrInterface> &mappingAttributes);
 
-/// Finds the top level scf::ForallOp of given target.
+/// Find the unique top level scf::ForallOp within a given target op.
 DiagnosedSilenceableFailure
 findTopLevelForallOp(Operation *target, scf::ForallOp &topLevelForallOp,
                      TransformOpInterface transformOp);

diff  --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
index 76292b4683380..46f0e186741e8 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td
@@ -15,8 +15,8 @@ include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
-def MapNestedForeachToThreads :
-  Op<Transform_Dialect, "gpu.map_nested_foreach_to_threads",
+def MapNestedForallToThreads :
+  Op<Transform_Dialect, "gpu.map_nested_forall_to_threads",
     [FunctionalStyleTransformOpTrait,
      MemoryEffectsOpInterface,
      TransformEachOpTrait,
@@ -118,8 +118,8 @@ def MapNestedForeachToThreads :
 }
 
 
-def MapForeachToBlocks :
-  Op<Transform_Dialect, "gpu.map_foreach_to_blocks",
+def MapForallToBlocks :
+  Op<Transform_Dialect, "gpu.map_forall_to_blocks",
     [FunctionalStyleTransformOpTrait,
      MemoryEffectsOpInterface,
      TransformOpInterface,

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index eaf9fec6e5cf5..8e38de4844e2b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -570,15 +570,15 @@ LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
 
 /// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
-FailureOr<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
+FailureOr<LinalgLoops> linalgOpToLoops(RewriterBase &rewriter,
                                        LinalgOp linalgOp);
 
 /// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`.
-FailureOr<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
+FailureOr<LinalgLoops> linalgOpToParallelLoops(RewriterBase &rewriter,
                                                LinalgOp linalgOp);
 
 /// Emit a loop nest of `affine.for` with the proper body for `linalgOp`.
-FailureOr<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
+FailureOr<LinalgLoops> linalgOpToAffineLoops(RewriterBase &rewriter,
                                              LinalgOp linalgOp);
 
 /// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
@@ -818,7 +818,7 @@ struct SplitReductionResult {
   LinalgOp resultCombiningLinalgOp;
 };
 FailureOr<SplitReductionResult>
-splitReduction(PatternRewriter &b, LinalgOp op,
+splitReduction(RewriterBase &b, LinalgOp op,
                const ControlSplitReductionFn &controlSplitReductionFn,
                bool useAlloc = false);
 
@@ -870,7 +870,7 @@ splitReduction(PatternRewriter &b, LinalgOp op,
 ///  return %4 : tensor<16x32xf32>
 /// ```
 FailureOr<SplitReductionResult>
-splitReductionByScaling(PatternRewriter &b, LinalgOp op,
+splitReductionByScaling(RewriterBase &b, LinalgOp op,
                         const ControlSplitReductionFn &controlSplitReductionFn,
                         bool useAlloc = false);
 
@@ -1097,7 +1097,7 @@ struct PadOpTransformationPattern : public OpRewritePattern<tensor::PadOp> {
 };
 
 using OptimizeCopyFn =
-    std::function<LogicalResult(PatternRewriter &, tensor::PadOp, Value)>;
+    std::function<LogicalResult(RewriterBase &, tensor::PadOp, Value)>;
 
 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
 /// InsertSliceOp. For now, only constant padding values are supported.
@@ -1113,7 +1113,7 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
 
 protected:
   OptimizeCopyFn optimizeCopyFn;
-  Value createFillOrGenerateOp(PatternRewriter &rewriter, tensor::PadOp padOp,
+  Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp,
                                Value dest,
                                const SmallVector<Value> &dynSizes) const;
 };

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index 4cb5befd3053c..598e8ba6d3af1 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -31,6 +31,10 @@ namespace scf {
 /// }
 /// S1(N) S2(N-1)                // Epilogue
 /// S2(N)                        // Epilogue
+FailureOr<ForOp> pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
+                                 const PipeliningOption &options);
+
+// TODO: such patterns should be auto-generated.
 class ForLoopPipeliningPattern : public OpRewritePattern<ForOp> {
 public:
   ForLoopPipeliningPattern(const PipeliningOption &options,
@@ -42,7 +46,9 @@ class ForLoopPipeliningPattern : public OpRewritePattern<ForOp> {
   }
 
   FailureOr<ForOp> returningMatchAndRewrite(ForOp forOp,
-                                            PatternRewriter &rewriter) const;
+                                            PatternRewriter &rewriter) const {
+    return pipelineForLoop(rewriter, forOp, options);
+  }
 
 protected:
   PipeliningOption options;

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1ed850faa17a9..5e03eccfc2f3a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -18,7 +18,7 @@
 
 namespace mlir {
 class Operation;
-class PatternRewriter;
+class RewriterBase;
 class TilingInterface;
 } // namespace mlir
 
@@ -243,7 +243,7 @@ struct SCFReductionTilingResult {
 ///   : tensor<7x4xf32> -> tensor<7xf32>
 /// ```
 FailureOr<scf::SCFReductionTilingResult>
-tileReductionUsingScf(PatternRewriter &b, PartialReductionOpInterface op,
+tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
                       ArrayRef<OpFoldResult> tileSize);
 
 } // namespace scf

diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 1314d8628dc55..bfeab9da7632f 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -152,7 +152,7 @@ struct PipeliningOption {
   // peeled. This takes the original operation, an i1 predicate value and the
   // pattern rewriter.
   using PredicateOpFn =
-      std::function<Operation *(Operation *, Value, PatternRewriter &)>;
+      std::function<Operation *(RewriterBase &, Operation *, Value)>;
   PredicateOpFn predicateFn = nullptr;
 
   // TODO: add option to decide if the prologue should be peeled.

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h b/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h
deleted file mode 100644
index ca3da2b682e70..0000000000000
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h
+++ /dev/null
@@ -1,28 +0,0 @@
-//===- TransformUtils.h - Transform Dialect Utils ---------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H
-#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H
-
-#include "mlir/IR/PatternMatch.h"
-
-namespace mlir {
-namespace transform {
-
-/// A simple pattern rewriter that can be constructed from a context. This is
-/// necessary to apply patterns to a specific op locally.
-class TrivialPatternRewriter : public PatternRewriter {
-public:
-  explicit TrivialPatternRewriter(MLIRContext *context)
-      : PatternRewriter(context) {}
-};
-
-} // namespace transform
-} // namespace mlir
-
-#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H

diff  --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 19ee04d10b19a..99dfaa9dee1d2 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -16,7 +16,6 @@
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
-#include "mlir/Dialect/Transform/IR/TransformUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;

diff  --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 22b8f5e15dde2..bd003b844a0b5 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -16,7 +16,6 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
-#include "mlir/Dialect/Transform/IR/TransformUtils.h"
 #include "mlir/IR/IRMapping.h"
 
 using namespace mlir;
@@ -26,13 +25,13 @@ using namespace mlir::transform;
 /// Check if given mapping attributes are one of the desired attributes
 static DiagnosedSilenceableFailure
 checkAttributeType(ArrayRef<DeviceMappingAttrInterface> threadMappingAttributes,
-                   const std::optional<ArrayAttr> &foreachMapping,
+                   const std::optional<ArrayAttr> &forallMapping,
                    std::optional<TransformOpInterface> transformOp) {
-  if (!foreachMapping.has_value())
+  if (!forallMapping.has_value())
     return transformOp->emitSilenceableError() << "mapping must be present";
 
   DenseSet<Attribute> seen;
-  for (Attribute map : foreachMapping->getValue()) {
+  for (Attribute map : forallMapping->getValue()) {
     if (!llvm::is_contained(threadMappingAttributes, map)) {
       return transformOp->emitDefiniteFailure()
              << "mapping must be one of " << threadMappingAttributes;
@@ -124,7 +123,7 @@ createGpuLaunch(RewriterBase &rewriter, Location loc,
 
 /// Alter kernel configuration of the given kernel.
 static DiagnosedSilenceableFailure
-alterGpuLaunch(TrivialPatternRewriter &rewriter, LaunchOp gpuLaunch,
+alterGpuLaunch(IRRewriter &rewriter, LaunchOp gpuLaunch,
                TransformOpInterface transformOp,
                std::optional<int64_t> gridDimX = std::nullopt,
                std::optional<int64_t> gridDimY = std::nullopt,
@@ -165,10 +164,10 @@ alterGpuLaunch(TrivialPatternRewriter &rewriter, LaunchOp gpuLaunch,
 }
 
 //===----------------------------------------------------------------------===//
-// MapForeachToBlocks
+// MapForallToBlocks
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl(
+DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
     RewriterBase &rewriter, scf::ForallOp forallOp,
     function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
         blockIdGenerator,
@@ -262,7 +261,7 @@ mlir::transform::gpu::findTopLevelForallOp(Operation *target,
     if (forallOp->getParentOfType<scf::ForallOp>())
       return WalkResult::advance();
     if (topLevelForallOp)
-      // TODO: Handle multiple foreach if there is no dependences between them
+      // TODO: Handle multiple forall if they are independent.
       return WalkResult::interrupt();
     topLevelForallOp = forallOp;
     return WalkResult::advance();
@@ -274,14 +273,13 @@ mlir::transform::gpu::findTopLevelForallOp(Operation *target,
   return DiagnosedSilenceableFailure::success();
 }
 
-/// This is a helper that is only used in
-/// rewriteTopLevelForallToGpuBlocks. It generates GPU dialects
-/// block_id.
-static void generateGpuBlockIds(RewriterBase &rewriter, scf::ForallOp foreachOp,
-                                SmallVectorImpl<Value> &blockOps) {
-  Location loc = foreachOp->getLoc();
+/// This is a helper that is only used in rewriteTopLevelForallToGpuBlocks.
+/// It generates GPU dialect block_id.
+static void createGpuBlockIds(RewriterBase &rewriter, scf::ForallOp forallOp,
+                              SmallVectorImpl<Value> &blockOps) {
+  Location loc = forallOp->getLoc();
   OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPoint(foreachOp);
+  rewriter.setInsertionPoint(forallOp);
   IndexType indexType = rewriter.getIndexType();
   blockOps = SmallVector<Value>{
       rewriter.create<BlockIdOp>(loc, indexType, Dimension::x),
@@ -290,11 +288,11 @@ static void generateGpuBlockIds(RewriterBase &rewriter, scf::ForallOp foreachOp,
 }
 
 DiagnosedSilenceableFailure
-transform::MapForeachToBlocks::applyToOne(Operation *target,
-                                          ApplyToEachResultList &results,
-                                          transform::TransformState &state) {
+transform::MapForallToBlocks::applyToOne(Operation *target,
+                                         ApplyToEachResultList &results,
+                                         transform::TransformState &state) {
   LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
-  TrivialPatternRewriter rewriter(getContext());
+  IRRewriter rewriter(getContext());
   auto transformOp = cast<TransformOpInterface>(getOperation());
 
   if (!getGenerateGpuLaunch() && !gpuLaunch) {
@@ -339,8 +337,8 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
   diag = checkAttributeType(blockMappingAttributes,
                             topLevelForallOp.getMapping(), transformOp);
   if (diag.succeeded())
-    diag = mlir::transform::gpu::mapForeachToBlocksImpl(
-        rewriter, topLevelForallOp, generateGpuBlockIds, gridDim, transformOp,
+    diag = mlir::transform::gpu::mapForallToBlocksImpl(
+        rewriter, topLevelForallOp, createGpuBlockIds, gridDim, transformOp,
         blockMappingAttributes);
   if (diag.succeeded()) {
     diag = alterGpuLaunch(rewriter, gpuLaunch,
@@ -353,7 +351,7 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
 }
 
 //===----------------------------------------------------------------------===//
-// MapNestedForeachToThreads
+// MapNestedForallToThreads
 //===----------------------------------------------------------------------===//
 
 /// Searches `scf.forall` ops nested under `target` and maps each such
@@ -497,7 +495,7 @@ static DiagnosedSilenceableFailure rewriteOneForallToGpuThreads(
   return DiagnosedSilenceableFailure::success();
 }
 
-DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
+DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
     RewriterBase &rewriter, Operation *target,
     const SmallVectorImpl<int64_t> &blockDim,
     function_ref<void(RewriterBase &, scf::ForallOp, SmallVectorImpl<Value> &)>
@@ -527,7 +525,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
   return diag;
 }
 
-DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
+DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
     Operation *target, ApplyToEachResultList &results, TransformState &state) {
   LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
   auto transformOp = cast<TransformOpInterface>(getOperation());
@@ -548,7 +546,7 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
   }
 
   MLIRContext *ctx = getContext();
-  TrivialPatternRewriter rewriter(ctx);
+  IRRewriter rewriter(ctx);
   rewriter.setInsertionPoint(target);
 
   SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
@@ -565,7 +563,7 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
                       rewriter.create<ThreadIdOp>(forallOp->getLoc(), indexType,
                                                   Dimension::z)});
   };
-  diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl(
+  diag = mlir::transform::gpu::mapNestedForallToThreadsImpl(
       rewriter, target, blockDim, threadIdGenerator, getSyncAfterDistribute(),
       transformOp, threadMappingAttributes);
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3e6e1dfdc381d..8777439f54156 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -24,7 +24,6 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
-#include "mlir/Dialect/Transform/IR/TransformUtils.h"
 #include "mlir/Dialect/Transform/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineMap.h"
@@ -68,6 +67,13 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
 
   // Apply the pattern directly to the op.
   PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
+  // We want to discourage direct use of PatternRewriter in APIs but In this
+  // very specific case, an IRRewriter is not enough.
+  struct TrivialPatternRewriter : public PatternRewriter {
+  public:
+    explicit TrivialPatternRewriter(MLIRContext *context)
+        : PatternRewriter(context) {}
+  };
   TrivialPatternRewriter rewriter(operation->getContext());
   rewriter.setInsertionPoint(operation);
   auto result = pattern.returningMatchAndRewrite(op, rewriter);
@@ -293,7 +299,7 @@ static LogicalResult applyTilingToAll(
     if (!tilingInterfaceOp)
       return transformOp->emitError("only TilingInterface ops are supported");
 
-    TrivialPatternRewriter rewriter(target->getContext());
+    IRRewriter rewriter(target->getContext());
     rewriter.setInsertionPoint(target);
     FailureOr<scf::SCFTileAndFuseResult> tiledResults =
         applyFn(tilingInterfaceOp);
@@ -377,7 +383,7 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
       tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
       [&](TilingInterface tilingInterfaceOp)
           -> FailureOr<scf::SCFTileAndFuseResult> {
-        TrivialPatternRewriter rewriter(getContext());
+        IRRewriter rewriter(getContext());
         return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
             rewriter, tilingInterfaceOp, tileAndFuseOptions);
       });
@@ -761,7 +767,9 @@ transform::GeneralizeOp::applyToOne(LinalgOp target,
     results.push_back(target);
     return DiagnosedSilenceableFailure::success();
   }
-  FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
+  IRRewriter rewriter(getContext());
+  rewriter.setInsertionPoint(target);
+  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
   if (succeeded(generic)) {
     results.push_back(generic->getOperation());
     return DiagnosedSilenceableFailure::success();
@@ -783,7 +791,7 @@ transform::InterchangeOp::applyToOne(GenericOp target,
     results.push_back(target);
     return DiagnosedSilenceableFailure::success();
   }
-  TrivialPatternRewriter rewriter(target->getContext());
+  IRRewriter rewriter(target->getContext());
   FailureOr<GenericOp> res =
       interchangeGenericOp(rewriter, target,
                            SmallVector<unsigned>(interchangeVector.begin(),
@@ -1867,7 +1875,7 @@ transform::PromoteOp::applyToOne(LinalgOp target,
   if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
     return emitDefaultDefiniteFailure(target);
 
-  TrivialPatternRewriter rewriter(target->getContext());
+  IRRewriter rewriter(target->getContext());
   rewriter.setInsertionPoint(target);
   FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
   if (failed(res))
@@ -1966,7 +1974,7 @@ transform::ScalarizeOp::applyToOne(LinalgOp target,
     return tileSizes;
   });
   SmallVector<int64_t> emptyTileSizes;
-  TrivialPatternRewriter rewriter(getContext());
+  IRRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
       rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
@@ -2018,7 +2026,7 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
                                            TransformState &state) {
   // Collect the dynamic split points if provided.
   ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
-  TrivialPatternRewriter rewriter(getContext());
+  IRRewriter rewriter(getContext());
   SmallVector<OpFoldResult> splitPoints;
   splitPoints.reserve(payload.size());
   if (getDynamicSplitPoint()) {
@@ -2225,7 +2233,7 @@ DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
                                          unsigned(getInsertSplitDimension()),
                                          bool(getInnerParallel())};
   };
-  TrivialPatternRewriter rewriter(getContext());
+  IRRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   FailureOr<SplitReductionResult> splitResult =
       (getUseScalingAlgorithm())
@@ -2265,7 +2273,7 @@ void transform::TileReductionUsingScfOp::build(
 DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
     LinalgOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  TrivialPatternRewriter rewriter(getContext());
+  IRRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
       rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
@@ -2308,7 +2316,7 @@ void transform::TileReductionUsingForallOp::build(
 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
     LinalgOp target, transform::ApplyToEachResultList &results,
     transform::TransformState &state) {
-  TrivialPatternRewriter rewriter(getContext());
+  IRRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   SmallVector<OpFoldResult> numThreads =
       getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
@@ -2495,7 +2503,7 @@ transform::TileOp::apply(TransformResults &transformResults,
     }
 
     tilingOptions.setInterchange(getInterchange());
-    TrivialPatternRewriter rewriter(op->getContext());
+    IRRewriter rewriter(op->getContext());
     FailureOr<scf::SCFTilingResult> maybeTilingResult =
         tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
     if (failed(maybeTilingResult))
@@ -2888,7 +2896,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
     }
 
     tilingOptions.setInterchange(getInterchange());
-    TrivialPatternRewriter rewriter(tilingInterfaceOp.getContext());
+    IRRewriter rewriter(tilingInterfaceOp.getContext());
     FailureOr<scf::SCFTilingResult> tilingResult =
         tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions);
     if (failed(tilingResult))

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index ae680a1f2c837..583128a418125 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -175,8 +175,8 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
 
 /// Replace the index operations in the body of the loop nest by the matching
 /// induction variables.
-static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
-                                                PatternRewriter &rewriter,
+static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
+                                                LinalgOp linalgOp,
                                                 ArrayRef<Operation *> loopOps) {
   // Extract the induction variables of the loop nest from outer to inner.
   SmallVector<Value> allIvs;
@@ -206,7 +206,7 @@ static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
 }
 
 template <typename LoopTy>
-static FailureOr<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
+static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter,
                                                   LinalgOp linalgOp) {
   using LoadOpTy = std::conditional_t<std::is_same<LoopTy, AffineForOp>::value,
                                       AffineLoadOp, memref::LoadOp>;
@@ -247,7 +247,7 @@ static FailureOr<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
   }
   LinalgLoops loops(loopSet.begin(), loopSet.end());
   // Replace all index operations in the loop body.
-  replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops);
+  replaceIndexOpsByInductionVariables(rewriter, linalgOp, loops);
   return loops;
 }
 
@@ -367,20 +367,19 @@ mlir::createConvertLinalgToAffineLoopsPass() {
 
 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
 FailureOr<LinalgLoops>
-mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
-                                    LinalgOp linalgOp) {
+mlir::linalg::linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp) {
   return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp);
 }
 
 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
-FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
+FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(RewriterBase &rewriter,
                                                      LinalgOp linalgOp) {
   return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
 }
 
 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
 FailureOr<LinalgLoops>
-mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
+mlir::linalg::linalgOpToParallelLoops(RewriterBase &rewriter,
                                       LinalgOp linalgOp) {
   return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
index 26dee47b8926a..78b46a7c34e15 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
@@ -28,7 +28,7 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
-    PatternRewriter &b, LinalgOp op,
+    RewriterBase &b, LinalgOp op,
     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPoint(op);
@@ -238,7 +238,7 @@ static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
 
 /// Core rewrite implementation.
 FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
-    PatternRewriter &b, LinalgOp op,
+    RewriterBase &b, LinalgOp op,
     const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPoint(op);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 2ba1562064093..17c46182eb5d1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -857,7 +857,7 @@ PadOpTransformationPattern::matchAndRewrite(tensor::PadOp padOp,
 /// Filling `dest` using FillOp constant padding value if possible.
 /// Otherwise, generate a tensor::GenerateOp.
 Value GeneralizePadOpPattern::createFillOrGenerateOp(
-    PatternRewriter &rewriter, tensor::PadOp padOp, Value dest,
+    RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
     const SmallVector<Value> &dynSizes) const {
   auto padValue = padOp.getConstantPaddingValue();
   if (padValue)

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 275fab1f7a519..6b27b412f183e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -31,8 +31,8 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
-#include <type_traits>
 #include <optional>
+#include <type_traits>
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -1535,7 +1535,7 @@ struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
   /// Vectorize the copying of a tensor::PadOp's source. This is possible if
   /// each dimension size is statically know in the source type or the result
   /// type (or both).
-  static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter,
+  static LogicalResult tryVectorizeCopy(RewriterBase &rewriter,
                                         tensor::PadOp padOp, Value dest) {
     auto sourceType = padOp.getSourceType();
     auto resultType = padOp.getResultType();
@@ -2592,15 +2592,17 @@ struct Conv1DGenerator
     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
-        resVals[w] = depthwiseConv1dSliceAsMulAcc(
-            rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
+        resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
+                                                  lhsVals[linearIndex(kw, w)],
+                                                  rhsVals[kw], resVals[w]);
       }
     }
 
     // Its possible we failed to create the Fma.
     if (!llvm::all_of(resVals, [](Value v) { return v; })) {
       // Manually revert (in reverse order) to avoid leaving a bad IR state.
-      for (auto &collection : {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
+      for (auto &collection :
+           {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
         for (Value v : collection)
           rewriter.eraseOp(v.getDefiningOp());
       return rewriter.notifyMatchFailure(op, "failed to create FMA");

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 3dafd3d739bf5..b35e1045b2ee0 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -16,7 +16,6 @@
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
-#include "mlir/Dialect/Transform/IR/TransformUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 
 using namespace mlir;
@@ -89,7 +88,7 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
   for (Operation *target : state.getPayloadOps(getTarget())) {
     Location location = target->getLoc();
     Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
-    TrivialPatternRewriter rewriter(getContext());
+    IRRewriter rewriter(getContext());
     scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
     if (!exec) {
       DiagnosedSilenceableFailure diag = emitSilenceableError()
@@ -190,10 +189,10 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target,
                        getReadLatency());
       };
   scf::ForLoopPipeliningPattern pattern(options, target->getContext());
-  TrivialPatternRewriter rewriter(getContext());
+  IRRewriter rewriter(getContext());
   rewriter.setInsertionPoint(target);
   FailureOr<scf::ForOp> patternResult =
-      pattern.returningMatchAndRewrite(target, rewriter);
+      scf::pipelineForLoop(rewriter, target, options);
   if (succeeded(patternResult)) {
     results.push_back(*patternResult);
     return DiagnosedSilenceableFailure::success();

diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 4a7175b109614..b76fe151c8e61 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -62,13 +62,13 @@ struct LoopPipelinerInternal {
   bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
   /// Emits the prologue, this creates `maxStage - 1` part which will contain
   /// operations from stages [0; i], where i is the part index.
-  void emitPrologue(PatternRewriter &rewriter);
+  void emitPrologue(RewriterBase &rewriter);
   /// Gather liverange information for Values that are used in a 
diff erent stage
   /// than its definition.
   llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
   scf::ForOp createKernelLoop(
       const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
-      PatternRewriter &rewriter,
+      RewriterBase &rewriter,
       llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
   /// Emits the pipelined kernel. This clones loop operations following user
   /// order and remaps operands defined in a 
diff erent stage as their use.
@@ -76,10 +76,10 @@ struct LoopPipelinerInternal {
       scf::ForOp newForOp,
       const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
       const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
-      PatternRewriter &rewriter);
+      RewriterBase &rewriter);
   /// Emits the epilogue, this creates `maxStage - 1` part which will contain
   /// operations from stages [i; maxStage], where i is the part index.
-  llvm::SmallVector<Value> emitEpilogue(PatternRewriter &rewriter);
+  llvm::SmallVector<Value> emitEpilogue(RewriterBase &rewriter);
 };
 
 bool LoopPipelinerInternal::initializeLoopInfo(
@@ -172,7 +172,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
   return clone;
 }
 
-void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
+void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
   // Initialize the iteration argument to the loop initiale values.
   for (BlockArgument &arg : forOp.getRegionIterArgs()) {
     OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
@@ -244,7 +244,7 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
 scf::ForOp LoopPipelinerInternal::createKernelLoop(
     const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
         &crossStageValues,
-    PatternRewriter &rewriter,
+    RewriterBase &rewriter,
     llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
   // Creates the list of initial values associated to values used across
   // stages. The initial values come from the prologue created above.
@@ -299,7 +299,7 @@ void LoopPipelinerInternal::createKernel(
     const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
         &crossStageValues,
     const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
-    PatternRewriter &rewriter) {
+    RewriterBase &rewriter) {
   valueMapping.clear();
 
   // Create the kernel, we clone instruction based on the order given by
@@ -380,7 +380,7 @@ void LoopPipelinerInternal::createKernel(
     }
 
     if (predicates[useStage]) {
-      newOp = predicateFn(newOp, predicates[useStage], rewriter);
+      newOp = predicateFn(rewriter, newOp, predicates[useStage]);
       // Remap the results to the new predicated one.
       for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
         mapping.map(std::get<0>(values), std::get<1>(values));
@@ -430,7 +430,7 @@ void LoopPipelinerInternal::createKernel(
 }
 
 llvm::SmallVector<Value>
-LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) {
+LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
   llvm::SmallVector<Value> returnValues(forOp->getNumResults());
   // Emit 
diff erent versions of the induction variable. They will be
   // removed by dead code if not used.
@@ -495,9 +495,8 @@ void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
 
 } // namespace
 
-FailureOr<ForOp> ForLoopPipeliningPattern::returningMatchAndRewrite(
-    ForOp forOp, PatternRewriter &rewriter) const {
-
+FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
+                                            const PipeliningOption &options) {
   LoopPipelinerInternal pipeliner;
   if (!pipeliner.initializeLoopInfo(forOp, options))
     return failure();

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 003d2d055da2c..7a802379355fe 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -398,7 +398,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
 }
 
 FailureOr<scf::SCFReductionTilingResult>
-mlir::scf::tileReductionUsingScf(PatternRewriter &b,
+mlir::scf::tileReductionUsingScf(RewriterBase &b,
                                  PartialReductionOpInterface op,
                                  ArrayRef<OpFoldResult> tileSize) {
   Location loc = op.getLoc();

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 21b3b2dd2045d..cc4382ac1e0a1 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -11,7 +11,6 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
-#include "mlir/Dialect/Transform/IR/TransformUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
@@ -114,7 +113,14 @@ LogicalResult PatternApplicatorExtension::findAllMatches(
   }
 
   PatternApplicator applicator(it->second);
-  transform::TrivialPatternRewriter rewriter(root->getContext());
+  // We want to discourage direct use of PatternRewriter in APIs but In this
+  // very specific case, an IRRewriter is not enough.
+  struct TrivialPatternRewriter : public PatternRewriter {
+  public:
+    explicit TrivialPatternRewriter(MLIRContext *context)
+        : PatternRewriter(context) {}
+  };
+  TrivialPatternRewriter rewriter(root->getContext());
   applicator.applyDefaultCostModel();
   root->walk([&](Operation *op) {
     if (succeeded(applicator.matchAndRewrite(op, rewriter)))
@@ -453,7 +459,7 @@ transform::GetConsumersOfResult::apply(transform::TransformResults &results,
 
 DiagnosedSilenceableFailure
 transform::GetDefiningOp::apply(transform::TransformResults &results,
-                              transform::TransformState &state) {
+                                transform::TransformState &state) {
   SmallVector<Operation *> definingOps;
   for (Value v : state.getPayloadValues(getTarget())) {
     if (v.isa<BlockArgument>()) {

diff  --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
index 84b8a78c461a5..45fbf7695a446 100644
--- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file  -canonicalize -cse --verify-diagnostics %s
 
-func.func @map_nested_foreach_to_threads_not_gpu_launch() -> () {
+func.func @map_nested_forall_to_threads_not_gpu_launch() -> () {
   %1 = tensor.empty() : tensor<4xf32>
   return
 }
@@ -8,12 +8,12 @@ transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["tensor.empty"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{Given target is not gpu.launch}}
-  %1 = transform.gpu.map_nested_foreach_to_threads %funcop
+  %1 = transform.gpu.map_nested_forall_to_threads %funcop
 }
 
 // -----
 
-func.func @map_nested_foreach_to_threads_excessive_threads(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> {
+func.func @map_nested_forall_to_threads_excessive_threads(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> {
   %one = arith.constant 1 : index
   %c900 = arith.constant 900 : index
   %c9 = arith.constant 9 : index
@@ -49,12 +49,12 @@ transform.sequence failures(propagate) {
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{Trying to launch a GPU kernel with gridDim = (1, 1, 1) blockDim = (1200, 9, 1). It is larger than the limits.}}
   // expected-note @below {{"blockDim" is very large}}
-  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [1200, 9, 1] }
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [1200, 9, 1] }
 }
 
 // -----
 
-func.func @map_nested_foreach_to_threads_fewer_threads(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> {
+func.func @map_nested_forall_to_threads_fewer_threads(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> {
   %one = arith.constant 1 : index
   %c900 = arith.constant 900 : index
   %c9 = arith.constant 9 : index
@@ -90,12 +90,12 @@ transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{The requested GPU threads are fewer than the number of loop trip counts. Try to tile scf.forall before mapping or set small blockDim.}}
-  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [128, 4, 1] }
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [128, 4, 1] }
 }
 
 // -----
 
-func.func @map_nested_foreach_to_threads_dynamic_trip_count(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token, %c9 : index, %c7 : index) -> memref<2 x 32 x f32> {
+func.func @map_nested_forall_to_threads_dynamic_trip_count(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token, %c9 : index, %c7 : index) -> memref<2 x 32 x f32> {
   %one = arith.constant 1 : index
   %c900 = arith.constant 900 : index
   %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
@@ -116,12 +116,12 @@ transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{unsupported dynamic blockdim size}}
-  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [128, 4, 1] }
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [128, 4, 1] }
 }
 
 // -----
 
-func.func @map_nested_foreach_to_threads_not_buffer(%x: tensor<32x32xf32>, %y: tensor<32x32xf32>, %z: tensor<32x32xf32>, %stream : !gpu.async.token) {
+func.func @map_nested_forall_to_threads_not_buffer(%x: tensor<32x32xf32>, %y: tensor<32x32xf32>, %z: tensor<32x32xf32>, %stream : !gpu.async.token) {
   %one = arith.constant 1 : index
   %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
             threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
@@ -135,16 +135,16 @@ func.func @map_nested_foreach_to_threads_not_buffer(%x: tensor<32x32xf32>, %y: t
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  %foreach, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
+  %forall, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{only bufferized scf.forall lowers to gpu.thread_id}}
-  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [128, 4, 1] }
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [128, 4, 1] }
 }
 
 // -----
 
 
-func.func @map_foreach_to_blocks_not_gpu_launch() -> () {
+func.func @map_forall_to_blocks_not_gpu_launch() -> () {
   // expected-note @below {{when applied to this payload op}}
   %1 = tensor.empty() : tensor<4xf32>
   return
@@ -153,12 +153,12 @@ transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["tensor.empty"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{Given target is not gpu.launch}}
-  %1 = transform.gpu.map_foreach_to_blocks %funcop
+  %1 = transform.gpu.map_forall_to_blocks %funcop
 }
 
 // -----
 
-func.func @map_foreach_to_blocks_not_unique(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> {
+func.func @map_forall_to_blocks_not_unique(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> {
   %one = arith.constant 1 : index
   %c900 = arith.constant 900 : index
   %c9 = arith.constant 9 : index
@@ -190,13 +190,13 @@ transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{could not find a unique topLevel scf.forall}}
-  %1 = transform.gpu.map_foreach_to_blocks %funcop
+  %1 = transform.gpu.map_forall_to_blocks %funcop
 }
 
 // -----
 
 // expected-note @below {{when applied to this payload op}}
-func.func @map_foreach_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> {
+func.func @map_forall_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> {
   %one = arith.constant 1 : index
   %c65537 = arith.constant 65536 : index
   %c9 = arith.constant 9 : index
@@ -223,12 +223,12 @@ transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{could not find a unique topLevel scf.forall}}
-  %1 = transform.gpu.map_foreach_to_blocks %funcop { generate_gpu_launch }
+  %1 = transform.gpu.map_forall_to_blocks %funcop { generate_gpu_launch }
 }
 
 // -----
 
-func.func @map_foreach_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> {
+func.func @map_forall_to_blocks_large_loop(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> {
   %one = arith.constant 1 : index
   %c65535 = arith.constant 65535 : index
   scf.forall (%i, %j) in (%c65535, %c65535) {
@@ -244,7 +244,7 @@ transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{Trying to launch a GPU kernel with gridDim = (65535, 65535, 1) blockDim = (1, 1, 1). It is larger than the limits.}}
-  %1 = transform.gpu.map_foreach_to_blocks %funcop { generate_gpu_launch }
+  %1 = transform.gpu.map_forall_to_blocks %funcop { generate_gpu_launch }
 }
 
 // -----
@@ -271,7 +271,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{#gpu.thread<x> is duplicated, cannot map 
diff erent loops to the same processor}}
-  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [32, 32]}
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [32, 32]}
 }
 
 // -----
@@ -301,5 +301,5 @@ transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!pdl.operation) -> !pdl.operation
   // expected-error @below {{transform.structured.tile_to_forall_op failed to apply}}
-  %foreach, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
+  %forall, %tiled = transform.structured.tile_to_forall_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread<y>, #gpu.thread<x>, #gpu.thread<z> ] )
 }

diff  --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir
index 9ae28a5e64fc8..107d1fe1ff7d3 100644
--- a/mlir/test/Dialect/GPU/transform-gpu.mlir
+++ b/mlir/test/Dialect/GPU/transform-gpu.mlir
@@ -33,7 +33,7 @@ func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_foreach_to_blocks %funcop { gridDim = [12, 9]}
+  transform.gpu.map_forall_to_blocks %funcop { gridDim = [12, 9]}
 }
 
 // -----
@@ -87,7 +87,7 @@ func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !g
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9] }
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9] }
 }
 
 // -----
@@ -126,8 +126,8 @@ func.func @saxpy4d(%x: !type4d, %y: !type4d, %alpha : f32) -> !type4d {
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  %gpuLaunch = transform.gpu.map_foreach_to_blocks %funcop { generate_gpu_launch }
-  transform.gpu.map_nested_foreach_to_threads %gpuLaunch { blockDim = [32, 4, 1] }
+  %gpuLaunch = transform.gpu.map_forall_to_blocks %funcop { generate_gpu_launch }
+  transform.gpu.map_nested_forall_to_threads %gpuLaunch { blockDim = [32, 4, 1] }
 }
 
 // -----
@@ -160,7 +160,7 @@ func.func @saxpy2d_no_barrier(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false }
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false }
 }
 
 // -----
@@ -192,7 +192,7 @@ func.func @saxpy2d_singleloop(%x: !type, %y: !type, %stream : !gpu.async.token)
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [32]}
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [32]}
 }
 
 // -----
@@ -228,7 +228,7 @@ func.func @saxpy3d_fold_id_z(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %s
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false }
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false }
 }
 
 // -----
@@ -267,5 +267,5 @@ func.func @map_multi_level(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %str
 transform.sequence failures(propagate) {
 ^bb1(%arg0: !pdl.operation):
   %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation
-  transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9] }
+  transform.gpu.map_nested_forall_to_threads %funcop { blockDim = [12, 9] }
 }

diff  --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 3fb82acc78228..f35e774c9704f 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -150,8 +150,8 @@ struct TestSCFPipeliningPass
 
   /// Helper to generate "predicated" version of `op`. For simplicity we just
   /// wrap the operation in a scf.ifOp operation.
-  static Operation *predicateOp(Operation *op, Value pred,
-                                PatternRewriter &rewriter) {
+  static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
+                                Value pred) {
     Location loc = op->getLoc();
     auto ifOp =
         rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);


        


More information about the Mlir-commits mailing list