[Mlir-commits] [mlir] [mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims (PR #134935)

Bangtian Liu llvmlistbot at llvm.org
Tue Apr 8 15:03:47 PDT 2025


https://github.com/bangtianliu created https://github.com/llvm/llvm-project/pull/134935

This PR is mainly about exposing the python bindings for` linalg::isaContractionOpInterface` and` linalg::inferContractionDims`

>From 29f4f994731d19bb0e6e40139ca28e4b0328dbf5 Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Tue, 8 Apr 2025 15:03:11 -0700
Subject: [PATCH] [mlir][python] expose python bindings for
 linalg::isaContractionOpInterface and linalg::inferContractionDims

Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
 mlir/include/mlir-c/Dialect/Linalg.h       | 12 +++++
 mlir/lib/Bindings/Python/DialectLinalg.cpp | 62 +++++++++++++++++++++-
 mlir/lib/CAPI/Dialect/Linalg.cpp           | 32 +++++++++++
 mlir/test/python/dialects/linalg/ops.py    | 33 ++++++++++++
 4 files changed, 138 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 0ab201e158033..c57d193e62d25 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -22,6 +22,18 @@ extern "C" {
 MLIR_CAPI_EXPORTED void
 mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
 
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
+
+struct MlirLinalgContractionDimensions {
+  MlirAttribute batch;
+  MlirAttribute m;
+  MlirAttribute n;
+  MlirAttribute k;
+};
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op);
+
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
 
 #ifdef __cplusplus
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 548df4ee100aa..0dbd4f18b7212 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -6,12 +6,45 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Dialect/Linalg.h"
 #include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
+using namespace mlir::python::nanobind_adaptors;
+
+struct PyContractionDimensions {
+  MlirLinalgContractionDimensions value;
+
+  PyContractionDimensions() = default;
+  PyContractionDimensions(const MlirLinalgContractionDimensions &v)
+      : value(v) {}
+};
+
+static std::optional<PyContractionDimensions>
+mlirLinalgInferContractionDimensionsBinding(MlirOperation op) {
+  MlirLinalgContractionDimensions dims =
+      mlirLinalgInferContractionDimensions(op);
+
+  // Detect "empty" result.
+  if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
+      mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
+    return std::nullopt;
+  }
+  return PyContractionDimensions{dims};
+}
+
+static std::vector<int32_t> convertDenseI32AttrToList(MlirAttribute attr) {
+  std::vector<int32_t> result;
+  int64_t size = mlirDenseArrayGetNumElements(attr);
+  result.reserve(size);
+  for (int64_t i = 0; i < size; ++i) {
+    result.push_back(mlirDenseI32ArrayGetElement(attr, i));
+  }
+  return result;
+}
 
 static void populateDialectLinalgSubmodule(nb::module_ m) {
   m.def(
@@ -20,6 +53,33 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
       nb::arg("op"),
       "Fill the region for `op`, which is assumed to be a builtin named Linalg "
       "op.");
+
+  m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
+        "Checks if the given operation is a Linalg contraction operation.",
+        nb::arg("op"));
+
+  nb::class_<PyContractionDimensions>(m, "ContractionDimensions")
+      .def_prop_ro("batch",
+                   [](const PyContractionDimensions &self) {
+                     return convertDenseI32AttrToList(self.value.batch);
+                   })
+      .def_prop_ro("m",
+                   [](const PyContractionDimensions &self) {
+                     return convertDenseI32AttrToList(self.value.m);
+                   })
+      .def_prop_ro("n",
+                   [](const PyContractionDimensions &self) {
+                     return convertDenseI32AttrToList(self.value.n);
+                   })
+      .def_prop_ro("k", [](const PyContractionDimensions &self) {
+        return convertDenseI32AttrToList(self.value.k);
+      });
+
+  m.def("infer_contraction_dimensions",
+        &mlirLinalgInferContractionDimensionsBinding,
+        "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
+        "op.",
+        nb::arg("op"));
 }
 
 NB_MODULE(_mlirDialectsLinalg, m) {
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 2fb5bc651de07..7e053d1188f24 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -41,4 +41,36 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
   fun(b, *body, op->getAttrs());
 }
 
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
+  auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+  return linalg::isaContractionOpInterface(linalgOp);
+}
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op) {
+  MlirLinalgContractionDimensions result{};
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(op));
+  if (!linalgOp)
+    return result;
+
+  auto maybeDims = linalg::inferContractionDims(linalgOp);
+  if (failed(maybeDims))
+    return result;
+
+  linalg::ContractionDimensions contractionDims = maybeDims.value();
+  MLIRContext *ctx = linalgOp.getContext();
+
+  auto toAttr = [&](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+    SmallVector<int32_t> intVals(vals.begin(), vals.end());
+    return wrap(DenseI32ArrayAttr::get(ctx, intVals));
+  };
+
+  result.batch = toAttr(contractionDims.batch);
+  result.m = toAttr(contractionDims.m);
+  result.n = toAttr(contractionDims.n);
+  result.k = toAttr(contractionDims.k);
+
+  return result;
+}
+
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index e32a911b24b11..3129a9bbe1d8a 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -606,3 +606,36 @@ def tensor_pack(src, dst):
         # CHECK:           return %[[VAL_4]] : tensor<128x128xf32>
         # CHECK:         }
         print(module)
+
+
+ at run
+def test_infer_contraction_dimensions():
+    with Context(), Location.unknown():
+        module = ir.Module.parse(r"""
+            module {
+                func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>)
+                    -> tensor<4x4xf32> {
+                    %cst = arith.constant 0.0 : f32
+                    %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
+                    %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) 
+                        outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
+                    return %1 : tensor<4x4xf32>
+                }
+            }
+        """)
+        func_op = module.body.operations[0]
+        body_block = func_op.regions[0].blocks[0]
+        fill_op = body_block.operations[1]
+        matmul_op = body_block.operations[2]
+
+        assert not linalg.isa_contraction_op(fill_op)
+        assert linalg.isa_contraction_op(matmul_op)
+
+        dims = linalg.infer_contraction_dimensions(fill_op)
+        assert dims is None
+        dims = linalg.infer_contraction_dimensions(matmul_op)
+        assert dims
+
+        assert dims.m == [0], f"Expected m=[0], got {dims.m}"
+        assert dims.n == [1], f"Expected n=[1], got {dims.n}"
+        assert dims.k == [2], f"Expected k=[2], got {dims.k}"



More information about the Mlir-commits mailing list