[Mlir-commits] [mlir] c359f76 - [mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims (#134935)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 9 17:01:42 PDT 2025
Author: Bangtian Liu
Date: 2025-04-09T20:01:38-04:00
New Revision: c359f7625f4d5bacbd88c9c9d26943b7a7e45a3e
URL: https://github.com/llvm/llvm-project/commit/c359f7625f4d5bacbd88c9c9d26943b7a7e45a3e
DIFF: https://github.com/llvm/llvm-project/commit/c359f7625f4d5bacbd88c9c9d26943b7a7e45a3e.diff
LOG: [mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims (#134935)
This PR is mainly about exposing the python bindings for`
linalg::isaContractionOpInterface` and` linalg::inferContractionDims`.
---------
Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
Added:
mlir/test/python/dialects/linalg/utils.py
Modified:
mlir/include/mlir-c/Dialect/Linalg.h
mlir/lib/Bindings/Python/DialectLinalg.cpp
mlir/lib/CAPI/Dialect/Linalg.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 0ab201e158033..c57d193e62d25 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -22,6 +22,18 @@ extern "C" {
MLIR_CAPI_EXPORTED void
mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
+
+struct MlirLinalgContractionDimensions {
+ MlirAttribute batch;
+ MlirAttribute m;
+ MlirAttribute n;
+ MlirAttribute k;
+};
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op);
+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
#ifdef __cplusplus
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 548df4ee100aa..978ea8664b6b9 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -8,10 +8,25 @@
#include "mlir-c/Dialect/Linalg.h"
#include "mlir-c/IR.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
+using namespace mlir::python::nanobind_adaptors;
+
+static std::optional<MlirLinalgContractionDimensions>
+InferContractionDimensions(MlirOperation op) {
+ MlirLinalgContractionDimensions dims =
+ mlirLinalgInferContractionDimensions(op);
+
+ // Detect "empty" result. This occurs when `op` is not a contraction op,
+ // or when `linalg::inferContractionDims` fails.
+ if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) &&
+ mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) {
+ return std::nullopt;
+ }
+ return dims;
+}
static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def(
@@ -20,6 +35,30 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
nb::arg("op"),
"Fill the region for `op`, which is assumed to be a builtin named Linalg "
"op.");
+
+ m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
+ "Checks if the given operation is a Linalg contraction operation.",
+ nb::arg("op"));
+
+ nb::class_<MlirLinalgContractionDimensions>(m, "ContractionDimensions")
+ .def_prop_ro("batch",
+ [](const MlirLinalgContractionDimensions &self) {
+ return self.batch;
+ })
+ .def_prop_ro(
+ "m",
+ [](const MlirLinalgContractionDimensions &self) { return self.m; })
+ .def_prop_ro(
+ "n",
+ [](const MlirLinalgContractionDimensions &self) { return self.n; })
+ .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) {
+ return self.k;
+ });
+
+ m.def("infer_contraction_dimensions", &InferContractionDimensions,
+ "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
+ "op.",
+ nb::arg("op"));
}
NB_MODULE(_mlirDialectsLinalg, m) {
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 2fb5bc651de07..362b89bdef6c9 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -41,4 +41,38 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
fun(b, *body, op->getAttrs());
}
+MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
+ auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+ // isaContractionOpInterface handles null linalgOp internally.
+ return linalg::isaContractionOpInterface(linalgOp);
+}
+
+MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
+mlirLinalgInferContractionDimensions(MlirOperation op) {
+ MlirLinalgContractionDimensions result{};
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(unwrap(op));
+ if (!linalgOp)
+ return result;
+
+ FailureOr<linalg::ContractionDimensions> maybeDims =
+ linalg::inferContractionDims(linalgOp);
+ if (failed(maybeDims))
+ return result;
+
+ linalg::ContractionDimensions contractionDims = *maybeDims;
+ MLIRContext *ctx = linalgOp.getContext();
+
+ auto toAttr = [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+ return wrap(
+ DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t, 2>(vals)));
+ };
+
+ result.batch = toAttr(contractionDims.batch);
+ result.m = toAttr(contractionDims.m);
+ result.n = toAttr(contractionDims.n);
+ result.k = toAttr(contractionDims.k);
+
+ return result;
+}
+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py
new file mode 100644
index 0000000000000..a48aa90fa5836
--- /dev/null
+++ b/mlir/test/python/dialects/linalg/utils.py
@@ -0,0 +1,97 @@
+# RUN: %PYTHON %s
+
+from mlir.dialects import arith, func, linalg
+from mlir.dialects.linalg.opdsl.lang import *
+from mlir.ir import *
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ return f
+
+
+ at run
+def test_infer_contraction_dimensions_from_ops():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ # === Static shapes ===
+ m, n, k = 4, 4, 4
+ a_type = RankedTensorType.get((m, k), f32)
+ b_type = RankedTensorType.get((k, n), f32)
+ c_type = RankedTensorType.get((m, n), f32)
+
+ @func.FuncOp.from_py_func(a_type, b_type, c_type)
+ def contraction_fn(a, b, c):
+ zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
+ filled = linalg.fill(zero, outs=[c])
+ fill_op = filled.owner
+
+ assert not linalg.isa_contraction_op(zero.operation)
+ assert not linalg.isa_contraction_op(fill_op)
+ assert linalg.infer_contraction_dimensions(fill_op) is None
+
+ dim_m = AffineDimExpr.get(0)
+ dim_n = AffineDimExpr.get(1)
+ dim_k = AffineDimExpr.get(2)
+
+ a_map = AffineMap.get(3, 0, [dim_m, dim_k])
+ b_map = AffineMap.get(3, 0, [dim_k, dim_n])
+ c_map = AffineMap.get(3, 0, [dim_m, dim_n])
+ result = linalg.contract(
+ a,
+ b,
+ outs=(filled,),
+ indexing_maps=[a_map, b_map, c_map],
+ )
+ contraction_op = result.owner
+
+ assert linalg.isa_contraction_op(contraction_op)
+ dims = linalg.infer_contraction_dimensions(contraction_op)
+ assert dims is not None
+
+ # Expect m=[0], n=[1], k=[2] as per standard matmul
+ assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
+ assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
+ assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
+ assert (
+ list(dims.batch) == []
+ ), f"Expected batch=[], got {list(dims.batch)}"
+
+ # === Dynamic shape case ===
+ dyn = ShapedType.get_dynamic_size()
+ a_dyn_type = RankedTensorType.get((4, dyn), f32)
+ b_dyn_type = RankedTensorType.get((dyn, 4), f32)
+ c_type = RankedTensorType.get((4, 4), f32)
+
+ @func.FuncOp.from_py_func(a_dyn_type, b_dyn_type, c_type)
+ def dynamic_contraction_fn(a, b, c):
+ zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
+ filled = linalg.fill(zero, outs=[c])
+ dim_m = AffineDimExpr.get(0)
+ dim_n = AffineDimExpr.get(1)
+ dim_k = AffineDimExpr.get(2)
+
+ a_map = AffineMap.get(3, 0, [dim_m, dim_k])
+ b_map = AffineMap.get(3, 0, [dim_k, dim_n])
+ c_map = AffineMap.get(3, 0, [dim_m, dim_n])
+
+ result = linalg.contract(
+ a,
+ b,
+ outs=(filled,),
+ indexing_maps=[a_map, b_map, c_map],
+ )
+ contraction_op = result.owner
+
+ assert linalg.isa_contraction_op(contraction_op)
+ dims = linalg.infer_contraction_dimensions(contraction_op)
+ assert dims is not None
+ assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
+ assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
+ assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
+ assert (
+ list(dims.batch) == []
+ ), f"Expected batch=[], got {list(dims.batch)}"
More information about the Mlir-commits
mailing list