[Mlir-commits] [mlir] [mlir][CAPI][python] expose the python bindings for linalg::isaConvolutionOpInterface and linalg::inferConvolutionDims (PR #135253)

Bangtian Liu llvmlistbot at llvm.org
Thu Apr 10 14:02:38 PDT 2025


https://github.com/bangtianliu updated https://github.com/llvm/llvm-project/pull/135253

>From bed4492ef18774921cdf8aed52d8da7a7cacf96a Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Thu, 10 Apr 2025 13:49:51 -0700
Subject: [PATCH 1/2] [tuner] expose the python bindings for
 linalg::isaConvolutionOpInterface and linalg::inferConvolutionDims

Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
 mlir/include/mlir-c/Dialect/Linalg.h       | 18 +++++-
 mlir/lib/Bindings/Python/DialectLinalg.cpp | 63 ++++++++++++++++++++-
 mlir/lib/CAPI/Dialect/Linalg.cpp           | 47 ++++++++++++++-
 mlir/test/python/dialects/linalg/utils.py  | 66 +++++++++++++++++++++-
 4 files changed, 190 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index c57d193e62d25..8715739473f6c 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -22,7 +22,7 @@ extern "C" {
 MLIR_CAPI_EXPORTED void
 mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
 
-MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op);
+MLIR_CAPI_EXPORTED bool mlirLinalgIsaContractionOp(MlirOperation op);
 
 struct MlirLinalgContractionDimensions {
   MlirAttribute batch;
@@ -34,6 +34,22 @@ struct MlirLinalgContractionDimensions {
 MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
 mlirLinalgInferContractionDimensions(MlirOperation op);
 
+MLIR_CAPI_EXPORTED bool mlirLinalgIsaConvolutionOp(MlirOperation op);
+
+struct MlirLinalgConvolutionDimensions {
+  MlirAttribute batch;
+  MlirAttribute outputImage;
+  MlirAttribute outputChannel;
+  MlirAttribute filterLoop;
+  MlirAttribute inputChannel;
+  MlirAttribute depth;
+  MlirAttribute strides;
+  MlirAttribute dilations;
+};
+
+MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
+mlirLinalgInferConvolutionDimensions(MlirOperation op);
+
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
 
 #ifdef __cplusplus
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 978ea8664b6b9..d98bfd9f2d979 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -28,6 +28,26 @@ InferContractionDimensions(MlirOperation op) {
   return dims;
 }
 
+static std::optional<MlirLinalgConvolutionDimensions>
+InferConvolutionDimensions(MlirOperation op) {
+  MlirLinalgConvolutionDimensions dims =
+      mlirLinalgInferConvolutionDimensions(op);
+
+  // Detect "empty" result. This occurs when `op` is not a convolution op,
+  // or when `linalg::inferConvolutionDims` fails.
+  if (mlirAttributeIsNull(dims.batch) &&
+      mlirAttributeIsNull(dims.outputImage) &&
+      mlirAttributeIsNull(dims.outputChannel) &&
+      mlirAttributeIsNull(dims.filterLoop) &&
+      mlirAttributeIsNull(dims.inputChannel) &&
+      mlirAttributeIsNull(dims.depth) && mlirAttributeIsNull(dims.strides) &&
+      mlirAttributeIsNull(dims.dilations)) {
+    return std::nullopt;
+  }
+
+  return dims;
+}
+
 static void populateDialectLinalgSubmodule(nb::module_ m) {
   m.def(
       "fill_builtin_region",
@@ -36,7 +56,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
       "Fill the region for `op`, which is assumed to be a builtin named Linalg "
       "op.");
 
-  m.def("isa_contraction_op", &mlirLinalgIsContractionOp,
+  m.def("isa_contraction_op", &mlirLinalgIsaContractionOp,
         "Checks if the given operation is a Linalg contraction operation.",
         nb::arg("op"));
 
@@ -59,6 +79,47 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
         "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction "
         "op.",
         nb::arg("op"));
+
+  m.def("isa_convolution_op", &mlirLinalgIsaConvolutionOp,
+        "Checks if the given operation is a Linalg convolution operation.",
+        nb::arg("op"));
+
+  nb::class_<MlirLinalgConvolutionDimensions>(m, "ConvolutionDimensions")
+      .def_prop_ro("batch",
+                   [](const MlirLinalgConvolutionDimensions &self) {
+                     return self.batch;
+                   })
+      .def_prop_ro("output_image",
+                   [](const MlirLinalgConvolutionDimensions &self) {
+                     return self.outputImage;
+                   })
+      .def_prop_ro("output_channel",
+                   [](const MlirLinalgConvolutionDimensions &self) {
+                     return self.outputChannel;
+                   })
+      .def_prop_ro("filter_loop",
+                   [](const MlirLinalgConvolutionDimensions &self) {
+                     return self.filterLoop;
+                   })
+      .def_prop_ro("input_channel",
+                   [](const MlirLinalgConvolutionDimensions &self) {
+                     return self.inputChannel;
+                   })
+      .def_prop_ro("depth",
+                   [](const MlirLinalgConvolutionDimensions &self) {
+                     return self.depth;
+                   })
+      .def_prop_ro("strides",
+                   [](const MlirLinalgConvolutionDimensions &self) {
+                     return self.strides;
+                   })
+      .def_prop_ro("dilations",
+                   [](const MlirLinalgConvolutionDimensions &self) {
+                     return self.dilations;
+                   });
+
+  m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
+        "Infers convolution dimensions", nb::arg("op"));
 }
 
 NB_MODULE(_mlirDialectsLinalg, m) {
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 362b89bdef6c9..737d7e6e68641 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -41,7 +41,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
   fun(b, *body, op->getAttrs());
 }
 
-MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) {
+MLIR_CAPI_EXPORTED bool mlirLinalgIsaContractionOp(MlirOperation op) {
   auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
   // isaContractionOpInterface handles null linalgOp internally.
   return linalg::isaContractionOpInterface(linalgOp);
@@ -75,4 +75,49 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
   return result;
 }
 
+MLIR_CAPI_EXPORTED bool mlirLinalgIsaConvolutionOp(MlirOperation op) {
+  auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+  if (!linalgOp)
+    return false;
+
+  return linalg::isaConvolutionOpInterface(linalgOp);
+}
+
+MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
+mlirLinalgInferConvolutionDimensions(MlirOperation op) {
+  MlirLinalgConvolutionDimensions result{};
+  auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+  if (!linalgOp)
+    return result;
+
+  FailureOr<linalg::ConvolutionDimensions> maybeDims =
+      linalg::inferConvolutionDims(linalgOp);
+  if (failed(maybeDims))
+    return result;
+
+  linalg::ConvolutionDimensions dims = *maybeDims;
+  MLIRContext *ctx = linalgOp.getContext();
+
+  auto toI32Attr =
+      [&ctx](const SmallVector<unsigned, 2> &vals) -> MlirAttribute {
+    return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of<int32_t>(vals)));
+  };
+
+  auto toI64Attr =
+      [&ctx](const SmallVector<int64_t, 2> &vals) -> MlirAttribute {
+    return wrap(DenseI64ArrayAttr::get(ctx, vals));
+  };
+
+  result.batch = toI32Attr(dims.batch);
+  result.outputImage = toI32Attr(dims.outputImage);
+  result.outputChannel = toI32Attr(dims.outputChannel);
+  result.filterLoop = toI32Attr(dims.filterLoop);
+  result.inputChannel = toI32Attr(dims.inputChannel);
+  result.depth = toI32Attr(dims.depth);
+  result.strides = toI64Attr(dims.strides);
+  result.dilations = toI64Attr(dims.dilations);
+
+  return result;
+}
+
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py
index a48aa90fa5836..98157b0e443cf 100644
--- a/mlir/test/python/dialects/linalg/utils.py
+++ b/mlir/test/python/dialects/linalg/utils.py
@@ -52,7 +52,7 @@ def contraction_fn(a, b, c):
                 dims = linalg.infer_contraction_dimensions(contraction_op)
                 assert dims is not None
 
-                # Expect m=[0], n=[1], k=[2] as per standard matmul
+                # Expect m=[0], n=[1], k=[2] as per standard matmul.
                 assert list(dims.m) == [0], f"Expected m=[0], got {list(dims.m)}"
                 assert list(dims.n) == [1], f"Expected n=[1], got {list(dims.n)}"
                 assert list(dims.k) == [2], f"Expected k=[2], got {list(dims.k)}"
@@ -95,3 +95,67 @@ def dynamic_contraction_fn(a, b, c):
                 assert (
                     list(dims.batch) == []
                 ), f"Expected batch=[], got {list(dims.batch)}"
+
+
+ at run
+def test_infer_convolution_dimensions_from_ops():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+
+        with InsertionPoint(module.body):
+            # === Static shapes ===
+            batch, h, w, c_in, kh, kw, c_out = 1, 8, 8, 4, 3, 3, 16
+            input_type = RankedTensorType.get((batch, h, w, c_in), f32)
+            filter_type = RankedTensorType.get((kh, kw, c_in, c_out), f32)
+            output_type = RankedTensorType.get(
+                (batch, h - kh + 1, w - kw + 1, c_out), f32
+            )
+
+            @func.FuncOp.from_py_func(input_type, filter_type, output_type)
+            def conv_fn(input, filter, output):
+                zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
+                filled = linalg.fill(zero, outs=[output])
+                fill_op = filled.owner
+
+                assert not linalg.isa_convolution_op(fill_op)
+                assert linalg.infer_convolution_dimensions(fill_op) is None
+
+                result = linalg.conv_2d_nhwc_hwcf(input, filter, outs=[filled])
+                conv_op = result.owner
+
+                assert linalg.isa_convolution_op(conv_op)
+                dims = linalg.infer_convolution_dimensions(conv_op)
+                assert dims is not None
+                assert list(dims.batch) == [0]
+                assert list(dims.output_image) == [1, 2]
+                assert list(dims.output_channel) == [3]
+                assert list(dims.filter_loop) == [4, 5]
+                assert list(dims.input_channel) == [6]
+                assert list(dims.depth) == []
+                assert list(dims.strides) == [1, 1]
+                assert list(dims.dilations) == [1, 1]
+
+            # === Dynamic shapes ===
+            dyn = ShapedType.get_dynamic_size()
+            dyn_input_type = RankedTensorType.get((batch, dyn, dyn, c_in), f32)
+            dyn_output_type = RankedTensorType.get((batch, dyn, dyn, c_out), f32)
+
+            @func.FuncOp.from_py_func(dyn_input_type, filter_type, dyn_output_type)
+            def dyn_conv_fn(input, filter, output):
+                zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
+                filled = linalg.fill(zero, outs=[output])
+                result = linalg.conv_2d_nhwc_hwcf(input, filter, outs=[filled])
+                conv_op = result.owner
+
+                assert linalg.isa_convolution_op(conv_op)
+                dims = linalg.infer_convolution_dimensions(conv_op)
+                assert dims is not None
+                assert list(dims.batch) == [0]
+                assert list(dims.output_image) == [1, 2]
+                assert list(dims.output_channel) == [3]
+                assert list(dims.filter_loop) == [4, 5]
+                assert list(dims.input_channel) == [6]
+                assert list(dims.depth) == []
+                assert list(dims.strides) == [1, 1]
+                assert list(dims.dilations) == [1, 1]

>From d1647e4888d254a37f2eafa0e51299748b907224 Mon Sep 17 00:00:00 2001
From: Bangtian Liu <liubangtian at gmail.com>
Date: Thu, 10 Apr 2025 14:07:17 -0700
Subject: [PATCH 2/2] fix the cases

Signed-off-by: Bangtian Liu <liubangtian at gmail.com>
---
 mlir/include/mlir-c/Dialect/Linalg.h       | 4 ++--
 mlir/lib/Bindings/Python/DialectLinalg.cpp | 4 ++--
 mlir/lib/CAPI/Dialect/Linalg.cpp           | 4 ++--
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 8715739473f6c..838c280903e2e 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -22,7 +22,7 @@ extern "C" {
 MLIR_CAPI_EXPORTED void
 mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp);
 
-MLIR_CAPI_EXPORTED bool mlirLinalgIsaContractionOp(MlirOperation op);
+MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op);
 
 struct MlirLinalgContractionDimensions {
   MlirAttribute batch;
@@ -34,7 +34,7 @@ struct MlirLinalgContractionDimensions {
 MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions
 mlirLinalgInferContractionDimensions(MlirOperation op);
 
-MLIR_CAPI_EXPORTED bool mlirLinalgIsaConvolutionOp(MlirOperation op);
+MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op);
 
 struct MlirLinalgConvolutionDimensions {
   MlirAttribute batch;
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index d98bfd9f2d979..ce1102a3b3498 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -56,7 +56,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
       "Fill the region for `op`, which is assumed to be a builtin named Linalg "
       "op.");
 
-  m.def("isa_contraction_op", &mlirLinalgIsaContractionOp,
+  m.def("isa_contraction_op", &mlirLinalgIsAContractionOp,
         "Checks if the given operation is a Linalg contraction operation.",
         nb::arg("op"));
 
@@ -80,7 +80,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
         "op.",
         nb::arg("op"));
 
-  m.def("isa_convolution_op", &mlirLinalgIsaConvolutionOp,
+  m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp,
         "Checks if the given operation is a Linalg convolution operation.",
         nb::arg("op"));
 
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 737d7e6e68641..7c456102a2c0c 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -41,7 +41,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
   fun(b, *body, op->getAttrs());
 }
 
-MLIR_CAPI_EXPORTED bool mlirLinalgIsaContractionOp(MlirOperation op) {
+MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) {
   auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
   // isaContractionOpInterface handles null linalgOp internally.
   return linalg::isaContractionOpInterface(linalgOp);
@@ -75,7 +75,7 @@ mlirLinalgInferContractionDimensions(MlirOperation op) {
   return result;
 }
 
-MLIR_CAPI_EXPORTED bool mlirLinalgIsaConvolutionOp(MlirOperation op) {
+MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) {
   auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
   if (!linalgOp)
     return false;



More information about the Mlir-commits mailing list