[Mlir-commits] [mlir] 30d6189 - [mlir] provide C API and Python bindings for symbol tables

Alex Zinenko llvmlistbot at llvm.org
Tue Nov 2 06:23:08 PDT 2021


Author: Alex Zinenko
Date: 2021-11-02T14:22:58+01:00
New Revision: 30d61893fb7bbe364bf25074feaf0b178dac64e6

URL: https://github.com/llvm/llvm-project/commit/30d61893fb7bbe364bf25074feaf0b178dac64e6
DIFF: https://github.com/llvm/llvm-project/commit/30d61893fb7bbe364bf25074feaf0b178dac64e6.diff

LOG: [mlir] provide C API and Python bindings for symbol tables

Symbol tables are a largely useful top-level IR construct, for example, they
make it easy to access functions in a module by name instead of traversing the
list of module's operations to find the corresponding function.

Depends On D112886

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D112821

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/include/mlir-c/Support.h
    mlir/include/mlir/CAPI/IR.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/lib/CAPI/IR/Support.cpp
    mlir/test/CAPI/ir.c
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index ca0c45224f3a5..1610191256eea 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -54,6 +54,7 @@ DEFINE_C_API_STRUCT(MlirOperation, void);
 DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
 DEFINE_C_API_STRUCT(MlirBlock, void);
 DEFINE_C_API_STRUCT(MlirRegion, void);
+DEFINE_C_API_STRUCT(MlirSymbolTable, void);
 
 DEFINE_C_API_STRUCT(MlirAttribute, const void);
 DEFINE_C_API_STRUCT(MlirIdentifier, const void);
@@ -738,6 +739,47 @@ MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
 /// Returns the hash value of the type id.
 MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
 
+//===----------------------------------------------------------------------===//
+// Symbol and SymbolTable API.
+//===----------------------------------------------------------------------===//
+
+/// Returns the name of the attribute used to store symbol names compatible with
+/// symbol tables.
+MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName();
+
+/// Creates a symbol table for the given operation. If the operation does not
+/// have the SymbolTable trait, returns a null symbol table.
+MLIR_CAPI_EXPORTED MlirSymbolTable
+mlirSymbolTableCreate(MlirOperation operation);
+
+/// Returns true if the symbol table is null.
+static inline bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable) {
+  return !symbolTable.ptr;
+}
+
+/// Destroys the symbol table created with mlirSymbolTableCreate. This does not
+/// affect the operations in the table.
+MLIR_CAPI_EXPORTED void mlirSymbolTableDestroy(MlirSymbolTable symbolTable);
+
+/// Looks up a symbol with the given name in the given symbol table and returns
+/// the operation that corresponds to the symbol. If the symbol cannot be found,
+/// returns a null operation.
+MLIR_CAPI_EXPORTED MlirOperation
+mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name);
+
+/// Inserts the given operation into the given symbol table. The operation must
+/// have the symbol trait. If the symbol table already has a symbol with the
+/// same name, renames the symbol being inserted to ensure name uniqueness. Note
+/// that this does not move the operation itself into the block of the symbol
+/// table operation, this should be done separately. Returns the name of the
+/// symbol after insertion.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation);
+
+/// Removes the given operation from the symbol table and erases it.
+MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable,
+                                             MlirOperation operation);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h
index 315f6c4564eba..f20e58fe62317 100644
--- a/mlir/include/mlir-c/Support.h
+++ b/mlir/include/mlir-c/Support.h
@@ -79,6 +79,10 @@ inline static MlirStringRef mlirStringRefCreate(const char *str,
 MLIR_CAPI_EXPORTED MlirStringRef
 mlirStringRefCreateFromCString(const char *str);
 
+/// Returns true if two string references are equal, false otherwise.
+MLIR_CAPI_EXPORTED bool mlirStringRefEqual(MlirStringRef string,
+                                           MlirStringRef other);
+
 /// A callback for returning string references.
 ///
 /// This function is called back by the functions that need to return a

diff  --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index d5e961367e79a..a864175d01912 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -27,6 +27,7 @@ DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
 DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
 DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags)
 DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
+DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable);
 
 DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute)
 DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier)

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d47d06a3aa75e..8f451cf34bed3 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1530,6 +1530,57 @@ PyValue PyValue::createFromCapsule(pybind11::object capsule) {
   return PyValue(ownerRef, value);
 }
 
+//------------------------------------------------------------------------------
+// PySymbolTable.
+//------------------------------------------------------------------------------
+
+PySymbolTable::PySymbolTable(PyOperationBase &operation)
+    : operation(operation.getOperation().getRef()) {
+  symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
+  if (mlirSymbolTableIsNull(symbolTable)) {
+    throw py::cast_error("Operation is not a Symbol Table.");
+  }
+}
+
+py::object PySymbolTable::dunderGetItem(const std::string &name) {
+  operation->checkValid();
+  MlirOperation symbol = mlirSymbolTableLookup(
+      symbolTable, mlirStringRefCreate(name.data(), name.length()));
+  if (mlirOperationIsNull(symbol))
+    throw py::key_error("Symbol '" + name + "' not in the symbol table.");
+
+  return PyOperation::forOperation(operation->getContext(), symbol,
+                                   operation.getObject())
+      ->createOpView();
+}
+
+void PySymbolTable::erase(PyOperationBase &symbol) {
+  operation->checkValid();
+  symbol.getOperation().checkValid();
+  mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
+  // The operation is also erased, so we must invalidate it. There may be Python
+  // references to this operation so we don't want to delete it from the list of
+  // live operations here.
+  symbol.getOperation().valid = false;
+}
+
+void PySymbolTable::dunderDel(const std::string &name) {
+  py::object operation = dunderGetItem(name);
+  erase(py::cast<PyOperationBase &>(operation));
+}
+
+PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
+  operation->checkValid();
+  symbol.getOperation().checkValid();
+  MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
+      symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
+  if (mlirAttributeIsNull(symbolAttr))
+    throw py::value_error("Expected operation to have a symbol name.");
+  return PyAttribute(
+      symbol.getOperation().getContext(),
+      mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
+}
+
 namespace {
 /// CRTP base class for Python MLIR values that subclass Value and should be
 /// castable from it. The value hierarchy is one level deep and is not supposed
@@ -2670,6 +2721,20 @@ void mlir::python::populateIRCore(py::module &m) {
   PyBlockArgument::bind(m);
   PyOpResult::bind(m);
 
+  //----------------------------------------------------------------------------
+  // Mapping of SymbolTable.
+  //----------------------------------------------------------------------------
+  py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
+      .def(py::init<PyOperationBase &>())
+      .def("__getitem__", &PySymbolTable::dunderGetItem)
+      .def("insert", &PySymbolTable::insert)
+      .def("erase", &PySymbolTable::erase)
+      .def("__delitem__", &PySymbolTable::dunderDel)
+      .def("__contains__", [](PySymbolTable &table, const std::string &name) {
+        return !mlirOperationIsNull(mlirSymbolTableLookup(
+            table, mlirStringRefCreate(name.data(), name.length())));
+      });
+
   // Container bindings.
   PyBlockArgumentList::bind(m);
   PyBlockIterator::bind(m);

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 73924fc74bdbf..eb5c2385a165d 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -32,6 +32,7 @@ class DefaultingPyMlirContext;
 class PyModule;
 class PyOperation;
 class PyType;
+class PySymbolTable;
 class PyValue;
 
 /// Template for a reference to a concrete type which captures a python
@@ -513,6 +514,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   bool valid = true;
 
   friend class PyOperationBase;
+  friend class PySymbolTable;
 };
 
 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
@@ -876,6 +878,38 @@ class PyIntegerSet : public BaseContextObject {
   MlirIntegerSet integerSet;
 };
 
+/// Bindings for MLIR symbol tables.
+class PySymbolTable {
+public:
+  /// Constructs a symbol table for the given operation.
+  explicit PySymbolTable(PyOperationBase &operation);
+
+  /// Destroys the symbol table.
+  ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); }
+
+  /// Returns the symbol (opview) with the given name, throws if there is no
+  /// such symbol in the table.
+  pybind11::object dunderGetItem(const std::string &name);
+
+  /// Removes the given operation from the symbol table and erases it.
+  void erase(PyOperationBase &symbol);
+
+  /// Removes the operation with the given name from the symbol table and erases
+  /// it, throws if there is no such symbol in the table.
+  void dunderDel(const std::string &name);
+
+  /// Inserts the given operation into the symbol table. The operation must have
+  /// the symbol trait.
+  PyAttribute insert(PyOperationBase &symbol);
+
+  /// Casts the bindings class into the C API structure.
+  operator MlirSymbolTable() { return symbolTable; }
+
+private:
+  PyOperationRef operation;
+  MlirSymbolTable symbolTable;
+};
+
 void populateIRAffine(pybind11::module &m);
 void populateIRAttributes(pybind11::module &m);
 void populateIRCore(pybind11::module &m);

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 6f617dc19269d..13490b342d9f7 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -763,3 +763,36 @@ bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
 size_t mlirTypeIDHashValue(MlirTypeID typeID) {
   return hash_value(unwrap(typeID));
 }
+
+//===----------------------------------------------------------------------===//
+// Symbol and SymbolTable API.
+//===----------------------------------------------------------------------===//
+
+MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
+  return wrap(SymbolTable::getSymbolAttrName());
+}
+
+MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
+  if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>())
+    return wrap(static_cast<SymbolTable *>(nullptr));
+  return wrap(new SymbolTable(unwrap(operation)));
+}
+
+void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
+  delete unwrap(symbolTable);
+}
+
+MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
+                                    MlirStringRef name) {
+  return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length)));
+}
+
+MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
+                                    MlirOperation operation) {
+  return wrap(unwrap(symbolTable)->insert(unwrap(operation)));
+}
+
+void mlirSymbolTableErase(MlirSymbolTable symbolTable,
+                          MlirOperation operation) {
+  unwrap(symbolTable)->erase(unwrap(operation));
+}

diff  --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp
index e4b409906297d..b6e1f9180c771 100644
--- a/mlir/lib/CAPI/IR/Support.cpp
+++ b/mlir/lib/CAPI/IR/Support.cpp
@@ -7,9 +7,15 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir-c/Support.h"
+#include "llvm/ADT/StringRef.h"
 
 #include <cstring>
 
 MlirStringRef mlirStringRefCreateFromCString(const char *str) {
   return mlirStringRefCreate(str, strlen(str));
 }
+
+bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
+  return llvm::StringRef(string.data, string.length) ==
+         llvm::StringRef(other.data, other.length);
+}

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index fd245cd6afb17..ef555377c1ce5 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1692,57 +1692,6 @@ static void deleteUserData(void *userData) {
           (intptr_t)userData);
 }
 
-void testDiagnostics() {
-  MlirContext ctx = mlirContextCreate();
-  MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
-      ctx, errorHandler, (void *)42, deleteUserData);
-  fprintf(stderr, "@test_diagnostics\n");
-  MlirLocation unknownLoc = mlirLocationUnknownGet(ctx);
-  mlirEmitError(unknownLoc, "test diagnostics");
-  MlirLocation fileLineColLoc = mlirLocationFileLineColGet(
-      ctx, mlirStringRefCreateFromCString("file.c"), 1, 2);
-  mlirEmitError(fileLineColLoc, "test diagnostics");
-  MlirLocation callSiteLoc = mlirLocationCallSiteGet(
-      mlirLocationFileLineColGet(
-          ctx, mlirStringRefCreateFromCString("other-file.c"), 2, 3),
-      fileLineColLoc);
-  mlirEmitError(callSiteLoc, "test diagnostics");
-  MlirLocation null = {0};
-  MlirLocation nameLoc =
-      mlirLocationNameGet(ctx, mlirStringRefCreateFromCString("named"), null);
-  mlirEmitError(nameLoc, "test diagnostics");
-  MlirLocation locs[2] = {nameLoc, callSiteLoc};
-  MlirAttribute nullAttr = {0};
-  MlirLocation fusedLoc = mlirLocationFusedGet(ctx, 2, locs, nullAttr);
-  mlirEmitError(fusedLoc, "test diagnostics");
-  mlirContextDetachDiagnosticHandler(ctx, id);
-  mlirEmitError(unknownLoc, "more test diagnostics");
-  // CHECK-LABEL: @test_diagnostics
-  // CHECK: processing diagnostic (userData: 42) <<
-  // CHECK:   test diagnostics
-  // CHECK:   loc(unknown)
-  // CHECK: >> end of diagnostic (userData: 42)
-  // CHECK: processing diagnostic (userData: 42) <<
-  // CHECK:   test diagnostics
-  // CHECK:   loc("file.c":1:2)
-  // CHECK: >> end of diagnostic (userData: 42)
-  // CHECK: processing diagnostic (userData: 42) <<
-  // CHECK:   test diagnostics
-  // CHECK:   loc(callsite("other-file.c":2:3 at "file.c":1:2))
-  // CHECK: >> end of diagnostic (userData: 42)
-  // CHECK: processing diagnostic (userData: 42) <<
-  // CHECK:   test diagnostics
-  // CHECK:   loc("named")
-  // CHECK: >> end of diagnostic (userData: 42)
-  // CHECK: processing diagnostic (userData: 42) <<
-  // CHECK:   test diagnostics
-  // CHECK:   loc(fused["named", callsite("other-file.c":2:3 at "file.c":1:2)])
-  // CHECK: deleting user data (userData: 42)
-  // CHECK-NOT: processing diagnostic
-  // CHECK:     more test diagnostics
-  mlirContextDestroy(ctx);
-}
-
 int testTypeID(MlirContext ctx) {
   fprintf(stderr, "@testTypeID\n");
 
@@ -1841,6 +1790,148 @@ int testTypeID(MlirContext ctx) {
   return 0;
 }
 
+int testSymbolTable(MlirContext ctx) {
+  fprintf(stderr, "@testSymbolTable\n");
+
+  const char *moduleString = "func private @foo()"
+                             "func private @bar()";
+  const char *otherModuleString = "func private @qux()"
+                                  "func private @foo()";
+
+  MlirModule module =
+      mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString));
+  MlirModule otherModule = mlirModuleCreateParse(
+      ctx, mlirStringRefCreateFromCString(otherModuleString));
+
+  MlirSymbolTable symbolTable =
+      mlirSymbolTableCreate(mlirModuleGetOperation(module));
+
+  MlirOperation funcFoo =
+      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("foo"));
+  if (mlirOperationIsNull(funcFoo))
+    return 1;
+
+  MlirOperation funcBar =
+      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar"));
+  if (mlirOperationEqual(funcFoo, funcBar))
+    return 2;
+
+  MlirOperation missing =
+      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
+  if (!mlirOperationIsNull(missing))
+    return 3;
+
+  MlirBlock moduleBody = mlirModuleGetBody(module);
+  MlirBlock otherModuleBody = mlirModuleGetBody(otherModule);
+  MlirOperation operation = mlirBlockGetFirstOperation(otherModuleBody);
+  mlirOperationRemoveFromParent(operation);
+  mlirBlockAppendOwnedOperation(moduleBody, operation);
+
+  // At this moment, the operation is still missing from the symbol table.
+  MlirOperation stillMissing =
+      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
+  if (!mlirOperationIsNull(stillMissing))
+    return 4;
+
+  // After it is added to the symbol table, and not only the operation with
+  // which the table is associated, it can be looked up.
+  mlirSymbolTableInsert(symbolTable, operation);
+  MlirOperation funcQux =
+      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("qux"));
+  if (!mlirOperationEqual(operation, funcQux))
+    return 5;
+
+  // Erasing from the symbol table also removes the operation.
+  mlirSymbolTableErase(symbolTable, funcBar);
+  MlirOperation nowMissing =
+      mlirSymbolTableLookup(symbolTable, mlirStringRefCreateFromCString("bar"));
+  if (!mlirOperationIsNull(nowMissing))
+    return 6;
+
+  // Adding a symbol with the same name to the table should rename.
+  MlirOperation duplicateNameOp = mlirBlockGetFirstOperation(otherModuleBody);
+  mlirOperationRemoveFromParent(duplicateNameOp);
+  mlirBlockAppendOwnedOperation(moduleBody, duplicateNameOp);
+  MlirAttribute newName = mlirSymbolTableInsert(symbolTable, duplicateNameOp);
+  MlirStringRef newNameStr = mlirStringAttrGetValue(newName);
+  if (mlirStringRefEqual(newNameStr, mlirStringRefCreateFromCString("foo")))
+    return 7;
+  MlirAttribute updatedName = mlirOperationGetAttributeByName(
+      duplicateNameOp, mlirSymbolTableGetSymbolAttributeName());
+  if (!mlirAttributeEqual(updatedName, newName))
+    return 8;
+
+  mlirOperationDump(mlirModuleGetOperation(module));
+  mlirOperationDump(mlirModuleGetOperation(otherModule));
+  // clang-format off
+  // CHECK-LABEL: @testSymbolTable
+  // CHECK: module
+  // CHECK:   func private @foo
+  // CHECK:   func private @qux
+  // CHECK:   func private @foo{{.+}}
+  // CHECK: module
+  // CHECK-NOT: @qux
+  // CHECK-NOT: @foo
+  // clang-format on
+
+  mlirSymbolTableDestroy(symbolTable);
+  mlirModuleDestroy(module);
+  mlirModuleDestroy(otherModule);
+
+  return 0;
+}
+
+void testDiagnostics() {
+  MlirContext ctx = mlirContextCreate();
+  MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler(
+      ctx, errorHandler, (void *)42, deleteUserData);
+  fprintf(stderr, "@test_diagnostics\n");
+  MlirLocation unknownLoc = mlirLocationUnknownGet(ctx);
+  mlirEmitError(unknownLoc, "test diagnostics");
+  MlirLocation fileLineColLoc = mlirLocationFileLineColGet(
+      ctx, mlirStringRefCreateFromCString("file.c"), 1, 2);
+  mlirEmitError(fileLineColLoc, "test diagnostics");
+  MlirLocation callSiteLoc = mlirLocationCallSiteGet(
+      mlirLocationFileLineColGet(
+          ctx, mlirStringRefCreateFromCString("other-file.c"), 2, 3),
+      fileLineColLoc);
+  mlirEmitError(callSiteLoc, "test diagnostics");
+  MlirLocation null = {0};
+  MlirLocation nameLoc =
+      mlirLocationNameGet(ctx, mlirStringRefCreateFromCString("named"), null);
+  mlirEmitError(nameLoc, "test diagnostics");
+  MlirLocation locs[2] = {nameLoc, callSiteLoc};
+  MlirAttribute nullAttr = {0};
+  MlirLocation fusedLoc = mlirLocationFusedGet(ctx, 2, locs, nullAttr);
+  mlirEmitError(fusedLoc, "test diagnostics");
+  mlirContextDetachDiagnosticHandler(ctx, id);
+  mlirEmitError(unknownLoc, "more test diagnostics");
+  // CHECK-LABEL: @test_diagnostics
+  // CHECK: processing diagnostic (userData: 42) <<
+  // CHECK:   test diagnostics
+  // CHECK:   loc(unknown)
+  // CHECK: >> end of diagnostic (userData: 42)
+  // CHECK: processing diagnostic (userData: 42) <<
+  // CHECK:   test diagnostics
+  // CHECK:   loc("file.c":1:2)
+  // CHECK: >> end of diagnostic (userData: 42)
+  // CHECK: processing diagnostic (userData: 42) <<
+  // CHECK:   test diagnostics
+  // CHECK:   loc(callsite("other-file.c":2:3 at "file.c":1:2))
+  // CHECK: >> end of diagnostic (userData: 42)
+  // CHECK: processing diagnostic (userData: 42) <<
+  // CHECK:   test diagnostics
+  // CHECK:   loc("named")
+  // CHECK: >> end of diagnostic (userData: 42)
+  // CHECK: processing diagnostic (userData: 42) <<
+  // CHECK:   test diagnostics
+  // CHECK:   loc(fused["named", callsite("other-file.c":2:3 at "file.c":1:2)])
+  // CHECK: deleting user data (userData: 42)
+  // CHECK-NOT: processing diagnostic
+  // CHECK:     more test diagnostics
+  mlirContextDestroy(ctx);
+}
+
 int main() {
   MlirContext ctx = mlirContextCreate();
   mlirRegisterAllDialects(ctx);
@@ -1870,9 +1961,10 @@ int main() {
     return 11;
   if (testClone())
     return 12;
-  if (testTypeID(ctx)) {
+  if (testTypeID(ctx))
     return 13;
-  }
+  if (testSymbolTable(ctx))
+    return 14;
 
   mlirContextDestroy(ctx);
 

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index c94c22ea53a0b..950c31217fa22 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -773,7 +773,7 @@ def testAppendMoveFromAnotherBlock():
   with Context():
     m1 = Module.parse("func private @foo()")
     m2 = Module.parse("func private @bar()")
-    func = m1.body.operations[0]  
+    func = m1.body.operations[0]
     m2.body.append(func)
 
     # CHECK: module
@@ -803,3 +803,76 @@ def testDetachFromParent():
 
     print(m1)
     # CHECK-NOT: func private @foo
+
+
+# CHECK-LABEL: TEST: testSymbolTable
+ at run
+def testSymbolTable():
+  with Context() as ctx:
+    ctx.allow_unregistered_dialects = True
+    m1 = Module.parse("""
+      func private @foo()
+      func private @bar()""")
+    m2 = Module.parse("""
+      func private @qux()
+      func private @foo()
+      "foo.bar"() : () -> ()""")
+
+    symbol_table = SymbolTable(m1.operation)
+
+    # CHECK: func private @foo
+    # CHECK: func private @bar
+    assert "foo" in symbol_table
+    print(symbol_table["foo"])
+    assert "bar" in symbol_table
+    bar = symbol_table["bar"]
+    print(symbol_table["bar"])
+
+    assert "qux" not in symbol_table
+
+    del symbol_table["bar"]
+    try:
+      symbol_table.erase(symbol_table["bar"])
+    except KeyError:
+      pass
+    else:
+      assert False, "expected KeyError"
+
+    # CHECK: module
+    # CHECK:   func private @foo()
+    print(m1)
+    assert "bar" not in symbol_table
+
+    try:
+      print(bar)
+    except RuntimeError as e:
+      if "the operation has been invalidated" not in str(e):
+        raise
+    else:
+      assert False, "expected RuntimeError due to invalidated operation"
+
+    qux = m2.body.operations[0]
+    m1.body.append(qux)
+    symbol_table.insert(qux)
+    assert "qux" in symbol_table
+
+    # Check that insertion actually renames this symbol in the symbol table.
+    foo2 = m2.body.operations[0]
+    m1.body.append(foo2)
+    updated_name = symbol_table.insert(foo2)
+    assert foo2.name.value != "foo"
+    assert foo2.name == updated_name
+
+    # CHECK: module
+    # CHECK:   func private @foo()
+    # CHECK:   func private @qux()
+    # CHECK:   func private @foo{{.*}}
+    print(m1)
+
+    try:
+      symbol_table.insert(m2.body.operations[0])
+    except ValueError as e:
+      if "Expected operation to have a symbol name" not in str(e):
+        raise
+    else:
+      assert False, "exepcted ValueError when adding a non-symbol"


        


More information about the Mlir-commits mailing list