[Mlir-commits] [mlir] c1de154 - [MLIR][Python] Add a `.get` method to `IntegerType` (#174406)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 5 19:57:43 PST 2026


Author: Twice
Date: 2026-01-06T11:57:39+08:00
New Revision: c1de1543bf3bd9c2840d7d79f6a37a22b8a3b99a

URL: https://github.com/llvm/llvm-project/commit/c1de1543bf3bd9c2840d7d79f6a37a22b8a3b99a
DIFF: https://github.com/llvm/llvm-project/commit/c1de1543bf3bd9c2840d7d79f6a37a22b8a3b99a.diff

LOG: [MLIR][Python] Add a `.get` method to `IntegerType` (#174406)

In this PR, I added a `.get` class method to `IntegerType`. The main
goal is to ensure that types from upstream dialects have a `.get` method
(at least for the builtin dialect). The benefit is that, for any MLIR
type, we can construct an instance directly without special-casing types
that don’t provide a `.get` method.

The design mirrors `mlir::IntegerType` in C++: it takes `width` and
`signedness` parameters, and `signedness` defaults to `signless`.

It is related to #169045.

Added: 
    

Modified: 
    mlir/include/mlir/Bindings/Python/IRTypes.h
    mlir/lib/Bindings/Python/IRTypes.cpp
    mlir/test/python/ir/builtin_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
index a0901fefec5ce..b305dec188f5a 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -26,6 +26,8 @@ class MLIR_PYTHON_API_EXPORTED PyIntegerType
   static constexpr const char *pyClassName = "IntegerType";
   using PyConcreteType::PyConcreteType;
 
+  enum Signedness { Signless, Signed, Unsigned };
+
   static void bindDerived(ClassTy &c);
 };
 

diff  --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index ca56fc3248ed8..af07dd53c2b54 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -35,6 +35,12 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
 }
 
 void PyIntegerType::bindDerived(ClassTy &c) {
+  nb::enum_<Signedness>(c, "Signedness")
+      .value("SIGNLESS", Signless)
+      .value("SIGNED", Signed)
+      .value("UNSIGNED", Unsigned)
+      .export_values();
+
   c.def_static(
       "get_signless",
       [](unsigned width, DefaultingPyMlirContext context) {
@@ -59,6 +65,33 @@ void PyIntegerType::bindDerived(ClassTy &c) {
       },
       nb::arg("width"), nb::arg("context") = nb::none(),
       "Create an unsigned integer type");
+  c.def_static(
+      "get",
+      [](unsigned width, Signedness signedness,
+         DefaultingPyMlirContext context) {
+        MlirType t;
+        switch (signedness) {
+        case Signless:
+          t = mlirIntegerTypeGet(context->get(), width);
+          break;
+        case Signed:
+          t = mlirIntegerTypeSignedGet(context->get(), width);
+          break;
+        case Unsigned:
+          t = mlirIntegerTypeUnsignedGet(context->get(), width);
+          break;
+        }
+        return PyIntegerType(context->getRef(), t);
+      },
+      nb::arg("width"), nb::arg("signedness") = Signless,
+      nb::arg("context") = nb::none(), "Create an integer type");
+  c.def_prop_ro("signedness", [](PyIntegerType &self) -> Signedness {
+    if (mlirIntegerTypeIsSignless(self))
+      return Signless;
+    if (mlirIntegerTypeIsSigned(self))
+      return Signed;
+    return Unsigned;
+  });
   c.def_prop_ro(
       "width",
       [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index aa1665a4020fc..fc39ff1aa5332 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -227,6 +227,20 @@ def testIntegerType():
         print("signed:", IntegerType.get_signed(8))
         # CHECK: unsigned: ui64
         print("unsigned:", IntegerType.get_unsigned(64))
+        # CHECK: signless: i8
+        print("signless:", IntegerType.get(8))
+        # CHECK: signless: i16
+        print("signless:", IntegerType.get(16, IntegerType.SIGNLESS))
+        # CHECK: signed: si8
+        print("signed:", IntegerType.get(8, IntegerType.SIGNED))
+        # CHECK: unsigned: ui64
+        print("unsigned:", IntegerType.get(64, IntegerType.UNSIGNED))
+        # CHECK: SIGNLESS
+        print(IntegerType.get(8).signedness)
+        # CHECK: SIGNED
+        print(IntegerType.get(8, IntegerType.SIGNED).signedness)
+        # CHECK: UNSIGNED
+        print(IntegerType.get(8, IntegerType.UNSIGNED).signedness)
 
 
 # CHECK-LABEL: TEST: testIndexType


        


More information about the Mlir-commits mailing list