[Mlir-commits] [mlir] [mlir][Linalg] Prevent vectorization of generic Conv with dynamic dims (PR #185415)

Abhishek Varma llvmlistbot at llvm.org
Mon Mar 9 06:20:20 PDT 2026


https://github.com/Abhishek-Varma created https://github.com/llvm/llvm-project/pull/185415

-- We should use `isaConvolutionOpInterface` instead as it accommodates both named as well as generic convolution ops.
-- https://github.com/llvm/llvm-project/pull/176339 missed making one such update to `vectorizeDynamicLinalgOpPrecondition` and it got exposed in a downstream project.
-- This commit therefore aims to fix the same.

>From 54c288acca5a4c4a5a3ab29250c66a2f021dcf07 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 9 Mar 2026 13:11:39 +0000
Subject: [PATCH] [mlir][Linalg] Prevent vectorization of generic Conv with
 dynamic spatial dims

-- We should use isaConvolutionOpInterface instead as it accommodates
   both named as well as generic convolution ops.
-- https://github.com/llvm/llvm-project/pull/176339 missed making one such
   update to `vectorizeDynamicLinalgOpPrecondition` and it got exposed in
   a downstream project.
-- This commit therefore aims to fix the same.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../Linalg/Transforms/Vectorization.cpp       |  2 +-
 .../Linalg/vectorization/unsupported.mlir     | 27 +++++++++++++++++++
 2 files changed, 28 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 7ac759f635f87..0477815f329bf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2110,7 +2110,7 @@ vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
 static LogicalResult
 vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
                                      bool flatten1DDepthwiseConv) {
-  if (isa<ConvolutionOpInterface>(op.getOperation()))
+  if (isaConvolutionOpInterface(op))
     return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
 
   if (hasReductionIterator(op))
diff --git a/mlir/test/Dialect/Linalg/vectorization/unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization/unsupported.mlir
index 271d6169609e9..5fa64679e595d 100644
--- a/mlir/test/Dialect/Linalg/vectorization/unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/unsupported.mlir
@@ -112,6 +112,33 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @generic_conv1d_ncw_fcw_dyn_spatial(%input: tensor<1x8x?xf16>, %filter: tensor<4x8x1xf16>, %output: tensor<1x4x?xf16>) -> tensor<1x4x?xf16> {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  %0 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>,
+                      affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
+                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
+     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
+    ins(%input, %filter : tensor<1x8x?xf16>, tensor<4x8x1xf16>)
+    outs(%output : tensor<1x4x?xf16>) {
+  ^bb0(%in: f16, %filt: f16, %out: f16):
+    %mul = arith.mulf %in, %filt : f16
+    %add = arith.addf %out, %mul : f16
+    linalg.yield %add : f16
+  } -> tensor<1x4x?xf16>
+  return %0 : tensor<1x4x?xf16>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @conv2d_nchw_fchw(%input: tensor<1x5x8x8xf32>, %filter:  tensor<4x5x3x3xf32>, %output: tensor<1x4x6x6xf32>) {
   // expected-error @+1 {{Attempted to vectorize, but failed}}
   linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x5x8x8xf32>, tensor<4x5x3x3xf32>) outs(%output : tensor<1x4x6x6xf32>) -> tensor<1x4x6x6xf32>



More information about the Mlir-commits mailing list