[Mlir-commits] [mlir] 7bc7d0a - [mlir] Optimize symbol related checks in SymbolDCE

River Riddle llvmlistbot at llvm.org
Fri Oct 16 12:09:23 PDT 2020


Author: River Riddle
Date: 2020-10-16T12:08:48-07:00
New Revision: 7bc7d0ac7ae2e2c578463758422214e80ce5e056

URL: https://github.com/llvm/llvm-project/commit/7bc7d0ac7ae2e2c578463758422214e80ce5e056
DIFF: https://github.com/llvm/llvm-project/commit/7bc7d0ac7ae2e2c578463758422214e80ce5e056.diff

LOG: [mlir] Optimize symbol related checks in SymbolDCE

This revision contains two optimizations related to symbol checking:
* Optimize SymbolOpInterface to only check for a name attribute if the operation is an optional symbol.
This removes an otherwise unnecessary attribute lookup from a majority of symbols.
* Add a new SymbolTableCollection class to represent a collection of SymbolTables.
This allows for perfoming non-flat symbol lookups in O(1) time by caching SymbolTables for symbol table operations. This class is very useful for algorithms that operate on multiple symbol tables, either recursively or not.

Differential Revision: https://reviews.llvm.org/D89505

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/SymbolInterfaces.td
    mlir/include/mlir/IR/SymbolTable.h
    mlir/lib/IR/SymbolTable.cpp
    mlir/lib/Transforms/SymbolDCE.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 9c1cf3d841dd..a9f706a228dc 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1536,7 +1536,7 @@ class OpInterface
   /// Inherit the base class constructor.
   using InterfaceBase::InterfaceBase;
 
-private:
+protected:
   /// Returns the impl interface instance for the given operation.
   static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
     // Access the raw interface from the abstract operation.

diff  --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 148551324868..7435e1aebda0 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -145,8 +145,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
   let extraClassDeclaration = [{
     /// Custom classof that handles the case where the symbol is optional.
     static bool classof(Operation *op) {
-      return Base::classof(op)
-        && op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
+      auto *concept = getInterfaceFor(op);
+      if (!concept)
+        return false;
+      return !concept->isOptionalSymbol(op) ||
+             op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
     }
   }];
 

diff  --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 4e41fe55d2c5..00b1573365a7 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -210,6 +210,40 @@ class SymbolTable {
   unsigned uniquingCounter = 0;
 };
 
+//===----------------------------------------------------------------------===//
+// SymbolTableCollection
+//===----------------------------------------------------------------------===//
+
+/// This class represents a collection of `SymbolTable`s. This simplifies
+/// certain algorithms that run recursively on nested symbol tables. Symbol
+/// tables are constructed lazily to reduce the upfront cost of constructing
+/// unnecessary tables.
+class SymbolTableCollection {
+public:
+  /// Look up a symbol with the specified name within the specified symbol table
+  /// operation, returning null if no such name exists.
+  Operation *lookupSymbolIn(Operation *symbolTableOp, StringRef symbol);
+  Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
+  template <typename T, typename NameT>
+  T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const {
+    return dyn_cast_or_null<T>(
+        lookupSymbolIn(symbolTableOp, std::forward<NameT>(name)));
+  }
+  /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
+  /// by a given SymbolRefAttr when resolved within the provided symbol table
+  /// operation. Returns failure if any of the nested references could not be
+  /// resolved.
+  LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
+                               SmallVectorImpl<Operation *> &symbols);
+
+  /// Lookup, or create, a symbol table for an operation.
+  SymbolTable &getSymbolTable(Operation *op);
+
+private:
+  /// The constructed symbol tables nested within this table.
+  DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables;
+};
+
 //===----------------------------------------------------------------------===//
 // SymbolTable Trait Types
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index e18e691f8cc8..5aeb9bdad65f 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -258,13 +258,16 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
   return resolvedSymbols.back();
 }
 
-LogicalResult
-SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
-                            SmallVectorImpl<Operation *> &symbols) {
+/// Internal implementation of `lookupSymbolIn` that allows for specialized
+/// implementations of the lookup function.
+static LogicalResult lookupSymbolInImpl(
+    Operation *symbolTableOp, SymbolRefAttr symbol,
+    SmallVectorImpl<Operation *> &symbols,
+    function_ref<Operation *(Operation *, StringRef)> lookupSymbolFn) {
   assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
 
   // Lookup the root reference for this symbol.
-  symbolTableOp = lookupSymbolIn(symbolTableOp, symbol.getRootReference());
+  symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
   if (!symbolTableOp)
     return failure();
   symbols.push_back(symbolTableOp);
@@ -281,15 +284,24 @@ SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
   // Otherwise, lookup each of the nested non-leaf references and ensure that
   // each corresponds to a valid symbol table.
   for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
-    symbolTableOp = lookupSymbolIn(symbolTableOp, ref.getValue());
+    symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getValue());
     if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
       return failure();
     symbols.push_back(symbolTableOp);
   }
-  symbols.push_back(lookupSymbolIn(symbolTableOp, symbol.getLeafReference()));
+  symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
   return success(symbols.back());
 }
 
+LogicalResult
+SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
+                            SmallVectorImpl<Operation *> &symbols) {
+  auto lookupFn = [](Operation *symbolTableOp, StringRef symbol) {
+    return lookupSymbolIn(symbolTableOp, symbol);
+  };
+  return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
+}
+
 /// Returns the operation registered with the given symbol name within the
 /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
 /// nullptr if no valid symbol was found.
@@ -887,6 +899,42 @@ LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
   return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
 }
 
+//===----------------------------------------------------------------------===//
+// SymbolTableCollection
+//===----------------------------------------------------------------------===//
+
+Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
+                                                 StringRef symbol) {
+  return getSymbolTable(symbolTableOp).lookup(symbol);
+}
+Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
+                                                 SymbolRefAttr name) {
+  SmallVector<Operation *, 4> symbols;
+  if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
+    return nullptr;
+  return symbols.back();
+}
+/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced by
+/// a given SymbolRefAttr. Returns failure if any of the nested references could
+/// not be resolved.
+LogicalResult
+SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
+                                      SymbolRefAttr name,
+                                      SmallVectorImpl<Operation *> &symbols) {
+  auto lookupFn = [this](Operation *symbolTableOp, StringRef symbol) {
+    return lookupSymbolIn(symbolTableOp, symbol);
+  };
+  return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
+}
+
+/// Lookup, or create, a symbol table for an operation.
+SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
+  auto it = symbolTables.try_emplace(op, nullptr);
+  if (it.second)
+    it.first->second = std::make_unique<SymbolTable>(op);
+  return *it.first->second;
+}
+
 //===----------------------------------------------------------------------===//
 // Symbol Interfaces
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp
index 56997b6d2af7..5d94245489c0 100644
--- a/mlir/lib/Transforms/SymbolDCE.cpp
+++ b/mlir/lib/Transforms/SymbolDCE.cpp
@@ -24,6 +24,7 @@ struct SymbolDCE : public SymbolDCEBase<SymbolDCE> {
   /// `symbolTableIsHidden` is true if this symbol table is known to be
   /// unaccessible from operations in its parent regions.
   LogicalResult computeLiveness(Operation *symbolTableOp,
+                                SymbolTableCollection &symbolTable,
                                 bool symbolTableIsHidden,
                                 DenseSet<Operation *> &liveSymbols);
 };
@@ -49,7 +50,9 @@ void SymbolDCE::runOnOperation() {
 
   // Compute the set of live symbols within the symbol table.
   DenseSet<Operation *> liveSymbols;
-  if (failed(computeLiveness(symbolTableOp, symbolTableIsHidden, liveSymbols)))
+  SymbolTableCollection symbolTable;
+  if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
+                             liveSymbols)))
     return signalPassFailure();
 
   // After computing the liveness, delete all of the symbols that were found to
@@ -71,6 +74,7 @@ void SymbolDCE::runOnOperation() {
 /// `symbolTableIsHidden` is true if this symbol table is known to be
 /// unaccessible from operations in its parent regions.
 LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
+                                         SymbolTableCollection &symbolTable,
                                          bool symbolTableIsHidden,
                                          DenseSet<Operation *> &liveSymbols) {
   // A worklist of live operations to propagate uses from.
@@ -104,7 +108,7 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
       // symbol, or if it is a private symbol.
       SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
       bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
-      if (failed(computeLiveness(op, symIsHidden, liveSymbols)))
+      if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
         return failure();
     }
 
@@ -120,7 +124,7 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
     for (const SymbolTable::SymbolUse &use : *uses) {
       // Lookup the symbols referenced by this use.
       resolvedSymbols.clear();
-      if (failed(SymbolTable::lookupSymbolIn(
+      if (failed(symbolTable.lookupSymbolIn(
               op->getParentOp(), use.getSymbolRef(), resolvedSymbols))) {
         return use.getUser()->emitError()
                << "unable to resolve reference to symbol "


        


More information about the Mlir-commits mailing list