[llvm-branch-commits] [mlir] [mlir][Transforms] Support rolling back properties in dialect conversion (PR #82474)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Feb 21 01:41:00 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/82474

>From 1795b6dcfdc0b6447c5209000332432de49109cc Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 21 Feb 2024 09:37:48 +0000
Subject: [PATCH] [mlir][Transforms] Support rolling back properties in dialect
 conversion

The dialect conversion rolls back inplace op modifications upon failure. Rolling back modifications of op properties was not supported before this commit.
---
 .../Transforms/Utils/DialectConversion.cpp    | 31 ++++++++++++++++++-
 mlir/test/Transforms/test-legalizer.mlir      | 12 +++++++
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 18 ++++++++++-
 3 files changed, 59 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 673bd0383809cb..75094a41012e8d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1002,12 +1002,33 @@ class ModifyOperationRewrite : public OperationRewrite {
       : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
         loc(op->getLoc()), attrs(op->getAttrDictionary()),
         operands(op->operand_begin(), op->operand_end()),
-        successors(op->successor_begin(), op->successor_end()) {}
+        successors(op->successor_begin(), op->successor_end()) {
+    if (OpaqueProperties prop = op->getPropertiesStorage()) {
+      // Make a copy of the properties.
+      propertiesStorage = operator new(op->getPropertiesStorageSize());
+      OpaqueProperties propCopy(propertiesStorage);
+      op->getName().initOpProperties(propCopy, /*init=*/prop);
+    }
+  }
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::ModifyOperation;
   }
 
+  ~ModifyOperationRewrite() override {
+    assert(!propertiesStorage &&
+           "rewrite was neither committed nor rolled back");
+  }
+
+  void commit() override {
+    if (propertiesStorage) {
+      OpaqueProperties propCopy(propertiesStorage);
+      op->getName().destroyOpProperties(propCopy);
+      operator delete(propertiesStorage);
+      propertiesStorage = nullptr;
+    }
+  }
+
   /// Discard the transaction state and reset the state of the original
   /// operation.
   void rollback() override {
@@ -1016,6 +1037,13 @@ class ModifyOperationRewrite : public OperationRewrite {
     op->setOperands(operands);
     for (const auto &it : llvm::enumerate(successors))
       op->setSuccessor(it.value(), it.index());
+    if (propertiesStorage) {
+      OpaqueProperties propCopy(propertiesStorage);
+      op->copyProperties(propCopy);
+      op->getName().destroyOpProperties(propCopy);
+      operator delete(propertiesStorage);
+      propertiesStorage = nullptr;
+    }
   }
 
 private:
@@ -1023,6 +1051,7 @@ class ModifyOperationRewrite : public OperationRewrite {
   DictionaryAttr attrs;
   SmallVector<Value, 8> operands;
   SmallVector<Block *, 2> successors;
+  void *propertiesStorage = nullptr;
 };
 } // namespace
 
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 84fcc18ab7d370..62d776cd7573ee 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -334,3 +334,15 @@ func.func @test_move_op_before_rollback() {
   }) : () -> ()
   "test.return"() : () -> ()
 }
+
+// -----
+
+// CHECK-LABEL: func @test_properties_rollback()
+func.func @test_properties_rollback() {
+  // CHECK: test.with_properties <{a = 32 : i64,
+  // expected-remark @below{{op 'test.with_properties' is not legalizable}}
+  test.with_properties
+      <{a = 32 : i64, array = array<i64: 1, 2, 3, 4>, b = "foo"}>
+      {modify_inplace}
+  "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 1c02232b8adbb1..57e846294f8b9f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -806,6 +806,21 @@ struct TestUndoBlockErase : public ConversionPattern {
   }
 };
 
+/// A pattern that modifies a property in-place, but keeps the op illegal.
+struct TestUndoPropertiesModification : public ConversionPattern {
+  TestUndoPropertiesModification(MLIRContext *ctx)
+      : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    if (!op->hasAttr("modify_inplace"))
+      return failure();
+    rewriter.modifyOpInPlace(
+        op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); });
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Type-Conversion Rewrite Testing
 
@@ -1085,7 +1100,8 @@ struct TestLegalizePatternDriver
              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
              TestNonRootReplacement, TestBoundedRecursiveRewrite,
              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
-             TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext());
+             TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+             TestUndoPropertiesModification>(&getContext());
     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);



More information about the llvm-branch-commits mailing list