[Mlir-commits] [mlir] d820acd - [mlir][bufferize][NFC] Use custom walk instead of GreedyPatternRewriter

Matthias Springer llvmlistbot at llvm.org
Fri Apr 22 02:23:24 PDT 2022


Author: Matthias Springer
Date: 2022-04-22T18:23:09+09:00
New Revision: d820acdde1986788c88a1f68552ae96bb5b57431

URL: https://github.com/llvm/llvm-project/commit/d820acdde1986788c88a1f68552ae96bb5b57431
DIFF: https://github.com/llvm/llvm-project/commit/d820acdde1986788c88a1f68552ae96bb5b57431.diff

LOG: [mlir][bufferize][NFC] Use custom walk instead of GreedyPatternRewriter

The bufferization driver was previously using a GreedyPatternRewriter. This was problematic because bufferization must traverse ops top-to-bottom. The GreedyPatternRewriter was previously configured via `useTopDownTraversal`, but this was a hack; this API was just meant for performance improvements and should not affect the result of the rewrite.

BEGIN_PUBLIC
No public commit message needed.
END_PUBLIC

Differential Revision: https://reviews.llvm.org/D123618

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
    mlir/test/Dialect/Tensor/bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index d3e6a5c7f5e3b..8cefd83eead38 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -46,6 +46,13 @@ namespace bufferization {
 /// with 
diff ering element types or memory spaces.
 FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
                                           MemRefType type);
+
+/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
+/// to_memref op are 
diff erent, a memref.cast is needed.
+LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
+                                       ToMemrefOp toMemref,
+                                       bool allowSameType = true);
+
 } // namespace bufferization
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 87f1d480ec340..5c13b9290b6b8 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -21,10 +21,6 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
                                               MemRefType destType) {
   auto srcType = value.getType().cast<MemRefType>();
 
-  // Casting to the same type, nothing to do.
-  if (srcType == destType)
-    return value;
-
   // Element type, rank and memory space must match.
   if (srcType.getElementType() != destType.getElementType())
     return failure();
@@ -79,6 +75,55 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
   return copy;
 }
 
+/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
+/// to_memref op are 
diff erent, a memref.cast is needed.
+LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
+    RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
+  auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
+  if (!memrefToTensor)
+    return failure();
+
+  Type srcType = memrefToTensor.memref().getType();
+  Type destType = toMemref.getType();
+
+  // Directly rewrite if the type did not change.
+  if (srcType == destType) {
+    // Function can be configured to only handle cases where a cast is needed.
+    if (!allowSameType)
+      return failure();
+    rewriter.replaceOp(toMemref, memrefToTensor.memref());
+    return success();
+  }
+
+  auto rankedSrcType = srcType.dyn_cast<MemRefType>();
+  auto rankedDestType = destType.dyn_cast<MemRefType>();
+  auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
+
+  // Ranked memref -> Ranked memref cast.
+  if (rankedSrcType && rankedDestType) {
+    FailureOr<Value> replacement = castOrReallocMemRefValue(
+        rewriter, memrefToTensor.memref(), rankedDestType);
+    if (failed(replacement))
+      return failure();
+
+    rewriter.replaceOp(toMemref, *replacement);
+    return success();
+  }
+
+  // Unranked memref -> Ranked memref cast: May require a copy.
+  // TODO: Not implemented at the moment.
+  if (unrankedSrcType && rankedDestType)
+    return failure();
+
+  // Unranked memref -> unranked memref cast
+  // Ranked memref -> unranked memref cast: No copy needed.
+  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
+         "expected that types are cast compatible");
+  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
+                                              memrefToTensor.memref());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // CloneOp
 //===----------------------------------------------------------------------===//
@@ -249,51 +294,6 @@ struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
   }
 };
 
-/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
-/// to_memref op are 
diff erent, a memref.cast is needed.
-static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
-                                              ToMemrefOp toMemref,
-                                              bool allowSameType = true) {
-  auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
-  if (!memrefToTensor)
-    return failure();
-
-  Type srcType = memrefToTensor.memref().getType();
-  Type destType = toMemref.getType();
-
-  // Function can be configured to only handle cases where a cast is needed.
-  if (!allowSameType && srcType == destType)
-    return failure();
-
-  auto rankedSrcType = srcType.dyn_cast<MemRefType>();
-  auto rankedDestType = destType.dyn_cast<MemRefType>();
-  auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
-
-  // Ranked memref -> Ranked memref cast.
-  if (rankedSrcType && rankedDestType) {
-    FailureOr<Value> replacement = castOrReallocMemRefValue(
-        rewriter, memrefToTensor.memref(), rankedDestType);
-    if (failed(replacement))
-      return failure();
-
-    rewriter.replaceOp(toMemref, *replacement);
-    return success();
-  }
-
-  // Unranked memref -> Ranked memref cast: May require a copy.
-  // TODO: Not implemented at the moment.
-  if (unrankedSrcType && rankedDestType)
-    return failure();
-
-  // Unranked memref -> unranked memref cast
-  // Ranked memref -> unranked memref cast: No copy needed.
-  assert(memref::CastOp::areCastCompatible(srcType, destType) &&
-         "expected that types are cast compatible");
-  rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
-                                              memrefToTensor.memref());
-  return success();
-}
-
 /// Canonicalize bufferization.to_tensor + bufferization.to_memref to
 /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
 struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 12451ca5f2ced..8571617b2a677 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -242,65 +242,6 @@ static bool hasTensorSemantics(Operation *op) {
   return hasTensorResult || hasTensorOperand;
 }
 
-/// Rewrite pattern that bufferizes bufferizable ops.
-struct BufferizationPattern
-    : public OpInterfaceRewritePattern<BufferizableOpInterface> {
-  BufferizationPattern(MLIRContext *context, BufferizationState &state,
-                       PatternBenefit benefit = 1)
-      : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit),
-        state(&state) {}
-
-  LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp,
-                                PatternRewriter &rewriter) const override {
-    const BufferizationOptions &options = state->getOptions();
-
-    // No tensors => no buffers.
-    if (!hasTensorSemantics(bufferizableOp.getOperation()))
-      return failure();
-    if (!options.isOpAllowed(bufferizableOp.getOperation()))
-      return failure();
-    return bufferizableOp.bufferize(rewriter, *state);
-  }
-
-private:
-  BufferizationState *const state;
-};
-
-/// Check the result of bufferization. Return an error if an op was not
-/// bufferized, unless partial bufferization is allowed.
-static LogicalResult
-checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
-  if (!options.allowUnknownOps) {
-    // Check if all ops were bufferized.
-    LogicalResult status = success();
-    op->walk([&](Operation *op) {
-      if (!hasTensorSemantics(op))
-        return WalkResult::advance();
-
-      // Bufferization dialect ops will canonicalize away if all other ops are
-      // bufferized.
-      if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
-        return WalkResult::advance();
-
-      // Ops that are not in the allow list can be ignored.
-      if (!options.isOpAllowed(op))
-        return WalkResult::advance();
-
-      // Ops without any uses and no side effects will fold away.
-      if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
-        return WalkResult::advance();
-
-      status = op->emitError("op was not bufferized");
-      return WalkResult::interrupt();
-    });
-
-    if (failed(status))
-      return status;
-  }
-
-  return success();
-}
-
 LogicalResult
 bufferization::finalizeBuffers(Operation *op,
                                const BufferizationOptions &options) {
@@ -335,35 +276,131 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
   return success();
 }
 
+namespace {
+/// A rewriter that keeps track of extra information during bufferization.
+class BufferizationRewriter : public IRRewriter {
+public:
+  BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
+                        DenseSet<Operation *> &toMemrefOps,
+                        SmallVector<Operation *> &worklist)
+      : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
+        worklist(worklist) {}
+
+protected:
+  void notifyOperationRemoved(Operation *op) override {
+    IRRewriter::notifyOperationRemoved(op);
+    erasedOps.insert(op);
+  }
+
+  void notifyOperationInserted(Operation *op) override {
+    IRRewriter::notifyOperationInserted(op);
+
+    // Keep track of to_memref ops.
+    if (isa<ToMemrefOp>(op)) {
+      toMemrefOps.insert(op);
+      return;
+    }
+
+    // Skip to_tensor ops.
+    if (isa<ToTensorOp>(op))
+      return;
+
+    // A new bufferizable op was inserted. Add it to the worklist.
+    if (hasTensorSemantics(op))
+      worklist.push_back(op);
+  }
+
+private:
+  /// A set of all erased ops.
+  DenseSet<Operation *> &erasedOps;
+
+  /// A set of all to_memref ops.
+  DenseSet<Operation *> &toMemrefOps;
+
+  /// The list of bufferizable ops.
+  SmallVector<Operation *> &worklist;
+};
+} // namespace
+
 LogicalResult
 bufferization::bufferizeOp(Operation *op,
                            BufferizationState &bufferizationState) {
-  // Bufferize the op and its nested ops.
-  RewritePatternSet patterns(op->getContext());
-  patterns.add<BufferizationPattern>(patterns.getContext(), bufferizationState);
-
-  // Bufferize ops top-to-bottom. When creating a new op, we should ideally
-  // know the exact memref type of all operands. Otherwise, we have to use a
-  // memref type with a fully dynamic layout map, which has to canonicalize
-  // away. This is less efficient.
+  const auto &options = bufferizationState.getOptions();
+
+  // Keep track of to_memref ops.
+  DenseSet<Operation *> toMemrefOps;
+  op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
+
+  // Gather all bufferizable ops in top-to-bottom order.
   //
-  // Note: If "fullyDynamicLayoutMaps = false", we may have to insert buffer
-  // copies to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-
-  // compatible layout maps when doing a traversal other than top-to-bottom.
-  // There are currently no canonicalization patterns to fold these away.
-  GreedyRewriteConfig config;
-  config.useTopDownTraversal = true;
-
-  // TODO: Perform a preorder walk instead of the greedy pattern rewriter. This
-  // would be more efficient because every bufferization pattern is guaranteed
-  // to apply only a single time (otherwise, an assertion would be triggered).
-  // However, there are restrictions wrt. erasing ops during a preorder walk,
-  // which would likely require a larger refactoring.
-  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
-    return failure();
+  // We should ideally know the exact memref type of all operands when
+  // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
+  // Otherwise, we have to use a memref type with a fully dynamic layout map,
+  // which has to canonicalize away. This is less efficient.
+  //
+  // If "fullyDynamicLayoutMaps = false", we would have to insert buffer copies
+  // to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-compatible
+  // layout maps when doing a traversal other than top-to-bottom. These would
+  // not easily fold away.
+  SmallVector<Operation *> worklist;
+  op->walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (hasTensorSemantics(op))
+      worklist.push_back(op);
+  });
 
-  if (failed(checkBufferizationResult(op, bufferizationState.getOptions())))
-    return failure();
+  // Keep track of all erased ops.
+  DenseSet<Operation *> erasedOps;
+
+  // Bufferize all ops.
+  BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
+                                 worklist);
+  for (unsigned i = 0; i < worklist.size(); ++i) {
+    Operation *op = worklist[i];
+    // Skip ops that were erased.
+    if (erasedOps.contains(op))
+      continue;
+    // Skip ops that are not bufferizable.
+    auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
+    if (!bufferizableOp)
+      continue;
+    // Continue ops that are not allowed.
+    if (!options.isOpAllowed(op))
+      continue;
+    // Bufferize the op.
+    rewriter.setInsertionPoint(op);
+    (void)bufferizableOp.bufferize(rewriter, bufferizationState);
+  }
+
+  // Fold all to_memref(to_tensor(x)) pairs.
+  for (Operation *op : toMemrefOps) {
+    if (erasedOps.contains(op))
+      continue;
+    rewriter.setInsertionPoint(op);
+    (void)bufferization::foldToMemrefToTensorPair(rewriter,
+                                                  cast<ToMemrefOp>(op));
+  }
+
+  /// Check the result of bufferization. Return an error if an op was not
+  /// bufferized, unless partial bufferization is allowed.
+  if (bufferizationState.getOptions().allowUnknownOps)
+    return success();
+
+  for (Operation *op : worklist) {
+    // Skip ops that are entirely gone.
+    if (erasedOps.contains(op))
+      continue;
+    // Ops that no longer have tensor semantics (because they were updated
+    // in-place) are allowed.
+    if (!hasTensorSemantics(op))
+      continue;
+    // Continue ops that are not allowed.
+    if (!options.isOpAllowed(op))
+      continue;
+    // Ops without any uses and no side effects will fold away.
+    if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
+      continue;
+    return op->emitError("op was not bufferized");
+  }
 
   return success();
 }

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 5fef35f3d4c19..774161376a841 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -884,8 +884,8 @@ func.func @scf_for_yield_non_equivalent(
 //       CHECK:   %[[cloned:.*]] = bufferization.clone %[[t]]
 //       CHECK:   %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[cloned]])
 // This alloc is for the linalg.init_tensor.
-//       CHECK:     %[[alloc2:.*]] = memref.alloc(%{{.*}})
-//       CHECK:     memref.dealloc %[[iter]]
+//   CHECK-DAG:     %[[alloc2:.*]] = memref.alloc(%{{.*}})
+//   CHECK-DAG:     memref.dealloc %[[iter]]
 // This alloc is for the scf.yield.
 //       CHECK:     %[[alloc3:.*]] = memref.alloc(%{{.*}})
 //       CHECK:     memref.copy %[[alloc2]], %[[alloc3]]

diff  --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index e983d28af1fa5..83fb1f0bfd9e3 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
+// RUN: mlir-opt %s -tensor-bufferize -cse | FileCheck %s
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)>
@@ -72,14 +72,6 @@ func.func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
   return %0 : f32
 }
 
-// CHECK-LABEL:   func @tensor.from_elements_no_elements() -> tensor<0xindex> {
-// CHECK:           %[[RET:.*]] = arith.constant dense<> : tensor<0xindex>
-// CHECK:           return %[[RET]] : tensor<0xindex>
-func.func @tensor.from_elements_no_elements() -> tensor<0xindex> {
-  %0 = tensor.from_elements : tensor<0xindex>
-  return %0 : tensor<0xindex>
-}
-
 // CHECK-LABEL:   func @tensor.from_elements_0d(
 // CHECK-SAME:        %[[ELEM0:.*]]: index) -> tensor<index> {
 // CHECK:           %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<index>
@@ -185,8 +177,8 @@ func.func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> {
 // CHECK-SAME:                                       %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
 // CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK:           %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
-// CHECK:           %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
+// CHECK-DAG:       %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
+// CHECK-DAG:       %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
 // CHECK:           scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
 // CHECK:             %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
 // CHECK:             store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
@@ -212,7 +204,7 @@ func.func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tenso
 // CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
-// CHECK:           %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
+// CHECK-DAG:       %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
 // CHECK:           scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
 // CHECK:             %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index
 // CHECK:             store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
@@ -278,8 +270,8 @@ func.func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
   // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
   // CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
   // CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : memref<?x10xf32>
-  //     CHECK: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
-  //     CHECK: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
+  // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
+  // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
   //     CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]])
   //     CHECK: memref.copy %[[m1]], %[[alloc]]
   //     CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1]


        


More information about the Mlir-commits mailing list