[Mlir-commits] [mlir] [MLIR][Transform] Make TransformState constructor public (PR #101186)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 21 09:29:43 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Amy Wang (kaitingwang)

<details>
<summary>Changes</summary>

This is discussed in RFC:
https://discourse.llvm.org/t/rfc-making-the-constructor-of-the-transformstate-class-protected/80377

---
Full diff: https://github.com/llvm/llvm-project/pull/101186.diff


9 Files Affected:

- (modified) mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h (+6-2) 
- (modified) mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp (+10-2) 
- (added) mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir (+19) 
- (modified) mlir/test/lib/Dialect/Transform/CMakeLists.txt (+1) 
- (added) mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp (+101) 
- (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp (+22) 
- (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td (+9) 
- (modified) mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h (+25) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
index 842e244dcde56c..0bb6037a77a16d 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
@@ -135,7 +135,9 @@ LogicalResult
 applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
                 const RaggedArray<MappedValue> &extraMapping = {},
                 const TransformOptions &options = TransformOptions(),
-                bool enforceToplevelTransformOp = true);
+                bool enforceToplevelTransformOp = true,
+                function_ref<void (TransformState &)> stateInitializer = nullptr,
+                function_ref<LogicalResult (TransformState &)> stateExporter = nullptr);
 
 /// The state maintained across applications of various ops implementing the
 /// TransformOpInterface. The operations implementing this interface and the
@@ -217,7 +219,9 @@ class TransformState {
 
   friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
                                        const RaggedArray<MappedValue> &,
-                                       const TransformOptions &, bool);
+                                       const TransformOptions &, bool,
+                                       function_ref<void (TransformState &)>,
+                                       function_ref<LogicalResult (TransformState &)>);
 
   friend TransformState
   detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index f8f85e4615c500..5bc6d4ee5033f1 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1999,7 +1999,9 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
 LogicalResult transform::applyTransforms(
     Operation *payloadRoot, TransformOpInterface transform,
     const RaggedArray<MappedValue> &extraMapping,
-    const TransformOptions &options, bool enforceToplevelTransformOp) {
+    const TransformOptions &options, bool enforceToplevelTransformOp,
+    function_ref<void(TransformState &)> stateInitializer,
+    function_ref<LogicalResult(TransformState &)> stateExporter) {
   if (enforceToplevelTransformOp) {
     if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
         transform->getNumOperands() != 0) {
@@ -2013,7 +2015,13 @@ LogicalResult transform::applyTransforms(
 
   TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
                        options);
-  return state.applyTransform(transform).checkAndReport();
+  if (stateInitializer)
+    stateInitializer(state);
+  if (state.applyTransform(transform).checkAndReport().failed())
+    return failure();
+  if (stateExporter)
+    return stateExporter(state);
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir b/mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir
new file mode 100644
index 00000000000000..9fb4d1b8689164
--- /dev/null
+++ b/mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -test-pass-state-extension-communication -verify-diagnostics | FileCheck %s
+
+// CHECK: Printing opCollection before processing transform ops, size: 1
+// CHECK: PASS-TRANSFORMOP-PASS
+
+// CHECK: Printing opCollection after processing transform ops, size: 4
+// CHECK: PASS-TRANSFORMOP-PASS transform.test_initializer_extension_A transform.test_initializer_extension_B transform.test_initializer_extension_C
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // expected-remark @below {{Number of currently registered op: 1}}
+    transform.test_initializer_extension "A"
+    // expected-remark @below {{Number of currently registered op: 2}}
+    transform.test_initializer_extension "B"
+    // expected-remark @below {{Number of currently registered op: 3}}
+    transform.test_initializer_extension "C"
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
index e6ab915a657b6f..ca141d2778ee2d 100644
--- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
@@ -6,6 +6,7 @@ mlir_tablegen(TestTransformDialectExtensionTypes.cpp.inc -gen-typedef-defs -type
 add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen)
 
 add_mlir_library(MLIRTestTransformDialect
+  TestPassStateExtensionCommunication.cpp
   TestTransformDialectExtension.cpp
   TestTransformDialectInterpreter.cpp
   TestTransformStateExtension.cpp
diff --git a/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
new file mode 100644
index 00000000000000..4b5958af21d014
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
@@ -0,0 +1,101 @@
+//===- TestPassStateExtensionCommunication.cpp -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a test pass that showcases how communication can be
+// conducted between a regular mlir pass and transform ops through the
+// transform state extension stateInitializer and stateExporter mechanism.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTransformStateExtension.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::test;
+
+namespace {
+template <typename Derived>
+class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};
+
+struct TestPassStateExtensionCommunication
+    : public PassWrapper<TestPassStateExtensionCommunication,
+                         OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestPassStateExtensionCommunication)
+
+  StringRef getArgument() const final {
+    return "test-pass-state-extension-communication";
+  }
+
+  StringRef getDescription() const final {
+    return "test state communciation between a mlir pass and transform ops";
+  }
+
+  static void printVector(const SmallVector<std::string> &opCollection,
+                          const std::string &extraMessage = {}) {
+    outs() << "Printing opCollection" << extraMessage
+           << ", size: " << opCollection.size() << "\n";
+    for (const auto &subVector : opCollection) {
+      outs() << subVector << " ";
+    }
+    outs() << "\n";
+  }
+
+  void runOnOperation() override {
+    ModuleOp module = getOperation();
+
+    // Create an opCollection vector.
+    SmallVector<std::string> opCollection = {"PASS-TRANSFORMOP-PASS "};
+    printVector(opCollection, " before processing transform ops");
+
+    auto stateInitializer =
+        [&opCollection](mlir::transform::TransformState &state) -> void {
+      TransformStateInitializerExtension *ext =
+          state.getExtension<TransformStateInitializerExtension>();
+      if (!ext)
+        state.addExtension<TransformStateInitializerExtension>(0, opCollection);
+    };
+
+    auto stateExporter =
+        [&opCollection](
+            mlir::transform::TransformState &state) -> LogicalResult {
+      TransformStateInitializerExtension *ext =
+          state.getExtension<TransformStateInitializerExtension>();
+      if (!ext) {
+        errs() << "Target transform state extension not found!\n";
+        return failure();
+      }
+      opCollection.clear();
+      opCollection = ext->getRegisteredOps();
+      return success();
+    };
+
+    // Process transform ops with stateInitializer and stateExporter.
+    for (auto op : module.getBody()->getOps<transform::TransformOpInterface>())
+      if (failed(transform::applyTransforms(
+              module, op, {}, mlir::transform::TransformOptions(), false,
+              stateInitializer, stateExporter)))
+        return signalPassFailure();
+
+    // Print the opCollection vector after processing transform ops.
+    printVector(opCollection, " after processing transform ops");
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+/// Registers the test pass here.
+void registerTestPassStateExtensionCommunication() {
+  PassRegistration<TestPassStateExtensionCommunication> reg;
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index c023aad4a3ee77..a0a7afce66d9a1 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -804,6 +804,28 @@ void mlir::test::TestProduceInvalidIR::getEffects(
   transform::modifiesPayload(effects);
 }
 
+DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply(
+    transform::TransformRewriter &rewriter,
+    transform::TransformResults &results, transform::TransformState &state) {
+  std::string opName =
+      this->getOperationName().str() + "_" + getTypeAttr().str();
+  TransformStateInitializerExtension *initExt =
+      state.getExtension<TransformStateInitializerExtension>();
+  if (!initExt) {
+    emitRemark() << "\nSpecified extension not found, adding a new one!\n";
+    SmallVector<std::string> opCollection = {opName};
+    state.addExtension<TransformStateInitializerExtension>(1, opCollection);
+  } else {
+    initExt->setNumOp(initExt->getNumOp() + 1);
+    initExt->pushRegisteredOps(opName);
+    InFlightDiagnostic diag = emitRemark()
+                              << "Number of currently registered op: "
+                              << initExt->getNumOp() << "\n"
+                              << initExt->printMessage() << "\n";
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 namespace {
 /// Test conversion pattern that replaces ops with the "replace_with_new_op"
 /// attribute with "test.new_op".
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 4f2cf34f7d3347..76375dba369448 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -549,4 +549,13 @@ def TestProduceInvalidIR
   }];
 }
 
+def TestInitializerExtensionOp
+  : Op<Transform_Dialect, "test_initializer_extension",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        NoMemoryEffect]> {
+  let arguments = (ins StrAttr:$type);
+  let assemblyFormat = "$type attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
index 0bfa6bed015c0f..bbcbabea010b33 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
+++ b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
@@ -34,6 +34,31 @@ class TestTransformStateExtension
 private:
   StringAttr message;
 };
+
+class TransformStateInitializerExtension
+    : public transform::TransformState::Extension {
+public:
+  TransformStateInitializerExtension(transform::TransformState &state,
+                              int numOp, SmallVector<std::string>& registeredOps)
+      : Extension(state), numOp(numOp), registeredOps(registeredOps) {}
+
+  int getNumOp() { return numOp; }
+  void setNumOp(int num) { numOp = num; }
+  SmallVector<std::string> getRegisteredOps() { return registeredOps; }
+  void pushRegisteredOps(const std::string& newOp) { registeredOps.push_back(newOp); }
+  std::string printMessage() const {
+    std::string message = "Registered transformOps are: ";
+    for (const auto& op : registeredOps) {
+      message += op + " | ";
+    }
+    return message;
+  }
+
+private:
+  int numOp;
+  SmallVector<std::string> registeredOps;
+};
+
 } // namespace test
 } // namespace mlir
 
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 1842fa158e75a9..36b142484bb04a 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -148,6 +148,7 @@ void registerTestTensorCopyInsertionPass();
 void registerTestTensorTransforms();
 void registerTestTopologicalSortAnalysisPass();
 void registerTestTransformDialectEraseSchedulePass();
+void registerTestPassStateExtensionCommunication();
 void registerTestVectorLowerings();
 void registerTestVectorReductionToSPIRVDotProd();
 void registerTestWrittenToPass();
@@ -283,6 +284,7 @@ void registerTestPasses() {
   mlir::test::registerTestTensorTransforms();
   mlir::test::registerTestTopologicalSortAnalysisPass();
   mlir::test::registerTestTransformDialectEraseSchedulePass();
+  mlir::test::registerTestPassStateExtensionCommunication();
   mlir::test::registerTestVectorLowerings();
   mlir::test::registerTestVectorReductionToSPIRVDotProd();
   mlir::test::registerTestWrittenToPass();

``````````

</details>


https://github.com/llvm/llvm-project/pull/101186


More information about the Mlir-commits mailing list