[Mlir-commits] [mlir] a885031 - [mlir][Transform][NFC] Use a single rewriter instead of duplicating it everywhere
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Dec 1 03:54:59 PST 2022
Author: Nicolas Vasilache
Date: 2022-12-01T03:54:31-08:00
New Revision: a8850312c106d71f9e35fd902f9dcd3c4ac0a690
URL: https://github.com/llvm/llvm-project/commit/a8850312c106d71f9e35fd902f9dcd3c4ac0a690
DIFF: https://github.com/llvm/llvm-project/commit/a8850312c106d71f9e35fd902f9dcd3c4ac0a690.diff
LOG: [mlir][Transform][NFC] Use a single rewriter instead of duplicating it everywhere
Differential Revision: https://reviews.llvm.org/D139094
Added:
mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h
Modified:
mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h b/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h
new file mode 100644
index 0000000000000..512c9151ce4cc
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformUtils.h
@@ -0,0 +1,29 @@
+//===- TransformUtils.h - Transform Dialect Utils ----------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H
+#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace transform {
+
+/// A simple pattern rewriter that can be constructed from a context. This is
+/// necessary to apply patterns to a specific op locally.
+class TrivialPatternRewriter : public PatternRewriter {
+public:
+ explicit TrivialPatternRewriter(MLIRContext *context)
+ : PatternRewriter(context) {}
+};
+
+} // namespace transform
+} // namespace mlir
+
+#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMUTILS_H
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 605c07f33ad44..eafc6d9eafabc 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -11,17 +11,10 @@
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
using namespace mlir;
-namespace {
-/// A simple pattern rewriter that implements no special logic.
-class SimpleRewriter : public PatternRewriter {
-public:
- SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
-};
-} // namespace
-
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index a3d0261e1a033..e035a26a84c07 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -16,24 +16,13 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/Value.h"
-#include "llvm/ADT/None.h"
-#include "llvm/ADT/Optional.h"
using namespace mlir;
using namespace mlir::gpu;
using namespace mlir::transform;
-namespace {
-/// A simple pattern rewriter that implements no special logic.
-class SimpleRewriter : public PatternRewriter {
-public:
- SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
-};
-} // namespace
-
/// Check if given mapping attributes are one of the desired attributes
static DiagnosedSilenceableFailure
checkAttributeType(ArrayRef<DeviceMappingAttrInterface> threadMappingAttributes,
@@ -135,7 +124,7 @@ createGpuLaunch(RewriterBase &rewriter, Location loc,
/// Alter kernel configuration of the given kernel.
static DiagnosedSilenceableFailure
-alterGpuLaunch(SimpleRewriter &rewriter, LaunchOp gpuLaunch,
+alterGpuLaunch(TrivialPatternRewriter &rewriter, LaunchOp gpuLaunch,
TransformOpInterface transformOp,
Optional<int64_t> gridDimX = llvm::None,
Optional<int64_t> gridDimY = llvm::None,
@@ -305,7 +294,7 @@ transform::MapForeachToBlocks::applyToOne(Operation *target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
auto transformOp = cast<TransformOpInterface>(getOperation());
if (!getGenerateGpuLaunch() && !gpuLaunch) {
@@ -555,7 +544,7 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
}
MLIRContext *ctx = getContext();
- SimpleRewriter rewriter(ctx);
+ TrivialPatternRewriter rewriter(ctx);
rewriter.setInsertionPoint(target);
SmallVector<DeviceMappingAttrInterface> threadMappingAttributes = {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 53709a26f58ef..4bfa5b4a2e8a1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -41,14 +42,6 @@ static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
return result;
}
-namespace {
-/// A simple pattern rewriter that implements no special logic.
-class SimpleRewriter : public PatternRewriter {
-public:
- SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
-};
-} // namespace
-
/// Attempts to apply the pattern specified as template argument to the given
/// operation. The pattern is expected to have a `returningMatchAndRewrite`
/// function that returns the "main" result or failure. Returns failure if the
@@ -65,7 +58,7 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
// Apply the pattern directly to the op.
PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
- SimpleRewriter rewriter(operation->getContext());
+ TrivialPatternRewriter rewriter(operation->getContext());
rewriter.setInsertionPoint(operation);
auto result = pattern.returningMatchAndRewrite(op, rewriter);
if (failed(result))
@@ -125,7 +118,7 @@ static LogicalResult applyTilingToAll(
if (!tilingInterfaceOp)
return transformOp->emitError("only TilingInterface ops are supported");
- SimpleRewriter rewriter(target->getContext());
+ TrivialPatternRewriter rewriter(target->getContext());
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFTileAndFuseResult> tiledResults =
applyFn(tilingInterfaceOp);
@@ -209,7 +202,7 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
[&](TilingInterface tilingInterfaceOp)
-> FailureOr<scf::SCFTileAndFuseResult> {
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
rewriter, tilingInterfaceOp, tileAndFuseOptions);
});
@@ -601,7 +594,7 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
results.push_back(target);
return DiagnosedSilenceableFailure(success());
}
- SimpleRewriter rewriter(target->getContext());
+ TrivialPatternRewriter rewriter(target->getContext());
FailureOr<GenericOp> res =
interchangeGenericOp(rewriter, target, interchangeVector);
if (failed(res))
@@ -875,7 +868,7 @@ transform::PromoteOp::applyToOne(linalg::LinalgOp target,
if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
- SimpleRewriter rewriter(target->getContext());
+ TrivialPatternRewriter rewriter(target->getContext());
rewriter.setInsertionPoint(target);
FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
if (failed(res))
@@ -974,7 +967,7 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
return tileSizes;
});
SmallVector<int64_t> emptyTileSizes;
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
@@ -993,7 +986,7 @@ DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
TransformState &state) {
// Collect the dynamic split points if provided.
ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
SmallVector<OpFoldResult> splitPoints;
splitPoints.reserve(payload.size());
if (getDynamicSplitPoint()) {
@@ -1122,8 +1115,7 @@ void SplitOp::print(OpAsmPrinter &printer) {
}
LogicalResult SplitOp::verify() {
- if ((static_cast<int64_t>(getStaticSplitPoint()) !=
- ShapedType::kDynamic) ^
+ if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
(getDynamicSplitPoint() == nullptr)) {
return emitOpError() << "expects either a dynamic or a static split "
"point to be provided";
@@ -1172,7 +1164,7 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
unsigned(getInsertSplitDimension()),
bool(getInnerParallel())};
};
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<SplitReductionResult> splitResult =
(getUseScalingAlgorithm())
@@ -1195,7 +1187,7 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
SmallVector<OpFoldResult> sizes;
@@ -1223,7 +1215,7 @@ DiagnosedSilenceableFailure
transform::TileReductionUsingForeachThreadOp::applyToOne(
linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
SmallVector<int64_t> numThreads = extractFromI64ArrayAttr(getNumThreads());
SmallVector<OpFoldResult> numThreadResults;
@@ -1321,7 +1313,7 @@ transform::TileOp::apply(TransformResults &transformResults,
}
tilingOptions.setInterchange(getInterchange());
- SimpleRewriter rewriter(linalgOp.getContext());
+ TrivialPatternRewriter rewriter(linalgOp.getContext());
FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
rewriter, cast<TilingInterface>(linalgOp.getOperation()),
tilingOptions);
@@ -1714,7 +1706,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults,
}
tilingOptions.setInterchange(getInterchange());
- SimpleRewriter rewriter(tilingInterfaceOp.getContext());
+ TrivialPatternRewriter rewriter(tilingInterfaceOp.getContext());
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions);
if (failed(tilingResult))
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 02c18c8d72a62..8777662ec3902 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -16,18 +16,11 @@
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
using namespace mlir;
-namespace {
-/// A simple pattern rewriter that implements no special logic.
-class SimpleRewriter : public PatternRewriter {
-public:
- SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
-};
-} // namespace
-
//===----------------------------------------------------------------------===//
// GetParentForOp
//===----------------------------------------------------------------------===//
@@ -97,7 +90,7 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
for (Operation *target : state.getPayloadOps(getTarget())) {
Location location = target->getLoc();
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
if (!exec) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
@@ -201,7 +194,7 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target,
getReadLatency());
};
scf::ForLoopPipeliningPattern pattern(options, target->getContext());
- SimpleRewriter rewriter(getContext());
+ TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<scf::ForOp> patternResult =
pattern.returningMatchAndRewrite(target, rewriter);
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 76e0c89adb7d4..98ab3f71f06a0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/IR/TransformUtils.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -33,14 +34,6 @@ using namespace mlir;
//===----------------------------------------------------------------------===//
namespace {
-/// A simple pattern rewriter that can be constructed from a context. This is
-/// necessary to apply patterns to a specific op locally.
-class TrivialPatternRewriter : public PatternRewriter {
-public:
- explicit TrivialPatternRewriter(MLIRContext *context)
- : PatternRewriter(context) {}
-};
-
/// A TransformState extension that keeps track of compiled PDL pattern sets.
/// This is intended to be used along the WithPDLPatterns op. The extension
/// can be constructed given an operation that has a SymbolTable trait and
@@ -109,7 +102,7 @@ LogicalResult PatternApplicatorExtension::findAllMatches(
}
PatternApplicator applicator(it->second);
- TrivialPatternRewriter rewriter(root->getContext());
+ transform::TrivialPatternRewriter rewriter(root->getContext());
applicator.applyDefaultCostModel();
root->walk([&](Operation *op) {
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
More information about the Mlir-commits
mailing list