[Mlir-commits] [mlir] [Linalg] Update Vectorization to work with both named as well as generic conv ops (PR #176339)

Abhishek Varma llvmlistbot at llvm.org
Mon Jan 19 22:45:20 PST 2026


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/176339

>From 73d96e93e1fe2a8274f40dc926e67dc5c4c10904 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 2 Jan 2026 09:42:58 +0000
Subject: [PATCH 1/2] [Linalg] Update Vectorization to work with both named as
 well as generic convolution ops

-- This commit updates Vectorization to work with both named as well
   as generic convolution ops.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 .../Linalg/Transforms/Vectorization.cpp       | 97 ++++++++++++++-----
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       |  4 +-
 .../convolution-with-patterns-flatten.mlir    | 13 ++-
 .../convolution-with-patterns.mlir            | 51 +++++-----
 .../Linalg/vectorization/convolution.mlir     |  9 +-
 5 files changed, 114 insertions(+), 60 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6f73f4f57e50d..bdda2d3fb0fd9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -2071,7 +2072,7 @@ vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
     return failure();
   }
 
-  if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
+  if (!isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
     LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
     return failure();
   }
@@ -2436,10 +2437,11 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   if (isElementwise(linalgOp))
     return success();
 
-  // TODO: isaConvolutionOpInterface that can also infer from generic
-  // features. But we will still need stride/dilation attributes that will be
-  // annoying to reverse-engineer...
-  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+  // Check for convolution ops - both named ops implementing
+  // ConvolutionOpInterface and generic ops that semantically match convolution
+  // patterns.
+  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
+      isaConvolutionOpInterface(linalgOp))
     return vectorizeConvOpPrecondition(linalgOp);
 
   // TODO: the common vector shape is equal to the static loop sizes only when
@@ -2649,12 +2651,12 @@ vectorizeScalableVectorPrecondition(Operation *op,
 
   // Cond 4: Only the following ops are supported in the
   // presence of scalable vectors
-  return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
-                 isa<linalg::BatchMatmulOp>(op) ||
-                 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
-                 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
-                 isa<linalg::BatchMmt4DOp>(op) ||
-                 hasReductionIterator(linalgOp));
+  return success(
+      isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
+      isa<linalg::BatchMatmulOp>(op) ||
+      isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(linalgOp) ||
+      isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+      isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp));
 }
 
 LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2745,7 +2747,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
             // TODO: isaConvolutionOpInterface that can also infer from
             // generic features. Will require stride/dilation attributes
             // inference.
-            if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+            if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
+                isaConvolutionOpInterface(linalgOp)) {
               FailureOr<Operation *> convOr = vectorizeConvolution(
                   rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
                   flatten1DDepthwiseConv);
@@ -3491,6 +3494,42 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
   bindShapeDims<0>(shapedType, vals...);
 }
 
+/// Helper to extract strides and dilations for 1D convolution/pooling ops.
+/// Returns true if the op is a recognized 1D conv/pool op and extracts the
+/// stride and dilation values. For unrecognized ops, returns false.
+static bool extract1DConvPoolStrideDilation(LinalgOp op, int &strideW,
+                                            int &dilationW) {
+#define EXTRACT_1D_CONV_POOL_STRIDE_DILATION(ConvOpTy)                         \
+  if (std::optional<DilationsAndStrides> convParams =                          \
+          matchConvolutionOpOfType<ConvOpTy>(op)) {                            \
+    strideW = static_cast<int>(convParams->strides.front());                   \
+    dilationW = static_cast<int>(convParams->dilations.front());               \
+    return true;                                                               \
+  }
+
+  // 1D Convolution ops
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNwcWcfOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNcwFcwOp);
+  // Depthwise 1D Convolution ops
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNcwCwOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcmOp);
+  // 1D Pooling ops (NWC layout)
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcSumOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxUnsignedOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinUnsignedOp);
+  // 1D Pooling ops (NCW layout)
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwSumOp);
+  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwMaxOp);
+
+#undef EXTRACT_1D_CONV_POOL_STRIDE_DILATION
+
+  return false;
+}
+
 namespace {
 /// Generate a vector implementation for either:
 /// ```
@@ -3546,14 +3585,19 @@ struct Conv1DGenerator
     auto maybeKind = getCombinerOpKind(reduceOp);
     reductionKind = maybeKind.value();
 
-    // The ConvolutionOpInterface gives us guarantees of existence for
-    // strides/dilations. However, we do not need to rely on those, we can
-    // simply use them if present, otherwise use the default and let the generic
-    // conv. matcher in the ConvGenerator succeed or fail.
-    auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
-    auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
-    strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
-    dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+    // Try to extract strides/dilations from named 1D conv/pool ops using
+    // matchConvolutionOpOfType. This works for both named ops and generic ops
+    // that match their semantics. For unrecognized generic ops, fall back to
+    // checking attributes directly (which may not exist for generic ops).
+    if (!extract1DConvPoolStrideDilation(linalgOp, strideW, dilationW)) {
+      // Fallback: check for stride/dilation attributes directly.
+      // For generic ops without these attributes, default to 1.
+      auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
+      auto dilations =
+          linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+      strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
+      dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+    }
   }
 
   /// Generate a vector implementation for:
@@ -4276,13 +4320,14 @@ static FailureOr<Operation *> vectorizeConvolution(
   if (!inputVecSizes.empty()) {
     // Only use the input vector size corresponding to the channel dim. Other
     // vector dims will be inferred from the Ops.
-    assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
-            isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
+    assert((isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op) ||
+            isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op)) &&
            "Not a 1D depthwise conv!");
-    size_t chDimIdx =
-        TypeSwitch<Operation *, size_t>(op)
-            .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
-            .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
+    size_t chDimIdx = 0;
+    if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op))
+      chDimIdx = 2;
+    else if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op))
+      chDimIdx = 1;
 
     vecChDimSize = inputVecSizes[chDimIdx];
     vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index a1ee6b307caf5..b24326b999755 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -346,9 +346,9 @@ static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
     return bodyMatcherForZeroPointOffsets(accOp, mulOp, body);
   }
   BlockArgument lhsBlockArg =
-      getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(mulAndOp->getOperand(0));
   BlockArgument rhsBlockArg =
-      getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
+      getBlockArgumentWithOptionalCastOps(mulAndOp->getOperand(1));
   BlockArgument outBlockArg =
       getBlockArgumentWithOptionalCastOps(accOp->getOperand(0));
   if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
index c47824a18cf56..3bb077946b257 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
@@ -1,4 +1,7 @@
 // RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+// Test the same patterns on generic convolution ops by first generalizing the
+// named ops. This avoids duplicating lit tests for linalg.generic conv ops.
+// RUN: mlir-opt --split-input-file --linalg-generalize-named-ops --transform-interpreter %s | FileCheck %s
 
 ///----------------------------------------------------------------------------------------
 /// Tests for vectorizing depthwise convolutions (with patterns) with the
@@ -19,7 +22,7 @@ func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x3xi8>,
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -52,7 +55,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK:           %[[SC_ADDI:.*]] = vector.shape_cast %[[ADDI]] : vector<1x24xi8> to vector<1x8x3xi8>
 // CHECK:           vector.transfer_write %[[SC_ADDI]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
 
-//------
+//-----
 
 func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x4xf32>,
                                                                 %filter: memref<2x4xf32>,
@@ -106,7 +109,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -170,7 +173,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2(%input: memref<3x5
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -302,7 +305,7 @@ func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2(%input: tensor<3x9x4
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
     transform.yield
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
index cea60842f4606..f8781ff5452d9 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
@@ -1,4 +1,7 @@
 // RUN: mlir-opt -transform-interpreter -split-input-file  %s | FileCheck %s
+// Test the same patterns on generic convolution ops by first generalizing the
+// named ops. This avoids duplicating lit tests for linalg.generic conv ops.
+// RUN: mlir-opt --linalg-generalize-named-ops --transform-interpreter --split-input-file %s | FileCheck %s
 
 func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf32>, %output: memref<4x2x8xf32>) {
   linalg.conv_1d_nwc_wcf
@@ -63,7 +66,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -102,7 +105,7 @@ func.func @conv1d_nwc_4x2x8_memref_i1(%input: memref<4x6x3xi1>, %filter: memref<
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -179,7 +182,7 @@ func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: me
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -265,7 +268,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -323,7 +326,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -403,7 +406,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -443,7 +446,7 @@ func.func @conv1d_ncw_4x8x2_memref_i1(%input: memref<4x3x6xi1>, %filter: memref<
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -537,7 +540,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -604,7 +607,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -666,7 +669,7 @@ func.func @conv1d_8_tensor(%input: tensor<11xf32>, %filter: tensor<4xf32>, %outp
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_1d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.conv_1d", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -716,7 +719,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -770,7 +773,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %fi
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -804,7 +807,7 @@ func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter:
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -838,7 +841,7 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -873,7 +876,7 @@ func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: me
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -908,7 +911,7 @@ func.func @pooling_nwc_max_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: me
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.pooling_nwc_max"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pooling_nwc_max", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -949,7 +952,7 @@ func.func @pooling_nwc_sum_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %fil
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -990,7 +993,7 @@ func.func @pooling_nwc_max_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %fil
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.pooling_nwc_max"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pooling_nwc_max", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -1029,7 +1032,7 @@ func.func @pooling_nwc_sum_memref_2_2_2_3(%input: memref<4x6x3xf32>, %filter: me
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -1068,7 +1071,7 @@ func.func @pooling_ncw_sum_memref_1_2_1_3(%input: memref<4x3x4xf32>, %filter: me
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.pooling_ncw_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pooling_ncw_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -1099,7 +1102,7 @@ func.func @pooling_nwc_sum_mixed_type_memref_1_2_1_1(%input: memref<1x2x3xf16>,
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -1131,7 +1134,7 @@ func.func @pooling_nwc_sum_memref_2_2_2_1(%input: memref<4x4x3xf32>, %filter: me
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -1173,7 +1176,7 @@ func.func @pooling_ncw_sum_memref_2_2_2_3(%input: memref<4x3x6xf32>, %filter: me
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.pooling_ncw_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pooling_ncw_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
@@ -1207,7 +1210,7 @@ func.func @pooling_ncw_sum_memref_2_3_2_1(%input: memref<4x2x5xf32>, %filter: me
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.pooling_ncw_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.pooling_ncw_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution.mlir
index 5c321d40f6c60..4f01e77039158 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution.mlir
@@ -1,4 +1,7 @@
 // RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
+// Test the same patterns on generic convolution ops by first generalizing the
+// named ops. This avoids duplicating lit tests for linalg.generic conv ops.
+// RUN: mlir-opt --split-input-file --linalg-generalize-named-ops --transform-interpreter -cse %s | FileCheck %s
 
 ///----------------------------------------------------------------------------------------
 /// ATM, all tests in this file require masking. As the support for masking is
@@ -23,7 +26,7 @@ func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize %0 vector_sizes [1, 8, 4, 1] : !transform.any_op
     transform.yield
   }
@@ -84,7 +87,7 @@ func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable(
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
     transform.yield
   }
@@ -145,7 +148,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize %0 vector_sizes [3, 2, [4], 2] : !transform.any_op
     transform.yield
   }

>From ff06548d357c35eb887c95f11887a3e9a9be419c Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 19 Jan 2026 07:24:16 +0000
Subject: [PATCH 2/2] Review comment v1.0

---
 .../Linalg/Transforms/Vectorization.cpp       | 132 +++++++++---------
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       |   4 +-
 2 files changed, 66 insertions(+), 70 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bdda2d3fb0fd9..ad9698348df2b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2437,11 +2437,8 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   if (isElementwise(linalgOp))
     return success();
 
-  // Check for convolution ops - both named ops implementing
-  // ConvolutionOpInterface and generic ops that semantically match convolution
-  // patterns.
-  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
-      isaConvolutionOpInterface(linalgOp))
+  // Check for both named as well as generic convolution ops.
+  if (isaConvolutionOpInterface(linalgOp))
     return vectorizeConvOpPrecondition(linalgOp);
 
   // TODO: the common vector shape is equal to the static loop sizes only when
@@ -2744,11 +2741,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
   auto vectorizeResult =
       TypeSwitch<Operation *, LogicalResult>(op)
           .Case<linalg::LinalgOp>([&](auto linalgOp) {
-            // TODO: isaConvolutionOpInterface that can also infer from
-            // generic features. Will require stride/dilation attributes
-            // inference.
-            if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
-                isaConvolutionOpInterface(linalgOp)) {
+            // Check for both named as well as generic convolution ops.
+            if (isaConvolutionOpInterface(linalgOp)) {
               FailureOr<Operation *> convOr = vectorizeConvolution(
                   rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
                   flatten1DDepthwiseConv);
@@ -3494,40 +3488,35 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
   bindShapeDims<0>(shapedType, vals...);
 }
 
-/// Helper to extract strides and dilations for 1D convolution/pooling ops.
-/// Returns true if the op is a recognized 1D conv/pool op and extracts the
-/// stride and dilation values. For unrecognized ops, returns false.
-static bool extract1DConvPoolStrideDilation(LinalgOp op, int &strideW,
-                                            int &dilationW) {
-#define EXTRACT_1D_CONV_POOL_STRIDE_DILATION(ConvOpTy)                         \
-  if (std::optional<DilationsAndStrides> convParams =                          \
-          matchConvolutionOpOfType<ConvOpTy>(op)) {                            \
-    strideW = static_cast<int>(convParams->strides.front());                   \
-    dilationW = static_cast<int>(convParams->dilations.front());               \
-    return true;                                                               \
-  }
-
-  // 1D Convolution ops
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DOp);
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNwcWcfOp);
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNcwFcwOp);
-  // Depthwise 1D Convolution ops
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcOp);
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNcwCwOp);
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcmOp);
-  // 1D Pooling ops (NWC layout)
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcSumOp);
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxOp);
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxUnsignedOp);
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinOp);
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinUnsignedOp);
-  // 1D Pooling ops (NCW layout)
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwSumOp);
-  EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwMaxOp);
-
-#undef EXTRACT_1D_CONV_POOL_STRIDE_DILATION
-
-  return false;
+/// Match 1D convolution or pooling operations and return their dilations and
+/// strides. Returns std::nullopt for unrecognized ops.
+static std::optional<DilationsAndStrides> match1DConvPoolOp(LinalgOp op) {
+#define MATCH_1D_CONV_POOL_OP(ConvOpTy)                                        \
+  if (auto convParams = matchConvolutionOpOfType<ConvOpTy>(op))                \
+    return convParams;
+
+  // 1D Convolution ops.
+  MATCH_1D_CONV_POOL_OP(linalg::Conv1DOp);
+  MATCH_1D_CONV_POOL_OP(linalg::Conv1DNwcWcfOp);
+  MATCH_1D_CONV_POOL_OP(linalg::Conv1DNcwFcwOp);
+  // Depthwise 1D Convolution ops.
+  // Note: Only NWC layout without channel multiplier is supported.
+  // DepthwiseConv1DNcwCwOp (NCW) and DepthwiseConv1DNwcWcmOp (with multiplier)
+  // are not supported.
+  MATCH_1D_CONV_POOL_OP(linalg::DepthwiseConv1DNwcWcOp);
+  // 1D Pooling ops (NWC layout).
+  MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcSumOp);
+  MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMaxOp);
+  MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMaxUnsignedOp);
+  MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMinOp);
+  MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMinUnsignedOp);
+  // 1D Pooling ops (NCW layout).
+  MATCH_1D_CONV_POOL_OP(linalg::PoolingNcwSumOp);
+  MATCH_1D_CONV_POOL_OP(linalg::PoolingNcwMaxOp);
+
+#undef MATCH_1D_CONV_POOL_OP
+
+  return std::nullopt;
 }
 
 namespace {
@@ -3567,8 +3556,26 @@ namespace {
 /// kw is unrolled, w is unrolled iff dilationW > 1.
 struct Conv1DGenerator
     : public StructuredGenerator<LinalgOp, utils::IteratorType> {
-  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
-      : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
+  /// Factory method to create a Conv1DGenerator. Returns failure if the
+  /// operation doesn't have valid strides/dilations.
+  static FailureOr<Conv1DGenerator> create(RewriterBase &rewriter,
+                                           LinalgOp linalgOp) {
+    // Try to match a 1D conv/pool op using matchConvolutionOpOfType. This
+    // works for both named ops and generic ops that match their semantics.
+    std::optional<DilationsAndStrides> convParams = match1DConvPoolOp(linalgOp);
+    if (!convParams)
+      return failure();
+
+    int strideW = static_cast<int>(convParams->strides.front());
+    int dilationW = static_cast<int>(convParams->dilations.front());
+    return Conv1DGenerator(rewriter, linalgOp, strideW, dilationW);
+  }
+
+private:
+  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
+                  int dilationW)
+      : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
+        strideW(strideW), dilationW(dilationW) {
 
     lhsShaped = linalgOp.getDpsInputOperand(0)->get();
     rhsShaped = linalgOp.getDpsInputOperand(1)->get();
@@ -3584,22 +3591,9 @@ struct Conv1DGenerator
 
     auto maybeKind = getCombinerOpKind(reduceOp);
     reductionKind = maybeKind.value();
-
-    // Try to extract strides/dilations from named 1D conv/pool ops using
-    // matchConvolutionOpOfType. This works for both named ops and generic ops
-    // that match their semantics. For unrecognized generic ops, fall back to
-    // checking attributes directly (which may not exist for generic ops).
-    if (!extract1DConvPoolStrideDilation(linalgOp, strideW, dilationW)) {
-      // Fallback: check for stride/dilation attributes directly.
-      // For generic ops without these attributes, default to 1.
-      auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
-      auto dilations =
-          linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
-      strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
-      dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
-    }
   }
 
+public:
   /// Generate a vector implementation for:
   /// ```
   ///   Op def: (     w,     kw  )
@@ -4295,20 +4289,22 @@ struct Conv1DGenerator
 static FailureOr<Operation *> vectorizeConvolution(
     RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
     ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
-  Conv1DGenerator conv1dGen(rewriter, op);
-  auto res = conv1dGen.generateNonChanneledConv();
+  FailureOr<Conv1DGenerator> conv1dGen = Conv1DGenerator::create(rewriter, op);
+  if (failed(conv1dGen))
+    return failure();
+  auto res = conv1dGen->generateNonChanneledConv();
   if (succeeded(res))
     return res;
-  res = conv1dGen.generateNwcConv();
+  res = conv1dGen->generateNwcConv();
   if (succeeded(res))
     return res;
-  res = conv1dGen.generateNcwConv();
+  res = conv1dGen->generateNcwConv();
   if (succeeded(res))
     return res;
-  res = conv1dGen.generateNwcPooling();
+  res = conv1dGen->generateNwcPooling();
   if (succeeded(res))
     return res;
-  res = conv1dGen.generateNcwPooling();
+  res = conv1dGen->generateNcwPooling();
   if (succeeded(res))
     return res;
 
@@ -4332,8 +4328,8 @@ static FailureOr<Operation *> vectorizeConvolution(
     vecChDimSize = inputVecSizes[chDimIdx];
     vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
   }
-  return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
-                                       flatten1DDepthwiseConv);
+  return conv1dGen->generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
+                                        flatten1DDepthwiseConv);
 }
 
 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index b24326b999755..a1ee6b307caf5 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -346,9 +346,9 @@ static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
     return bodyMatcherForZeroPointOffsets(accOp, mulOp, body);
   }
   BlockArgument lhsBlockArg =
-      getBlockArgumentWithOptionalCastOps(mulAndOp->getOperand(0));
+      getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
   BlockArgument rhsBlockArg =
-      getBlockArgumentWithOptionalCastOps(mulAndOp->getOperand(1));
+      getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
   BlockArgument outBlockArg =
       getBlockArgumentWithOptionalCastOps(accOp->getOperand(0));
   if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||



More information about the Mlir-commits mailing list