[Mlir-commits] [mlir] 9d6096c - [mlir][SCF][bufferize][NFC] Move scf.if buffer type computation to getBufferType
Matthias Springer
llvmlistbot at llvm.org
Tue Aug 30 07:49:43 PDT 2022
Author: Matthias Springer
Date: 2022-08-30T16:48:10+02:00
New Revision: 9d6096c56fcafbd882d5f688cbd8d62ec2f2ac71
URL: https://github.com/llvm/llvm-project/commit/9d6096c56fcafbd882d5f688cbd8d62ec2f2ac71
DIFF: https://github.com/llvm/llvm-project/commit/9d6096c56fcafbd882d5f688cbd8d62ec2f2ac71.diff
LOG: [mlir][SCF][bufferize][NFC] Move scf.if buffer type computation to getBufferType
A part of the functionality of `bufferize` is extracted into `getBufferType`. Also, bufferized scf.yields inside scf.if are now created with the correct bufferized type from the get-to.
Differential Revision: https://reviews.llvm.org/D132862
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index b92b131616fbc..7e458a1f37626 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -163,52 +163,22 @@ struct IfOpInterface
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto ifOp = cast<scf::IfOp>(op);
- auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
- auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
- // Reconcile type mismatches between then/else branches by inserting memref
- // casts.
- SmallVector<Value> thenResults, elseResults;
- bool insertedCast = false;
- for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) {
- Value thenValue = thenYieldOp.getResults()[i];
- Value elseValue = elseYieldOp.getResults()[i];
- if (thenValue.getType() == elseValue.getType()) {
- thenResults.push_back(thenValue);
- elseResults.push_back(elseValue);
+ // Compute bufferized result types.
+ SmallVector<Type> newTypes;
+ for (Value result : ifOp.getResults()) {
+ if (!result.getType().isa<TensorType>()) {
+ newTypes.push_back(result.getType());
continue;
}
-
- // Type mismatch between then/else yield value. Cast both to a memref type
- // with a fully dynamic layout map.
- auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>();
- auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>();
- if (thenMemrefType.getMemorySpaceAsInt() !=
- elseMemrefType.getMemorySpaceAsInt())
- return op->emitError("inconsistent memory space on then/else branches");
- rewriter.setInsertionPoint(thenYieldOp);
- BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout(
- ifOp.getResultTypes()[i].cast<TensorType>(),
- thenMemrefType.getMemorySpaceAsInt());
- thenResults.push_back(rewriter.create<memref::CastOp>(
- thenYieldOp.getLoc(), memrefType, thenValue));
- rewriter.setInsertionPoint(elseYieldOp);
- elseResults.push_back(rewriter.create<memref::CastOp>(
- elseYieldOp.getLoc(), memrefType, elseValue));
- insertedCast = true;
- }
-
- if (insertedCast) {
- rewriter.setInsertionPoint(thenYieldOp);
- rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults);
- rewriter.setInsertionPoint(elseYieldOp);
- rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults);
+ auto bufferType = bufferization::getBufferType(result, options);
+ if (failed(bufferType))
+ return failure();
+ newTypes.push_back(*bufferType);
}
// Create new op.
rewriter.setInsertionPoint(ifOp);
- ValueRange resultsValueRange(thenResults);
- TypeRange newTypes(resultsValueRange);
auto newIfOp =
rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
/*withElseRegion=*/true);
@@ -223,6 +193,55 @@ struct IfOpInterface
return success();
}
+ FailureOr<BaseMemRefType>
+ getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ auto ifOp = cast<scf::IfOp>(op);
+ auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
+ auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
+ assert(value.getDefiningOp() == op && "invalid valid");
+
+ // Determine buffer types of the true/false branches.
+ auto opResult = value.cast<OpResult>();
+ auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
+ auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
+ BaseMemRefType thenBufferType, elseBufferType;
+ if (thenValue.getType().isa<BaseMemRefType>()) {
+ // True branch was already bufferized.
+ thenBufferType = thenValue.getType().cast<BaseMemRefType>();
+ } else {
+ auto maybeBufferType =
+ bufferization::getBufferType(thenValue, options, fixedTypes);
+ if (failed(maybeBufferType))
+ return failure();
+ thenBufferType = *maybeBufferType;
+ }
+ if (elseValue.getType().isa<BaseMemRefType>()) {
+ // False branch was already bufferized.
+ elseBufferType = elseValue.getType().cast<BaseMemRefType>();
+ } else {
+ auto maybeBufferType =
+ bufferization::getBufferType(elseValue, options, fixedTypes);
+ if (failed(maybeBufferType))
+ return failure();
+ elseBufferType = *maybeBufferType;
+ }
+
+ // Best case: Both branches have the exact same buffer type.
+ if (thenBufferType == elseBufferType)
+ return thenBufferType;
+
+ // Memory space mismatch.
+ if (thenBufferType.getMemorySpaceAsInt() !=
+ elseBufferType.getMemorySpaceAsInt())
+ return op->emitError("inconsistent memory space on then/else branches");
+
+ // Layout maps are
diff erent: Promote to fully dynamic layout map.
+ return getMemRefTypeWithFullyDynamicLayout(
+ opResult.getType().cast<TensorType>(),
+ thenBufferType.getMemorySpaceAsInt());
+ }
+
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
// IfOp results are equivalent to their corresponding yield values if both
@@ -973,9 +992,12 @@ struct YieldOpInterface
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
- if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
+ // In case of scf::ForOp / scf::IfOp, we may have to cast the value
+ // before yielding it.
+ // TODO: Do the same for scf::WhileOp.
+ if (isa<scf::ForOp, scf::IfOp>(yieldOp->getParentOp())) {
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
- forOp.getRegionIterArgs()[it.index()], options);
+ yieldOp->getParentOp()->getResult(it.index()), options);
if (failed(resultType))
return failure();
buffer = castBuffer(rewriter, buffer, *resultType);
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
index 52338d0701be3..66a9807fb1086 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
@@ -5,9 +5,9 @@ func.func @inconsistent_memory_space_scf_if(%c: i1) -> tensor<10xf32> {
// bufferized.
%0 = bufferization.alloc_tensor() {memory_space = 0 : ui64} : tensor<10xf32>
%1 = bufferization.alloc_tensor() {memory_space = 1 : ui64} : tensor<10xf32>
- // expected-error @+2 {{inconsistent memory space on then/else branches}}
- // expected-error @+1 {{failed to bufferize op}}
+ // expected-error @+1 {{inconsistent memory space on then/else branches}}
%r = scf.if %c -> tensor<10xf32> {
+ // expected-error @+1 {{failed to bufferize op}}
scf.yield %0 : tensor<10xf32>
} else {
scf.yield %1 : tensor<10xf32>
More information about the Mlir-commits
mailing list