[Mlir-commits] [mlir] e4635e6 - [mlir][FoldUtils] Ensure the created constant dominates the replaced op

River Riddle llvmlistbot at llvm.org
Mon Aug 23 11:52:40 PDT 2021


Author: River Riddle
Date: 2021-08-23T18:48:24Z
New Revision: e4635e6328c88f1f403b316bd8340fe03f9835a8

URL: https://github.com/llvm/llvm-project/commit/e4635e6328c88f1f403b316bd8340fe03f9835a8
DIFF: https://github.com/llvm/llvm-project/commit/e4635e6328c88f1f403b316bd8340fe03f9835a8.diff

LOG: [mlir][FoldUtils] Ensure the created constant dominates the replaced op

This revision fixes a bug where an operation would get replaced with
a pre-existing constant that didn't dominate it. This can occur when
a pattern inserts operations to be folded at the beginning of the
constants insertion block. This revision fixes the bug by moving the
existing constant before the replaced operation in such cases. This is
fine because if a constant didn't already exist, a new one would have
been inserted before this operation anyways.

Differential Revision: https://reviews.llvm.org/D108498

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/test/Transforms/test-operation-folder.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index af415d5fc5228..f9e1bb45de567 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -233,6 +233,14 @@ LogicalResult OperationFolder::tryToFold(
     if (auto *constOp =
             tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
                                    res.getType(), op->getLoc())) {
+      // Ensure that this constant dominates the operation we are replacing it
+      // with. This may not automatically happen if the operation being folded
+      // was inserted before the constant within the insertion block.
+      if (constOp->getBlock() == op->getBlock() &&
+          !constOp->isBeforeInBlock(op)) {
+        constOp->moveBefore(op);
+      }
+
       results.push_back(constOp->getResult(0));
       continue;
     }

diff  --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir
index 7ef65303aa81a..f5e86ed4b9d8b 100644
--- a/mlir/test/Transforms/test-operation-folder.mlir
+++ b/mlir/test/Transforms/test-operation-folder.mlir
@@ -10,3 +10,15 @@ func @foo() -> i32 {
   %0 = "test.op_in_place_fold_anchor"(%c42) : (i32) -> (i32)
   return %0 : i32
 }
+
+func @test_fold_before_previously_folded_op() -> (i32, i32) {
+  // When folding two constants will be generated and uniqued. Check that the
+  // uniqued constant properly dominates both uses.
+  // CHECK: %[[CST:.+]] = constant true
+  // CHECK-NEXT: "test.cast"(%[[CST]]) : (i1) -> i32
+  // CHECK-NEXT: "test.cast"(%[[CST]]) : (i1) -> i32
+
+  %0 = "test.cast"() {test_fold_before_previously_folded_op} : () -> (i32)
+  %1 = "test.cast"() {test_fold_before_previously_folded_op} : () -> (i32)
+  return %0, %1 : i32, i32
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 62bed7e0bba2c..de141dc72e981 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -99,6 +99,29 @@ struct FoldingPattern : public RewritePattern {
   }
 };
 
+/// This pattern creates a foldable operation at the entry point of the block.
+/// This tests the situation where the operation folder will need to replace an
+/// operation with a previously created constant that does not initially
+/// dominate the operation to replace.
+struct FolderInsertBeforePreviouslyFoldedConstantPattern
+    : public OpRewritePattern<TestCastOp> {
+public:
+  using OpRewritePattern<TestCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TestCastOp op,
+                                PatternRewriter &rewriter) const override {
+    if (!op->hasAttr("test_fold_before_previously_folded_op"))
+      return failure();
+    rewriter.setInsertionPointToStart(op->getBlock());
+
+    auto constOp =
+        rewriter.create<ConstantOp>(op.getLoc(), rewriter.getBoolAttr(true));
+    rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(),
+                                            Value(constOp));
+    return success();
+  }
+};
+
 struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
   StringRef getArgument() const final { return "test-patterns"; }
   StringRef getDescription() const final { return "Run test dialect patterns"; }
@@ -107,7 +130,9 @@ struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
     populateWithGenerated(patterns);
 
     // Verify named pattern is generated with expected name.
-    patterns.add<FoldingPattern, TestNamedPatternRule>(&getContext());
+    patterns.add<FoldingPattern, TestNamedPatternRule,
+                 FolderInsertBeforePreviouslyFoldedConstantPattern>(
+        &getContext());
 
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }


        


More information about the Mlir-commits mailing list