[Mlir-commits] [mlir] [MLIR][Transform] Clean up the applyTransforms API (PR #107890)

Amy Wang llvmlistbot at llvm.org
Mon Sep 9 10:03:19 PDT 2024


https://github.com/kaitingwang updated https://github.com/llvm/llvm-project/pull/107890

>From c0c00e1259052b20c0e1e2ab39a486c0f2433692 Mon Sep 17 00:00:00 2001
From: Amy Wang <kai.ting.wang at huawei.com>
Date: Mon, 9 Sep 2024 12:24:33 -0400
Subject: [PATCH] [MLIR][Transform] Clean up the applyTransforms API

Put the stateInitializer and stateExporter into the
TransformOptions to de-clutter the API.
---
 .../Interfaces/TransformInterfaces.h          | 39 ++++++++++++-------
 .../Interfaces/TransformInterfaces.cpp        |  8 ++--
 .../TestPassStateExtensionCommunication.cpp   |  6 +--
 3 files changed, 34 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
index 43193e4cd4cf63..a57feaef8549bf 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
@@ -115,12 +115,29 @@ class TransformOptions {
   /// Returns true if the expensive checks are requested.
   bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
 
-  // Returns true if enforcing a single top-level transform op is requested.
+  /// Returns true if enforcing a single top-level transform op is requested.
   bool getEnforceSingleToplevelTransformOp() const {
     return enforceSingleToplevelTransformOp;
   }
 
+  /// Sets the initializer and exporter methods for transfering data in and out
+  /// of a TransformState variable from a MLIR pass invoking applyTransforms.
+  void setStateInitializerExporter(
+      function_ref<void(TransformState &)> initializer = nullptr,
+      function_ref<LogicalResult(TransformState &)> exporter = nullptr) {
+    stateInitializer = initializer;
+    stateExporter = exporter;
+  }
+  function_ref<void(TransformState &)> getStateInitializer() const {
+    return stateInitializer;
+  }
+  function_ref<LogicalResult(TransformState &)> getStateExporter() const {
+    return stateExporter;
+  }
+
 private:
+  function_ref<void(TransformState &)> stateInitializer = nullptr;
+  function_ref<LogicalResult(TransformState &)> stateExporter = nullptr;
   bool expensiveChecksEnabled = true;
   bool enforceSingleToplevelTransformOp = true;
 };
@@ -131,13 +148,11 @@ class TransformOptions {
 /// will be executed following the internal logic of the operation. It must
 /// have the `PossibleTopLevelTransformOp` trait and not have any operands.
 /// This function internally keeps track of the transformation state.
-LogicalResult applyTransforms(
-    Operation *payloadRoot, TransformOpInterface transform,
-    const RaggedArray<MappedValue> &extraMapping = {},
-    const TransformOptions &options = TransformOptions(),
-    bool enforceToplevelTransformOp = true,
-    function_ref<void(TransformState &)> stateInitializer = nullptr,
-    function_ref<LogicalResult(TransformState &)> stateExporter = nullptr);
+LogicalResult
+applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
+                const RaggedArray<MappedValue> &extraMapping = {},
+                const TransformOptions &options = TransformOptions(),
+                bool enforceToplevelTransformOp = true);
 
 /// The state maintained across applications of various ops implementing the
 /// TransformOpInterface. The operations implementing this interface and the
@@ -217,11 +232,9 @@ class TransformState {
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
   };
 
-  friend LogicalResult
-  applyTransforms(Operation *, TransformOpInterface,
-                  const RaggedArray<MappedValue> &, const TransformOptions &,
-                  bool, function_ref<void(TransformState &)>,
-                  function_ref<LogicalResult(TransformState &)>);
+  friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
+                                       const RaggedArray<MappedValue> &,
+                                       const TransformOptions &, bool);
 
   friend TransformState
   detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 5bc6d4ee5033f1..5b4bad99678b3c 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1999,9 +1999,7 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
 LogicalResult transform::applyTransforms(
     Operation *payloadRoot, TransformOpInterface transform,
     const RaggedArray<MappedValue> &extraMapping,
-    const TransformOptions &options, bool enforceToplevelTransformOp,
-    function_ref<void(TransformState &)> stateInitializer,
-    function_ref<LogicalResult(TransformState &)> stateExporter) {
+    const TransformOptions &options, bool enforceToplevelTransformOp) {
   if (enforceToplevelTransformOp) {
     if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
         transform->getNumOperands() != 0) {
@@ -2015,6 +2013,10 @@ LogicalResult transform::applyTransforms(
 
   TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
                        options);
+  function_ref<void(TransformState &)> stateInitializer =
+      options.getStateInitializer();
+  function_ref<LogicalResult(TransformState &)> stateExporter =
+      options.getStateExporter();
   if (stateInitializer)
     stateInitializer(state);
   if (state.applyTransform(transform).checkAndReport().failed())
diff --git a/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
index 9ec99f70630a82..6107bf65465365 100644
--- a/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
@@ -78,11 +78,11 @@ struct TestPassStateExtensionCommunication
       return success();
     };
 
+    auto options = mlir::transform::TransformOptions();
+    options.setStateInitializerExporter(stateInitializer, stateExporter);
     // Process transform ops with stateInitializer and stateExporter.
     for (auto op : module.getBody()->getOps<transform::TransformOpInterface>())
-      if (failed(transform::applyTransforms(
-              module, op, {}, mlir::transform::TransformOptions(), false,
-              stateInitializer, stateExporter)))
+      if (failed(transform::applyTransforms(module, op, {}, options, false)))
         return signalPassFailure();
 
     // Print the opCollection vector after processing transform ops.



More information about the Mlir-commits mailing list