[Mlir-commits] [mlir] c359f76 - [mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims (#134935)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 9 17:01:42 PDT 2025


Author: Bangtian Liu
Date: 2025-04-09T20:01:38-04:00
New Revision: c359f7625f4d5bacbd88c9c9d26943b7a7e45a3e

URL: https://github.com/llvm/llvm-project/commit/c359f7625f4d5bacbd88c9c9d26943b7a7e45a3e
DIFF: https://github.com/llvm/llvm-project/commit/c359f7625f4d5bacbd88c9c9d26943b7a7e45a3e.diff

LOG: [mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims (#134935)

This PR is mainly about exposing the python bindings for`
linalg::isaContractionOpInterface` and` linalg::inferContractionDims`.

---------

Signed-off-by: Bangtian Liu <liubangtian at gmail.com>

Added: 
    mlir/test/python/dialects/linalg/utils.py

Modified: 
    mlir/include/mlir-c/Dialect/Linalg.h
    mlir/lib/Bindings/Python/DialectLinalg.cpp
    mlir/lib/CAPI/Dialect/Linalg.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 0ab201e158033..c57d193e62d25 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -22,6 +22,18 @@ extern "C" {
 MLIR_CAPI_EXPORTED void
 mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
 
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
+
+struct MlirLinalgContractionDimensions {
+  MlirAttribute batch;
+  MlirAttribute m;
+  MlirAttribute n;
+  MlirAttribute k;
+};
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op);
+
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
 
 #ifdef __cplusplus

diff  --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 548df4ee100aa..978ea8664b6b9 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -8,10 +8,25 @@
 
 #include "mlir-c/Dialect/Linalg.h"
 #include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
+using namespace mlir::python::nanobind_adaptors;
+
+static std::optional<MlirLinalgContractionDimensions>
+InferContractionDimensions(MlirOperation op) {
+  MlirLinalgContractionDimensions dims =
+      mlirLinalgInferContractionDimensions(op);
+
+  // Detect "empty" result. This occurs when `op` is not a contraction op,
+  // or when `linalg::inferContractionDims` fails.
+  if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
+      mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
+    return std::nullopt;
+  }
+  return dims;
+}
 
 static void populateDialectLinalgSubmodule(nb::module_ m) {
   m.def(
@@ -20,6 +35,30 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
       nb::arg("op"),
       "Fill the region for `op`, which is assumed to be a builtin named Linalg "
       "op.");
+
+  m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
+        "Checks if the given operation is a Linalg contraction operation.",
+        nb::arg("op"));
+
+  nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
+      .def_prop_ro("batch",
+                   [](const MlirLinalgContractionDimensions &self) {
+                     return self.batch;
+                   })
+      .def_prop_ro(
+          "m",
+          [](const MlirLinalgContractionDimensions &self) { return self.m; })
+      .def_prop_ro(
+          "n",
+          [](const MlirLinalgContractionDimensions &self) { return self.n; })
+      .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
+        return self.k;
+      });
+
+  m.def("infer_contraction_dimensions", &InferContractionDimensions,
+        "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
+        "op.",
+        nb::arg("op"));
 }
 
 NB_MODULE(_mlirDialectsLinalg, m) {

diff  --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 2fb5bc651de07..362b89bdef6c9 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -41,4 +41,38 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
   fun(b, *body, op->getAttrs());
 }
 
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
+  auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+  // isaContractionOpInterface handles null linalgOp internally.
+  return linalg::isaContractionOpInterface(linalgOp);
+}
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op) {
+  MlirLinalgContractionDimensions result{};
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(op));
+  if (!linalgOp)
+    return result;
+
+  FailureOr<linalg::ContractionDimensions> maybeDims =
+      linalg::inferContractionDims(linalgOp);
+  if (failed(maybeDims))
+    return result;
+
+  linalg::ContractionDimensions contractionDims = *maybeDims;
+  MLIRContext *ctx = linalgOp.getContext();
+
+  auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+    return wrap(
+        DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
+  };
+
+  result.batch = toAttr(contractionDims.batch);
+  result.m = toAttr(contractionDims.m);
+  result.n = toAttr(contractionDims.n);
+  result.k = toAttr(contractionDims.k);
+
+  return result;
+}
+
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

diff  --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py
new file mode 100644
index 0000000000000..a48aa90fa5836
--- /dev/null
+++ b/mlir/test/python/dialects/linalg/utils.py
@@ -0,0 +1,97 @@
+# RUN: %PYTHON %s
+
+from mlir.dialects import arith, func, linalg
+from mlir.dialects.linalg.opdsl.lang import *
+from mlir.ir import *
+
+
+def run(f):
+    print("\nTEST:", f.__name__)
+    f()
+    return f
+
+
+ at run
+def test_infer_contraction_dimensions_from_ops():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            # === Static shapes ===
+            m, n, k = 4, 4, 4
+            a_type = RankedTensorType.get((m, k), f32)
+            b_type = RankedTensorType.get((k, n), f32)
+            c_type = RankedTensorType.get((m, n), f32)
+
+            @func.FuncOp.from_py_func(a_type, b_type, c_type)
+            def contraction_fn(a, b, c):
+                zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
+                filled = linalg.fill(zero, outs=[c])
+                fill_op = filled.owner
+
+                assert not linalg.isa_contraction_op(zero.operation)
+                assert not linalg.isa_contraction_op(fill_op)
+                assert linalg.infer_contraction_dimensions(fill_op) is None
+
+                dim_m = AffineDimExpr.get(0)
+                dim_n = AffineDimExpr.get(1)
+                dim_k = AffineDimExpr.get(2)
+
+                a_map = AffineMap.get(3, 0, [dim_m, dim_k])
+                b_map = AffineMap.get(3, 0, [dim_k, dim_n])
+                c_map = AffineMap.get(3, 0, [dim_m, dim_n])
+                result = linalg.contract(
+                    a,
+                    b,
+                    outs=(filled,),
+                    indexing_maps=[a_map, b_map, c_map],
+                )
+                contraction_op = result.owner
+
+                assert linalg.isa_contraction_op(contraction_op)
+                dims = linalg.infer_contraction_dimensions(contraction_op)
+                assert dims is not None
+
+                # Expect m=[0], n=[1], k=[2] as per standard matmul
+                assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
+                assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
+                assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
+                assert (
+                    list(dims.batch) == []
+                ), f"Expected batch=[], got {list(dims.batch)}"
+
+            # === Dynamic shape case ===
+            dyn = ShapedType.get_dynamic_size()
+            a_dyn_type = RankedTensorType.get((4, dyn), f32)
+            b_dyn_type = RankedTensorType.get((dyn, 4), f32)
+            c_type = RankedTensorType.get((4, 4), f32)
+
+            @func.FuncOp.from_py_func(a_dyn_type, b_dyn_type, c_type)
+            def dynamic_contraction_fn(a, b, c):
+                zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
+                filled = linalg.fill(zero, outs=[c])
+                dim_m = AffineDimExpr.get(0)
+                dim_n = AffineDimExpr.get(1)
+                dim_k = AffineDimExpr.get(2)
+
+                a_map = AffineMap.get(3, 0, [dim_m, dim_k])
+                b_map = AffineMap.get(3, 0, [dim_k, dim_n])
+                c_map = AffineMap.get(3, 0, [dim_m, dim_n])
+
+                result = linalg.contract(
+                    a,
+                    b,
+                    outs=(filled,),
+                    indexing_maps=[a_map, b_map, c_map],
+                )
+                contraction_op = result.owner
+
+                assert linalg.isa_contraction_op(contraction_op)
+                dims = linalg.infer_contraction_dimensions(contraction_op)
+                assert dims is not None
+                assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
+                assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
+                assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
+                assert (
+                    list(dims.batch) == []
+                ), f"Expected batch=[], got {list(dims.batch)}"


        


More information about the Mlir-commits mailing list