[Mlir-commits] [mlir] 3781b79 - [mlir][py] Enable building ops with raw inputs
Jacques Pienaar
llvmlistbot at llvm.org
Wed Dec 21 10:10:37 PST 2022
Author: Jacques Pienaar
Date: 2022-12-21T10:10:31-08:00
New Revision: 3781b7905d8d808e5d4e97d597263f8ac48541b8
URL: https://github.com/llvm/llvm-project/commit/3781b7905d8d808e5d4e97d597263f8ac48541b8
DIFF: https://github.com/llvm/llvm-project/commit/3781b7905d8d808e5d4e97d597263f8ac48541b8.diff
LOG: [mlir][py] Enable building ops with raw inputs
For cases where we can automatically construct the Attribute allow for more
user-friendly input. This is consistent with C++ builder generation as well
choice of which single builder to generate here (most
specialized/user-friendly).
Registration of attribute builders from more pythonic input is all Python side.
The downside is that
* extra checking to see if user provided a custom builder in op builders,
* the ODS attribute name is load bearing
upside is that
* easily change these/register dialect specific ones in downstream projects,
* adding support/changing to different convenience builders are all along with
the rest of the convenience functions in Python (and no additional changes
to tablegen file or recompilation needed);
Allow for both building with Attributes as well as raw inputs. This change
should therefore be backwards compatible as well as allow for avoiding
recreating Attribute where already available.
Differential Revision: https://reviews.llvm.org/D139568
Added:
Modified:
mlir/docs/Bindings/Python.md
mlir/lib/Bindings/Python/Globals.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.cpp
mlir/python/mlir/ir.py
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/test/python/dialects/shape.py
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index cdb00dc22146e..a7b2b313ea423 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -743,6 +743,34 @@ with Context():
dictionary = DictAttr.get({"array": array, "unit": UnitAttr.get()})
```
+Custom builders for Attributes to be used during Operation creation can be
+registered by way of the `register_attribute_builder`. In particular the
+following is how a custom builder is registered for `I32Attr`:
+
+```python
+ at register_attribute_builder("I32Attr")
+def _i32Attr(x: int, context: Context):
+ return IntegerAttr.get(
+ IntegerType.get_signless(32, context=context), x)
+```
+
+This allows to invoke op creation of an op with a `I32Attr` with
+
+```python
+foo.Op(30)
+```
+
+The registration is based on the ODS name but registry is via pure python
+method. Only single custom builder is allowed to be registered per ODS attribute
+type (e.g., I32Attr can have only one, which can correspond to multiple of the
+underlying IntegerAttr type).
+
+instead of
+
+```python
+foo.Op(IntegerAttr.get(IndexType.get_signless(32, context=context), 30))
+```
+
## Style
In general, for the core parts of MLIR, the Python bindings should be largely
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 6613d2b6963c0..ba6cfb545b71b 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -58,6 +58,12 @@ class PyGlobals {
/// have a DIALECT_NAMESPACE attribute.
pybind11::object registerDialectDecorator(pybind11::object pyClass);
+ /// Adds a user-friendly Attribute builder.
+ /// Raises an exception if the mapping already exists.
+ /// This is intended to be called by implementation code.
+ void registerAttributeBuilder(const std::string &attributeKind,
+ pybind11::function pyFunc);
+
/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
@@ -71,6 +77,10 @@ class PyGlobals {
pybind11::object pyClass,
pybind11::object rawOpViewClass);
+ /// Returns the custom Attribute builder for Attribute kind.
+ std::optional<pybind11::function>
+ lookupAttributeBuilder(const std::string &attributeKind);
+
/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
llvm::Optional<pybind11::object>
@@ -92,6 +102,8 @@ class PyGlobals {
/// Map of operation name to custom subclass that directly initializes
/// the OpView base class (bypassing the user class constructor).
llvm::StringMap<pybind11::object> rawOpViewClassMap;
+ /// Map of attribute ODS name to custom builder.
+ llvm::StringMap<pybind11::function> attributeBuilderMap;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 794be974284cd..f2aa8da5bc79f 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -194,6 +194,29 @@ struct PyGlobalDebugFlag {
}
};
+struct PyAttrBuilderMap {
+ static bool dunderContains(const std::string &attributeKind) {
+ return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
+ }
+ static py::function dundeGetItemNamed(const std::string &attributeKind) {
+ auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
+ if (!builder)
+ throw py::key_error();
+ return *builder;
+ }
+ static void dundeSetItemNamed(const std::string &attributeKind,
+ py::function func) {
+ PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func));
+ }
+
+ static void bind(py::module &m) {
+ py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
+ .def_static("contains", &PyAttrBuilderMap::dunderContains)
+ .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
+ .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed);
+ }
+};
+
//------------------------------------------------------------------------------
// Collections.
//------------------------------------------------------------------------------
@@ -3283,4 +3306,7 @@ void mlir::python::populateIRCore(py::module &m) {
// Debug bindings.
PyGlobalDebugFlag::bind(m);
+
+ // Attribute builder getter.
+ PyAttrBuilderMap::bind(m);
}
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index b6d1df51f44e6..be6de5fd2f1ba 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -60,6 +60,17 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
loadedDialectModulesCache.insert(dialectNamespace);
}
+void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
+ py::function pyFunc) {
+ py::function &found = attributeBuilderMap[attributeKind];
+ if (found) {
+ throw std::runtime_error((llvm::Twine("Attribute builder for '") +
+ attributeKind + "' is already registered")
+ .str());
+ }
+ found = std::move(pyFunc);
+}
+
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
@@ -84,6 +95,22 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
}
+std::optional<py::function>
+PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
+ // Fast match against the class map first (common case).
+ const auto foundIt = attributeBuilderMap.find(attributeKind);
+ if (foundIt != attributeBuilderMap.end()) {
+ if (foundIt->second.is_none())
+ return std::nullopt;
+ assert(foundIt->second && "py::function is defined");
+ return foundIt->second;
+ }
+
+ // Not found and loading did not yield a registration. Negative cache.
+ attributeBuilderMap[attributeKind] = py::none();
+ return std::nullopt;
+}
+
llvm::Optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
loadDialectModule(dialectNamespace);
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 99e88ff743848..19986917d69bb 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -4,3 +4,44 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
+
+
+# Convenience decorator for registering user-friendly Attribute builders.
+def register_attribute_builder(kind):
+ def decorator_builder(func):
+ AttrBuilder.insert(kind, func)
+ return func
+ return decorator_builder
+
+
+ at register_attribute_builder("BoolAttr")
+def _boolAttr(x: bool, context: Context):
+ return BoolAttr.get(x, context=context)
+
+ at register_attribute_builder("IndexAttr")
+def _indexAttr(x: int, context: Context):
+ return IntegerAttr.get(IndexType.get(context=context), x)
+
+ at register_attribute_builder("I32Attr")
+def _i32Attr(x: int, context: Context):
+ return IntegerAttr.get(
+ IntegerType.get_signless(32, context=context), x)
+
+ at register_attribute_builder("I64Attr")
+def _i64Attr(x: int, context: Context):
+ return IntegerAttr.get(
+ IntegerType.get_signless(64, context=context), x)
+
+ at register_attribute_builder("SymbolNameAttr")
+def _symbolNameAttr(x: str, context: Context):
+ return StringAttr.get(x, context=context)
+
+try:
+ import numpy as np
+ @register_attribute_builder("IndexElementsAttr")
+ def _indexElementsAttr(x: list[int], context: Context):
+ return DenseElementsAttr.get(
+ np.array(x, dtype=np.int64), type=IndexType.get(context=context),
+ context=context)
+except ImportError:
+ pass
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 2dda3db53bb2e..97fe306d6c160 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -115,11 +115,14 @@ def AttributedOp : TestOp<"attributed_op"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
- // CHECK: attributes["i32attr"] = i32attr
- // CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr
+ // CHECK: attributes["i32attr"] = (i32attr if (
+ // CHECK-NEXT: issubclass(type(i32attr), _ods_ir.Attribute) or
+ // CHECK-NEXT: not _ods_ir.AttrBuilder.contains('I32Attr')
+ // CHECK-NEXT: _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context)
+ // CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = (optionalF32Attr
// CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
- // CHECK: attributes["in"] = in_
+ // CHECK: attributes["in"] = (in_
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -161,7 +164,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
- // CHECK: if is_ is not None: attributes["is"] = is_
+ // CHECK: if is_ is not None: attributes["is"] = (is_
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -188,8 +191,8 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
- // CHECK: if arr is not None: attributes["arr"] = arr
- // CHECK: if unsupported is not None: attributes["unsupported"] = unsupported
+ // CHECK: if arr is not None: attributes["arr"] = (arr
+ // CHECK: if unsupported is not None: attributes["unsupported"] = (unsupported
// CHECK: _ods_successors = None
// CHECK: super().__init__(self.build_generic(
// CHECK: attributes=attributes, results=results, operands=operands,
@@ -202,7 +205,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
- // CHECK: def __init__(self, type, *, loc=None, ip=None):
+ // CHECK: def __init__(self, type_, *, loc=None, ip=None):
// CHECK: operands = []
// CHECK: results = []
// CHECK: _ods_result_type_source_attr = attributes["type"]
@@ -217,7 +220,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
- // CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None):
+ // CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
let arguments = (ins TypeAttr:$type);
let results = (outs AnyType:$res, Variadic<AnyType>);
}
diff --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py
index 2ebad0d8acbfa..2d2a2034bfb34 100644
--- a/mlir/test/python/dialects/shape.py
+++ b/mlir/test/python/dialects/shape.py
@@ -22,9 +22,18 @@ def testConstShape():
@func.FuncOp.from_py_func(
RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
def const_shape_tensor(arg):
+ shape.ConstWitnessOp(False)
+ shape.ConstSizeOp(30)
+ shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
+ shape.ConstShapeOp([1, 2])
return shape.ConstShapeOp(
- DenseElementsAttr.get(np.array([10, 20], dtype=np.int64), type=IndexType.get()))
+ DenseElementsAttr.get(
+ np.array([3, 4], dtype=np.int64), type=IndexType.get()))
# CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
- # CHECK: shape.const_shape [10, 20] : tensor<2xindex>
+ # CHECK-DAG: shape.const_witness false
+ # CHECK-DAG: shape.const_size 30
+ # CHECK-DAG: shape.const_size 40
+ # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
+ # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
print(module)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index a5ffcc44519f2..1bd98eeb019ce 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -280,15 +280,16 @@ static llvm::cl::opt<std::string> clDialectExtensionName(
using AttributeClasses = DenseMap<StringRef, StringRef>;
-/// Checks whether `str` is a Python keyword.
-static bool isPythonKeyword(StringRef str) {
- static llvm::StringSet<> keywords(
- {"and", "as", "assert", "break", "class", "continue",
- "def", "del", "elif", "else", "except", "finally",
- "for", "from", "global", "if", "import", "in",
- "is", "lambda", "nonlocal", "not", "or", "pass",
- "raise", "return", "try", "while", "with", "yield"});
- return keywords.contains(str);
+/// Checks whether `str` is a Python keyword or would shadow builtin function.
+static bool isPythonReserved(StringRef str) {
+ static llvm::StringSet<> reserved(
+ {"and", "as", "assert", "break", "callable", "class",
+ "continue", "def", "del", "elif", "else", "except",
+ "finally", "for", "from", "global", "if", "import",
+ "in", "is", "lambda", "nonlocal", "not", "or",
+ "pass", "raise", "return", "issubclass", "try", "type",
+ "while", "with", "yield"});
+ return reserved.contains(str);
}
/// Checks whether `str` would shadow a generated variable or attribute
@@ -306,7 +307,7 @@ static bool isODSReserved(StringRef str) {
/// (does not change the `name` if it already is suitable) and returns the
/// modified version.
static std::string sanitizeName(StringRef name) {
- if (isPythonKeyword(name) || isODSReserved(name))
+ if (isPythonReserved(name) || isODSReserved(name))
return (name + "_").str();
return name.str();
}
@@ -531,16 +532,30 @@ constexpr const char *multiOperandAppendPackTemplate =
"operands.append(_get_op_results_or_values({0}))";
constexpr const char *multiResultAppendTemplate = "results.extend({0})";
-/// Template for setting an attribute in the operation builder.
-/// {0} is the attribute name;
-/// {1} is the builder argument name.
-constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
-
-/// Template for setting an optional attribute in the operation builder.
-/// {0} is the attribute name;
-/// {1} is the builder argument name.
-constexpr const char *initOptionalAttributeTemplate =
- R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
+/// Template for attribute builder from raw input in the operation builder.
+/// {0} is the builder argument name;
+/// {1} is the attribute builder from raw;
+/// {2} is the attribute builder from raw.
+/// Use the value the user passed in if either it is already an Attribute or
+/// there is no method registered to make it an Attribute.
+constexpr const char *initAttributeWithBuilderTemplate =
+ R"Py(attributes["{1}"] = ({0} if (
+ issubclass(type({0}), _ods_ir.Attribute) or
+ not _ods_ir.AttrBuilder.contains('{2}')) else
+ _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+
+/// Template for attribute builder from raw input for optional attribute in the
+/// operation builder.
+/// {0} is the builder argument name;
+/// {1} is the attribute builder from raw;
+/// {2} is the attribute builder from raw.
+/// Use the value the user passed in if either it is already an Attribute or
+/// there is no method registered to make it an Attribute.
+constexpr const char *initOptionalAttributeWithBuilderTemplate =
+ R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
+ issubclass(type({0}), _ods_ir.Attribute) or
+ not _ods_ir.AttrBuilder.contains('{2}')) else
+ _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -656,6 +671,7 @@ static void
populateBuilderLinesAttr(const Operator &op,
llvm::ArrayRef<std::string> argNames,
llvm::SmallVectorImpl<std::string> &builderLines) {
+ builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
Argument arg = op.getArg(i);
auto *attribute = arg.dyn_cast<NamedAttribute *>();
@@ -670,10 +686,10 @@ populateBuilderLinesAttr(const Operator &op,
}
builderLines.push_back(llvm::formatv(
- (attribute->attr.isOptional() || attribute->attr.hasDefaultValue())
- ? initOptionalAttributeTemplate
- : initAttributeTemplate,
- attribute->name, argNames[i]));
+ attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
+ ? initOptionalAttributeWithBuilderTemplate
+ : initAttributeWithBuilderTemplate,
+ argNames[i], attribute->name, attribute->attr.getAttrDefName()));
}
}
@@ -753,8 +769,7 @@ constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
/// corresponding interface:
/// - {0} is the name of the class for which the types are inferred.
constexpr const char *inferTypeInterfaceTemplate =
- R"PY(_ods_context = _ods_get_default_loc_context(loc)
-results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
+ R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
operands=operands,
attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
context=_ods_context,
More information about the Mlir-commits
mailing list