[Mlir-commits] [mlir] 878950b - [mlir][bufferization] Simplify `getBufferType`
Matthias Springer
llvmlistbot at llvm.org
Wed Aug 16 06:05:11 PDT 2023
Author: Matthias Springer
Date: 2023-08-16T15:02:07+02:00
New Revision: 878950b82cb7727361b5ae13ea0326e39c7677fe
URL: https://github.com/llvm/llvm-project/commit/878950b82cb7727361b5ae13ea0326e39c7677fe
DIFF: https://github.com/llvm/llvm-project/commit/878950b82cb7727361b5ae13ea0326e39c7677fe.diff
LOG: [mlir][bufferization] Simplify `getBufferType`
`getBufferType` computes the bufferized type of an SSA value without bufferizing any IR. This is useful for predicting the bufferized type of iter_args of a loop.
To avoid endless recursion (e.g., in the case of "scf.for", the type of the iter_arg depends on the type of init_arg and the type of the yielded value; the type of the yielded value depends on the type of the iter_arg again), `fixedTypes` was used to fall back to "fixed" type. A simpler way is to maintain an "invocation stack". `getBufferType` implementations can then inspect the invocation stack to detect repetitive computations (typically when computing the bufferized type of a block argument).
Also improve error messages in case of inconsistent memory spaces inside of a loop.
Differential Revision: https://reviews.llvm.org/D158060
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 7f758e492d89eb..6fc487c1a11aa5 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -609,17 +609,18 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options);
/// Return the buffer type for a given Value (tensor) after bufferization
-/// without bufferizing any IR. If at any point during the type computation, the
-/// type of a value in `fixedTypes` in required, the mapped type is used.
+/// without bufferizing any IR. This function (and not the other overload
+/// without `invocationStack`) can be used from `getBufferType` implementations
+/// of the `BufferizableOpInterface`.
///
/// Note: It should be sufficient to call `getBuffer()->getType()` in most
/// cases. However, when a buffer type should be predicted without modifying any
/// IR, this function can be used.
///
-/// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType>
-getBufferType(Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes);
+/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
+FailureOr<BaseMemRefType> getBufferType(Value value,
+ const BufferizationOptions &options,
+ SmallVector<Value> &invocationStack);
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
@@ -691,7 +692,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// places.
FailureOr<BaseMemRefType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes);
+ SmallVector<Value> &invocationStack);
/// This is the default implementation of
/// BufferizableOpInterface::resultBufferizesToMemoryWrite. Should not be called
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 6a349f34ef1fbd..bd7a2d8b3f1eac 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -494,18 +494,31 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
This method is useful when the bufferized type of value must be
predicted before modifying any IR.
+
+ Implementations may call `bufferization::getBufferType` to compute the
+ bufferized type of another SSA value. The same (unmodified)
+ `invocationStack` must be passed to that function. The stack contains
+ all SSA values for which a buffer type computation is currently in
+ progress. Implementations may inspect the stack to detect repetitive
+ computations for the same SSA value. (E.g., when bufferized types of a
+ loop.)
+
+ Note: This interface method should never be called directly from user
+ code. Always use `bufferization::getBufferType`.
}],
/*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
- "const ::mlir::DenseMap<::mlir::Value, ::mlir::BaseMemRefType>":$fixedTypes),
+ "::llvm::SmallVector<::mlir::Value> &":$invocationStack),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(getOwnerOfValue(value) == $_op.getOperation() &&
"expected that value belongs to this op");
+ assert(invocationStack.back() == value &&
+ "inconsistant invocation stack");
return ::mlir::bufferization::detail::defaultGetBufferType(
- value, options, fixedTypes);
+ value, options, invocationStack);
}]
>,
InterfaceMethod<
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index d62f781a01b003..fec07af349b3a8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -108,7 +108,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes);
+ SmallVector<Value> &invocationStack);
RankedTensorType getType() {
return ::llvm::cast<RankedTensorType>(getResult().getType());
@@ -388,7 +388,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) {
+ SmallVector<Value> &invocationStack) {
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
}
}];
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index ddc5ce54040f65..1bfb0c7e102cd3 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -170,13 +170,13 @@ struct SelectOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
- options, fixedTypes);
+ options, invocationStack);
auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
- options, fixedTypes);
+ options, invocationStack);
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1427037619e591..a96cfedc9a4527 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
//===----------------------------------------------------------------------===//
@@ -728,27 +729,25 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
- DenseMap<Value, BaseMemRefType> fixedTypes;
- return getBufferType(value, options, fixedTypes);
+ SmallVector<Value> invocationStack;
+ return getBufferType(value, options, invocationStack);
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType> bufferization::getBufferType(
- Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) {
+FailureOr<BaseMemRefType>
+bufferization::getBufferType(Value value, const BufferizationOptions &options,
+ SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) &&
"unexpected non-tensor type");
-
- // If the `value` is in `fixedTypes`, return the mapped type.
- const auto &it = fixedTypes.find(value);
- if (it != fixedTypes.end())
- return it->second;
+ invocationStack.push_back(value);
+ auto popFromStack =
+ llvm::make_scope_exit([&]() { invocationStack.pop_back(); });
// Try querying BufferizableOpInterface.
Operation *op = getOwnerOfValue(value);
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (bufferizableOp)
- return bufferizableOp.getBufferType(value, options, fixedTypes);
+ return bufferizableOp.getBufferType(value, options, invocationStack);
// Op is not bufferizable.
if (!options.defaultMemorySpace.has_value())
@@ -996,7 +995,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) {
+ SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
// No further analysis is possible for a block argument.
@@ -1013,7 +1012,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
- return getBufferType(equivalentOperand, options, fixedTypes);
+ return getBufferType(equivalentOperand, options, invocationStack);
}
// If we do not know the memory space and there is no default memory space,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index e16cfcead1c37d..c8681374ccae11 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -230,9 +230,9 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
return {};
}
-FailureOr<BaseMemRefType> AllocTensorOp::getBufferType(
- Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) {
+FailureOr<BaseMemRefType>
+AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
+ SmallVector<Value> &invocationStack) {
assert(value == getResult() && "invalid value");
// Compute memory space of this allocation.
@@ -241,7 +241,7 @@ FailureOr<BaseMemRefType> AllocTensorOp::getBufferType(
memorySpace = *getMemorySpace();
} else if (getCopy()) {
auto copyBufferType =
- bufferization::getBufferType(getCopy(), options, fixedTypes);
+ bufferization::getBufferType(getCopy(), options, invocationStack);
if (failed(copyBufferType))
return failure();
memorySpace = copyBufferType->getMemorySpace();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index b72c4f1999401b..10c704fc64dd51 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -199,7 +199,7 @@ struct CallOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
auto callOp = cast<func::CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
@@ -321,7 +321,7 @@ struct FuncOpInterface
: public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
auto funcOp = cast<FuncOp>(op);
auto bbArg = cast<BlockArgument>(value);
// Unstructured control flow is not supported.
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 11ee5415bb4cbc..ac01d264eb8fba 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct IfOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
auto ifOp = cast<scf::IfOp>(op);
auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
@@ -227,7 +227,7 @@ struct IfOpInterface
thenBufferType = cast<BaseMemRefType>(thenValue.getType());
} else {
auto maybeBufferType =
- bufferization::getBufferType(thenValue, options, fixedTypes);
+ bufferization::getBufferType(thenValue, options, invocationStack);
if (failed(maybeBufferType))
return failure();
thenBufferType = *maybeBufferType;
@@ -237,7 +237,7 @@ struct IfOpInterface
elseBufferType = cast<BaseMemRefType>(elseValue.getType());
} else {
auto maybeBufferType =
- bufferization::getBufferType(elseValue, options, fixedTypes);
+ bufferization::getBufferType(elseValue, options, invocationStack);
if (failed(maybeBufferType))
return failure();
elseBufferType = *maybeBufferType;
@@ -331,33 +331,34 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
///
/// This function uses bufferization::getBufferType to compute the bufferized
/// type of the init_arg and of the yielded value. (The computation of the
-/// usually requires computing the bufferized type of the corresponding
-/// iter_arg; the implementation of getBufferType traces back the use-def chain
-/// of the given value and computes a buffer type along the way.) If both buffer
-/// types are equal, no casts are needed the computed buffer type can be used
-/// directly. Otherwise, the buffer types can only
diff er in their layout map
-/// and a cast must be inserted.
+/// bufferized yielded value type usually requires computing the bufferized type
+/// of the iter_arg again; the implementation of getBufferType traces back the
+/// use-def chain of the given value and computes a buffer type along the way.)
+/// If both buffer types are equal, no casts are needed the computed buffer type
+/// can be used directly. Otherwise, the buffer types can only
diff er in their
+/// layout map and a cast must be inserted.
static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
- BlockArgument iterArg, Value initArg, Value yieldedValue,
- const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) {
+ Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
+ const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
// Determine the buffer type of the init_arg.
auto initArgBufferType =
- bufferization::getBufferType(initArg, options, fixedTypes);
+ bufferization::getBufferType(initArg, options, invocationStack);
if (failed(initArgBufferType))
return failure();
- // Fix the iter_arg type, so that recursive lookups return the buffer type
- // of the init_arg. This is to avoid infinite loops when calculating the
- // buffer type of the yielded value.
- //
- // Note: For more precise layout map computation, a fixpoint iteration could
- // be done (i.e., re-computing the yielded buffer type until the bufferized
- // iter_arg type no longer changes). This current implementation immediately
- // switches to a fully dynamic layout map when a mismatch between bufferized
- // init_arg type and bufferized yield value type is detected.
- DenseMap<Value, BaseMemRefType> newFixedTypes(fixedTypes);
- newFixedTypes[iterArg] = *initArgBufferType;
+ if (llvm::count(invocationStack, iterArg) >= 2) {
+ // If the iter_arg is already twice on the invocation stack, just take the
+ // type of the init_arg. This is to avoid infinite loops when calculating
+ // the buffer type. This will most likely result in computing a memref type
+ // with a fully dynamic layout map.
+
+ // Note: For more precise layout map computation, a fixpoint iteration could
+ // be done (i.e., re-computing the yielded buffer type until the bufferized
+ // iter_arg type no longer changes). This current implementation immediately
+ // switches to a fully dynamic layout map when a mismatch between bufferized
+ // init_arg type and bufferized yield value type is detected.
+ return *initArgBufferType;
+ }
// Compute the buffer type of the yielded value.
BaseMemRefType yieldedValueBufferType;
@@ -365,8 +366,10 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
// scf.yield was already bufferized.
yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
} else {
+ // Note: This typically triggers a recursive call for the buffer type of
+ // the iter_arg.
auto maybeBufferType =
- bufferization::getBufferType(yieldedValue, options, newFixedTypes);
+ bufferization::getBufferType(yieldedValue, options, invocationStack);
if (failed(maybeBufferType))
return failure();
yieldedValueBufferType = *maybeBufferType;
@@ -376,20 +379,26 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
if (*initArgBufferType == yieldedValueBufferType)
return yieldedValueBufferType;
- // If there is a mismatch between the yielded buffer type and the iter_arg
+ // If there is a mismatch between the yielded buffer type and the init_arg
// buffer type, the buffer type must be promoted to a fully dynamic layout
// map.
- auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
+ auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
+ auto iterTensorType = cast<TensorType>(iterArg.getType());
+ auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
+ if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
+ return loopOp->emitOpError(
+ "init_arg and yielded value bufferize to inconsistent memory spaces");
#ifndef NDEBUG
- auto iterRanked = llvm::cast<MemRefType>(*initArgBufferType);
- assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
- "expected same shape");
- assert(yieldedRanked.getMemorySpace() == iterRanked.getMemorySpace() &&
- "expected same memory space");
+ if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
+ assert(
+ llvm::all_equal({yieldedRankedBufferType.getShape(),
+ cast<MemRefType>(initBufferType).getShape(),
+ cast<RankedTensorType>(iterTensorType).getShape()}) &&
+ "expected same shape");
+ }
#endif // NDEBUG
return getMemRefTypeWithFullyDynamicLayout(
- cast<RankedTensorType>(iterArg.getType()),
- yieldedRanked.getMemorySpace());
+ iterTensorType, yieldedBufferType.getMemorySpace());
}
/// Return `true` if the given loop may have 0 iterations.
@@ -513,29 +522,32 @@ struct ForOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
auto forOp = cast<scf::ForOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
assert(isa<TensorType>(value.getType()) && "expected tensor type");
- // Get result/argument number.
- unsigned resultNum;
- if (auto bbArg = dyn_cast<BlockArgument>(value)) {
- resultNum =
- forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg))
- .getResultNumber();
- } else {
- resultNum = cast<OpResult>(value).getResultNumber();
+ if (auto opResult = dyn_cast<OpResult>(value)) {
+ // The type of an OpResult must match the corresponding iter_arg type.
+ BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(
+ forOp.getOpOperandForResult(opResult));
+ return bufferization::getBufferType(bbArg, options, invocationStack);
}
+ // Compute result/argument number.
+ BlockArgument bbArg = cast<BlockArgument>(value);
+ unsigned resultNum =
+ forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg))
+ .getResultNumber();
+
// Compute the bufferized type.
auto yieldOp =
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
Value yieldedValue = yieldOp.getOperand(resultNum);
BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
Value initArg = forOp.getInitArgs()[resultNum];
- return computeLoopRegionIterArgBufferType(iterArg, initArg, yieldedValue,
- options, fixedTypes);
+ return computeLoopRegionIterArgBufferType(
+ op, iterArg, initArg, yieldedValue, options, invocationStack);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -838,7 +850,7 @@ struct WhileOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
auto whileOp = cast<scf::WhileOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
assert(isa<TensorType>(value.getType()) && "expected tensor type");
@@ -849,8 +861,8 @@ struct WhileOpInterface
Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
auto yieldOp = whileOp.getYieldOp();
Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
- return computeLoopRegionIterArgBufferType(bbArg, initArg, yieldedValue,
- options, fixedTypes);
+ return computeLoopRegionIterArgBufferType(
+ op, bbArg, initArg, yieldedValue, options, invocationStack);
}
}
@@ -872,7 +884,7 @@ struct WhileOpInterface
return cast<BaseMemRefType>(conditionYieldedVal.getType());
}
return bufferization::getBufferType(conditionYieldedVal, options,
- fixedTypes);
+ invocationStack);
}
/// Assert that yielded values of an scf.while op are equivalent to their
@@ -1104,20 +1116,20 @@ struct ForallOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
auto forallOp = cast<ForallOp>(op);
if (auto bbArg = dyn_cast<BlockArgument>(value))
// A tensor block argument has the same bufferized type as the
// corresponding output operand.
return bufferization::getBufferType(
- forallOp.getTiedOpOperand(bbArg)->get(), options, fixedTypes);
+ forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
// The bufferized result type is the same as the bufferized type of the
// corresponding output operand.
return bufferization::getBufferType(
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
- fixedTypes);
+ invocationStack);
}
bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index efcea2f6b45ca9..a67ea0334b22b9 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -48,10 +48,10 @@ struct CastOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
auto castOp = cast<tensor::CastOp>(op);
- auto maybeSrcBufferType =
- bufferization::getBufferType(castOp.getSource(), options, fixedTypes);
+ auto maybeSrcBufferType = bufferization::getBufferType(
+ castOp.getSource(), options, invocationStack);
if (failed(maybeSrcBufferType))
return failure();
Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
@@ -133,10 +133,10 @@ struct CollapseShapeOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
auto maybeSrcBufferType = bufferization::getBufferType(
- collapseShapeOp.getSrc(), options, fixedTypes);
+ collapseShapeOp.getSrc(), options, invocationStack);
if (failed(maybeSrcBufferType))
return failure();
auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
@@ -302,10 +302,10 @@ struct ExpandShapeOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
auto maybeSrcBufferType = bufferization::getBufferType(
- expandShapeOp.getSrc(), options, fixedTypes);
+ expandShapeOp.getSrc(), options, invocationStack);
if (failed(maybeSrcBufferType))
return failure();
auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
@@ -383,11 +383,11 @@ struct ExtractSliceOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
assert(value == extractSliceOp.getResult() && "invalid value");
auto srcMemrefType = bufferization::getBufferType(
- extractSliceOp.getSource(), options, fixedTypes);
+ extractSliceOp.getSource(), options, invocationStack);
if (failed(srcMemrefType))
return failure();
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
@@ -853,11 +853,11 @@ struct PadOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
// Infer memory space from the source tensor.
auto padOp = cast<tensor::PadOp>(op);
- auto maybeSrcBufferType =
- bufferization::getBufferType(padOp.getSource(), options, fixedTypes);
+ auto maybeSrcBufferType = bufferization::getBufferType(
+ padOp.getSource(), options, invocationStack);
if (failed(maybeSrcBufferType))
return failure();
MemRefLayoutAttrInterface layout;
@@ -1002,11 +1002,11 @@ struct ReshapeOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ SmallVector<Value> &invocationStack) const {
auto reshapeOp = cast<tensor::ReshapeOp>(op);
assert(value == reshapeOp.getResult() && "unexpected value provided");
auto maybeSourceBufferType = bufferization::getBufferType(
- reshapeOp.getSource(), options, fixedTypes);
+ reshapeOp.getSource(), options, invocationStack);
if (failed(maybeSourceBufferType))
return failure();
return getMemRefTypeWithStaticIdentityLayout(
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
index c77b6ce345e9ff..c8d6d506270a99 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
@@ -26,3 +26,16 @@ func.func @execute_region_multiple_blocks(%t: tensor<5xf32>) -> tensor<5xf32> {
}
func.return %0 : tensor<5xf32>
}
+
+// -----
+
+func.func @inconsistent_memory_space_scf_for(%lb: index, %ub: index, %step: index) -> tensor<10xf32> {
+ %0 = bufferization.alloc_tensor() {memory_space = 0 : ui64} : tensor<10xf32>
+ %1 = bufferization.alloc_tensor() {memory_space = 1 : ui64} : tensor<10xf32>
+ // expected-error @below{{init_arg and yielded value bufferize to inconsistent memory spaces}}
+ %2 = scf.for %iv = %lb to %ub step %step iter_args(%arg = %0) -> tensor<10xf32> {
+ // expected-error @below {{failed to bufferize op}}
+ scf.yield %1 : tensor<10xf32>
+ }
+ return %2 : tensor<10xf32>
+}
More information about the Mlir-commits
mailing list