[Mlir-commits] [mlir] 2fc0d4a - [mlir] Add Float Attribute, Integer Attribute and Bool Attribute subclasses to python bindings.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 2 09:34:37 PDT 2020


Author: zhanghb97
Date: 2020-10-03T00:32:51+08:00
New Revision: 2fc0d4a8e83807d57f8d586af82934f94dead5e3

URL: https://github.com/llvm/llvm-project/commit/2fc0d4a8e83807d57f8d586af82934f94dead5e3
DIFF: https://github.com/llvm/llvm-project/commit/2fc0d4a8e83807d57f8d586af82934f94dead5e3.diff

LOG: [mlir] Add Float Attribute, Integer Attribute and Bool Attribute subclasses to python bindings.

Based on PyAttribute and PyConcreteAttribute classes, this patch implements the bindings of Float Attribute, Integer Attribute and Bool Attribute subclasses.
This patch also defines the `mlirFloatAttrDoubleGetChecked` C API which is bound with the `FloatAttr.get_typed` python method.

Differential Revision: https://reviews.llvm.org/D88531

Added: 
    

Modified: 
    mlir/include/mlir-c/StandardAttributes.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/CAPI/IR/StandardAttributes.cpp
    mlir/test/Bindings/Python/ir_attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/StandardAttributes.h b/mlir/include/mlir-c/StandardAttributes.h
index e5d5aeab4343..2fc2ecc9ee1d 100644
--- a/mlir/include/mlir-c/StandardAttributes.h
+++ b/mlir/include/mlir-c/StandardAttributes.h
@@ -93,6 +93,11 @@ int mlirAttributeIsAFloat(MlirAttribute attr);
 MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
                                      double value);
 
+/** Same as "mlirFloatAttrDoubleGet", but if the type is not valid for a
+ * construction of a FloatAttr, returns a null MlirAttribute. */
+MlirAttribute mlirFloatAttrDoubleGetChecked(MlirType type, double value,
+                                            MlirLocation loc);
+
 /** Returns the value stored in the given floating point attribute, interpreting
  * the value as double. */
 double mlirFloatAttrGetValueDouble(MlirAttribute attr);

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 8d64b2d8de0a..36e25eebfc71 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -742,6 +742,106 @@ class PyConcreteAttribute : public BaseTy {
   static void bindDerived(ClassTy &m) {}
 };
 
+/// Float Point Attribute subclass - FloatAttr.
+class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
+  static constexpr const char *pyClassName = "FloatAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        // TODO: Make the location optional and create a default location.
+        [](PyType &type, double value, PyLocation &loc) {
+          MlirAttribute attr =
+              mlirFloatAttrDoubleGetChecked(type.type, value, loc.loc);
+          // TODO: Rework error reporting once diagnostic engine is exposed
+          // in C API.
+          if (mlirAttributeIsNull(attr)) {
+            throw SetPyError(PyExc_ValueError,
+                             llvm::Twine("invalid '") +
+                                 py::repr(py::cast(type)).cast<std::string>() +
+                                 "' and expected floating point type.");
+          }
+          return PyFloatAttribute(type.getContext(), attr);
+        },
+        py::arg("type"), py::arg("value"), py::arg("loc"),
+        "Gets an uniqued float point attribute associated to a type");
+    c.def_static(
+        "get_f32",
+        [](PyMlirContext &context, double value) {
+          MlirAttribute attr = mlirFloatAttrDoubleGet(
+              context.get(), mlirF32TypeGet(context.get()), value);
+          return PyFloatAttribute(context.getRef(), attr);
+        },
+        py::arg("context"), py::arg("value"),
+        "Gets an uniqued float point attribute associated to a f32 type");
+    c.def_static(
+        "get_f64",
+        [](PyMlirContext &context, double value) {
+          MlirAttribute attr = mlirFloatAttrDoubleGet(
+              context.get(), mlirF64TypeGet(context.get()), value);
+          return PyFloatAttribute(context.getRef(), attr);
+        },
+        py::arg("context"), py::arg("value"),
+        "Gets an uniqued float point attribute associated to a f64 type");
+    c.def_property_readonly(
+        "value",
+        [](PyFloatAttribute &self) {
+          return mlirFloatAttrGetValueDouble(self.attr);
+        },
+        "Returns the value of the float point attribute");
+  }
+};
+
+/// Integer Attribute subclass - IntegerAttr.
+class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
+  static constexpr const char *pyClassName = "IntegerAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyType &type, int64_t value) {
+          MlirAttribute attr = mlirIntegerAttrGet(type.type, value);
+          return PyIntegerAttribute(type.getContext(), attr);
+        },
+        py::arg("type"), py::arg("value"),
+        "Gets an uniqued integer attribute associated to a type");
+    c.def_property_readonly(
+        "value",
+        [](PyIntegerAttribute &self) {
+          return mlirIntegerAttrGetValueInt(self.attr);
+        },
+        "Returns the value of the integer attribute");
+  }
+};
+
+/// Bool Attribute subclass - BoolAttr.
+class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
+  static constexpr const char *pyClassName = "BoolAttr";
+  using PyConcreteAttribute::PyConcreteAttribute;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](PyMlirContext &context, bool value) {
+          MlirAttribute attr = mlirBoolAttrGet(context.get(), value);
+          return PyBoolAttribute(context.getRef(), attr);
+        },
+        py::arg("context"), py::arg("value"), "Gets an uniqued bool attribute");
+    c.def_property_readonly(
+        "value",
+        [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self.attr); },
+        "Returns the value of the bool attribute");
+  }
+};
+
 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
@@ -1630,6 +1730,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           "The underlying generic attribute of the NamedAttribute binding");
 
   // Standard attribute bindings.
+  PyFloatAttribute::bind(m);
+  PyIntegerAttribute::bind(m);
+  PyBoolAttribute::bind(m);
   PyStringAttribute::bind(m);
 
   // Mapping of Type.

diff  --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/StandardAttributes.cpp
index 77d5fcb8b33c..1277d2b041ac 100644
--- a/mlir/lib/CAPI/IR/StandardAttributes.cpp
+++ b/mlir/lib/CAPI/IR/StandardAttributes.cpp
@@ -102,6 +102,11 @@ MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
   return wrap(FloatAttr::get(unwrap(type), value));
 }
 
+MlirAttribute mlirFloatAttrDoubleGetChecked(MlirType type, double value,
+                                            MlirLocation loc) {
+  return wrap(FloatAttr::getChecked(unwrap(type), value, unwrap(loc)));
+}
+
 double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
   return unwrap(attr).cast<FloatAttr>().getValueAsDouble();
 }

diff  --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index a2fd50056bf0..dfdc81909a9a 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -92,6 +92,63 @@ def testStandardAttrCasts():
 run(testStandardAttrCasts)
 
 
+# CHECK-LABEL: TEST: testFloatAttr
+def testFloatAttr():
+  ctx = mlir.ir.Context()
+  fattr = mlir.ir.FloatAttr(ctx.parse_attr("42.0 : f32"))
+  # CHECK: fattr value: 42.0
+  print("fattr value:", fattr.value)
+
+  # Test factory methods.
+  loc = ctx.get_unknown_location()
+  # CHECK: default_get: 4.200000e+01 : f32
+  print("default_get:", mlir.ir.FloatAttr.get(
+      mlir.ir.F32Type(ctx), 42.0, loc))
+  # CHECK: f32_get: 4.200000e+01 : f32
+  print("f32_get:", mlir.ir.FloatAttr.get_f32(ctx, 42.0))
+  # CHECK: f64_get: 4.200000e+01 : f64
+  print("f64_get:", mlir.ir.FloatAttr.get_f64(ctx, 42.0))
+  try:
+    fattr_invalid = mlir.ir.FloatAttr.get(
+        mlir.ir.IntegerType.get_signless(ctx, 32), 42, loc)
+  except ValueError as e:
+    # CHECK: invalid 'Type(i32)' and expected floating point type.
+    print(e)
+  else:
+    print("Exception not produced")
+
+run(testFloatAttr)
+
+
+# CHECK-LABEL: TEST: testIntegerAttr
+def testIntegerAttr():
+  ctx = mlir.ir.Context()
+  iattr = mlir.ir.IntegerAttr(ctx.parse_attr("42"))
+  # CHECK: iattr value: 42
+  print("iattr value:", iattr.value)
+
+  # Test factory methods.
+  # CHECK: default_get: 42 : i32
+  print("default_get:", mlir.ir.IntegerAttr.get(
+      mlir.ir.IntegerType.get_signless(ctx, 32), 42))
+
+run(testIntegerAttr)
+
+
+# CHECK-LABEL: TEST: testBoolAttr
+def testBoolAttr():
+  ctx = mlir.ir.Context()
+  battr = mlir.ir.BoolAttr(ctx.parse_attr("true"))
+  # CHECK: iattr value: 1
+  print("iattr value:", battr.value)
+
+  # Test factory methods.
+  # CHECK: default_get: true
+  print("default_get:", mlir.ir.BoolAttr.get(ctx, True))
+
+run(testBoolAttr)
+
+
 # CHECK-LABEL: TEST: testStringAttr
 def testStringAttr():
   ctx = mlir.ir.Context()


        


More information about the Mlir-commits mailing list