[Mlir-commits] [mlir] 04b5274 - [MLIR] Introduce applyOpPatternsAndFold for op local rewrites

Uday Bondhugula llvmlistbot at llvm.org
Wed Apr 15 01:41:47 PDT 2020


Author: Uday Bondhugula
Date: 2020-04-15T14:10:01+05:30
New Revision: 04b5274ede3ebc1de98c47e34cb762bae474696b

URL: https://github.com/llvm/llvm-project/commit/04b5274ede3ebc1de98c47e34cb762bae474696b
DIFF: https://github.com/llvm/llvm-project/commit/04b5274ede3ebc1de98c47e34cb762bae474696b.diff

LOG: [MLIR] Introduce applyOpPatternsAndFold for op local rewrites

Introduce mlir::applyOpPatternsAndFold which applies patterns as well as
any folding only on a specified op (in contrast to
applyPatternsAndFoldGreedily which applies patterns only on the regions
of an op isolated from above).  The caller is made aware of the op being
folded away or erased.

Depends on D77485.

Differential Revision: https://reviews.llvm.org/D77487

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
    mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp
    mlir/test/Dialect/Affine/affine-data-copy.mlir
    mlir/test/Dialect/Affine/simplify-affine-structures.mlir
    mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 4679d9871922..6dbc5b9664d0 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -455,6 +455,15 @@ bool applyPatternsAndFoldGreedily(Operation *op,
 /// Rewrite the given regions, which must be isolated from above.
 bool applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
                                   const OwningRewritePatternList &patterns);
+
+/// Applies the specified patterns on `op` alone while also trying to fold it,
+/// by selecting the highest benefits patterns in a greedy manner. Returns true
+/// if no more patterns can be matched. `erased` is set to true if `op` was
+/// folded away or erased as a result of becoming dead. Note: This does not
+/// apply any patterns recursively to the regions of `op`.
+bool applyOpPatternsAndFold(Operation *op,
+                            const OwningRewritePatternList &patterns,
+                            bool *erased = nullptr);
 } // end namespace mlir
 
 #endif // MLIR_PATTERN_MATCH_H

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index c861b214d3b3..78128ff4b0df 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -211,20 +211,25 @@ void AffineDataCopyGeneration::runOnFunction() {
   for (auto &block : f)
     runOnBlock(&block, copyNests);
 
-  // Promote any single iteration loops in the copy nests.
+  // Promote any single iteration loops in the copy nests and collect
+  // load/stores to simplify.
+  SmallVector<Operation *, 4> copyOps;
   for (auto nest : copyNests)
-    nest->walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); });
+    // With a post order walk, the erasure of loops does not affect
+    // continuation of the walk or the collection of load/store ops.
+    nest->walk([&](Operation *op) {
+      if (auto forOp = dyn_cast<AffineForOp>(op))
+        promoteIfSingleIteration(forOp);
+      else if (isa<AffineLoadOp>(op) || isa<AffineStoreOp>(op))
+        copyOps.push_back(op);
+    });
 
   // Promoting single iteration loops could lead to simplification of
-  // load's/store's. We will run canonicalization patterns on load/stores.
-  // TODO: this whole function load/store canonicalization should be replaced by
-  // canonicalization that is limited to only the load/store ops
-  // introduced/touched by this pass (those inside 'copyNests'). This would be
-  // possible once the necessary support is available in the pattern rewriter.
-  if (!copyNests.empty()) {
-    OwningRewritePatternList patterns;
-    AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
-    AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
-    applyPatternsAndFoldGreedily(f, std::move(patterns));
-  }
+  // contained load's/store's, and the latter could anyway also be
+  // canonicalized.
+  OwningRewritePatternList patterns;
+  AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
+  AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
+  for (auto op : copyOps)
+    applyOpPatternsAndFold(op, std::move(patterns));
 }

diff  --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 0df4ea0d3f87..fada39aa1cf2 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -6,14 +6,16 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements a pass to simplify affine structures.
+// This file implements a pass to simplify affine structures in operations.
 //
 //===----------------------------------------------------------------------===//
 
 #include "PassDetail.h"
 #include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/Utils.h"
 
 #define DEBUG_TYPE "simplify-affine-structure"
@@ -77,13 +79,22 @@ mlir::createSimplifyAffineStructuresPass() {
 void SimplifyAffineStructures::runOnFunction() {
   auto func = getFunction();
   simplifiedAttributes.clear();
-  func.walk([&](Operation *opInst) {
-    for (auto attr : opInst->getAttrs()) {
+  OwningRewritePatternList patterns;
+  AffineForOp::getCanonicalizationPatterns(patterns, func.getContext());
+  AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext());
+  AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext());
+  func.walk([&](Operation *op) {
+    for (auto attr : op->getAttrs()) {
       if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())
-        simplifyAndUpdateAttribute(opInst, attr.first, mapAttr);
+        simplifyAndUpdateAttribute(op, attr.first, mapAttr);
       else if (auto setAttr = attr.second.dyn_cast<IntegerSetAttr>())
-        simplifyAndUpdateAttribute(opInst, attr.first, setAttr);
+        simplifyAndUpdateAttribute(op, attr.first, setAttr);
     }
+
+    // The simplification of the attribute will likely simplify the op. Try to
+    // fold / apply canonicalization patterns when we have affine dialect ops.
+    if (isa<AffineForOp>(op) || isa<AffineIfOp>(op) || isa<AffineApplyOp>(op))
+      applyOpPatternsAndFold(op, patterns);
   });
 
   // Turn memrefs' non-identity layouts maps into ones with identity. Collect

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 53c8e9fbd1c2..256c1340d0c6 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -26,6 +26,10 @@ using namespace mlir;
 /// The max number of iterations scanning for pattern match.
 static unsigned maxPatternMatchIterations = 10;
 
+//===----------------------------------------------------------------------===//
+// GreedyPatternRewriteDriver
+//===----------------------------------------------------------------------===//
+
 namespace {
 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
 /// applies the locally optimal patterns in a roughly "bottom up" way.
@@ -37,8 +41,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
     worklist.reserve(64);
   }
 
-  /// Perform the rewrites while folding and erasing any dead ops. Return true
-  /// if the rewrite converges in `maxIterations`.
   bool simplify(MutableArrayRef<Region> regions, int maxIterations);
 
   void addToWorklist(Operation *op) {
@@ -248,3 +250,112 @@ bool mlir::applyPatternsAndFoldGreedily(
   });
   return converged;
 }
+
+//===----------------------------------------------------------------------===//
+// OpPatternRewriteDriver
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This is a simple driver for the PatternMatcher to apply patterns and perform
+/// folding on a single op. It repeatedly applies locally optimal patterns.
+class OpPatternRewriteDriver : public PatternRewriter {
+public:
+  explicit OpPatternRewriteDriver(MLIRContext *ctx,
+                                  const OwningRewritePatternList &patterns)
+      : PatternRewriter(ctx), matcher(patterns), folder(ctx) {}
+
+  bool simplifyLocally(Operation *op, int maxIterations, bool &erased);
+
+  /// No additional action needed other than inserting the op.
+  Operation *insert(Operation *op) override { return OpBuilder::insert(op); }
+
+  // These are hooks implemented for PatternRewriter.
+protected:
+  /// If an operation is about to be removed, mark it so that we can let clients
+  /// know.
+  void notifyOperationRemoved(Operation *op) override {
+    opErasedViaPatternRewrites = true;
+  }
+
+  // When a root is going to be replaced, its removal will be notified as well.
+  // So there is nothing to do here.
+  void notifyRootReplaced(Operation *op) override {}
+
+private:
+  /// The low-level pattern matcher.
+  RewritePatternMatcher matcher;
+
+  /// Non-pattern based folder for operations.
+  OperationFolder folder;
+
+  /// Set to true if the operation has been erased via pattern rewrites.
+  bool opErasedViaPatternRewrites = false;
+};
+
+} // anonymous namespace
+
+/// Performs the rewrites and folding only on `op`. The simplification converges
+/// if the op is erased as a result of being folded, replaced, or dead, or no
+/// more changes happen in an iteration. Returns true if the rewrite converges
+/// in `maxIterations`. `erased` is set to true if `op` gets erased.
+bool OpPatternRewriteDriver::simplifyLocally(Operation *op, int maxIterations,
+                                             bool &erased) {
+  bool changed = false;
+  erased = false;
+  opErasedViaPatternRewrites = false;
+  int i = 0;
+  // Iterate until convergence or until maxIterations. Deletion of the op as
+  // a result of being dead or folded is convergence.
+  do {
+    // If the operation is trivially dead - remove it.
+    if (isOpTriviallyDead(op)) {
+      op->erase();
+      erased = true;
+      return true;
+    }
+
+    // Try to fold this op.
+    bool inPlaceUpdate;
+    if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
+                                   /*preReplaceAction=*/nullptr,
+                                   &inPlaceUpdate))) {
+      changed = true;
+      if (!inPlaceUpdate) {
+        erased = true;
+        return true;
+      }
+    }
+
+    // Make sure that any new operations are inserted at this point.
+    setInsertionPoint(op);
+
+    // Try to match one of the patterns. The rewriter is automatically
+    // notified of any necessary changes, so there is nothing else to do here.
+    changed |= matcher.matchAndRewrite(op, *this);
+    if ((erased = opErasedViaPatternRewrites))
+      return true;
+  } while (changed && ++i < maxIterations);
+
+  // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
+  return !changed;
+}
+
+/// Rewrites only `op` using the supplied canonicalization patterns and
+/// folding. `erased` is set to true if the op is erased as a result of being
+/// folded, replaced, or dead.
+bool mlir::applyOpPatternsAndFold(Operation *op,
+                                  const OwningRewritePatternList &patterns,
+                                  bool *erased) {
+  // Start the pattern driver.
+  OpPatternRewriteDriver driver(op->getContext(), patterns);
+  bool opErased;
+  bool converged =
+      driver.simplifyLocally(op, maxPatternMatchIterations, opErased);
+  if (erased)
+    *erased = opErased;
+  LLVM_DEBUG(if (!converged) {
+    llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
+                 << maxPatternMatchIterations << " times";
+  });
+  return converged;
+}

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 72f889e2315a..9fe96437e526 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Transforms/Utils.h"
 #include "llvm/ADT/DenseMap.h"
@@ -312,9 +313,19 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
                                   opGroupQueue, /*offset=*/0, forOp, b);
         lbShift = d * step;
       }
-      if (!prologue && res)
-        prologue = res;
-      epilogue = res;
+
+      if (res) {
+        // Simplify/canonicalize the affine.for.
+        OwningRewritePatternList patterns;
+        AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
+        bool erased;
+        applyOpPatternsAndFold(res, std::move(patterns), &erased);
+
+        if (!erased && !prologue)
+          prologue = res;
+        if (!erased)
+          epilogue = res;
+      }
     } else {
       // Start of first interval.
       lbShift = d * step;
@@ -694,7 +705,8 @@ bool mlir::isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
 }
 
 /// Return true if `loops` is a perfect nest.
-static bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef<AffineForOp> loops) {
+static bool LLVM_ATTRIBUTE_UNUSED
+isPerfectlyNested(ArrayRef<AffineForOp> loops) {
   auto outerLoop = loops.front();
   for (auto loop : loops.drop_front()) {
     auto parentForOp = dyn_cast<AffineForOp>(loop.getParentOp());

diff  --git a/mlir/test/Dialect/Affine/affine-data-copy.mlir b/mlir/test/Dialect/Affine/affine-data-copy.mlir
index 52c60d7177f8..97d64a6d1b44 100644
--- a/mlir/test/Dialect/Affine/affine-data-copy.mlir
+++ b/mlir/test/Dialect/Affine/affine-data-copy.mlir
@@ -216,7 +216,7 @@ func @min_upper_bound(%A: memref<4096xf32>) -> memref<4096xf32> {
   return %A : memref<4096xf32>
 }
 // CHECK:      affine.for %[[IV1:.*]] = 0 to 4096 step 100
-// CHECK-NEXT:   %[[BUF:.*]] = alloc() : memref<100xf32>
+// CHECK:        %[[BUF:.*]] = alloc() : memref<100xf32>
 // CHECK-NEXT:   affine.for %[[IV2:.*]] = #[[MAP_IDENTITY]](%[[IV1]]) to min #[[MAP_MIN_UB1]](%[[IV1]]) {
 // CHECK-NEXT:     affine.load %{{.*}}[%[[IV2]]] : memref<4096xf32>
 // CHECK-NEXT:     affine.store %{{.*}}, %[[BUF]][-%[[IV1]] + %[[IV2]]] : memref<100xf32>
@@ -226,7 +226,7 @@ func @min_upper_bound(%A: memref<4096xf32>) -> memref<4096xf32> {
 // CHECK-NEXT:     mulf
 // CHECK-NEXT:     affine.store %{{.*}}, %[[BUF]][-%[[IV1]] + %[[IV2]]] : memref<100xf32>
 // CHECK-NEXT:   }
-// CHECK-NEXT:   affine.for %[[IV2:.*]] = #[[MAP_IDENTITY]](%[[IV1]]) to min #[[MAP_MIN_UB1]](%[[IV1]]) {
+// CHECK:        affine.for %[[IV2:.*]] = #[[MAP_IDENTITY]](%[[IV1]]) to min #[[MAP_MIN_UB1]](%[[IV1]]) {
 // CHECK-NEXT:     affine.load %[[BUF]][-%[[IV1]] + %[[IV2]]] : memref<100xf32>
 // CHECK-NEXT:     affine.store %{{.*}}, %{{.*}}[%[[IV2]]] : memref<4096xf32>
 // CHECK-NEXT:   }
@@ -239,8 +239,8 @@ func @min_upper_bound(%A: memref<4096xf32>) -> memref<4096xf32> {
 // with multi-level tiling when the tile sizes used don't divide loop trip
 // counts.
 
-#lb = affine_map<(d0, d1) -> (d0 * 512, d1 * 6)>
-#ub = affine_map<(d0, d1) -> (d0 * 512 + 512, d1 * 6 + 6)>
+#lb = affine_map<()[s0, s1] -> (s0 * 512, s1 * 6)>
+#ub = affine_map<()[s0, s1] -> (s0 * 512 + 512, s1 * 6 + 6)>
 
 // CHECK-DAG: #[[LB:.*]] = affine_map<()[s0, s1] -> (s0 * 512, s1 * 6)>
 // CHECK-DAG: #[[UB:.*]] = affine_map<()[s0, s1] -> (s0 * 512 + 512, s1 * 6 + 6)>
@@ -250,7 +250,7 @@ func @min_upper_bound(%A: memref<4096xf32>) -> memref<4096xf32> {
 // CHECK-SAME: [[j:arg[0-9]+]]
 func @max_lower_bound(%M: memref<2048x516xf64>, %i : index, %j : index) {
   affine.for %ii = 0 to 2048 {
-    affine.for %jj = max #lb(%i, %j) to min #ub(%i, %j) {
+    affine.for %jj = max #lb()[%i, %j] to min #ub()[%i, %j] {
       affine.load %M[%ii, %jj] : memref<2048x516xf64>
     }
   }

diff  --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
index 832d723ab6f6..49fa339aa88a 100644
--- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
@@ -1,19 +1,19 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -simplify-affine-structures | FileCheck %s
 
-// CHECK-DAG: #[[SET_EMPTY_2D:.*]] = affine_set<(d0, d1) : (1 == 0)>
+// CHECK-DAG: #[[SET_EMPTY:.*]] = affine_set<() : (1 == 0)>
 // CHECK-DAG: #[[SET_2D:.*]] = affine_set<(d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0)>
-// CHECK-DAG: #[[SET_EMPTY_2D_2S:.*]] = affine_set<(d0, d1)[s0, s1] : (1 == 0)>
-// CHECK-DAG: #[[SET_2D_2S:.*]] = affine_set<(d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0)>
-// CHECK-DAG: #[[SET_EMPTY_1D:.*]] = affine_set<(d0) : (1 == 0)>
-// CHECK-DAG: #[[SET_EMPTY_1D_2S:.*]] = affine_set<(d0)[s0, s1] : (1 == 0)>
-// CHECK-DAG: #[[SET_EMPTY_3D:.*]] = affine_set<(d0, d1, d2) : (1 == 0)>
+// CHECK-DAG: #[[SET_7_11:.*]] = affine_set<(d0, d1) : (d0 * 7 + d1 * 5 + 88 == 0, d0 * 5 - d1 * 11 + 60 == 0, d0 * 11 + d1 * 7 - 24 == 0, d0 * 7 + d1 * 5 + 88 == 0)>
+
+// An external function that we will use in bodies to avoid DCE.
+func @external() -> ()
 
 // CHECK-LABEL: func @test_gaussian_elimination_empty_set0() {
 func @test_gaussian_elimination_empty_set0() {
   affine.for %arg0 = 1 to 10 {
     affine.for %arg1 = 1 to 100 {
-      // CHECK: [[SET_EMPTY_2D]](%arg0, %arg1)
+      // CHECK: affine.if #[[SET_EMPTY]]()
       affine.if affine_set<(d0, d1) : (2 == 0)>(%arg0, %arg1) {
+        call @external() : () -> ()
       }
     }
   }
@@ -24,8 +24,9 @@ func @test_gaussian_elimination_empty_set0() {
 func @test_gaussian_elimination_empty_set1() {
   affine.for %arg0 = 1 to 10 {
     affine.for %arg1 = 1 to 100 {
-      // CHECK: [[SET_EMPTY_2D]](%arg0, %arg1)
+      // CHECK: affine.if #[[SET_EMPTY]]()
       affine.if affine_set<(d0, d1) : (1 >= 0, -1 >= 0)> (%arg0, %arg1) {
+        call @external() : () -> ()
       }
     }
   }
@@ -38,6 +39,7 @@ func @test_gaussian_elimination_non_empty_set2() {
     affine.for %arg1 = 1 to 100 {
       // CHECK: #[[SET_2D]](%arg0, %arg1)
       affine.if affine_set<(d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0, d1 + 101 >= 0)>(%arg0, %arg1) {
+        call @external() : () -> ()
       }
     }
   }
@@ -50,8 +52,9 @@ func @test_gaussian_elimination_empty_set3() {
   %c11 = constant 11 : index
   affine.for %arg0 = 1 to 10 {
     affine.for %arg1 = 1 to 100 {
-      // CHECK: #[[SET_EMPTY_2D_2S]](%arg0, %arg1)[%c7, %c11]
+      // CHECK: #[[SET_EMPTY]]()
       affine.if affine_set<(d0, d1)[s0, s1] : (d0 - s0 == 0, d0 + s0 == 0, s0 - 1 == 0)>(%arg0, %arg1)[%c7, %c11] {
+        call @external() : () -> ()
       }
     }
   }
@@ -70,8 +73,9 @@ func @test_gaussian_elimination_non_empty_set4() {
   %c11 = constant 11 : index
   affine.for %arg0 = 1 to 10 {
     affine.for %arg1 = 1 to 100 {
-      // CHECK: #[[SET_2D_2S]](%arg0, %arg1)[%c7, %c11]
+      // CHECK: #[[SET_7_11]](%arg0, %arg1)
       affine.if #set_2d_non_empty(%arg0, %arg1)[%c7, %c11] {
+        call @external() : () -> ()
       }
     }
   }
@@ -79,7 +83,6 @@ func @test_gaussian_elimination_non_empty_set4() {
 }
 
 // Add invalid constraints to previous non-empty set to make it empty.
-// Set for test case: test_gaussian_elimination_empty_set5
 #set_2d_empty = affine_set<(d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0,
                                        d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0,
                                        d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0,
@@ -92,8 +95,9 @@ func @test_gaussian_elimination_empty_set5() {
   %c11 = constant 11 : index
   affine.for %arg0 = 1 to 10 {
     affine.for %arg1 = 1 to 100 {
-      // CHECK: #[[SET_EMPTY_2D_2S]](%arg0, %arg1)[%c7, %c11]
+      // CHECK: #[[SET_EMPTY]]()
       affine.if #set_2d_empty(%arg0, %arg1)[%c7, %c11] {
+        call @external() : () -> ()
       }
     }
   }
@@ -147,6 +151,7 @@ func @test_fuzz_explosion(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i
   affine.for %arg4 = 1 to 10 {
     affine.for %arg5 = 1 to 100 {
       affine.if #set_fuzz_virus(%arg4, %arg5, %arg0, %arg1, %arg2, %arg3) {
+        call @external() : () -> ()
       }
     }
   }
@@ -157,33 +162,33 @@ func @test_fuzz_explosion(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : i
 func @test_empty_set(%N : index) {
   affine.for %i = 0 to 10 {
     affine.for %j = 0 to 10 {
-      // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2)
+      // CHECK: affine.if #[[SET_EMPTY]]()
       affine.if affine_set<(d0, d1) : (d0 - d1 >= 0, d1 - d0 - 1 >= 0)>(%i, %j) {
         "foo"() : () -> ()
       }
-      // CHECK: affine.if #[[SET_EMPTY_1D]](%arg1)
+      // CHECK: affine.if #[[SET_EMPTY]]()
       affine.if affine_set<(d0) : (d0 >= 0, -d0 - 1 >= 0)>(%i) {
         "bar"() : () -> ()
       }
-      // CHECK: affine.if #[[SET_EMPTY_1D]](%arg1)
+      // CHECK: affine.if #[[SET_EMPTY]]()
       affine.if affine_set<(d0) : (d0 >= 0, -d0 - 1 >= 0)>(%i) {
         "foo"() : () -> ()
       }
-      // CHECK: affine.if #[[SET_EMPTY_1D_2S]](%arg1)[%arg0, %arg0]
+      // CHECK: affine.if #[[SET_EMPTY]]()
       affine.if affine_set<(d0)[s0, s1] : (d0 >= 0, -d0 + s0 - 1 >= 0, -s0 >= 0)>(%i)[%N, %N] {
         "bar"() : () -> ()
       }
-      // CHECK: affine.if #[[SET_EMPTY_3D]](%arg1, %arg2, %arg0)
+      // CHECK: affine.if #[[SET_EMPTY]]()
       // The set below implies d0 = d1; so d1 >= d0, but d0 >= d1 + 1.
       affine.if affine_set<(d0, d1, d2) : (d0 - d1 == 0, d2 - d0 >= 0, d0 - d1 - 1 >= 0)>(%i, %j, %N) {
         "foo"() : () -> ()
       }
-      // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2)
+      // CHECK: affine.if #[[SET_EMPTY]]()
       // The set below has rational solutions but no integer solutions; GCD test catches it.
       affine.if affine_set<(d0, d1) : (d0*2 -d1*2 - 1 == 0, d0 >= 0, -d0 + 100 >= 0, d1 >= 0, -d1 + 100 >= 0)>(%i, %j) {
         "foo"() : () -> ()
       }
-      // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2)
+      // CHECK: affine.if #[[SET_EMPTY]]()
       affine.if affine_set<(d0, d1) : (d1 == 0, d0 - 1 >= 0, - d0 - 1 >= 0)>(%i, %j) {
         "foo"() : () -> ()
       }
@@ -193,12 +198,12 @@ func @test_empty_set(%N : index) {
   affine.for %k = 0 to 10 {
     affine.for %l = 0 to 10 {
       // Empty because no multiple of 8 lies between 4 and 7.
-      // CHECK: affine.if #[[SET_EMPTY_1D]](%arg1)
+      // CHECK: affine.if #[[SET_EMPTY]]()
       affine.if affine_set<(d0) : (8*d0 - 4 >= 0, -8*d0 + 7 >= 0)>(%k) {
         "foo"() : () -> ()
       }
       // Same as above but with equalities and inequalities.
-      // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2)
+      // CHECK: affine.if #[[SET_EMPTY]]()
       affine.if affine_set<(d0, d1) : (d0 - 4*d1 == 0, 4*d1 - 5 >= 0, -4*d1 + 7 >= 0)>(%k, %l) {
         "foo"() : () -> ()
       }
@@ -206,12 +211,12 @@ func @test_empty_set(%N : index) {
       // 8*d1 here is a multiple of 4, and so can't lie between 9 and 11. GCD
       // tightening will tighten constraints to 4*d0 + 8*d1 >= 12 and 4*d0 +
       // 8*d1 <= 8; hence infeasible.
-      // CHECK: affine.if #[[SET_EMPTY_2D]](%arg1, %arg2)
+      // CHECK: affine.if #[[SET_EMPTY]]()
       affine.if affine_set<(d0, d1) : (4*d0 + 8*d1 - 9 >= 0, -4*d0 - 8*d1 + 11 >= 0)>(%k, %l) {
         "foo"() : () -> ()
       }
       // Same as above but with equalities added into the mix.
-      // CHECK: affine.if #[[SET_EMPTY_3D]](%arg1, %arg1, %arg2)
+      // CHECK: affine.if #[[SET_EMPTY]]()
       affine.if affine_set<(d0, d1, d2) : (d0 - 4*d2 == 0, d0 + 8*d1 - 9 >= 0, -d0 - 8*d1 + 11 >= 0)>(%k, %k, %l) {
         "foo"() : () -> ()
       }
@@ -219,7 +224,7 @@ func @test_empty_set(%N : index) {
   }
 
   affine.for %m = 0 to 10 {
-    // CHECK: affine.if #[[SET_EMPTY_1D]](%arg{{[0-9]+}})
+    // CHECK: affine.if #[[SET_EMPTY]]()
     affine.if affine_set<(d0) : (d0 mod 2 - 3 == 0)> (%m) {
       "foo"() : () -> ()
     }
@@ -230,20 +235,39 @@ func @test_empty_set(%N : index) {
 
 // -----
 
-// CHECK-DAG: #[[SET_2D:.*]] = affine_set<(d0, d1) : (d0 >= 0, -d0 + 50 >= 0)
-// CHECK-DAG: #[[SET_EMPTY:.*]] = affine_set<(d0, d1) : (1 == 0)
-// CHECK-DAG: #[[SET_UNIV:.*]] = affine_set<(d0, d1) : (0 == 0)
+// An external function that we will use in bodies to avoid DCE.
+func @external() -> ()
+
+// CHECK-DAG: #[[SET:.*]] = affine_set<()[s0] : (s0 >= 0, -s0 + 50 >= 0)
+// CHECK-DAG: #[[EMPTY_SET:.*]] = affine_set<() : (1 == 0)
+// CHECK-DAG: #[[UNIV_SET:.*]] = affine_set<() : (0 == 0)
 
 // CHECK-LABEL: func @simplify_set
 func @simplify_set(%a : index, %b : index) {
-  // CHECK: affine.if #[[SET_2D]]
+  // CHECK: affine.if #[[SET]]
   affine.if affine_set<(d0, d1) : (d0 - d1 + d1 + d0 >= 0, 2 >= 0, d0 >= 0, -d0 + 50 >= 0, -d0 + 100 >= 0)>(%a, %b) {
+    call @external() : () -> ()
   }
-  // CHECK: affine.if #[[SET_EMPTY]]
+  // CHECK: affine.if #[[EMPTY_SET]]
   affine.if affine_set<(d0, d1) : (d0 mod 2 - 1 == 0, d0 - 2 * (d0 floordiv 2) == 0)>(%a, %b) {
+    call @external() : () -> ()
   }
-  // CHECK: affine.if #[[SET_UNIV]]
+  // CHECK: affine.if #[[UNIV_SET]]
   affine.if affine_set<(d0, d1) : (1 >= 0, 3 >= 0)>(%a, %b) {
+    call @external() : () -> ()
   }
 	return
 }
+
+// -----
+
+// CHECK-DAG: -> (s0 * 2 + 1)
+
+// Test "op local" simplification on affine.apply. DCE on addi will not happen.
+func @affine.apply(%N : index) {
+  %v = affine.apply affine_map<(d0, d1) -> (d0 + d1 + 1)>(%N, %N)
+  addi %v, %v : index
+  // CHECK: affine.apply #map{{.*}}()[%arg0]
+  // CHECK-NEXT: addi
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index 6c8b546e6aef..b6df06ed9c54 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -88,16 +88,35 @@ void TestAffineDataCopy::runOnFunction() {
     generateCopyForMemRegion(region, loopNest, copyOptions, result);
   }
 
-  // Promote any single iteration loops in the copy nests.
+  // Promote any single iteration loops in the copy nests and simplify
+  // load/stores.
+  SmallVector<Operation *, 4> copyOps;
   for (auto nest : copyNests)
-    nest->walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); });
-
-  // Promoting single iteration loops could lead to simplification
-  // of load's/store's. We will run the canonicalization patterns again.
+    // With a post order walk, the erasure of loops does not affect
+    // continuation of the walk or the collection of load/store ops.
+    nest->walk([&](Operation *op) {
+      if (auto forOp = dyn_cast<AffineForOp>(op))
+        promoteIfSingleIteration(forOp);
+      else if (auto loadOp = dyn_cast<AffineLoadOp>(op))
+        copyOps.push_back(loadOp);
+      else if (auto storeOp = dyn_cast<AffineStoreOp>(op))
+        copyOps.push_back(storeOp);
+    });
+
+  // Promoting single iteration loops could lead to simplification of
+  // generated load's/store's, and the latter could anyway also be
+  // canonicalized.
   OwningRewritePatternList patterns;
-  AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
-  AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
-  applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
+  for (auto op : copyOps) {
+    patterns.clear();
+    if (isa<AffineLoadOp>(op)) {
+      AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
+    } else {
+      assert(isa<AffineStoreOp>(op) && "expected affine store op");
+      AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
+    }
+    applyOpPatternsAndFold(op, std::move(patterns));
+  }
 }
 
 namespace mlir {


        


More information about the Mlir-commits mailing list