[Mlir-commits] [mlir] 123c4b0 - [mlir][SCF][bufferize] Support different iter_arg/init_arg types (scf.for)
Matthias Springer
llvmlistbot at llvm.org
Tue Aug 30 07:35:44 PDT 2022
Author: Matthias Springer
Date: 2022-08-30T16:35:32+02:00
New Revision: 123c4b02517865b11af1079d206bc838edad79a6
URL: https://github.com/llvm/llvm-project/commit/123c4b02517865b11af1079d206bc838edad79a6
DIFF: https://github.com/llvm/llvm-project/commit/123c4b02517865b11af1079d206bc838edad79a6.diff
LOG: [mlir][SCF][bufferize] Support different iter_arg/init_arg types (scf.for)
Even though iter_arg and init_arg of an scf.for loop may have the same tensor type, their bufferized memref types are not necessarily equal. It is sometimes necessary to insert a cast in case of differing layout maps.
Differential Revision: https://reviews.llvm.org/D132860
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index f82bf26c19ef1..f22fe002ec86e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -495,6 +495,19 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options);
+/// Return the buffer type for a given Value (tensor) after bufferization
+/// without bufferizing any IR. If at any point during the type computation, the
+/// type of a value in `fixedTypes` in required, the mapped type is used.
+///
+/// Note: It should be sufficient to call `getBuffer()->getType()` in most
+/// cases. However, when a buffer type should be predicted without modifying any
+/// IR, this function can be used.
+///
+/// This function is a wrapper around BufferizableOpInterface::getBufferType.
+FailureOr<BaseMemRefType>
+getBufferType(Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes);
+
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
@@ -551,7 +564,8 @@ namespace detail {
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
FailureOr<BaseMemRefType>
-defaultGetBufferType(Value value, const BufferizationOptions &options);
+defaultGetBufferType(Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes);
} // namespace detail
} // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index ce88e01bff0a9..c28e3f6745539 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -350,12 +350,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"FailureOr<BaseMemRefType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "Value":$value,
- "const BufferizationOptions &":$options),
+ "const BufferizationOptions &":$options,
+ "const DenseMap<Value, BaseMemRefType>":$fixedTypes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(getOwnerOfValue(value) == $_op.getOperation() &&
"expected that value belongs to this op");
- return bufferization::detail::defaultGetBufferType(value, options);
+ return bufferization::detail::defaultGetBufferType(
+ value, options, fixedTypes);
}]
>,
];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 22d5ef27bdbe8..07e1f53ab6a79 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -92,7 +92,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
OpOperand &opOperand, const AnalysisState &state);
FailureOr<BaseMemRefType> getBufferType(
- Value value, const BufferizationOptions &options);
+ Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes);
RankedTensorType getType() {
return getResult().getType().cast<RankedTensorType>();
@@ -323,7 +324,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
}
FailureOr<BaseMemRefType> getBufferType(
- Value value, const BufferizationOptions &options) {
+ Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) {
return getMemref().getType().cast<BaseMemRefType>();
}
}];
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 3d34ee4d3e2ee..265c417b6ce0c 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -568,7 +568,8 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
}
FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
- Value value, const BufferizationOptions &options) {
+ Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) {
assert(value.getType().isa<TensorType>() && "expected tensor type");
// No further analysis is possible for a block argument.
@@ -587,7 +588,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliasingOperands.front()->get();
- return getBufferType(equivalentOperand, options);
+ return getBufferType(equivalentOperand, options, fixedTypes);
}
// If we do not know the memory space and there is no default memory space,
@@ -602,11 +603,26 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
+ DenseMap<Value, BaseMemRefType> fixedTypes;
+ return getBufferType(value, options, fixedTypes);
+}
+
+/// Return the buffer type for a given Value (tensor) after bufferization.
+FailureOr<BaseMemRefType> bufferization::getBufferType(
+ Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) {
assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
+
+ // If the `value` is in `fixedTypes`, return the mapped type.
+ const auto &it = fixedTypes.find(value);
+ if (it != fixedTypes.end())
+ return it->second;
+
+ // Try querying BufferizableOpInterface.
Operation *op = getOwnerOfValue(value);
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (bufferizableOp)
- return bufferizableOp.getBufferType(value, options);
+ return bufferizableOp.getBufferType(value, options, fixedTypes);
// Op is not bufferizable.
if (!options.defaultMemorySpace.has_value())
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index ee9591b2a01b8..cf29d721002bd 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -170,7 +170,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
}
// Create memory allocation.
- auto allocType = getBufferType(getResult(), options);
+ auto allocType = bufferization::getBufferType(getResult(), options);
if (failed(allocType))
return failure();
SmallVector<Value> dynamicDims = getDynamicSizes();
@@ -233,8 +233,9 @@ AllocTensorOp::getAliasingOpResult(OpOperand &opOperand,
return {};
}
-FailureOr<BaseMemRefType>
-AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options) {
+FailureOr<BaseMemRefType> AllocTensorOp::getBufferType(
+ Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) {
assert(value == getResult() && "invalid value");
// Compute memory space of this allocation.
@@ -242,7 +243,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options) {
if (getMemorySpace().has_value()) {
memorySpace = *getMemorySpace();
} else if (getCopy()) {
- auto copyBufferType = bufferization::getBufferType(getCopy(), options);
+ auto copyBufferType =
+ bufferization::getBufferType(getCopy(), options, fixedTypes);
if (failed(copyBufferType))
return failure();
memorySpace = copyBufferType->getMemorySpaceAsInt();
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 0b5939e60c1bd..b92b131616fbc 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -472,15 +472,76 @@ struct ForOpInterface
}
FailureOr<BaseMemRefType>
- getBufferType(Operation *op, Value value,
- const BufferizationOptions &options) const {
+ getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
auto forOp = cast<scf::ForOp>(op);
- // TODO: Only block arguments supported at the moment.
- if (value.isa<OpResult>())
+ assert(getOwnerOfValue(value) == op && "invalid value");
+ assert(value.getType().isa<TensorType>() && "expected tensor type");
+
+ // Get result/argument number.
+ unsigned resultNum;
+ if (auto bbArg = value.dyn_cast<BlockArgument>()) {
+ resultNum =
+ forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg))
+ .getResultNumber();
+ } else {
+ resultNum = value.cast<OpResult>().getResultNumber();
+ }
+
+ // Determine the buffer type of the init_arg.
+ Value initArg = forOp.getInitArgs()[resultNum];
+ auto initArgBufferType =
+ bufferization::getBufferType(initArg, options, fixedTypes);
+ if (failed(initArgBufferType))
return failure();
- auto bbArg = value.cast<BlockArgument>();
- return bufferization::getBufferType(
- forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
+
+ // Fix the iter_arg type, so that recursive lookups return the buffer type
+ // of the init_arg. This is to avoid infinite loops when calculating the
+ // buffer type of the yielded value.
+ //
+ // Note: For more precise layout map computation, a fixpoint iteration could
+ // be done (i.e., re-computing the yielded buffer type until the bufferized
+ // iter_arg type no longer changes). This current implementation immediately
+ // switches to a fully dynamic layout map when a mismatch between bufferized
+ // init_arg type and bufferized yield value type is detected.
+ DenseMap<Value, BaseMemRefType> newFixedTypes(fixedTypes);
+ newFixedTypes[forOp.getRegionIterArgs()[resultNum]] = *initArgBufferType;
+
+ // Compute the buffer type of the yielded value.
+ auto yieldOp =
+ cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
+ Value yieldedValue = yieldOp.getOperand(resultNum);
+ BaseMemRefType yieldedValueBufferType;
+ if (yieldedValue.getType().isa<BaseMemRefType>()) {
+ // scf.yield was already bufferized.
+ yieldedValueBufferType = yieldedValue.getType().cast<BaseMemRefType>();
+ } else {
+ auto maybeBufferType =
+ bufferization::getBufferType(yieldedValue, options, newFixedTypes);
+ if (failed(maybeBufferType))
+ return failure();
+ yieldedValueBufferType = *maybeBufferType;
+ }
+
+ // If yielded type and init_arg type are the same, use that type directly.
+ if (*initArgBufferType == yieldedValueBufferType)
+ return yieldedValueBufferType;
+
+ // If there is a mismatch between the yielded buffer type and the iter_arg
+ // buffer type, the buffer type must be promoted to a fully dynamic layout
+ // map.
+ auto yieldedRanked = yieldedValueBufferType.cast<MemRefType>();
+#ifndef NDEBUG
+ auto iterRanked = initArgBufferType->cast<MemRefType>();
+ assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
+ "expected same shape");
+ assert(yieldedRanked.getMemorySpaceAsInt() ==
+ iterRanked.getMemorySpaceAsInt() &&
+ "expected same memory space");
+#endif // NDEBUG
+ return getMemRefTypeWithFullyDynamicLayout(
+ forOp.getRegionIterArgs()[resultNum].getType().cast<RankedTensorType>(),
+ yieldedRanked.getMemorySpaceAsInt());
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -499,13 +560,22 @@ struct ForOpInterface
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
+ // Cast init_args if necessary.
+ SmallVector<Value> castedInitArgs;
+ for (const auto &it : llvm::enumerate(initArgs)) {
+ Value initArg = it.value();
+ auto targetType =
+ bufferization::getBufferType(forOp->getResult(it.index()), options);
+ if (failed(targetType))
+ return failure();
+ castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
+ }
+
// Construct a new scf.for op with memref instead of tensor values.
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), initArgs);
+ forOp.getStep(), castedInitArgs);
newForOp->setAttrs(forOp->getAttrs());
- ValueRange initArgsRange(initArgs);
- TypeRange initArgsTypes(initArgsRange);
Block *loopBody = &newForOp.getLoopBody().front();
// Set up new iter_args. The loop body uses tensors, so wrap the (memref)
@@ -904,10 +974,8 @@ struct YieldOpInterface
return failure();
Value buffer = *maybeBuffer;
if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
- FailureOr<BaseMemRefType> resultType =
- cast<BufferizableOpInterface>(forOp.getOperation())
- .getBufferType(forOp.getRegionIterArgs()[it.index()],
- options);
+ FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
+ forOp.getRegionIterArgs()[it.index()], options);
if (failed(resultType))
return failure();
buffer = castBuffer(rewriter, buffer, *resultType);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 2e23a167cd0b2..35010f520e022 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -293,7 +293,7 @@ struct ExtractSliceOpInterface
// Take a subview of the source buffer.
auto resultMemrefType =
- getBufferType(op, extractSliceOp.getResult(), options);
+ bufferization::getBufferType(extractSliceOp.getResult(), options);
if (failed(resultMemrefType))
return failure();
Value subView = rewriter.create<memref::SubViewOp>(
@@ -305,12 +305,12 @@ struct ExtractSliceOpInterface
}
FailureOr<BaseMemRefType>
- getBufferType(Operation *op, Value value,
- const BufferizationOptions &options) const {
+ getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
assert(value == extractSliceOp.getResult() && "invalid value");
- auto srcMemrefType =
- bufferization::getBufferType(extractSliceOp.getSource(), options);
+ auto srcMemrefType = bufferization::getBufferType(
+ extractSliceOp.getSource(), options, fixedTypes);
if (failed(srcMemrefType))
return failure();
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index 71ee9922f8f52..c8bf6125e252e 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -742,3 +742,25 @@ func.func @scf_for_yield_alias_of_non_equivalent(%sz: index) -> tensor<?xf32> {
}
return %r : tensor<?xf32>
}
+
+// -----
+
+// We just check that this example bufferizes to valid IR.
+
+// CHECK-LABEL: func @scf_for_buffer_type_mismatch
+func.func @scf_for_buffer_type_mismatch(%sz: index, %sz2: index) -> f32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %0 = bufferization.alloc_tensor(%sz) : tensor<?xf32>
+ %e2 = tensor.extract_slice %0[1][%sz2][1] : tensor<?xf32> to tensor<?xf32>
+ // init_arg and iter_arg have
diff erent buffer types. This must be resolved
+ // with casts.
+ %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t = %e2) -> tensor<?xf32> {
+ %s = "test.dummy"() : () -> (index)
+ %e = tensor.extract_slice %t[1][%s][1] : tensor<?xf32> to tensor<?xf32>
+ scf.yield %e : tensor<?xf32>
+ }
+ %x = tensor.extract %r[%c1] : tensor<?xf32>
+ return %x : f32
+}
More information about the Mlir-commits
mailing list