[Mlir-commits] [mlir] 71eeb5e - [mlir] Add a new SymbolUserOpInterface class

River Riddle llvmlistbot at llvm.org
Fri Oct 16 12:09:25 PDT 2020


Author: River Riddle
Date: 2020-10-16T12:08:48-07:00
New Revision: 71eeb5ec4d6edbbef31fec83fe75933d48f101df

URL: https://github.com/llvm/llvm-project/commit/71eeb5ec4d6edbbef31fec83fe75933d48f101df
DIFF: https://github.com/llvm/llvm-project/commit/71eeb5ec4d6edbbef31fec83fe75933d48f101df.diff

LOG: [mlir] Add a new SymbolUserOpInterface class

The initial goal of this interface is to fix the current problems with verifying symbol user operations, but can extend beyond that in the future. The current problems with the verification of symbol uses are:
* Extremely inefficient:
Most current symbol users perform the symbol lookup using the slow O(N) string compare methods, which can lead to extremely long verification times in large modules.
* Invalid/break the constraints of verification pass
If the symbol reference is not-flat(and even if it is flat in some cases) a verifier for an operation is not permitted to touch the referenced operation because it may be in the process of being mutated by a different thread within the pass manager.

The new SymbolUserOpInterface exposes a method `verifySymbolUses` that will be invoked from the parent symbol table to allow for verifying the constraints of any referenced symbols. This method is passed a `SymbolTableCollection` to allow for O(1) lookups of any necessary symbol operation.

Differential Revision: https://reviews.llvm.org/D89512

Added: 
    

Modified: 
    mlir/docs/Interfaces.md
    mlir/docs/SymbolsAndSymbolTables.md
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/SymbolInterfaces.td
    mlir/include/mlir/IR/SymbolTable.h
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/IR/SymbolTable.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md
index 633e43e42da4..9e32a70888aa 100644
--- a/mlir/docs/Interfaces.md
+++ b/mlir/docs/Interfaces.md
@@ -231,4 +231,12 @@ format of the header for each interface section goes as follows:
 
 ##### SymbolInterfaces
 
-*   `SymbolOpInterface` - Used to represent [`Symbol`](SymbolsAndSymbolTables.md#symbol) operations which reside immediately within a region that defines a [`SymbolTable`](SymbolsAndSymbolTables.md#symbol-table).
+*   `SymbolOpInterface` - Used to represent
+    [`Symbol`](SymbolsAndSymbolTables.md#symbol) operations which reside
+    immediately within a region that defines a
+    [`SymbolTable`](SymbolsAndSymbolTables.md#symbol-table).
+
+*   `SymbolUserOpInterface` - Used to represent operations that reference
+    [`Symbol`](SymbolsAndSymbolTables.md#symbol) operations. This provides the
+    ability to perform safe and efficient verification of symbol uses, as well
+    as additional functionality.

diff  --git a/mlir/docs/SymbolsAndSymbolTables.md b/mlir/docs/SymbolsAndSymbolTables.md
index 2b4301e43d2f..a2cd828fd4d1 100644
--- a/mlir/docs/SymbolsAndSymbolTables.md
+++ b/mlir/docs/SymbolsAndSymbolTables.md
@@ -142,6 +142,10 @@ See the `LangRef` definition of the
 [`SymbolRefAttr`](LangRef.md#symbol-reference-attribute) for more information
 about the structure of this attribute.
 
+Operations that reference a `Symbol` and want to perform verification and
+general mutation of the symbol should implement the `SymbolUserOpInterface` to
+ensure that symbol accesses are legal and efficient.
+
 ### Manipulating a Symbol
 
 As described above, `SymbolRefs` act as an auxiliary way of defining uses of

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 97ab739890be..42bfa100e96e 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -15,6 +15,7 @@
 
 include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -733,7 +734,9 @@ def BranchOp : Std_Op<"br",
 // CallOp
 //===----------------------------------------------------------------------===//
 
-def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> {
+def CallOp : Std_Op<"call",
+    [CallOpInterface, MemRefsNormalizable,
+     DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = "call operation";
   let description = [{
     The `call` operation represents a direct call to a function that is within
@@ -788,6 +791,7 @@ def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> {
   let assemblyFormat = [{
     $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
   }];
+  let verifier = ?;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 7435e1aebda0..c5f2bb8f631d 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -158,6 +158,27 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// SymbolUserOpInterface
+//===----------------------------------------------------------------------===//
+
+def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> {
+  let description = [{
+    This interface describes an operation that may use a `Symbol`. This
+    interface allows for users of symbols to hook into verification and other
+    symbol related utilities that are either costly or otherwise disallowed
+    within a traditional operation.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<"Verify the symbol uses held by this operation.",
+      "LogicalResult", "verifySymbolUses",
+      (ins "::mlir::SymbolTableCollection &":$symbolTable)
+    >,
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // Symbol Traits
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 00b1573365a7..d54d7590f2bb 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -236,6 +236,21 @@ class SymbolTableCollection {
   LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
                                SmallVectorImpl<Operation *> &symbols);
 
+  /// Returns the operation registered with the given symbol name within the
+  /// closest parent operation of, or including, 'from' with the
+  /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
+  /// found.
+  Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
+  Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
+  template <typename T>
+  T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
+    return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
+  }
+  template <typename T>
+  T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
+    return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
+  }
+
   /// Lookup, or create, a symbol table for an operation.
   SymbolTable &getSymbolTable(Operation *op);
 

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index b5d1429829e5..444b729ee751 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -740,34 +740,33 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
 // CallOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(CallOp op) {
+LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   // Check that the callee attribute was specified.
-  auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
+  auto fnAttr = getAttrOfType<FlatSymbolRefAttr>("callee");
   if (!fnAttr)
-    return op.emitOpError("requires a 'callee' symbol reference attribute");
-  auto fn =
-      op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
+    return emitOpError("requires a 'callee' symbol reference attribute");
+  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
   if (!fn)
-    return op.emitOpError() << "'" << fnAttr.getValue()
-                            << "' does not reference a valid function";
+    return emitOpError() << "'" << fnAttr.getValue()
+                         << "' does not reference a valid function";
 
   // Verify that the operand and result types match the callee.
   auto fnType = fn.getType();
-  if (fnType.getNumInputs() != op.getNumOperands())
-    return op.emitOpError("incorrect number of operands for callee");
+  if (fnType.getNumInputs() != getNumOperands())
+    return emitOpError("incorrect number of operands for callee");
 
   for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
-    if (op.getOperand(i).getType() != fnType.getInput(i))
-      return op.emitOpError("operand type mismatch: expected operand type ")
+    if (getOperand(i).getType() != fnType.getInput(i))
+      return emitOpError("operand type mismatch: expected operand type ")
              << fnType.getInput(i) << ", but provided "
-             << op.getOperand(i).getType() << " for operand number " << i;
+             << getOperand(i).getType() << " for operand number " << i;
 
-  if (fnType.getNumResults() != op.getNumResults())
-    return op.emitOpError("incorrect number of results for callee");
+  if (fnType.getNumResults() != getNumResults())
+    return emitOpError("incorrect number of results for callee");
 
   for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
-    if (op.getResult(i).getType() != fnType.getResult(i))
-      return op.emitOpError("result type mismatch");
+    if (getResult(i).getType() != fnType.getResult(i))
+      return emitOpError("result type mismatch");
 
   return success();
 }

diff  --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 5aeb9bdad65f..8c84c69d7db4 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -68,6 +68,30 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
   return success();
 }
 
+/// Walk all of the operations within the given set of regions, without
+/// traversing into any nested symbol tables. Stops walking if the result of the
+/// callback is anything other than `WalkResult::advance`.
+static Optional<WalkResult>
+walkSymbolTable(MutableArrayRef<Region> regions,
+                function_ref<Optional<WalkResult>(Operation *)> callback) {
+  SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
+  while (!worklist.empty()) {
+    for (Operation &op : worklist.pop_back_val()->getOps()) {
+      Optional<WalkResult> result = callback(&op);
+      if (result != WalkResult::advance())
+        return result;
+
+      // If this op defines a new symbol table scope, we can't traverse. Any
+      // symbol references nested within 'op' are 
diff erent semantically.
+      if (!op.hasTrait<OpTrait::SymbolTable>()) {
+        for (Region &region : op.getRegions())
+          worklist.push_back(&region);
+      }
+    }
+  }
+  return WalkResult::advance();
+}
+
 //===----------------------------------------------------------------------===//
 // SymbolTable
 //===----------------------------------------------------------------------===//
@@ -347,7 +371,18 @@ LogicalResult detail::verifySymbolTable(Operation *op) {
             .append("see existing symbol definition here");
     }
   }
-  return success();
+
+  // Verify any nested symbol user operations.
+  SymbolTableCollection symbolTable;
+  auto verifySymbolUserFn = [&](Operation *op) -> Optional<WalkResult> {
+    if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
+      return WalkResult(user.verifySymbolUses(symbolTable));
+    return WalkResult::advance();
+  };
+
+  Optional<WalkResult> result =
+      walkSymbolTable(op->getRegions(), verifySymbolUserFn);
+  return success(result && !result->wasInterrupted());
 }
 
 LogicalResult detail::verifySymbol(Operation *op) {
@@ -452,25 +487,13 @@ static WalkResult walkSymbolRefs(
 static Optional<WalkResult> walkSymbolUses(
     MutableArrayRef<Region> regions,
     function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
-  SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
-  while (!worklist.empty()) {
-    for (Operation &op : worklist.pop_back_val()->getOps()) {
-      if (walkSymbolRefs(&op, callback).wasInterrupted())
-        return WalkResult::interrupt();
-
-      // Check that this isn't a potentially unknown symbol table.
-      if (isPotentiallyUnknownSymbolTable(&op))
-        return llvm::None;
+  return walkSymbolTable(regions, [&](Operation *op) -> Optional<WalkResult> {
+    // Check that this isn't a potentially unknown symbol table.
+    if (isPotentiallyUnknownSymbolTable(op))
+      return llvm::None;
 
-      // If this op defines a new symbol table scope, we can't traverse. Any
-      // symbol references nested within 'op' are 
diff erent semantically.
-      if (!op.hasTrait<OpTrait::SymbolTable>()) {
-        for (Region &region : op.getRegions())
-          worklist.push_back(&region);
-      }
-    }
-  }
-  return WalkResult::advance();
+    return walkSymbolRefs(op, callback);
+  });
 }
 /// Walk all of the uses, for any symbol, that are nested within the given
 /// operation 'from', invoking the provided callback for each. This does not
@@ -927,6 +950,22 @@ SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
   return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
 }
 
+/// Returns the operation registered with the given symbol name within the
+/// closest parent operation of, or including, 'from' with the
+/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
+/// found.
+Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
+                                                          StringRef symbol) {
+  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
+  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
+}
+Operation *
+SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
+                                               SymbolRefAttr symbol) {
+  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
+  return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
+}
+
 /// Lookup, or create, a symbol table for an operation.
 SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
   auto it = symbolTables.try_emplace(op, nullptr);


        


More information about the Mlir-commits mailing list