[Mlir-commits] [mlir] ba9d886 - [mlir][bufferization][NFC] Bufferize with PostOrder traversal
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 27 03:46:05 PDT 2022
Author: Matthias Springer
Date: 2022-06-27T12:42:41+02:00
New Revision: ba9d886db4fbb2dfd6787bfa073811e77eacbfe7
URL: https://github.com/llvm/llvm-project/commit/ba9d886db4fbb2dfd6787bfa073811e77eacbfe7
DIFF: https://github.com/llvm/llvm-project/commit/ba9d886db4fbb2dfd6787bfa073811e77eacbfe7.diff
LOG: [mlir][bufferization][NFC] Bufferize with PostOrder traversal
This is useful because the result type of an op can sometimes be inferred from its body (e.g., `scf.if`). This will be utilized in subsequent changes.
Also introduces a new `getBufferType` interface method on BufferizableOpInterface. This method is useful for computing a bufferized block argument type with respect to OpOperand types of the parent op.
Differential Revision: https://reviews.llvm.org/D128420
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index ccc9d1d706261..4c56a8196d455 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -337,7 +337,24 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*defaultImplementation=*/[{
return success();
}]
- >
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the bufferized type of the given tensor block argument. The
+ block argument is guaranteed to belong to a block of this op.
+ }],
+ /*retType=*/"BaseMemRefType",
+ /*methodName=*/"getBufferType",
+ /*args=*/(ins "BlockArgument":$bbArg,
+ "const BufferizationOptions &":$options),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(bbArg.getOwner()->getParentOp() == $_op &&
+ "bbArg must belong to this op");
+ auto tensorType = bbArg.getType().cast<TensorType>();
+ return bufferization::getMemRefType(tensorType, options);
+ }]
+ >,
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 3fe32bc293ea8..8a01acaf9374a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -482,8 +482,10 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
+#ifndef NDEBUG
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
+#endif // NDEBUG
// Replace "%t = to_tensor %m" with %m.
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
@@ -492,7 +494,7 @@ Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- Type memrefType = getMemRefType(tensorType, options);
+ Type memrefType = getBufferType(value, options);
ensureToMemrefOpIsValid(value, memrefType);
return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
value);
@@ -507,6 +509,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options) {
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
+ if (auto bbArg = value.dyn_cast<BlockArgument>())
+ if (auto bufferizableOp =
+ options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
+ return bufferizableOp.getBufferType(bbArg, options);
+
return getMemRefType(tensorType, options);
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index f9c809081723e..c68d1d120be6a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -393,9 +393,16 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
// Otherwise, we have to use a memref type with a fully dynamic layout map to
// avoid copies. We are currently missing patterns for layout maps to
// canonicalize away (or canonicalize to more precise layouts).
+ //
+ // FuncOps must be bufferized before their bodies, so add them to the worklist
+ // first.
SmallVector<Operation *> worklist;
- op->walk<WalkOrder::PreOrder>([&](Operation *op) {
- if (hasTensorSemantics(op))
+ op->walk([&](func::FuncOp funcOp) {
+ if (hasTensorSemantics(funcOp))
+ worklist.push_back(funcOp);
+ });
+ op->walk<WalkOrder::PostOrder>([&](Operation *op) {
+ if (hasTensorSemantics(op) && !isa<func::FuncOp>(op))
worklist.push_back(op);
});
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 26f5e2b0b5519..785cd0d7806dc 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -725,7 +725,7 @@ struct WhileOpInterface
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
- return getBufferType(bbArg, options).cast<Type>();
+ return bufferization::getBufferType(bbArg, options).cast<Type>();
}));
// Construct a new scf.while op with memref instead of tensor values.
@@ -1107,7 +1107,7 @@ struct ParallelInsertSliceOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &b,
const BufferizationOptions &options) const {
// Will be bufferized as part of ForeachThreadOp.
- return failure();
+ return success();
}
// TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 177d820ccefcb..02a073de45699 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -154,7 +154,7 @@ struct AssumingYieldOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
// Op is bufferized as part of AssumingOp.
- return failure();
+ return success();
}
};
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index d067704112754..cc0357a055d71 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -313,10 +313,10 @@ func.func @scf_for_swapping_yields(
// CHECK: %[[alloc2:.*]] = memref.alloc(%{{.*}})
// CHECK: memref.copy %[[iter2]], %[[alloc2]]
// CHECK: memref.dealloc %[[iter2]]
-// CHECK: %[[casted2:.*]] = memref.cast %[[alloc2]]
// CHECK: %[[alloc1:.*]] = memref.alloc(%{{.*}})
// CHECK: memref.copy %[[iter1]], %[[alloc1]]
// CHECK: memref.dealloc %[[iter1]]
+// CHECK: %[[casted2:.*]] = memref.cast %[[alloc2]]
// CHECK: %[[casted1:.*]] = memref.cast %[[alloc1]]
// CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]]
// CHECK: memref.dealloc %[[alloc1]]
@@ -384,10 +384,10 @@ func.func @scf_while_non_equiv_condition(%arg0: tensor<5xi1>,
// CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w1]], %[[a1]]
// CHECK: memref.dealloc %[[w1]]
- // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
// CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w0]], %[[a0]]
// CHECK: memref.dealloc %[[w0]]
+ // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
// CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
// CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
// CHECK: memref.dealloc %[[a0]]
@@ -437,10 +437,10 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
// CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w1]], %[[a1]]
// CHECK: memref.dealloc %[[w1]]
- // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
// CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w0]], %[[a0]]
// CHECK: memref.dealloc %[[w0]]
+ // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
// CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
// CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
// CHECK: memref.dealloc %[[a0]]
@@ -457,9 +457,9 @@ func.func @scf_while_non_equiv_condition_and_body(%arg0: tensor<5xi1>,
// CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[b1]], %[[a3]]
// CHECK: memref.dealloc %[[b1]]
- // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
// CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[b0]], %[[a2]]
+ // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
// CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
// CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]]
// CHECK: memref.dealloc %[[a2]]
More information about the Mlir-commits
mailing list