[Mlir-commits] [mlir] 6c57b0d - [mlir] improve and test TransformState::Extension

Alex Zinenko llvmlistbot at llvm.org
Tue May 3 02:33:08 PDT 2022


Author: Alex Zinenko
Date: 2022-05-03T11:33:00+02:00
New Revision: 6c57b0debedaa5f211a39a8d6765ad4c74db4059

URL: https://github.com/llvm/llvm-project/commit/6c57b0debedaa5f211a39a8d6765ad4c74db4059
DIFF: https://github.com/llvm/llvm-project/commit/6c57b0debedaa5f211a39a8d6765ad4c74db4059.diff

LOG: [mlir] improve and test TransformState::Extension

Add the mechanism for TransformState extensions to update the mapping between
Transform IR values and Payload IR operations held by the state. The mechanism
is intentionally restrictive, similarly to how results of the transform op are
handled.

Introduce test ops that exercise a simple extension that maintains information
across the application of multiple transform ops.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D124778

Added: 
    mlir/test/Dialect/Transform/transform-state-extension.mlir
    mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
    mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
    mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index b4c8ada4e643b..d029c214b49e2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -74,6 +74,10 @@ class TransformState {
   /// This is helpful for transformations that apply to a particular handle.
   ArrayRef<Operation *> getPayloadOps(Value value) const;
 
+  /// Returns the Transform IR handle for the given Payload IR op if it exists
+  /// in the state, null otherwise.
+  Value getHandleForPayloadOp(Operation *op) const;
+
   /// Applies the transformation specified by the given transform op and updates
   /// the state accordingly.
   LogicalResult applyTransform(TransformOpInterface transform);
@@ -185,6 +189,10 @@ class TransformState {
     /// Provides read-only access to the parent TransformState object.
     const TransformState &getTransformState() const { return state; }
 
+    /// Replaces the given payload op with another op. If the replacement op is
+    /// null, removes the association of the payload op with its handle.
+    LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
+
   private:
     /// Back-reference to the state that is being extended.
     TransformState &state;
@@ -276,9 +284,17 @@ class TransformState {
   /// The callback function is called once per associated operation and is
   /// expected to return the modified operation or nullptr. In the latter case,
   /// the corresponding operation is no longer associated with the transform IR
-  /// value.
-  void updatePayloadOps(Value value,
-                        function_ref<Operation *(Operation *)> callback);
+  /// value. May fail if the operation produced by the update callback is
+  /// already associated with a 
diff erent Transform IR handle value.
+  LogicalResult
+  updatePayloadOps(Value value,
+                   function_ref<Operation *(Operation *)> callback);
+
+  /// Attempts to record the mapping between the given Payload IR operation and
+  /// the given Transform IR handle. Fails and reports an error if the operation
+  /// is already tracked by another handle.
+  static LogicalResult tryEmplaceReverseMapping(Mappings &map, Operation *op,
+                                                Value handle);
 
   /// The mappings between transform IR values and payload IR ops, aggregated by
   /// the region in which the transform IR values are defined.

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index c96f573363be0..3e11b6794bbde 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -41,6 +41,27 @@ transform::TransformState::getPayloadOps(Value value) const {
   return iter->getSecond();
 }
 
+Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
+  for (const Mappings &mapping : llvm::make_second_range(mappings)) {
+    if (Value handle = mapping.reverse.lookup(op))
+      return handle;
+  }
+  return Value();
+}
+
+LogicalResult transform::TransformState::tryEmplaceReverseMapping(
+    Mappings &map, Operation *operation, Value handle) {
+  auto insertionResult = map.reverse.insert({operation, handle});
+  if (!insertionResult.second) {
+    InFlightDiagnostic diag = operation->emitError()
+                              << "operation tracked by two handles";
+    diag.attachNote(handle.getLoc()) << "handle";
+    diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
+    return diag;
+  }
+  return success();
+}
+
 LogicalResult
 transform::TransformState::setPayloadOps(Value value,
                                          ArrayRef<Operation *> targets) {
@@ -63,14 +84,8 @@ transform::TransformState::setPayloadOps(Value value,
   // expressed using the dialect and may be constructed by valid API calls from
   // valid IR. Emit an error here.
   for (Operation *op : targets) {
-    auto insertionResult = mappings.reverse.insert({op, value});
-    if (!insertionResult.second) {
-      InFlightDiagnostic diag = op->emitError()
-                                << "operation tracked by two handles";
-      diag.attachNote(value.getLoc()) << "handle";
-      diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
-      return diag;
-    }
+    if (failed(tryEmplaceReverseMapping(mappings, op, value)))
+      return failure();
   }
 
   return success();
@@ -83,19 +98,26 @@ void transform::TransformState::removePayloadOps(Value value) {
   mappings.direct.erase(value);
 }
 
-void transform::TransformState::updatePayloadOps(
+LogicalResult transform::TransformState::updatePayloadOps(
     Value value, function_ref<Operation *(Operation *)> callback) {
-  auto it = getMapping(value).direct.find(value);
-  assert(it != getMapping(value).direct.end() && "unknown handle");
+  Mappings &mappings = getMapping(value);
+  auto it = mappings.direct.find(value);
+  assert(it != mappings.direct.end() && "unknown handle");
   SmallVector<Operation *> &association = it->getSecond();
   SmallVector<Operation *> updated;
   updated.reserve(association.size());
 
-  for (Operation *op : association)
-    if (Operation *updatedOp = callback(op))
+  for (Operation *op : association) {
+    mappings.reverse.erase(op);
+    if (Operation *updatedOp = callback(op)) {
       updated.push_back(updatedOp);
+      if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value)))
+        return failure();
+    }
+  }
 
   std::swap(association, updated);
+  return success();
 }
 
 LogicalResult
@@ -132,8 +154,21 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TransformState::Extension
+//===----------------------------------------------------------------------===//
+
 transform::TransformState::Extension::~Extension() = default;
 
+LogicalResult
+transform::TransformState::Extension::replacePayloadOp(Operation *op,
+                                                       Operation *replacement) {
+  return state.updatePayloadOps(state.getHandleForPayloadOp(op),
+                                [&](Operation *current) {
+                                  return current == op ? replacement : current;
+                                });
+}
+
 //===----------------------------------------------------------------------===//
 // TransformResults
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir
new file mode 100644
index 0000000000000..c63d16d4c5d40
--- /dev/null
+++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir
@@ -0,0 +1,46 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -split-input-file
+
+// expected-note @below {{associated payload op}}
+module {
+  transform.sequence {
+  ^bb0(%arg0: !pdl.operation):
+    // expected-remark @below {{extension absent}}
+    test_check_if_test_extension_present %arg0
+    test_add_test_extension "A"
+    // expected-remark @below {{extension present, A}}
+    test_check_if_test_extension_present %arg0
+    test_remove_test_extension
+    // expected-remark @below {{extension absent}}
+    test_check_if_test_extension_present %arg0
+  }
+}
+
+// -----
+
+// expected-note @below {{associated payload op}}
+module {
+  transform.sequence {
+  ^bb0(%arg0: !pdl.operation):
+    test_add_test_extension "A"
+    test_remove_test_extension
+    test_add_test_extension "B"
+    // expected-remark @below {{extension present, B}}
+    test_check_if_test_extension_present %arg0
+  }
+}
+
+// -----
+
+// expected-note @below {{associated payload op}}
+module {
+  transform.sequence {
+  ^bb0(%arg0: !pdl.operation):
+    test_add_test_extension "A"
+    // expected-remark @below {{extension present, A}}
+    test_check_if_test_extension_present %arg0
+    // expected-note @below {{associated payload op}}
+    test_remap_operand_to_self %arg0
+    // expected-remark @below {{extension present, A}}
+    test_check_if_test_extension_present %arg0
+  }
+}

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index a58f12d1176ec..c9687fa2d5bd6 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -12,10 +12,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "TestTransformDialectExtension.h"
+#include "TestTransformStateExtension.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
-#include "mlir/IR/Builders.h"
 #include "mlir/IR/OpImplementation.h"
 
 using namespace mlir;
@@ -142,6 +142,49 @@ LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply(
   return success();
 }
 
+LogicalResult
+mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results,
+                                          transform::TransformState &state) {
+  state.addExtension<TestTransformStateExtension>(getMessageAttr());
+  return success();
+}
+
+LogicalResult mlir::test::TestCheckIfTestExtensionPresentOp::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  auto *extension = state.getExtension<TestTransformStateExtension>();
+  if (!extension) {
+    emitRemark() << "extension absent";
+    return success();
+  }
+
+  InFlightDiagnostic diag = emitRemark()
+                            << "extension present, " << extension->getMessage();
+  for (Operation *payload : state.getPayloadOps(getOperand())) {
+    diag.attachNote(payload->getLoc()) << "associated payload op";
+    assert(state.getHandleForPayloadOp(payload) == getOperand() &&
+           "inconsistent mapping between transform IR handles and payload IR "
+           "operations");
+  }
+
+  return success();
+}
+
+LogicalResult mlir::test::TestRemapOperandPayloadToSelfOp::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  auto *extension = state.getExtension<TestTransformStateExtension>();
+  if (!extension)
+    return emitError() << "TestTransformStateExtension missing";
+
+  return extension->updateMapping(state.getPayloadOps(getOperand()).front(),
+                                  getOperation());
+}
+
+LogicalResult mlir::test::TestRemoveTestExtensionOp::apply(
+    transform::TransformResults &results, transform::TransformState &state) {
+  state.removeExtension<TestTransformStateExtension>();
+  return success();
+}
+
 namespace {
 /// Test extension of the Transform dialect. Registers additional ops and
 /// declares PDL as dependent dialect since the additional ops are using PDL

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 6fe34ae064232..d33f7907f27f2 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -56,4 +56,41 @@ def TestPrintRemarkAtOperandOp
   let cppNamespace = "::mlir::test";
 }
 
+def TestAddTestExtensionOp
+  : Op<Transform_Dialect, "test_add_test_extension",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        NoSideEffect]> {
+  let arguments = (ins StrAttr:$message);
+  let assemblyFormat = "$message attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
+def TestCheckIfTestExtensionPresentOp
+  : Op<Transform_Dialect, "test_check_if_test_extension_present",
+       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let arguments = (ins
+    Arg<PDL_Operation, "", [TransformMappingRead, PayloadIRRead]>:$operand);
+  let assemblyFormat = "$operand attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
+def TestRemapOperandPayloadToSelfOp
+  : Op<Transform_Dialect, "test_remap_operand_to_self",
+       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let arguments = (ins
+    Arg<PDL_Operation, "",
+        [TransformMappingRead, TransformMappingWrite, PayloadIRRead]>:$operand);
+  let assemblyFormat = "$operand attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
+def TestRemoveTestExtensionOp
+  : Op<Transform_Dialect, "test_remove_test_extension",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        NoSideEffect]> {
+  let assemblyFormat = "attr-dict";
+  let cppNamespace = "::mlir::test";
+}
+
+
 #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD

diff  --git a/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
new file mode 100644
index 0000000000000..3b2eb7602a7b5
--- /dev/null
+++ b/mlir/test/lib/Dialect/Transform/TestTransformStateExtension.h
@@ -0,0 +1,42 @@
+//===- TestTransformStateExtension.h - Test Utility -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines an TransformState extension for the purpose of testing the
+// relevant APIs.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H
+#define MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H
+
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+
+using namespace mlir;
+
+namespace mlir {
+namespace test {
+class TestTransformStateExtension
+    : public transform::TransformState::Extension {
+public:
+  TestTransformStateExtension(transform::TransformState &state,
+                              StringAttr message)
+      : Extension(state), message(message) {}
+
+  StringRef getMessage() const { return message.getValue(); }
+
+  LogicalResult updateMapping(Operation *previous, Operation *updated) {
+    return replacePayloadOp(previous, updated);
+  }
+
+private:
+  StringAttr message;
+};
+} // namespace test
+} // namespace mlir
+
+#endif // MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H


        


More information about the Mlir-commits mailing list