[Mlir-commits] [mlir] b57acb9 - Revert "Revert "[mlir][py] Enable building ops with raw inputs""

Jacques Pienaar llvmlistbot at llvm.org
Wed Dec 21 16:22:45 PST 2022


Author: Jacques Pienaar
Date: 2022-12-21T16:22:39-08:00
New Revision: b57acb9a405c289069345a498ebfc1d1b9b110de

URL: https://github.com/llvm/llvm-project/commit/b57acb9a405c289069345a498ebfc1d1b9b110de
DIFF: https://github.com/llvm/llvm-project/commit/b57acb9a405c289069345a498ebfc1d1b9b110de.diff

LOG: Revert "Revert "[mlir][py] Enable building ops with raw inputs""

Fix Python 3.6.9 issue encountered due to type checking here. Will
add back in follow up.

This reverts commit 1f47fee2948ef48781084afe0426171d000d7997.

Added: 
    

Modified: 
    mlir/docs/Bindings/Python.md
    mlir/lib/Bindings/Python/Globals.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.cpp
    mlir/python/mlir/ir.py
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/test/python/dialects/shape.py
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index cdb00dc22146e..a7b2b313ea423 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -743,6 +743,34 @@ with Context():
   dictionary = DictAttr.get({"array": array, "unit": UnitAttr.get()})
 ```
 
+Custom builders for Attributes to be used during Operation creation can be
+registered by way of the `register_attribute_builder`. In particular the
+following is how a custom builder is registered for `I32Attr`:
+
+```python
+ at register_attribute_builder("I32Attr")
+def _i32Attr(x: int, context: Context):
+  return IntegerAttr.get(
+        IntegerType.get_signless(32, context=context), x)
+```
+
+This allows to invoke op creation of an op with a `I32Attr` with
+
+```python
+foo.Op(30)
+```
+
+The registration is based on the ODS name but registry is via pure python
+method. Only single custom builder is allowed to be registered per ODS attribute
+type (e.g., I32Attr can have only one, which can correspond to multiple of the
+underlying IntegerAttr type).
+
+instead of
+
+```python
+foo.Op(IntegerAttr.get(IndexType.get_signless(32, context=context), 30))
+```
+
 ## Style
 
 In general, for the core parts of MLIR, the Python bindings should be largely

diff  --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 6613d2b6963c0..ba6cfb545b71b 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -58,6 +58,12 @@ class PyGlobals {
   /// have a DIALECT_NAMESPACE attribute.
   pybind11::object registerDialectDecorator(pybind11::object pyClass);
 
+  /// Adds a user-friendly Attribute builder.
+  /// Raises an exception if the mapping already exists.
+  /// This is intended to be called by implementation code.
+  void registerAttributeBuilder(const std::string &attributeKind,
+                                pybind11::function pyFunc);
+
   /// Adds a concrete implementation dialect class.
   /// Raises an exception if the mapping already exists.
   /// This is intended to be called by implementation code.
@@ -71,6 +77,10 @@ class PyGlobals {
                              pybind11::object pyClass,
                              pybind11::object rawOpViewClass);
 
+  /// Returns the custom Attribute builder for Attribute kind.
+  std::optional<pybind11::function>
+  lookupAttributeBuilder(const std::string &attributeKind);
+
   /// Looks up a registered dialect class by namespace. Note that this may
   /// trigger loading of the defining module and can arbitrarily re-enter.
   llvm::Optional<pybind11::object>
@@ -92,6 +102,8 @@ class PyGlobals {
   /// Map of operation name to custom subclass that directly initializes
   /// the OpView base class (bypassing the user class constructor).
   llvm::StringMap<pybind11::object> rawOpViewClassMap;
+  /// Map of attribute ODS name to custom builder.
+  llvm::StringMap<pybind11::function> attributeBuilderMap;
 
   /// Set of dialect namespaces that we have attempted to import implementation
   /// modules for.

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 794be974284cd..f2aa8da5bc79f 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -194,6 +194,29 @@ struct PyGlobalDebugFlag {
   }
 };
 
+struct PyAttrBuilderMap {
+  static bool dunderContains(const std::string &attributeKind) {
+    return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
+  }
+  static py::function dundeGetItemNamed(const std::string &attributeKind) {
+    auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
+    if (!builder)
+      throw py::key_error();
+    return *builder;
+  }
+  static void dundeSetItemNamed(const std::string &attributeKind,
+                                py::function func) {
+    PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func));
+  }
+
+  static void bind(py::module &m) {
+    py::class_<PyAttrBuilderMap>(m, "AttrBuilder", py::module_local())
+        .def_static("contains", &PyAttrBuilderMap::dunderContains)
+        .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed)
+        .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed);
+  }
+};
+
 //------------------------------------------------------------------------------
 // Collections.
 //------------------------------------------------------------------------------
@@ -3283,4 +3306,7 @@ void mlir::python::populateIRCore(py::module &m) {
 
   // Debug bindings.
   PyGlobalDebugFlag::bind(m);
+
+  // Attribute builder getter.
+  PyAttrBuilderMap::bind(m);
 }

diff  --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index b6d1df51f44e6..be6de5fd2f1ba 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -60,6 +60,17 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
   loadedDialectModulesCache.insert(dialectNamespace);
 }
 
+void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
+                                         py::function pyFunc) {
+  py::function &found = attributeBuilderMap[attributeKind];
+  if (found) {
+    throw std::runtime_error((llvm::Twine("Attribute builder for '") +
+                              attributeKind + "' is already registered")
+                                 .str());
+  }
+  found = std::move(pyFunc);
+}
+
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
                                     py::object pyClass) {
   py::object &found = dialectClassMap[dialectNamespace];
@@ -84,6 +95,22 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
 }
 
+std::optional<py::function>
+PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
+  // Fast match against the class map first (common case).
+  const auto foundIt = attributeBuilderMap.find(attributeKind);
+  if (foundIt != attributeBuilderMap.end()) {
+    if (foundIt->second.is_none())
+      return std::nullopt;
+    assert(foundIt->second && "py::function is defined");
+    return foundIt->second;
+  }
+
+  // Not found and loading did not yield a registration. Negative cache.
+  attributeBuilderMap[attributeKind] = py::none();
+  return std::nullopt;
+}
+
 llvm::Optional<py::object>
 PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
   loadDialectModule(dialectNamespace);

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 99e88ff743848..82468e8b76b40 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -4,3 +4,44 @@
 
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
+
+
+# Convenience decorator for registering user-friendly Attribute builders.
+def register_attribute_builder(kind):
+  def decorator_builder(func):
+    AttrBuilder.insert(kind, func)
+    return func
+  return decorator_builder
+
+
+ at register_attribute_builder("BoolAttr")
+def _boolAttr(x, context):
+  return BoolAttr.get(x, context=context)
+
+ at register_attribute_builder("IndexAttr")
+def _indexAttr(x, context):
+  return IntegerAttr.get(IndexType.get(context=context), x)
+
+ at register_attribute_builder("I32Attr")
+def _i32Attr(x, context):
+  return IntegerAttr.get(
+      IntegerType.get_signless(32, context=context), x)
+
+ at register_attribute_builder("I64Attr")
+def _i64Attr(x, context):
+  return IntegerAttr.get(
+      IntegerType.get_signless(64, context=context), x)
+
+ at register_attribute_builder("SymbolNameAttr")
+def _symbolNameAttr(x, context):
+  return StringAttr.get(x, context=context)
+
+try:
+  import numpy as np
+  @register_attribute_builder("IndexElementsAttr")
+  def _indexElementsAttr(x, context):
+    return DenseElementsAttr.get(
+        np.array(x, dtype=np.int64), type=IndexType.get(context=context),
+        context=context)
+except ImportError:
+  pass

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 2dda3db53bb2e..97fe306d6c160 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -115,11 +115,14 @@ def AttributedOp : TestOp<"attributed_op"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
-  // CHECK:   attributes["i32attr"] = i32attr
-  // CHECK:   if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr
+  // CHECK:   attributes["i32attr"] = (i32attr if (
+  // CHECK-NEXT:   issubclass(type(i32attr), _ods_ir.Attribute) or
+  // CHECK-NEXT:   not _ods_ir.AttrBuilder.contains('I32Attr')
+  // CHECK-NEXT:   _ods_ir.AttrBuilder.get('I32Attr')(i32attr, context=_ods_context)
+  // CHECK:   if optionalF32Attr is not None: attributes["optionalF32Attr"] = (optionalF32Attr
   // CHECK:   if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get(
   // CHECK:     _ods_get_default_loc_context(loc))
-  // CHECK:   attributes["in"] = in_
+  // CHECK:   attributes["in"] = (in_
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -161,7 +164,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_2))
   // CHECK:   if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
   // CHECK:     _ods_get_default_loc_context(loc))
-  // CHECK:   if is_ is not None: attributes["is"] = is_
+  // CHECK:   if is_ is not None: attributes["is"] = (is_
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -188,8 +191,8 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
-  // CHECK:   if arr is not None: attributes["arr"] = arr
-  // CHECK:   if unsupported is not None: attributes["unsupported"] = unsupported
+  // CHECK:   if arr is not None: attributes["arr"] = (arr
+  // CHECK:   if unsupported is not None: attributes["unsupported"] = (unsupported
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(self.build_generic(
   // CHECK:     attributes=attributes, results=results, operands=operands,
@@ -202,7 +205,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
-  // CHECK: def __init__(self, type, *, loc=None, ip=None):
+  // CHECK: def __init__(self, type_, *, loc=None, ip=None):
   // CHECK:   operands = []
   // CHECK:   results = []
   // CHECK:   _ods_result_type_source_attr = attributes["type"]
@@ -217,7 +220,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
-  // CHECK: def __init__(self, res, _gen_res_1, type, *, loc=None, ip=None):
+  // CHECK: def __init__(self, res, _gen_res_1, type_, *, loc=None, ip=None):
   let arguments = (ins TypeAttr:$type);
   let results = (outs AnyType:$res, Variadic<AnyType>);
 }

diff  --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py
index 2ebad0d8acbfa..2d2a2034bfb34 100644
--- a/mlir/test/python/dialects/shape.py
+++ b/mlir/test/python/dialects/shape.py
@@ -22,9 +22,18 @@ def testConstShape():
       @func.FuncOp.from_py_func(
           RankedTensorType.get((12, ShapedType.get_dynamic_size()), f32))
       def const_shape_tensor(arg):
+        shape.ConstWitnessOp(False)
+        shape.ConstSizeOp(30)
+        shape.ConstSizeOp(IntegerAttr.get(IndexType.get(), 40))
+        shape.ConstShapeOp([1, 2])
         return shape.ConstShapeOp(
-          DenseElementsAttr.get(np.array([10, 20], dtype=np.int64), type=IndexType.get()))
+            DenseElementsAttr.get(
+                np.array([3, 4], dtype=np.int64), type=IndexType.get()))
 
     # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
-    # CHECK: shape.const_shape [10, 20] : tensor<2xindex>
+    # CHECK-DAG: shape.const_witness false
+    # CHECK-DAG: shape.const_size 30
+    # CHECK-DAG: shape.const_size 40
+    # CHECK-DAG: shape.const_shape [1, 2] : tensor<2xindex>
+    # CHECK-DAG: shape.const_shape [3, 4] : tensor<2xindex>
     print(module)

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index a5ffcc44519f2..1bd98eeb019ce 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -280,15 +280,16 @@ static llvm::cl::opt<std::string> clDialectExtensionName(
 
 using AttributeClasses = DenseMap<StringRef, StringRef>;
 
-/// Checks whether `str` is a Python keyword.
-static bool isPythonKeyword(StringRef str) {
-  static llvm::StringSet<> keywords(
-      {"and",   "as",     "assert",   "break", "class",  "continue",
-       "def",   "del",    "elif",     "else",  "except", "finally",
-       "for",   "from",   "global",   "if",    "import", "in",
-       "is",    "lambda", "nonlocal", "not",   "or",     "pass",
-       "raise", "return", "try",      "while", "with",   "yield"});
-  return keywords.contains(str);
+/// Checks whether `str` is a Python keyword or would shadow builtin function.
+static bool isPythonReserved(StringRef str) {
+  static llvm::StringSet<> reserved(
+      {"and",      "as",    "assert", "break",      "callable", "class",
+       "continue", "def",   "del",    "elif",       "else",     "except",
+       "finally",  "for",   "from",   "global",     "if",       "import",
+       "in",       "is",    "lambda", "nonlocal",   "not",      "or",
+       "pass",     "raise", "return", "issubclass", "try",      "type",
+       "while",    "with",  "yield"});
+  return reserved.contains(str);
 }
 
 /// Checks whether `str` would shadow a generated variable or attribute
@@ -306,7 +307,7 @@ static bool isODSReserved(StringRef str) {
 /// (does not change the `name` if it already is suitable) and returns the
 /// modified version.
 static std::string sanitizeName(StringRef name) {
-  if (isPythonKeyword(name) || isODSReserved(name))
+  if (isPythonReserved(name) || isODSReserved(name))
     return (name + "_").str();
   return name.str();
 }
@@ -531,16 +532,30 @@ constexpr const char *multiOperandAppendPackTemplate =
     "operands.append(_get_op_results_or_values({0}))";
 constexpr const char *multiResultAppendTemplate = "results.extend({0})";
 
-/// Template for setting an attribute in the operation builder.
-///   {0} is the attribute name;
-///   {1} is the builder argument name.
-constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py";
-
-/// Template for setting an optional attribute in the operation builder.
-///   {0} is the attribute name;
-///   {1} is the builder argument name.
-constexpr const char *initOptionalAttributeTemplate =
-    R"Py(if {1} is not None: attributes["{0}"] = {1})Py";
+/// Template for attribute builder from raw input in the operation builder.
+///   {0} is the builder argument name;
+///   {1} is the attribute builder from raw;
+///   {2} is the attribute builder from raw.
+/// Use the value the user passed in if either it is already an Attribute or
+/// there is no method registered to make it an Attribute.
+constexpr const char *initAttributeWithBuilderTemplate =
+    R"Py(attributes["{1}"] = ({0} if (
+    issubclass(type({0}), _ods_ir.Attribute) or
+    not _ods_ir.AttrBuilder.contains('{2}')) else
+      _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
+
+/// Template for attribute builder from raw input for optional attribute in the
+/// operation builder.
+///   {0} is the builder argument name;
+///   {1} is the attribute builder from raw;
+///   {2} is the attribute builder from raw.
+/// Use the value the user passed in if either it is already an Attribute or
+/// there is no method registered to make it an Attribute.
+constexpr const char *initOptionalAttributeWithBuilderTemplate =
+    R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
+        issubclass(type({0}), _ods_ir.Attribute) or
+        not _ods_ir.AttrBuilder.contains('{2}')) else
+          _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
 
 constexpr const char *initUnitAttributeTemplate =
     R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
@@ -656,6 +671,7 @@ static void
 populateBuilderLinesAttr(const Operator &op,
                          llvm::ArrayRef<std::string> argNames,
                          llvm::SmallVectorImpl<std::string> &builderLines) {
+  builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
     Argument arg = op.getArg(i);
     auto *attribute = arg.dyn_cast<NamedAttribute *>();
@@ -670,10 +686,10 @@ populateBuilderLinesAttr(const Operator &op,
     }
 
     builderLines.push_back(llvm::formatv(
-        (attribute->attr.isOptional() || attribute->attr.hasDefaultValue())
-            ? initOptionalAttributeTemplate
-            : initAttributeTemplate,
-        attribute->name, argNames[i]));
+        attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
+            ? initOptionalAttributeWithBuilderTemplate
+            : initAttributeWithBuilderTemplate,
+        argNames[i], attribute->name, attribute->attr.getAttrDefName()));
   }
 }
 
@@ -753,8 +769,7 @@ constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
 /// corresponding interface:
 ///   - {0} is the name of the class for which the types are inferred.
 constexpr const char *inferTypeInterfaceTemplate =
-    R"PY(_ods_context = _ods_get_default_loc_context(loc)
-results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
+    R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
     operands=operands,
     attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
     context=_ods_context,


        


More information about the Mlir-commits mailing list