[Mlir-commits] [mlir] [Linalg] Update Vectorization to work with both named as well as generic conv ops (PR #176339)
Abhishek Varma
llvmlistbot at llvm.org
Tue Jan 20 22:27:59 PST 2026
https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/176339
>From 73d96e93e1fe2a8274f40dc926e67dc5c4c10904 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 2 Jan 2026 09:42:58 +0000
Subject: [PATCH 1/3] [Linalg] Update Vectorization to work with both named as
well as generic convolution ops
-- This commit updates Vectorization to work with both named as well
as generic convolution ops.
Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
.../Linalg/Transforms/Vectorization.cpp | 97 ++++++++++++++-----
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 4 +-
.../convolution-with-patterns-flatten.mlir | 13 ++-
.../convolution-with-patterns.mlir | 51 +++++-----
.../Linalg/vectorization/convolution.mlir | 9 +-
5 files changed, 114 insertions(+), 60 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6f73f4f57e50d..bdda2d3fb0fd9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -2071,7 +2072,7 @@ vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
return failure();
}
- if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
+ if (!isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
return failure();
}
@@ -2436,10 +2437,11 @@ static LogicalResult vectorizeLinalgOpPrecondition(
if (isElementwise(linalgOp))
return success();
- // TODO: isaConvolutionOpInterface that can also infer from generic
- // features. But we will still need stride/dilation attributes that will be
- // annoying to reverse-engineer...
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+ // Check for convolution ops - both named ops implementing
+ // ConvolutionOpInterface and generic ops that semantically match convolution
+ // patterns.
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
+ isaConvolutionOpInterface(linalgOp))
return vectorizeConvOpPrecondition(linalgOp);
// TODO: the common vector shape is equal to the static loop sizes only when
@@ -2649,12 +2651,12 @@ vectorizeScalableVectorPrecondition(Operation *op,
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
- return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
- isa<linalg::BatchMatmulOp>(op) ||
- isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
- isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
- isa<linalg::BatchMmt4DOp>(op) ||
- hasReductionIterator(linalgOp));
+ return success(
+ isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
+ isa<linalg::BatchMatmulOp>(op) ||
+ isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(linalgOp) ||
+ isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+ isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp));
}
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2745,7 +2747,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
// TODO: isaConvolutionOpInterface that can also infer from
// generic features. Will require stride/dilation attributes
// inference.
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
+ isaConvolutionOpInterface(linalgOp)) {
FailureOr<Operation *> convOr = vectorizeConvolution(
rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
flatten1DDepthwiseConv);
@@ -3491,6 +3494,42 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
bindShapeDims<0>(shapedType, vals...);
}
+/// Helper to extract strides and dilations for 1D convolution/pooling ops.
+/// Returns true if the op is a recognized 1D conv/pool op and extracts the
+/// stride and dilation values. For unrecognized ops, returns false.
+static bool extract1DConvPoolStrideDilation(LinalgOp op, int &strideW,
+ int &dilationW) {
+#define EXTRACT_1D_CONV_POOL_STRIDE_DILATION(ConvOpTy) \
+ if (std::optional<DilationsAndStrides> convParams = \
+ matchConvolutionOpOfType<ConvOpTy>(op)) { \
+ strideW = static_cast<int>(convParams->strides.front()); \
+ dilationW = static_cast<int>(convParams->dilations.front()); \
+ return true; \
+ }
+
+ // 1D Convolution ops
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNwcWcfOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNcwFcwOp);
+ // Depthwise 1D Convolution ops
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNcwCwOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcmOp);
+ // 1D Pooling ops (NWC layout)
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcSumOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxUnsignedOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinUnsignedOp);
+ // 1D Pooling ops (NCW layout)
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwSumOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwMaxOp);
+
+#undef EXTRACT_1D_CONV_POOL_STRIDE_DILATION
+
+ return false;
+}
+
namespace {
/// Generate a vector implementation for either:
/// ```
@@ -3546,14 +3585,19 @@ struct Conv1DGenerator
auto maybeKind = getCombinerOpKind(reduceOp);
reductionKind = maybeKind.value();
- // The ConvolutionOpInterface gives us guarantees of existence for
- // strides/dilations. However, we do not need to rely on those, we can
- // simply use them if present, otherwise use the default and let the generic
- // conv. matcher in the ConvGenerator succeed or fail.
- auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
- auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
- strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
- dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+ // Try to extract strides/dilations from named 1D conv/pool ops using
+ // matchConvolutionOpOfType. This works for both named ops and generic ops
+ // that match their semantics. For unrecognized generic ops, fall back to
+ // checking attributes directly (which may not exist for generic ops).
+ if (!extract1DConvPoolStrideDilation(linalgOp, strideW, dilationW)) {
+ // Fallback: check for stride/dilation attributes directly.
+ // For generic ops without these attributes, default to 1.
+ auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
+ auto dilations =
+ linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+ strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
+ dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+ }
}
/// Generate a vector implementation for:
@@ -4276,13 +4320,14 @@ static FailureOr<Operation *> vectorizeConvolution(
if (!inputVecSizes.empty()) {
// Only use the input vector size corresponding to the channel dim. Other
// vector dims will be inferred from the Ops.
- assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
- isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
+ assert((isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op) ||
+ isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op)) &&
"Not a 1D depthwise conv!");
- size_t chDimIdx =
- TypeSwitch<Operation *, size_t>(op)
- .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
- .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
+ size_t chDimIdx = 0;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op))
+ chDimIdx = 2;
+ else if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op))
+ chDimIdx = 1;
vecChDimSize = inputVecSizes[chDimIdx];
vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index a1ee6b307caf5..b24326b999755 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -346,9 +346,9 @@ static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
return bodyMatcherForZeroPointOffsets(accOp, mulOp, body);
}
BlockArgument lhsBlockArg =
- getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
+ getBlockArgumentWithOptionalCastOps(mulAndOp->getOperand(0));
BlockArgument rhsBlockArg =
- getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
+ getBlockArgumentWithOptionalCastOps(mulAndOp->getOperand(1));
BlockArgument outBlockArg =
getBlockArgumentWithOptionalCastOps(accOp->getOperand(0));
if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
index c47824a18cf56..3bb077946b257 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
@@ -1,4 +1,7 @@
// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+// Test the same patterns on generic convolution ops by first generalizing the
+// named ops. This avoids duplicating lit tests for linalg.generic conv ops.
+// RUN: mlir-opt --split-input-file --linalg-generalize-named-ops --transform-interpreter %s | FileCheck %s
///----------------------------------------------------------------------------------------
/// Tests for vectorizing depthwise convolutions (with patterns) with the
@@ -19,7 +22,7 @@ func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x3xi8>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -52,7 +55,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[SC_ADDI:.*]] = vector.shape_cast %[[ADDI]] : vector<1x24xi8> to vector<1x8x3xi8>
// CHECK: vector.transfer_write %[[SC_ADDI]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
-//------
+//-----
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x4xf32>,
%filter: memref<2x4xf32>,
@@ -106,7 +109,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -170,7 +173,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2(%input: memref<3x5
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -302,7 +305,7 @@ func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2(%input: tensor<3x9x4
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
transform.yield
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
index cea60842f4606..f8781ff5452d9 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
@@ -1,4 +1,7 @@
// RUN: mlir-opt -transform-interpreter -split-input-file %s | FileCheck %s
+// Test the same patterns on generic convolution ops by first generalizing the
+// named ops. This avoids duplicating lit tests for linalg.generic conv ops.
+// RUN: mlir-opt --linalg-generalize-named-ops --transform-interpreter --split-input-file %s | FileCheck %s
func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf32>, %output: memref<4x2x8xf32>) {
linalg.conv_1d_nwc_wcf
@@ -63,7 +66,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -102,7 +105,7 @@ func.func @conv1d_nwc_4x2x8_memref_i1(%input: memref<4x6x3xi1>, %filter: memref<
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -179,7 +182,7 @@ func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -265,7 +268,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -323,7 +326,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -403,7 +406,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -443,7 +446,7 @@ func.func @conv1d_ncw_4x8x2_memref_i1(%input: memref<4x3x6xi1>, %filter: memref<
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -537,7 +540,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -604,7 +607,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -666,7 +669,7 @@ func.func @conv1d_8_tensor(%input: tensor<11xf32>, %filter: tensor<4xf32>, %outp
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -716,7 +719,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -770,7 +773,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %fi
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -804,7 +807,7 @@ func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter:
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -838,7 +841,7 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -873,7 +876,7 @@ func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -908,7 +911,7 @@ func.func @pooling_nwc_max_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_max"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_max", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -949,7 +952,7 @@ func.func @pooling_nwc_sum_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %fil
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -990,7 +993,7 @@ func.func @pooling_nwc_max_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %fil
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_max"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_max", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -1029,7 +1032,7 @@ func.func @pooling_nwc_sum_memref_2_2_2_3(%input: memref<4x6x3xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -1068,7 +1071,7 @@ func.func @pooling_ncw_sum_memref_1_2_1_3(%input: memref<4x3x4xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_ncw_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_ncw_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -1099,7 +1102,7 @@ func.func @pooling_nwc_sum_mixed_type_memref_1_2_1_1(%input: memref<1x2x3xf16>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -1131,7 +1134,7 @@ func.func @pooling_nwc_sum_memref_2_2_2_1(%input: memref<4x4x3xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -1173,7 +1176,7 @@ func.func @pooling_ncw_sum_memref_2_2_2_3(%input: memref<4x3x6xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_ncw_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_ncw_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -1207,7 +1210,7 @@ func.func @pooling_ncw_sum_memref_2_3_2_1(%input: memref<4x2x5xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_ncw_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_ncw_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution.mlir
index 5c321d40f6c60..4f01e77039158 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution.mlir
@@ -1,4 +1,7 @@
// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
+// Test the same patterns on generic convolution ops by first generalizing the
+// named ops. This avoids duplicating lit tests for linalg.generic conv ops.
+// RUN: mlir-opt --split-input-file --linalg-generalize-named-ops --transform-interpreter -cse %s | FileCheck %s
///----------------------------------------------------------------------------------------
/// ATM, all tests in this file require masking. As the support for masking is
@@ -23,7 +26,7 @@ func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [1, 8, 4, 1] : !transform.any_op
transform.yield
}
@@ -84,7 +87,7 @@ func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable(
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
transform.yield
}
@@ -145,7 +148,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [3, 2, [4], 2] : !transform.any_op
transform.yield
}
>From 87a4bf9f86d2432a572c1d210cb87926652ae5f8 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 19 Jan 2026 07:24:16 +0000
Subject: [PATCH 2/3] Review comment v1.0
---
.../Linalg/Transforms/Vectorization.cpp | 132 +++++++++---------
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 4 +-
.../convolution-with-patterns-flatten.mlir | 2 +-
3 files changed, 67 insertions(+), 71 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bdda2d3fb0fd9..ad9698348df2b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2437,11 +2437,8 @@ static LogicalResult vectorizeLinalgOpPrecondition(
if (isElementwise(linalgOp))
return success();
- // Check for convolution ops - both named ops implementing
- // ConvolutionOpInterface and generic ops that semantically match convolution
- // patterns.
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
- isaConvolutionOpInterface(linalgOp))
+ // Check for both named as well as generic convolution ops.
+ if (isaConvolutionOpInterface(linalgOp))
return vectorizeConvOpPrecondition(linalgOp);
// TODO: the common vector shape is equal to the static loop sizes only when
@@ -2744,11 +2741,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
auto vectorizeResult =
TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
- // TODO: isaConvolutionOpInterface that can also infer from
- // generic features. Will require stride/dilation attributes
- // inference.
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
- isaConvolutionOpInterface(linalgOp)) {
+ // Check for both named as well as generic convolution ops.
+ if (isaConvolutionOpInterface(linalgOp)) {
FailureOr<Operation *> convOr = vectorizeConvolution(
rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
flatten1DDepthwiseConv);
@@ -3494,40 +3488,35 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
bindShapeDims<0>(shapedType, vals...);
}
-/// Helper to extract strides and dilations for 1D convolution/pooling ops.
-/// Returns true if the op is a recognized 1D conv/pool op and extracts the
-/// stride and dilation values. For unrecognized ops, returns false.
-static bool extract1DConvPoolStrideDilation(LinalgOp op, int &strideW,
- int &dilationW) {
-#define EXTRACT_1D_CONV_POOL_STRIDE_DILATION(ConvOpTy) \
- if (std::optional<DilationsAndStrides> convParams = \
- matchConvolutionOpOfType<ConvOpTy>(op)) { \
- strideW = static_cast<int>(convParams->strides.front()); \
- dilationW = static_cast<int>(convParams->dilations.front()); \
- return true; \
- }
-
- // 1D Convolution ops
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DOp);
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNwcWcfOp);
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNcwFcwOp);
- // Depthwise 1D Convolution ops
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcOp);
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNcwCwOp);
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcmOp);
- // 1D Pooling ops (NWC layout)
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcSumOp);
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxOp);
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxUnsignedOp);
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinOp);
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinUnsignedOp);
- // 1D Pooling ops (NCW layout)
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwSumOp);
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwMaxOp);
-
-#undef EXTRACT_1D_CONV_POOL_STRIDE_DILATION
-
- return false;
+/// Match 1D convolution or pooling operations and return their dilations and
+/// strides. Returns std::nullopt for unrecognized ops.
+static std::optional<DilationsAndStrides> match1DConvPoolOp(LinalgOp op) {
+#define MATCH_1D_CONV_POOL_OP(ConvOpTy) \
+ if (auto convParams = matchConvolutionOpOfType<ConvOpTy>(op)) \
+ return convParams;
+
+ // 1D Convolution ops.
+ MATCH_1D_CONV_POOL_OP(linalg::Conv1DOp);
+ MATCH_1D_CONV_POOL_OP(linalg::Conv1DNwcWcfOp);
+ MATCH_1D_CONV_POOL_OP(linalg::Conv1DNcwFcwOp);
+ // Depthwise 1D Convolution ops.
+ // Note: Only NWC layout without channel multiplier is supported.
+ // DepthwiseConv1DNcwCwOp (NCW) and DepthwiseConv1DNwcWcmOp (with multiplier)
+ // are not supported.
+ MATCH_1D_CONV_POOL_OP(linalg::DepthwiseConv1DNwcWcOp);
+ // 1D Pooling ops (NWC layout).
+ MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcSumOp);
+ MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMaxOp);
+ MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMaxUnsignedOp);
+ MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMinOp);
+ MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMinUnsignedOp);
+ // 1D Pooling ops (NCW layout).
+ MATCH_1D_CONV_POOL_OP(linalg::PoolingNcwSumOp);
+ MATCH_1D_CONV_POOL_OP(linalg::PoolingNcwMaxOp);
+
+#undef MATCH_1D_CONV_POOL_OP
+
+ return std::nullopt;
}
namespace {
@@ -3567,8 +3556,26 @@ namespace {
/// kw is unrolled, w is unrolled iff dilationW > 1.
struct Conv1DGenerator
: public StructuredGenerator<LinalgOp, utils::IteratorType> {
- Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
- : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
+ /// Factory method to create a Conv1DGenerator. Returns failure if the
+ /// operation doesn't have valid strides/dilations.
+ static FailureOr<Conv1DGenerator> create(RewriterBase &rewriter,
+ LinalgOp linalgOp) {
+ // Try to match a 1D conv/pool op using matchConvolutionOpOfType. This
+ // works for both named ops and generic ops that match their semantics.
+ std::optional<DilationsAndStrides> convParams = match1DConvPoolOp(linalgOp);
+ if (!convParams)
+ return failure();
+
+ int strideW = static_cast<int>(convParams->strides.front());
+ int dilationW = static_cast<int>(convParams->dilations.front());
+ return Conv1DGenerator(rewriter, linalgOp, strideW, dilationW);
+ }
+
+private:
+ Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
+ int dilationW)
+ : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
+ strideW(strideW), dilationW(dilationW) {
lhsShaped = linalgOp.getDpsInputOperand(0)->get();
rhsShaped = linalgOp.getDpsInputOperand(1)->get();
@@ -3584,22 +3591,9 @@ struct Conv1DGenerator
auto maybeKind = getCombinerOpKind(reduceOp);
reductionKind = maybeKind.value();
-
- // Try to extract strides/dilations from named 1D conv/pool ops using
- // matchConvolutionOpOfType. This works for both named ops and generic ops
- // that match their semantics. For unrecognized generic ops, fall back to
- // checking attributes directly (which may not exist for generic ops).
- if (!extract1DConvPoolStrideDilation(linalgOp, strideW, dilationW)) {
- // Fallback: check for stride/dilation attributes directly.
- // For generic ops without these attributes, default to 1.
- auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
- auto dilations =
- linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
- strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
- dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
- }
}
+public:
/// Generate a vector implementation for:
/// ```
/// Op def: ( w, kw )
@@ -4295,20 +4289,22 @@ struct Conv1DGenerator
static FailureOr<Operation *> vectorizeConvolution(
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
- Conv1DGenerator conv1dGen(rewriter, op);
- auto res = conv1dGen.generateNonChanneledConv();
+ FailureOr<Conv1DGenerator> conv1dGen = Conv1DGenerator::create(rewriter, op);
+ if (failed(conv1dGen))
+ return failure();
+ auto res = conv1dGen->generateNonChanneledConv();
if (succeeded(res))
return res;
- res = conv1dGen.generateNwcConv();
+ res = conv1dGen->generateNwcConv();
if (succeeded(res))
return res;
- res = conv1dGen.generateNcwConv();
+ res = conv1dGen->generateNcwConv();
if (succeeded(res))
return res;
- res = conv1dGen.generateNwcPooling();
+ res = conv1dGen->generateNwcPooling();
if (succeeded(res))
return res;
- res = conv1dGen.generateNcwPooling();
+ res = conv1dGen->generateNcwPooling();
if (succeeded(res))
return res;
@@ -4332,8 +4328,8 @@ static FailureOr<Operation *> vectorizeConvolution(
vecChDimSize = inputVecSizes[chDimIdx];
vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
}
- return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
- flatten1DDepthwiseConv);
+ return conv1dGen->generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
+ flatten1DDepthwiseConv);
}
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index b24326b999755..a1ee6b307caf5 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -346,9 +346,9 @@ static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
return bodyMatcherForZeroPointOffsets(accOp, mulOp, body);
}
BlockArgument lhsBlockArg =
- getBlockArgumentWithOptionalCastOps(mulAndOp->getOperand(0));
+ getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
BlockArgument rhsBlockArg =
- getBlockArgumentWithOptionalCastOps(mulAndOp->getOperand(1));
+ getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
BlockArgument outBlockArg =
getBlockArgumentWithOptionalCastOps(accOp->getOperand(0));
if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
index 3bb077946b257..4198c1341c6fd 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
@@ -55,7 +55,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[SC_ADDI:.*]] = vector.shape_cast %[[ADDI]] : vector<1x24xi8> to vector<1x8x3xi8>
// CHECK: vector.transfer_write %[[SC_ADDI]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
-//-----
+//------
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x4xf32>,
%filter: memref<2x4xf32>,
>From dba3a9e52cccbfcee78d8690c39023252c0f25ec Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 21 Jan 2026 06:27:30 +0000
Subject: [PATCH 3/3] Review comment v2.0
---
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ad9698348df2b..b628bc06a90bc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -4322,7 +4322,7 @@ static FailureOr<Operation *> vectorizeConvolution(
size_t chDimIdx = 0;
if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op))
chDimIdx = 2;
- else if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op))
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op))
chDimIdx = 1;
vecChDimSize = inputVecSizes[chDimIdx];
More information about the Mlir-commits
mailing list