[Mlir-commits] [mlir] 7c221a7 - [mlir][Symbol] Change Symbol from a Trait into an OpInterface.

River Riddle llvmlistbot at llvm.org
Mon Apr 27 13:06:15 PDT 2020


Author: River Riddle
Date: 2020-04-27T13:04:49-07:00
New Revision: 7c221a7d4fbce512656d9df202972230eb088f37

URL: https://github.com/llvm/llvm-project/commit/7c221a7d4fbce512656d9df202972230eb088f37
DIFF: https://github.com/llvm/llvm-project/commit/7c221a7d4fbce512656d9df202972230eb088f37.diff

LOG: [mlir][Symbol] Change Symbol from a Trait into an OpInterface.

This provides a much cleaner interface into Symbols, and allows for users to start injecting op-specific information. For example, derived op can now inject when a symbol can be discarded if use_empty. This would let us drop unused external functions, which generally have public visibility.

This revision also adds a new `extraTraitClassDeclaration` field to ODS OpInterface to allow for injecting declarations into the trait class that gets attached to the operations.

Differential Revision: https://reviews.llvm.org/D78522

Added: 
    mlir/include/mlir/IR/SymbolInterfaces.td

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
    mlir/include/mlir/IR/CMakeLists.txt
    mlir/include/mlir/IR/Function.h
    mlir/include/mlir/IR/Module.h
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/SymbolTable.h
    mlir/include/mlir/TableGen/OpInterfaces.h
    mlir/lib/IR/CMakeLists.txt
    mlir/lib/IR/SymbolTable.cpp
    mlir/lib/TableGen/OpInterfaces.cpp
    mlir/lib/Transforms/Inliner.cpp
    mlir/lib/Transforms/SymbolDCE.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/IR/TestSymbolUses.cpp
    mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 034534861b80..15d1bbf89001 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -14,6 +14,7 @@
 #ifndef FIR_DIALECT_FIR_OPS
 #define FIR_DIALECT_FIR_OPS
 
+include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffects.td"
 

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index c65c6c5f44aa..e48455e06340 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -15,6 +15,7 @@
 
 include "mlir/Dialect/GPU/GPUBase.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/SideEffects.td"
 
 // Type constraint accepting standard integers, indices and wrapped LLVM integer

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index c9ee88c77010..1bf9f6fa36f9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -14,6 +14,7 @@
 #define LLVMIR_OPS
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffects.td"
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index c9f27b971bb7..f83afab138c1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -16,6 +16,7 @@
 #define SPIRV_STRUCTURE_OPS
 
 include "mlir/Dialect/SPIRV/SPIRVBase.td"
+include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/SideEffects.td"
 

diff  --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 555b16fd29d0..c4e6c99ba640 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -2,3 +2,8 @@ set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
 mlir_tablegen(OpAsmInterface.h.inc -gen-op-interface-decls)
 mlir_tablegen(OpAsmInterface.cpp.inc -gen-op-interface-defs)
 add_public_tablegen_target(MLIROpAsmInterfacesIncGen)
+
+set(LLVM_TARGET_DEFINITIONS SymbolInterfaces.td)
+mlir_tablegen(SymbolInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(SymbolInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRSymbolInterfacesIncGen)

diff  --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index 1b2dd19eedab..0b725949576c 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -30,11 +30,10 @@ namespace mlir {
 /// implicitly capture global values, and all external references must use
 /// Function arguments or attributes that establish a symbolic connection(e.g.
 /// symbols referenced by name via a string attribute).
-class FuncOp
-    : public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
-                OpTrait::IsIsolatedFromAbove, OpTrait::Symbol,
-                OpTrait::FunctionLike, OpTrait::AutomaticAllocationScope,
-                CallableOpInterface::Trait> {
+class FuncOp : public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
+                         OpTrait::IsIsolatedFromAbove, OpTrait::FunctionLike,
+                         OpTrait::AutomaticAllocationScope,
+                         CallableOpInterface::Trait, SymbolOpInterface::Trait> {
 public:
   using Op::Op;
   using Op::print;

diff  --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h
index b02d10472b0d..c8adc15646ef 100644
--- a/mlir/include/mlir/IR/Module.h
+++ b/mlir/include/mlir/IR/Module.h
@@ -31,7 +31,8 @@ class ModuleOp
     : public Op<
           ModuleOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
           OpTrait::IsIsolatedFromAbove, OpTrait::SymbolTable,
-          OpTrait::SingleBlockImplicitTerminator<ModuleTerminatorOp>::Impl> {
+          OpTrait::SingleBlockImplicitTerminator<ModuleTerminatorOp>::Impl,
+          SymbolOpInterface::Trait> {
 public:
   using Op::Op;
   using Op::print;
@@ -95,6 +96,13 @@ class ModuleOp
       insertPt = Block::iterator(body->getTerminator());
     body->getOperations().insert(insertPt, op);
   }
+
+  //===--------------------------------------------------------------------===//
+  // SymbolOpInterface Methods
+  //===--------------------------------------------------------------------===//
+
+  /// A ModuleOp may optionally define a symbol.
+  bool isOptionalSymbol() { return true; }
 };
 
 /// The ModuleTerminatorOp is a special terminator operation for the body of a

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 7679d8b1008e..849ed1a1e6bc 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1658,10 +1658,6 @@ def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
 // Op has the same operand and result element type (or type itself, if scalar).
 def SameOperandsAndResultElementType :
   NativeOpTrait<"SameOperandsAndResultElementType">;
-// Op is a symbol.
-def Symbol : NativeOpTrait<"Symbol">;
-// Op defines a symbol table.
-def SymbolTable : NativeOpTrait<"SymbolTable">;
 // Op is a terminator.
 def Terminator : NativeOpTrait<"IsTerminator">;
 
@@ -1721,6 +1717,10 @@ class OpInterfaceTrait<string name, code verifyBody = [{}]> : NativeOpTrait<"">
   // Specify the body of the verification function. `$_op` will be replaced with
   // the operation being verified.
   code verify = verifyBody;
+
+  // An optional code block containing extra declarations to place in the
+  // interface trait declaration.
+  code extraTraitClassDeclaration = "";
 }
 
 // This class represents a single, optionally static, interface method.

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index c38c593ad6c5..d066c1599714 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1359,6 +1359,7 @@ class OpInterface : public Op<ConcreteType> {
 public:
   using Concept = typename Traits::Concept;
   template <typename T> using Model = typename Traits::template Model<T>;
+  using Base = OpInterface<ConcreteType, Traits>;
 
   OpInterface(Operation *op = nullptr)
       : Op<ConcreteType>(op), impl(op ? getInterfaceFor(op) : nullptr) {

diff  --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
new file mode 100644
index 000000000000..219ea6048f02
--- /dev/null
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -0,0 +1,155 @@
+//===- SymbolInterfaces.td - Interfaces for symbol ops -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains a set of interfaces and traits that can be used to define
+// properties of symbol and symbol table operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_SYMBOLINTERFACES
+#define MLIR_IR_SYMBOLINTERFACES
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// SymbolOpInterface
+//===----------------------------------------------------------------------===//
+
+def Symbol : OpInterface<"SymbolOpInterface"> {
+  let description = [{
+    This interface describes an operation that may define a `Symbol`. A `Symbol`
+    operation resides immediately within a region that defines a `SymbolTable`.
+    See [Symbols and SymbolTables](SymbolsAndSymbolTables.md) for more details
+    and constraints on `Symbol` operations.
+  }];
+
+  let methods = [
+    InterfaceMethod<"Returns the name of this symbol.",
+      "StringRef", "getName", (ins), [{
+        // Don't rely on the trait implementation as optional symbol operations
+        // may override this.
+        return mlir::SymbolTable::getSymbolName(op);
+      }], /*defaultImplementation=*/[{
+        return mlir::SymbolTable::getSymbolName(this->getOperation());
+      }]
+    >,
+    InterfaceMethod<"Sets the name of this symbol.",
+      "void", "setName", (ins "StringRef":$name), [{}],
+      /*defaultImplementation=*/[{
+        this->getOperation()->setAttr(
+            mlir::SymbolTable::getSymbolAttrName(),
+            StringAttr::get(name, this->getOperation()->getContext()));
+      }]
+    >,
+    InterfaceMethod<"Gets the visibility of this symbol.",
+      "mlir::SymbolTable::Visibility", "getVisibility", (ins), [{}],
+      /*defaultImplementation=*/[{
+        return mlir::SymbolTable::getSymbolVisibility(this->getOperation());
+      }]
+    >,
+    InterfaceMethod<"Sets the visibility of this symbol.",
+      "void", "setVisibility", (ins "mlir::SymbolTable::Visibility":$vis), [{}],
+      /*defaultImplementation=*/[{
+        mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis);
+      }]
+    >,
+    InterfaceMethod<[{
+        Get all of the uses of the current symbol that are nested within the
+        given operation 'from'.
+        Note: See mlir::SymbolTable::getSymbolUses for more details.
+      }],
+      "Optional<::mlir::SymbolTable::UseRange>", "getSymbolUses",
+      (ins "Operation *":$from), [{}],
+      /*defaultImplementation=*/[{
+        return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from);
+      }]
+    >,
+    InterfaceMethod<[{
+        Return if the current symbol is known to have no uses that are nested
+        within the given operation 'from'.
+        Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details.
+      }],
+      "bool", "symbolKnownUseEmpty", (ins "Operation *":$from), [{}],
+      /*defaultImplementation=*/[{
+        return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(),
+                                                        from);
+      }]
+    >,
+    InterfaceMethod<[{
+        Attempt to replace all uses of the current symbol with the provided
+        symbol 'newSymbol' that are nested within the given operation 'from'.
+        Note: See mlir::SymbolTable::replaceAllSymbolUses for more details.
+      }],
+      "LogicalResult", "replaceAllSymbolUses", (ins "StringRef":$newSymbol,
+                                                    "Operation *":$from), [{}],
+      /*defaultImplementation=*/[{
+        return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(),
+                                                         newSymbol, from);
+      }]
+    >,
+    InterfaceMethod<[{
+        Returns true if this operation optionally defines a symbol based on the
+        presence of the symbol name.
+      }],
+      "bool", "isOptionalSymbol", (ins), [{}],
+      /*defaultImplementation=*/[{ return false; }]
+    >,
+    InterfaceMethod<[{
+        Returns true if this operation can be discarded if it has no remaining
+        symbol uses.
+      }],
+      "bool", "canDiscardOnUseEmpty", (ins), [{}],
+      /*defaultImplementation=*/[{
+        // By default, base this on the visibility alone. A symbol can be
+        // discarded as long as it is not public. Only public symbols may be
+        // visible from outside of the IR.
+        return getVisibility() != ::mlir::SymbolTable::Visibility::Public;
+      }]
+    >,
+  ];
+
+  let verify = [{
+    // If this is an optional symbol, bail out early if possible.
+    auto concreteOp = cast<ConcreteOp>($_op);
+    if (concreteOp.isOptionalSymbol()) {
+      if(!concreteOp.getAttr(::mlir::SymbolTable::getSymbolAttrName()))
+        return success();
+    }
+    return ::mlir::detail::verifySymbol($_op);
+  }];
+
+  let extraClassDeclaration = [{
+    using Visibility = mlir::SymbolTable::Visibility;
+
+    /// Custom classof that handles the case where the symbol is optional.
+    static bool classof(Operation *op) {
+      return Base::classof(op)
+        && op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
+    }
+
+    /// Returns true if this symbol has nested visibility.
+    bool isNested() { return getVisibility() == Visibility::Nested; }
+    /// Returns true if this symbol has private visibility.
+    bool isPrivate() { return getVisibility() == Visibility::Private; }
+    /// Returns true if this symbol has public visibility.
+    bool isPublic() { return getVisibility() == Visibility::Public; }
+  }];
+
+  let extraTraitClassDeclaration = [{
+    using Visibility = mlir::SymbolTable::Visibility;
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Symbol Traits
+//===----------------------------------------------------------------------===//
+
+// Op defines a symbol table.
+def SymbolTable : NativeOpTrait<"SymbolTable">;
+
+#endif // MLIR_IR_SYMBOLINTERFACES

diff  --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index c61efb066e39..216948b2b3df 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -72,9 +72,6 @@ class SymbolTable {
     Nested,
   };
 
-  /// Returns true if the given operation defines a symbol.
-  static bool isSymbol(Operation *op);
-
   /// Returns the name of the given symbol operation.
   static StringRef getSymbolName(Operation *symbol);
   /// Sets the name of the given symbol operation.
@@ -207,12 +204,12 @@ class SymbolTable {
 // SymbolTable Trait Types
 //===----------------------------------------------------------------------===//
 
-namespace OpTrait {
-namespace impl {
+namespace detail {
 LogicalResult verifySymbolTable(Operation *op);
 LogicalResult verifySymbol(Operation *op);
-} // namespace impl
+} // namespace detail
 
+namespace OpTrait {
 /// A trait used to provide symbol table functionalities to a region operation.
 /// This operation must hold exactly 1 region. Once attached, all operations
 /// that are directly within the region, i.e not including those within child
@@ -224,7 +221,7 @@ template <typename ConcreteType>
 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
 public:
   static LogicalResult verifyTrait(Operation *op) {
-    return impl::verifySymbolTable(op);
+    return ::mlir::detail::verifySymbolTable(op);
   }
 
   /// Look up a symbol with the specified name, returning null if no such
@@ -245,68 +242,11 @@ class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
   }
 };
 
-/// A trait used to define a symbol that can be used on operations within a
-/// symbol table. Operations using this trait must adhere to the following:
-///   * Have a StringAttr attribute named 'SymbolTable::getSymbolAttrName()'.
-template <typename ConcreteType>
-class Symbol : public TraitBase<ConcreteType, Symbol> {
-public:
-  using Visibility = mlir::SymbolTable::Visibility;
-
-  static LogicalResult verifyTrait(Operation *op) {
-    return impl::verifySymbol(op);
-  }
-
-  /// Returns the name of this symbol.
-  StringRef getName() {
-    return this->getOperation()
-        ->template getAttrOfType<StringAttr>(
-            mlir::SymbolTable::getSymbolAttrName())
-        .getValue();
-  }
-
-  /// Set the name of this symbol.
-  void setName(StringRef name) {
-    this->getOperation()->setAttr(
-        mlir::SymbolTable::getSymbolAttrName(),
-        StringAttr::get(name, this->getOperation()->getContext()));
-  }
-
-  /// Returns the visibility of the current symbol.
-  Visibility getVisibility() {
-    return mlir::SymbolTable::getSymbolVisibility(this->getOperation());
-  }
-
-  /// Sets the visibility of the current symbol.
-  void setVisibility(Visibility vis) {
-    mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis);
-  }
-
-  /// Get all of the uses of the current symbol that are nested within the given
-  /// operation 'from'.
-  /// Note: See mlir::SymbolTable::getSymbolUses for more details.
-  Optional<::mlir::SymbolTable::UseRange> getSymbolUses(Operation *from) {
-    return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from);
-  }
-
-  /// Return if the current symbol is known to have no uses that are nested
-  /// within the given operation 'from'.
-  /// Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details.
-  bool symbolKnownUseEmpty(Operation *from) {
-    return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(), from);
-  }
+} // end namespace OpTrait
 
-  /// Attempt to replace all uses of the current symbol with the provided symbol
-  /// 'newSymbol' that are nested within the given operation 'from'.
-  /// Note: See mlir::SymbolTable::replaceAllSymbolUses for more details.
-  LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef newSymbol,
-                                                    Operation *from) {
-    return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(),
-                                                     newSymbol, from);
-  }
-};
+/// Include the generated symbol interfaces.
+#include "mlir/IR/SymbolInterfaces.h.inc"
 
-} // end namespace OpTrait
 } // end namespace mlir
 
 #endif // MLIR_IR_SYMBOLTABLE_H

diff  --git a/mlir/include/mlir/TableGen/OpInterfaces.h b/mlir/include/mlir/TableGen/OpInterfaces.h
index 2e1a63cf6636..0e1b943ce382 100644
--- a/mlir/include/mlir/TableGen/OpInterfaces.h
+++ b/mlir/include/mlir/TableGen/OpInterfaces.h
@@ -89,6 +89,9 @@ class OpInterface {
   // Return the interfaces extra class declaration code.
   llvm::Optional<StringRef> getExtraClassDeclaration() const;
 
+  // Return the traits extra class declaration code.
+  llvm::Optional<StringRef> getExtraTraitClassDeclaration() const;
+
   // Return the verify method body if it has one.
   llvm::Optional<StringRef> getVerify() const;
 

diff  --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 64998e4252c3..88c36eee4c77 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_library(MLIRIR
   DEPENDS
   MLIRCallInterfacesIncGen
   MLIROpAsmInterfacesIncGen
+  MLIRSymbolInterfacesIncGen
   )
 target_link_libraries(MLIRIR
   PUBLIC

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 2b1d99b0a363..487b51de8dc9 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -146,11 +146,6 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
   setSymbolName(symbol, nameBuffer);
 }
 
-/// Returns true if the given operation defines a symbol.
-bool SymbolTable::isSymbol(Operation *op) {
-  return op->hasTrait<OpTrait::Symbol>() || getNameIfSymbol(op).hasValue();
-}
-
 /// Returns the name of the given symbol operation.
 StringRef SymbolTable::getSymbolName(Operation *symbol) {
   Optional<StringRef> name = getNameIfSymbol(symbol);
@@ -286,7 +281,7 @@ Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
 // SymbolTable Trait Types
 //===----------------------------------------------------------------------===//
 
-LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
+LogicalResult detail::verifySymbolTable(Operation *op) {
   if (op->getNumRegions() != 1)
     return op->emitOpError()
            << "Operations with a 'SymbolTable' must have exactly one region";
@@ -316,7 +311,7 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
   return success();
 }
 
-LogicalResult OpTrait::impl::verifySymbol(Operation *op) {
+LogicalResult detail::verifySymbol(Operation *op) {
   // Verify the name attribute.
   if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
     return op->emitOpError() << "requires string attribute '"
@@ -866,3 +861,10 @@ LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
                                                 Region *from) {
   return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
 }
+
+//===----------------------------------------------------------------------===//
+// Symbol Interfaces
+//===----------------------------------------------------------------------===//
+
+/// Include the generated symbol interfaces.
+#include "mlir/IR/SymbolInterfaces.cpp.inc"

diff  --git a/mlir/lib/TableGen/OpInterfaces.cpp b/mlir/lib/TableGen/OpInterfaces.cpp
index c565547b2e09..be3782c78809 100644
--- a/mlir/lib/TableGen/OpInterfaces.cpp
+++ b/mlir/lib/TableGen/OpInterfaces.cpp
@@ -92,6 +92,12 @@ llvm::Optional<StringRef> OpInterface::getExtraClassDeclaration() const {
   return value.empty() ? llvm::Optional<StringRef>() : value;
 }
 
+// Return the traits extra class declaration code.
+llvm::Optional<StringRef> OpInterface::getExtraTraitClassDeclaration() const {
+  auto value = def->getValueAsString("extraTraitClassDeclaration");
+  return value.empty() ? llvm::Optional<StringRef>() : value;
+}
+
 // Return the body for this method if it has one.
 llvm::Optional<StringRef> OpInterface::getVerify() const {
   auto value = def->getValueAsString("verify");

diff  --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 10ad848a5bb9..28c8216f8333 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -31,26 +31,6 @@ using namespace mlir;
 // Symbol Use Tracking
 //===----------------------------------------------------------------------===//
 
-/// Returns true if this operation can be discarded if it is a symbol and has no
-/// uses. 'allUsesVisible' corresponds to if the parent symbol table is hidden
-/// from above.
-static bool canDiscardSymbolOnUseEmpty(Operation *op, bool allUsesVisible) {
-  if (!SymbolTable::isSymbol(op))
-    return false;
-
-  // TODO: This is essentially the same logic from SymbolDCE. Remove this when
-  // we have a 'Symbol' interface.
-  // Private symbols are always initially considered dead.
-  SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
-  if (visibility == mlir::SymbolTable::Visibility::Private)
-    return true;
-  // We only include nested visibility here if all uses are visible.
-  if (allUsesVisible && visibility == SymbolTable::Visibility::Nested)
-    return true;
-  // Otherwise, public symbols are never removable.
-  return false;
-}
-
 /// Walk all of the symbol table operations nested with 'op' along with a
 /// boolean signifying if the symbols within can be treated as if all uses are
 /// visible. The provided callback is invoked with the symbol table operation,
@@ -59,9 +39,8 @@ static bool canDiscardSymbolOnUseEmpty(Operation *op, bool allUsesVisible) {
 static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
                              function_ref<void(Operation *, bool)> callback) {
   if (op->hasTrait<OpTrait::SymbolTable>()) {
-    allSymUsesVisible = allSymUsesVisible || !SymbolTable::isSymbol(op) ||
-                        SymbolTable::getSymbolVisibility(op) ==
-                            SymbolTable::Visibility::Private;
+    SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
+    allSymUsesVisible = allSymUsesVisible || !symbol || symbol.isPrivate();
     callback(op, allSymUsesVisible);
   } else {
     // Otherwise if 'op' is not a symbol table, any nested symbols are
@@ -171,8 +150,11 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
         // If this is a callgraph operation, check to see if it is discardable.
         if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
           if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
-            if (canDiscardSymbolOnUseEmpty(&op, allUsesVisible))
+            SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
+            if (symbol && (allUsesVisible || symbol.isPrivate()) &&
+                symbol.canDiscardOnUseEmpty()) {
               discardableSymNodeUses.try_emplace(node, 0);
+            }
             continue;
           }
         }
@@ -224,7 +206,7 @@ void CGUseList::eraseNode(CallGraphNode *node) {
 bool CGUseList::isDead(CallGraphNode *node) const {
   // If the parent operation isn't a symbol, simply check normal SSA deadness.
   Operation *nodeOp = node->getCallableRegion()->getParentOp();
-  if (!SymbolTable::isSymbol(nodeOp))
+  if (!isa<SymbolOpInterface>(nodeOp))
     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
 
   // Otherwise, check the number of symbol uses.
@@ -235,7 +217,7 @@ bool CGUseList::isDead(CallGraphNode *node) const {
 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
   // If this isn't a symbol node, check for side-effects and SSA use count.
   Operation *nodeOp = node->getCallableRegion()->getParentOp();
-  if (!SymbolTable::isSymbol(nodeOp))
+  if (!isa<SymbolOpInterface>(nodeOp))
     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
 
   // Otherwise, check the number of symbol uses.

diff  --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp
index 581857a6a92e..56997b6d2af7 100644
--- a/mlir/lib/Transforms/SymbolDCE.cpp
+++ b/mlir/lib/Transforms/SymbolDCE.cpp
@@ -43,10 +43,9 @@ void SymbolDCE::runOnOperation() {
   // A flag that signals if the top level symbol table is hidden, i.e. not
   // accessible from parent scopes.
   bool symbolTableIsHidden = true;
-  if (symbolTableOp->getParentOp() && SymbolTable::isSymbol(symbolTableOp)) {
-    symbolTableIsHidden = SymbolTable::getSymbolVisibility(symbolTableOp) ==
-                          SymbolTable::Visibility::Private;
-  }
+  SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
+  if (symbolTableOp->getParentOp() && symbol)
+    symbolTableIsHidden = symbol.isPrivate();
 
   // Compute the set of live symbols within the symbol table.
   DenseSet<Operation *> liveSymbols;
@@ -61,7 +60,7 @@ void SymbolDCE::runOnOperation() {
     for (auto &block : nestedSymbolTable->getRegion(0)) {
       for (Operation &op :
            llvm::make_early_inc_range(block.without_terminator())) {
-        if (SymbolTable::isSymbol(&op) && !liveSymbols.count(&op))
+        if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op))
           op.erase();
       }
     }
@@ -80,30 +79,16 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
   // Walk the symbols within the current symbol table, marking the symbols that
   // are known to be live.
   for (auto &block : symbolTableOp->getRegion(0)) {
+    // Add all non-symbols or symbols that can't be discarded.
     for (Operation &op : block.without_terminator()) {
-      // Always add non symbol operations to the worklist.
-      if (!SymbolTable::isSymbol(&op)) {
+      SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
+      if (!symbol) {
         worklist.push_back(&op);
         continue;
       }
-
-      // Check the visibility to see if this symbol may be referenced
-      // externally.
-      SymbolTable::Visibility visibility =
-          SymbolTable::getSymbolVisibility(&op);
-
-      // Private symbols are always initially considered dead.
-      if (visibility == mlir::SymbolTable::Visibility::Private)
-        continue;
-      // We only include nested visibility here if the symbol table isn't
-      // hidden.
-      if (symbolTableIsHidden && visibility == SymbolTable::Visibility::Nested)
-        continue;
-
-      // TODO(riverriddle) Add hooks here to allow symbols to provide additional
-      // information, e.g. linkage can be used to drop some symbols that may
-      // otherwise be considered "live".
-      if (liveSymbols.insert(&op).second)
+      bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
+                           symbol.canDiscardOnUseEmpty();
+      if (!isDiscardable && liveSymbols.insert(&op).second)
         worklist.push_back(&op);
     }
   }
@@ -117,10 +102,9 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
     if (op->hasTrait<OpTrait::SymbolTable>()) {
       // The internal symbol table is hidden if the parent is, if its not a
       // symbol, or if it is a private symbol.
-      bool symbolIsHidden = symbolTableIsHidden || !SymbolTable::isSymbol(op) ||
-                            SymbolTable::getSymbolVisibility(op) ==
-                                SymbolTable::Visibility::Private;
-      if (failed(computeLiveness(op, symbolIsHidden, liveSymbols)))
+      SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
+      bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
+      if (failed(computeLiveness(op, symIsHidden, liveSymbols)))
         return failure();
     }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index d5259639f4b5..000a5722a76a 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/IR/OpBase.td"
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/SideEffects.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"

diff  --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
index 13188485ec41..0ec7f8258050 100644
--- a/mlir/test/lib/IR/TestSymbolUses.cpp
+++ b/mlir/test/lib/IR/TestSymbolUses.cpp
@@ -66,7 +66,7 @@ struct SymbolUsesPass
     // Walk nested symbols.
     SmallVector<FuncOp, 4> deadFunctions;
     module.getBodyRegion().walk([&](Operation *nestedOp) {
-      if (SymbolTable::isSymbol(nestedOp))
+      if (isa<SymbolOpInterface>(nestedOp))
         return operateOnSymbol(nestedOp, module, deadFunctions);
       return WalkResult::advance();
     });

diff  --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index b988333c5254..12ba8d43c9c1 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -174,6 +174,8 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
     os << "  static LogicalResult verifyTrait(Operation* op) {\n"
        << std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n  }\n";
   }
+  if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
+    os << extraTraitDecls << "\n";
 
   os << "  };\n";
 }


        


More information about the Mlir-commits mailing list