[Mlir-commits] [mlir] bdc3183 - [mlir][python] Implement more SymbolTable methods.

Stella Laurenzo llvmlistbot at llvm.org
Mon Nov 29 20:32:34 PST 2021


Author: Stella Laurenzo
Date: 2021-11-29T20:31:13-08:00
New Revision: bdc3183742f1e996d58bdf23b91966e64ad5e9a3

URL: https://github.com/llvm/llvm-project/commit/bdc3183742f1e996d58bdf23b91966e64ad5e9a3
DIFF: https://github.com/llvm/llvm-project/commit/bdc3183742f1e996d58bdf23b91966e64ad5e9a3.diff

LOG: [mlir][python] Implement more SymbolTable methods.

* set_symbol_name, get_symbol_name, set_visibility, get_visibility, replace_all_symbol_uses, walk_symbol_tables
* In integrations I've been doing, I've been reaching for all of these to do both general IR manipulation and module merging.
* I don't love the replace_all_symbol_uses underlying APIs since they necessitate SYMBOL_COUNT walks and have various sharp edges. I'm hoping that whatever emerges eventually for this can still retain this simple API as a one-shot.

Differential Revision: https://reviews.llvm.org/D114687

Added: 
    mlir/test/python/ir/symbol_table.py

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 6c1a92cea01d0..1d884e634b2f2 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -754,6 +754,9 @@ MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
 /// symbol tables.
 MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName();
 
+/// Returns the name of the attribute used to store symbol visibility.
+MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName();
+
 /// Creates a symbol table for the given operation. If the operation does not
 /// have the SymbolTable trait, returns a null symbol table.
 MLIR_CAPI_EXPORTED MlirSymbolTable
@@ -787,6 +790,23 @@ mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation);
 MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable,
                                              MlirOperation operation);
 
+/// Attempt to replace all uses that are nested within the given operation
+/// of the given symbol 'oldSymbol' with the provided 'newSymbol'. This does
+/// not traverse into nested symbol tables. Will fail atomically if there are
+/// any unknown operations that may be potential symbol tables.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(
+    MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from);
+
+/// Walks all symbol table operations nested within, and including, `op`. For
+/// each symbol table operation, the provided callback is invoked with the op
+/// and a boolean signifying if the symbols within that symbol table can be
+/// treated as if all uses within the IR are visible to the caller.
+/// `allSymUsesVisible` identifies whether all of the symbol uses of symbols
+/// within `op` are visible.
+MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(
+    MlirOperation from, bool allSymUsesVisible,
+    void (*callback)(MlirOperation, bool, void *userData), void *userData);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 8a110fcc4218b..0d349143306c8 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1596,6 +1596,112 @@ PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
 }
 
+PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
+  // Op must already be a symbol.
+  PyOperation &operation = symbol.getOperation();
+  operation.checkValid();
+  MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
+  MlirAttribute existingNameAttr =
+      mlirOperationGetAttributeByName(operation.get(), attrName);
+  if (mlirAttributeIsNull(existingNameAttr))
+    throw py::value_error("Expected operation to have a symbol name.");
+  return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
+}
+
+void PySymbolTable::setSymbolName(PyOperationBase &symbol,
+                                  const std::string &name) {
+  // Op must already be a symbol.
+  PyOperation &operation = symbol.getOperation();
+  operation.checkValid();
+  MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
+  MlirAttribute existingNameAttr =
+      mlirOperationGetAttributeByName(operation.get(), attrName);
+  if (mlirAttributeIsNull(existingNameAttr))
+    throw py::value_error("Expected operation to have a symbol name.");
+  MlirAttribute newNameAttr =
+      mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
+  mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
+}
+
+PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
+  PyOperation &operation = symbol.getOperation();
+  operation.checkValid();
+  MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
+  MlirAttribute existingVisAttr =
+      mlirOperationGetAttributeByName(operation.get(), attrName);
+  if (mlirAttributeIsNull(existingVisAttr))
+    throw py::value_error("Expected operation to have a symbol visibility.");
+  return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
+}
+
+void PySymbolTable::setVisibility(PyOperationBase &symbol,
+                                  const std::string &visibility) {
+  if (visibility != "public" && visibility != "private" &&
+      visibility != "nested")
+    throw py::value_error(
+        "Expected visibility to be 'public', 'private' or 'nested'");
+  PyOperation &operation = symbol.getOperation();
+  operation.checkValid();
+  MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
+  MlirAttribute existingVisAttr =
+      mlirOperationGetAttributeByName(operation.get(), attrName);
+  if (mlirAttributeIsNull(existingVisAttr))
+    throw py::value_error("Expected operation to have a symbol visibility.");
+  MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
+                                               toMlirStringRef(visibility));
+  mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
+}
+
+void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
+                                         const std::string &newSymbol,
+                                         PyOperationBase &from) {
+  PyOperation &fromOperation = from.getOperation();
+  fromOperation.checkValid();
+  if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
+          toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
+          from.getOperation())))
+
+    throw py::value_error("Symbol rename failed");
+}
+
+void PySymbolTable::walkSymbolTables(PyOperationBase &from,
+                                     bool allSymUsesVisible,
+                                     py::object callback) {
+  PyOperation &fromOperation = from.getOperation();
+  fromOperation.checkValid();
+  struct UserData {
+    PyMlirContextRef context;
+    py::object callback;
+    bool gotException;
+    std::string exceptionWhat;
+    py::object exceptionType;
+  };
+  UserData userData{
+      fromOperation.getContext(), std::move(callback), false, {}, {}};
+  mlirSymbolTableWalkSymbolTables(
+      fromOperation.get(), allSymUsesVisible,
+      [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
+        UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
+        auto pyFoundOp =
+            PyOperation::forOperation(calleeUserData->context, foundOp);
+        if (calleeUserData->gotException)
+          return;
+        try {
+          calleeUserData->callback(pyFoundOp.getObject(), isVisible);
+        } catch (py::error_already_set &e) {
+          calleeUserData->gotException = true;
+          calleeUserData->exceptionWhat = e.what();
+          calleeUserData->exceptionType = e.type();
+        }
+      },
+      static_cast<void *>(&userData));
+  if (userData.gotException) {
+    std::string message("Exception raised in callback: ");
+    message.append(userData.exceptionWhat);
+    throw std::runtime_error(std::move(message));
+  }
+}
+
 namespace {
 /// CRTP base class for Python MLIR values that subclass Value and should be
 /// castable from it. The value hierarchy is one level deep and is not supposed
@@ -2773,10 +2879,26 @@ void mlir::python::populateIRCore(py::module &m) {
       .def("insert", &PySymbolTable::insert, py::arg("operation"))
       .def("erase", &PySymbolTable::erase, py::arg("operation"))
       .def("__delitem__", &PySymbolTable::dunderDel)
-      .def("__contains__", [](PySymbolTable &table, const std::string &name) {
-        return !mlirOperationIsNull(mlirSymbolTableLookup(
-            table, mlirStringRefCreate(name.data(), name.length())));
-      });
+      .def("__contains__",
+           [](PySymbolTable &table, const std::string &name) {
+             return !mlirOperationIsNull(mlirSymbolTableLookup(
+                 table, mlirStringRefCreate(name.data(), name.length())));
+           })
+      // Static helpers.
+      .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
+                  py::arg("symbol"), py::arg("name"))
+      .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
+                  py::arg("symbol"))
+      .def_static("get_visibility", &PySymbolTable::getVisibility,
+                  py::arg("symbol"))
+      .def_static("set_visibility", &PySymbolTable::setVisibility,
+                  py::arg("symbol"), py::arg("visibility"))
+      .def_static("replace_all_symbol_uses",
+                  &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
+                  py::arg("new_symbol"), py::arg("from_op"))
+      .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
+                  py::arg("from_op"), py::arg("all_sym_uses_visible"),
+                  py::arg("callback"));
 
   // Container bindings.
   PyBlockArgumentList::bind(m);

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index f0d0cc654eabb..d5e8eb4aece55 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -910,6 +910,25 @@ class PySymbolTable {
   /// the symbol trait.
   PyAttribute insert(PyOperationBase &symbol);
 
+  /// Gets and sets the name of a symbol op.
+  static PyAttribute getSymbolName(PyOperationBase &symbol);
+  static void setSymbolName(PyOperationBase &symbol, const std::string &name);
+
+  /// Gets and sets the visibility of a symbol op.
+  static PyAttribute getVisibility(PyOperationBase &symbol);
+  static void setVisibility(PyOperationBase &symbol,
+                            const std::string &visibility);
+
+  /// Replaces all symbol uses within an operation. See the API
+  /// mlirSymbolTableReplaceAllSymbolUses for all caveats.
+  static void replaceAllSymbolUses(const std::string &oldSymbol,
+                                   const std::string &newSymbol,
+                                   PyOperationBase &from);
+
+  /// Walks all symbol tables under and including 'from'.
+  static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible,
+                               pybind11::object callback);
+
   /// Casts the bindings class into the C API structure.
   operator MlirSymbolTable() { return symbolTable; }
 

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 35a059275ffb2..424bbae179c33 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -786,6 +786,10 @@ MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
   return wrap(SymbolTable::getSymbolAttrName());
 }
 
+MlirStringRef mlirSymbolTableGetVisibilityAttributeName() {
+  return wrap(SymbolTable::getVisibilityAttrName());
+}
+
 MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
   if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>())
     return wrap(static_cast<SymbolTable *>(nullptr));
@@ -810,3 +814,25 @@ void mlirSymbolTableErase(MlirSymbolTable symbolTable,
                           MlirOperation operation) {
   unwrap(symbolTable)->erase(unwrap(operation));
 }
+
+MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
+                                                      MlirStringRef newSymbol,
+                                                      MlirOperation from) {
+  auto cppFrom = unwrap(from);
+  auto *context = cppFrom->getContext();
+  auto oldSymbolAttr = StringAttr::get(unwrap(oldSymbol), context);
+  auto newSymbolAttr = StringAttr::get(unwrap(newSymbol), context);
+  return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr,
+                                                unwrap(from)));
+}
+
+void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible,
+                                     void (*callback)(MlirOperation, bool,
+                                                      void *userData),
+                                     void *userData) {
+  SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible,
+                                [&](Operation *foundOpCpp, bool isVisible) {
+                                  callback(wrap(foundOpCpp), isVisible,
+                                           userData);
+                                });
+}

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 47ebeb291c35f..3c7653feb75a9 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -7,7 +7,7 @@
 #   * Local edits to signatures and types that MyPy did not auto detect (or
 #     detected incorrectly).
 
-from typing import Any, ClassVar, List, Optional
+from typing import Any, Callable, ClassVar, List, Optional
 
 from typing import overload
 
@@ -90,38 +90,34 @@ __all__ = [
     "_OperationBase",
 ]
 
-
-class AffineAddExpr(AffineBinaryExpr):
-    def __init__(self, expr: AffineExpr) -> None: ...
-    def get(self, *args, **kwargs) -> Any: ...
-    def isinstance(self, *args, **kwargs) -> Any: ...
-
-class AffineBinaryExpr(AffineExpr):
-    def __init__(self, expr: AffineExpr) -> None: ...
-    def isinstance(self, *args, **kwargs) -> Any: ...
+# Base classes: declared first to simplify declarations below.
+class _OperationBase:
+    def __init__(self, *args, **kwargs) -> None: ...
+    def detach_from_parent(self) -> object: ...
+    def get_asm(self, binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> object: ...
+    def move_after(self, other: _OperationBase) -> None: ...
+    def move_before(self, other: _OperationBase) -> None: ...
+    def print(self, file: object = ..., binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> None: ...
+    def verify(self) -> bool: ...
+    @overload
+    def __eq__(self, arg0: _OperationBase) -> bool: ...
+    @overload
+    def __eq__(self, arg0: object) -> bool: ...
+    def __hash__(self) -> int: ...
     @property
-    def lhs(self) -> AffineExpr: ...
+    def _CAPIPtr(self) -> object: ...
     @property
-    def rhs(self) -> AffineExpr: ...
-
-class AffineCeilDivExpr(AffineBinaryExpr):
-    def __init__(self, expr: AffineExpr) -> None: ...
-    def get(self, *args, **kwargs) -> Any: ...
-    def isinstance(self, *args, **kwargs) -> Any: ...
-
-class AffineConstantExpr(AffineExpr):
-    def __init__(self, expr: AffineExpr) -> None: ...
-    def get(self, *args, **kwargs) -> Any: ...
-    def isinstance(self, *args, **kwargs) -> Any: ...
+    def attributes(self) -> Any: ...
     @property
-    def value(self) -> int: ...
-
-class AffineDimExpr(AffineExpr):
-    def __init__(self, expr: AffineExpr) -> None: ...
-    def get(self, *args, **kwargs) -> Any: ...
-    def isinstance(self, *args, **kwargs) -> Any: ...
+    def location(self) -> Location: ...
     @property
-    def position(self) -> int: ...
+    def operands(self) -> Any: ...
+    @property
+    def regions(self) -> Any: ...
+    @property
+    def result(self) -> Any: ...
+    @property
+    def results(self) -> Any: ...
 
 class AffineExpr:
     def __init__(self, *args, **kwargs) -> None: ...
@@ -154,6 +150,91 @@ class AffineExpr:
     @property
     def context(self) -> object: ...
 
+class Attribute:
+    def __init__(self, cast_from_type: Attribute) -> None: ...
+    def _CAPICreate(self) -> Attribute: ...
+    def dump(self) -> None: ...
+    def get_named(self, *args, **kwargs) -> Any: ...
+    def parse(self, *args, **kwargs) -> Any: ...
+    @overload
+    def __eq__(self, arg0: Attribute) -> bool: ...
+    @overload
+    def __eq__(self, arg0: object) -> bool: ...
+    def __hash__(self) -> int: ...
+    @property
+    def _CAPIPtr(self) -> object: ...
+    @property
+    def context(self) -> object: ...
+    @property
+    def type(self) -> Any: ...
+
+class Type:
+    def __init__(self, cast_from_type: Type) -> None: ...
+    def _CAPICreate(self) -> Type: ...
+    def dump(self) -> None: ...
+    def parse(self, *args, **kwargs) -> Any: ...
+    @overload
+    def __eq__(self, arg0: Type) -> bool: ...
+    @overload
+    def __eq__(self, arg0: object) -> bool: ...
+    def __hash__(self) -> int: ...
+    @property
+    def _CAPIPtr(self) -> object: ...
+    @property
+    def context(self) -> object: ...
+
+class Value:
+    def __init__(self, *args, **kwargs) -> None: ...
+    def _CAPICreate(self) -> Value: ...
+    def dump(self) -> None: ...
+    @overload
+    def __eq__(self, arg0: Value) -> bool: ...
+    @overload
+    def __eq__(self, arg0: object) -> bool: ...
+    def __hash__(self) -> int: ...
+    @property
+    def _CAPIPtr(self) -> object: ...
+    @property
+    def context(self) -> Any: ...
+    @property
+    def owner(self) -> object: ...
+    @property
+    def type(self) -> Type: ...
+
+
+# Classes with no particular order sensitivity in alpha order.
+class AffineAddExpr(AffineBinaryExpr):
+    def __init__(self, expr: AffineExpr) -> None: ...
+    def get(self, *args, **kwargs) -> Any: ...
+    def isinstance(self, *args, **kwargs) -> Any: ...
+
+class AffineBinaryExpr(AffineExpr):
+    def __init__(self, expr: AffineExpr) -> None: ...
+    def isinstance(self, *args, **kwargs) -> Any: ...
+    @property
+    def lhs(self) -> AffineExpr: ...
+    @property
+    def rhs(self) -> AffineExpr: ...
+
+class AffineCeilDivExpr(AffineBinaryExpr):
+    def __init__(self, expr: AffineExpr) -> None: ...
+    def get(self, *args, **kwargs) -> Any: ...
+    def isinstance(self, *args, **kwargs) -> Any: ...
+
+class AffineConstantExpr(AffineExpr):
+    def __init__(self, expr: AffineExpr) -> None: ...
+    def get(self, *args, **kwargs) -> Any: ...
+    def isinstance(self, *args, **kwargs) -> Any: ...
+    @property
+    def value(self) -> int: ...
+
+class AffineDimExpr(AffineExpr):
+    def __init__(self, expr: AffineExpr) -> None: ...
+    def get(self, *args, **kwargs) -> Any: ...
+    def isinstance(self, *args, **kwargs) -> Any: ...
+    @property
+    def position(self) -> int: ...
+
 class AffineExprList:
     def __init__(self, *args, **kwargs) -> None: ...
     def __add__(self, arg0: AffineExprList) -> List[AffineExpr]: ...
@@ -245,24 +326,6 @@ class ArrayAttributeIterator:
     def __iter__(self) -> ArrayAttributeIterator: ...
     def __next__(self) -> Attribute: ...
 
-class Attribute:
-    def __init__(self, cast_from_type: Attribute) -> None: ...
-    def _CAPICreate(self) -> Attribute: ...
-    def dump(self) -> None: ...
-    def get_named(self, *args, **kwargs) -> Any: ...
-    def parse(self, *args, **kwargs) -> Any: ...
-    @overload
-    def __eq__(self, arg0: Attribute) -> bool: ...
-    @overload
-    def __eq__(self, arg0: object) -> bool: ...
-    def __hash__(self) -> int: ...
-    @property
-    def _CAPIPtr(self) -> object: ...
-    @property
-    def context(self) -> object: ...
-    @property
-    def type(self) -> Any: ...
-
 class BF16Type(Type):
     def __init__(self, cast_from_type: Type) -> None: ...
     def get(self, *args, **kwargs) -> Any: ...
@@ -751,7 +814,19 @@ class StringAttr(Attribute):
 class SymbolTable:
     def __init__(self, arg0: _OperationBase) -> None: ...
     def erase(self, operation: _OperationBase) -> None: ...
+    @staticmethod
+    def get_symbol_name(symbol: _OperationBase) -> Attribute: ...
+    @staticmethod
+    def get_visibility(symbol: _OperationBase) -> Attribute: ...
     def insert(self, operation: _OperationBase) -> Attribute: ...
+    @staticmethod
+    def replace_all_symbol_uses(old_symbol: str, new_symbol: str, from_op: _OperationBase) -> None: ...
+    @staticmethod
+    def set_symbol_name(symbol: _OperationBase, name: str) -> None: ...
+    @staticmethod
+    def set_visibility(symbol: _OperationBase, visibility: str) -> None: ...
+    @staticmethod
+    def walk_symbol_tables(from_op: _OperationBase, all_sym_uses_visible: bool, callback: Callable[[_OperationBase, bool], None) -> None: ...
     def __contains__(self, arg0: str) -> bool: ...
     def __delitem__(self, arg0: str) -> None: ...
     def __getitem__(self, arg0: str) -> object: ...
@@ -764,21 +839,6 @@ class TupleType(Type):
     @property
     def num_types(self) -> int: ...
 
-class Type:
-    def __init__(self, cast_from_type: Type) -> None: ...
-    def _CAPICreate(self) -> Type: ...
-    def dump(self) -> None: ...
-    def parse(self, *args, **kwargs) -> Any: ...
-    @overload
-    def __eq__(self, arg0: Type) -> bool: ...
-    @overload
-    def __eq__(self, arg0: object) -> bool: ...
-    def __hash__(self) -> int: ...
-    @property
-    def _CAPIPtr(self) -> object: ...
-    @property
-    def context(self) -> object: ...
-
 class TypeAttr(Attribute):
     def __init__(self, cast_from_attr: Attribute) -> None: ...
     def get(self, *args, **kwargs) -> Any: ...
@@ -807,24 +867,6 @@ class UnrankedTensorType(ShapedType):
     def get(self, *args, **kwargs) -> Any: ...
     def isinstance(self, *args, **kwargs) -> Any: ...
 
-class Value:
-    def __init__(self, *args, **kwargs) -> None: ...
-    def _CAPICreate(self) -> Value: ...
-    def dump(self) -> None: ...
-    @overload
-    def __eq__(self, arg0: Value) -> bool: ...
-    @overload
-    def __eq__(self, arg0: object) -> bool: ...
-    def __hash__(self) -> int: ...
-    @property
-    def _CAPIPtr(self) -> object: ...
-    @property
-    def context(self) -> Any: ...
-    @property
-    def owner(self) -> object: ...
-    @property
-    def type(self) -> Type: ...
-
 class VectorType(ShapedType):
     def __init__(self, cast_from_type: Type) -> None: ...
     def get(self, *args, **kwargs) -> Any: ...
@@ -833,31 +875,3 @@ class VectorType(ShapedType):
 class _GlobalDebug:
     flag: ClassVar[bool] = ...
     def __init__(self, *args, **kwargs) -> None: ...
-
-class _OperationBase:
-    def __init__(self, *args, **kwargs) -> None: ...
-    def detach_from_parent(self) -> object: ...
-    def get_asm(self, binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> object: ...
-    def move_after(self, other: _OperationBase) -> None: ...
-    def move_before(self, other: _OperationBase) -> None: ...
-    def print(self, file: object = ..., binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> None: ...
-    def verify(self) -> bool: ...
-    @overload
-    def __eq__(self, arg0: _OperationBase) -> bool: ...
-    @overload
-    def __eq__(self, arg0: object) -> bool: ...
-    def __hash__(self) -> int: ...
-    @property
-    def _CAPIPtr(self) -> object: ...
-    @property
-    def attributes(self) -> Any: ...
-    @property
-    def location(self) -> Location: ...
-    @property
-    def operands(self) -> Any: ...
-    @property
-    def regions(self) -> Any: ...
-    @property
-    def result(self) -> Any: ...
-    @property
-    def results(self) -> Any: ...

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 133edc2e1aee5..db8acc82ac3ab 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -835,79 +835,6 @@ def testDetachFromParent():
     # CHECK-NOT: func private @foo
 
 
-# CHECK-LABEL: TEST: testSymbolTable
- at run
-def testSymbolTable():
-  with Context() as ctx:
-    ctx.allow_unregistered_dialects = True
-    m1 = Module.parse("""
-      func private @foo()
-      func private @bar()""")
-    m2 = Module.parse("""
-      func private @qux()
-      func private @foo()
-      "foo.bar"() : () -> ()""")
-
-    symbol_table = SymbolTable(m1.operation)
-
-    # CHECK: func private @foo
-    # CHECK: func private @bar
-    assert "foo" in symbol_table
-    print(symbol_table["foo"])
-    assert "bar" in symbol_table
-    bar = symbol_table["bar"]
-    print(symbol_table["bar"])
-
-    assert "qux" not in symbol_table
-
-    del symbol_table["bar"]
-    try:
-      symbol_table.erase(symbol_table["bar"])
-    except KeyError:
-      pass
-    else:
-      assert False, "expected KeyError"
-
-    # CHECK: module
-    # CHECK:   func private @foo()
-    print(m1)
-    assert "bar" not in symbol_table
-
-    try:
-      print(bar)
-    except RuntimeError as e:
-      if "the operation has been invalidated" not in str(e):
-        raise
-    else:
-      assert False, "expected RuntimeError due to invalidated operation"
-
-    qux = m2.body.operations[0]
-    m1.body.append(qux)
-    symbol_table.insert(qux)
-    assert "qux" in symbol_table
-
-    # Check that insertion actually renames this symbol in the symbol table.
-    foo2 = m2.body.operations[0]
-    m1.body.append(foo2)
-    updated_name = symbol_table.insert(foo2)
-    assert foo2.name.value != "foo"
-    assert foo2.name == updated_name
-
-    # CHECK: module
-    # CHECK:   func private @foo()
-    # CHECK:   func private @qux()
-    # CHECK:   func private @foo{{.*}}
-    print(m1)
-
-    try:
-      symbol_table.insert(m2.body.operations[0])
-    except ValueError as e:
-      if "Expected operation to have a symbol name" not in str(e):
-        raise
-    else:
-      assert False, "exepcted ValueError when adding a non-symbol"
-
-
 # CHECK-LABEL: TEST: testOperationHash
 @run
 def testOperationHash():

diff  --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
new file mode 100644
index 0000000000000..af8eafb605279
--- /dev/null
+++ b/mlir/test/python/ir/symbol_table.py
@@ -0,0 +1,156 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+import io
+import itertools
+from mlir.ir import *
+
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  assert Context._get_live_count() == 0
+  return f
+
+
+# CHECK-LABEL: TEST: testSymbolTableInsert
+ at run
+def testSymbolTableInsert():
+  with Context() as ctx:
+    ctx.allow_unregistered_dialects = True
+    m1 = Module.parse("""
+      func private @foo()
+      func private @bar()""")
+    m2 = Module.parse("""
+      func private @qux()
+      func private @foo()
+      "foo.bar"() : () -> ()""")
+
+    symbol_table = SymbolTable(m1.operation)
+
+    # CHECK: func private @foo
+    # CHECK: func private @bar
+    assert "foo" in symbol_table
+    print(symbol_table["foo"])
+    assert "bar" in symbol_table
+    bar = symbol_table["bar"]
+    print(symbol_table["bar"])
+
+    assert "qux" not in symbol_table
+
+    del symbol_table["bar"]
+    try:
+      symbol_table.erase(symbol_table["bar"])
+    except KeyError:
+      pass
+    else:
+      assert False, "expected KeyError"
+
+    # CHECK: module
+    # CHECK:   func private @foo()
+    print(m1)
+    assert "bar" not in symbol_table
+
+    try:
+      print(bar)
+    except RuntimeError as e:
+      if "the operation has been invalidated" not in str(e):
+        raise
+    else:
+      assert False, "expected RuntimeError due to invalidated operation"
+
+    qux = m2.body.operations[0]
+    m1.body.append(qux)
+    symbol_table.insert(qux)
+    assert "qux" in symbol_table
+
+    # Check that insertion actually renames this symbol in the symbol table.
+    foo2 = m2.body.operations[0]
+    m1.body.append(foo2)
+    updated_name = symbol_table.insert(foo2)
+    assert foo2.name.value != "foo"
+    assert foo2.name == updated_name
+
+    # CHECK: module
+    # CHECK:   func private @foo()
+    # CHECK:   func private @qux()
+    # CHECK:   func private @foo{{.*}}
+    print(m1)
+
+    try:
+      symbol_table.insert(m2.body.operations[0])
+    except ValueError as e:
+      if "Expected operation to have a symbol name" not in str(e):
+        raise
+    else:
+      assert False, "exepcted ValueError when adding a non-symbol"
+
+
+# CHECK-LABEL: testSymbolTableRAUW
+ at run
+def testSymbolTableRAUW():
+  with Context() as ctx:
+    m = Module.parse("""
+      func private @foo() {
+        call @bar() : () -> ()
+        return
+      }
+      func private @bar()
+      """)
+    foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
+    SymbolTable.set_symbol_name(bar, "bam")
+    # Note that module.operation counts as a "nested symbol table" which won't
+    # be traversed into, so it is necessary to traverse its children.
+    SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
+    # CHECK: call @bam()
+    # CHECK: func private @bam
+    print(m)
+    # CHECK: Foo symbol: "foo"
+    # CHECK: Bar symbol: "bam"
+    print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}")
+    print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}")
+
+
+# CHECK-LABEL: testSymbolTableVisibility
+ at run
+def testSymbolTableVisibility():
+  with Context() as ctx:
+    m = Module.parse("""
+      func private @foo() {
+        return
+      }
+      """)
+    foo = m.operation.regions[0].blocks[0].operations[0]
+    # CHECK: Existing visibility: "private"
+    print(f"Existing visibility: {SymbolTable.get_visibility(foo)}")
+    SymbolTable.set_visibility(foo, "public")
+    # CHECK: func public @foo
+    print(m)
+
+
+# CHECK: testWalkSymbolTables
+ at run
+def testWalkSymbolTables():
+  with Context() as ctx:
+    m = Module.parse("""
+      module @outer {
+        module @inner{
+        }
+      }
+      """)
+    def callback(symbol_table_op, uses_visible):
+      print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
+    # CHECK: SYMBOL TABLE: True: module @inner
+    # CHECK: SYMBOL TABLE: True: module @outer
+    SymbolTable.walk_symbol_tables(m.operation, True, callback)
+
+    # Make sure exceptions in the callback are handled.
+    def error_callback(symbol_table_op, uses_visible):
+      assert False, "Raised from python"
+    try:
+      SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
+    except RuntimeError as e:
+      # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
+      print(f"GOT EXCEPTION: {e}")
+


        


More information about the Mlir-commits mailing list