[Mlir-commits] [mlir] [mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims (PR #134935)
Bangtian Liu
llvmlistbot at llvm.org
Tue Apr 8 15:03:47 PDT 2025
https://github.com/bangtianliu created https://github.com/llvm/llvm-project/pull/134935
This PR is mainly about exposing the python bindings for` linalg::isaContractionOpInterface` and` linalg::inferContractionDims`
>From 29f4f994731d19bb0e6e40139ca28e4b0328dbf5 Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Tue, 8 Apr 2025 15:03:11 -0700
Subject: [PATCH] [mlir][python] expose python bindings for
linalg::isaContractionOpInterface and linalg::inferContractionDims
Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
mlir/include/mlir-c/Dialect/Linalg.h | 12 +++++
mlir/lib/Bindings/Python/DialectLinalg.cpp | 62 +++++++++++++++++++++-
mlir/lib/CAPI/Dialect/Linalg.cpp | 32 +++++++++++
mlir/test/python/dialects/linalg/ops.py | 33 ++++++++++++
4 files changed, 138 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 0ab201e158033..c57d193e62d25 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -22,6 +22,18 @@ extern "C" {
MLIR_CAPI_EXPORTED void
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
+
+struct MlirLinalgContractionDimensions {
+ MlirAttribute batch;
+ MlirAttribute m;
+ MlirAttribute n;
+ MlirAttribute k;
+};
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op);
+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
#ifdef __cplusplus
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 548df4ee100aa..0dbd4f18b7212 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -6,12 +6,45 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
+using namespace mlir::python::nanobind_adaptors;
+
+struct PyContractionDimensions {
+ MlirLinalgContractionDimensions value;
+
+ PyContractionDimensions() = default;
+ PyContractionDimensions(const MlirLinalgContractionDimensions &v)
+ : value(v) {}
+};
+
+static std::optional<PyContractionDimensions>
+mlirLinalgInferContractionDimensionsBinding(MlirOperation op) {
+ MlirLinalgContractionDimensions dims =
+ mlirLinalgInferContractionDimensions(op);
+
+ // Detect "empty" result.
+ if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
+ mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
+ return std::nullopt;
+ }
+ return PyContractionDimensions{dims};
+}
+
+static std::vector<int32_t> convertDenseI32AttrToList(MlirAttribute attr) {
+ std::vector<int32_t> result;
+ int64_t size = mlirDenseArrayGetNumElements(attr);
+ result.reserve(size);
+ for (int64_t i = 0; i < size; ++i) {
+ result.push_back(mlirDenseI32ArrayGetElement(attr, i));
+ }
+ return result;
+}
static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
@@ -20,6 +53,33 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
nb::arg("op"),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");
+
+ m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
+ "Checks if the given operation is a Linalg contraction operation.",
+ nb::arg("op"));
+
+ nb::class_<PyContractionDimensions>(m, "ContractionDimensions")
+ .def_prop_ro("batch",
+ [](const PyContractionDimensions &self) {
+ return convertDenseI32AttrToList(self.value.batch);
+ })
+ .def_prop_ro("m",
+ [](const PyContractionDimensions &self) {
+ return convertDenseI32AttrToList(self.value.m);
+ })
+ .def_prop_ro("n",
+ [](const PyContractionDimensions &self) {
+ return convertDenseI32AttrToList(self.value.n);
+ })
+ .def_prop_ro("k", [](const PyContractionDimensions &self) {
+ return convertDenseI32AttrToList(self.value.k);
+ });
+
+ m.def("infer_contraction_dimensions",
+ &mlirLinalgInferContractionDimensionsBinding,
+ "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
+ "op.",
+ nb::arg("op"));
}
NB_MODULE(_mlirDialectsLinalg, m) {
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 2fb5bc651de07..7e053d1188f24 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -41,4 +41,36 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
fun(b, *body, op->getAttrs());
}
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
+ auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+ return linalg::isaContractionOpInterface(linalgOp);
+}
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op) {
+ MlirLinalgContractionDimensions result{};
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(op));
+ if (!linalgOp)
+ return result;
+
+ auto maybeDims = linalg::inferContractionDims(linalgOp);
+ if (failed(maybeDims))
+ return result;
+
+ linalg::ContractionDimensions contractionDims = maybeDims.value();
+ MLIRContext *ctx = linalgOp.getContext();
+
+ auto toAttr = [&](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+ SmallVector<int32_t> intVals(vals.begin(), vals.end());
+ return wrap(DenseI32ArrayAttr::get(ctx, intVals));
+ };
+
+ result.batch = toAttr(contractionDims.batch);
+ result.m = toAttr(contractionDims.m);
+ result.n = toAttr(contractionDims.n);
+ result.k = toAttr(contractionDims.k);
+
+ return result;
+}
+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index e32a911b24b11..3129a9bbe1d8a 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -606,3 +606,36 @@ def tensor_pack(src, dst):
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
# CHECK: }
print(module)
+
+
+ at run
+def test_infer_contraction_dimensions():
+ with Context(), Location.unknown():
+ module = ir.Module.parse(r"""
+ module {
+ func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>)
+ -> tensor<4x4xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
+ %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>)
+ outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
+ return %1 : tensor<4x4xf32>
+ }
+ }
+ """)
+ func_op = module.body.operations[0]
+ body_block = func_op.regions[0].blocks[0]
+ fill_op = body_block.operations[1]
+ matmul_op = body_block.operations[2]
+
+ assert not linalg.isa_contraction_op(fill_op)
+ assert linalg.isa_contraction_op(matmul_op)
+
+ dims = linalg.infer_contraction_dimensions(fill_op)
+ assert dims is None
+ dims = linalg.infer_contraction_dimensions(matmul_op)
+ assert dims
+
+ assert dims.m == [0], f"Expected m=[0], got {dims.m}"
+ assert dims.n == [1], f"Expected n=[1], got {dims.n}"
+ assert dims.k == [2], f"Expected k=[2], got {dims.k}"
More information about the Mlir-commits
mailing list