[Mlir-commits] [mlir] 4eee9ef - Add SymbolRefAttr to python bindings

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 5 18:52:21 PDT 2023


Author: max
Date: 2023-07-05T20:51:33-05:00
New Revision: 4eee9ef9768b1335800878b8f0b7aa3e549e41dc

URL: https://github.com/llvm/llvm-project/commit/4eee9ef9768b1335800878b8f0b7aa3e549e41dc
DIFF: https://github.com/llvm/llvm-project/commit/4eee9ef9768b1335800878b8f0b7aa3e549e41dc.diff

LOG: Add SymbolRefAttr to python bindings

Differential Revision: https://reviews.llvm.org/D154541

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/python/mlir/ir.py
    mlir/test/python/dialects/ml_program.py
    mlir/test/python/ir/attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index b760dd0cdb9a55..63198192453efb 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -283,9 +283,6 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx,
 MLIR_CAPI_EXPORTED MlirStringRef
 mlirFlatSymbolRefAttrGetValue(MlirAttribute attr);
 
-/// Returns the typeID of an FlatSymbolRef attribute.
-MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void);
-
 //===----------------------------------------------------------------------===//
 // Type attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 99881b35c96d31..4ee06fa7a6d751 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -442,14 +442,59 @@ class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
   }
 };
 
+class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
+  static constexpr const char *pyClassName = "SymbolRefAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static MlirAttribute fromList(const std::vector<std::string> &symbols,
+                                PyMlirContext &context) {
+    if (symbols.empty())
+      throw std::runtime_error("SymbolRefAttr must be composed of at least "
+                               "one symbol.");
+    MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
+    SmallVector<MlirAttribute, 3> referenceAttrs;
+    for (size_t i = 1; i < symbols.size(); ++i) {
+      referenceAttrs.push_back(
+          mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
+    }
+    return mlirSymbolRefAttrGet(context.get(), rootSymbol,
+                                referenceAttrs.size(), referenceAttrs.data());
+  }
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](const std::vector<std::string> &symbols,
+           DefaultingPyMlirContext context) {
+          return PySymbolRefAttribute::fromList(symbols, context.resolve());
+        },
+        py::arg("symbols"), py::arg("context") = py::none(),
+        "Gets a uniqued SymbolRef attribute from a list of symbol names");
+    c.def_property_readonly(
+        "value",
+        [](PySymbolRefAttribute &self) {
+          std::vector<std::string> symbols = {
+              unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
+          for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
+               ++i)
+            symbols.push_back(
+                unwrap(mlirSymbolRefAttrGetRootReference(
+                           mlirSymbolRefAttrGetNestedReference(self, i)))
+                    .str());
+          return symbols;
+        },
+        "Returns the value of the SymbolRef attribute as a list[str]");
+  }
+};
+
 class PyFlatSymbolRefAttribute
     : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
   static constexpr const char *pyClassName = "FlatSymbolRefAttr";
   using PyConcreteAttribute::PyConcreteAttribute;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFlatSymbolRefAttrGetTypeID;
 
   static void bindDerived(ClassTy &c) {
     c.def_static(
@@ -1167,6 +1212,16 @@ py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
   throw py::cast_error(msg);
 }
 
+py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
+  if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
+    return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
+  if (PySymbolRefAttribute::isaFunction(pyAttribute))
+    return py::cast(PySymbolRefAttribute(pyAttribute));
+  std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
+                    std::string(py::repr(py::cast(pyAttribute))) + ")";
+  throw py::cast_error(msg);
+}
+
 } // namespace
 
 void mlir::python::populateIRAttributes(py::module &m) {
@@ -1201,6 +1256,11 @@ void mlir::python::populateIRAttributes(py::module &m) {
       pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
 
   PyDictAttribute::bind(m);
+  PySymbolRefAttribute::bind(m);
+  PyGlobals::get().registerTypeCaster(
+      mlirSymbolRefAttrGetTypeID(),
+      pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
+
   PyFlatSymbolRefAttribute::bind(m);
   PyOpaqueAttribute::bind(m);
   PyFloatAttribute::bind(m);

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index da8a58de775a30..3ab6d57b41690d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3131,13 +3131,13 @@ void mlir::python::populateIRCore(py::module &m) {
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
       .def_static(
           "parse",
-          [](std::string attrSpec, DefaultingPyMlirContext context) {
+          [](const std::string &attrSpec, DefaultingPyMlirContext context) {
             PyMlirContext::ErrorCapture errors(context->getRef());
-            MlirAttribute type = mlirAttributeParseGet(
+            MlirAttribute attr = mlirAttributeParseGet(
                 context->get(), toMlirStringRef(attrSpec));
-            if (mlirAttributeIsNull(type))
+            if (mlirAttributeIsNull(attr))
               throw MLIRError("Unable to parse attribute", errors.take());
-            return type;
+            return attr;
           },
           py::arg("asm"), py::arg("context") = py::none(),
           "Parses an attribute from an assembly form. Raises an MLIRError on "

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 289913d4f5480e..de221ddbfa7a92 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -305,10 +305,6 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
   return wrap(llvm::cast<FlatSymbolRefAttr>(unwrap(attr)).getValue());
 }
 
-MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) {
-  return wrap(FlatSymbolRefAttr::getTypeID());
-}
-
 //===----------------------------------------------------------------------===//
 // Type attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 76077acb6a579c..e36736f2974f0c 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -73,6 +73,14 @@ def _symbolNameAttr(x, context):
 
 @register_attribute_builder("SymbolRefAttr")
 def _symbolRefAttr(x, context):
+    if isinstance(x, list):
+        return SymbolRefAttr.get(x, context=context)
+    else:
+        return FlatSymbolRefAttr.get(x, context=context)
+
+
+ at register_attribute_builder("FlatSymbolRefAttr")
+def _flatSymbolRefAttr(x, context):
     return FlatSymbolRefAttr.get(x, context=context)
 
 
@@ -105,6 +113,7 @@ def _f64ArrayAttr(x, context):
 def _denseI64ArrayAttr(x, context):
     return DenseI64ArrayAttr.get(x, context=context)
 
+
 @register_attribute_builder("DenseBoolArrayAttr")
 def _denseBoolArrayAttr(x, context):
     return DenseBoolArrayAttr.get(x, context=context)

diff  --git a/mlir/test/python/dialects/ml_program.py b/mlir/test/python/dialects/ml_program.py
index f16de2add37995..edffcfbf0138d3 100644
--- a/mlir/test/python/dialects/ml_program.py
+++ b/mlir/test/python/dialects/ml_program.py
@@ -2,7 +2,7 @@
 # This is just a smoke test that the dialect is functional.
 
 from mlir.ir import *
-from mlir.dialects import ml_program
+from mlir.dialects import ml_program, arith, builtin
 
 
 def constructAndPrintInModule(f):
@@ -26,3 +26,21 @@ def testFuncOp():
     with InsertionPoint(block):
         # CHECK: ml_program.return
         ml_program.ReturnOp([block.arguments[0]])
+
+
+# CHECK-LABEL: testGlobalStoreOp
+ at constructAndPrintInModule
+def testGlobalStoreOp():
+    # CHECK: %cst = arith.constant 4.242000e+01 : f32
+    cst = arith.ConstantOp(value=42.42, result=F32Type.get())
+
+    m = builtin.ModuleOp()
+    m.sym_name = StringAttr.get("symbol1")
+    m.sym_visibility = StringAttr.get("public")
+    # CHECK: module @symbol1 attributes {sym_visibility = "public"} {
+    # CHECK:   ml_program.global public mutable @symbol2 : f32
+    # CHECK: }
+    with InsertionPoint(m.body):
+        ml_program.GlobalOp("symbol2", F32Type.get(), is_mutable=True)
+    # CHECK: ml_program.global_store @symbol1::@symbol2 = %cst : f32
+    ml_program.GlobalStoreOp(["symbol1", "symbol2"], cst)

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 221c186ae7d521..28729e86ccd4c0 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -228,7 +228,7 @@ def testBoolAttr():
 @run
 def testFlatSymbolRefAttr():
     with Context() as ctx:
-        sattr = FlatSymbolRefAttr(Attribute.parse("@symbol"))
+        sattr = Attribute.parse("@symbol")
         # CHECK: symattr value: symbol
         print("symattr value:", sattr.value)
 
@@ -237,6 +237,21 @@ def testFlatSymbolRefAttr():
         print("default_get:", FlatSymbolRefAttr.get("foobar"))
 
 
+# CHECK-LABEL: TEST: testSymbolRefAttr
+ at run
+def testSymbolRefAttr():
+    with Context() as ctx:
+        sattr = Attribute.parse("@symbol1::@symbol2")
+        # CHECK: symattr value: ['symbol1', 'symbol2']
+        print("symattr value:", sattr.value)
+
+        # CHECK: default_get: @symbol1::@symbol2
+        print("default_get:", SymbolRefAttr.get(["symbol1", "symbol2"]))
+
+        # CHECK: default_get: @"@symbol1"::@"@symbol2"
+        print("default_get:", SymbolRefAttr.get(["@symbol1", "@symbol2"]))
+
+
 # CHECK-LABEL: TEST: testOpaqueAttr
 @run
 def testOpaqueAttr():


        


More information about the Mlir-commits mailing list