[Mlir-commits] [mlir] [MLIR][Transform] Clean up the applyTransforms API (PR #107890)
Amy Wang
llvmlistbot at llvm.org
Mon Sep 9 10:03:19 PDT 2024
https://github.com/kaitingwang updated https://github.com/llvm/llvm-project/pull/107890
>From c0c00e1259052b20c0e1e2ab39a486c0f2433692 Mon Sep 17 00:00:00 2001
From: Amy Wang <kai.ting.wang at huawei.com>
Date: Mon, 9 Sep 2024 12:24:33 -0400
Subject: [PATCH] [MLIR][Transform] Clean up the applyTransforms API
Put the stateInitializer and stateExporter into the
TransformOptions to de-clutter the API.
---
.../Interfaces/TransformInterfaces.h | 39 ++++++++++++-------
.../Interfaces/TransformInterfaces.cpp | 8 ++--
.../TestPassStateExtensionCommunication.cpp | 6 +--
3 files changed, 34 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
index 43193e4cd4cf63..a57feaef8549bf 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
@@ -115,12 +115,29 @@ class TransformOptions {
/// Returns true if the expensive checks are requested.
bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
- // Returns true if enforcing a single top-level transform op is requested.
+ /// Returns true if enforcing a single top-level transform op is requested.
bool getEnforceSingleToplevelTransformOp() const {
return enforceSingleToplevelTransformOp;
}
+ /// Sets the initializer and exporter methods for transfering data in and out
+ /// of a TransformState variable from a MLIR pass invoking applyTransforms.
+ void setStateInitializerExporter(
+ function_ref<void(TransformState &)> initializer = nullptr,
+ function_ref<LogicalResult(TransformState &)> exporter = nullptr) {
+ stateInitializer = initializer;
+ stateExporter = exporter;
+ }
+ function_ref<void(TransformState &)> getStateInitializer() const {
+ return stateInitializer;
+ }
+ function_ref<LogicalResult(TransformState &)> getStateExporter() const {
+ return stateExporter;
+ }
+
private:
+ function_ref<void(TransformState &)> stateInitializer = nullptr;
+ function_ref<LogicalResult(TransformState &)> stateExporter = nullptr;
bool expensiveChecksEnabled = true;
bool enforceSingleToplevelTransformOp = true;
};
@@ -131,13 +148,11 @@ class TransformOptions {
/// will be executed following the internal logic of the operation. It must
/// have the `PossibleTopLevelTransformOp` trait and not have any operands.
/// This function internally keeps track of the transformation state.
-LogicalResult applyTransforms(
- Operation *payloadRoot, TransformOpInterface transform,
- const RaggedArray<MappedValue> &extraMapping = {},
- const TransformOptions &options = TransformOptions(),
- bool enforceToplevelTransformOp = true,
- function_ref<void(TransformState &)> stateInitializer = nullptr,
- function_ref<LogicalResult(TransformState &)> stateExporter = nullptr);
+LogicalResult
+applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
+ const RaggedArray<MappedValue> &extraMapping = {},
+ const TransformOptions &options = TransformOptions(),
+ bool enforceToplevelTransformOp = true);
/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
@@ -217,11 +232,9 @@ class TransformState {
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
};
- friend LogicalResult
- applyTransforms(Operation *, TransformOpInterface,
- const RaggedArray<MappedValue> &, const TransformOptions &,
- bool, function_ref<void(TransformState &)>,
- function_ref<LogicalResult(TransformState &)>);
+ friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
+ const RaggedArray<MappedValue> &,
+ const TransformOptions &, bool);
friend TransformState
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 5bc6d4ee5033f1..5b4bad99678b3c 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1999,9 +1999,7 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
LogicalResult transform::applyTransforms(
Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping,
- const TransformOptions &options, bool enforceToplevelTransformOp,
- function_ref<void(TransformState &)> stateInitializer,
- function_ref<LogicalResult(TransformState &)> stateExporter) {
+ const TransformOptions &options, bool enforceToplevelTransformOp) {
if (enforceToplevelTransformOp) {
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
@@ -2015,6 +2013,10 @@ LogicalResult transform::applyTransforms(
TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
options);
+ function_ref<void(TransformState &)> stateInitializer =
+ options.getStateInitializer();
+ function_ref<LogicalResult(TransformState &)> stateExporter =
+ options.getStateExporter();
if (stateInitializer)
stateInitializer(state);
if (state.applyTransform(transform).checkAndReport().failed())
diff --git a/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
index 9ec99f70630a82..6107bf65465365 100644
--- a/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
@@ -78,11 +78,11 @@ struct TestPassStateExtensionCommunication
return success();
};
+ auto options = mlir::transform::TransformOptions();
+ options.setStateInitializerExporter(stateInitializer, stateExporter);
// Process transform ops with stateInitializer and stateExporter.
for (auto op : module.getBody()->getOps<transform::TransformOpInterface>())
- if (failed(transform::applyTransforms(
- module, op, {}, mlir::transform::TransformOptions(), false,
- stateInitializer, stateExporter)))
+ if (failed(transform::applyTransforms(module, op, {}, options, false)))
return signalPassFailure();
// Print the opCollection vector after processing transform ops.
More information about the Mlir-commits
mailing list