[Mlir-commits] [mlir] [mlir][py] ability to downcast AffineExpr after #172892 (PR #174808)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 7 09:25:01 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

<details>
<summary>Changes</summary>

AffineExpr is a separate hierarchy of LLVM-style nested classes that doesn't rely on TypeID and is not extensible. We need the ability to downcast the Python equivalent of those to a specific subclass that was seemingly lost in PR #<!-- -->172892. Bring it back by having an explicit cast. We don't really need user-defined type casters here since AffineExpr is entirely closed and not typed, unlike values.

---
Full diff: https://github.com/llvm/llvm-project/pull/174808.diff


3 Files Affected:

- (modified) mlir/include/mlir/Bindings/Python/IRCore.h (+2) 
- (modified) mlir/lib/Bindings/Python/IRAffine.cpp (+26-4) 
- (modified) mlir/test/python/ir/affine_expr.py (+11) 


``````````diff
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 729cbb6df3267..1f19683dfe80d 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -1214,6 +1214,8 @@ class MLIR_PYTHON_API_EXPORTED PyAffineExpr : public BaseContextObject {
   PyAffineExpr ceilDiv(const PyAffineExpr &other) const;
   PyAffineExpr mod(const PyAffineExpr &other) const;
 
+  nanobind::typed<nanobind::object, PyAffineExpr> maybeDownCast();
+
 private:
   MlirAffineExpr affineExpr;
 };
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index b3d15ee59566b..10c7e014f7309 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -193,14 +193,14 @@ class PyAffineBinaryExpr : public PyConcreteAffineExpr<PyAffineBinaryExpr> {
   static constexpr const char *pyClassName = "AffineBinaryExpr";
   using PyConcreteAffineExpr::PyConcreteAffineExpr;
 
-  PyAffineExpr lhs() {
+  nb::typed<nb::object, PyAffineExpr> lhs() {
     MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get());
-    return PyAffineExpr(getContext(), lhsExpr);
+    return PyAffineExpr(getContext(), lhsExpr).maybeDownCast();
   }
 
-  PyAffineExpr rhs() {
+  nb::typed<nb::object, PyAffineExpr> rhs() {
     MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get());
-    return PyAffineExpr(getContext(), rhsExpr);
+    return PyAffineExpr(getContext(), rhsExpr).maybeDownCast();
   }
 
   static void bindDerived(ClassTy &c) {
@@ -375,6 +375,27 @@ PyAffineExpr PyAffineExpr::createFromCapsule(const nb::object &capsule) {
       rawAffineExpr);
 }
 
+nb::typed<nb::object, PyAffineExpr> PyAffineExpr::maybeDownCast() {
+  MlirAffineExpr expr = get();
+  if (mlirAffineExprIsAConstant(expr))
+    return nb::cast(PyAffineConstantExpr(getContext(), expr));
+  if (mlirAffineExprIsADim(expr))
+    return nb::cast(PyAffineDimExpr(getContext(), expr));
+  if (mlirAffineExprIsASymbol(expr))
+    return nb::cast(PyAffineSymbolExpr(getContext(), expr));
+  if (mlirAffineExprIsAAdd(expr))
+    return nb::cast(PyAffineAddExpr(getContext(), expr));
+  if (mlirAffineExprIsAMul(expr))
+    return nb::cast(PyAffineMulExpr(getContext(), expr));
+  if (mlirAffineExprIsAMod(expr))
+    return nb::cast(PyAffineModExpr(getContext(), expr));
+  if (mlirAffineExprIsAFloorDiv(expr))
+    return nb::cast(PyAffineFloorDivExpr(getContext(), expr));
+  if (mlirAffineExprIsACeilDiv(expr))
+    return nb::cast(PyAffineCeilDivExpr(getContext(), expr));
+  return nb::cast(*this);
+}
+
 //------------------------------------------------------------------------------
 // PyAffineMap and utilities.
 //------------------------------------------------------------------------------
@@ -593,6 +614,7 @@ void populateIRAffine(nb::module_ &m) {
              return PyAffineExpr(self.getContext(),
                                  mlirAffineExprCompose(self, other));
            })
+      .def_prop_ro("maybe_downcast", &PyAffineExpr::maybeDownCast)
       .def(
           "shift_dims",
           [](PyAffineExpr &self, uint32_t numDims, uint32_t shift,
diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py
index 82c509efdf8fc..cf1d5258333d7 100644
--- a/mlir/test/python/ir/affine_expr.py
+++ b/mlir/test/python/ir/affine_expr.py
@@ -424,3 +424,14 @@ def testAffineExprSimplify():
     with Context() as ctx:
         expr = AffineExpr.get_dim(0) + AffineExpr.get_symbol(0)
         assert expr == AffineExpr.simplify_affine_expr(expr, 1, 1)
+
+
+# CHECK-LABEL: TEST: testAffineExprDowncast
+ at run
+def testAffineExprDowncast():
+    with Context() as ctx:
+        expr = AffineExpr.get_dim(0) + AffineExpr.get_symbol(0)
+        assert isinstance(expr.lhs, AffineDimExpr)
+        assert isinstance(expr.rhs, AffineSymbolExpr)
+        assert expr.lhs.position == 0
+        assert expr.rhs.position == 0

``````````

</details>


https://github.com/llvm/llvm-project/pull/174808


More information about the Mlir-commits mailing list