[Mlir-commits] [mlir] c1de154 - [MLIR][Python] Add a `.get` method to `IntegerType` (#174406)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 5 19:57:43 PST 2026
Author: Twice
Date: 2026-01-06T11:57:39+08:00
New Revision: c1de1543bf3bd9c2840d7d79f6a37a22b8a3b99a
URL: https://github.com/llvm/llvm-project/commit/c1de1543bf3bd9c2840d7d79f6a37a22b8a3b99a
DIFF: https://github.com/llvm/llvm-project/commit/c1de1543bf3bd9c2840d7d79f6a37a22b8a3b99a.diff
LOG: [MLIR][Python] Add a `.get` method to `IntegerType` (#174406)
In this PR, I added a `.get` class method to `IntegerType`. The main
goal is to ensure that types from upstream dialects have a `.get` method
(at least for the builtin dialect). The benefit is that, for any MLIR
type, we can construct an instance directly without special-casing types
that don’t provide a `.get` method.
The design mirrors `mlir::IntegerType` in C++: it takes `width` and
`signedness` parameters, and `signedness` defaults to `signless`.
It is related to #169045.
Added:
Modified:
mlir/include/mlir/Bindings/Python/IRTypes.h
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/test/python/ir/builtin_types.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index a0901fefec5ce..b305dec188f5a 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -26,6 +26,8 @@ class MLIR_PYTHON_API_EXPORTED PyIntegerType
static constexpr const char *pyClassName = "IntegerType";
using PyConcreteType::PyConcreteType;
+ enum Signedness { Signless, Signed, Unsigned };
+
static void bindDerived(ClassTy &c);
};
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index ca56fc3248ed8..af07dd53c2b54 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -35,6 +35,12 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
}
void PyIntegerType::bindDerived(ClassTy &c) {
+ nb::enum_<Signedness>(c, "Signedness")
+ .value("SIGNLESS", Signless)
+ .value("SIGNED", Signed)
+ .value("UNSIGNED", Unsigned)
+ .export_values();
+
c.def_static(
"get_signless",
[](unsigned width, DefaultingPyMlirContext context) {
@@ -59,6 +65,33 @@ void PyIntegerType::bindDerived(ClassTy &c) {
},
nb::arg("width"), nb::arg("context") = nb::none(),
"Create an unsigned integer type");
+ c.def_static(
+ "get",
+ [](unsigned width, Signedness signedness,
+ DefaultingPyMlirContext context) {
+ MlirType t;
+ switch (signedness) {
+ case Signless:
+ t = mlirIntegerTypeGet(context->get(), width);
+ break;
+ case Signed:
+ t = mlirIntegerTypeSignedGet(context->get(), width);
+ break;
+ case Unsigned:
+ t = mlirIntegerTypeUnsignedGet(context->get(), width);
+ break;
+ }
+ return PyIntegerType(context->getRef(), t);
+ },
+ nb::arg("width"), nb::arg("signedness") = Signless,
+ nb::arg("context") = nb::none(), "Create an integer type");
+ c.def_prop_ro("signedness", [](PyIntegerType &self) -> Signedness {
+ if (mlirIntegerTypeIsSignless(self))
+ return Signless;
+ if (mlirIntegerTypeIsSigned(self))
+ return Signed;
+ return Unsigned;
+ });
c.def_prop_ro(
"width",
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index aa1665a4020fc..fc39ff1aa5332 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -227,6 +227,20 @@ def testIntegerType():
print("signed:", IntegerType.get_signed(8))
# CHECK: unsigned: ui64
print("unsigned:", IntegerType.get_unsigned(64))
+ # CHECK: signless: i8
+ print("signless:", IntegerType.get(8))
+ # CHECK: signless: i16
+ print("signless:", IntegerType.get(16, IntegerType.SIGNLESS))
+ # CHECK: signed: si8
+ print("signed:", IntegerType.get(8, IntegerType.SIGNED))
+ # CHECK: unsigned: ui64
+ print("unsigned:", IntegerType.get(64, IntegerType.UNSIGNED))
+ # CHECK: SIGNLESS
+ print(IntegerType.get(8).signedness)
+ # CHECK: SIGNED
+ print(IntegerType.get(8, IntegerType.SIGNED).signedness)
+ # CHECK: UNSIGNED
+ print(IntegerType.get(8, IntegerType.UNSIGNED).signedness)
# CHECK-LABEL: TEST: testIndexType
More information about the Mlir-commits
mailing list