[Mlir-commits] [mlir] 621d7cc - [mlir] Add a new BranchOpInterface to allow for opaquely interfacing with branching terminator operations.
River Riddle
llvmlistbot at llvm.org
Thu Mar 5 12:58:47 PST 2020
Author: River Riddle
Date: 2020-03-05T12:50:35-08:00
New Revision: 621d7cca3751f934f991e34fe0e26187c33314f4
URL: https://github.com/llvm/llvm-project/commit/621d7cca3751f934f991e34fe0e26187c33314f4
DIFF: https://github.com/llvm/llvm-project/commit/621d7cca3751f934f991e34fe0e26187c33314f4.diff
LOG: [mlir] Add a new BranchOpInterface to allow for opaquely interfacing with branching terminator operations.
This interface contains the necessary components to provide the same builtin behavior that terminators have. This will be used in future revisions to remove many of the hardcoded constraints placed on successors and successor operands. The interface initially contains three methods:
```c++
// Return a set of values corresponding to the operands for successor 'index', or None if the operands do not correspond to materialized values.
Optional<OperandRange> getSuccessorOperands(unsigned index);
// Return true if this terminator can have it's successor operands erased.
bool canEraseSuccessorOperand();
// Erase the operand of a successor. This is only valid to call if 'canEraseSuccessorOperand' returns true.
void eraseSuccessorOperand(unsigned succIdx, unsigned opIdx);
```
Differential Revision: https://reviews.llvm.org/D75314
Added:
mlir/include/mlir/Analysis/ControlFlowInterfaces.h
mlir/include/mlir/Analysis/ControlFlowInterfaces.td
mlir/lib/Analysis/ControlFlowInterfaces.cpp
Modified:
mlir/include/mlir/Analysis/CMakeLists.txt
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/Analysis/CMakeLists.txt
mlir/lib/Dialect/LLVMIR/CMakeLists.txt
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/SPIRV/CMakeLists.txt
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/StandardOps/CMakeLists.txt
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/OperationSupport.cpp
mlir/test/lib/TestDialect/CMakeLists.txt
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestDialect.h
mlir/test/lib/TestDialect/TestOps.td
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/CMakeLists.txt b/mlir/include/mlir/Analysis/CMakeLists.txt
index 3d9a7ed36979..cc8c4939b73b 100644
--- a/mlir/include/mlir/Analysis/CMakeLists.txt
+++ b/mlir/include/mlir/Analysis/CMakeLists.txt
@@ -3,6 +3,11 @@ mlir_tablegen(CallInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(CallInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRCallOpInterfacesIncGen)
+set(LLVM_TARGET_DEFINITIONS ControlFlowInterfaces.td)
+mlir_tablegen(ControlFlowInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(ControlFlowInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRControlFlowInterfacesIncGen)
+
set(LLVM_TARGET_DEFINITIONS InferTypeOpInterface.td)
mlir_tablegen(InferTypeOpInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(InferTypeOpInterface.cpp.inc -gen-op-interface-defs)
diff --git a/mlir/include/mlir/Analysis/ControlFlowInterfaces.h b/mlir/include/mlir/Analysis/ControlFlowInterfaces.h
new file mode 100644
index 000000000000..87f82040ed61
--- /dev/null
+++ b/mlir/include/mlir/Analysis/ControlFlowInterfaces.h
@@ -0,0 +1,43 @@
+//===- ControlFlowInterfaces.h - ControlFlow Interfaces ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the definitions of the branch interfaces defined in
+// `ControlFlowInterfaces.td`.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H
+#define MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+class BranchOpInterface;
+
+namespace detail {
+/// Erase an operand from a branch operation that is used as a successor
+/// operand. `operandIndex` is the operand within `operands` to be erased.
+void eraseBranchSuccessorOperand(OperandRange operands, unsigned operandIndex,
+ Operation *op);
+
+/// Return the `BlockArgument` corresponding to operand `operandIndex` in some
+/// successor if `operandIndex` is within the range of `operands`, or None if
+/// `operandIndex` isn't a successor operand index.
+Optional<BlockArgument>
+getBranchSuccessorArgument(Optional<OperandRange> operands,
+ unsigned operandIndex, Block *successor);
+
+/// Verify that the given operands match those of the given successor block.
+LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
+ Optional<OperandRange> operands);
+} // namespace detail
+
+#include "mlir/Analysis/ControlFlowInterfaces.h.inc"
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Analysis/ControlFlowInterfaces.td b/mlir/include/mlir/Analysis/ControlFlowInterfaces.td
new file mode 100644
index 000000000000..b34b7a9d7e81
--- /dev/null
+++ b/mlir/include/mlir/Analysis/ControlFlowInterfaces.td
@@ -0,0 +1,85 @@
+//===-- ControlFlowInterfaces.td - ControlFlow Interfaces --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of interfaces that can be used to define information
+// about control flow operations, e.g. branches.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_CONTROLFLOWINTERFACES
+#define MLIR_ANALYSIS_CONTROLFLOWINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// BranchOpInterface
+//===----------------------------------------------------------------------===//
+
+def BranchOpInterface : OpInterface<"BranchOpInterface"> {
+ let description = [{
+ This interface provides information for branching terminator operations,
+ i.e. terminator operations with successors.
+ }];
+ let methods = [
+ InterfaceMethod<[{
+ Returns a set of values that correspond to the arguments to the
+ successor at the given index. Returns None if the operands to the
+ successor are non-materialized values, i.e. they are internal to the
+ operation.
+ }],
+ "Optional<OperandRange>", "getSuccessorOperands", (ins "unsigned":$index)
+ >,
+ InterfaceMethod<[{
+ Return true if this operation can erase an operand to a successor block.
+ }],
+ "bool", "canEraseSuccessorOperand"
+ >,
+ InterfaceMethod<[{
+ Erase the operand at `operandIndex` from the `index`-th successor. This
+ should only be called if `canEraseSuccessorOperand` returns true.
+ }],
+ "void", "eraseSuccessorOperand",
+ (ins "unsigned":$index, "unsigned":$operandIndex), [{}],
+ /*defaultImplementation=*/[{
+ ConcreteOp *op = static_cast<ConcreteOp *>(this);
+ Optional<OperandRange> operands = op->getSuccessorOperands(index);
+ assert(operands && "unable to query operands for successor");
+ detail::eraseBranchSuccessorOperand(*operands, operandIndex, *op);
+ }]
+ >,
+ InterfaceMethod<[{
+ Returns the `BlockArgument` corresponding to operand `operandIndex` in
+ some successor, or None if `operandIndex` isn't a successor operand
+ index.
+ }],
+ "Optional<BlockArgument>", "getSuccessorBlockArgument",
+ (ins "unsigned":$operandIndex), [{
+ Operation *opaqueOp = op;
+ for (unsigned i = 0, e = opaqueOp->getNumSuccessors(); i != e; ++i) {
+ if (Optional<BlockArgument> arg = detail::getBranchSuccessorArgument(
+ op.getSuccessorOperands(i), operandIndex,
+ opaqueOp->getSuccessor(i)))
+ return arg;
+ }
+ return llvm::None;
+ }]
+ >
+ ];
+
+ let verify = [{
+ auto concreteOp = cast<ConcreteOpType>($_op);
+ for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) {
+ Optional<OperandRange> operands = concreteOp.getSuccessorOperands(i);
+ if (failed(detail::verifyBranchSuccessorOperands($_op, i, operands)))
+ return failure();
+ }
+ return success();
+ }];
+}
+
+#endif // MLIR_ANALYSIS_CONTROLFLOWINTERFACES
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index d1378b827c52..a8b0be0793d6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
#define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
+#include "mlir/Analysis/ControlFlowInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b4585407edce..2ab9041bc2a0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -14,6 +14,7 @@
#define LLVMIR_OPS
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Analysis/ControlFlowInterfaces.td"
class LLVM_Builder<string builder> {
string llvmBuilder = builder;
@@ -315,7 +316,9 @@ def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt">;
def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">;
// Call-related operations.
-def LLVM_InvokeOp : LLVM_Op<"invoke", [Terminator]>,
+def LLVM_InvokeOp : LLVM_Op<"invoke", [
+ DeclareOpInterfaceMethods<BranchOpInterface>, Terminator
+ ]>,
Arguments<(ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>)>,
Results<(outs Variadic<LLVM_Type>)> {
@@ -458,11 +461,13 @@ def LLVM_FreezeOp : LLVM_OneResultOp<"freeze", [SameOperandsAndResultType]>,
}
// Terminators.
-def LLVM_BrOp : LLVM_TerminatorOp<"br", []> {
+def LLVM_BrOp : LLVM_TerminatorOp<"br",
+ [DeclareOpInterfaceMethods<BranchOpInterface>]> {
let successors = (successor AnySuccessor:$dest);
let assemblyFormat = "$dest attr-dict";
}
-def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", []> {
+def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
+ [DeclareOpInterfaceMethods<BranchOpInterface>]> {
let arguments = (ins LLVMI1:$condition);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let assemblyFormat = "$condition `,` successors attr-dict";
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
index 03884afe3e95..5ef825af4ea1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -16,10 +16,13 @@
include "mlir/Dialect/SPIRV/SPIRVBase.td"
include "mlir/Analysis/CallInterfaces.td"
+include "mlir/Analysis/ControlFlowInterfaces.td"
// -----
-def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
+def SPV_BranchOp : SPV_Op<"Branch",[
+ DeclareOpInterfaceMethods<BranchOpInterface>, InFunctionScope,
+ Terminator]> {
let summary = "Unconditional branch to target block.";
let description = [{
@@ -75,8 +78,9 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
// -----
-def SPV_BranchConditionalOp : SPV_Op<"BranchConditional",
- [InFunctionScope, Terminator]> {
+def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [
+ DeclareOpInterfaceMethods<BranchOpInterface>, InFunctionScope,
+ Terminator]> {
let summary = [{
If Condition is true, branch to true block, otherwise branch to false
block.
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
index ea541c056a11..e223f17933f3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
@@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_SPIRV_SPIRVOPS_H_
#define MLIR_DIALECT_SPIRV_SPIRVOPS_H_
+#include "mlir/Analysis/ControlFlowInterfaces.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Function.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 1e19c0270416..10822764ff6d 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -15,6 +15,7 @@
#define MLIR_DIALECT_STANDARDOPS_IR_OPS_H
#include "mlir/Analysis/CallInterfaces.h"
+#include "mlir/Analysis/ControlFlowInterfaces.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 2fe0365408e2..e44f8ff18a3a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -14,6 +14,7 @@
#define STANDARD_OPS
include "mlir/Analysis/CallInterfaces.td"
+include "mlir/Analysis/ControlFlowInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
def Std_Dialect : Dialect {
@@ -331,7 +332,8 @@ def AtomicRMWOp : Std_Op<"atomic_rmw", [
// BranchOp
//===----------------------------------------------------------------------===//
-def BranchOp : Std_Op<"br", [Terminator]> {
+def BranchOp : Std_Op<"br",
+ [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator]> {
let summary = "branch operation";
let description = [{
The "br" operation represents a branch operation in a function.
@@ -668,7 +670,8 @@ def CmpIOp : Std_Op<"cmpi",
// CondBranchOp
//===----------------------------------------------------------------------===//
-def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
+def CondBranchOp : Std_Op<"cond_br",
+ [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator]> {
let summary = "conditional branch operation";
let description = [{
The "cond_br" operation represents a conditional branch operation in a
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 96545d6be2b0..7735dd176b08 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -639,6 +639,10 @@ class OperandRange final
type_range getTypes() const { return {begin(), end()}; }
auto getType() const { return getTypes(); }
+ /// Return the operand index of the first element of this range. The range
+ /// must not be empty.
+ unsigned getBeginOperandIndex() const;
+
private:
/// See `detail::indexed_accessor_range_base` for details.
static OpOperand *offset_base(OpOperand *object, ptr
diff _t index) {
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 9eccde56e6ee..44ec43c96233 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
AffineAnalysis.cpp
AffineStructures.cpp
CallGraph.cpp
+ ControlFlowInterfaces.cpp
Dominance.cpp
InferTypeOpInterface.cpp
Liveness.cpp
@@ -14,6 +15,7 @@ set(LLVM_OPTIONAL_SOURCES
add_llvm_library(MLIRAnalysis
CallGraph.cpp
+ ControlFlowInterfaces.cpp
InferTypeOpInterface.cpp
Liveness.cpp
SliceAnalysis.cpp
@@ -26,6 +28,7 @@ add_llvm_library(MLIRAnalysis
add_dependencies(MLIRAnalysis
MLIRAffineOps
MLIRCallOpInterfacesIncGen
+ MLIRControlFlowInterfacesIncGen
MLIRTypeInferOpInterfaceIncGen
MLIRLoopOps
)
@@ -45,6 +48,7 @@ add_llvm_library(MLIRLoopAnalysis
add_dependencies(MLIRLoopAnalysis
MLIRAffineOps
MLIRCallOpInterfacesIncGen
+ MLIRControlFlowInterfacesIncGen
MLIRTypeInferOpInterfaceIncGen
MLIRLoopOps
)
diff --git a/mlir/lib/Analysis/ControlFlowInterfaces.cpp b/mlir/lib/Analysis/ControlFlowInterfaces.cpp
new file mode 100644
index 000000000000..7d98f29d7cf6
--- /dev/null
+++ b/mlir/lib/Analysis/ControlFlowInterfaces.cpp
@@ -0,0 +1,101 @@
+//===- ControlFlowInterfaces.h - ControlFlow Interfaces -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/ControlFlowInterfaces.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// ControlFlowInterfaces
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/ControlFlowInterfaces.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// BranchOpInterface
+//===----------------------------------------------------------------------===//
+
+/// Erase an operand from a branch operation that is used as a successor
+/// operand. 'operandIndex' is the operand within 'operands' to be erased.
+void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands,
+ unsigned operandIndex,
+ Operation *op) {
+ assert(operandIndex < operands.size() &&
+ "invalid index for successor operands");
+
+ // Erase the operand from the operation.
+ size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex;
+ op->eraseOperand(fullOperandIndex);
+
+ // If this operation has an OperandSegmentSizeAttr, keep it up to date.
+ auto operandSegmentAttr =
+ op->getAttrOfType<DenseElementsAttr>("operand_segment_sizes");
+ if (!operandSegmentAttr)
+ return;
+
+ // Find the segment containing the full operand index and decrement it.
+ // TODO: This seems like a general utility that could be added somewhere.
+ SmallVector<int32_t, 4> values(operandSegmentAttr.getValues<int32_t>());
+ unsigned currentSize = 0;
+ for (unsigned i = 0, e = values.size(); i != e; ++i) {
+ currentSize += values[i];
+ if (fullOperandIndex < currentSize) {
+ --values[i];
+ break;
+ }
+ }
+ op->setAttr("operand_segment_sizes",
+ DenseIntElementsAttr::get(operandSegmentAttr.getType(), values));
+}
+
+/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
+/// successor if 'operandIndex' is within the range of 'operands', or None if
+/// `operandIndex` isn't a successor operand index.
+Optional<BlockArgument> mlir::detail::getBranchSuccessorArgument(
+ Optional<OperandRange> operands, unsigned operandIndex, Block *successor) {
+ // Check that the operands are valid.
+ if (!operands || operands->empty())
+ return llvm::None;
+
+ // Check to ensure that this operand is within the range.
+ unsigned operandsStart = operands->getBeginOperandIndex();
+ if (operandIndex < operandsStart ||
+ operandIndex >= (operandsStart + operands->size()))
+ return llvm::None;
+
+ // Index the successor.
+ unsigned argIndex = operandIndex - operandsStart;
+ return successor->getArgument(argIndex);
+}
+
+/// Verify that the given operands match those of the given successor block.
+LogicalResult
+mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
+ Optional<OperandRange> operands) {
+ if (!operands)
+ return success();
+
+ // Check the count.
+ unsigned operandCount = operands->size();
+ Block *destBB = op->getSuccessor(succNo);
+ if (operandCount != destBB->getNumArguments())
+ return op->emitError() << "branch has " << operandCount
+ << " operands for successor #" << succNo
+ << ", but target block has "
+ << destBB->getNumArguments();
+
+ // Check the types.
+ auto operandIt = operands->begin();
+ for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
+ if ((*operandIt).getType() != destBB->getArgument(i).getType())
+ return op->emitError() << "type mismatch for bb argument #" << i
+ << " of successor #" << succNo;
+ }
+ return success();
+}
diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
index 8eafbd8a2179..97b70cae9366 100644
--- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt
@@ -4,7 +4,7 @@ add_mlir_dialect_library(MLIRLLVMIR
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
)
-add_dependencies(MLIRLLVMIR MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen MLIROpenMP LLVMFrontendOpenMP LLVMAsmParser LLVMCore LLVMSupport)
+add_dependencies(MLIRLLVMIR MLIRControlFlowInterfacesIncGen MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen MLIROpenMP LLVMFrontendOpenMP LLVMAsmParser LLVMCore LLVMSupport)
target_link_libraries(MLIRLLVMIR LLVMAsmParser LLVMCore LLVMSupport LLVMFrontendOpenMP MLIROpenMP MLIRIR)
add_mlir_dialect_library(MLIRNVVMIR
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 75369142991b..567ddee94d7d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -153,6 +153,28 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
return success();
}
+//===----------------------------------------------------------------------===//
+// LLVM::BrOp
+//===----------------------------------------------------------------------===//
+
+Optional<OperandRange> BrOp::getSuccessorOperands(unsigned index) {
+ assert(index == 0 && "invalid successor index");
+ return getOperands();
+}
+
+bool BrOp::canEraseSuccessorOperand() { return true; }
+
+//===----------------------------------------------------------------------===//
+// LLVM::CondBrOp
+//===----------------------------------------------------------------------===//
+
+Optional<OperandRange> CondBrOp::getSuccessorOperands(unsigned index) {
+ assert(index < getNumSuccessors() && "invalid successor index");
+ return index == 0 ? trueDestOperands() : falseDestOperands();
+}
+
+bool CondBrOp::canEraseSuccessorOperand() { return true; }
+
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::LoadOp.
//===----------------------------------------------------------------------===//
@@ -229,9 +251,16 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
}
///===----------------------------------------------------------------------===//
-/// Verifying/Printing/Parsing for LLVM::InvokeOp.
+/// LLVM::InvokeOp
///===----------------------------------------------------------------------===//
+Optional<OperandRange> InvokeOp::getSuccessorOperands(unsigned index) {
+ assert(index < getNumSuccessors() && "invalid successor index");
+ return index == 0 ? normalDestOperands() : unwindDestOperands();
+}
+
+bool InvokeOp::canEraseSuccessorOperand() { return true; }
+
static LogicalResult verify(InvokeOp op) {
if (op.getNumResults() > 1)
return op.emitOpError("must have 0 or 1 result");
diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt
index ad1bb4df2b5b..7e77f3e7866e 100644
--- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRSPIRV
)
add_dependencies(MLIRSPIRV
+ MLIRControlFlowInterfacesIncGen
MLIRSPIRVAvailabilityIncGen
MLIRSPIRVCanonicalizationIncGen
MLIRSPIRVEnumAvailabilityIncGen
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 6a638673d584..907f8f82f62f 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -942,10 +942,30 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
return success();
}
+//===----------------------------------------------------------------------===//
+// spv.BranchOp
+//===----------------------------------------------------------------------===//
+
+Optional<OperandRange> spirv::BranchOp::getSuccessorOperands(unsigned index) {
+ assert(index == 0 && "invalid successor index");
+ return getOperands();
+}
+
+bool spirv::BranchOp::canEraseSuccessorOperand() { return true; }
+
//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//
+Optional<OperandRange>
+spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
+ assert(index < 2 && "invalid successor index");
+ return index == kTrueIndex ? getTrueBlockArguments()
+ : getFalseBlockArguments();
+}
+
+bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; }
+
static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
OperationState &state) {
auto &builder = parser.getBuilder();
diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
index c8af4702fbc1..9b8ffcdb1980 100644
--- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRStandardOps
add_dependencies(MLIRStandardOps
MLIRCallOpInterfacesIncGen
+ MLIRControlFlowInterfacesIncGen
MLIREDSC
MLIRIR
MLIRStandardOpsIncGen
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 1059e66d1fc5..6cb1f21ccda5 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -482,7 +482,7 @@ Block *BranchOp::getDest() { return getSuccessor(); }
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
void BranchOp::eraseOperand(unsigned index) {
- getOperation()->eraseSuccessorOperand(0, index);
+ getOperation()->eraseOperand(index);
}
void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
@@ -490,6 +490,13 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<SimplifyBrToBlockWithSinglePred>(context);
}
+Optional<OperandRange> BranchOp::getSuccessorOperands(unsigned index) {
+ assert(index == 0 && "invalid successor index");
+ return getOperands();
+}
+
+bool BranchOp::canEraseSuccessorOperand() { return true; }
+
//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
@@ -749,6 +756,13 @@ void CondBranchOp::getCanonicalizationPatterns(
results.insert<SimplifyConstCondBranchPred>(context);
}
+Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {
+ assert(index < getNumSuccessors() && "invalid successor index");
+ return index == trueIndex ? getTrueOperands() : getFalseOperands();
+}
+
+bool CondBranchOp::canEraseSuccessorOperand() { return true; }
+
//===----------------------------------------------------------------------===//
// Constant*Op
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index bfd4b40b317b..2af425d65406 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -950,37 +950,13 @@ LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
return success();
}
-static LogicalResult verifySuccessor(Operation *op, unsigned succNo) {
- Operation::operand_range operands = op->getSuccessorOperands(succNo);
- unsigned operandCount = op->getNumSuccessorOperands(succNo);
- Block *destBB = op->getSuccessor(succNo);
- if (operandCount != destBB->getNumArguments())
- return op->emitError() << "branch has " << operandCount
- << " operands for successor #" << succNo
- << ", but target block has "
- << destBB->getNumArguments();
-
- auto operandIt = operands.begin();
- for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) {
- if ((*operandIt).getType() != destBB->getArgument(i).getType())
- return op->emitError() << "type mismatch for bb argument #" << i
- << " of successor #" << succNo;
- }
-
- return success();
-}
-
static LogicalResult verifyTerminatorSuccessors(Operation *op) {
auto *parent = op->getParentRegion();
// Verify that the operands lines up with the BB arguments in the successor.
- for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
- auto *succ = op->getSuccessor(i);
+ for (Block *succ : op->getSuccessors())
if (succ->getParent() != parent)
return op->emitError("reference to block defined in another region");
- if (failed(verifySuccessor(op, i)))
- return failure();
- }
return success();
}
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 107fc483b96d..25859a562cea 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -183,6 +183,13 @@ Type TypeRange::dereference_iterator(OwnerT object, ptr
diff _t index) {
OperandRange::OperandRange(Operation *op)
: OperandRange(op->getOpOperands().data(), op->getNumOperands()) {}
+/// Return the operand index of the first element of this range. The range
+/// must not be empty.
+unsigned OperandRange::getBeginOperandIndex() const {
+ assert(!empty() && "range must not be empty");
+ return base->getOperandNumber();
+}
+
//===----------------------------------------------------------------------===//
// ResultRange
diff --git a/mlir/test/lib/TestDialect/CMakeLists.txt b/mlir/test/lib/TestDialect/CMakeLists.txt
index 15459b9abaa1..d81500912f4d 100644
--- a/mlir/test/lib/TestDialect/CMakeLists.txt
+++ b/mlir/test/lib/TestDialect/CMakeLists.txt
@@ -16,6 +16,7 @@ add_llvm_library(MLIRTestDialect
TestPatterns.cpp
)
add_dependencies(MLIRTestDialect
+ MLIRControlFlowInterfacesIncGen
MLIRTestOpsIncGen
MLIRTypeInferOpInterfaceIncGen
)
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index 3ded7b95b9f9..649b547626d9 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -163,6 +163,17 @@ TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
return success();
}
+//===----------------------------------------------------------------------===//
+// TestBranchOp
+//===----------------------------------------------------------------------===//
+
+Optional<OperandRange> TestBranchOp::getSuccessorOperands(unsigned index) {
+ assert(index == 0 && "invalid successor index");
+ return getOperands();
+}
+
+bool TestBranchOp::canEraseSuccessorOperand() { return true; }
+
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/TestDialect/TestDialect.h b/mlir/test/lib/TestDialect/TestDialect.h
index 300327630073..8228f31434f5 100644
--- a/mlir/test/lib/TestDialect/TestDialect.h
+++ b/mlir/test/lib/TestDialect/TestDialect.h
@@ -15,6 +15,7 @@
#define MLIR_TESTDIALECT_H
#include "mlir/Analysis/CallInterfaces.h"
+#include "mlir/Analysis/ControlFlowInterfaces.h"
#include "mlir/Analysis/InferTypeOpInterface.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Dialect.h"
diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index 26205e7eb855..5ee4a46eb818 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -11,6 +11,7 @@
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
+include "mlir/Analysis/ControlFlowInterfaces.td"
include "mlir/Analysis/CallInterfaces.td"
include "mlir/Analysis/InferTypeOpInterface.td"
@@ -446,6 +447,11 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
]> {
let arguments = (ins AnyTensor, AnyTensor);
let results = (outs AnyTensor);
+
+ let extraClassDeclaration = [{
+ LogicalResult reifyReturnTypeShapes(OpBuilder &builder,
+ SmallVectorImpl<Value> &shapes);
+ }];
}
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
@@ -454,7 +460,8 @@ def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
(I32ElementsAttrOp ConstantAttr<I32ElementsAttr, "0">),
[(IsNotScalar $attr)]>;
-def TestBranchOp : TEST_Op<"br", [Terminator]> {
+def TestBranchOp : TEST_Op<"br",
+ [DeclareOpInterfaceMethods<BranchOpInterface>, Terminator]> {
let successors = (successor AnySuccessor:$target);
}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index ebd82f9fb8bf..7aa51bdbf1cf 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1155,8 +1155,8 @@ void OpEmitter::genOpInterfaceMethods() {
continue;
auto interface = opTrait->getOpInterface();
for (auto method : interface.getMethods()) {
- // Don't declare if the method has a body.
- if (method.getBody())
+ // Don't declare if the method has a body or a default implementation.
+ if (method.getBody() || method.getDefaultImplementation())
continue;
std::string args;
llvm::raw_string_ostream os(args);
More information about the Mlir-commits
mailing list