[Mlir-commits] [mlir] [MLIR][Transform] Clean up the applyTransforms API (PR #107890)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 9 09:37:06 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Amy Wang (kaitingwang)

<details>
<summary>Changes</summary>

Put the stateInitializer and stateExporter into the TransformOptions to de-clutter the API.

---
Full diff: https://github.com/llvm/llvm-project/pull/107890.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h (+26-13) 
- (modified) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp (+5-3) 
- (modified) mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp (+3-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
index 43193e4cd4cf63..a57feaef8549bf 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
@@ -115,12 +115,29 @@ class TransformOptions {
   /// Returns true if the expensive checks are requested.
   bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
 
-  // Returns true if enforcing a single top-level transform op is requested.
+  /// Returns true if enforcing a single top-level transform op is requested.
   bool getEnforceSingleToplevelTransformOp() const {
     return enforceSingleToplevelTransformOp;
   }
 
+  /// Sets the initializer and exporter methods for transfering data in and out
+  /// of a TransformState variable from a MLIR pass invoking applyTransforms.
+  void setStateInitializerExporter(
+      function_ref<void(TransformState &)> initializer = nullptr,
+      function_ref<LogicalResult(TransformState &)> exporter = nullptr) {
+    stateInitializer = initializer;
+    stateExporter = exporter;
+  }
+  function_ref<void(TransformState &)> getStateInitializer() const {
+    return stateInitializer;
+  }
+  function_ref<LogicalResult(TransformState &)> getStateExporter() const {
+    return stateExporter;
+  }
+
 private:
+  function_ref<void(TransformState &)> stateInitializer = nullptr;
+  function_ref<LogicalResult(TransformState &)> stateExporter = nullptr;
   bool expensiveChecksEnabled = true;
   bool enforceSingleToplevelTransformOp = true;
 };
@@ -131,13 +148,11 @@ class TransformOptions {
 /// will be executed following the internal logic of the operation. It must
 /// have the `PossibleTopLevelTransformOp` trait and not have any operands.
 /// This function internally keeps track of the transformation state.
-LogicalResult applyTransforms(
-    Operation *payloadRoot, TransformOpInterface transform,
-    const RaggedArray<MappedValue> &extraMapping = {},
-    const TransformOptions &options = TransformOptions(),
-    bool enforceToplevelTransformOp = true,
-    function_ref<void(TransformState &)> stateInitializer = nullptr,
-    function_ref<LogicalResult(TransformState &)> stateExporter = nullptr);
+LogicalResult
+applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
+                const RaggedArray<MappedValue> &extraMapping = {},
+                const TransformOptions &options = TransformOptions(),
+                bool enforceToplevelTransformOp = true);
 
 /// The state maintained across applications of various ops implementing the
 /// TransformOpInterface. The operations implementing this interface and the
@@ -217,11 +232,9 @@ class TransformState {
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
   };
 
-  friend LogicalResult
-  applyTransforms(Operation *, TransformOpInterface,
-                  const RaggedArray<MappedValue> &, const TransformOptions &,
-                  bool, function_ref<void(TransformState &)>,
-                  function_ref<LogicalResult(TransformState &)>);
+  friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
+                                       const RaggedArray<MappedValue> &,
+                                       const TransformOptions &, bool);
 
   friend TransformState
   detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 5bc6d4ee5033f1..5b4bad99678b3c 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1999,9 +1999,7 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
 LogicalResult transform::applyTransforms(
     Operation *payloadRoot, TransformOpInterface transform,
     const RaggedArray<MappedValue> &extraMapping,
-    const TransformOptions &options, bool enforceToplevelTransformOp,
-    function_ref<void(TransformState &)> stateInitializer,
-    function_ref<LogicalResult(TransformState &)> stateExporter) {
+    const TransformOptions &options, bool enforceToplevelTransformOp) {
   if (enforceToplevelTransformOp) {
     if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
         transform->getNumOperands() != 0) {
@@ -2015,6 +2013,10 @@ LogicalResult transform::applyTransforms(
 
   TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
                        options);
+  function_ref<void(TransformState &)> stateInitializer =
+      options.getStateInitializer();
+  function_ref<LogicalResult(TransformState &)> stateExporter =
+      options.getStateExporter();
   if (stateInitializer)
     stateInitializer(state);
   if (state.applyTransform(transform).checkAndReport().failed())
diff --git a/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
index 9ec99f70630a82..a5d95d719af6e1 100644
--- a/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
@@ -78,11 +78,12 @@ struct TestPassStateExtensionCommunication
       return success();
     };
 
+    auto options = mlir::transform::TransformOptions();
+    options.setStateInitializerExporter(stateInitializer, stateExporter);
     // Process transform ops with stateInitializer and stateExporter.
     for (auto op : module.getBody()->getOps<transform::TransformOpInterface>())
       if (failed(transform::applyTransforms(
-              module, op, {}, mlir::transform::TransformOptions(), false,
-              stateInitializer, stateExporter)))
+              module, op, {}, options, false)))
         return signalPassFailure();
 
     // Print the opCollection vector after processing transform ops.

``````````

</details>


https://github.com/llvm/llvm-project/pull/107890


More information about the Mlir-commits mailing list