[Mlir-commits] [mlir] [MLIR][Transform] Make TransformState constructor public (PR #101186)
Amy Wang
llvmlistbot at llvm.org
Wed Aug 21 09:29:09 PDT 2024
https://github.com/kaitingwang updated https://github.com/llvm/llvm-project/pull/101186
>From a7f7b2459344edae711c11d8863abd236db8f356 Mon Sep 17 00:00:00 2001
From: Amy Wang <kai.ting.wang at huawei.com>
Date: Wed, 21 Aug 2024 12:27:15 -0400
Subject: [PATCH] [MLIR][Transform] Allow passing state info from a pass into
the transform framework and back to the pass.
---
.../Interfaces/TransformInterfaces.h | 8 +-
.../Interfaces/TransformInterfaces.cpp | 12 ++-
...transform-state-extension-initializer.mlir | 19 ++++
.../test/lib/Dialect/Transform/CMakeLists.txt | 1 +
.../TestPassStateExtensionCommunication.cpp | 101 ++++++++++++++++++
.../TestTransformDialectExtension.cpp | 22 ++++
.../TestTransformDialectExtension.td | 9 ++
.../Transform/TestTransformStateExtension.h | 25 +++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
9 files changed, 195 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir
create mode 100644 mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
index 842e244dcde56c..0bb6037a77a16d 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
@@ -135,7 +135,9 @@ LogicalResult
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping = {},
const TransformOptions &options = TransformOptions(),
- bool enforceToplevelTransformOp = true);
+ bool enforceToplevelTransformOp = true,
+ function_ref<void (TransformState &)> stateInitializer = nullptr,
+ function_ref<LogicalResult (TransformState &)> stateExporter = nullptr);
/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
@@ -217,7 +219,9 @@ class TransformState {
friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
const RaggedArray<MappedValue> &,
- const TransformOptions &, bool);
+ const TransformOptions &, bool,
+ function_ref<void (TransformState &)>,
+ function_ref<LogicalResult (TransformState &)>);
friend TransformState
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index f8f85e4615c500..5bc6d4ee5033f1 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1999,7 +1999,9 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
LogicalResult transform::applyTransforms(
Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping,
- const TransformOptions &options, bool enforceToplevelTransformOp) {
+ const TransformOptions &options, bool enforceToplevelTransformOp,
+ function_ref<void(TransformState &)> stateInitializer,
+ function_ref<LogicalResult(TransformState &)> stateExporter) {
if (enforceToplevelTransformOp) {
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
@@ -2013,7 +2015,13 @@ LogicalResult transform::applyTransforms(
TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
options);
- return state.applyTransform(transform).checkAndReport();
+ if (stateInitializer)
+ stateInitializer(state);
+ if (state.applyTransform(transform).checkAndReport().failed())
+ return failure();
+ if (stateExporter)
+ return stateExporter(state);
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir b/mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir
new file mode 100644
index 00000000000000..9fb4d1b8689164
--- /dev/null
+++ b/mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -test-pass-state-extension-communication -verify-diagnostics | FileCheck %s
+
+// CHECK: Printing opCollection before processing transform ops, size: 1
+// CHECK: PASS-TRANSFORMOP-PASS
+
+// CHECK: Printing opCollection after processing transform ops, size: 4
+// CHECK: PASS-TRANSFORMOP-PASS transform.test_initializer_extension_A transform.test_initializer_extension_B transform.test_initializer_extension_C
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+ // expected-remark @below {{Number of currently registered op: 1}}
+ transform.test_initializer_extension "A"
+ // expected-remark @below {{Number of currently registered op: 2}}
+ transform.test_initializer_extension "B"
+ // expected-remark @below {{Number of currently registered op: 3}}
+ transform.test_initializer_extension "C"
+ transform.yield
+ }
+}
diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
index e6ab915a657b6f..ca141d2778ee2d 100644
--- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
@@ -6,6 +6,7 @@ mlir_tablegen(TestTransformDialectExtensionTypes.cpp.inc -gen-typedef-defs -type
add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen)
add_mlir_library(MLIRTestTransformDialect
+ TestPassStateExtensionCommunication.cpp
TestTransformDialectExtension.cpp
TestTransformDialectInterpreter.cpp
TestTransformStateExtension.cpp
diff --git a/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
new file mode 100644
index 00000000000000..4b5958af21d014
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
@@ -0,0 +1,101 @@
+//===- TestPassStateExtensionCommunication.cpp -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a test pass that showcases how communication can be
+// conducted between a regular mlir pass and transform ops through the
+// transform state extension stateInitializer and stateExporter mechanism.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTransformStateExtension.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::test;
+
+namespace {
+template <typename Derived>
+class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};
+
+struct TestPassStateExtensionCommunication
+ : public PassWrapper<TestPassStateExtensionCommunication,
+ OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestPassStateExtensionCommunication)
+
+ StringRef getArgument() const final {
+ return "test-pass-state-extension-communication";
+ }
+
+ StringRef getDescription() const final {
+ return "test state communciation between a mlir pass and transform ops";
+ }
+
+ static void printVector(const SmallVector<std::string> &opCollection,
+ const std::string &extraMessage = {}) {
+ outs() << "Printing opCollection" << extraMessage
+ << ", size: " << opCollection.size() << "\n";
+ for (const auto &subVector : opCollection) {
+ outs() << subVector << " ";
+ }
+ outs() << "\n";
+ }
+
+ void runOnOperation() override {
+ ModuleOp module = getOperation();
+
+ // Create an opCollection vector.
+ SmallVector<std::string> opCollection = {"PASS-TRANSFORMOP-PASS "};
+ printVector(opCollection, " before processing transform ops");
+
+ auto stateInitializer =
+ [&opCollection](mlir::transform::TransformState &state) -> void {
+ TransformStateInitializerExtension *ext =
+ state.getExtension<TransformStateInitializerExtension>();
+ if (!ext)
+ state.addExtension<TransformStateInitializerExtension>(0, opCollection);
+ };
+
+ auto stateExporter =
+ [&opCollection](
+ mlir::transform::TransformState &state) -> LogicalResult {
+ TransformStateInitializerExtension *ext =
+ state.getExtension<TransformStateInitializerExtension>();
+ if (!ext) {
+ errs() << "Target transform state extension not found!\n";
+ return failure();
+ }
+ opCollection.clear();
+ opCollection = ext->getRegisteredOps();
+ return success();
+ };
+
+ // Process transform ops with stateInitializer and stateExporter.
+ for (auto op : module.getBody()->getOps<transform::TransformOpInterface>())
+ if (failed(transform::applyTransforms(
+ module, op, {}, mlir::transform::TransformOptions(), false,
+ stateInitializer, stateExporter)))
+ return signalPassFailure();
+
+ // Print the opCollection vector after processing transform ops.
+ printVector(opCollection, " after processing transform ops");
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+/// Registers the test pass here.
+void registerTestPassStateExtensionCommunication() {
+ PassRegistration<TestPassStateExtensionCommunication> reg;
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index c023aad4a3ee77..a0a7afce66d9a1 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -804,6 +804,28 @@ void mlir::test::TestProduceInvalidIR::getEffects(
transform::modifiesPayload(effects);
}
+DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply(
+ transform::TransformRewriter &rewriter,
+ transform::TransformResults &results, transform::TransformState &state) {
+ std::string opName =
+ this->getOperationName().str() + "_" + getTypeAttr().str();
+ TransformStateInitializerExtension *initExt =
+ state.getExtension<TransformStateInitializerExtension>();
+ if (!initExt) {
+ emitRemark() << "\nSpecified extension not found, adding a new one!\n";
+ SmallVector<std::string> opCollection = {opName};
+ state.addExtension<TransformStateInitializerExtension>(1, opCollection);
+ } else {
+ initExt->setNumOp(initExt->getNumOp() + 1);
+ initExt->pushRegisteredOps(opName);
+ InFlightDiagnostic diag = emitRemark()
+ << "Number of currently registered op: "
+ << initExt->getNumOp() << "\n"
+ << initExt->printMessage() << "\n";
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
namespace {
/// Test conversion pattern that replaces ops with the "replace_with_new_op"
/// attribute with "test.new_op".
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 4f2cf34f7d3347..76375dba369448 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -549,4 +549,13 @@ def TestProduceInvalidIR
}];
}
+def TestInitializerExtensionOp
+ : Op<Transform_Dialect, "test_initializer_extension",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ NoMemoryEffect]> {
+ let arguments = (ins StrAttr:$type);
+ let assemblyFormat = "$type attr-dict";
+ let cppNamespace = "::mlir::test";
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
index 0bfa6bed015c0f..bbcbabea010b33 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
+++ b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
@@ -34,6 +34,31 @@ class TestTransformStateExtension
private:
StringAttr message;
};
+
+class TransformStateInitializerExtension
+ : public transform::TransformState::Extension {
+public:
+ TransformStateInitializerExtension(transform::TransformState &state,
+ int numOp, SmallVector<std::string>& registeredOps)
+ : Extension(state), numOp(numOp), registeredOps(registeredOps) {}
+
+ int getNumOp() { return numOp; }
+ void setNumOp(int num) { numOp = num; }
+ SmallVector<std::string> getRegisteredOps() { return registeredOps; }
+ void pushRegisteredOps(const std::string& newOp) { registeredOps.push_back(newOp); }
+ std::string printMessage() const {
+ std::string message = "Registered transformOps are: ";
+ for (const auto& op : registeredOps) {
+ message += op + " | ";
+ }
+ return message;
+ }
+
+private:
+ int numOp;
+ SmallVector<std::string> registeredOps;
+};
+
} // namespace test
} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 1842fa158e75a9..36b142484bb04a 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -148,6 +148,7 @@ void registerTestTensorCopyInsertionPass();
void registerTestTensorTransforms();
void registerTestTopologicalSortAnalysisPass();
void registerTestTransformDialectEraseSchedulePass();
+void registerTestPassStateExtensionCommunication();
void registerTestVectorLowerings();
void registerTestVectorReductionToSPIRVDotProd();
void registerTestWrittenToPass();
@@ -283,6 +284,7 @@ void registerTestPasses() {
mlir::test::registerTestTensorTransforms();
mlir::test::registerTestTopologicalSortAnalysisPass();
mlir::test::registerTestTransformDialectEraseSchedulePass();
+ mlir::test::registerTestPassStateExtensionCommunication();
mlir::test::registerTestVectorLowerings();
mlir::test::registerTestVectorReductionToSPIRVDotProd();
mlir::test::registerTestWrittenToPass();
More information about the Mlir-commits
mailing list