[Mlir-commits] [mlir] d820acd - [mlir][bufferize][NFC] Use custom walk instead of GreedyPatternRewriter
Matthias Springer
llvmlistbot at llvm.org
Fri Apr 22 02:23:24 PDT 2022
Author: Matthias Springer
Date: 2022-04-22T18:23:09+09:00
New Revision: d820acdde1986788c88a1f68552ae96bb5b57431
URL: https://github.com/llvm/llvm-project/commit/d820acdde1986788c88a1f68552ae96bb5b57431
DIFF: https://github.com/llvm/llvm-project/commit/d820acdde1986788c88a1f68552ae96bb5b57431.diff
LOG: [mlir][bufferize][NFC] Use custom walk instead of GreedyPatternRewriter
The bufferization driver was previously using a GreedyPatternRewriter. This was problematic because bufferization must traverse ops top-to-bottom. The GreedyPatternRewriter was previously configured via `useTopDownTraversal`, but this was a hack; this API was just meant for performance improvements and should not affect the result of the rewrite.
BEGIN_PUBLIC
No public commit message needed.
END_PUBLIC
Differential Revision: https://reviews.llvm.org/D123618
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
mlir/test/Dialect/Tensor/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index d3e6a5c7f5e3b..8cefd83eead38 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -46,6 +46,13 @@ namespace bufferization {
/// with
diff ering element types or memory spaces.
FailureOr<Value> castOrReallocMemRefValue(OpBuilder &b, Value value,
MemRefType type);
+
+/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
+/// to_memref op are
diff erent, a memref.cast is needed.
+LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
+ ToMemrefOp toMemref,
+ bool allowSameType = true);
+
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 87f1d480ec340..5c13b9290b6b8 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -21,10 +21,6 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
MemRefType destType) {
auto srcType = value.getType().cast<MemRefType>();
- // Casting to the same type, nothing to do.
- if (srcType == destType)
- return value;
-
// Element type, rank and memory space must match.
if (srcType.getElementType() != destType.getElementType())
return failure();
@@ -79,6 +75,55 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
return copy;
}
+/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
+/// to_memref op are
diff erent, a memref.cast is needed.
+LogicalResult mlir::bufferization::foldToMemrefToTensorPair(
+ RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) {
+ auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
+ if (!memrefToTensor)
+ return failure();
+
+ Type srcType = memrefToTensor.memref().getType();
+ Type destType = toMemref.getType();
+
+ // Directly rewrite if the type did not change.
+ if (srcType == destType) {
+ // Function can be configured to only handle cases where a cast is needed.
+ if (!allowSameType)
+ return failure();
+ rewriter.replaceOp(toMemref, memrefToTensor.memref());
+ return success();
+ }
+
+ auto rankedSrcType = srcType.dyn_cast<MemRefType>();
+ auto rankedDestType = destType.dyn_cast<MemRefType>();
+ auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
+
+ // Ranked memref -> Ranked memref cast.
+ if (rankedSrcType && rankedDestType) {
+ FailureOr<Value> replacement = castOrReallocMemRefValue(
+ rewriter, memrefToTensor.memref(), rankedDestType);
+ if (failed(replacement))
+ return failure();
+
+ rewriter.replaceOp(toMemref, *replacement);
+ return success();
+ }
+
+ // Unranked memref -> Ranked memref cast: May require a copy.
+ // TODO: Not implemented at the moment.
+ if (unrankedSrcType && rankedDestType)
+ return failure();
+
+ // Unranked memref -> unranked memref cast
+ // Ranked memref -> unranked memref cast: No copy needed.
+ assert(memref::CastOp::areCastCompatible(srcType, destType) &&
+ "expected that types are cast compatible");
+ rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
+ memrefToTensor.memref());
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// CloneOp
//===----------------------------------------------------------------------===//
@@ -249,51 +294,6 @@ struct ToMemrefOfCast : public OpRewritePattern<ToMemrefOp> {
}
};
-/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
-/// to_memref op are
diff erent, a memref.cast is needed.
-static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
- ToMemrefOp toMemref,
- bool allowSameType = true) {
- auto memrefToTensor = toMemref.tensor().getDefiningOp<ToTensorOp>();
- if (!memrefToTensor)
- return failure();
-
- Type srcType = memrefToTensor.memref().getType();
- Type destType = toMemref.getType();
-
- // Function can be configured to only handle cases where a cast is needed.
- if (!allowSameType && srcType == destType)
- return failure();
-
- auto rankedSrcType = srcType.dyn_cast<MemRefType>();
- auto rankedDestType = destType.dyn_cast<MemRefType>();
- auto unrankedSrcType = srcType.dyn_cast<UnrankedMemRefType>();
-
- // Ranked memref -> Ranked memref cast.
- if (rankedSrcType && rankedDestType) {
- FailureOr<Value> replacement = castOrReallocMemRefValue(
- rewriter, memrefToTensor.memref(), rankedDestType);
- if (failed(replacement))
- return failure();
-
- rewriter.replaceOp(toMemref, *replacement);
- return success();
- }
-
- // Unranked memref -> Ranked memref cast: May require a copy.
- // TODO: Not implemented at the moment.
- if (unrankedSrcType && rankedDestType)
- return failure();
-
- // Unranked memref -> unranked memref cast
- // Ranked memref -> unranked memref cast: No copy needed.
- assert(memref::CastOp::areCastCompatible(srcType, destType) &&
- "expected that types are cast compatible");
- rewriter.replaceOpWithNewOp<memref::CastOp>(toMemref, destType,
- memrefToTensor.memref());
- return success();
-}
-
/// Canonicalize bufferization.to_tensor + bufferization.to_memref to
/// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in.
struct TensorLoadToMemref : public OpRewritePattern<ToMemrefOp> {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 12451ca5f2ced..8571617b2a677 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -242,65 +242,6 @@ static bool hasTensorSemantics(Operation *op) {
return hasTensorResult || hasTensorOperand;
}
-/// Rewrite pattern that bufferizes bufferizable ops.
-struct BufferizationPattern
- : public OpInterfaceRewritePattern<BufferizableOpInterface> {
- BufferizationPattern(MLIRContext *context, BufferizationState &state,
- PatternBenefit benefit = 1)
- : OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit),
- state(&state) {}
-
- LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp,
- PatternRewriter &rewriter) const override {
- const BufferizationOptions &options = state->getOptions();
-
- // No tensors => no buffers.
- if (!hasTensorSemantics(bufferizableOp.getOperation()))
- return failure();
- if (!options.isOpAllowed(bufferizableOp.getOperation()))
- return failure();
- return bufferizableOp.bufferize(rewriter, *state);
- }
-
-private:
- BufferizationState *const state;
-};
-
-/// Check the result of bufferization. Return an error if an op was not
-/// bufferized, unless partial bufferization is allowed.
-static LogicalResult
-checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
- if (!options.allowUnknownOps) {
- // Check if all ops were bufferized.
- LogicalResult status = success();
- op->walk([&](Operation *op) {
- if (!hasTensorSemantics(op))
- return WalkResult::advance();
-
- // Bufferization dialect ops will canonicalize away if all other ops are
- // bufferized.
- if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
- return WalkResult::advance();
-
- // Ops that are not in the allow list can be ignored.
- if (!options.isOpAllowed(op))
- return WalkResult::advance();
-
- // Ops without any uses and no side effects will fold away.
- if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
- return WalkResult::advance();
-
- status = op->emitError("op was not bufferized");
- return WalkResult::interrupt();
- });
-
- if (failed(status))
- return status;
- }
-
- return success();
-}
-
LogicalResult
bufferization::finalizeBuffers(Operation *op,
const BufferizationOptions &options) {
@@ -335,35 +276,131 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
return success();
}
+namespace {
+/// A rewriter that keeps track of extra information during bufferization.
+class BufferizationRewriter : public IRRewriter {
+public:
+ BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
+ DenseSet<Operation *> &toMemrefOps,
+ SmallVector<Operation *> &worklist)
+ : IRRewriter(ctx), erasedOps(erasedOps), toMemrefOps(toMemrefOps),
+ worklist(worklist) {}
+
+protected:
+ void notifyOperationRemoved(Operation *op) override {
+ IRRewriter::notifyOperationRemoved(op);
+ erasedOps.insert(op);
+ }
+
+ void notifyOperationInserted(Operation *op) override {
+ IRRewriter::notifyOperationInserted(op);
+
+ // Keep track of to_memref ops.
+ if (isa<ToMemrefOp>(op)) {
+ toMemrefOps.insert(op);
+ return;
+ }
+
+ // Skip to_tensor ops.
+ if (isa<ToTensorOp>(op))
+ return;
+
+ // A new bufferizable op was inserted. Add it to the worklist.
+ if (hasTensorSemantics(op))
+ worklist.push_back(op);
+ }
+
+private:
+ /// A set of all erased ops.
+ DenseSet<Operation *> &erasedOps;
+
+ /// A set of all to_memref ops.
+ DenseSet<Operation *> &toMemrefOps;
+
+ /// The list of bufferizable ops.
+ SmallVector<Operation *> &worklist;
+};
+} // namespace
+
LogicalResult
bufferization::bufferizeOp(Operation *op,
BufferizationState &bufferizationState) {
- // Bufferize the op and its nested ops.
- RewritePatternSet patterns(op->getContext());
- patterns.add<BufferizationPattern>(patterns.getContext(), bufferizationState);
-
- // Bufferize ops top-to-bottom. When creating a new op, we should ideally
- // know the exact memref type of all operands. Otherwise, we have to use a
- // memref type with a fully dynamic layout map, which has to canonicalize
- // away. This is less efficient.
+ const auto &options = bufferizationState.getOptions();
+
+ // Keep track of to_memref ops.
+ DenseSet<Operation *> toMemrefOps;
+ op->walk([&](ToMemrefOp toMemrefOp) { toMemrefOps.insert(toMemrefOp); });
+
+ // Gather all bufferizable ops in top-to-bottom order.
//
- // Note: If "fullyDynamicLayoutMaps = false", we may have to insert buffer
- // copies to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-
- // compatible layout maps when doing a traversal other than top-to-bottom.
- // There are currently no canonicalization patterns to fold these away.
- GreedyRewriteConfig config;
- config.useTopDownTraversal = true;
-
- // TODO: Perform a preorder walk instead of the greedy pattern rewriter. This
- // would be more efficient because every bufferization pattern is guaranteed
- // to apply only a single time (otherwise, an assertion would be triggered).
- // However, there are restrictions wrt. erasing ops during a preorder walk,
- // which would likely require a larger refactoring.
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
- return failure();
+ // We should ideally know the exact memref type of all operands when
+ // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
+ // Otherwise, we have to use a memref type with a fully dynamic layout map,
+ // which has to canonicalize away. This is less efficient.
+ //
+ // If "fullyDynamicLayoutMaps = false", we would have to insert buffer copies
+ // to fold ("finalize") to_memref(to_tensor(x)) ops with non-cast-compatible
+ // layout maps when doing a traversal other than top-to-bottom. These would
+ // not easily fold away.
+ SmallVector<Operation *> worklist;
+ op->walk<WalkOrder::PreOrder>([&](Operation *op) {
+ if (hasTensorSemantics(op))
+ worklist.push_back(op);
+ });
- if (failed(checkBufferizationResult(op, bufferizationState.getOptions())))
- return failure();
+ // Keep track of all erased ops.
+ DenseSet<Operation *> erasedOps;
+
+ // Bufferize all ops.
+ BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
+ worklist);
+ for (unsigned i = 0; i < worklist.size(); ++i) {
+ Operation *op = worklist[i];
+ // Skip ops that were erased.
+ if (erasedOps.contains(op))
+ continue;
+ // Skip ops that are not bufferizable.
+ auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
+ if (!bufferizableOp)
+ continue;
+ // Continue ops that are not allowed.
+ if (!options.isOpAllowed(op))
+ continue;
+ // Bufferize the op.
+ rewriter.setInsertionPoint(op);
+ (void)bufferizableOp.bufferize(rewriter, bufferizationState);
+ }
+
+ // Fold all to_memref(to_tensor(x)) pairs.
+ for (Operation *op : toMemrefOps) {
+ if (erasedOps.contains(op))
+ continue;
+ rewriter.setInsertionPoint(op);
+ (void)bufferization::foldToMemrefToTensorPair(rewriter,
+ cast<ToMemrefOp>(op));
+ }
+
+ /// Check the result of bufferization. Return an error if an op was not
+ /// bufferized, unless partial bufferization is allowed.
+ if (bufferizationState.getOptions().allowUnknownOps)
+ return success();
+
+ for (Operation *op : worklist) {
+ // Skip ops that are entirely gone.
+ if (erasedOps.contains(op))
+ continue;
+ // Ops that no longer have tensor semantics (because they were updated
+ // in-place) are allowed.
+ if (!hasTensorSemantics(op))
+ continue;
+ // Continue ops that are not allowed.
+ if (!options.isOpAllowed(op))
+ continue;
+ // Ops without any uses and no side effects will fold away.
+ if (op->getUses().empty() && MemoryEffectOpInterface::hasNoEffect(op))
+ continue;
+ return op->emitError("op was not bufferized");
+ }
return success();
}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 5fef35f3d4c19..774161376a841 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -884,8 +884,8 @@ func.func @scf_for_yield_non_equivalent(
// CHECK: %[[cloned:.*]] = bufferization.clone %[[t]]
// CHECK: %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[cloned]])
// This alloc is for the linalg.init_tensor.
-// CHECK: %[[alloc2:.*]] = memref.alloc(%{{.*}})
-// CHECK: memref.dealloc %[[iter]]
+// CHECK-DAG: %[[alloc2:.*]] = memref.alloc(%{{.*}})
+// CHECK-DAG: memref.dealloc %[[iter]]
// This alloc is for the scf.yield.
// CHECK: %[[alloc3:.*]] = memref.alloc(%{{.*}})
// CHECK: memref.copy %[[alloc2]], %[[alloc3]]
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index e983d28af1fa5..83fb1f0bfd9e3 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -tensor-bufferize | FileCheck %s
+// RUN: mlir-opt %s -tensor-bufferize -cse | FileCheck %s
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)>
@@ -72,14 +72,6 @@ func.func @tensor.extract(%arg0: tensor<?xf32>, %arg1: index) -> f32 {
return %0 : f32
}
-// CHECK-LABEL: func @tensor.from_elements_no_elements() -> tensor<0xindex> {
-// CHECK: %[[RET:.*]] = arith.constant dense<> : tensor<0xindex>
-// CHECK: return %[[RET]] : tensor<0xindex>
-func.func @tensor.from_elements_no_elements() -> tensor<0xindex> {
- %0 = tensor.from_elements : tensor<0xindex>
- return %0 : tensor<0xindex>
-}
-
// CHECK-LABEL: func @tensor.from_elements_0d(
// CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor<index> {
// CHECK: %[[MEMREF:.*]] = memref.alloc() {{.*}} : memref<index>
@@ -185,8 +177,8 @@ func.func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> {
// CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
-// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
+// CHECK-DAG: %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
+// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
// CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
@@ -212,7 +204,7 @@ func.func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tenso
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
-// CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
+// CHECK-DAG: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
// CHECK: %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index
// CHECK: store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
@@ -278,8 +270,8 @@ func.func @tensor.insert_slice(%t1: tensor<?x?xf32>, %t2: tensor<?x10xf32>,
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x?xf32>
// CHECK-DAG: %[[m2:.*]] = bufferization.to_memref %[[t2]] : memref<?x10xf32>
- // CHECK: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
- // CHECK: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
+ // CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
+ // CHECK-DAG: %[[dim1:.*]] = memref.dim %[[m1]], %[[c1]]
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]], %[[dim1]])
// CHECK: memref.copy %[[m1]], %[[alloc]]
// CHECK: %[[subview:.*]] = memref.subview %[[alloc]][%[[idx1]], 5] [%[[idx2]], 10] [1, 1]
More information about the Mlir-commits
mailing list