[Mlir-commits] [mlir] bdc3183 - [mlir][python] Implement more SymbolTable methods.
Stella Laurenzo
llvmlistbot at llvm.org
Mon Nov 29 20:32:34 PST 2021
Author: Stella Laurenzo
Date: 2021-11-29T20:31:13-08:00
New Revision: bdc3183742f1e996d58bdf23b91966e64ad5e9a3
URL: https://github.com/llvm/llvm-project/commit/bdc3183742f1e996d58bdf23b91966e64ad5e9a3
DIFF: https://github.com/llvm/llvm-project/commit/bdc3183742f1e996d58bdf23b91966e64ad5e9a3.diff
LOG: [mlir][python] Implement more SymbolTable methods.
* set_symbol_name, get_symbol_name, set_visibility, get_visibility, replace_all_symbol_uses, walk_symbol_tables
* In integrations I've been doing, I've been reaching for all of these to do both general IR manipulation and module merging.
* I don't love the replace_all_symbol_uses underlying APIs since they necessitate SYMBOL_COUNT walks and have various sharp edges. I'm hoping that whatever emerges eventually for this can still retain this simple API as a one-shot.
Differential Revision: https://reviews.llvm.org/D114687
Added:
mlir/test/python/ir/symbol_table.py
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/CAPI/IR/IR.cpp
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 6c1a92cea01d0..1d884e634b2f2 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -754,6 +754,9 @@ MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
/// symbol tables.
MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName();
+/// Returns the name of the attribute used to store symbol visibility.
+MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName();
+
/// Creates a symbol table for the given operation. If the operation does not
/// have the SymbolTable trait, returns a null symbol table.
MLIR_CAPI_EXPORTED MlirSymbolTable
@@ -787,6 +790,23 @@ mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation);
MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable,
MlirOperation operation);
+/// Attempt to replace all uses that are nested within the given operation
+/// of the given symbol 'oldSymbol' with the provided 'newSymbol'. This does
+/// not traverse into nested symbol tables. Will fail atomically if there are
+/// any unknown operations that may be potential symbol tables.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(
+ MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from);
+
+/// Walks all symbol table operations nested within, and including, `op`. For
+/// each symbol table operation, the provided callback is invoked with the op
+/// and a boolean signifying if the symbols within that symbol table can be
+/// treated as if all uses within the IR are visible to the caller.
+/// `allSymUsesVisible` identifies whether all of the symbol uses of symbols
+/// within `op` are visible.
+MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables(
+ MlirOperation from, bool allSymUsesVisible,
+ void (*callback)(MlirOperation, bool, void *userData), void *userData);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 8a110fcc4218b..0d349143306c8 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1596,6 +1596,112 @@ PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
}
+PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
+ // Op must already be a symbol.
+ PyOperation &operation = symbol.getOperation();
+ operation.checkValid();
+ MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
+ MlirAttribute existingNameAttr =
+ mlirOperationGetAttributeByName(operation.get(), attrName);
+ if (mlirAttributeIsNull(existingNameAttr))
+ throw py::value_error("Expected operation to have a symbol name.");
+ return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
+}
+
+void PySymbolTable::setSymbolName(PyOperationBase &symbol,
+ const std::string &name) {
+ // Op must already be a symbol.
+ PyOperation &operation = symbol.getOperation();
+ operation.checkValid();
+ MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
+ MlirAttribute existingNameAttr =
+ mlirOperationGetAttributeByName(operation.get(), attrName);
+ if (mlirAttributeIsNull(existingNameAttr))
+ throw py::value_error("Expected operation to have a symbol name.");
+ MlirAttribute newNameAttr =
+ mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
+ mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
+}
+
+PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
+ PyOperation &operation = symbol.getOperation();
+ operation.checkValid();
+ MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
+ MlirAttribute existingVisAttr =
+ mlirOperationGetAttributeByName(operation.get(), attrName);
+ if (mlirAttributeIsNull(existingVisAttr))
+ throw py::value_error("Expected operation to have a symbol visibility.");
+ return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
+}
+
+void PySymbolTable::setVisibility(PyOperationBase &symbol,
+ const std::string &visibility) {
+ if (visibility != "public" && visibility != "private" &&
+ visibility != "nested")
+ throw py::value_error(
+ "Expected visibility to be 'public', 'private' or 'nested'");
+ PyOperation &operation = symbol.getOperation();
+ operation.checkValid();
+ MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
+ MlirAttribute existingVisAttr =
+ mlirOperationGetAttributeByName(operation.get(), attrName);
+ if (mlirAttributeIsNull(existingVisAttr))
+ throw py::value_error("Expected operation to have a symbol visibility.");
+ MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
+ toMlirStringRef(visibility));
+ mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
+}
+
+void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
+ const std::string &newSymbol,
+ PyOperationBase &from) {
+ PyOperation &fromOperation = from.getOperation();
+ fromOperation.checkValid();
+ if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
+ toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
+ from.getOperation())))
+
+ throw py::value_error("Symbol rename failed");
+}
+
+void PySymbolTable::walkSymbolTables(PyOperationBase &from,
+ bool allSymUsesVisible,
+ py::object callback) {
+ PyOperation &fromOperation = from.getOperation();
+ fromOperation.checkValid();
+ struct UserData {
+ PyMlirContextRef context;
+ py::object callback;
+ bool gotException;
+ std::string exceptionWhat;
+ py::object exceptionType;
+ };
+ UserData userData{
+ fromOperation.getContext(), std::move(callback), false, {}, {}};
+ mlirSymbolTableWalkSymbolTables(
+ fromOperation.get(), allSymUsesVisible,
+ [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
+ UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
+ auto pyFoundOp =
+ PyOperation::forOperation(calleeUserData->context, foundOp);
+ if (calleeUserData->gotException)
+ return;
+ try {
+ calleeUserData->callback(pyFoundOp.getObject(), isVisible);
+ } catch (py::error_already_set &e) {
+ calleeUserData->gotException = true;
+ calleeUserData->exceptionWhat = e.what();
+ calleeUserData->exceptionType = e.type();
+ }
+ },
+ static_cast<void *>(&userData));
+ if (userData.gotException) {
+ std::string message("Exception raised in callback: ");
+ message.append(userData.exceptionWhat);
+ throw std::runtime_error(std::move(message));
+ }
+}
+
namespace {
/// CRTP base class for Python MLIR values that subclass Value and should be
/// castable from it. The value hierarchy is one level deep and is not supposed
@@ -2773,10 +2879,26 @@ void mlir::python::populateIRCore(py::module &m) {
.def("insert", &PySymbolTable::insert, py::arg("operation"))
.def("erase", &PySymbolTable::erase, py::arg("operation"))
.def("__delitem__", &PySymbolTable::dunderDel)
- .def("__contains__", [](PySymbolTable &table, const std::string &name) {
- return !mlirOperationIsNull(mlirSymbolTableLookup(
- table, mlirStringRefCreate(name.data(), name.length())));
- });
+ .def("__contains__",
+ [](PySymbolTable &table, const std::string &name) {
+ return !mlirOperationIsNull(mlirSymbolTableLookup(
+ table, mlirStringRefCreate(name.data(), name.length())));
+ })
+ // Static helpers.
+ .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
+ py::arg("symbol"), py::arg("name"))
+ .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
+ py::arg("symbol"))
+ .def_static("get_visibility", &PySymbolTable::getVisibility,
+ py::arg("symbol"))
+ .def_static("set_visibility", &PySymbolTable::setVisibility,
+ py::arg("symbol"), py::arg("visibility"))
+ .def_static("replace_all_symbol_uses",
+ &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
+ py::arg("new_symbol"), py::arg("from_op"))
+ .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
+ py::arg("from_op"), py::arg("all_sym_uses_visible"),
+ py::arg("callback"));
// Container bindings.
PyBlockArgumentList::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index f0d0cc654eabb..d5e8eb4aece55 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -910,6 +910,25 @@ class PySymbolTable {
/// the symbol trait.
PyAttribute insert(PyOperationBase &symbol);
+ /// Gets and sets the name of a symbol op.
+ static PyAttribute getSymbolName(PyOperationBase &symbol);
+ static void setSymbolName(PyOperationBase &symbol, const std::string &name);
+
+ /// Gets and sets the visibility of a symbol op.
+ static PyAttribute getVisibility(PyOperationBase &symbol);
+ static void setVisibility(PyOperationBase &symbol,
+ const std::string &visibility);
+
+ /// Replaces all symbol uses within an operation. See the API
+ /// mlirSymbolTableReplaceAllSymbolUses for all caveats.
+ static void replaceAllSymbolUses(const std::string &oldSymbol,
+ const std::string &newSymbol,
+ PyOperationBase &from);
+
+ /// Walks all symbol tables under and including 'from'.
+ static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible,
+ pybind11::object callback);
+
/// Casts the bindings class into the C API structure.
operator MlirSymbolTable() { return symbolTable; }
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 35a059275ffb2..424bbae179c33 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -786,6 +786,10 @@ MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
return wrap(SymbolTable::getSymbolAttrName());
}
+MlirStringRef mlirSymbolTableGetVisibilityAttributeName() {
+ return wrap(SymbolTable::getVisibilityAttrName());
+}
+
MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>())
return wrap(static_cast<SymbolTable *>(nullptr));
@@ -810,3 +814,25 @@ void mlirSymbolTableErase(MlirSymbolTable symbolTable,
MlirOperation operation) {
unwrap(symbolTable)->erase(unwrap(operation));
}
+
+MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
+ MlirStringRef newSymbol,
+ MlirOperation from) {
+ auto cppFrom = unwrap(from);
+ auto *context = cppFrom->getContext();
+ auto oldSymbolAttr = StringAttr::get(unwrap(oldSymbol), context);
+ auto newSymbolAttr = StringAttr::get(unwrap(newSymbol), context);
+ return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr,
+ unwrap(from)));
+}
+
+void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible,
+ void (*callback)(MlirOperation, bool,
+ void *userData),
+ void *userData) {
+ SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible,
+ [&](Operation *foundOpCpp, bool isVisible) {
+ callback(wrap(foundOpCpp), isVisible,
+ userData);
+ });
+}
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 47ebeb291c35f..3c7653feb75a9 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -7,7 +7,7 @@
# * Local edits to signatures and types that MyPy did not auto detect (or
# detected incorrectly).
-from typing import Any, ClassVar, List, Optional
+from typing import Any, Callable, ClassVar, List, Optional
from typing import overload
@@ -90,38 +90,34 @@ __all__ = [
"_OperationBase",
]
-
-class AffineAddExpr(AffineBinaryExpr):
- def __init__(self, expr: AffineExpr) -> None: ...
- def get(self, *args, **kwargs) -> Any: ...
- def isinstance(self, *args, **kwargs) -> Any: ...
-
-class AffineBinaryExpr(AffineExpr):
- def __init__(self, expr: AffineExpr) -> None: ...
- def isinstance(self, *args, **kwargs) -> Any: ...
+# Base classes: declared first to simplify declarations below.
+class _OperationBase:
+ def __init__(self, *args, **kwargs) -> None: ...
+ def detach_from_parent(self) -> object: ...
+ def get_asm(self, binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> object: ...
+ def move_after(self, other: _OperationBase) -> None: ...
+ def move_before(self, other: _OperationBase) -> None: ...
+ def print(self, file: object = ..., binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> None: ...
+ def verify(self) -> bool: ...
+ @overload
+ def __eq__(self, arg0: _OperationBase) -> bool: ...
+ @overload
+ def __eq__(self, arg0: object) -> bool: ...
+ def __hash__(self) -> int: ...
@property
- def lhs(self) -> AffineExpr: ...
+ def _CAPIPtr(self) -> object: ...
@property
- def rhs(self) -> AffineExpr: ...
-
-class AffineCeilDivExpr(AffineBinaryExpr):
- def __init__(self, expr: AffineExpr) -> None: ...
- def get(self, *args, **kwargs) -> Any: ...
- def isinstance(self, *args, **kwargs) -> Any: ...
-
-class AffineConstantExpr(AffineExpr):
- def __init__(self, expr: AffineExpr) -> None: ...
- def get(self, *args, **kwargs) -> Any: ...
- def isinstance(self, *args, **kwargs) -> Any: ...
+ def attributes(self) -> Any: ...
@property
- def value(self) -> int: ...
-
-class AffineDimExpr(AffineExpr):
- def __init__(self, expr: AffineExpr) -> None: ...
- def get(self, *args, **kwargs) -> Any: ...
- def isinstance(self, *args, **kwargs) -> Any: ...
+ def location(self) -> Location: ...
@property
- def position(self) -> int: ...
+ def operands(self) -> Any: ...
+ @property
+ def regions(self) -> Any: ...
+ @property
+ def result(self) -> Any: ...
+ @property
+ def results(self) -> Any: ...
class AffineExpr:
def __init__(self, *args, **kwargs) -> None: ...
@@ -154,6 +150,91 @@ class AffineExpr:
@property
def context(self) -> object: ...
+class Attribute:
+ def __init__(self, cast_from_type: Attribute) -> None: ...
+ def _CAPICreate(self) -> Attribute: ...
+ def dump(self) -> None: ...
+ def get_named(self, *args, **kwargs) -> Any: ...
+ def parse(self, *args, **kwargs) -> Any: ...
+ @overload
+ def __eq__(self, arg0: Attribute) -> bool: ...
+ @overload
+ def __eq__(self, arg0: object) -> bool: ...
+ def __hash__(self) -> int: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def context(self) -> object: ...
+ @property
+ def type(self) -> Any: ...
+
+class Type:
+ def __init__(self, cast_from_type: Type) -> None: ...
+ def _CAPICreate(self) -> Type: ...
+ def dump(self) -> None: ...
+ def parse(self, *args, **kwargs) -> Any: ...
+ @overload
+ def __eq__(self, arg0: Type) -> bool: ...
+ @overload
+ def __eq__(self, arg0: object) -> bool: ...
+ def __hash__(self) -> int: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def context(self) -> object: ...
+
+class Value:
+ def __init__(self, *args, **kwargs) -> None: ...
+ def _CAPICreate(self) -> Value: ...
+ def dump(self) -> None: ...
+ @overload
+ def __eq__(self, arg0: Value) -> bool: ...
+ @overload
+ def __eq__(self, arg0: object) -> bool: ...
+ def __hash__(self) -> int: ...
+ @property
+ def _CAPIPtr(self) -> object: ...
+ @property
+ def context(self) -> Any: ...
+ @property
+ def owner(self) -> object: ...
+ @property
+ def type(self) -> Type: ...
+
+
+# Classes with no particular order sensitivity in alpha order.
+class AffineAddExpr(AffineBinaryExpr):
+ def __init__(self, expr: AffineExpr) -> None: ...
+ def get(self, *args, **kwargs) -> Any: ...
+ def isinstance(self, *args, **kwargs) -> Any: ...
+
+class AffineBinaryExpr(AffineExpr):
+ def __init__(self, expr: AffineExpr) -> None: ...
+ def isinstance(self, *args, **kwargs) -> Any: ...
+ @property
+ def lhs(self) -> AffineExpr: ...
+ @property
+ def rhs(self) -> AffineExpr: ...
+
+class AffineCeilDivExpr(AffineBinaryExpr):
+ def __init__(self, expr: AffineExpr) -> None: ...
+ def get(self, *args, **kwargs) -> Any: ...
+ def isinstance(self, *args, **kwargs) -> Any: ...
+
+class AffineConstantExpr(AffineExpr):
+ def __init__(self, expr: AffineExpr) -> None: ...
+ def get(self, *args, **kwargs) -> Any: ...
+ def isinstance(self, *args, **kwargs) -> Any: ...
+ @property
+ def value(self) -> int: ...
+
+class AffineDimExpr(AffineExpr):
+ def __init__(self, expr: AffineExpr) -> None: ...
+ def get(self, *args, **kwargs) -> Any: ...
+ def isinstance(self, *args, **kwargs) -> Any: ...
+ @property
+ def position(self) -> int: ...
+
class AffineExprList:
def __init__(self, *args, **kwargs) -> None: ...
def __add__(self, arg0: AffineExprList) -> List[AffineExpr]: ...
@@ -245,24 +326,6 @@ class ArrayAttributeIterator:
def __iter__(self) -> ArrayAttributeIterator: ...
def __next__(self) -> Attribute: ...
-class Attribute:
- def __init__(self, cast_from_type: Attribute) -> None: ...
- def _CAPICreate(self) -> Attribute: ...
- def dump(self) -> None: ...
- def get_named(self, *args, **kwargs) -> Any: ...
- def parse(self, *args, **kwargs) -> Any: ...
- @overload
- def __eq__(self, arg0: Attribute) -> bool: ...
- @overload
- def __eq__(self, arg0: object) -> bool: ...
- def __hash__(self) -> int: ...
- @property
- def _CAPIPtr(self) -> object: ...
- @property
- def context(self) -> object: ...
- @property
- def type(self) -> Any: ...
-
class BF16Type(Type):
def __init__(self, cast_from_type: Type) -> None: ...
def get(self, *args, **kwargs) -> Any: ...
@@ -751,7 +814,19 @@ class StringAttr(Attribute):
class SymbolTable:
def __init__(self, arg0: _OperationBase) -> None: ...
def erase(self, operation: _OperationBase) -> None: ...
+ @staticmethod
+ def get_symbol_name(symbol: _OperationBase) -> Attribute: ...
+ @staticmethod
+ def get_visibility(symbol: _OperationBase) -> Attribute: ...
def insert(self, operation: _OperationBase) -> Attribute: ...
+ @staticmethod
+ def replace_all_symbol_uses(old_symbol: str, new_symbol: str, from_op: _OperationBase) -> None: ...
+ @staticmethod
+ def set_symbol_name(symbol: _OperationBase, name: str) -> None: ...
+ @staticmethod
+ def set_visibility(symbol: _OperationBase, visibility: str) -> None: ...
+ @staticmethod
+ def walk_symbol_tables(from_op: _OperationBase, all_sym_uses_visible: bool, callback: Callable[[_OperationBase, bool], None) -> None: ...
def __contains__(self, arg0: str) -> bool: ...
def __delitem__(self, arg0: str) -> None: ...
def __getitem__(self, arg0: str) -> object: ...
@@ -764,21 +839,6 @@ class TupleType(Type):
@property
def num_types(self) -> int: ...
-class Type:
- def __init__(self, cast_from_type: Type) -> None: ...
- def _CAPICreate(self) -> Type: ...
- def dump(self) -> None: ...
- def parse(self, *args, **kwargs) -> Any: ...
- @overload
- def __eq__(self, arg0: Type) -> bool: ...
- @overload
- def __eq__(self, arg0: object) -> bool: ...
- def __hash__(self) -> int: ...
- @property
- def _CAPIPtr(self) -> object: ...
- @property
- def context(self) -> object: ...
-
class TypeAttr(Attribute):
def __init__(self, cast_from_attr: Attribute) -> None: ...
def get(self, *args, **kwargs) -> Any: ...
@@ -807,24 +867,6 @@ class UnrankedTensorType(ShapedType):
def get(self, *args, **kwargs) -> Any: ...
def isinstance(self, *args, **kwargs) -> Any: ...
-class Value:
- def __init__(self, *args, **kwargs) -> None: ...
- def _CAPICreate(self) -> Value: ...
- def dump(self) -> None: ...
- @overload
- def __eq__(self, arg0: Value) -> bool: ...
- @overload
- def __eq__(self, arg0: object) -> bool: ...
- def __hash__(self) -> int: ...
- @property
- def _CAPIPtr(self) -> object: ...
- @property
- def context(self) -> Any: ...
- @property
- def owner(self) -> object: ...
- @property
- def type(self) -> Type: ...
-
class VectorType(ShapedType):
def __init__(self, cast_from_type: Type) -> None: ...
def get(self, *args, **kwargs) -> Any: ...
@@ -833,31 +875,3 @@ class VectorType(ShapedType):
class _GlobalDebug:
flag: ClassVar[bool] = ...
def __init__(self, *args, **kwargs) -> None: ...
-
-class _OperationBase:
- def __init__(self, *args, **kwargs) -> None: ...
- def detach_from_parent(self) -> object: ...
- def get_asm(self, binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> object: ...
- def move_after(self, other: _OperationBase) -> None: ...
- def move_before(self, other: _OperationBase) -> None: ...
- def print(self, file: object = ..., binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> None: ...
- def verify(self) -> bool: ...
- @overload
- def __eq__(self, arg0: _OperationBase) -> bool: ...
- @overload
- def __eq__(self, arg0: object) -> bool: ...
- def __hash__(self) -> int: ...
- @property
- def _CAPIPtr(self) -> object: ...
- @property
- def attributes(self) -> Any: ...
- @property
- def location(self) -> Location: ...
- @property
- def operands(self) -> Any: ...
- @property
- def regions(self) -> Any: ...
- @property
- def result(self) -> Any: ...
- @property
- def results(self) -> Any: ...
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 133edc2e1aee5..db8acc82ac3ab 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -835,79 +835,6 @@ def testDetachFromParent():
# CHECK-NOT: func private @foo
-# CHECK-LABEL: TEST: testSymbolTable
- at run
-def testSymbolTable():
- with Context() as ctx:
- ctx.allow_unregistered_dialects = True
- m1 = Module.parse("""
- func private @foo()
- func private @bar()""")
- m2 = Module.parse("""
- func private @qux()
- func private @foo()
- "foo.bar"() : () -> ()""")
-
- symbol_table = SymbolTable(m1.operation)
-
- # CHECK: func private @foo
- # CHECK: func private @bar
- assert "foo" in symbol_table
- print(symbol_table["foo"])
- assert "bar" in symbol_table
- bar = symbol_table["bar"]
- print(symbol_table["bar"])
-
- assert "qux" not in symbol_table
-
- del symbol_table["bar"]
- try:
- symbol_table.erase(symbol_table["bar"])
- except KeyError:
- pass
- else:
- assert False, "expected KeyError"
-
- # CHECK: module
- # CHECK: func private @foo()
- print(m1)
- assert "bar" not in symbol_table
-
- try:
- print(bar)
- except RuntimeError as e:
- if "the operation has been invalidated" not in str(e):
- raise
- else:
- assert False, "expected RuntimeError due to invalidated operation"
-
- qux = m2.body.operations[0]
- m1.body.append(qux)
- symbol_table.insert(qux)
- assert "qux" in symbol_table
-
- # Check that insertion actually renames this symbol in the symbol table.
- foo2 = m2.body.operations[0]
- m1.body.append(foo2)
- updated_name = symbol_table.insert(foo2)
- assert foo2.name.value != "foo"
- assert foo2.name == updated_name
-
- # CHECK: module
- # CHECK: func private @foo()
- # CHECK: func private @qux()
- # CHECK: func private @foo{{.*}}
- print(m1)
-
- try:
- symbol_table.insert(m2.body.operations[0])
- except ValueError as e:
- if "Expected operation to have a symbol name" not in str(e):
- raise
- else:
- assert False, "exepcted ValueError when adding a non-symbol"
-
-
# CHECK-LABEL: TEST: testOperationHash
@run
def testOperationHash():
diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
new file mode 100644
index 0000000000000..af8eafb605279
--- /dev/null
+++ b/mlir/test/python/ir/symbol_table.py
@@ -0,0 +1,156 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+import io
+import itertools
+from mlir.ir import *
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
+
+
+# CHECK-LABEL: TEST: testSymbolTableInsert
+ at run
+def testSymbolTableInsert():
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ m1 = Module.parse("""
+ func private @foo()
+ func private @bar()""")
+ m2 = Module.parse("""
+ func private @qux()
+ func private @foo()
+ "foo.bar"() : () -> ()""")
+
+ symbol_table = SymbolTable(m1.operation)
+
+ # CHECK: func private @foo
+ # CHECK: func private @bar
+ assert "foo" in symbol_table
+ print(symbol_table["foo"])
+ assert "bar" in symbol_table
+ bar = symbol_table["bar"]
+ print(symbol_table["bar"])
+
+ assert "qux" not in symbol_table
+
+ del symbol_table["bar"]
+ try:
+ symbol_table.erase(symbol_table["bar"])
+ except KeyError:
+ pass
+ else:
+ assert False, "expected KeyError"
+
+ # CHECK: module
+ # CHECK: func private @foo()
+ print(m1)
+ assert "bar" not in symbol_table
+
+ try:
+ print(bar)
+ except RuntimeError as e:
+ if "the operation has been invalidated" not in str(e):
+ raise
+ else:
+ assert False, "expected RuntimeError due to invalidated operation"
+
+ qux = m2.body.operations[0]
+ m1.body.append(qux)
+ symbol_table.insert(qux)
+ assert "qux" in symbol_table
+
+ # Check that insertion actually renames this symbol in the symbol table.
+ foo2 = m2.body.operations[0]
+ m1.body.append(foo2)
+ updated_name = symbol_table.insert(foo2)
+ assert foo2.name.value != "foo"
+ assert foo2.name == updated_name
+
+ # CHECK: module
+ # CHECK: func private @foo()
+ # CHECK: func private @qux()
+ # CHECK: func private @foo{{.*}}
+ print(m1)
+
+ try:
+ symbol_table.insert(m2.body.operations[0])
+ except ValueError as e:
+ if "Expected operation to have a symbol name" not in str(e):
+ raise
+ else:
+ assert False, "exepcted ValueError when adding a non-symbol"
+
+
+# CHECK-LABEL: testSymbolTableRAUW
+ at run
+def testSymbolTableRAUW():
+ with Context() as ctx:
+ m = Module.parse("""
+ func private @foo() {
+ call @bar() : () -> ()
+ return
+ }
+ func private @bar()
+ """)
+ foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
+ SymbolTable.set_symbol_name(bar, "bam")
+ # Note that module.operation counts as a "nested symbol table" which won't
+ # be traversed into, so it is necessary to traverse its children.
+ SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
+ # CHECK: call @bam()
+ # CHECK: func private @bam
+ print(m)
+ # CHECK: Foo symbol: "foo"
+ # CHECK: Bar symbol: "bam"
+ print(f"Foo symbol: {SymbolTable.get_symbol_name(foo)}")
+ print(f"Bar symbol: {SymbolTable.get_symbol_name(bar)}")
+
+
+# CHECK-LABEL: testSymbolTableVisibility
+ at run
+def testSymbolTableVisibility():
+ with Context() as ctx:
+ m = Module.parse("""
+ func private @foo() {
+ return
+ }
+ """)
+ foo = m.operation.regions[0].blocks[0].operations[0]
+ # CHECK: Existing visibility: "private"
+ print(f"Existing visibility: {SymbolTable.get_visibility(foo)}")
+ SymbolTable.set_visibility(foo, "public")
+ # CHECK: func public @foo
+ print(m)
+
+
+# CHECK: testWalkSymbolTables
+ at run
+def testWalkSymbolTables():
+ with Context() as ctx:
+ m = Module.parse("""
+ module @outer {
+ module @inner{
+ }
+ }
+ """)
+ def callback(symbol_table_op, uses_visible):
+ print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
+ # CHECK: SYMBOL TABLE: True: module @inner
+ # CHECK: SYMBOL TABLE: True: module @outer
+ SymbolTable.walk_symbol_tables(m.operation, True, callback)
+
+ # Make sure exceptions in the callback are handled.
+ def error_callback(symbol_table_op, uses_visible):
+ assert False, "Raised from python"
+ try:
+ SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
+ except RuntimeError as e:
+ # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
+ print(f"GOT EXCEPTION: {e}")
+
More information about the Mlir-commits
mailing list