[Mlir-commits] [mlir] [mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims (PR #134935)
Bangtian Liu
llvmlistbot at llvm.org
Wed Apr 9 12:57:05 PDT 2025
https://github.com/bangtianliu updated https://github.com/llvm/llvm-project/pull/134935
>From d151987cbb149747d9bf0d97f23b17c21e951725 Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Tue, 8 Apr 2025 15:03:11 -0700
Subject: [PATCH 1/2] [mlir][python] expose python bindings for
linalg::isaContractionOpInterface and linalg::inferContractionDims
Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
mlir/include/mlir-c/Dialect/Linalg.h | 12 +++++
mlir/lib/Bindings/Python/DialectLinalg.cpp | 62 +++++++++++++++++++++-
mlir/lib/CAPI/Dialect/Linalg.cpp | 32 +++++++++++
mlir/test/python/dialects/linalg/ops.py | 35 ++++++++++++
4 files changed, 140 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 0ab201e158033..c57d193e62d25 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -22,6 +22,18 @@ extern "C" {
MLIR_CAPI_EXPORTED void
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
+
+struct MlirLinalgContractionDimensions {
+ MlirAttribute batch;
+ MlirAttribute m;
+ MlirAttribute n;
+ MlirAttribute k;
+};
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op);
+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
#ifdef __cplusplus
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 548df4ee100aa..0dbd4f18b7212 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -6,12 +6,45 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
+using namespace mlir::python::nanobind_adaptors;
+
+struct PyContractionDimensions {
+ MlirLinalgContractionDimensions value;
+
+ PyContractionDimensions() = default;
+ PyContractionDimensions(const MlirLinalgContractionDimensions &v)
+ : value(v) {}
+};
+
+static std::optional<PyContractionDimensions>
+mlirLinalgInferContractionDimensionsBinding(MlirOperation op) {
+ MlirLinalgContractionDimensions dims =
+ mlirLinalgInferContractionDimensions(op);
+
+ // Detect "empty" result.
+ if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
+ mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
+ return std::nullopt;
+ }
+ return PyContractionDimensions{dims};
+}
+
+static std::vector<int32_t> convertDenseI32AttrToList(MlirAttribute attr) {
+ std::vector<int32_t> result;
+ int64_t size = mlirDenseArrayGetNumElements(attr);
+ result.reserve(size);
+ for (int64_t i = 0; i < size; ++i) {
+ result.push_back(mlirDenseI32ArrayGetElement(attr, i));
+ }
+ return result;
+}
static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
@@ -20,6 +53,33 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
nb::arg("op"),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");
+
+ m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
+ "Checks if the given operation is a Linalg contraction operation.",
+ nb::arg("op"));
+
+ nb::class_<PyContractionDimensions>(m, "ContractionDimensions")
+ .def_prop_ro("batch",
+ [](const PyContractionDimensions &self) {
+ return convertDenseI32AttrToList(self.value.batch);
+ })
+ .def_prop_ro("m",
+ [](const PyContractionDimensions &self) {
+ return convertDenseI32AttrToList(self.value.m);
+ })
+ .def_prop_ro("n",
+ [](const PyContractionDimensions &self) {
+ return convertDenseI32AttrToList(self.value.n);
+ })
+ .def_prop_ro("k", [](const PyContractionDimensions &self) {
+ return convertDenseI32AttrToList(self.value.k);
+ });
+
+ m.def("infer_contraction_dimensions",
+ &mlirLinalgInferContractionDimensionsBinding,
+ "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
+ "op.",
+ nb::arg("op"));
}
NB_MODULE(_mlirDialectsLinalg, m) {
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 2fb5bc651de07..7e053d1188f24 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -41,4 +41,36 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
fun(b, *body, op->getAttrs());
}
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
+ auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+ return linalg::isaContractionOpInterface(linalgOp);
+}
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op) {
+ MlirLinalgContractionDimensions result{};
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(op));
+ if (!linalgOp)
+ return result;
+
+ auto maybeDims = linalg::inferContractionDims(linalgOp);
+ if (failed(maybeDims))
+ return result;
+
+ linalg::ContractionDimensions contractionDims = maybeDims.value();
+ MLIRContext *ctx = linalgOp.getContext();
+
+ auto toAttr = [&](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+ SmallVector<int32_t> intVals(vals.begin(), vals.end());
+ return wrap(DenseI32ArrayAttr::get(ctx, intVals));
+ };
+
+ result.batch = toAttr(contractionDims.batch);
+ result.m = toAttr(contractionDims.m);
+ result.n = toAttr(contractionDims.n);
+ result.k = toAttr(contractionDims.k);
+
+ return result;
+}
+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index e32a911b24b11..2574e5736eb92 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -606,3 +606,38 @@ def tensor_pack(src, dst):
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
print(module)
+
+
+ at run
+def test_infer_contraction_dimensions():
+ with Context(), Location.unknown():
+ module = ir.Module.parse(
+ r"""
+ module {
+ func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>)
+ -> tensor<4x4xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
+ %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>)
+ outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
+ return %1 : tensor<4x4xf32>
+ }
+ }
+ """
+ )
+ func_op = module.body.operations[0]
+ body_block = func_op.regions[0].blocks[0]
+ fill_op = body_block.operations[1]
+ matmul_op = body_block.operations[2]
+
+ assert not linalg.isa_contraction_op(fill_op)
+ assert linalg.isa_contraction_op(matmul_op)
+
+ dims = linalg.infer_contraction_dimensions(fill_op)
+ assert dims is None
+ dims = linalg.infer_contraction_dimensions(matmul_op)
+ assert dims
+
+ assert dims.m == [0], f"Expected m=[0], got {dims.m}"
+ assert dims.n == [1], f"Expected n=[1], got {dims.n}"
+ assert dims.k == [2], f"Expected k=[2], got {dims.k}"
>From 113cd15c7592e3fd9518192f0a935e4046dd3a69 Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Wed, 9 Apr 2025 13:01:41 -0700
Subject: [PATCH 2/2] address reviewer comments
Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
mlir/lib/Bindings/Python/DialectLinalg.cpp | 52 ++++--------
mlir/lib/CAPI/Dialect/Linalg.cpp | 12 +--
mlir/test/python/dialects/linalg/ops.py | 35 --------
mlir/test/python/dialects/linalg/utils.py | 97 ++++++++++++++++++++++
4 files changed, 119 insertions(+), 77 deletions(-)
create mode 100644 mlir/test/python/dialects/linalg/utils.py
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 0dbd4f18b7212..e9f9dd1d27b17 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Nanobind.h"
@@ -15,16 +14,8 @@
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
-struct PyContractionDimensions {
- MlirLinalgContractionDimensions value;
-
- PyContractionDimensions() = default;
- PyContractionDimensions(const MlirLinalgContractionDimensions &v)
- : value(v) {}
-};
-
-static std::optional<PyContractionDimensions>
-mlirLinalgInferContractionDimensionsBinding(MlirOperation op) {
+static std::optional<MlirLinalgContractionDimensions>
+InferContractionDimensions(MlirOperation op) {
MlirLinalgContractionDimensions dims =
mlirLinalgInferContractionDimensions(op);
@@ -33,17 +24,7 @@ mlirLinalgInferContractionDimensionsBinding(MlirOperation op) {
mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
return std::nullopt;
}
- return PyContractionDimensions{dims};
-}
-
-static std::vector<int32_t> convertDenseI32AttrToList(MlirAttribute attr) {
- std::vector<int32_t> result;
- int64_t size = mlirDenseArrayGetNumElements(attr);
- result.reserve(size);
- for (int64_t i = 0; i < size; ++i) {
- result.push_back(mlirDenseI32ArrayGetElement(attr, i));
- }
- return result;
+ return dims;
}
static void populateDialectLinalgSubmodule(nb::module_ m) {
@@ -58,25 +39,22 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
"Checks if the given operation is a Linalg contraction operation.",
nb::arg("op"));
- nb::class_<PyContractionDimensions>(m, "ContractionDimensions")
+ nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
.def_prop_ro("batch",
- [](const PyContractionDimensions &self) {
- return convertDenseI32AttrToList(self.value.batch);
- })
- .def_prop_ro("m",
- [](const PyContractionDimensions &self) {
- return convertDenseI32AttrToList(self.value.m);
- })
- .def_prop_ro("n",
- [](const PyContractionDimensions &self) {
- return convertDenseI32AttrToList(self.value.n);
+ [](const MlirLinalgContractionDimensions &self) {
+ return self.batch;
})
- .def_prop_ro("k", [](const PyContractionDimensions &self) {
- return convertDenseI32AttrToList(self.value.k);
+ .def_prop_ro(
+ "m",
+ [](const MlirLinalgContractionDimensions &self) { return self.m; })
+ .def_prop_ro(
+ "n",
+ [](const MlirLinalgContractionDimensions &self) { return self.n; })
+ .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
+ return self.k;
});
- m.def("infer_contraction_dimensions",
- &mlirLinalgInferContractionDimensionsBinding,
+ m.def("infer_contraction_dimensions", &InferContractionDimensions,
"Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
"op.",
nb::arg("op"));
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 7e053d1188f24..362b89bdef6c9 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -43,6 +43,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+ // isaContractionOpInterface handles null linalgOp internally.
return linalg::isaContractionOpInterface(linalgOp);
}
@@ -53,16 +54,17 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
if (!linalgOp)
return result;
- auto maybeDims = linalg::inferContractionDims(linalgOp);
+ FailureOr<linalg::ContractionDimensions> maybeDims =
+ linalg::inferContractionDims(linalgOp);
if (failed(maybeDims))
return result;
- linalg::ContractionDimensions contractionDims = maybeDims.value();
+ linalg::ContractionDimensions contractionDims = *maybeDims;
MLIRContext *ctx = linalgOp.getContext();
- auto toAttr = [&](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
- SmallVector<int32_t> intVals(vals.begin(), vals.end());
- return wrap(DenseI32ArrayAttr::get(ctx, intVals));
+ auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+ return wrap(
+ DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
};
result.batch = toAttr(contractionDims.batch);
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index 2574e5736eb92..e32a911b24b11 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -606,38 +606,3 @@ def tensor_pack(src, dst):
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
print(module)
-
-
- at run
-def test_infer_contraction_dimensions():
- with Context(), Location.unknown():
- module = ir.Module.parse(
- r"""
- module {
- func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>)
- -> tensor<4x4xf32> {
- %cst = arith.constant 0.0 : f32
- %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
- %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>)
- outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
- return %1 : tensor<4x4xf32>
- }
- }
- """
- )
- func_op = module.body.operations[0]
- body_block = func_op.regions[0].blocks[0]
- fill_op = body_block.operations[1]
- matmul_op = body_block.operations[2]
-
- assert not linalg.isa_contraction_op(fill_op)
- assert linalg.isa_contraction_op(matmul_op)
-
- dims = linalg.infer_contraction_dimensions(fill_op)
- assert dims is None
- dims = linalg.infer_contraction_dimensions(matmul_op)
- assert dims
-
- assert dims.m == [0], f"Expected m=[0], got {dims.m}"
- assert dims.n == [1], f"Expected n=[1], got {dims.n}"
- assert dims.k == [2], f"Expected k=[2], got {dims.k}"
diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py
new file mode 100644
index 0000000000000..a48aa90fa5836
--- /dev/null
+++ b/mlir/test/python/dialects/linalg/utils.py
@@ -0,0 +1,97 @@
+# RUN: %PYTHON %s
+
+from mlir.dialects import arith, func, linalg
+from mlir.dialects.linalg.opdsl.lang import *
+from mlir.ir import *
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ return f
+
+
+ at run
+def test_infer_contraction_dimensions_from_ops():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ # === Static shapes ===
+ m, n, k = 4, 4, 4
+ a_type = RankedTensorType.get((m, k), f32)
+ b_type = RankedTensorType.get((k, n), f32)
+ c_type = RankedTensorType.get((m, n), f32)
+
+ @func.FuncOp.from_py_func(a_type, b_type, c_type)
+ def contraction_fn(a, b, c):
+ zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
+ filled = linalg.fill(zero, outs=[c])
+ fill_op = filled.owner
+
+ assert not linalg.isa_contraction_op(zero.operation)
+ assert not linalg.isa_contraction_op(fill_op)
+ assert linalg.infer_contraction_dimensions(fill_op) is None
+
+ dim_m = AffineDimExpr.get(0)
+ dim_n = AffineDimExpr.get(1)
+ dim_k = AffineDimExpr.get(2)
+
+ a_map = AffineMap.get(3, 0, [dim_m, dim_k])
+ b_map = AffineMap.get(3, 0, [dim_k, dim_n])
+ c_map = AffineMap.get(3, 0, [dim_m, dim_n])
+ result = linalg.contract(
+ a,
+ b,
+ outs=(filled,),
+ indexing_maps=[a_map, b_map, c_map],
+ )
+ contraction_op = result.owner
+
+ assert linalg.isa_contraction_op(contraction_op)
+ dims = linalg.infer_contraction_dimensions(contraction_op)
+ assert dims is not None
+
+ # Expect m=[0], n=[1], k=[2] as per standard matmul
+ assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
+ assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
+ assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
+ assert (
+ list(dims.batch) == []
+ ), f"Expected batch=[], got {list(dims.batch)}"
+
+ # === Dynamic shape case ===
+ dyn = ShapedType.get_dynamic_size()
+ a_dyn_type = RankedTensorType.get((4, dyn), f32)
+ b_dyn_type = RankedTensorType.get((dyn, 4), f32)
+ c_type = RankedTensorType.get((4, 4), f32)
+
+ @func.FuncOp.from_py_func(a_dyn_type, b_dyn_type, c_type)
+ def dynamic_contraction_fn(a, b, c):
+ zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
+ filled = linalg.fill(zero, outs=[c])
+ dim_m = AffineDimExpr.get(0)
+ dim_n = AffineDimExpr.get(1)
+ dim_k = AffineDimExpr.get(2)
+
+ a_map = AffineMap.get(3, 0, [dim_m, dim_k])
+ b_map = AffineMap.get(3, 0, [dim_k, dim_n])
+ c_map = AffineMap.get(3, 0, [dim_m, dim_n])
+
+ result = linalg.contract(
+ a,
+ b,
+ outs=(filled,),
+ indexing_maps=[a_map, b_map, c_map],
+ )
+ contraction_op = result.owner
+
+ assert linalg.isa_contraction_op(contraction_op)
+ dims = linalg.infer_contraction_dimensions(contraction_op)
+ assert dims is not None
+ assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
+ assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
+ assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
+ assert (
+ list(dims.batch) == []
+ ), f"Expected batch=[], got {list(dims.batch)}"
More information about the Mlir-commits
mailing list