[Mlir-commits] [mlir] [mlir] Expose `AffineExpr.shift_dims/shift_symbols` through C and Python bindings (PR #131521)
Ivan Butygin
llvmlistbot at llvm.org
Sun Mar 16 07:30:26 PDT 2025
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/131521
None
>From 663f5086aa39641bb41fcac353377f8e56b654dd Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 16 Mar 2025 15:26:43 +0100
Subject: [PATCH] [mlir] Expose `AffineExpr.shift_dims/shift_symbols` through C
and Python bindings
---
mlir/include/mlir-c/AffineExpr.h | 12 ++++++++++++
mlir/lib/Bindings/Python/IRAffine.cpp | 19 +++++++++++++++++++
mlir/lib/CAPI/IR/AffineExpr.cpp | 12 ++++++++++++
mlir/test/python/ir/affine_expr.py | 11 +++++++++++
4 files changed, 54 insertions(+)
diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h
index 14e951ddee9ad..ab768eb2ec870 100644
--- a/mlir/include/mlir-c/AffineExpr.h
+++ b/mlir/include/mlir-c/AffineExpr.h
@@ -92,6 +92,18 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr,
MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprCompose(
MlirAffineExpr affineExpr, struct MlirAffineMap affineMap);
+/// Replace dims[offset ... numDims)
+/// by dims[offset + shift ... shift + numDims).
+MLIR_CAPI_EXPORTED MlirAffineExpr
+mlirAffineExprShiftDims(MlirAffineExpr affineExpr, uint32_t numDims,
+ uint32_t shift, uint32_t offset);
+
+/// Replace symbols[offset ... numSymbols)
+/// by symbols[offset + shift ... shift + numSymbols).
+MLIR_CAPI_EXPORTED MlirAffineExpr
+mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, uint32_t numSymbols,
+ uint32_t shift, uint32_t offset);
+
//===----------------------------------------------------------------------===//
// Affine Dimension Expression.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index a2df824f59a53..3c95d29c4bcca 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -580,6 +580,25 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
return PyAffineExpr(self.getContext(),
mlirAffineExprCompose(self, other));
})
+ .def(
+ "shift_dims",
+ [](PyAffineExpr &self, uint32_t numDims, uint32_t shift,
+ uint32_t offset) {
+ return PyAffineExpr(
+ self.getContext(),
+ mlirAffineExprShiftDims(self, numDims, shift, offset));
+ },
+ nb::arg("num_dims"), nb::arg("shift"), nb::arg("offset").none() = 0)
+ .def(
+ "shift_symbols",
+ [](PyAffineExpr &self, uint32_t numSymbols, uint32_t shift,
+ uint32_t offset) {
+ return PyAffineExpr(
+ self.getContext(),
+ mlirAffineExprShiftSymbols(self, numSymbols, shift, offset));
+ },
+ nb::arg("num_symbols"), nb::arg("shift"),
+ nb::arg("offset").none() = 0)
.def_static(
"get_add", &PyAffineAddExpr::get,
"Gets an affine expression containing a sum of two expressions.")
diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp
index 6e3328b65cb08..bc3dcd4174736 100644
--- a/mlir/lib/CAPI/IR/AffineExpr.cpp
+++ b/mlir/lib/CAPI/IR/AffineExpr.cpp
@@ -61,6 +61,18 @@ MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr,
return wrap(unwrap(affineExpr).compose(unwrap(affineMap)));
}
+MlirAffineExpr mlirAffineExprShiftDims(MlirAffineExpr affineExpr,
+ uint32_t numDims, uint32_t shift,
+ uint32_t offset) {
+ return wrap(unwrap(affineExpr).shiftDims(numDims, shift, offset));
+}
+
+MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr,
+ uint32_t numSymbols, uint32_t shift,
+ uint32_t offset) {
+ return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset));
+}
+
//===----------------------------------------------------------------------===//
// Affine Dimension Expression.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/affine_expr.py b/mlir/test/python/ir/affine_expr.py
index c7861c1acfe12..2f64aff143420 100644
--- a/mlir/test/python/ir/affine_expr.py
+++ b/mlir/test/python/ir/affine_expr.py
@@ -405,3 +405,14 @@ def testHash():
dictionary[s1] = 1
assert d0 in dictionary
assert s1 in dictionary
+
+
+# CHECK-LABEL: TEST: testAffineExprShift
+ at run
+def testAffineExprShift():
+ with Context() as ctx:
+ dims = [AffineExpr.get_dim(i) for i in range(4)]
+ syms = [AffineExpr.get_symbol(i) for i in range(4)]
+
+ assert (dims[2] + dims[3]) == (dims[0] + dims[1]).shift_dims(2, 2)
+ assert (syms[2] + syms[3]) == (syms[0] + syms[1]).shift_symbols(2, 2, 0)
More information about the Mlir-commits
mailing list