[Mlir-commits] [llvm] [mlir] [MLIR][Transform] Allow stateInitializer and stateExporter for applyTransforms (PR #101186)

Amy Wang llvmlistbot at llvm.org
Sat Aug 24 16:27:55 PDT 2024


https://github.com/kaitingwang updated https://github.com/llvm/llvm-project/pull/101186

>From a369381e6af093b65403394a3a8ac2150b212c0c Mon Sep 17 00:00:00 2001
From: Amy Wang <kai.ting.wang at huawei.com>
Date: Wed, 21 Aug 2024 12:27:15 -0400
Subject: [PATCH 1/2] [MLIR][Transform] Allow passing state info from a pass
 into the transform framework and back to the pass.

---
 .../Interfaces/TransformInterfaces.h          |  20 ++--
 .../Interfaces/TransformInterfaces.cpp        |  12 ++-
 ...transform-state-extension-initializer.mlir |  19 ++++
 .../test/lib/Dialect/Transform/CMakeLists.txt |   1 +
 .../TestPassStateExtensionCommunication.cpp   | 101 ++++++++++++++++++
 .../TestTransformDialectExtension.cpp         |  22 ++++
 .../TestTransformDialectExtension.td          |   9 ++
 .../Transform/TestTransformStateExtension.h   |  28 +++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 9 files changed, 204 insertions(+), 10 deletions(-)
 create mode 100644 mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir
 create mode 100644 mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp

diff --git a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
index 842e244dcde56c..43193e4cd4cf63 100644
--- a/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h
@@ -131,11 +131,13 @@ class TransformOptions {
 /// will be executed following the internal logic of the operation. It must
 /// have the `PossibleTopLevelTransformOp` trait and not have any operands.
 /// This function internally keeps track of the transformation state.
-LogicalResult
-applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
-                const RaggedArray<MappedValue> &extraMapping = {},
-                const TransformOptions &options = TransformOptions(),
-                bool enforceToplevelTransformOp = true);
+LogicalResult applyTransforms(
+    Operation *payloadRoot, TransformOpInterface transform,
+    const RaggedArray<MappedValue> &extraMapping = {},
+    const TransformOptions &options = TransformOptions(),
+    bool enforceToplevelTransformOp = true,
+    function_ref<void(TransformState &)> stateInitializer = nullptr,
+    function_ref<LogicalResult(TransformState &)> stateExporter = nullptr);
 
 /// The state maintained across applications of various ops implementing the
 /// TransformOpInterface. The operations implementing this interface and the
@@ -215,9 +217,11 @@ class TransformState {
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
   };
 
-  friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
-                                       const RaggedArray<MappedValue> &,
-                                       const TransformOptions &, bool);
+  friend LogicalResult
+  applyTransforms(Operation *, TransformOpInterface,
+                  const RaggedArray<MappedValue> &, const TransformOptions &,
+                  bool, function_ref<void(TransformState &)>,
+                  function_ref<LogicalResult(TransformState &)>);
 
   friend TransformState
   detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index f8f85e4615c500..5bc6d4ee5033f1 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1999,7 +1999,9 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
 LogicalResult transform::applyTransforms(
     Operation *payloadRoot, TransformOpInterface transform,
     const RaggedArray<MappedValue> &extraMapping,
-    const TransformOptions &options, bool enforceToplevelTransformOp) {
+    const TransformOptions &options, bool enforceToplevelTransformOp,
+    function_ref<void(TransformState &)> stateInitializer,
+    function_ref<LogicalResult(TransformState &)> stateExporter) {
   if (enforceToplevelTransformOp) {
     if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
         transform->getNumOperands() != 0) {
@@ -2013,7 +2015,13 @@ LogicalResult transform::applyTransforms(
 
   TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
                        options);
-  return state.applyTransform(transform).checkAndReport();
+  if (stateInitializer)
+    stateInitializer(state);
+  if (state.applyTransform(transform).checkAndReport().failed())
+    return failure();
+  if (stateExporter)
+    return stateExporter(state);
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir b/mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir
new file mode 100644
index 00000000000000..9fb4d1b8689164
--- /dev/null
+++ b/mlir/test/Dialect/Transform/transform-state-extension-initializer.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -test-pass-state-extension-communication -verify-diagnostics | FileCheck %s
+
+// CHECK: Printing opCollection before processing transform ops, size: 1
+// CHECK: PASS-TRANSFORMOP-PASS
+
+// CHECK: Printing opCollection after processing transform ops, size: 4
+// CHECK: PASS-TRANSFORMOP-PASS transform.test_initializer_extension_A transform.test_initializer_extension_B transform.test_initializer_extension_C
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
+    // expected-remark @below {{Number of currently registered op: 1}}
+    transform.test_initializer_extension "A"
+    // expected-remark @below {{Number of currently registered op: 2}}
+    transform.test_initializer_extension "B"
+    // expected-remark @below {{Number of currently registered op: 3}}
+    transform.test_initializer_extension "C"
+    transform.yield
+  }
+}
diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
index e6ab915a657b6f..ca141d2778ee2d 100644
--- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt
@@ -6,6 +6,7 @@ mlir_tablegen(TestTransformDialectExtensionTypes.cpp.inc -gen-typedef-defs -type
 add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen)
 
 add_mlir_library(MLIRTestTransformDialect
+  TestPassStateExtensionCommunication.cpp
   TestTransformDialectExtension.cpp
   TestTransformDialectInterpreter.cpp
   TestTransformStateExtension.cpp
diff --git a/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
new file mode 100644
index 00000000000000..9ec99f70630a82
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/TestPassStateExtensionCommunication.cpp
@@ -0,0 +1,101 @@
+//===- TestPassStateExtensionCommunication.cpp ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a test pass that showcases how communication can be
+// conducted between a regular mlir pass and transform ops through the
+// transform state extension stateInitializer and stateExporter mechanism.
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestTransformStateExtension.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::test;
+
+namespace {
+template <typename Derived>
+class OpPassWrapper : public PassWrapper<Derived, OperationPass<>> {};
+
+struct TestPassStateExtensionCommunication
+    : public PassWrapper<TestPassStateExtensionCommunication,
+                         OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestPassStateExtensionCommunication)
+
+  StringRef getArgument() const final {
+    return "test-pass-state-extension-communication";
+  }
+
+  StringRef getDescription() const final {
+    return "test state communciation between a mlir pass and transform ops";
+  }
+
+  static void printVector(const SmallVector<std::string> &opCollection,
+                          const std::string &extraMessage = {}) {
+    outs() << "Printing opCollection" << extraMessage
+           << ", size: " << opCollection.size() << "\n";
+    for (const auto &subVector : opCollection) {
+      outs() << subVector << " ";
+    }
+    outs() << "\n";
+  }
+
+  void runOnOperation() override {
+    ModuleOp module = getOperation();
+
+    // Create an opCollection vector.
+    SmallVector<std::string> opCollection = {"PASS-TRANSFORMOP-PASS "};
+    printVector(opCollection, " before processing transform ops");
+
+    auto stateInitializer =
+        [&opCollection](mlir::transform::TransformState &state) -> void {
+      TransformStateInitializerExtension *ext =
+          state.getExtension<TransformStateInitializerExtension>();
+      if (!ext)
+        state.addExtension<TransformStateInitializerExtension>(0, opCollection);
+    };
+
+    auto stateExporter =
+        [&opCollection](
+            mlir::transform::TransformState &state) -> LogicalResult {
+      TransformStateInitializerExtension *ext =
+          state.getExtension<TransformStateInitializerExtension>();
+      if (!ext) {
+        errs() << "Target transform state extension not found!\n";
+        return failure();
+      }
+      opCollection.clear();
+      opCollection = ext->getRegisteredOps();
+      return success();
+    };
+
+    // Process transform ops with stateInitializer and stateExporter.
+    for (auto op : module.getBody()->getOps<transform::TransformOpInterface>())
+      if (failed(transform::applyTransforms(
+              module, op, {}, mlir::transform::TransformOptions(), false,
+              stateInitializer, stateExporter)))
+        return signalPassFailure();
+
+    // Print the opCollection vector after processing transform ops.
+    printVector(opCollection, " after processing transform ops");
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+/// Registers the test pass here.
+void registerTestPassStateExtensionCommunication() {
+  PassRegistration<TestPassStateExtensionCommunication> reg;
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index c023aad4a3ee77..a0a7afce66d9a1 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -804,6 +804,28 @@ void mlir::test::TestProduceInvalidIR::getEffects(
   transform::modifiesPayload(effects);
 }
 
+DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply(
+    transform::TransformRewriter &rewriter,
+    transform::TransformResults &results, transform::TransformState &state) {
+  std::string opName =
+      this->getOperationName().str() + "_" + getTypeAttr().str();
+  TransformStateInitializerExtension *initExt =
+      state.getExtension<TransformStateInitializerExtension>();
+  if (!initExt) {
+    emitRemark() << "\nSpecified extension not found, adding a new one!\n";
+    SmallVector<std::string> opCollection = {opName};
+    state.addExtension<TransformStateInitializerExtension>(1, opCollection);
+  } else {
+    initExt->setNumOp(initExt->getNumOp() + 1);
+    initExt->pushRegisteredOps(opName);
+    InFlightDiagnostic diag = emitRemark()
+                              << "Number of currently registered op: "
+                              << initExt->getNumOp() << "\n"
+                              << initExt->printMessage() << "\n";
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 namespace {
 /// Test conversion pattern that replaces ops with the "replace_with_new_op"
 /// attribute with "test.new_op".
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 4f2cf34f7d3347..76375dba369448 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -549,4 +549,13 @@ def TestProduceInvalidIR
   }];
 }
 
+def TestInitializerExtensionOp
+  : Op<Transform_Dialect, "test_initializer_extension",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        NoMemoryEffect]> {
+  let arguments = (ins StrAttr:$type);
+  let assemblyFormat = "$type attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
index 0bfa6bed015c0f..54d00c70cfd8d8 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
+++ b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
@@ -34,6 +34,34 @@ class TestTransformStateExtension
 private:
   StringAttr message;
 };
+
+class TransformStateInitializerExtension
+    : public transform::TransformState::Extension {
+public:
+  TransformStateInitializerExtension(transform::TransformState &state,
+                                     int numOp,
+                                     SmallVector<std::string> &registeredOps)
+      : Extension(state), numOp(numOp), registeredOps(registeredOps) {}
+
+  int getNumOp() { return numOp; }
+  void setNumOp(int num) { numOp = num; }
+  SmallVector<std::string> getRegisteredOps() { return registeredOps; }
+  void pushRegisteredOps(const std::string &newOp) {
+    registeredOps.push_back(newOp);
+  }
+  std::string printMessage() const {
+    std::string message = "Registered transformOps are: ";
+    for (const auto &op : registeredOps) {
+      message += op + " | ";
+    }
+    return message;
+  }
+
+private:
+  int numOp;
+  SmallVector<std::string> registeredOps;
+};
+
 } // namespace test
 } // namespace mlir
 
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 1842fa158e75a9..36b142484bb04a 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -148,6 +148,7 @@ void registerTestTensorCopyInsertionPass();
 void registerTestTensorTransforms();
 void registerTestTopologicalSortAnalysisPass();
 void registerTestTransformDialectEraseSchedulePass();
+void registerTestPassStateExtensionCommunication();
 void registerTestVectorLowerings();
 void registerTestVectorReductionToSPIRVDotProd();
 void registerTestWrittenToPass();
@@ -283,6 +284,7 @@ void registerTestPasses() {
   mlir::test::registerTestTensorTransforms();
   mlir::test::registerTestTopologicalSortAnalysisPass();
   mlir::test::registerTestTransformDialectEraseSchedulePass();
+  mlir::test::registerTestPassStateExtensionCommunication();
   mlir::test::registerTestVectorLowerings();
   mlir::test::registerTestVectorReductionToSPIRVDotProd();
   mlir::test::registerTestWrittenToPass();

>From 9937bf0c14be5284e45f1b4dd6e0bca94755fd6d Mon Sep 17 00:00:00 2001
From: Kazu Hirata <kazu at google.com>
Date: Sat, 24 Aug 2024 16:20:38 -0700
Subject: [PATCH 2/2] [Mips] Remove a trivial variable (NFC) (#105940)

We assign I->getNumOperands() to J and immediately print that out as a
debug message.  We don't need to keep J across iterations.
---
 llvm/lib/Target/Mips/MipsConstantIslandPass.cpp | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/llvm/lib/Target/Mips/MipsConstantIslandPass.cpp b/llvm/lib/Target/Mips/MipsConstantIslandPass.cpp
index 0341af0caac46e..60bb10369df4fa 100644
--- a/llvm/lib/Target/Mips/MipsConstantIslandPass.cpp
+++ b/llvm/lib/Target/Mips/MipsConstantIslandPass.cpp
@@ -1630,8 +1630,6 @@ MipsConstantIslands::fixupConditionalBr(ImmBranch &Br) {
 }
 
 void MipsConstantIslands::prescanForConstants() {
-  unsigned J = 0;
-  (void)J;
   for (MachineBasicBlock &B : *MF) {
     for (MachineBasicBlock::instr_iterator I = B.instr_begin(),
                                            EB = B.instr_end();
@@ -1640,8 +1638,7 @@ void MipsConstantIslands::prescanForConstants() {
         case Mips::LwConstant32: {
           PrescannedForConstants = true;
           LLVM_DEBUG(dbgs() << "constant island constant " << *I << "\n");
-          J = I->getNumOperands();
-          LLVM_DEBUG(dbgs() << "num operands " << J << "\n");
+          LLVM_DEBUG(dbgs() << "num operands " << I->getNumOperands() << "\n");
           MachineOperand& Literal = I->getOperand(1);
           if (Literal.isImm()) {
             int64_t V = Literal.getImm();



More information about the Mlir-commits mailing list