[Mlir-commits] [mlir] 74f0660 - [mlir][Transform] NFC - Pass TransformState as an argument to applyToOne methods

Nicolas Vasilache llvmlistbot at llvm.org
Wed Jun 22 01:19:23 PDT 2022


Author: Nicolas Vasilache
Date: 2022-06-22T01:19:13-07:00
New Revision: 74f066016096c60e0cee07f0af8de193ecb2f6c3

URL: https://github.com/llvm/llvm-project/commit/74f066016096c60e0cee07f0af8de193ecb2f6c3
DIFF: https://github.com/llvm/llvm-project/commit/74f066016096c60e0cee07f0af8de193ecb2f6c3.diff

LOG: [mlir][Transform] NFC - Pass TransformState as an argument to applyToOne methods

This will allow implementing state-dependent behavior in the future.

Differential Revision: https://reviews.llvm.org/D128327

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15186ecd36940..2d8a4986e09d6 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -33,7 +33,7 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
 
   let extraClassDeclaration = [{
     ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
-        ::mlir::linalg::LinalgOp target);
+        ::mlir::linalg::LinalgOp target, TransformState &state);
   }];
 }
 
@@ -74,7 +74,7 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
 
   let extraClassDeclaration = [{
     ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
-        ::mlir::linalg::LinalgOp target);
+        ::mlir::linalg::LinalgOp target, TransformState &state);
   }];
 }
 
@@ -96,7 +96,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
 
   let extraClassDeclaration = [{
     ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
-        ::mlir::linalg::LinalgOp target);
+        ::mlir::linalg::LinalgOp target, TransformState &state);
   }];
 }
 
@@ -124,7 +124,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
 
   let extraClassDeclaration = [{
     ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
-        ::mlir::linalg::LinalgOp target);
+        ::mlir::linalg::LinalgOp target, TransformState &state);
   }];
 }
 
@@ -149,7 +149,7 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
 
   let extraClassDeclaration = [{
     ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
-        ::mlir::linalg::LinalgOp target);
+        ::mlir::linalg::LinalgOp target, TransformState &state);
   }];
 }
 
@@ -218,7 +218,7 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
 
   let extraClassDeclaration = [{
     ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
-        ::mlir::linalg::LinalgOp target);
+        ::mlir::linalg::LinalgOp target, TransformState &state);
   }];
 }
 
@@ -275,7 +275,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
   let assemblyFormat = "$target attr-dict";
 
   let extraClassDeclaration = [{
-    ::mlir::FailureOr<Operation *> applyToOne(::mlir::Operation *target);
+    ::mlir::FailureOr<Operation *> applyToOne(
+      ::mlir::Operation *target, TransformState &state);
   }];
 }
 

diff  --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 47a56691b46ff..9be9fb0d3c20f 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -88,7 +88,8 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
   let assemblyFormat = "$target attr-dict";
 
   let extraClassDeclaration = [{
-    ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
+    ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
+      ::mlir::scf::ForOp loop, TransformState &state);
   }];
 }
  
@@ -115,7 +116,8 @@ def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
   let assemblyFormat = "$target attr-dict";
 
   let extraClassDeclaration = [{
-    ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
+    ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
+      ::mlir::scf::ForOp loop, TransformState &state);
   }];
 }
 
@@ -137,7 +139,8 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
   let assemblyFormat = "$target attr-dict";
 
   let extraClassDeclaration = [{
-    ::mlir::LogicalResult applyToOne(::mlir::scf::ForOp loop);
+    ::mlir::LogicalResult applyToOne(
+      ::mlir::scf::ForOp loop, TransformState &state);
   }];
 }
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 7392a289135f6..390d6c6240657 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -582,9 +582,9 @@ class PossibleTopLevelTransformOpTrait
 /// transformation to a single operation handle and producing one or multiple
 /// operation handles.
 /// The op must implement a method with one of the following signatures:
-///   - FailureOr<convertible-to-Operation*> applyToOne(OpTy)
-///   - FailureOr<SmallVector<convertible-to-Operation*>> applyToOne(OpTy)
-///   - LogicalResult applyToOne(OpTy)
+///   - FailureOr<convertible-to-Operation*> applyToOne(OpTy, state)
+///   - FailureOr<SmallVector<convertible-to-Operation*>>applyToOne(OpTy, state)
+///   - LogicalResult applyToOne(OpTy, state)
 /// to perform a transformation that is applied in turn to all payload IR
 /// operations that correspond to the handle of the transform IR operation.
 /// In the functions above, OpTy is either Operation * or a concrete payload IR
@@ -811,7 +811,7 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
   // produced.
   DiagnosedSilenceableFailure result = detail::applyTransformToEach(
       targets, results, [&](TransformOpType specificOp) {
-        return static_cast<OpTy *>(this)->applyToOne(specificOp);
+        return static_cast<OpTy *>(this)->applyToOne(specificOp, state);
       });
   if (!result.succeeded())
     return result;

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index db7b1808c658f..f8ce4701ab74d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -76,7 +76,8 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
 // DecomposeOp
 //===----------------------------------------------------------------------===//
 
-FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
+FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target,
+                                                       TransformState &state) {
   FailureOr<LinalgOp> windowed =
       tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
   if (succeeded(windowed))
@@ -220,7 +221,8 @@ LogicalResult transform::FuseOp::verify() {
 // GeneralizeOp
 //===----------------------------------------------------------------------===//
 
-FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
+FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target,
+                                                        TransformState &state) {
   // Exit early if no transformation is needed.
   if (isa<GenericOp>(target))
     return target;
@@ -236,7 +238,8 @@ FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
 // InterchangeOp
 //===----------------------------------------------------------------------===//
 
-FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) {
+FailureOr<LinalgOp>
+transform::InterchangeOp::applyToOne(LinalgOp target, TransformState &state) {
   SmallVector<unsigned> interchangeVector =
       extractUIntArray(getIteratorInterchange());
   // Exit early if no transformation is needed.
@@ -272,7 +275,8 @@ LogicalResult transform::InterchangeOp::verify() {
 // PadOp
 //===---------------------------------------------------------------------===//
 
-FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) {
+FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target,
+                                                 TransformState &state) {
   // Convert the integer packing flags to booleans.
   SmallVector<bool> packPaddings;
   for (int64_t packPadding : extractI64Array(getPackPaddings()))
@@ -377,7 +381,8 @@ LogicalResult transform::PadOp::verify() {
 // ScalarizeOp
 //===----------------------------------------------------------------------===//
 
-FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
+FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target,
+                                                       TransformState &state) {
   LinalgTilingOptions tilingOptions;
   tilingOptions.scalarizeDynamicDims();
   // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
@@ -399,7 +404,8 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
 //===----------------------------------------------------------------------===//
 
 FailureOr<SmallVector<Operation *>>
-transform::SplitReductionOp::applyToOne(LinalgOp target) {
+transform::SplitReductionOp::applyToOne(LinalgOp target,
+                                        TransformState &state) {
   ControlSplitReductionFn splitFn = [&](LinalgOp) {
     return std::pair<int64_t, unsigned>(getSplitFactor(),
                                         getInsertSplitDimension());
@@ -455,7 +461,8 @@ void TileOp::print(OpAsmPrinter &p) {
 // VectorizeOp
 //===----------------------------------------------------------------------===//
 
-FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target) {
+FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target,
+                                               TransformState &state) {
   if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
     InFlightDiagnostic diag = emitOpError()
                               << "applies only to isolated-from-above targets";

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index f2c80d6673574..b08bbde33e2c8 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -127,7 +127,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
 // LoopPeelOp
 //===----------------------------------------------------------------------===//
 
-FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop) {
+FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop,
+                                                        TransformState &state) {
   scf::ForOp result;
   IRRewriter rewriter(loop->getContext());
   LogicalResult status =
@@ -180,7 +181,8 @@ loopScheduling(scf::ForOp forOp,
   }
 }
 
-FailureOr<scf::ForOp> transform::LoopPipelineOp::applyToOne(scf::ForOp loop) {
+FailureOr<scf::ForOp>
+transform::LoopPipelineOp::applyToOne(scf::ForOp loop, TransformState &state) {
   scf::PipeliningOption options;
   options.getScheduleFn =
       [this](scf::ForOp forOp,
@@ -203,7 +205,8 @@ FailureOr<scf::ForOp> transform::LoopPipelineOp::applyToOne(scf::ForOp loop) {
 // LoopUnrollOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop) {
+LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop,
+                                                  TransformState &state) {
   if (failed(loopUnrollByFactor(loop, getFactor())))
     return reportUnknownTransformError(loop);
   return success();

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 43c181651a42e..c48a7936adb12 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -227,7 +227,8 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
 }
 
 FailureOr<SmallVector<Operation *>>
-mlir::test::TestWrongNumberOfResultsOp::applyToOne(Operation *) {
+mlir::test::TestWrongNumberOfResultsOp::applyToOne(
+    Operation *, transform::TransformState &state) {
   return SmallVector<Operation *>{};
 }
 

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index d811a57d3c112..1b8ddb9649c34 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -140,7 +140,7 @@ def TestWrongNumberOfResultsOp
   let cppNamespace = "::mlir::test";
   let extraClassDeclaration = [{
     ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
-        ::mlir::Operation *target);
+        ::mlir::Operation *target, transform::TransformState &state);
   }];
 }
 


        


More information about the Mlir-commits mailing list