[Mlir-commits] [mlir] f24d931 - [mlir][linalg][bufferize][NFC] Specify bufferize traversal in `bufferize`
Matthias Springer
llvmlistbot at llvm.org
Tue Nov 23 04:36:40 PST 2021
Author: Matthias Springer
Date: 2021-11-23T21:33:19+09:00
New Revision: f24d9313cc9fe9f6cd70f606c1dc8f8213587468
URL: https://github.com/llvm/llvm-project/commit/f24d9313cc9fe9f6cd70f606c1dc8f8213587468
DIFF: https://github.com/llvm/llvm-project/commit/f24d9313cc9fe9f6cd70f606c1dc8f8213587468.diff
LOG: [mlir][linalg][bufferize][NFC] Specify bufferize traversal in `bufferize`
The interface method `bufferize` controls how (and it what order) nested ops are traversed. This simplifies bufferization of scf::ForOps and scf::IfOps, which used to need special rules in scf::YieldOp.
Differential Revision: https://reviews.llvm.org/D114057
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 491f8a56eb609..881f1edb11c47 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -297,10 +297,16 @@ struct BufferizationState {
/// bufferization is necessary.
Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
+/// Bufferize all ops in the given region.
+LogicalResult bufferize(Region *region, BufferizationState &state);
+
+/// Bufferize all ops in the given block.
+LogicalResult bufferize(Block *block, BufferizationState &state);
+
/// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this
/// function returns immediately. Otherwise, it calls the `bufferize` interface
/// method of `BufferizableOpInterface`.
-LogicalResult bufferizeOp(Operation *op, BufferizationState &state);
+LogicalResult bufferize(Operation *op, BufferizationState &state);
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
/// executed after the analysis, but before bufferization. They can be used
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index 757eca50eb5f0..ca0454549bcab 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -163,8 +163,13 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
InterfaceMethod<
/*desc=*/[{
Bufferize this op, i.e., rewrite it into a memref-based equivalent.
- `bvm` maps tensor values to memref values and this method should map
- tensor results to memref results after creating/modifying ops.
+ Tensor values should be mapped to buffer values using `state`.
+
+ Implementations are required to required to bufferize nested ops
+ before returning. Otherwise, nested ops will not be bufferized.
+
+ This method will never be called on ops that do not have at least one
+ tensor operand or result.
}],
/*retType=*/"LogicalResult",
/*methodName=*/"bufferize",
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 3897734a1898c..fc9f414f7cd9d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -392,8 +392,26 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
}
LogicalResult
-mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op,
- BufferizationState &state) {
+mlir::linalg::comprehensive_bufferize::bufferize(Region *region,
+ BufferizationState &state) {
+ for (Block &block : *region)
+ if (failed(bufferize(&block, state)))
+ return failure();
+ return success();
+}
+
+LogicalResult
+mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
+ BufferizationState &state) {
+ for (Operation &op : *block)
+ if (failed(bufferize(&op, state)))
+ return failure();
+ return success();
+}
+
+LogicalResult
+mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
+ BufferizationState &state) {
OpBuilder b(op->getContext());
// Skip BufferCast and TensorLoad ops.
@@ -404,15 +422,22 @@ mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op,
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
- if (!hasTensorResult && !hasTensorOperand)
+
+ // No tensor results or operands: Simply bufferize all nested ops.
+ if (!hasTensorResult && !hasTensorOperand) {
+ for (Region ®ion : op->getRegions())
+ if (failed(bufferize(®ion, state)))
+ return failure();
return success();
+ }
- // Bufferize using `BufferizableOpInterface`.
+ // Bufferize using `BufferizableOpInterface`. Interface implementations are
+ // responsible for bufferizing nested ops.
b.setInsertionPoint(op);
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
return bufferizableOp.bufferize(b, state);
- // Other op with tensors. No bufferization method specified.
+ // Emit error if tensor op is not bufferizable.
return op->emitError() << "unsupported op with tensors";
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index ac67d86ac9b89..d062bbab4ad0c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -784,11 +784,12 @@ static Value createNewAllocDeallocPairForShapedValue(
//===----------------------------------------------------------------------===//
/// FuncOp always creates TensorToMemRef ops.
-static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
- BufferizationState &state) {
+static LogicalResult bufferizeFuncOp(FuncOp funcOp, BufferizationState &state) {
// Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
+ OpBuilder b(funcOp->getContext());
b.setInsertionPointToStart(&funcOp.body().front());
+
+ // Create BufferCastOps for function args.
for (auto bbArg : funcOp.getArguments()) {
auto tensorType = bbArg.getType().dyn_cast<TensorType>();
if (!tensorType)
@@ -804,7 +805,9 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
state.mapBuffer(bbArg, bufferCast);
}
- return success();
+
+ // Bufferize function body.
+ return bufferize(&funcOp.body(), state);
}
//===----------------------------------------------------------------------===//
@@ -923,37 +926,6 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
return res;
}
-//===----------------------------------------------------------------------===//
-// Bufferization entry-point for functions.
-//===----------------------------------------------------------------------===//
-
-static LogicalResult bufferizeFuncOpInternals(FuncOp funcOp,
- BufferizationState &state) {
- LLVM_DEBUG(llvm::dbgs() << "\n\n");
- LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
- OpBuilder b(funcOp->getContext());
-
- // Start by bufferizing `funcOp` arguments.
- if (failed(bufferize(b, funcOp, state)))
- return failure();
-
- auto walkFunc = [&](Operation *op) -> WalkResult {
- if (failed(bufferizeOp(op, state)))
- return failure();
- return success();
- };
-
- // Bufferize ops pre-order, i.e., bufferize ops first, then their children.
- // This is needed for ops with blocks that have BlockArguments. These must be
- // mapped before bufferizing the children.
- if (funcOp.walk<WalkOrder::PreOrder>(walkFunc).wasInterrupted())
- return failure();
-
- LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n');
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// Bufferization entry-point for modules.
//===----------------------------------------------------------------------===//
@@ -1380,7 +1352,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
// Bufferization phase.
if (!options.testAnalysisOnly) {
// Bufferize all ops in funcOp.
- if (failed(bufferizeFuncOpInternals(funcOp, state)))
+ if (failed(bufferizeFuncOp(funcOp, state)))
return failure();
// Erase all obsolete ops.
@@ -1547,12 +1519,13 @@ struct ExecuteRegionOpInterface
BufferizationState &state) const {
// TODO: Add bufferization support when needed. scf.execute_region should be
// bufferized similar to scf.if.
+ auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
bool hasTensorReturnType = any_of(
op->getResultTypes(), [](Type t) { return t.isa<TensorType>(); });
if (hasTensorReturnType)
return op->emitError(
"scf.execute_region with tensor result not supported");
- return success();
+ return comprehensive_bufferize::bufferize(&executeRegionOp.region(), state);
}
};
@@ -1609,37 +1582,33 @@ struct IfOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
- // scf::IfOp is bufferized after scf::YieldOp in the else branch.
- return success();
- }
-};
+ auto ifOp = cast<scf::IfOp>(op);
-/// Bufferize the scf::IfOp. This function is called after the YieldOp was
-/// bufferized.
-static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b,
- BufferizationState &state) {
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(ifOp);
+ // Bufferize then/else blocks.
+ if (failed(comprehensive_bufferize::bufferize(ifOp.thenBlock(), state)))
+ return failure();
+ if (failed(comprehensive_bufferize::bufferize(ifOp.elseBlock(), state)))
+ return failure();
- for (OpResult opResult : ifOp->getResults()) {
- if (!opResult.getType().isa<TensorType>())
- continue;
- // TODO: Atm we bail on unranked TensorType because we don't know how to
- // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
- assert(opResult.getType().isa<RankedTensorType>() &&
- "unsupported unranked tensor");
+ for (OpResult opResult : ifOp->getResults()) {
+ if (!opResult.getType().isa<TensorType>())
+ continue;
+ // TODO: Atm we bail on unranked TensorType because we don't know how to
+ // alloc an UnrankedMemRefType + its underlying ranked MemRefType.
+ assert(opResult.getType().isa<RankedTensorType>() &&
+ "unsupported unranked tensor");
- Value resultBuffer = getResultBuffer(b, opResult, state);
- if (!resultBuffer)
- return failure();
+ Value resultBuffer = getResultBuffer(b, opResult, state);
+ if (!resultBuffer)
+ return failure();
- state.aliasInfo.createAliasInfoEntry(resultBuffer);
- state.mapBuffer(opResult, resultBuffer);
- }
+ state.aliasInfo.createAliasInfoEntry(resultBuffer);
+ state.mapBuffer(opResult, resultBuffer);
+ }
- return success();
-}
+ return success();
+ }
+};
struct ForOpInterface
: public BufferizableOpInterface::ExternalModel<ForOpInterface,
@@ -1687,9 +1656,6 @@ struct ForOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
- // Note: This method is just setting up the mappings for the block arguments
- // and the result buffer. The op is bufferized after the scf::YieldOp.
-
auto forOp = cast<scf::ForOp>(op);
// Take a guard before anything else.
@@ -1716,41 +1682,39 @@ struct ForOpInterface
state.mapBuffer(opResult, resultBuffer);
}
- return success();
- }
-};
+ // Bufferize loop body.
+ if (failed(comprehensive_bufferize::bufferize(&forOp.region(), state)))
+ return failure();
-/// Bufferize the scf::ForOp. This function is called after the YieldOp was
-/// bufferized.
-static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b,
- BufferizationState &state) {
- auto yieldOp = cast<scf::YieldOp>(&forOp.region().front().back());
- for (OpOperand &operand : yieldOp->getOpOperands()) {
- auto tensorType = operand.get().getType().dyn_cast<TensorType>();
- if (!tensorType)
- continue;
+ // Finish bufferizing scf::ForOp.
+ auto yieldOp = cast<scf::YieldOp>(&forOp.region().front().back());
+ for (OpOperand &operand : yieldOp->getOpOperands()) {
+ auto tensorType = operand.get().getType().dyn_cast<TensorType>();
+ if (!tensorType)
+ continue;
- OpOperand &forOperand = forOp.getOpOperandForResult(
- forOp->getResult(operand.getOperandNumber()));
- auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
- Value yieldedBuffer = state.lookupBuffer(operand.get());
- Value bbArgBuffer = state.lookupBuffer(bbArg);
- if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer,
- bbArgBuffer)) {
- // TODO: this could get resolved with copies but it can also turn into
- // swaps so we need to be careful about order of copies.
- return yieldOp->emitError()
- << "Yield operand #" << operand.getOperandNumber()
- << " does not bufferize to an equivalent buffer to the matching"
- << " enclosing scf::for operand";
- }
+ OpOperand &forOperand = forOp.getOpOperandForResult(
+ forOp->getResult(operand.getOperandNumber()));
+ auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
+ Value yieldedBuffer = state.lookupBuffer(operand.get());
+ Value bbArgBuffer = state.lookupBuffer(bbArg);
+ if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer,
+ bbArgBuffer)) {
+ // TODO: this could get resolved with copies but it can also turn into
+ // swaps so we need to be careful about order of copies.
+ return yieldOp->emitError()
+ << "Yield operand #" << operand.getOperandNumber()
+ << " does not bufferize to an equivalent buffer to the matching"
+ << " enclosing scf::for operand";
+ }
- // Buffers are equivalent so the work is already done and we just yield
- // the bbArg so that it later canonicalizes away.
- operand.set(bbArg);
+ // Buffers are equivalent so the work is already done and we just yield
+ // the bbArg so that it later canonicalizes away.
+ operand.set(bbArg);
+ }
+ return success();
}
- return success();
-}
+};
struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
@@ -1774,27 +1738,10 @@ struct YieldOpInterface
LogicalResult bufferize(Operation *op, OpBuilder &b,
BufferizationState &state) const {
auto yieldOp = cast<scf::YieldOp>(op);
-
- if (auto execOp = dyn_cast<scf::ExecuteRegionOp>(yieldOp->getParentOp())) {
- if (execOp->getNumResults() != 0)
- return execOp->emitError(
- "expected result-less scf.execute_region containing op");
- return success();
- }
-
- // Bufferize scf::IfOp after bufferizing the scf::YieldOp in the else
- // branch.
- if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
- if (ifOp.elseYield() != yieldOp)
- return success();
- return bufferizeIfOp(ifOp, b, state);
- }
-
- // Bufferize scf::ForOp after bufferizing the scf::YieldOp.
- if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp()))
- return bufferizeForOp(forOp, b, state);
-
- return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp");
+ if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
+ yieldOp->getParentOp()))
+ return yieldOp->emitError("unsupported scf::YieldOp parent");
+ return success();
}
};
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 83fde817ef842..52fed305d06cd 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -340,7 +340,8 @@ struct TiledLoopOpInterface
static_cast<int>(oldInputs.size()) + numNewInputBuffers,
static_cast<int>(oldOutputs.size()) + numNewOutputBuffers}));
- return success();
+ // Bufferize loop body.
+ return comprehensive_bufferize::bufferize(&tiledLoopOp.region(), state);
}
};
More information about the Mlir-commits
mailing list