[Mlir-commits] [mlir] [MLIR][Python] Added a base class to all builtin floating point types (PR #81720)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 14 02:12:47 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Sergei Lebedev (superbobry)
<details>
<summary>Changes</summary>
This allows to
* check if a given ir.Type is a floating point type via isinstance() or issubclass()
* get the bitwidth of a floating point type
See motivation and discussion in https://discourse.llvm.org/t/add-floattype-to-mlir-python-bindings/76959.
---
Full diff: https://github.com/llvm/llvm-project/pull/81720.diff
5 Files Affected:
- (modified) mlir/include/mlir-c/BuiltinTypes.h (+6)
- (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+28-10)
- (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+8)
- (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+19-9)
- (modified) mlir/test/python/ir/builtin_types.py (+34-1)
``````````diff
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 881b6dad2b84d7..99c5e3f46b04c1 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -73,6 +73,12 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
// Floating-point types.
//===----------------------------------------------------------------------===//
+/// Checks whether the given type is a floating-point type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type);
+
+/// Returns the bitwidth of a floating-point type.
+MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type);
+
/// Returns the typeID of an Float8E5M2 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 820992de659068..e1e4eb999b3aa8 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -109,8 +109,22 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
}
};
+class PyFloatType : public PyConcreteType<PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
+ static constexpr const char *pyClassName = "FloatType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_property_readonly(
+ "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
+ "Returns the width of the floating-point type");
+ }
+};
+
/// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
+class PyFloat8E4M3FNType
+ : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -130,7 +144,7 @@ class PyFloat8E4M3FNType : public PyConcreteType<PyFloat8E4M3FNType> {
};
/// Floating Point Type subclass - Float8M5E2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
+class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -150,7 +164,8 @@ class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type> {
};
/// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
+class PyFloat8E4M3FNUZType
+ : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -170,7 +185,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType<PyFloat8E4M3FNUZType> {
};
/// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
+class PyFloat8E4M3B11FNUZType
+ : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -190,7 +206,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType<PyFloat8E4M3B11FNUZType> {
};
/// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
+class PyFloat8E5M2FNUZType
+ : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -210,7 +227,7 @@ class PyFloat8E5M2FNUZType : public PyConcreteType<PyFloat8E5M2FNUZType> {
};
/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type> {
+class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -230,7 +247,7 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
};
/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type> {
+class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -250,7 +267,7 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
};
/// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type> {
+class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -270,7 +287,7 @@ class PyTF32Type : public PyConcreteType<PyTF32Type> {
};
/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type> {
+class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -290,7 +307,7 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
};
/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type> {
+class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
@@ -819,6 +836,7 @@ class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
void mlir::python::populateIRTypes(py::module &m) {
PyIntegerType::bind(m);
+ PyFloatType::bind(m);
PyIndexType::bind(m);
PyFloat8E4M3FNType::bind(m);
PyFloat8E5M2Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 18c9414c5d0f34..e1a5d82587cf9e 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -78,6 +78,14 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
// Floating-point types.
//===----------------------------------------------------------------------===//
+bool mlirTypeIsAFloat(MlirType type) {
+ return llvm::isa<FloatType>(unwrap(type));
+}
+
+unsigned mlirFloatTypeGetWidth(MlirType type) {
+ return llvm::cast<FloatType>(unwrap(type)).getWidth();
+}
+
MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
return wrap(Float8E5M2Type::getTypeID());
}
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 344abb64a57d23..586bf7f8e93fba 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -1442,7 +1442,17 @@ class DictAttr(Attribute):
@property
def typeid(self) -> TypeID: ...
-class F16Type(Type):
+class FloatType(Type):
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def width(self) -> int:
+ """
+ Returns the width of the floating-point type.
+ """
+
+class F16Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> F16Type:
@@ -1455,7 +1465,7 @@ class F16Type(Type):
@property
def typeid(self) -> TypeID: ...
-class F32Type(Type):
+class F32Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> F32Type:
@@ -1468,7 +1478,7 @@ class F32Type(Type):
@property
def typeid(self) -> TypeID: ...
-class F64Type(Type):
+class F64Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> F64Type:
@@ -1502,7 +1512,7 @@ class FlatSymbolRefAttr(Attribute):
Returns the value of the FlatSymbolRef attribute as a string
"""
-class Float8E4M3B11FNUZType(Type):
+class Float8E4M3B11FNUZType(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType:
@@ -1515,7 +1525,7 @@ class Float8E4M3B11FNUZType(Type):
@property
def typeid(self) -> TypeID: ...
-class Float8E4M3FNType(Type):
+class Float8E4M3FNType(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E4M3FNType:
@@ -1528,7 +1538,7 @@ class Float8E4M3FNType(Type):
@property
def typeid(self) -> TypeID: ...
-class Float8E4M3FNUZType(Type):
+class Float8E4M3FNUZType(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E4M3FNUZType:
@@ -1541,7 +1551,7 @@ class Float8E4M3FNUZType(Type):
@property
def typeid(self) -> TypeID: ...
-class Float8E5M2FNUZType(Type):
+class Float8E5M2FNUZType(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E5M2FNUZType:
@@ -1554,7 +1564,7 @@ class Float8E5M2FNUZType(Type):
@property
def typeid(self) -> TypeID: ...
-class Float8E5M2Type(Type):
+class Float8E5M2Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> Float8E5M2Type:
@@ -1601,7 +1611,7 @@ class FloatAttr(Attribute):
Returns the value of the float attribute
"""
-class FloatTF32Type(Type):
+class FloatTF32Type(FloatType):
static_typeid: ClassVar[TypeID] # value = <mlir._mlir_libs._TypeID object>
@staticmethod
def get(context: Optional[Context] = None) -> FloatTF32Type:
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 30a5054ada91ac..4eea1a9c372ef7 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -100,8 +100,38 @@ def testTypeIsInstance():
print(IntegerType.isinstance(t1))
# CHECK: False
print(F32Type.isinstance(t1))
+ # CHECK: False
+ print(FloatType.isinstance(t1))
# CHECK: True
print(F32Type.isinstance(t2))
+ # CHECK: True
+ print(FloatType.isinstance(t2))
+
+
+# CHECK-LABEL: TEST: testFloatTypeSubclasses
+ at run
+def testFloatTypeSubclasses():
+ ctx = Context()
+ # CHECK: True
+ print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f8E5M2", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f8E4M3FNUZ", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f8E4M3B11FNUZ", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f8E5M2FNUZ", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f16", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("bf16", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f32", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("tf32", ctx), FloatType))
+ # CHECK: True
+ print(isinstance(Type.parse("f64", ctx), FloatType))
# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
@@ -218,7 +248,10 @@ def testFloatType():
# CHECK: float: f32
print("float:", F32Type.get())
# CHECK: float: f64
- print("float:", F64Type.get())
+ f64 = F64Type.get()
+ print("float:", f64)
+ # CHECK: f64 width: 64
+ print("f64 width:", f64.width)
# CHECK-LABEL: TEST: testNoneType
``````````
</details>
https://github.com/llvm/llvm-project/pull/81720
More information about the Mlir-commits
mailing list