[Mlir-commits] [mlir] 4a8e172 - [mlir][Linalg] Prevent vectorization of generic Conv with dynamic dims (#185415)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 9 09:47:22 PDT 2026


Author: Abhishek Varma
Date: 2026-03-09T16:47:16Z
New Revision: 4a8e172677c94d021260d47dc93fc16a5069f122

URL: https://github.com/llvm/llvm-project/commit/4a8e172677c94d021260d47dc93fc16a5069f122
DIFF: https://github.com/llvm/llvm-project/commit/4a8e172677c94d021260d47dc93fc16a5069f122.diff

LOG: [mlir][Linalg] Prevent vectorization of generic Conv with dynamic dims (#185415)

-- We should use `isaConvolutionOpInterface` instead as it accommodates
both named as well as generic convolution ops.
-- https://github.com/llvm/llvm-project/pull/176339 missed making one
such update to `vectorizeDynamicLinalgOpPrecondition` and it got exposed
in a downstream project.
-- This commit therefore aims to fix the same.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization/unsupported.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 7ac759f635f87..0477815f329bf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2110,7 +2110,7 @@ vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
 static LogicalResult
 vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
                                      bool flatten1DDepthwiseConv) {
-  if (isa<ConvolutionOpInterface>(op.getOperation()))
+  if (isaConvolutionOpInterface(op))
     return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
 
   if (hasReductionIterator(op))

diff  --git a/mlir/test/Dialect/Linalg/vectorization/unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization/unsupported.mlir
index 271d6169609e9..59d51e432b743 100644
--- a/mlir/test/Dialect/Linalg/vectorization/unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/unsupported.mlir
@@ -112,6 +112,36 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+/// Dynamic spatial dims for non-depthwise conv is not supported. This is already
+/// being tested for named ops and the following lit test checks that the same is
+/// applicable to linalg.generic conv ops as well.
+func.func @generic_conv1d_ncw_fcw_dyn_spatial(%input: tensor<1x8x?xf16>, %filter: tensor<4x8x1xf16>, %output: tensor<1x4x?xf16>) -> tensor<1x4x?xf16> {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  %0 = linalg.generic
+    {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>,
+                      affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>,
+                      affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>],
+     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
+    ins(%input, %filter : tensor<1x8x?xf16>, tensor<4x8x1xf16>)
+    outs(%output : tensor<1x4x?xf16>) {
+  ^bb0(%in: f16, %filt: f16, %out: f16):
+    %mul = arith.mulf %in, %filt : f16
+    %add = arith.addf %out, %mul : f16
+    linalg.yield %add : f16
+  } -> tensor<1x4x?xf16>
+  return %0 : tensor<1x4x?xf16>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @conv2d_nchw_fchw(%input: tensor<1x5x8x8xf32>, %filter:  tensor<4x5x3x3xf32>, %output: tensor<1x4x6x6xf32>) {
   // expected-error @+1 {{Attempted to vectorize, but failed}}
   linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x5x8x8xf32>, tensor<4x5x3x3xf32>) outs(%output : tensor<1x4x6x6xf32>) -> tensor<1x4x6x6xf32>


        


More information about the Mlir-commits mailing list