[Mlir-commits] [mlir] 4eee9ef - Add SymbolRefAttr to python bindings
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 5 18:52:21 PDT 2023
Author: max
Date: 2023-07-05T20:51:33-05:00
New Revision: 4eee9ef9768b1335800878b8f0b7aa3e549e41dc
URL: https://github.com/llvm/llvm-project/commit/4eee9ef9768b1335800878b8f0b7aa3e549e41dc
DIFF: https://github.com/llvm/llvm-project/commit/4eee9ef9768b1335800878b8f0b7aa3e549e41dc.diff
LOG: Add SymbolRefAttr to python bindings
Differential Revision: https://reviews.llvm.org/D154541
Added:
Modified:
mlir/include/mlir-c/BuiltinAttributes.h
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/python/mlir/ir.py
mlir/test/python/dialects/ml_program.py
mlir/test/python/ir/attributes.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index b760dd0cdb9a55..63198192453efb 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -283,9 +283,6 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx,
MLIR_CAPI_EXPORTED MlirStringRef
mlirFlatSymbolRefAttrGetValue(MlirAttribute attr);
-/// Returns the typeID of an FlatSymbolRef attribute.
-MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void);
-
//===----------------------------------------------------------------------===//
// Type attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 99881b35c96d31..4ee06fa7a6d751 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -442,14 +442,59 @@ class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
}
};
+class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
+ static constexpr const char *pyClassName = "SymbolRefAttr";
+ using PyConcreteAttribute::PyConcreteAttribute;
+
+ static MlirAttribute fromList(const std::vector<std::string> &symbols,
+ PyMlirContext &context) {
+ if (symbols.empty())
+ throw std::runtime_error("SymbolRefAttr must be composed of at least "
+ "one symbol.");
+ MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
+ SmallVector<MlirAttribute, 3> referenceAttrs;
+ for (size_t i = 1; i < symbols.size(); ++i) {
+ referenceAttrs.push_back(
+ mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
+ }
+ return mlirSymbolRefAttrGet(context.get(), rootSymbol,
+ referenceAttrs.size(), referenceAttrs.data());
+ }
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](const std::vector<std::string> &symbols,
+ DefaultingPyMlirContext context) {
+ return PySymbolRefAttribute::fromList(symbols, context.resolve());
+ },
+ py::arg("symbols"), py::arg("context") = py::none(),
+ "Gets a uniqued SymbolRef attribute from a list of symbol names");
+ c.def_property_readonly(
+ "value",
+ [](PySymbolRefAttribute &self) {
+ std::vector<std::string> symbols = {
+ unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
+ for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
+ ++i)
+ symbols.push_back(
+ unwrap(mlirSymbolRefAttrGetRootReference(
+ mlirSymbolRefAttrGetNestedReference(self, i)))
+ .str());
+ return symbols;
+ },
+ "Returns the value of the SymbolRef attribute as a list[str]");
+ }
+};
+
class PyFlatSymbolRefAttribute
: public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
public:
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
static constexpr const char *pyClassName = "FlatSymbolRefAttr";
using PyConcreteAttribute::PyConcreteAttribute;
- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
- mlirFlatSymbolRefAttrGetTypeID;
static void bindDerived(ClassTy &c) {
c.def_static(
@@ -1167,6 +1212,16 @@ py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
throw py::cast_error(msg);
}
+py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
+ if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
+ return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
+ if (PySymbolRefAttribute::isaFunction(pyAttribute))
+ return py::cast(PySymbolRefAttribute(pyAttribute));
+ std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
+ std::string(py::repr(py::cast(pyAttribute))) + ")";
+ throw py::cast_error(msg);
+}
+
} // namespace
void mlir::python::populateIRAttributes(py::module &m) {
@@ -1201,6 +1256,11 @@ void mlir::python::populateIRAttributes(py::module &m) {
pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
PyDictAttribute::bind(m);
+ PySymbolRefAttribute::bind(m);
+ PyGlobals::get().registerTypeCaster(
+ mlirSymbolRefAttrGetTypeID(),
+ pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
+
PyFlatSymbolRefAttribute::bind(m);
PyOpaqueAttribute::bind(m);
PyFloatAttribute::bind(m);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index da8a58de775a30..3ab6d57b41690d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3131,13 +3131,13 @@ void mlir::python::populateIRCore(py::module &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
.def_static(
"parse",
- [](std::string attrSpec, DefaultingPyMlirContext context) {
+ [](const std::string &attrSpec, DefaultingPyMlirContext context) {
PyMlirContext::ErrorCapture errors(context->getRef());
- MlirAttribute type = mlirAttributeParseGet(
+ MlirAttribute attr = mlirAttributeParseGet(
context->get(), toMlirStringRef(attrSpec));
- if (mlirAttributeIsNull(type))
+ if (mlirAttributeIsNull(attr))
throw MLIRError("Unable to parse attribute", errors.take());
- return type;
+ return attr;
},
py::arg("asm"), py::arg("context") = py::none(),
"Parses an attribute from an assembly form. Raises an MLIRError on "
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 289913d4f5480e..de221ddbfa7a92 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -305,10 +305,6 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
return wrap(llvm::cast<FlatSymbolRefAttr>(unwrap(attr)).getValue());
}
-MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) {
- return wrap(FlatSymbolRefAttr::getTypeID());
-}
-
//===----------------------------------------------------------------------===//
// Type attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 76077acb6a579c..e36736f2974f0c 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -73,6 +73,14 @@ def _symbolNameAttr(x, context):
@register_attribute_builder("SymbolRefAttr")
def _symbolRefAttr(x, context):
+ if isinstance(x, list):
+ return SymbolRefAttr.get(x, context=context)
+ else:
+ return FlatSymbolRefAttr.get(x, context=context)
+
+
+ at register_attribute_builder("FlatSymbolRefAttr")
+def _flatSymbolRefAttr(x, context):
return FlatSymbolRefAttr.get(x, context=context)
@@ -105,6 +113,7 @@ def _f64ArrayAttr(x, context):
def _denseI64ArrayAttr(x, context):
return DenseI64ArrayAttr.get(x, context=context)
+
@register_attribute_builder("DenseBoolArrayAttr")
def _denseBoolArrayAttr(x, context):
return DenseBoolArrayAttr.get(x, context=context)
diff --git a/mlir/test/python/dialects/ml_program.py b/mlir/test/python/dialects/ml_program.py
index f16de2add37995..edffcfbf0138d3 100644
--- a/mlir/test/python/dialects/ml_program.py
+++ b/mlir/test/python/dialects/ml_program.py
@@ -2,7 +2,7 @@
# This is just a smoke test that the dialect is functional.
from mlir.ir import *
-from mlir.dialects import ml_program
+from mlir.dialects import ml_program, arith, builtin
def constructAndPrintInModule(f):
@@ -26,3 +26,21 @@ def testFuncOp():
with InsertionPoint(block):
# CHECK: ml_program.return
ml_program.ReturnOp([block.arguments[0]])
+
+
+# CHECK-LABEL: testGlobalStoreOp
+ at constructAndPrintInModule
+def testGlobalStoreOp():
+ # CHECK: %cst = arith.constant 4.242000e+01 : f32
+ cst = arith.ConstantOp(value=42.42, result=F32Type.get())
+
+ m = builtin.ModuleOp()
+ m.sym_name = StringAttr.get("symbol1")
+ m.sym_visibility = StringAttr.get("public")
+ # CHECK: module @symbol1 attributes {sym_visibility = "public"} {
+ # CHECK: ml_program.global public mutable @symbol2 : f32
+ # CHECK: }
+ with InsertionPoint(m.body):
+ ml_program.GlobalOp("symbol2", F32Type.get(), is_mutable=True)
+ # CHECK: ml_program.global_store @symbol1::@symbol2 = %cst : f32
+ ml_program.GlobalStoreOp(["symbol1", "symbol2"], cst)
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 221c186ae7d521..28729e86ccd4c0 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -228,7 +228,7 @@ def testBoolAttr():
@run
def testFlatSymbolRefAttr():
with Context() as ctx:
- sattr = FlatSymbolRefAttr(Attribute.parse("@symbol"))
+ sattr = Attribute.parse("@symbol")
# CHECK: symattr value: symbol
print("symattr value:", sattr.value)
@@ -237,6 +237,21 @@ def testFlatSymbolRefAttr():
print("default_get:", FlatSymbolRefAttr.get("foobar"))
+# CHECK-LABEL: TEST: testSymbolRefAttr
+ at run
+def testSymbolRefAttr():
+ with Context() as ctx:
+ sattr = Attribute.parse("@symbol1::@symbol2")
+ # CHECK: symattr value: ['symbol1', 'symbol2']
+ print("symattr value:", sattr.value)
+
+ # CHECK: default_get: @symbol1::@symbol2
+ print("default_get:", SymbolRefAttr.get(["symbol1", "symbol2"]))
+
+ # CHECK: default_get: @"@symbol1"::@"@symbol2"
+ print("default_get:", SymbolRefAttr.get(["@symbol1", "@symbol2"]))
+
+
# CHECK-LABEL: TEST: testOpaqueAttr
@run
def testOpaqueAttr():
More information about the Mlir-commits
mailing list