[Mlir-commits] [mlir] 648f34a - Merge with mainline.

Chris Lattner llvmlistbot at llvm.org
Mon May 17 11:15:16 PDT 2021


Author: Chris Lattner
Date: 2021-05-17T11:15:10-07:00
New Revision: 648f34a2840b75f4081884052f2ccb11f62f8209

URL: https://github.com/llvm/llvm-project/commit/648f34a2840b75f4081884052f2ccb11f62f8209
DIFF: https://github.com/llvm/llvm-project/commit/648f34a2840b75f4081884052f2ccb11f62f8209.diff

LOG: Merge with mainline.

Differential Revision: https://reviews.llvm.org/D102636

Added: 
    mlir/test/Transforms/canonicalize-td.mlir

Modified: 
    mlir/docs/PatternRewriter.md
    mlir/include/mlir/Transforms/FoldUtils.h
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/include/mlir/Transforms/Passes.td
    mlir/lib/Transforms/Canonicalizer.cpp
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index f452230584c0d..de9b881bc0bb5 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -242,8 +242,8 @@ driver can be found [here](DialectConversion.md).
 
 ### Greedy Pattern Rewrite Driver
 
-This driver performs a post order traversal over the provided operations and
-greedily applies the patterns that locally have the most benefit. The benefit of
+This driver walks the provided operations and greedily applies the patterns that
+locally have the most benefit. The benefit of
 a pattern is decided solely by the benefit specified on the pattern, and the
 relative order of the pattern within the pattern list (when two patterns have
 the same local benefit). Patterns are iteratively applied to operations until a
@@ -252,5 +252,12 @@ used via the following: `applyPatternsAndFoldGreedily` and
 `applyOpPatternsAndFold`. The latter of which only applies patterns to the
 provided operation, and will not traverse the IR.
 
+The driver is configurable and supports two modes: 1) you may opt-in to a
+"top-down" traversal, which seeds the worklist with each operation top down and
+in a pre-order over the region tree.  This is generally more efficient in
+compile time.  2) the default is a "bottom up" traversal, which builds the
+initial worklist with a postorder traversal of the region tree.  This may
+match larger patterns with ambiguous pattern sets.
+
 Note: This driver is the one used by the [canonicalization](Canonicalization.md)
 [pass](Passes.md#-canonicalize-canonicalize-operations) in MLIR.

diff  --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 7f4166c12ded9..9f78f374a6fef 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -33,6 +33,11 @@ class OperationFolder {
 public:
   OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
 
+  /// Scan the specified region for constants that can be used in folding,
+  /// moving them to the entry block (or any custom insertion location specified
+  /// by shouldMaterializeInto), and add them to our known-constants table.
+  void processExistingConstants(Region &region);
+
   /// Tries to perform folding on the given `op`, including unifying
   /// deduplicated constants. If successful, replaces `op`'s uses with
   /// folded results, and returns success. `preReplaceAction` is invoked on `op`

diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 3a76fbd3e0b02..b67fc5ffa06b7 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -25,36 +25,45 @@ namespace mlir {
 /// Rewrite the regions of the specified operation, which must be isolated from
 /// above, by repeatedly applying the highest benefit patterns in a greedy
 /// work-list driven manner.
+///
 /// This variant may stop after a predefined number of iterations, see the
 /// alternative below to provide a specific number of iterations before stopping
 /// in absence of convergence.
+///
 /// Return success if the iterative process converged and no more patterns can
 /// be matched in the result operation regions.
+///
 /// Note: This does not apply patterns to the top-level operation itself.
 ///       These methods also perform folding and simple dead-code elimination
 ///       before attempting to match any of the provided patterns.
+///
+/// You may choose the order of initial traversal with the `useTopDownTraversal`
+/// boolean.  When set to true, it walks the operations top-down, which is
+/// generally more efficient in compile time.  When set to false, its initial
+/// traversal of the region tree is post-order, which may match larger patterns
+/// when given an ambiguous pattern set.
 LogicalResult
 applyPatternsAndFoldGreedily(Operation *op,
-                             const FrozenRewritePatternSet &patterns);
+                             const FrozenRewritePatternSet &patterns,
+                             bool useTopDownTraversal = false);
 
 /// Rewrite the regions of the specified operation, with a user-provided limit
 /// on iterations to attempt before reaching convergence.
-LogicalResult
-applyPatternsAndFoldGreedily(Operation *op,
-                             const FrozenRewritePatternSet &patterns,
-                             unsigned maxIterations);
+LogicalResult applyPatternsAndFoldGreedily(
+    Operation *op, const FrozenRewritePatternSet &patterns,
+    unsigned maxIterations, bool useTopDownTraversal = false);
 
 /// Rewrite the given regions, which must be isolated from above.
 LogicalResult
 applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
-                             const FrozenRewritePatternSet &patterns);
+                             const FrozenRewritePatternSet &patterns,
+                             bool useTopDownTraversal = false);
 
 /// Rewrite the given regions, with a user-provided limit on iterations to
 /// attempt before reaching convergence.
-LogicalResult
-applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
-                             const FrozenRewritePatternSet &patterns,
-                             unsigned maxIterations);
+LogicalResult applyPatternsAndFoldGreedily(
+    MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
+    unsigned maxIterations, bool useTopDownTraversal = false);
 
 /// Applies the specified patterns on `op` alone while also trying to fold it,
 /// by selecting the highest benefits patterns in a greedy manner. Returns

diff  --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index b833ab47b4149..34b1c950b674a 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -362,6 +362,11 @@ def Canonicalizer : Pass<"canonicalize"> {
     details.
   }];
   let constructor = "mlir::createCanonicalizerPass()";
+  let options = [
+    Option<"topDownProcessingEnabled", "top-down", "bool",
+           /*default=*/"false",
+           "Seed the worklist in general top-down order">
+  ];
 }
 
 def CSE : Pass<"cse"> {

diff  --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 27a0fe57c04b6..cc3a2170190d2 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -31,7 +31,10 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
     return success();
   }
   void runOnOperation() override {
-    (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns);
+    (void)applyPatternsAndFoldGreedily(
+        getOperation()->getRegions(), patterns,
+        /*maxIterations=*/10, /*useTopDownTraversal=*/
+        topDownProcessingEnabled);
   }
 
   FrozenRewritePatternSet patterns;

diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 024ae1892861c..e6faa860bc76d 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -84,6 +84,85 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
 // OperationFolder
 //===----------------------------------------------------------------------===//
 
+/// Scan the specified region for constants that can be used in folding,
+/// moving them to the entry block (or any custom insertion location specified
+/// by shouldMaterializeInto), and add them to our known-constants table.
+void OperationFolder::processExistingConstants(Region &region) {
+  if (region.empty())
+    return;
+
+  // March the constant insertion point forward, moving all constants to the
+  // top of the block, but keeping them in their order of discovery.
+  Region *insertRegion = getInsertionRegion(interfaces, &region.front());
+  auto &uniquedConstants = foldScopes[insertRegion];
+
+  Block &insertBlock = insertRegion->front();
+  Block::iterator constantIterator = insertBlock.begin();
+
+  // Process each constant that we discover in this region.
+  auto processConstant = [&](Operation *op, Attribute value) {
+    assert(op->getNumResults() == 1 && "constants have one result");
+    // Check to see if we already have an instance of this constant.
+    Operation *&constOp = uniquedConstants[std::make_tuple(
+        op->getDialect(), value, op->getResult(0).getType())];
+
+    // If we already have an instance of this constant, CSE/delete this one as
+    // we go.
+    if (constOp) {
+      if (constantIterator == Block::iterator(op))
+        ++constantIterator; // Don't invalidate our iterator when scanning.
+      op->getResult(0).replaceAllUsesWith(constOp->getResult(0));
+      op->erase();
+      return;
+    }
+
+    // Otherwise, remember that we have this constant.
+    constOp = op;
+    referencedDialects[op].push_back(op->getDialect());
+
+    // If the constant isn't already at the insertion point then move it up.
+    if (constantIterator != Block::iterator(op))
+      op->moveBefore(&insertBlock, constantIterator);
+    else
+      ++constantIterator; // It was pointing at the constant.
+  };
+
+  // Collect all the constants for this region of isolation or insertion (as
+  // specified by the shouldMaterializeInto hook).  Collect any subregions of
+  // isolation/constant insertion for subsequent processing.
+  SmallVector<Operation *> insertionSubregionOps;
+  region.walk<WalkOrder::PreOrder>([&](Operation *op) {
+    // If this is a constant, process it.
+    Attribute value;
+    if (matchPattern(op, m_Constant(&value))) {
+      processConstant(op, value);
+      // We may have deleted the operation, don't check it for regions.
+      return WalkResult::skip();
+    }
+
+    // If the operation has regions and is isolated, don't recurse into it.
+    if (op->getNumRegions() != 0) {
+      auto hasDifferentInsertRegion = [&](Region &region) {
+        return !region.empty() &&
+               getInsertionRegion(interfaces, &region.front()) != insertRegion;
+      };
+      if (llvm::any_of(op->getRegions(), hasDifferentInsertRegion)) {
+        insertionSubregionOps.push_back(op);
+        return WalkResult::skip();
+      }
+    }
+
+    // Otherwise keep going.
+    return WalkResult::advance();
+  });
+
+  // Process regions in any isolated ops separately.
+  for (Operation *subregionOps : insertionSubregionOps) {
+    for (Region &region : subregionOps->getRegions())
+      processExistingConstants(region);
+  }
+}
+
 LogicalResult OperationFolder::tryToFold(
     Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
     function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
@@ -262,19 +341,19 @@ Operation *OperationFolder::tryGetOrCreateConstant(
     Attribute value, Type type, Location loc) {
   // Check if an existing mapping already exists.
   auto constKey = std::make_tuple(dialect, value, type);
-  auto *&constInst = uniquedConstants[constKey];
-  if (constInst)
-    return constInst;
+  Operation *&constOp = uniquedConstants[constKey];
+  if (constOp)
+    return constOp;
 
   // If one doesn't exist, try to materialize one.
-  if (!(constInst = materializeConstant(dialect, builder, value, type, loc)))
+  if (!(constOp = materializeConstant(dialect, builder, value, type, loc)))
     return nullptr;
 
   // Check to see if the generated constant is in the expected dialect.
-  auto *newDialect = constInst->getDialect();
+  auto *newDialect = constOp->getDialect();
   if (newDialect == dialect) {
-    referencedDialects[constInst].push_back(dialect);
-    return constInst;
+    referencedDialects[constOp].push_back(dialect);
+    return constOp;
   }
 
   // If it isn't, then we also need to make sure that the mapping for the new
@@ -284,13 +363,13 @@ Operation *OperationFolder::tryGetOrCreateConstant(
   // If an existing operation in the new dialect already exists, delete the
   // materialized operation in favor of the existing one.
   if (auto *existingOp = uniquedConstants.lookup(newKey)) {
-    constInst->erase();
+    constOp->erase();
     referencedDialects[existingOp].push_back(dialect);
-    return constInst = existingOp;
+    return constOp = existingOp;
   }
 
   // Otherwise, update the new dialect to the materialized operation.
-  referencedDialects[constInst].assign({dialect, newDialect});
-  auto newIt = uniquedConstants.insert({newKey, constInst});
+  referencedDialects[constOp].assign({dialect, newDialect});
+  auto newIt = uniquedConstants.insert({newKey, constOp});
   return newIt.first->second;
 }

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index c82076bde0820..1d0a79d054060 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -37,8 +37,10 @@ namespace {
 class GreedyPatternRewriteDriver : public PatternRewriter {
 public:
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
-                                      const FrozenRewritePatternSet &patterns)
-      : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
+                                      const FrozenRewritePatternSet &patterns,
+                                      bool useTopDownTraversal)
+      : PatternRewriter(ctx), matcher(patterns), folder(ctx),
+        useTopDownTraversal(useTopDownTraversal) {
     worklist.reserve(64);
 
     // Apply a simple cost model based solely on pattern benefit.
@@ -134,6 +136,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 
   /// Non-pattern based folder for operations.
   OperationFolder folder;
+
+  /// Whether to use a top-down or bottom-up traversal to seed the initial
+  /// worklist.
+  bool useTopDownTraversal;
 };
 } // end anonymous namespace
 
@@ -141,15 +147,36 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 /// if the rewrite converges in `maxIterations`.
 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
                                           int maxIterations) {
-  // Add the given operation to the worklist.
-  auto collectOps = [this](Operation *op) { addToWorklist(op); };
+  // For maximum compatibility with existing passes, do not process existing
+  // constants unless we're performing a top-down traversal.
+  // TODO: This is just for compatibility with older MLIR, remove this.
+  if (useTopDownTraversal) {
+    // Perform a prepass over the IR to discover constants.
+    for (auto &region : regions)
+      folder.processExistingConstants(region);
+  }
 
   bool changed = false;
-  int i = 0;
+  int iteration = 0;
   do {
-    // Add all nested operations to the worklist.
+    worklist.clear();
+    worklistMap.clear();
+
+    // Add all nested operations to the worklist in preorder.
     for (auto &region : regions)
-      region.walk(collectOps);
+      if (useTopDownTraversal)
+        region.walk<WalkOrder::PreOrder>(
+            [this](Operation *op) { worklist.push_back(op); });
+      else
+        region.walk([this](Operation *op) { addToWorklist(op); });
+
+    if (useTopDownTraversal) {
+      // Reverse the list so our pop-back loop processes them in-order.
+      std::reverse(worklist.begin(), worklist.end());
+      // Remember the reverse index.
+      for (size_t i = 0, e = worklist.size(); i != e; ++i)
+        worklistMap[worklist[i]] = i;
+    }
 
     // These are scratch vectors used in the folding loop below.
     SmallVector<Value, 8> originalOperands, resultValues;
@@ -187,6 +214,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
         notifyOperationRemoved(op);
       };
 
+      // Add the given operation to the worklist.
+      auto collectOps = [this](Operation *op) { addToWorklist(op); };
+
       // Try to fold this op.
       bool inPlaceUpdate;
       if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
@@ -197,14 +227,16 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
       }
 
       // Try to match one of the patterns. The rewriter is automatically
-      // notified of any necessary changes, so there is nothing else to do here.
+      // notified of any necessary changes, so there is nothing else to do
+      // here.
       changed |= succeeded(matcher.matchAndRewrite(op, *this));
     }
 
-    // After applying patterns, make sure that the CFG of each of the regions is
-    // kept up to date.
+    // After applying patterns, make sure that the CFG of each of the regions
+    // is kept up to date.
     changed |= succeeded(simplifyRegions(*this, regions));
-  } while (changed && ++i < maxIterations);
+  } while (changed && ++iteration < maxIterations);
+
   // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
   return !changed;
 }
@@ -216,28 +248,29 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
 /// top-level operation itself.
 ///
 LogicalResult
-mlir::applyPatternsAndFoldGreedily(Operation *op,
-                                   const FrozenRewritePatternSet &patterns) {
-  return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations);
-}
-LogicalResult
 mlir::applyPatternsAndFoldGreedily(Operation *op,
                                    const FrozenRewritePatternSet &patterns,
-                                   unsigned maxIterations) {
-  return applyPatternsAndFoldGreedily(op->getRegions(), patterns,
-                                      maxIterations);
+                                   bool useTopDownTraversal) {
+  return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations,
+                                      useTopDownTraversal);
 }
-/// Rewrite the given regions, which must be isolated from above.
-LogicalResult
-mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
-                                   const FrozenRewritePatternSet &patterns) {
-  return applyPatternsAndFoldGreedily(regions, patterns,
-                                      maxPatternMatchIterations);
+LogicalResult mlir::applyPatternsAndFoldGreedily(
+    Operation *op, const FrozenRewritePatternSet &patterns,
+    unsigned maxIterations, bool useTopDownTraversal) {
+  return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations,
+                                      useTopDownTraversal);
 }
+/// Rewrite the given regions, which must be isolated from above.
 LogicalResult
 mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
                                    const FrozenRewritePatternSet &patterns,
-                                   unsigned maxIterations) {
+                                   bool useTopDownTraversal) {
+  return applyPatternsAndFoldGreedily(
+      regions, patterns, maxPatternMatchIterations, useTopDownTraversal);
+}
+LogicalResult mlir::applyPatternsAndFoldGreedily(
+    MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
+    unsigned maxIterations, bool useTopDownTraversal) {
   if (regions.empty())
     return success();
 
@@ -252,7 +285,8 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
          "patterns can only be applied to operations IsolatedFromAbove");
 
   // Start the pattern driver.
-  GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns);
+  GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns,
+                                    useTopDownTraversal);
   bool converged = driver.simplify(regions, maxIterations);
   LLVM_DEBUG(if (!converged) {
     llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "

diff  --git a/mlir/test/Transforms/canonicalize-td.mlir b/mlir/test/Transforms/canonicalize-td.mlir
new file mode 100644
index 0000000000000..caecb8cb2fa88
--- /dev/null
+++ b/mlir/test/Transforms/canonicalize-td.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize{top-down=true})' | FileCheck %s --check-prefix=TD
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='func(canonicalize)' | FileCheck %s --check-prefix=BU
+
+
+// BU-LABEL: func @default_insertion_position
+// TD-LABEL: func @default_insertion_position
+func @default_insertion_position(%cond: i1) {
+  // Constant should be folded into the entry block.
+
+  // BU: constant 2
+  // BU-NEXT: scf.if
+
+  // TD: constant 2
+  // TD-NEXT: scf.if
+  scf.if %cond {
+    %0 = constant 1 : i32
+    %2 = addi %0, %0 : i32
+    "foo.yield"(%2) : (i32) -> ()
+  }
+  return
+}
+
+// This shows that we don't pull the constant out of the region because it
+// wants to be the insertion point for the constant.
+// BU-LABEL: func @custom_insertion_position
+// TD-LABEL: func @custom_insertion_position
+func @custom_insertion_position() {
+  // BU: test.one_region_op
+  // BU-NEXT: constant 2
+
+  // TD: test.one_region_op
+  // TD-NEXT: constant 2
+  "test.one_region_op"() ({
+
+    %0 = constant 1 : i32
+    %2 = addi %0, %0 : i32
+    "foo.yield"(%2) : (i32) -> ()
+  }) : () -> ()
+  return
+}
+


        


More information about the Mlir-commits mailing list