[Mlir-commits] [mlir] 30d6189 - [mlir] provide C API and Python bindings for symbol tables
Alex Zinenko
llvmlistbot at llvm.org
Tue Nov 2 06:23:08 PDT 2021
Author: Alex Zinenko
Date: 2021-11-02T14:22:58+01:00
New Revision: 30d61893fb7bbe364bf25074feaf0b178dac64e6
URL: https://github.com/llvm/llvm-project/commit/30d61893fb7bbe364bf25074feaf0b178dac64e6
DIFF: https://github.com/llvm/llvm-project/commit/30d61893fb7bbe364bf25074feaf0b178dac64e6.diff
LOG: [mlir] provide C API and Python bindings for symbol tables
Symbol tables are a largely useful top-level IR construct, for example, they
make it easy to access functions in a module by name instead of traversing the
list of module's operations to find the corresponding function.
Depends On D112886
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D112821
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/include/mlir-c/Support.h
mlir/include/mlir/CAPI/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/CAPI/IR/Support.cpp
mlir/test/CAPI/ir.c
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index ca0c45224f3a5..1610191256eea 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -54,6 +54,7 @@ DEFINE_C_API_STRUCT(MlirOperation, void);
DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
DEFINE_C_API_STRUCT(MlirBlock, void);
DEFINE_C_API_STRUCT(MlirRegion, void);
+DEFINE_C_API_STRUCT(MlirSymbolTable, void);
DEFINE_C_API_STRUCT(MlirAttribute, const void);
DEFINE_C_API_STRUCT(MlirIdentifier, const void);
@@ -738,6 +739,47 @@ MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
/// Returns the hash value of the type id.
MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
+//===----------------------------------------------------------------------===//
+// Symbol and SymbolTable API.
+//===----------------------------------------------------------------------===//
+
+/// Returns the name of the attribute used to store symbol names compatible with
+/// symbol tables.
+MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName();
+
+/// Creates a symbol table for the given operation. If the operation does not
+/// have the SymbolTable trait, returns a null symbol table.
+MLIR_CAPI_EXPORTED MlirSymbolTable
+mlirSymbolTableCreate(MlirOperation operation);
+
+/// Returns true if the symbol table is null.
+static inline bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable) {
+ return !symbolTable.ptr;
+}
+
+/// Destroys the symbol table created with mlirSymbolTableCreate. This does not
+/// affect the operations in the table.
+MLIR_CAPI_EXPORTED void mlirSymbolTableDestroy(MlirSymbolTable symbolTable);
+
+/// Looks up a symbol with the given name in the given symbol table and returns
+/// the operation that corresponds to the symbol. If the symbol cannot be found,
+/// returns a null operation.
+MLIR_CAPI_EXPORTED MlirOperation
+mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name);
+
+/// Inserts the given operation into the given symbol table. The operation must
+/// have the symbol trait. If the symbol table already has a symbol with the
+/// same name, renames the symbol being inserted to ensure name uniqueness. Note
+/// that this does not move the operation itself into the block of the symbol
+/// table operation, this should be done separately. Returns the name of the
+/// symbol after insertion.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation);
+
+/// Removes the given operation from the symbol table and erases it.
+MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable,
+ MlirOperation operation);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h
index 315f6c4564eba..f20e58fe62317 100644
--- a/mlir/include/mlir-c/Support.h
+++ b/mlir/include/mlir-c/Support.h
@@ -79,6 +79,10 @@ inline static MlirStringRef mlirStringRefCreate(const char *str,
MLIR_CAPI_EXPORTED MlirStringRef
mlirStringRefCreateFromCString(const char *str);
+/// Returns true if two string references are equal, false otherwise.
+MLIR_CAPI_EXPORTED bool mlirStringRefEqual(MlirStringRef string,
+ MlirStringRef other);
+
/// A callback for returning string references.
///
/// This function is called back by the functions that need to return a
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index d5e961367e79a..a864175d01912 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -27,6 +27,7 @@ DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
+DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable);
DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute)
DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d47d06a3aa75e..8f451cf34bed3 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1530,6 +1530,57 @@ PyValue PyValue::createFromCapsule(pybind11::object capsule) {
return PyValue(ownerRef, value);
}
+//------------------------------------------------------------------------------
+// PySymbolTable.
+//------------------------------------------------------------------------------
+
+PySymbolTable::PySymbolTable(PyOperationBase &operation)
+ : operation(operation.getOperation().getRef()) {
+ symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
+ if (mlirSymbolTableIsNull(symbolTable)) {
+ throw py::cast_error("Operation is not a Symbol Table.");
+ }
+}
+
+py::object PySymbolTable::dunderGetItem(const std::string &name) {
+ operation->checkValid();
+ MlirOperation symbol = mlirSymbolTableLookup(
+ symbolTable, mlirStringRefCreate(name.data(), name.length()));
+ if (mlirOperationIsNull(symbol))
+ throw py::key_error("Symbol '" + name + "' not in the symbol table.");
+
+ return PyOperation::forOperation(operation->getContext(), symbol,
+ operation.getObject())
+ ->createOpView();
+}
+
+void PySymbolTable::erase(PyOperationBase &symbol) {
+ operation->checkValid();
+ symbol.getOperation().checkValid();
+ mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
+ // The operation is also erased, so we must invalidate it. There may be Python
+ // references to this operation so we don't want to delete it from the list of
+ // live operations here.
+ symbol.getOperation().valid = false;
+}
+
+void PySymbolTable::dunderDel(const std::string &name) {
+ py::object operation = dunderGetItem(name);
+ erase(py::cast<PyOperationBase &>(operation));
+}
+
+PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
+ operation->checkValid();
+ symbol.getOperation().checkValid();
+ MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
+ symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
+ if (mlirAttributeIsNull(symbolAttr))
+ throw py::value_error("Expected operation to have a symbol name.");
+ return PyAttribute(
+ symbol.getOperation().getContext(),
+ mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
+}
+
namespace {
/// CRTP base class for Python MLIR values that subclass Value and should be
/// castable from it. The value hierarchy is one level deep and is not supposed
@@ -2670,6 +2721,20 @@ void mlir::python::populateIRCore(py::module &m) {
PyBlockArgument::bind(m);
PyOpResult::bind(m);
+ //----------------------------------------------------------------------------
+ // Mapping of SymbolTable.
+ //----------------------------------------------------------------------------
+ py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
+ .def(py::init<PyOperationBase &>())
+ .def("__getitem__", &PySymbolTable::dunderGetItem)
+ .def("insert", &PySymbolTable::insert)
+ .def("erase", &PySymbolTable::erase)
+ .def("__delitem__", &PySymbolTable::dunderDel)
+ .def("__contains__", [](PySymbolTable &table, const std::string &name) {
+ return !mlirOperationIsNull(mlirSymbolTableLookup(
+ table, mlirStringRefCreate(name.data(), name.length())));
+ });
+
// Container bindings.
PyBlockArgumentList::bind(m);
PyBlockIterator::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 73924fc74bdbf..eb5c2385a165d 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -32,6 +32,7 @@ class DefaultingPyMlirContext;
class PyModule;
class PyOperation;
class PyType;
+class PySymbolTable;
class PyValue;
/// Template for a reference to a concrete type which captures a python
@@ -513,6 +514,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
bool valid = true;
friend class PyOperationBase;
+ friend class PySymbolTable;
};
/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
@@ -876,6 +878,38 @@ class PyIntegerSet : public BaseContextObject {
MlirIntegerSet integerSet;
};
+/// Bindings for MLIR symbol tables.
+class PySymbolTable {
+public:
+ /// Constructs a symbol table for the given operation.
+ explicit PySymbolTable(PyOperationBase &operation);
+
+ /// Destroys the symbol table.
+ ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); }
+
+ /// Returns the symbol (opview) with the given name, throws if there is no
+ /// such symbol in the table.
+ pybind11::object dunderGetItem(const std::string &name);
+
+ /// Removes the given operation from the symbol table and erases it.
+ void erase(PyOperationBase &symbol);
+
+ /// Removes the operation with the given name from the symbol table and erases
+ /// it, throws if there is no such symbol in the table.
+ void dunderDel(const std::string &name);
+
+ /// Inserts the given operation into the symbol table. The operation must have
+ /// the symbol trait.
+ PyAttribute insert(PyOperationBase &symbol);
+
+ /// Casts the bindings class into the C API structure.
+ operator MlirSymbolTable() { return symbolTable; }
+
+private:
+ PyOperationRef operation;
+ MlirSymbolTable symbolTable;
+};
+
void populateIRAffine(pybind11::module &m);
void populateIRAttributes(pybind11::module &m);
void populateIRCore(pybind11::module &m);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 6f617dc19269d..13490b342d9f7 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -763,3 +763,36 @@ bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
size_t mlirTypeIDHashValue(MlirTypeID typeID) {
return hash_value(unwrap(typeID));
}
+
+//===----------------------------------------------------------------------===//
+// Symbol and SymbolTable API.
+//===----------------------------------------------------------------------===//
+
+MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
+ return wrap(SymbolTable::getSymbolAttrName());
+}
+
+MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
+ if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>())
+ return wrap(static_cast<SymbolTable *>(nullptr));
+ return wrap(new SymbolTable(unwrap(operation)));
+}
+
+void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
+ delete unwrap(symbolTable);
+}
+
+MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
+ MlirStringRef name) {
+ return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length)));
+}
+
+MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
+ MlirOperation operation) {
+ return wrap(unwrap(symbolTable)->insert(unwrap(operation)));
+}
+
+void mlirSymbolTableErase(MlirSymbolTable symbolTable,
+ MlirOperation operation) {
+ unwrap(symbolTable)->erase(unwrap(operation));
+}
diff --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp
index e4b409906297d..b6e1f9180c771 100644
--- a/mlir/lib/CAPI/IR/Support.cpp
+++ b/mlir/lib/CAPI/IR/Support.cpp
@@ -7,9 +7,15 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/Support.h"
+#include "llvm/ADT/StringRef.h"
#include <cstring>
MlirStringRef mlirStringRefCreateFromCString(const char *str) {
return mlirStringRefCreate(str, strlen(str));
}
+
+bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
+ return llvm::StringRef(string.data, string.length) ==
+ llvm::StringRef(other.data, other.length);
+}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index fd245cd6afb17..ef555377c1ce5 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1692,57 +1692,6 @@ static void deleteUserData(void *userData) {
(intptr_t)userData);
}
-void testDiagnostics() {
- MlirContext ctx = mlirContextCreate();
- MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
- ctx, errorHandler, (void *)42, deleteUserData);
- fprintf(stderr, "@test_diagnostics\n");
- MlirLocation unknownLoc = mlirLocationUnknownGet(ctx);
- mlirEmitError(unknownLoc, "test diagnostics");
- MlirLocation fileLineColLoc = mlirLocationFileLineColGet(
- ctx, mlirStringRefCreateFromCString("file.c"), 1, 2);
- mlirEmitError(fileLineColLoc, "test diagnostics");
- MlirLocation callSiteLoc = mlirLocationCallSiteGet(
- mlirLocationFileLineColGet(
- ctx, mlirStringRefCreateFromCString("other-file.c"), 2, 3),
- fileLineColLoc);
- mlirEmitError(callSiteLoc, "test diagnostics");
- MlirLocation null = {0};
- MlirLocation nameLoc =
- mlirLocationNameGet(ctx, mlirStringRefCreateFromCString("named"), null);
- mlirEmitError(nameLoc, "test diagnostics");
- MlirLocation locs[2] = {nameLoc, callSiteLoc};
- MlirAttribute nullAttr = {0};
- MlirLocation fusedLoc = mlirLocationFusedGet(ctx, 2, locs, nullAttr);
- mlirEmitError(fusedLoc, "test diagnostics");
- mlirContextDetachDiagnosticHandler(ctx, id);
- mlirEmitError(unknownLoc, "more test diagnostics");
- // CHECK-LABEL: @test_diagnostics
- // CHECK: processing diagnostic (userData: 42) <<
- // CHECK: test diagnostics
- // CHECK: loc(unknown)
- // CHECK: >> end of diagnostic (userData: 42)
- // CHECK: processing diagnostic (userData: 42) <<
- // CHECK: test diagnostics
- // CHECK: loc("file.c":1:2)
- // CHECK: >> end of diagnostic (userData: 42)
- // CHECK: processing diagnostic (userData: 42) <<
- // CHECK: test diagnostics
- // CHECK: loc(callsite("other-file.c":2:3 at "file.c":1:2))
- // CHECK: >> end of diagnostic (userData: 42)
- // CHECK: processing diagnostic (userData: 42) <<
- // CHECK: test diagnostics
- // CHECK: loc("named")
- // CHECK: >> end of diagnostic (userData: 42)
- // CHECK: processing diagnostic (userData: 42) <<
- // CHECK: test diagnostics
- // CHECK: loc(fused["named", callsite("other-file.c":2:3 at "file.c":1:2)])
- // CHECK: deleting user data (userData: 42)
- // CHECK-NOT: processing diagnostic
- // CHECK: more test diagnostics
- mlirContextDestroy(ctx);
-}
-
int testTypeID(MlirContext ctx) {
fprintf(stderr, "@testTypeID\n");
@@ -1841,6 +1790,148 @@ int testTypeID(MlirContext ctx) {
return 0;
}
+int testSymbolTable(MlirContext ctx) {
+ fprintf(stderr, "@testSymbolTable\n");
+
+ const char *moduleString = "func private @foo()"
+ "func private @bar()";
+ const char *otherModuleString = "func private @qux()"
+ "func private @foo()";
+
+ MlirModule module =
+ mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+ MlirModule otherModule = mlirModuleCreateParse(
+ ctx, mlirStringRefCreateFromCString(otherModuleString));
+
+ MlirSymbolTable symbolTable =
+ mlirSymbolTableCreate(mlirModuleGetOperation(module));
+
+ MlirOperation funcFoo =
+ mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("foo"));
+ if (mlirOperationIsNull(funcFoo))
+ return 1;
+
+ MlirOperation funcBar =
+ mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar"));
+ if (mlirOperationEqual(funcFoo, funcBar))
+ return 2;
+
+ MlirOperation missing =
+ mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
+ if (!mlirOperationIsNull(missing))
+ return 3;
+
+ MlirBlock moduleBody = mlirModuleGetBody(module);
+ MlirBlock otherModuleBody = mlirModuleGetBody(otherModule);
+ MlirOperation operation = mlirBlockGetFirstOperation(otherModuleBody);
+ mlirOperationRemoveFromParent(operation);
+ mlirBlockAppendOwnedOperation(moduleBody, operation);
+
+ // At this moment, the operation is still missing from the symbol table.
+ MlirOperation stillMissing =
+ mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
+ if (!mlirOperationIsNull(stillMissing))
+ return 4;
+
+ // After it is added to the symbol table, and not only the operation with
+ // which the table is associated, it can be looked up.
+ mlirSymbolTableInsert(symbolTable, operation);
+ MlirOperation funcQux =
+ mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
+ if (!mlirOperationEqual(operation, funcQux))
+ return 5;
+
+ // Erasing from the symbol table also removes the operation.
+ mlirSymbolTableErase(symbolTable, funcBar);
+ MlirOperation nowMissing =
+ mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar"));
+ if (!mlirOperationIsNull(nowMissing))
+ return 6;
+
+ // Adding a symbol with the same name to the table should rename.
+ MlirOperation duplicateNameOp = mlirBlockGetFirstOperation(otherModuleBody);
+ mlirOperationRemoveFromParent(duplicateNameOp);
+ mlirBlockAppendOwnedOperation(moduleBody, duplicateNameOp);
+ MlirAttribute newName = mlirSymbolTableInsert(symbolTable, duplicateNameOp);
+ MlirStringRef newNameStr = mlirStringAttrGetValue(newName);
+ if (mlirStringRefEqual(newNameStr, mlirStringRefCreateFromCString("foo")))
+ return 7;
+ MlirAttribute updatedName = mlirOperationGetAttributeByName(
+ duplicateNameOp, mlirSymbolTableGetSymbolAttributeName());
+ if (!mlirAttributeEqual(updatedName, newName))
+ return 8;
+
+ mlirOperationDump(mlirModuleGetOperation(module));
+ mlirOperationDump(mlirModuleGetOperation(otherModule));
+ // clang-format off
+ // CHECK-LABEL: @testSymbolTable
+ // CHECK: module
+ // CHECK: func private @foo
+ // CHECK: func private @qux
+ // CHECK: func private @foo{{.+}}
+ // CHECK: module
+ // CHECK-NOT: @qux
+ // CHECK-NOT: @foo
+ // clang-format on
+
+ mlirSymbolTableDestroy(symbolTable);
+ mlirModuleDestroy(module);
+ mlirModuleDestroy(otherModule);
+
+ return 0;
+}
+
+void testDiagnostics() {
+ MlirContext ctx = mlirContextCreate();
+ MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
+ ctx, errorHandler, (void *)42, deleteUserData);
+ fprintf(stderr, "@test_diagnostics\n");
+ MlirLocation unknownLoc = mlirLocationUnknownGet(ctx);
+ mlirEmitError(unknownLoc, "test diagnostics");
+ MlirLocation fileLineColLoc = mlirLocationFileLineColGet(
+ ctx, mlirStringRefCreateFromCString("file.c"), 1, 2);
+ mlirEmitError(fileLineColLoc, "test diagnostics");
+ MlirLocation callSiteLoc = mlirLocationCallSiteGet(
+ mlirLocationFileLineColGet(
+ ctx, mlirStringRefCreateFromCString("other-file.c"), 2, 3),
+ fileLineColLoc);
+ mlirEmitError(callSiteLoc, "test diagnostics");
+ MlirLocation null = {0};
+ MlirLocation nameLoc =
+ mlirLocationNameGet(ctx, mlirStringRefCreateFromCString("named"), null);
+ mlirEmitError(nameLoc, "test diagnostics");
+ MlirLocation locs[2] = {nameLoc, callSiteLoc};
+ MlirAttribute nullAttr = {0};
+ MlirLocation fusedLoc = mlirLocationFusedGet(ctx, 2, locs, nullAttr);
+ mlirEmitError(fusedLoc, "test diagnostics");
+ mlirContextDetachDiagnosticHandler(ctx, id);
+ mlirEmitError(unknownLoc, "more test diagnostics");
+ // CHECK-LABEL: @test_diagnostics
+ // CHECK: processing diagnostic (userData: 42) <<
+ // CHECK: test diagnostics
+ // CHECK: loc(unknown)
+ // CHECK: >> end of diagnostic (userData: 42)
+ // CHECK: processing diagnostic (userData: 42) <<
+ // CHECK: test diagnostics
+ // CHECK: loc("file.c":1:2)
+ // CHECK: >> end of diagnostic (userData: 42)
+ // CHECK: processing diagnostic (userData: 42) <<
+ // CHECK: test diagnostics
+ // CHECK: loc(callsite("other-file.c":2:3 at "file.c":1:2))
+ // CHECK: >> end of diagnostic (userData: 42)
+ // CHECK: processing diagnostic (userData: 42) <<
+ // CHECK: test diagnostics
+ // CHECK: loc("named")
+ // CHECK: >> end of diagnostic (userData: 42)
+ // CHECK: processing diagnostic (userData: 42) <<
+ // CHECK: test diagnostics
+ // CHECK: loc(fused["named", callsite("other-file.c":2:3 at "file.c":1:2)])
+ // CHECK: deleting user data (userData: 42)
+ // CHECK-NOT: processing diagnostic
+ // CHECK: more test diagnostics
+ mlirContextDestroy(ctx);
+}
+
int main() {
MlirContext ctx = mlirContextCreate();
mlirRegisterAllDialects(ctx);
@@ -1870,9 +1961,10 @@ int main() {
return 11;
if (testClone())
return 12;
- if (testTypeID(ctx)) {
+ if (testTypeID(ctx))
return 13;
- }
+ if (testSymbolTable(ctx))
+ return 14;
mlirContextDestroy(ctx);
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index c94c22ea53a0b..950c31217fa22 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -773,7 +773,7 @@ def testAppendMoveFromAnotherBlock():
with Context():
m1 = Module.parse("func private @foo()")
m2 = Module.parse("func private @bar()")
- func = m1.body.operations[0]
+ func = m1.body.operations[0]
m2.body.append(func)
# CHECK: module
@@ -803,3 +803,76 @@ def testDetachFromParent():
print(m1)
# CHECK-NOT: func private @foo
+
+
+# CHECK-LABEL: TEST: testSymbolTable
+ at run
+def testSymbolTable():
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ m1 = Module.parse("""
+ func private @foo()
+ func private @bar()""")
+ m2 = Module.parse("""
+ func private @qux()
+ func private @foo()
+ "foo.bar"() : () -> ()""")
+
+ symbol_table = SymbolTable(m1.operation)
+
+ # CHECK: func private @foo
+ # CHECK: func private @bar
+ assert "foo" in symbol_table
+ print(symbol_table["foo"])
+ assert "bar" in symbol_table
+ bar = symbol_table["bar"]
+ print(symbol_table["bar"])
+
+ assert "qux" not in symbol_table
+
+ del symbol_table["bar"]
+ try:
+ symbol_table.erase(symbol_table["bar"])
+ except KeyError:
+ pass
+ else:
+ assert False, "expected KeyError"
+
+ # CHECK: module
+ # CHECK: func private @foo()
+ print(m1)
+ assert "bar" not in symbol_table
+
+ try:
+ print(bar)
+ except RuntimeError as e:
+ if "the operation has been invalidated" not in str(e):
+ raise
+ else:
+ assert False, "expected RuntimeError due to invalidated operation"
+
+ qux = m2.body.operations[0]
+ m1.body.append(qux)
+ symbol_table.insert(qux)
+ assert "qux" in symbol_table
+
+ # Check that insertion actually renames this symbol in the symbol table.
+ foo2 = m2.body.operations[0]
+ m1.body.append(foo2)
+ updated_name = symbol_table.insert(foo2)
+ assert foo2.name.value != "foo"
+ assert foo2.name == updated_name
+
+ # CHECK: module
+ # CHECK: func private @foo()
+ # CHECK: func private @qux()
+ # CHECK: func private @foo{{.*}}
+ print(m1)
+
+ try:
+ symbol_table.insert(m2.body.operations[0])
+ except ValueError as e:
+ if "Expected operation to have a symbol name" not in str(e):
+ raise
+ else:
+ assert False, "exepcted ValueError when adding a non-symbol"
More information about the Mlir-commits
mailing list