[Mlir-commits] [mlir] 74f0660 - [mlir][Transform] NFC - Pass TransformState as an argument to applyToOne methods
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Jun 22 01:19:23 PDT 2022
Author: Nicolas Vasilache
Date: 2022-06-22T01:19:13-07:00
New Revision: 74f066016096c60e0cee07f0af8de193ecb2f6c3
URL: https://github.com/llvm/llvm-project/commit/74f066016096c60e0cee07f0af8de193ecb2f6c3
DIFF: https://github.com/llvm/llvm-project/commit/74f066016096c60e0cee07f0af8de193ecb2f6c3.diff
LOG: [mlir][Transform] NFC - Pass TransformState as an argument to applyToOne methods
This will allow implementing state-dependent behavior in the future.
Differential Revision: https://reviews.llvm.org/D128327
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15186ecd36940..2d8a4986e09d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -33,7 +33,7 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
- ::mlir::linalg::LinalgOp target);
+ ::mlir::linalg::LinalgOp target, TransformState &state);
}];
}
@@ -74,7 +74,7 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
- ::mlir::linalg::LinalgOp target);
+ ::mlir::linalg::LinalgOp target, TransformState &state);
}];
}
@@ -96,7 +96,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
- ::mlir::linalg::LinalgOp target);
+ ::mlir::linalg::LinalgOp target, TransformState &state);
}];
}
@@ -124,7 +124,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
- ::mlir::linalg::LinalgOp target);
+ ::mlir::linalg::LinalgOp target, TransformState &state);
}];
}
@@ -149,7 +149,7 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
- ::mlir::linalg::LinalgOp target);
+ ::mlir::linalg::LinalgOp target, TransformState &state);
}];
}
@@ -218,7 +218,7 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
let extraClassDeclaration = [{
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
- ::mlir::linalg::LinalgOp target);
+ ::mlir::linalg::LinalgOp target, TransformState &state);
}];
}
@@ -275,7 +275,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
- ::mlir::FailureOr<Operation *> applyToOne(::mlir::Operation *target);
+ ::mlir::FailureOr<Operation *> applyToOne(
+ ::mlir::Operation *target, TransformState &state);
}];
}
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 47a56691b46ff..9be9fb0d3c20f 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -88,7 +88,8 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
+ ::mlir::scf::ForOp loop, TransformState &state);
}];
}
@@ -115,7 +116,8 @@ def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
- ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
+ ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
+ ::mlir::scf::ForOp loop, TransformState &state);
}];
}
@@ -137,7 +139,8 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
- ::mlir::LogicalResult applyToOne(::mlir::scf::ForOp loop);
+ ::mlir::LogicalResult applyToOne(
+ ::mlir::scf::ForOp loop, TransformState &state);
}];
}
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 7392a289135f6..390d6c6240657 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -582,9 +582,9 @@ class PossibleTopLevelTransformOpTrait
/// transformation to a single operation handle and producing one or multiple
/// operation handles.
/// The op must implement a method with one of the following signatures:
-/// - FailureOr<convertible-to-Operation*> applyToOne(OpTy)
-/// - FailureOr<SmallVector<convertible-to-Operation*>> applyToOne(OpTy)
-/// - LogicalResult applyToOne(OpTy)
+/// - FailureOr<convertible-to-Operation*> applyToOne(OpTy, state)
+/// - FailureOr<SmallVector<convertible-to-Operation*>>applyToOne(OpTy, state)
+/// - LogicalResult applyToOne(OpTy, state)
/// to perform a transformation that is applied in turn to all payload IR
/// operations that correspond to the handle of the transform IR operation.
/// In the functions above, OpTy is either Operation * or a concrete payload IR
@@ -811,7 +811,7 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
// produced.
DiagnosedSilenceableFailure result = detail::applyTransformToEach(
targets, results, [&](TransformOpType specificOp) {
- return static_cast<OpTy *>(this)->applyToOne(specificOp);
+ return static_cast<OpTy *>(this)->applyToOne(specificOp, state);
});
if (!result.succeeded())
return result;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index db7b1808c658f..f8ce4701ab74d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -76,7 +76,8 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
// DecomposeOp
//===----------------------------------------------------------------------===//
-FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
+FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target,
+ TransformState &state) {
FailureOr<LinalgOp> windowed =
tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
if (succeeded(windowed))
@@ -220,7 +221,8 @@ LogicalResult transform::FuseOp::verify() {
// GeneralizeOp
//===----------------------------------------------------------------------===//
-FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
+FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target,
+ TransformState &state) {
// Exit early if no transformation is needed.
if (isa<GenericOp>(target))
return target;
@@ -236,7 +238,8 @@ FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
// InterchangeOp
//===----------------------------------------------------------------------===//
-FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) {
+FailureOr<LinalgOp>
+transform::InterchangeOp::applyToOne(LinalgOp target, TransformState &state) {
SmallVector<unsigned> interchangeVector =
extractUIntArray(getIteratorInterchange());
// Exit early if no transformation is needed.
@@ -272,7 +275,8 @@ LogicalResult transform::InterchangeOp::verify() {
// PadOp
//===---------------------------------------------------------------------===//
-FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) {
+FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target,
+ TransformState &state) {
// Convert the integer packing flags to booleans.
SmallVector<bool> packPaddings;
for (int64_t packPadding : extractI64Array(getPackPaddings()))
@@ -377,7 +381,8 @@ LogicalResult transform::PadOp::verify() {
// ScalarizeOp
//===----------------------------------------------------------------------===//
-FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
+FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target,
+ TransformState &state) {
LinalgTilingOptions tilingOptions;
tilingOptions.scalarizeDynamicDims();
// Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
@@ -399,7 +404,8 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
//===----------------------------------------------------------------------===//
FailureOr<SmallVector<Operation *>>
-transform::SplitReductionOp::applyToOne(LinalgOp target) {
+transform::SplitReductionOp::applyToOne(LinalgOp target,
+ TransformState &state) {
ControlSplitReductionFn splitFn = [&](LinalgOp) {
return std::pair<int64_t, unsigned>(getSplitFactor(),
getInsertSplitDimension());
@@ -455,7 +461,8 @@ void TileOp::print(OpAsmPrinter &p) {
// VectorizeOp
//===----------------------------------------------------------------------===//
-FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target) {
+FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target,
+ TransformState &state) {
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
InFlightDiagnostic diag = emitOpError()
<< "applies only to isolated-from-above targets";
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index f2c80d6673574..b08bbde33e2c8 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -127,7 +127,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
// LoopPeelOp
//===----------------------------------------------------------------------===//
-FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop) {
+FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop,
+ TransformState &state) {
scf::ForOp result;
IRRewriter rewriter(loop->getContext());
LogicalResult status =
@@ -180,7 +181,8 @@ loopScheduling(scf::ForOp forOp,
}
}
-FailureOr<scf::ForOp> transform::LoopPipelineOp::applyToOne(scf::ForOp loop) {
+FailureOr<scf::ForOp>
+transform::LoopPipelineOp::applyToOne(scf::ForOp loop, TransformState &state) {
scf::PipeliningOption options;
options.getScheduleFn =
[this](scf::ForOp forOp,
@@ -203,7 +205,8 @@ FailureOr<scf::ForOp> transform::LoopPipelineOp::applyToOne(scf::ForOp loop) {
// LoopUnrollOp
//===----------------------------------------------------------------------===//
-LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop) {
+LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop,
+ TransformState &state) {
if (failed(loopUnrollByFactor(loop, getFactor())))
return reportUnknownTransformError(loop);
return success();
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 43c181651a42e..c48a7936adb12 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -227,7 +227,8 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
}
FailureOr<SmallVector<Operation *>>
-mlir::test::TestWrongNumberOfResultsOp::applyToOne(Operation *) {
+mlir::test::TestWrongNumberOfResultsOp::applyToOne(
+ Operation *, transform::TransformState &state) {
return SmallVector<Operation *>{};
}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index d811a57d3c112..1b8ddb9649c34 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -140,7 +140,7 @@ def TestWrongNumberOfResultsOp
let cppNamespace = "::mlir::test";
let extraClassDeclaration = [{
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
- ::mlir::Operation *target);
+ ::mlir::Operation *target, transform::TransformState &state);
}];
}
More information about the Mlir-commits
mailing list