[Mlir-commits] [mlir] 6ecebb4 - [mlir][bufferization] Support unstructured control flow
Matthias Springer
llvmlistbot at llvm.org
Thu Aug 31 03:56:10 PDT 2023
Author: Matthias Springer
Date: 2023-08-31T12:55:53+02:00
New Revision: 6ecebb496cc6960e100a05375ab7f64e831dd933
URL: https://github.com/llvm/llvm-project/commit/6ecebb496cc6960e100a05375ab7f64e831dd933
DIFF: https://github.com/llvm/llvm-project/commit/6ecebb496cc6960e100a05375ab7f64e831dd933.diff
LOG: [mlir][bufferization] Support unstructured control flow
This revision adds support for unstructured control flow to the bufferization infrastructure. In particular: regions with multiple blocks, `cf.br`, `cf.cond_br`.
Two helper templates are added to `BufferizableOpInterface.h`, which can be implemented by ops that supported unstructured control flow in their regions (e.g., `func.func`) and ops that branch to another block (e.g., `cf.br`).
A block signature is always bufferized together with the op that owns the block.
Differential Revision: https://reviews.llvm.org/D158094
Added:
mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h
mlir/lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp
mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
mlir/test/Dialect/ControlFlow/one-shot-bufferize-analysis.mlir
mlir/test/Dialect/ControlFlow/one-shot-bufferize-invalid.mlir
mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
mlir/include/mlir/InitAllDialects.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/ControlFlow/CMakeLists.txt
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 6fc487c1a11aa5..b61994e8b9feea 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -411,6 +411,10 @@ struct TraversalConfig {
/// Specifies whether OpOperands with a
diff erent type that are not the result
/// of a CastOpInterface op should be followed.
bool followSameTypeOrCastsOnly = false;
+
+ /// Specifies whether already visited values should be visited again.
+ /// (Note: This can result in infinite looping.)
+ bool revisitAlreadyVisitedValues = false;
};
/// AnalysisState provides a variety of helper functions for dealing with
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 42aff77303e0d1..7433853717f24f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -415,6 +415,13 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
the input IR and returns `failure` in that case. If this op is
expected to survive bufferization, `success` should be returned
(together with `allow-unknown-ops` enabled).
+
+ Note: If this op supports unstructured control flow in its regions,
+ then this function should also bufferize all block signatures that
+ belong to this op. Branch ops (that branch to a block) are typically
+ bufferized together with the block signature (this is just a
+ suggestion to make sure IR is valid at every point in time and could
+ be done
diff erently).
}],
/*retType=*/"::mlir::LogicalResult",
/*methodName=*/"bufferize",
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
new file mode 100644
index 00000000000000..78109770efab7a
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -0,0 +1,179 @@
+//===- UnstructuredControlFlow.h - Op Interface Helpers ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+
+//===----------------------------------------------------------------------===//
+// Helpers for Unstructured Control Flow
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace bufferization {
+
+namespace detail {
+/// Return a list of operands that are forwarded to the given block argument.
+/// I.e., find all predecessors of the block argument's owner and gather the
+/// operands that are equivalent to the block argument.
+SmallVector<OpOperand *> getCallerOpOperands(BlockArgument bbArg);
+} // namespace detail
+
+/// A template that provides a default implementation of `getAliasingOpOperands`
+/// for ops that support unstructured control flow within their regions.
+template <typename ConcreteModel, typename ConcreteOp>
+struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
+ : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
+
+ FailureOr<BaseMemRefType>
+ getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ SmallVector<Value> &invocationStack) const {
+ // Note: The user may want to override this function for OpResults in
+ // case the bufferized result type is
diff erent from the bufferized type of
+ // the aliasing OpOperand (if any).
+ if (isa<OpResult>(value))
+ return bufferization::detail::defaultGetBufferType(value, options,
+ invocationStack);
+
+ // Compute the buffer type of the block argument by computing the bufferized
+ // operand types of all forwarded values. If these are all the same type,
+ // take that type. Otherwise, take only the memory space and fall back to a
+ // buffer type with a fully dynamic layout map.
+ BaseMemRefType bufferType;
+ auto tensorType = cast<TensorType>(value.getType());
+ for (OpOperand *opOperand :
+ detail::getCallerOpOperands(cast<BlockArgument>(value))) {
+
+ // If the forwarded operand is already on the invocation stack, we ran
+ // into a loop and this operand cannot be used to compute the bufferized
+ // type.
+ if (llvm::find(invocationStack, opOperand->get()) !=
+ invocationStack.end())
+ continue;
+
+ // Compute the bufferized type of the forwarded operand.
+ BaseMemRefType callerType;
+ if (auto memrefType =
+ dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
+ // The operand was already bufferized. Take its type directly.
+ callerType = memrefType;
+ } else {
+ FailureOr<BaseMemRefType> maybeCallerType =
+ bufferization::getBufferType(opOperand->get(), options,
+ invocationStack);
+ if (failed(maybeCallerType))
+ return failure();
+ callerType = *maybeCallerType;
+ }
+
+ if (!bufferType) {
+ // This is the first buffer type that we computed.
+ bufferType = callerType;
+ continue;
+ }
+
+ if (bufferType == callerType)
+ continue;
+
+ // If the computed buffer type does not match the computed buffer type
+ // of the earlier forwarded operands, fall back to a buffer type with a
+ // fully dynamic layout map.
+#ifndef NDEBUG
+ if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
+ assert(bufferType.hasRank() && callerType.hasRank() &&
+ "expected ranked memrefs");
+ assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
+ rankedTensorType.getShape()}) &&
+ "expected same shape");
+ } else {
+ assert(!bufferType.hasRank() && !callerType.hasRank() &&
+ "expected unranked memrefs");
+ }
+#endif // NDEBUG
+
+ if (bufferType.getMemorySpace() != callerType.getMemorySpace())
+ return op->emitOpError("incoming operands of block argument have "
+ "inconsistent memory spaces");
+
+ bufferType = getMemRefTypeWithFullyDynamicLayout(
+ tensorType, bufferType.getMemorySpace());
+ }
+
+ if (!bufferType)
+ return op->emitOpError("could not infer buffer type of block argument");
+
+ return bufferType;
+ }
+
+protected:
+ /// Assuming that `bbArg` is a block argument of a block that belongs to the
+ /// given `op`, return all OpOperands of users of this block that are
+ /// aliasing with the given block argument.
+ AliasingOpOperandList
+ getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg,
+ const AnalysisState &state) const {
+ assert(bbArg.getOwner()->getParentOp() == op && "invalid bbArg");
+
+ // Gather aliasing OpOperands of all operations (callers) that link to
+ // this block.
+ AliasingOpOperandList result;
+ for (OpOperand *opOperand : detail::getCallerOpOperands(bbArg))
+ result.addAlias(
+ {opOperand, BufferRelation::Equivalent, /*isDefinite=*/false});
+
+ return result;
+ }
+};
+
+/// A template that provides a default implementation of `getAliasingValues`
+/// for ops that implement the `BranchOpInterface`.
+template <typename ConcreteModel, typename ConcreteOp>
+struct BranchOpBufferizableOpInterfaceExternalModel
+ : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
+ AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ AliasingValueList result;
+ auto branchOp = cast<BranchOpInterface>(op);
+ auto operandNumber = opOperand.getOperandNumber();
+
+ // Gather aliasing block arguments of blocks to which this op may branch to.
+ for (const auto &it : llvm::enumerate(op->getSuccessors())) {
+ Block *block = it.value();
+ SuccessorOperands operands = branchOp.getSuccessorOperands(it.index());
+ assert(operands.getProducedOperandCount() == 0 &&
+ "produced operands not supported");
+ if (operands.getForwardedOperands().empty())
+ continue;
+ // The first and last operands that are forwarded to this successor.
+ int64_t firstOperandIndex =
+ operands.getForwardedOperands().getBeginOperandIndex();
+ int64_t lastOperandIndex =
+ firstOperandIndex + operands.getForwardedOperands().size();
+ bool matchingDestination = operandNumber >= firstOperandIndex &&
+ operandNumber < lastOperandIndex;
+ // A branch op may have multiple successors. Find the ones that correspond
+ // to this OpOperand. (There is usually only one.)
+ if (!matchingDestination)
+ continue;
+ // Compute the matching block argument of the destination block.
+ BlockArgument bbArg =
+ block->getArgument(operandNumber - firstOperandIndex);
+ result.addAlias(
+ {bbArg, BufferRelation::Equivalent, /*isDefinite=*/false});
+ }
+
+ return result;
+ }
+};
+
+} // namespace bufferization
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_UNSTRUCTUREDCONTROLFLOW_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 6b1994a5335f15..3d3316db6b0933 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -78,8 +78,15 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
const OpFilter *opFilter = nullptr,
BufferizationStatistics *statistics = nullptr);
-/// Bufferize the signature of `block`. All block argument types are changed to
-/// memref types.
+/// Bufferize the signature of `block` and its callers (i.e., ops that have the
+/// given block as a successor). All block argument types are changed to memref
+/// types. All corresponding operands of all callers are wrapped in
+/// bufferization.to_memref ops. All uses of bufferized tensor block arguments
+/// are wrapped in bufferization.to_tensor ops.
+///
+/// It is expected that all callers implement the `BranchOpInterface`.
+/// Otherwise, this function will fail. The `BranchOpInterface` is used to query
+/// the range of operands that are forwarded to this block.
///
/// It is expected that the parent op of this block implements the
/// `BufferizableOpInterface`. The buffer types of tensor block arguments are
diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h
new file mode 100644
index 00000000000000..9b30ab4d98d27c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_CONTROLFLOW_BUFFERIZABLEOPINTERFACEIMPL_H
+#define MLIR_DIALECT_CONTROLFLOW_BUFFERIZABLEOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace cf {
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace cf
+} // namespace mlir
+
+#endif // MLIR_DIALECT_CONTROLFLOW_BUFFERIZABLEOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index aa0a580aceb828..54b39902b897d6 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -29,6 +29,7 @@
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -135,6 +136,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
builtin::registerCastOpInterfaceExternalModels(registry);
+ cf::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerTilingInterfaceExternalModels(registry);
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index a96cfedc9a4527..34a1625b0daa7b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -234,6 +234,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
});
if (aliasingValues.getNumAliases() == 1 &&
+ isa<OpResult>(aliasingValues.getAliases()[0].value) &&
!state.bufferizesToMemoryWrite(opOperand) &&
state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
.getNumAliases() == 1 &&
@@ -498,11 +499,16 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
bool AnalysisState::isValueRead(Value value) const {
assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
SmallVector<OpOperand *> workingSet;
+ DenseSet<OpOperand *> visited;
for (OpOperand &use : value.getUses())
workingSet.push_back(&use);
while (!workingSet.empty()) {
OpOperand *uMaybeReading = workingSet.pop_back_val();
+ if (visited.contains(uMaybeReading))
+ continue;
+ visited.insert(uMaybeReading);
+
// Skip over all ops that neither read nor write (but create an alias).
if (bufferizesToAliasOnly(*uMaybeReading))
for (AliasingValue alias : getAliasingValues(*uMaybeReading))
@@ -522,11 +528,21 @@ bool AnalysisState::isValueRead(Value value) const {
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
Value value, llvm::function_ref<bool(Value)> condition,
TraversalConfig config) const {
+ llvm::DenseSet<Value> visited;
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
while (!workingSet.empty()) {
Value value = workingSet.pop_back_val();
+
+ if (!config.revisitAlreadyVisitedValues && visited.contains(value)) {
+ // Stop traversal if value was already visited.
+ if (config.alwaysIncludeLeaves)
+ result.insert(value);
+ continue;
+ }
+ visited.insert(value);
+
if (condition(value)) {
result.insert(value);
continue;
@@ -659,11 +675,15 @@ bool AnalysisState::isTensorYielded(Value tensor) const {
// preceding value, so we can follow SSA use-def chains and do a simple
// analysis.
SmallVector<OpOperand *> worklist;
+ DenseSet<OpOperand *> visited;
for (OpOperand &use : tensor.getUses())
worklist.push_back(&use);
while (!worklist.empty()) {
OpOperand *operand = worklist.pop_back_val();
+ if (visited.contains(operand))
+ continue;
+ visited.insert(operand);
Operation *op = operand->getOwner();
// If the op is not bufferizable, we can safely assume that the value is not
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 5b6e850643daec..6527c67f3a8816 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
BufferizableOpInterface.cpp
BufferizationOps.cpp
BufferizationDialect.cpp
+ UnstructuredControlFlow.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
diff --git a/mlir/lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp b/mlir/lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp
new file mode 100644
index 00000000000000..32a198c86167fb
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp
@@ -0,0 +1,32 @@
+//===- UnstructuredControlFlow.cpp - Op Interface Helpers ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
+
+using namespace mlir;
+
+SmallVector<OpOperand *>
+mlir::bufferization::detail::getCallerOpOperands(BlockArgument bbArg) {
+ SmallVector<OpOperand *> result;
+ Block *block = bbArg.getOwner();
+ for (Operation *caller : block->getUsers()) {
+ auto branchOp = dyn_cast<BranchOpInterface>(caller);
+ assert(branchOp && "expected that all callers implement BranchOpInterface");
+ auto it = llvm::find(caller->getSuccessors(), block);
+ assert(it != caller->getSuccessors().end() && "could not find successor");
+ int64_t successorIdx = std::distance(caller->getSuccessors().begin(), it);
+ SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
+ assert(operands.getProducedOperandCount() == 0 &&
+ "produced operands not supported");
+ int64_t operandIndex =
+ operands.getForwardedOperands().getBeginOperandIndex() +
+ bbArg.getArgNumber();
+ result.push_back(&caller->getOpOperand(operandIndex));
+ }
+ return result;
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c3950a553d7b5d..8fca041fe6aaa6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -357,6 +358,16 @@ static bool isaTensor(Type t) { return isa<TensorType>(t); }
/// Return true if the given op has a tensor result or a tensor operand.
static bool hasTensorSemantics(Operation *op) {
+ bool hasTensorBlockArgument = any_of(op->getRegions(), [](Region &r) {
+ return any_of(r.getBlocks(), [](Block &b) {
+ return any_of(b.getArguments(), [](BlockArgument bbArg) {
+ return isaTensor(bbArg.getType());
+ });
+ });
+ });
+ if (hasTensorBlockArgument)
+ return true;
+
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
@@ -618,6 +629,43 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
}
}
+ // Bufferize callers of the block.
+ for (Operation *op : block->getUsers()) {
+ auto branchOp = dyn_cast<BranchOpInterface>(op);
+ if (!branchOp)
+ return op->emitOpError("cannot bufferize ops with block references that "
+ "do not implement BranchOpInterface");
+
+ auto it = llvm::find(op->getSuccessors(), block);
+ assert(it != op->getSuccessors().end() && "could find successor");
+ int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
+
+ SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
+ SmallVector<Value> newOperands;
+ for (auto [operand, type] :
+ llvm::zip(operands.getForwardedOperands(), newTypes)) {
+ if (operand.getType() == type) {
+ // Not a tensor type. Nothing to do for this operand.
+ newOperands.push_back(operand);
+ continue;
+ }
+ FailureOr<BaseMemRefType> operandBufferType =
+ bufferization::getBufferType(operand, options);
+ if (failed(operandBufferType))
+ return failure();
+ rewriter.setInsertionPointAfterValue(operand);
+ Value bufferizedOperand = rewriter.create<bufferization::ToMemrefOp>(
+ operand.getLoc(), *operandBufferType, operand);
+ // A cast is needed if the operand and the block argument have
diff erent
+ // bufferized types.
+ if (type != *operandBufferType)
+ bufferizedOperand = rewriter.create<memref::CastOp>(
+ operand.getLoc(), type, bufferizedOperand);
+ newOperands.push_back(bufferizedOperand);
+ }
+ operands.getMutableForwardedOperands().assign(newOperands);
+ }
+
return success();
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 568dde6919471c..8141e554961995 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -319,16 +320,45 @@ struct ReturnOpInterface
};
struct FuncOpInterface
- : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
+ : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
+ FuncOpInterface, FuncOp> {
+
+ static bool supportsUnstructuredControlFlow() { return true; }
+
+ AliasingOpOperandList
+ getAliasingOpOperands(Operation *op, Value value,
+ const AnalysisState &state) const {
+ return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
+ }
+
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto funcOp = cast<FuncOp>(op);
auto bbArg = cast<BlockArgument>(value);
- // Unstructured control flow is not supported.
- assert(bbArg.getOwner() == &funcOp.getBody().front() &&
- "expected that block argument belongs to first block");
- return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options);
+
+ // Function arguments are special.
+ if (bbArg.getOwner() == &funcOp.getBody().front())
+ return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
+ options);
+
+ return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
+ getBufferType(op, value, options, invocationStack);
+ }
+
+ LogicalResult verifyAnalysis(Operation *op,
+ const AnalysisState &state) const {
+ auto funcOp = cast<func::FuncOp>(op);
+ // TODO: func.func with multiple returns are not supported.
+ if (!getAssumedUniqueReturnOp(funcOp) && !funcOp.isExternal())
+ return op->emitOpError("op without unique func.return is not supported");
+ const auto &options =
+ static_cast<const OneShotBufferizationOptions &>(state.getOptions());
+ // allow-return-allocs is required for ops with multiple blocks.
+ if (options.allowReturnAllocs || funcOp.getRegion().getBlocks().size() <= 1)
+ return success();
+ return op->emitOpError(
+ "op cannot be bufferized without allow-return-allocs");
}
/// Rewrite function bbArgs and return values into buffer form. This function
@@ -358,7 +388,7 @@ struct FuncOpInterface
// Bodiless functions are assumed opaque and we cannot know the
// bufferization contract they want to enforce. As a consequence, only
// support functions that don't return any tensors atm.
- if (funcOp.getBody().empty()) {
+ if (funcOp.isExternal()) {
SmallVector<Type> retTypes;
for (Type resultType : funcType.getResults()) {
if (isa<TensorType>(resultType))
@@ -420,6 +450,11 @@ struct FuncOpInterface
BlockArgument bbArg = dyn_cast<BlockArgument>(value);
assert(bbArg && "expected BlockArgument");
+ // Non-entry block arguments are always writable. (They may alias with
+ // values that are not writable, which will turn them into read-only.)
+ if (bbArg.getOwner() != &funcOp.getBody().front())
+ return true;
+
// "bufferization.writable" overrides other writability decisions. This is
// currently used for testing only.
if (BoolAttr writable = funcOp.getArgAttrOfType<BoolAttr>(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index ba595bec0e6bdc..49b5ebdf722a1a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -309,8 +309,29 @@ static bool happensBefore(Operation *a, Operation *b,
return false;
}
+static bool isReachable(Block *from, Block *to, ArrayRef<Block *> except) {
+ DenseSet<Block *> visited;
+ SmallVector<Block *> worklist;
+ for (Block *succ : from->getSuccessors())
+ worklist.push_back(succ);
+ while (!worklist.empty()) {
+ Block *next = worklist.pop_back_val();
+ if (llvm::find(except, next) != except.end())
+ continue;
+ if (next == to)
+ return true;
+ if (visited.contains(next))
+ continue;
+ visited.insert(next);
+ for (Block *succ : next->getSuccessors())
+ worklist.push_back(succ);
+ }
+ return false;
+}
+
/// Return `true` if op dominance can be used to rule out a read-after-write
-/// conflicts based on the ordering of ops.
+/// conflicts based on the ordering of ops. Returns `false` if op dominance
+/// cannot be used to due region-based loops.
///
/// Generalized op dominance can often be used to rule out potential conflicts
/// due to "read happens before write". E.g., the following IR is not a RaW
@@ -383,9 +404,9 @@ static bool happensBefore(Operation *a, Operation *b,
/// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
/// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
///
-static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
- const SetVector<Value> &definitions,
- AnalysisState &state) {
+static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite,
+ const SetVector<Value> &definitions,
+ AnalysisState &state) {
const BufferizationOptions &options = state.getOptions();
for (Value def : definitions) {
Region *rRead =
@@ -411,9 +432,53 @@ static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
if (rRead->getParentOp()->isAncestor(uWrite->getOwner()))
return false;
}
+
+ return true;
+}
+
+/// Return `true` if op dominance can be used to rule out a read-after-write
+/// conflicts based on the ordering of ops. Returns `false` if op dominance
+/// cannot be used to due block-based loops within a region.
+///
+/// Refer to the `canUseOpDominanceDueToRegions` documentation for details on
+/// how op domiance is used during RaW conflict detection.
+///
+/// On a high-level, there is a potential RaW in a program if there exists a
+/// possible program execution such that there is a sequence of DEF, followed
+/// by WRITE, followed by READ. Each additional DEF resets the sequence.
+///
+/// Op dominance cannot be used if there is a path from block(READ) to
+/// block(WRITE) and a path from block(WRITE) to block(READ). block(DEF) should
+/// not appear on that path.
+static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite,
+ const SetVector<Value> &definitions,
+ AnalysisState &state) {
+ // Fast path: If READ and WRITE are in
diff erent regions, their block cannot
+ // be reachable just via unstructured control flow. (Loops due to regions are
+ // covered by `canUseOpDominanceDueToRegions`.)
+ if (uRead->getOwner()->getParentRegion() !=
+ uWrite->getOwner()->getParentRegion())
+ return true;
+
+ Block *readBlock = uRead->getOwner()->getBlock();
+ Block *writeBlock = uWrite->getOwner()->getBlock();
+ for (Value def : definitions) {
+ Block *defBlock = def.getParentBlock();
+ if (isReachable(readBlock, writeBlock, {defBlock}) &&
+ isReachable(writeBlock, readBlock, {defBlock}))
+ return false;
+ }
+
return true;
}
+static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
+ const SetVector<Value> &definitions,
+ AnalysisState &state) {
+ return canUseOpDominanceDueToRegions(uRead, uWrite, definitions, state) &&
+ canUseOpDominanceDueToBlocks(uRead, uWrite, definitions, state);
+}
+
/// Annotate IR with details about the detected RaW conflict.
static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
Value definition) {
diff --git a/mlir/lib/Dialect/ControlFlow/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/CMakeLists.txt
index f33061b2d87cff..9f57627c321fb0 100644
--- a/mlir/lib/Dialect/ControlFlow/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/CMakeLists.txt
@@ -1 +1,2 @@
add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
new file mode 100644
index 00000000000000..3228872029a274
--- /dev/null
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -0,0 +1,75 @@
+//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+namespace mlir {
+namespace cf {
+namespace {
+
+template <typename ConcreteModel, typename ConcreteOp>
+struct BranchLikeOpInterface
+ : public BranchOpBufferizableOpInterfaceExternalModel<ConcreteModel,
+ ConcreteOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return false;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ return false;
+ }
+
+ LogicalResult verifyAnalysis(Operation *op,
+ const AnalysisState &state) const {
+ const auto &options =
+ static_cast<const OneShotBufferizationOptions &>(state.getOptions());
+ if (options.allowReturnAllocs)
+ return success();
+ return op->emitOpError(
+ "op cannot be bufferized without allow-return-allocs");
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const BufferizationOptions &options) const {
+ // The operands of this op are bufferized together with the block signature.
+ return success();
+ }
+};
+
+/// Bufferization of cf.br.
+struct BranchOpInterface
+ : public BranchLikeOpInterface<BranchOpInterface, cf::BranchOp> {};
+
+/// Bufferization of cf.cond_br.
+struct CondBranchOpInterface
+ : public BranchLikeOpInterface<CondBranchOpInterface, cf::CondBranchOp> {};
+
+} // namespace
+} // namespace cf
+} // namespace mlir
+
+void mlir::cf::registerBufferizableOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
+ cf::BranchOp::attachInterface<BranchOpInterface>(*ctx);
+ cf::CondBranchOp::attachInterface<CondBranchOpInterface>(*ctx);
+ });
+}
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
new file mode 100644
index 00000000000000..b2ef59887515e7
--- /dev/null
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRControlFlowTransforms
+ BufferizableOpInterfaceImpl.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms
+
+ LINK_LIBS PUBLIC
+ MLIRBufferizationDialect
+ MLIRBufferizationTransforms
+ MLIRControlFlowDialect
+ MLIRMemRefDialect
+ MLIRIR
+ )
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index ac01d264eb8fba..1a604c00e4321f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -10,6 +10,8 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h"
+#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -99,37 +101,74 @@ struct ConditionOpInterface
}
};
+/// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
+/// return an empty op.
+static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
+ scf::YieldOp result;
+ for (Block &block : executeRegionOp.getRegion()) {
+ if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
+ if (result)
+ return {};
+ result = yieldOp;
+ }
+ }
+ return result;
+}
+
/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
/// fully implemented at the moment.
struct ExecuteRegionOpInterface
- : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
- scf::ExecuteRegionOp> {
+ : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
+ ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
+
+ static bool supportsUnstructuredControlFlow() { return true; }
+
+ bool isWritable(Operation *op, Value value,
+ const AnalysisState &state) const {
+ return true;
+ }
+
+ LogicalResult verifyAnalysis(Operation *op,
+ const AnalysisState &state) const {
+ auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
+ // TODO: scf.execute_region with multiple yields are not supported.
+ if (!getUniqueYieldOp(executeRegionOp))
+ return op->emitOpError("op without unique scf.yield is not supported");
+ const auto &options =
+ static_cast<const OneShotBufferizationOptions &>(state.getOptions());
+ // allow-return-allocs is required for ops with multiple blocks.
+ if (options.allowReturnAllocs ||
+ executeRegionOp.getRegion().getBlocks().size() == 1)
+ return success();
+ return op->emitOpError(
+ "op cannot be bufferized without allow-return-allocs");
+ }
+
AliasingOpOperandList
getAliasingOpOperands(Operation *op, Value value,
const AnalysisState &state) const {
+ if (auto bbArg = dyn_cast<BlockArgument>(value))
+ return getAliasingBranchOpOperands(op, bbArg, state);
+
// ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
// any SSA value that is in scope. To allow for use-def chain traversal
// through ExecuteRegionOps in the analysis, the corresponding yield value
// is considered to be aliasing with the result.
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
- size_t resultNum = std::distance(op->getOpResults().begin(),
- llvm::find(op->getOpResults(), value));
- // TODO: Support multiple blocks.
- assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
- "expected exactly 1 block");
- auto yieldOp = dyn_cast<scf::YieldOp>(
- executeRegionOp.getRegion().front().getTerminator());
- assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
+ auto it = llvm::find(op->getOpResults(), value);
+ assert(it != op->getOpResults().end() && "invalid value");
+ size_t resultNum = std::distance(op->getOpResults().begin(), it);
+ auto yieldOp = getUniqueYieldOp(executeRegionOp);
+ // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
+ if (!yieldOp)
+ return {};
return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
- assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
- "only 1 block supported");
- auto yieldOp =
- cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
+ auto yieldOp = getUniqueYieldOp(executeRegionOp);
TypeRange newResultTypes(yieldOp.getResults());
// Create new op and move over region.
@@ -137,6 +176,12 @@ struct ExecuteRegionOpInterface
rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
newOp.getRegion().takeBody(executeRegionOp.getRegion());
+ // Bufferize every block.
+ for (Block &block : newOp.getRegion())
+ if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
+ options)))
+ return failure();
+
// Update all uses of the old op.
rewriter.setInsertionPointAfter(newOp);
SmallVector<Value> newResults;
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
index 071ec6f5d28ef7..45edca756e0dc7 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
@@ -321,3 +321,12 @@ func.func @regression_scf_while() {
}
return
}
+
+// -----
+
+// expected-error @below{{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
+func.func @func_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> {
+ func.return %t : tensor<5xf32>
+^bb1(%arg1 : tensor<5xf32>):
+ func.return %arg1 : tensor<5xf32>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index af7d8916448edc..249ee3448b8904 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -659,3 +659,15 @@ func.func @to_memref_op_unsupported(
return %r1 : vector<5xf32>
}
+
+// -----
+
+// Note: The cf.br canonicalizes away, so there's nothing to check here. There
+// is a detailed test in ControlFlow/bufferize.mlir.
+
+// CHECK-LABEL: func @br_in_func(
+func.func @br_in_func(%t: tensor<5xf32>) -> tensor<5xf32> {
+ cf.br ^bb1(%t : tensor<5xf32>)
+^bb1(%arg1 : tensor<5xf32>):
+ func.return %arg1 : tensor<5xf32>
+}
diff --git a/mlir/test/Dialect/ControlFlow/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/ControlFlow/one-shot-bufferize-analysis.mlir
new file mode 100644
index 00000000000000..84df4c8045a886
--- /dev/null
+++ b/mlir/test/Dialect/ControlFlow/one-shot-bufferize-analysis.mlir
@@ -0,0 +1,221 @@
+// RUN: mlir-opt -one-shot-bufferize="allow-return-allocs test-analysis-only dump-alias-sets bufferize-function-boundaries" -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @single_branch(
+// CHECK-SAME: {__bbarg_alias_set_attr__ = [{{\[}}[{{\[}}"%[[arg1:.*]]", "%[[t:.*]]"]], [{{\[}}"%[[arg1]]", "%[[t]]"]]]]}
+func.func @single_branch(%t: tensor<5xf32>) -> tensor<5xf32> {
+// CHECK: cf.br
+// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
+ cf.br ^bb1(%t : tensor<5xf32>)
+// CHECK: ^{{.*}}(%[[arg1]]: tensor<5xf32>)
+^bb1(%arg1 : tensor<5xf32>):
+ func.return %arg1 : tensor<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @diamond_branch(
+// CHECK-SAME: %{{.*}}: i1, %[[t0:.*]]: tensor<5xf32> {{.*}}, %[[t1:.*]]: tensor<5xf32> {{.*}}) -> tensor<5xf32>
+// CHECK-SAME: {__bbarg_alias_set_attr__ = [{{\[}}[{{\[}}"%[[arg1:.*]]", "%[[arg3:.*]]", "%[[arg2:.*]]", "%[[t0]]", "%[[t1]]"], [
+func.func @diamond_branch(%c: i1, %t0: tensor<5xf32>, %t1: tensor<5xf32>) -> tensor<5xf32> {
+// CHECK: cf.cond_br
+// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]}
+ cf.cond_br %c, ^bb1(%t0 : tensor<5xf32>), ^bb2(%t1 : tensor<5xf32>)
+// CHECK: ^{{.*}}(%[[arg1]]: tensor<5xf32>):
+^bb3(%arg1 : tensor<5xf32>):
+ func.return %arg1 : tensor<5xf32>
+// CHECK: ^{{.*}}(%[[arg2]]: tensor<5xf32>):
+^bb1(%arg2 : tensor<5xf32>):
+// CHECK: cf.br
+// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
+ cf.br ^bb3(%arg2 : tensor<5xf32>)
+// CHECK: ^{{.*}}(%[[arg3]]: tensor<5xf32>):
+^bb2(%arg3 : tensor<5xf32>):
+// CHECK: cf.br
+// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
+ cf.br ^bb3(%arg3 : tensor<5xf32>)
+}
+
+// -----
+
+// CHECK-LABEL: func @looping_branches(
+// CHECK-SAME: {__bbarg_alias_set_attr__ = [{{\[}}[], [{{\[}}"%[[arg2:.*]]", "%[[arg1:.*]]", "%[[inserted:.*]]", "%[[empty:.*]]"]], [
+func.func @looping_branches() -> tensor<5xf32> {
+// CHECK: %[[empty]] = tensor.empty()
+ %0 = tensor.empty() : tensor<5xf32>
+// CHECK: cf.br
+// CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
+ cf.br ^bb1(%0: tensor<5xf32>)
+// CHECK: ^{{.*}}(%[[arg1]]: tensor<5xf32>):
+^bb1(%arg1: tensor<5xf32>):
+ %pos = "test.foo"() : () -> (index)
+ %val = "test.bar"() : () -> (f32)
+// CHECK: %[[inserted]] = tensor.insert
+// CHECK-SAME: __inplace_operands_attr__ = ["none", "true", "none"]
+ %inserted = tensor.insert %val into %arg1[%pos] : tensor<5xf32>
+ %cond = "test.qux"() : () -> (i1)
+// CHECK: cf.cond_br
+// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]}
+ cf.cond_br %cond, ^bb1(%inserted: tensor<5xf32>), ^bb2(%inserted: tensor<5xf32>)
+^bb2(%arg2: tensor<5xf32>):
+ func.return %arg2 : tensor<5xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @looping_branches_with_conflict(
+func.func @looping_branches_with_conflict(%f: f32) -> tensor<5xf32> {
+ %0 = tensor.empty() : tensor<5xf32>
+ %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+// CHECK: cf.br
+// CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
+ cf.br ^bb1(%filled: tensor<5xf32>)
+^bb2(%arg2: tensor<5xf32>):
+ %pos2 = "test.foo"() : () -> (index)
+ // One OpOperand cannot bufferize in-place because an "old" value is read.
+ %element = tensor.extract %filled[%pos2] : tensor<5xf32>
+ func.return %arg2 : tensor<5xf32>
+^bb1(%arg1: tensor<5xf32>):
+ %pos = "test.foo"() : () -> (index)
+ %val = "test.bar"() : () -> (f32)
+// CHECK: tensor.insert
+// CHECK-SAME: __inplace_operands_attr__ = ["none", "true", "none"]
+ %inserted = tensor.insert %val into %arg1[%pos] : tensor<5xf32>
+ %cond = "test.qux"() : () -> (i1)
+// CHECK: cf.cond_br
+// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]}
+ cf.cond_br %cond, ^bb1(%inserted: tensor<5xf32>), ^bb2(%inserted: tensor<5xf32>)
+}
+
+// -----
+
+// CHECK-LABEL: func @looping_branches_outside_def(
+func.func @looping_branches_outside_def(%f: f32) {
+// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor()
+ %0 = bufferization.alloc_tensor() : tensor<5xf32>
+// CHECK: %[[fill:.*]] = linalg.fill
+// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"], __opresult_alias_set_attr__ = [{{\[}}"%[[fill]]", "%[[alloc]]"]]}
+ %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+ cf.br ^bb1
+^bb1:
+ %pos = "test.foo"() : () -> (index)
+ %val = "test.bar"() : () -> (f32)
+// CHECK: tensor.insert
+// CHECK-SAME: __inplace_operands_attr__ = ["none", "false", "none"]
+ %inserted = tensor.insert %val into %filled[%pos] : tensor<5xf32>
+ %pos2 = "test.foo"() : () -> (index)
+ %read = tensor.extract %inserted[%pos2] : tensor<5xf32>
+ %cond = "test.qux"(%read) : (f32) -> (i1)
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb2:
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @looping_branches_outside_def2(
+func.func @looping_branches_outside_def2(%f: f32) {
+// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor()
+ %0 = bufferization.alloc_tensor() : tensor<5xf32>
+// CHECK: %[[fill:.*]] = linalg.fill
+// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"], __opresult_alias_set_attr__ = [{{\[}}"%[[arg0:.*]]", "%[[fill]]", "%[[alloc]]"]]}
+ %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+// CHECK: cf.br {{.*}}(%[[fill]] : tensor<5xf32>)
+// CHECK-SAME: __inplace_operands_attr__ = ["true"]
+ cf.br ^bb1(%filled: tensor<5xf32>)
+// CHECK: ^{{.*}}(%[[arg0]]: tensor<5xf32>):
+^bb1(%arg0: tensor<5xf32>):
+ %pos = "test.foo"() : () -> (index)
+ %val = "test.bar"() : () -> (f32)
+// CHECK: tensor.insert
+// CHECK-SAME: __inplace_operands_attr__ = ["none", "false", "none"]
+ %inserted = tensor.insert %val into %arg0[%pos] : tensor<5xf32>
+ %pos2 = "test.foo"() : () -> (index)
+ %read = tensor.extract %inserted[%pos2] : tensor<5xf32>
+ %cond = "test.qux"(%read) : (f32) -> (i1)
+// CHECK: cf.cond_br
+// CHECK-SAME: __inplace_operands_attr__ = ["none", "true"]
+ cf.cond_br %cond, ^bb1(%arg0: tensor<5xf32>), ^bb2
+^bb2:
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @looping_branches_outside_def3(
+func.func @looping_branches_outside_def3(%f: f32) {
+// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor()
+ %0 = bufferization.alloc_tensor() : tensor<5xf32>
+// CHECK: %[[fill:.*]] = linalg.fill
+// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"], __opresult_alias_set_attr__ = [{{\[}}"%[[arg0:.*]]", "%[[fill]]", "%[[alloc]]"]]}
+ %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+// CHECK: cf.br {{.*}}(%[[fill]] : tensor<5xf32>)
+// CHECK-SAME: __inplace_operands_attr__ = ["true"]
+ cf.br ^bb1(%filled: tensor<5xf32>)
+// CHECK: ^{{.*}}(%[[arg0]]: tensor<5xf32>):
+^bb1(%arg0: tensor<5xf32>):
+ %pos = "test.foo"() : () -> (index)
+ %val = "test.bar"() : () -> (f32)
+// CHECK: tensor.insert
+// CHECK-SAME: __inplace_operands_attr__ = ["none", "false", "none"]
+ %inserted = tensor.insert %val into %arg0[%pos] : tensor<5xf32>
+ %pos2 = "test.foo"() : () -> (index)
+ %read = tensor.extract %inserted[%pos2] : tensor<5xf32>
+ %cond = "test.qux"(%read) : (f32) -> (i1)
+// CHECK: cf.cond_br
+// CHECK-SAME: __inplace_operands_attr__ = ["none", "true"]
+ cf.cond_br %cond, ^bb1(%filled: tensor<5xf32>), ^bb2
+^bb2:
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @looping_branches_sequence_outside_def(
+func.func @looping_branches_sequence_outside_def(%f: f32) {
+// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor()
+ %0 = bufferization.alloc_tensor() : tensor<5xf32>
+// CHECK: %[[fill:.*]] = linalg.fill
+// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"], __opresult_alias_set_attr__ = [{{\[}}"%[[fill]]", "%[[alloc]]"]]}
+ %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+ cf.br ^bb1
+^bb1:
+ %pos = "test.foo"() : () -> (index)
+ %val = "test.bar"() : () -> (f32)
+// CHECK: tensor.insert
+// CHECK-SAME: __inplace_operands_attr__ = ["none", "false", "none"]
+ %inserted = tensor.insert %val into %filled[%pos] : tensor<5xf32>
+ cf.br ^bb2
+^bb2:
+ %pos2 = "test.foo"() : () -> (index)
+ %read = tensor.extract %inserted[%pos2] : tensor<5xf32>
+ %cond = "test.qux"(%read) : (f32) -> (i1)
+ cf.cond_br %cond, ^bb1, ^bb3
+^bb3:
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @looping_branches_sequence_inside_def(
+func.func @looping_branches_sequence_inside_def(%f: f32) {
+ cf.br ^bb1
+^bb1:
+// CHECK: %[[alloc:.*]] = bufferization.alloc_tensor()
+ %0 = bufferization.alloc_tensor() : tensor<5xf32>
+// CHECK: %[[fill:.*]] = linalg.fill
+// CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"], __opresult_alias_set_attr__ = [{{\[}}"%[[inserted:.*]]", "%[[fill]]", "%[[alloc]]"]]}
+ %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+ %pos = "test.foo"() : () -> (index)
+ %val = "test.bar"() : () -> (f32)
+// CHECK: %[[inserted]] = tensor.insert
+// CHECK-SAME: __inplace_operands_attr__ = ["none", "true", "none"]
+ %inserted = tensor.insert %val into %filled[%pos] : tensor<5xf32>
+ cf.br ^bb2
+^bb2:
+ %pos2 = "test.foo"() : () -> (index)
+ %read = tensor.extract %inserted[%pos2] : tensor<5xf32>
+ %cond = "test.qux"(%read) : (f32) -> (i1)
+ cf.cond_br %cond, ^bb1, ^bb3
+^bb3:
+ func.return
+}
diff --git a/mlir/test/Dialect/ControlFlow/one-shot-bufferize-invalid.mlir b/mlir/test/Dialect/ControlFlow/one-shot-bufferize-invalid.mlir
new file mode 100644
index 00000000000000..7ff837540711ef
--- /dev/null
+++ b/mlir/test/Dialect/ControlFlow/one-shot-bufferize-invalid.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file %s -verify-diagnostics
+
+// expected-error @below{{failed to bufferize op}}
+// expected-error @below{{incoming operands of block argument have inconsistent memory spaces}}
+func.func @inconsistent_memory_space() -> tensor<5xf32> {
+ %0 = bufferization.alloc_tensor() {memory_space = 0 : ui64} : tensor<5xf32>
+ cf.br ^bb1(%0: tensor<5xf32>)
+^bb1(%arg1: tensor<5xf32>):
+ func.return %arg1 : tensor<5xf32>
+^bb2():
+ %1 = bufferization.alloc_tensor() {memory_space = 1 : ui64} : tensor<5xf32>
+ cf.br ^bb1(%1: tensor<5xf32>)
+}
+
+// -----
+
+// expected-error @below{{failed to bufferize op}}
+// expected-error @below{{could not infer buffer type of block argument}}
+func.func @cannot_infer_type() {
+ return
+ // The type of the block argument cannot be inferred.
+^bb1(%t: tensor<5xf32>):
+ cf.br ^bb1(%t: tensor<5xf32>)
+}
diff --git a/mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir b/mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir
new file mode 100644
index 00000000000000..482cb379d57a97
--- /dev/null
+++ b/mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -one-shot-bufferize="allow-return-allocs" -split-input-file %s | FileCheck %s --check-prefix=CHECK-NO-FUNC
+
+// CHECK-NO-FUNC-LABEL: func @br(
+// CHECK-NO-FUNC-SAME: %[[t:.*]]: tensor<5xf32>)
+// CHECK-NO-FUNC: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<5xf32, strided<[?], offset: ?>>
+// CHECK-NO-FUNC: %[[r:.*]] = scf.execute_region -> memref<5xf32, strided<[?], offset: ?>> {
+// CHECK-NO-FUNC: cf.br ^[[block:.*]](%[[m]]
+// CHECK-NO-FUNC: ^[[block]](%[[arg1:.*]]: memref<5xf32, strided<[?], offset: ?>>):
+// CHECK-NO-FUNC: scf.yield %[[arg1]]
+// CHECK-NO-FUNC: }
+// CHECK-NO-FUNC: return
+func.func @br(%t: tensor<5xf32>) {
+ %0 = scf.execute_region -> tensor<5xf32> {
+ cf.br ^bb1(%t : tensor<5xf32>)
+ ^bb1(%arg1 : tensor<5xf32>):
+ scf.yield %arg1 : tensor<5xf32>
+ }
+ return
+}
+
+// -----
+
+// CHECK-NO-FUNC-LABEL: func @cond_br(
+// CHECK-NO-FUNC-SAME: %[[t1:.*]]: tensor<5xf32>,
+// CHECK-NO-FUNC: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<5xf32, strided<[?], offset: ?>>
+// CHECK-NO-FUNC: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32>
+// CHECK-NO-FUNC: %[[r:.*]] = scf.execute_region -> memref<5xf32, strided<[?], offset: ?>> {
+// CHECK-NO-FUNC: cf.cond_br %{{.*}}, ^[[block1:.*]](%[[m1]] : {{.*}}), ^[[block2:.*]](%[[alloc]] : {{.*}})
+// CHECK-NO-FUNC: ^[[block1]](%[[arg1:.*]]: memref<5xf32, strided<[?], offset: ?>>):
+// CHECK-NO-FUNC: scf.yield %[[arg1]]
+// CHECK-NO-FUNC: ^[[block2]](%[[arg2:.*]]: memref<5xf32>):
+// CHECK-NO-FUNC: %[[cast:.*]] = memref.cast %[[arg2]] : memref<5xf32> to memref<5xf32, strided<[?], offset: ?>
+// CHECK-NO-FUNC: cf.br ^[[block1]](%[[cast]] : {{.*}})
+// CHECK-NO-FUNC: }
+// CHECK-NO-FUNC: return
+func.func @cond_br(%t1: tensor<5xf32>, %c: i1) {
+ // Use an alloc for the second block instead of a function block argument.
+ // A cast must be inserted because the two will have
diff erent layout maps.
+ %t0 = bufferization.alloc_tensor() : tensor<5xf32>
+ %0 = scf.execute_region -> tensor<5xf32> {
+ cf.cond_br %c, ^bb1(%t1 : tensor<5xf32>), ^bb2(%t0 : tensor<5xf32>)
+ ^bb1(%arg1 : tensor<5xf32>):
+ scf.yield %arg1 : tensor<5xf32>
+ ^bb2(%arg2 : tensor<5xf32>):
+ cf.br ^bb1(%arg2 : tensor<5xf32>)
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @looping_branches(
+func.func @looping_branches() -> tensor<5xf32> {
+// CHECK: %[[alloc:.*]] = memref.alloc
+ %0 = bufferization.alloc_tensor() : tensor<5xf32>
+// CHECK: cf.br {{.*}}(%[[alloc]]
+ cf.br ^bb1(%0: tensor<5xf32>)
+// CHECK: ^{{.*}}(%[[arg1:.*]]: memref<5xf32>):
+^bb1(%arg1: tensor<5xf32>):
+ %pos = "test.foo"() : () -> (index)
+ %val = "test.bar"() : () -> (f32)
+// CHECK: memref.store %{{.*}}, %[[arg1]]
+ %inserted = tensor.insert %val into %arg1[%pos] : tensor<5xf32>
+ %cond = "test.qux"() : () -> (i1)
+// CHECK: cf.cond_br {{.*}}(%[[arg1]] {{.*}}(%[[arg1]]
+ cf.cond_br %cond, ^bb1(%inserted: tensor<5xf32>), ^bb2(%inserted: tensor<5xf32>)
+// CHECK: ^{{.*}}(%[[arg2:.*]]: memref<5xf32>):
+^bb2(%arg2: tensor<5xf32>):
+// CHECK: return %[[arg2]]
+ func.return %arg2 : tensor<5xf32>
+}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
index c8d6d506270a99..0544656034b22c 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-invalid.mlir
@@ -17,10 +17,10 @@ func.func @inconsistent_memory_space_scf_if(%c: i1) -> tensor<10xf32> {
// -----
-func.func @execute_region_multiple_blocks(%t: tensor<5xf32>) -> tensor<5xf32> {
- // expected-error @below{{op or BufferizableOpInterface implementation does not support unstructured control flow, but at least one region has multiple blocks}}
+func.func @execute_region_multiple_yields(%t: tensor<5xf32>) -> tensor<5xf32> {
+ // expected-error @below{{op op without unique scf.yield is not supported}}
%0 = scf.execute_region -> tensor<5xf32> {
- cf.br ^bb1(%t : tensor<5xf32>)
+ scf.yield %t : tensor<5xf32>
^bb1(%arg1 : tensor<5xf32>):
scf.yield %arg1 : tensor<5xf32>
}
@@ -29,6 +29,20 @@ func.func @execute_region_multiple_blocks(%t: tensor<5xf32>) -> tensor<5xf32> {
// -----
+func.func @execute_region_no_yield(%t: tensor<5xf32>) -> tensor<5xf32> {
+ // expected-error @below{{op op without unique scf.yield is not supported}}
+ %0 = scf.execute_region -> tensor<5xf32> {
+ cf.br ^bb0(%t : tensor<5xf32>)
+ ^bb0(%arg0 : tensor<5xf32>):
+ cf.br ^bb1(%arg0: tensor<5xf32>)
+ ^bb1(%arg1 : tensor<5xf32>):
+ cf.br ^bb0(%arg1: tensor<5xf32>)
+ }
+ func.return %0 : tensor<5xf32>
+}
+
+// -----
+
func.func @inconsistent_memory_space_scf_for(%lb: index, %ub: index, %step: index) -> tensor<10xf32> {
%0 = bufferization.alloc_tensor() {memory_space = 0 : ui64} : tensor<10xf32>
%1 = bufferization.alloc_tensor() {memory_space = 1 : ui64} : tensor<10xf32>
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index c06fbe82e7362b..23bca584c83be1 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4022,6 +4022,24 @@ cc_library(
],
)
+cc_library(
+ name = "ControlFlowTransforms",
+ srcs = glob([
+ "lib/Dialect/ControlFlow/Transforms/*.cpp",
+ ]),
+ hdrs = glob([
+ "include/mlir/Dialect/ControlFlow/Transforms/*.h",
+ ]),
+ includes = ["include"],
+ deps = [
+ ":BufferizationDialect",
+ ":BufferizationTransforms",
+ ":ControlFlowDialect",
+ ":IR",
+ ":MemRefDialect",
+ ],
+)
+
cc_library(
name = "FuncDialect",
srcs = glob(
@@ -8203,6 +8221,7 @@ cc_library(
":ComplexToLibm",
":ComplexToSPIRV",
":ControlFlowDialect",
+ ":ControlFlowTransforms",
":ConversionPasses",
":ConvertToLLVM",
":DLTIDialect",
@@ -11907,11 +11926,13 @@ cc_library(
"lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
"lib/Dialect/Bufferization/IR/BufferizationDialect.cpp",
"lib/Dialect/Bufferization/IR/BufferizationOps.cpp",
+ "lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp",
],
hdrs = [
"include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
"include/mlir/Dialect/Bufferization/IR/Bufferization.h",
"include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h",
+ "include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h",
],
includes = ["include"],
deps = [
More information about the Mlir-commits
mailing list