[Mlir-commits] [mlir] bd1ccfe - [mlir] Add a new RewritePattern::hasBoundedRewriteRecursion hook.

River Riddle llvmlistbot at llvm.org
Thu Apr 9 12:42:38 PDT 2020


Author: River Riddle
Date: 2020-04-09T12:42:28-07:00
New Revision: bd1ccfe6df24203c494685c82b83124303d99ce0

URL: https://github.com/llvm/llvm-project/commit/bd1ccfe6df24203c494685c82b83124303d99ce0
DIFF: https://github.com/llvm/llvm-project/commit/bd1ccfe6df24203c494685c82b83124303d99ce0.diff

LOG: [mlir] Add a new RewritePattern::hasBoundedRewriteRecursion hook.

Summary: Some pattern rewriters, like dialect conversion, prohibit the unbounded recursion(or reapplication) of patterns on generated IR. Most patterns are not written with recursive application in mind, so will generally explode the stack if uncaught. This revision adds a hook to RewritePattern, `hasBoundedRewriteRecursion`, to signal that the pattern can safely be applied to the generated IR of a previous application of the same pattern. This allows for establishing a contract between the pattern and rewriter that the pattern knows and can handle the potential recursive application.

Differential Revision: https://reviews.llvm.org/D77782

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index ef3c9fa62aa8..457a4b11d816 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -131,6 +131,12 @@ class RewritePattern : public Pattern {
     return failure();
   }
 
+  /// Returns true if this pattern is known to result in recursive application,
+  /// i.e. this pattern may generate IR that also matches this pattern, but is
+  /// known to bound the recursion. This signals to a rewriter that it is safe
+  /// to apply this pattern recursively to generated IR.
+  virtual bool hasBoundedRewriteRecursion() const { return false; }
+
   /// Return a list of operations that may be generated when rewriting an
   /// operation instance with this pattern.
   ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 2fc9b223a57c..eb4bf3b6d0ef 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -789,23 +789,10 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
         Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
         // 3. Reduce the problem to lowering a new InsertStridedSlice op with
         // smaller rank.
-        InsertStridedSliceOp insertStridedSliceOp =
-            rewriter.create<InsertStridedSliceOp>(
-                loc, extractedSource, extractedDest,
-                getI64SubArray(op.offsets(), /* dropFront=*/1),
-                getI64SubArray(op.strides(), /* dropFront=*/1));
-        // Call matchAndRewrite recursively from within the pattern. This
-        // circumvents the current limitation that a given pattern cannot
-        // be called multiple times by the PatternRewrite infrastructure (to
-        // avoid infinite recursion, but in this case, infinite recursion
-        // cannot happen because the rank is strictly decreasing).
-        // TODO(rriddle, nicolasvasilache) Implement something like a hook for
-        // a potential function that must decrease and allow the same pattern
-        // multiple times.
-        auto success = matchAndRewrite(insertStridedSliceOp, rewriter);
-        (void)success;
-        assert(succeeded(success) && "Unexpected failure");
-        extractedSource = insertStridedSliceOp;
+        extractedSource = rewriter.create<InsertStridedSliceOp>(
+            loc, extractedSource, extractedDest,
+            getI64SubArray(op.offsets(), /* dropFront=*/1),
+            getI64SubArray(op.strides(), /* dropFront=*/1));
       }
       // 4. Insert the extractedSource into the res vector.
       res = insertOne(rewriter, loc, extractedSource, res, off);
@@ -814,6 +801,9 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
     rewriter.replaceOp(op, res);
     return success();
   }
+  /// This pattern creates recursive InsertStridedSliceOp, but the recursion is
+  /// bounded as the rank is strictly decreasing.
+  bool hasBoundedRewriteRecursion() const final { return true; }
 };
 
 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
@@ -1068,28 +1058,19 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
          off += stride, ++idx) {
       Value extracted = extractOne(rewriter, loc, op.vector(), off);
       if (op.offsets().getValue().size() > 1) {
-        StridedSliceOp stridedSliceOp = rewriter.create<StridedSliceOp>(
+        extracted = rewriter.create<StridedSliceOp>(
             loc, extracted, getI64SubArray(op.offsets(), /* dropFront=*/1),
             getI64SubArray(op.sizes(), /* dropFront=*/1),
             getI64SubArray(op.strides(), /* dropFront=*/1));
-        // Call matchAndRewrite recursively from within the pattern. This
-        // circumvents the current limitation that a given pattern cannot
-        // be called multiple times by the PatternRewrite infrastructure (to
-        // avoid infinite recursion, but in this case, infinite recursion
-        // cannot happen because the rank is strictly decreasing).
-        // TODO(rriddle, nicolasvasilache) Implement something like a hook for
-        // a potential function that must decrease and allow the same pattern
-        // multiple times.
-        auto success = matchAndRewrite(stridedSliceOp, rewriter);
-        (void)success;
-        assert(succeeded(success) && "Unexpected failure");
-        extracted = stridedSliceOp;
       }
       res = insertOne(rewriter, loc, extracted, res, idx);
     }
     rewriter.replaceOp(op, {res});
     return success();
   }
+  /// This pattern creates recursive StridedSliceOp, but the recursion is
+  /// bounded as the rank is strictly decreasing.
+  bool hasBoundedRewriteRecursion() const final { return true; }
 };
 
 } // namespace

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index ab6924f091ad..e153e5c2f24a 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -1256,10 +1256,9 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
   });
 
   // Ensure that we don't cycle by not allowing the same pattern to be
-  // applied twice in the same recursion stack.
-  // TODO(riverriddle) We could eventually converge, but that requires more
-  // complicated analysis.
-  if (!appliedPatterns.insert(pattern).second) {
+  // applied twice in the same recursion stack if it is not known to be safe.
+  if (!pattern->hasBoundedRewriteRecursion() &&
+      !appliedPatterns.insert(pattern).second) {
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern was already applied"));
     return failure();
   }

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 3305e017d5b3..557908d2b1a4 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -143,6 +143,13 @@ func @create_block() {
   return
 }
 
+// CHECK-LABEL: @bounded_recursion
+func @bounded_recursion() {
+  // CHECK: test.recursive_rewrite 0
+  test.recursive_rewrite 3
+  return
+}
+
 // -----
 
 func @fail_to_convert_illegal_op() -> i32 {

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 8859d50342af..8eedd1ff6bb8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1061,6 +1061,12 @@ def TestRewriteOp : TEST_Op<"rewrite">,
   Arguments<(ins AnyType)>, Results<(outs AnyType)>;
 def : Pat<(TestRewriteOp $input), (replaceWithValue $input)>;
 
+// Check that patterns can specify bounded recursion when rewriting.
+def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> {
+  let arguments = (ins I64Attr:$depth);
+  let assemblyFormat = "$depth attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // Test Type Legalization
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 39b3fc1e5f4b..90b34d9fe70f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -360,6 +360,28 @@ struct TestNonRootReplacement : public RewritePattern {
     return success();
   }
 };
+
+//===----------------------------------------------------------------------===//
+// Recursive Rewrite Testing
+/// This pattern is applied to the same operation multiple times, but has a
+/// bounded recursion.
+struct TestBoundedRecursiveRewrite
+    : public OpRewritePattern<TestRecursiveRewriteOp> {
+  using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
+                                PatternRewriter &rewriter) const final {
+    // Decrement the depth of the op in-place.
+    rewriter.updateRootInPlace(op, [&] {
+      op.setAttr("depth",
+                 rewriter.getI64IntegerAttr(op.depth().getSExtValue() - 1));
+    });
+    return success();
+  }
+
+  /// The conversion target handles bounding the recursion of this pattern.
+  bool hasBoundedRewriteRecursion() const final { return true; }
+};
 } // namespace
 
 namespace {
@@ -414,7 +436,7 @@ struct TestLegalizePatternDriver
         TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
         TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
         TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
-        TestNonRootReplacement>(&getContext());
+        TestNonRootReplacement, TestBoundedRecursiveRewrite>(&getContext());
     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);
@@ -449,6 +471,10 @@ struct TestLegalizePatternDriver
           op->getAttrOfType<UnitAttr>("test.recursively_legal"));
     });
 
+    // Mark the bound recursion operation as dynamically legal.
+    target.addDynamicallyLegalOp<TestRecursiveRewriteOp>(
+        [](TestRecursiveRewriteOp op) { return op.depth() == 0; });
+
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
       (void)applyPartialConversion(getOperation(), target, patterns,


        


More information about the Mlir-commits mailing list