[Mlir-commits] [mlir] [MLIR][Python] Added a base class to all builtin floating point types (PR #81720)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 14 02:12:47 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Sergei Lebedev (superbobry)

<details>
<summary>Changes</summary>

This allows to

* check if a given ir.Type is a floating point type via isinstance() or issubclass()
* get the bitwidth of a floating point type

See motivation and discussion in https://discourse.llvm.org/t/add-floattype-to-mlir-python-bindings/76959.

---
Full diff: https://github.com/llvm/llvm-project/pull/81720.diff


5 Files Affected:

- (modified) mlir/include/mlir-c/BuiltinTypes.h (+6) 
- (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+28-10) 
- (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+8) 
- (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+19-9) 
- (modified) mlir/test/python/ir/builtin_types.py (+34-1) 


``````````diff
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 881b6dad2b84d7..99c5e3f46b04c1 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -73,6 +73,12 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
 // Floating-point types.
 //===----------------------------------------------------------------------===//
 
+/// Checks whether the given type is a floating-point type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type);
+
+/// Returns the bitwidth of a floating-point type.
+MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type);
+
 /// Returns the typeID of an Float8E5M2 type.
 MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void);
 
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 820992de659068..e1e4eb999b3aa8 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -109,8 +109,22 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
   }
 };
 
+class PyFloatType : public PyConcreteType<PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
+  static constexpr const char *pyClassName = "FloatType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_property_readonly(
+        "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
+        "Returns the width of the floating-point type");
+  }
+};
+
 /// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
+class PyFloat8E4M3FNType
+    : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -130,7 +144,7 @@ class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
 };
 
 /// Floating Point Type subclass - Float8M5E2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
+class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -150,7 +164,8 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
 };
 
 /// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
+class PyFloat8E4M3FNUZType
+    : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -170,7 +185,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
 };
 
 /// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
+class PyFloat8E4M3B11FNUZType
+    : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -190,7 +206,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
 };
 
 /// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
+class PyFloat8E5M2FNUZType
+    : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -210,7 +227,7 @@ class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
 };
 
 /// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type> {
+class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -230,7 +247,7 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
 };
 
 /// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type> {
+class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -250,7 +267,7 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
 };
 
 /// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type> {
+class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -270,7 +287,7 @@ class PyTF32Type : public PyConcreteType<PyTF32Type> {
 };
 
 /// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type> {
+class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -290,7 +307,7 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
 };
 
 /// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type> {
+class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
 public:
   static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
   static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -819,6 +836,7 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
 
 void mlir::python::populateIRTypes(py::module &m) {
   PyIntegerType::bind(m);
+  PyFloatType::bind(m);
   PyIndexType::bind(m);
   PyFloat8E4M3FNType::bind(m);
   PyFloat8E5M2Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 18c9414c5d0f34..e1a5d82587cf9e 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -78,6 +78,14 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
 // Floating-point types.
 //===----------------------------------------------------------------------===//
 
+bool mlirTypeIsAFloat(MlirType type) {
+  return llvm::isa<FloatType>(unwrap(type));
+}
+
+unsigned mlirFloatTypeGetWidth(MlirType type) {
+  return llvm::cast<FloatType>(unwrap(type)).getWidth();
+}
+
 MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
   return wrap(Float8E5M2Type::getTypeID());
 }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 344abb64a57d23..586bf7f8e93fba 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -1442,7 +1442,17 @@ class DictAttr(Attribute):
     @property
     def typeid(self) -> TypeID: ...
 
-class F16Type(Type):
+class FloatType(Type):
+    @staticmethod
+    def isinstance(other: Type) -> bool: ...
+    def __init__(self, cast_from_type: Type) -> None: ...
+    @property
+    def width(self) -> int:
+        """
+        Returns the width of the floating-point type.
+        """
+
+class F16Type(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> F16Type:
@@ -1455,7 +1465,7 @@ class F16Type(Type):
     @property
     def typeid(self) -> TypeID: ...
 
-class F32Type(Type):
+class F32Type(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> F32Type:
@@ -1468,7 +1478,7 @@ class F32Type(Type):
     @property
     def typeid(self) -> TypeID: ...
 
-class F64Type(Type):
+class F64Type(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> F64Type:
@@ -1502,7 +1512,7 @@ class FlatSymbolRefAttr(Attribute):
         Returns the value of the FlatSymbolRef attribute as a string
         """
 
-class Float8E4M3B11FNUZType(Type):
+class Float8E4M3B11FNUZType(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType:
@@ -1515,7 +1525,7 @@ class Float8E4M3B11FNUZType(Type):
     @property
     def typeid(self) -> TypeID: ...
 
-class Float8E4M3FNType(Type):
+class Float8E4M3FNType(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E4M3FNType:
@@ -1528,7 +1538,7 @@ class Float8E4M3FNType(Type):
     @property
     def typeid(self) -> TypeID: ...
 
-class Float8E4M3FNUZType(Type):
+class Float8E4M3FNUZType(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E4M3FNUZType:
@@ -1541,7 +1551,7 @@ class Float8E4M3FNUZType(Type):
     @property
     def typeid(self) -> TypeID: ...
 
-class Float8E5M2FNUZType(Type):
+class Float8E5M2FNUZType(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E5M2FNUZType:
@@ -1554,7 +1564,7 @@ class Float8E5M2FNUZType(Type):
     @property
     def typeid(self) -> TypeID: ...
 
-class Float8E5M2Type(Type):
+class Float8E5M2Type(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E5M2Type:
@@ -1601,7 +1611,7 @@ class FloatAttr(Attribute):
         Returns the value of the float attribute
         """
 
-class FloatTF32Type(Type):
+class FloatTF32Type(FloatType):
     static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
     @staticmethod
     def get(context: Optional[Context] = None) -> FloatTF32Type:
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 30a5054ada91ac..4eea1a9c372ef7 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -100,8 +100,38 @@ def testTypeIsInstance():
     print(IntegerType.isinstance(t1))
     # CHECK: False
     print(F32Type.isinstance(t1))
+    # CHECK: False
+    print(FloatType.isinstance(t1))
     # CHECK: True
     print(F32Type.isinstance(t2))
+    # CHECK: True
+    print(FloatType.isinstance(t2))
+
+
+# CHECK-LABEL: TEST: testFloatTypeSubclasses
+ at run
+def testFloatTypeSubclasses():
+    ctx = Context()
+    # CHECK: True
+    print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f8E5M2", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f8E4M3FNUZ", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f8E4M3B11FNUZ", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f8E5M2FNUZ", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f16", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("bf16", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f32", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("tf32", ctx), FloatType))
+    # CHECK: True
+    print(isinstance(Type.parse("f64", ctx), FloatType))
 
 
 # CHECK-LABEL: TEST: testTypeEqDoesNotRaise
@@ -218,7 +248,10 @@ def testFloatType():
         # CHECK: float: f32
         print("float:", F32Type.get())
         # CHECK: float: f64
-        print("float:", F64Type.get())
+        f64 = F64Type.get()
+        print("float:", f64)
+        # CHECK: f64 width: 64
+        print("f64 width:", f64.width)
 
 
 # CHECK-LABEL: TEST: testNoneType

``````````

</details>


https://github.com/llvm/llvm-project/pull/81720


More information about the Mlir-commits mailing list