[flang-commits] [flang] 7c221a7 - [mlir][Symbol] Change Symbol from a Trait into an OpInterface.
River Riddle via flang-commits
flang-commits at lists.llvm.org
Mon Apr 27 13:06:13 PDT 2020
Author: River Riddle
Date: 2020-04-27T13:04:49-07:00
New Revision: 7c221a7d4fbce512656d9df202972230eb088f37
URL: https://github.com/llvm/llvm-project/commit/7c221a7d4fbce512656d9df202972230eb088f37
DIFF: https://github.com/llvm/llvm-project/commit/7c221a7d4fbce512656d9df202972230eb088f37.diff
LOG: [mlir][Symbol] Change Symbol from a Trait into an OpInterface.
This provides a much cleaner interface into Symbols, and allows for users to start injecting op-specific information. For example, derived op can now inject when a symbol can be discarded if use_empty. This would let us drop unused external functions, which generally have public visibility.
This revision also adds a new `extraTraitClassDeclaration` field to ODS OpInterface to allow for injecting declarations into the trait class that gets attached to the operations.
Differential Revision: https://reviews.llvm.org/D78522
Added:
mlir/include/mlir/IR/SymbolInterfaces.td
Modified:
flang/include/flang/Optimizer/Dialect/FIROps.td
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/include/mlir/IR/CMakeLists.txt
mlir/include/mlir/IR/Function.h
mlir/include/mlir/IR/Module.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/SymbolTable.h
mlir/include/mlir/TableGen/OpInterfaces.h
mlir/lib/IR/CMakeLists.txt
mlir/lib/IR/SymbolTable.cpp
mlir/lib/TableGen/OpInterfaces.cpp
mlir/lib/Transforms/Inliner.cpp
mlir/lib/Transforms/SymbolDCE.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/IR/TestSymbolUses.cpp
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 034534861b80..15d1bbf89001 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -14,6 +14,7 @@
#ifndef FIR_DIALECT_FIR_OPS
#define FIR_DIALECT_FIR_OPS
+include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffects.td"
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index c65c6c5f44aa..e48455e06340 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -15,6 +15,7 @@
include "mlir/Dialect/GPU/GPUBase.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffects.td"
// Type constraint accepting standard integers, indices and wrapped LLVM integer
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index c9ee88c77010..1bf9f6fa36f9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -14,6 +14,7 @@
#define LLVMIR_OPS
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffects.td"
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index c9f27b971bb7..f83afab138c1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -16,6 +16,7 @@
#define SPIRV_STRUCTURE_OPS
include "mlir/Dialect/SPIRV/SPIRVBase.td"
+include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/SideEffects.td"
diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 555b16fd29d0..c4e6c99ba640 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -2,3 +2,8 @@ set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
mlir_tablegen(OpAsmInterface.h.inc -gen-op-interface-decls)
mlir_tablegen(OpAsmInterface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIROpAsmInterfacesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS SymbolInterfaces.td)
+mlir_tablegen(SymbolInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(SymbolInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRSymbolInterfacesIncGen)
diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index 1b2dd19eedab..0b725949576c 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -30,11 +30,10 @@ namespace mlir {
/// implicitly capture global values, and all external references must use
/// Function arguments or attributes that establish a symbolic connection(e.g.
/// symbols referenced by name via a string attribute).
-class FuncOp
- : public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
- OpTrait::IsIsolatedFromAbove, OpTrait::Symbol,
- OpTrait::FunctionLike, OpTrait::AutomaticAllocationScope,
- CallableOpInterface::Trait> {
+class FuncOp : public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
+ OpTrait::IsIsolatedFromAbove, OpTrait::FunctionLike,
+ OpTrait::AutomaticAllocationScope,
+ CallableOpInterface::Trait, SymbolOpInterface::Trait> {
public:
using Op::Op;
using Op::print;
diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h
index b02d10472b0d..c8adc15646ef 100644
--- a/mlir/include/mlir/IR/Module.h
+++ b/mlir/include/mlir/IR/Module.h
@@ -31,7 +31,8 @@ class ModuleOp
: public Op<
ModuleOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
OpTrait::IsIsolatedFromAbove, OpTrait::SymbolTable,
- OpTrait::SingleBlockImplicitTerminator<ModuleTerminatorOp>::Impl> {
+ OpTrait::SingleBlockImplicitTerminator<ModuleTerminatorOp>::Impl,
+ SymbolOpInterface::Trait> {
public:
using Op::Op;
using Op::print;
@@ -95,6 +96,13 @@ class ModuleOp
insertPt = Block::iterator(body->getTerminator());
body->getOperations().insert(insertPt, op);
}
+
+ //===--------------------------------------------------------------------===//
+ // SymbolOpInterface Methods
+ //===--------------------------------------------------------------------===//
+
+ /// A ModuleOp may optionally define a symbol.
+ bool isOptionalSymbol() { return true; }
};
/// The ModuleTerminatorOp is a special terminator operation for the body of a
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 7679d8b1008e..849ed1a1e6bc 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1658,10 +1658,6 @@ def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
// Op has the same operand and result element type (or type itself, if scalar).
def SameOperandsAndResultElementType :
NativeOpTrait<"SameOperandsAndResultElementType">;
-// Op is a symbol.
-def Symbol : NativeOpTrait<"Symbol">;
-// Op defines a symbol table.
-def SymbolTable : NativeOpTrait<"SymbolTable">;
// Op is a terminator.
def Terminator : NativeOpTrait<"IsTerminator">;
@@ -1721,6 +1717,10 @@ class OpInterfaceTrait<string name, code verifyBody = [{}]> : NativeOpTrait<"">
// Specify the body of the verification function. `$_op` will be replaced with
// the operation being verified.
code verify = verifyBody;
+
+ // An optional code block containing extra declarations to place in the
+ // interface trait declaration.
+ code extraTraitClassDeclaration = "";
}
// This class represents a single, optionally static, interface method.
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index c38c593ad6c5..d066c1599714 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1359,6 +1359,7 @@ class OpInterface : public Op<ConcreteType> {
public:
using Concept = typename Traits::Concept;
template <typename T> using Model = typename Traits::template Model<T>;
+ using Base = OpInterface<ConcreteType, Traits>;
OpInterface(Operation *op = nullptr)
: Op<ConcreteType>(op), impl(op ? getInterfaceFor(op) : nullptr) {
diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
new file mode 100644
index 000000000000..219ea6048f02
--- /dev/null
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -0,0 +1,155 @@
+//===- SymbolInterfaces.td - Interfaces for symbol ops -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of interfaces and traits that can be used to define
+// properties of symbol and symbol table operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_SYMBOLINTERFACES
+#define MLIR_IR_SYMBOLINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// SymbolOpInterface
+//===----------------------------------------------------------------------===//
+
+def Symbol : OpInterface<"SymbolOpInterface"> {
+ let description = [{
+ This interface describes an operation that may define a `Symbol`. A `Symbol`
+ operation resides immediately within a region that defines a `SymbolTable`.
+ See [Symbols and SymbolTables](SymbolsAndSymbolTables.md) for more details
+ and constraints on `Symbol` operations.
+ }];
+
+ let methods = [
+ InterfaceMethod<"Returns the name of this symbol.",
+ "StringRef", "getName", (ins), [{
+ // Don't rely on the trait implementation as optional symbol operations
+ // may override this.
+ return mlir::SymbolTable::getSymbolName(op);
+ }], /*defaultImplementation=*/[{
+ return mlir::SymbolTable::getSymbolName(this->getOperation());
+ }]
+ >,
+ InterfaceMethod<"Sets the name of this symbol.",
+ "void", "setName", (ins "StringRef":$name), [{}],
+ /*defaultImplementation=*/[{
+ this->getOperation()->setAttr(
+ mlir::SymbolTable::getSymbolAttrName(),
+ StringAttr::get(name, this->getOperation()->getContext()));
+ }]
+ >,
+ InterfaceMethod<"Gets the visibility of this symbol.",
+ "mlir::SymbolTable::Visibility", "getVisibility", (ins), [{}],
+ /*defaultImplementation=*/[{
+ return mlir::SymbolTable::getSymbolVisibility(this->getOperation());
+ }]
+ >,
+ InterfaceMethod<"Sets the visibility of this symbol.",
+ "void", "setVisibility", (ins "mlir::SymbolTable::Visibility":$vis), [{}],
+ /*defaultImplementation=*/[{
+ mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis);
+ }]
+ >,
+ InterfaceMethod<[{
+ Get all of the uses of the current symbol that are nested within the
+ given operation 'from'.
+ Note: See mlir::SymbolTable::getSymbolUses for more details.
+ }],
+ "Optional<::mlir::SymbolTable::UseRange>", "getSymbolUses",
+ (ins "Operation *":$from), [{}],
+ /*defaultImplementation=*/[{
+ return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from);
+ }]
+ >,
+ InterfaceMethod<[{
+ Return if the current symbol is known to have no uses that are nested
+ within the given operation 'from'.
+ Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details.
+ }],
+ "bool", "symbolKnownUseEmpty", (ins "Operation *":$from), [{}],
+ /*defaultImplementation=*/[{
+ return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(),
+ from);
+ }]
+ >,
+ InterfaceMethod<[{
+ Attempt to replace all uses of the current symbol with the provided
+ symbol 'newSymbol' that are nested within the given operation 'from'.
+ Note: See mlir::SymbolTable::replaceAllSymbolUses for more details.
+ }],
+ "LogicalResult", "replaceAllSymbolUses", (ins "StringRef":$newSymbol,
+ "Operation *":$from), [{}],
+ /*defaultImplementation=*/[{
+ return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(),
+ newSymbol, from);
+ }]
+ >,
+ InterfaceMethod<[{
+ Returns true if this operation optionally defines a symbol based on the
+ presence of the symbol name.
+ }],
+ "bool", "isOptionalSymbol", (ins), [{}],
+ /*defaultImplementation=*/[{ return false; }]
+ >,
+ InterfaceMethod<[{
+ Returns true if this operation can be discarded if it has no remaining
+ symbol uses.
+ }],
+ "bool", "canDiscardOnUseEmpty", (ins), [{}],
+ /*defaultImplementation=*/[{
+ // By default, base this on the visibility alone. A symbol can be
+ // discarded as long as it is not public. Only public symbols may be
+ // visible from outside of the IR.
+ return getVisibility() != ::mlir::SymbolTable::Visibility::Public;
+ }]
+ >,
+ ];
+
+ let verify = [{
+ // If this is an optional symbol, bail out early if possible.
+ auto concreteOp = cast<ConcreteOp>($_op);
+ if (concreteOp.isOptionalSymbol()) {
+ if(!concreteOp.getAttr(::mlir::SymbolTable::getSymbolAttrName()))
+ return success();
+ }
+ return ::mlir::detail::verifySymbol($_op);
+ }];
+
+ let extraClassDeclaration = [{
+ using Visibility = mlir::SymbolTable::Visibility;
+
+ /// Custom classof that handles the case where the symbol is optional.
+ static bool classof(Operation *op) {
+ return Base::classof(op)
+ && op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
+ }
+
+ /// Returns true if this symbol has nested visibility.
+ bool isNested() { return getVisibility() == Visibility::Nested; }
+ /// Returns true if this symbol has private visibility.
+ bool isPrivate() { return getVisibility() == Visibility::Private; }
+ /// Returns true if this symbol has public visibility.
+ bool isPublic() { return getVisibility() == Visibility::Public; }
+ }];
+
+ let extraTraitClassDeclaration = [{
+ using Visibility = mlir::SymbolTable::Visibility;
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Symbol Traits
+//===----------------------------------------------------------------------===//
+
+// Op defines a symbol table.
+def SymbolTable : NativeOpTrait<"SymbolTable">;
+
+#endif // MLIR_IR_SYMBOLINTERFACES
diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index c61efb066e39..216948b2b3df 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -72,9 +72,6 @@ class SymbolTable {
Nested,
};
- /// Returns true if the given operation defines a symbol.
- static bool isSymbol(Operation *op);
-
/// Returns the name of the given symbol operation.
static StringRef getSymbolName(Operation *symbol);
/// Sets the name of the given symbol operation.
@@ -207,12 +204,12 @@ class SymbolTable {
// SymbolTable Trait Types
//===----------------------------------------------------------------------===//
-namespace OpTrait {
-namespace impl {
+namespace detail {
LogicalResult verifySymbolTable(Operation *op);
LogicalResult verifySymbol(Operation *op);
-} // namespace impl
+} // namespace detail
+namespace OpTrait {
/// A trait used to provide symbol table functionalities to a region operation.
/// This operation must hold exactly 1 region. Once attached, all operations
/// that are directly within the region, i.e not including those within child
@@ -224,7 +221,7 @@ template <typename ConcreteType>
class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
public:
static LogicalResult verifyTrait(Operation *op) {
- return impl::verifySymbolTable(op);
+ return ::mlir::detail::verifySymbolTable(op);
}
/// Look up a symbol with the specified name, returning null if no such
@@ -245,68 +242,11 @@ class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
}
};
-/// A trait used to define a symbol that can be used on operations within a
-/// symbol table. Operations using this trait must adhere to the following:
-/// * Have a StringAttr attribute named 'SymbolTable::getSymbolAttrName()'.
-template <typename ConcreteType>
-class Symbol : public TraitBase<ConcreteType, Symbol> {
-public:
- using Visibility = mlir::SymbolTable::Visibility;
-
- static LogicalResult verifyTrait(Operation *op) {
- return impl::verifySymbol(op);
- }
-
- /// Returns the name of this symbol.
- StringRef getName() {
- return this->getOperation()
- ->template getAttrOfType<StringAttr>(
- mlir::SymbolTable::getSymbolAttrName())
- .getValue();
- }
-
- /// Set the name of this symbol.
- void setName(StringRef name) {
- this->getOperation()->setAttr(
- mlir::SymbolTable::getSymbolAttrName(),
- StringAttr::get(name, this->getOperation()->getContext()));
- }
-
- /// Returns the visibility of the current symbol.
- Visibility getVisibility() {
- return mlir::SymbolTable::getSymbolVisibility(this->getOperation());
- }
-
- /// Sets the visibility of the current symbol.
- void setVisibility(Visibility vis) {
- mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis);
- }
-
- /// Get all of the uses of the current symbol that are nested within the given
- /// operation 'from'.
- /// Note: See mlir::SymbolTable::getSymbolUses for more details.
- Optional<::mlir::SymbolTable::UseRange> getSymbolUses(Operation *from) {
- return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from);
- }
-
- /// Return if the current symbol is known to have no uses that are nested
- /// within the given operation 'from'.
- /// Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details.
- bool symbolKnownUseEmpty(Operation *from) {
- return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(), from);
- }
+} // end namespace OpTrait
- /// Attempt to replace all uses of the current symbol with the provided symbol
- /// 'newSymbol' that are nested within the given operation 'from'.
- /// Note: See mlir::SymbolTable::replaceAllSymbolUses for more details.
- LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef newSymbol,
- Operation *from) {
- return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(),
- newSymbol, from);
- }
-};
+/// Include the generated symbol interfaces.
+#include "mlir/IR/SymbolInterfaces.h.inc"
-} // end namespace OpTrait
} // end namespace mlir
#endif // MLIR_IR_SYMBOLTABLE_H
diff --git a/mlir/include/mlir/TableGen/OpInterfaces.h b/mlir/include/mlir/TableGen/OpInterfaces.h
index 2e1a63cf6636..0e1b943ce382 100644
--- a/mlir/include/mlir/TableGen/OpInterfaces.h
+++ b/mlir/include/mlir/TableGen/OpInterfaces.h
@@ -89,6 +89,9 @@ class OpInterface {
// Return the interfaces extra class declaration code.
llvm::Optional<StringRef> getExtraClassDeclaration() const;
+ // Return the traits extra class declaration code.
+ llvm::Optional<StringRef> getExtraTraitClassDeclaration() const;
+
// Return the verify method body if it has one.
llvm::Optional<StringRef> getVerify() const;
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 64998e4252c3..88c36eee4c77 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRIR
DEPENDS
MLIRCallInterfacesIncGen
MLIROpAsmInterfacesIncGen
+ MLIRSymbolInterfacesIncGen
)
target_link_libraries(MLIRIR
PUBLIC
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 2b1d99b0a363..487b51de8dc9 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -146,11 +146,6 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
setSymbolName(symbol, nameBuffer);
}
-/// Returns true if the given operation defines a symbol.
-bool SymbolTable::isSymbol(Operation *op) {
- return op->hasTrait<OpTrait::Symbol>() || getNameIfSymbol(op).hasValue();
-}
-
/// Returns the name of the given symbol operation.
StringRef SymbolTable::getSymbolName(Operation *symbol) {
Optional<StringRef> name = getNameIfSymbol(symbol);
@@ -286,7 +281,7 @@ Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
// SymbolTable Trait Types
//===----------------------------------------------------------------------===//
-LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
+LogicalResult detail::verifySymbolTable(Operation *op) {
if (op->getNumRegions() != 1)
return op->emitOpError()
<< "Operations with a 'SymbolTable' must have exactly one region";
@@ -316,7 +311,7 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
return success();
}
-LogicalResult OpTrait::impl::verifySymbol(Operation *op) {
+LogicalResult detail::verifySymbol(Operation *op) {
// Verify the name attribute.
if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
return op->emitOpError() << "requires string attribute '"
@@ -866,3 +861,10 @@ LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
Region *from) {
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
}
+
+//===----------------------------------------------------------------------===//
+// Symbol Interfaces
+//===----------------------------------------------------------------------===//
+
+/// Include the generated symbol interfaces.
+#include "mlir/IR/SymbolInterfaces.cpp.inc"
diff --git a/mlir/lib/TableGen/OpInterfaces.cpp b/mlir/lib/TableGen/OpInterfaces.cpp
index c565547b2e09..be3782c78809 100644
--- a/mlir/lib/TableGen/OpInterfaces.cpp
+++ b/mlir/lib/TableGen/OpInterfaces.cpp
@@ -92,6 +92,12 @@ llvm::Optional<StringRef> OpInterface::getExtraClassDeclaration() const {
return value.empty() ? llvm::Optional<StringRef>() : value;
}
+// Return the traits extra class declaration code.
+llvm::Optional<StringRef> OpInterface::getExtraTraitClassDeclaration() const {
+ auto value = def->getValueAsString("extraTraitClassDeclaration");
+ return value.empty() ? llvm::Optional<StringRef>() : value;
+}
+
// Return the body for this method if it has one.
llvm::Optional<StringRef> OpInterface::getVerify() const {
auto value = def->getValueAsString("verify");
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 10ad848a5bb9..28c8216f8333 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -31,26 +31,6 @@ using namespace mlir;
// Symbol Use Tracking
//===----------------------------------------------------------------------===//
-/// Returns true if this operation can be discarded if it is a symbol and has no
-/// uses. 'allUsesVisible' corresponds to if the parent symbol table is hidden
-/// from above.
-static bool canDiscardSymbolOnUseEmpty(Operation *op, bool allUsesVisible) {
- if (!SymbolTable::isSymbol(op))
- return false;
-
- // TODO: This is essentially the same logic from SymbolDCE. Remove this when
- // we have a 'Symbol' interface.
- // Private symbols are always initially considered dead.
- SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
- if (visibility == mlir::SymbolTable::Visibility::Private)
- return true;
- // We only include nested visibility here if all uses are visible.
- if (allUsesVisible && visibility == SymbolTable::Visibility::Nested)
- return true;
- // Otherwise, public symbols are never removable.
- return false;
-}
-
/// Walk all of the symbol table operations nested with 'op' along with a
/// boolean signifying if the symbols within can be treated as if all uses are
/// visible. The provided callback is invoked with the symbol table operation,
@@ -59,9 +39,8 @@ static bool canDiscardSymbolOnUseEmpty(Operation *op, bool allUsesVisible) {
static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
function_ref<void(Operation *, bool)> callback) {
if (op->hasTrait<OpTrait::SymbolTable>()) {
- allSymUsesVisible = allSymUsesVisible || !SymbolTable::isSymbol(op) ||
- SymbolTable::getSymbolVisibility(op) ==
- SymbolTable::Visibility::Private;
+ SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
+ allSymUsesVisible = allSymUsesVisible || !symbol || symbol.isPrivate();
callback(op, allSymUsesVisible);
} else {
// Otherwise if 'op' is not a symbol table, any nested symbols are
@@ -171,8 +150,11 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
// If this is a callgraph operation, check to see if it is discardable.
if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
- if (canDiscardSymbolOnUseEmpty(&op, allUsesVisible))
+ SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
+ if (symbol && (allUsesVisible || symbol.isPrivate()) &&
+ symbol.canDiscardOnUseEmpty()) {
discardableSymNodeUses.try_emplace(node, 0);
+ }
continue;
}
}
@@ -224,7 +206,7 @@ void CGUseList::eraseNode(CallGraphNode *node) {
bool CGUseList::isDead(CallGraphNode *node) const {
// If the parent operation isn't a symbol, simply check normal SSA deadness.
Operation *nodeOp = node->getCallableRegion()->getParentOp();
- if (!SymbolTable::isSymbol(nodeOp))
+ if (!isa<SymbolOpInterface>(nodeOp))
return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
// Otherwise, check the number of symbol uses.
@@ -235,7 +217,7 @@ bool CGUseList::isDead(CallGraphNode *node) const {
bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
// If this isn't a symbol node, check for side-effects and SSA use count.
Operation *nodeOp = node->getCallableRegion()->getParentOp();
- if (!SymbolTable::isSymbol(nodeOp))
+ if (!isa<SymbolOpInterface>(nodeOp))
return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
// Otherwise, check the number of symbol uses.
diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp
index 581857a6a92e..56997b6d2af7 100644
--- a/mlir/lib/Transforms/SymbolDCE.cpp
+++ b/mlir/lib/Transforms/SymbolDCE.cpp
@@ -43,10 +43,9 @@ void SymbolDCE::runOnOperation() {
// A flag that signals if the top level symbol table is hidden, i.e. not
// accessible from parent scopes.
bool symbolTableIsHidden = true;
- if (symbolTableOp->getParentOp() && SymbolTable::isSymbol(symbolTableOp)) {
- symbolTableIsHidden = SymbolTable::getSymbolVisibility(symbolTableOp) ==
- SymbolTable::Visibility::Private;
- }
+ SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
+ if (symbolTableOp->getParentOp() && symbol)
+ symbolTableIsHidden = symbol.isPrivate();
// Compute the set of live symbols within the symbol table.
DenseSet<Operation *> liveSymbols;
@@ -61,7 +60,7 @@ void SymbolDCE::runOnOperation() {
for (auto &block : nestedSymbolTable->getRegion(0)) {
for (Operation &op :
llvm::make_early_inc_range(block.without_terminator())) {
- if (SymbolTable::isSymbol(&op) && !liveSymbols.count(&op))
+ if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op))
op.erase();
}
}
@@ -80,30 +79,16 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
// Walk the symbols within the current symbol table, marking the symbols that
// are known to be live.
for (auto &block : symbolTableOp->getRegion(0)) {
+ // Add all non-symbols or symbols that can't be discarded.
for (Operation &op : block.without_terminator()) {
- // Always add non symbol operations to the worklist.
- if (!SymbolTable::isSymbol(&op)) {
+ SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
+ if (!symbol) {
worklist.push_back(&op);
continue;
}
-
- // Check the visibility to see if this symbol may be referenced
- // externally.
- SymbolTable::Visibility visibility =
- SymbolTable::getSymbolVisibility(&op);
-
- // Private symbols are always initially considered dead.
- if (visibility == mlir::SymbolTable::Visibility::Private)
- continue;
- // We only include nested visibility here if the symbol table isn't
- // hidden.
- if (symbolTableIsHidden && visibility == SymbolTable::Visibility::Nested)
- continue;
-
- // TODO(riverriddle) Add hooks here to allow symbols to provide additional
- // information, e.g. linkage can be used to drop some symbols that may
- // otherwise be considered "live".
- if (liveSymbols.insert(&op).second)
+ bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
+ symbol.canDiscardOnUseEmpty();
+ if (!isDiscardable && liveSymbols.insert(&op).second)
worklist.push_back(&op);
}
}
@@ -117,10 +102,9 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
if (op->hasTrait<OpTrait::SymbolTable>()) {
// The internal symbol table is hidden if the parent is, if its not a
// symbol, or if it is a private symbol.
- bool symbolIsHidden = symbolTableIsHidden || !SymbolTable::isSymbol(op) ||
- SymbolTable::getSymbolVisibility(op) ==
- SymbolTable::Visibility::Private;
- if (failed(computeLiveness(op, symbolIsHidden, liveSymbols)))
+ SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
+ bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
+ if (failed(computeLiveness(op, symIsHidden, liveSymbols)))
return failure();
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index d5259639f4b5..000a5722a76a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -11,6 +11,7 @@
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffects.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index 13188485ec41..0ec7f8258050 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -66,7 +66,7 @@ struct SymbolUsesPass
// Walk nested symbols.
SmallVector<FuncOp, 4> deadFunctions;
module.getBodyRegion().walk([&](Operation *nestedOp) {
- if (SymbolTable::isSymbol(nestedOp))
+ if (isa<SymbolOpInterface>(nestedOp))
return operateOnSymbol(nestedOp, module, deadFunctions);
return WalkResult::advance();
});
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index b988333c5254..12ba8d43c9c1 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -174,6 +174,8 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
os << " static LogicalResult verifyTrait(Operation* op) {\n"
<< std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n";
}
+ if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
+ os << extraTraitDecls << "\n";
os << " };\n";
}
More information about the flang-commits
mailing list