[Mlir-commits] [mlir] [Linalg] Update Vectorization to work with both named as well as generic conv ops (PR #176339)
Abhishek Varma
llvmlistbot at llvm.org
Mon Jan 19 00:01:46 PST 2026
https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/176339
>From 81eb4738dbfe613db89844748b60038ed83b24f0 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 2 Jan 2026 09:42:58 +0000
Subject: [PATCH 1/2] [Linalg] Update Vectorization to work with both named as
well as generic convolution ops
-- This commit updates Vectorization to work with both named as well
as generic convolution ops.
-- For i1 data type convolution ops' body includes arith.ori(arith.andi)
instead of arith.add*(arith.mul*) - so the matcher utility for generic
conv has been updated for the same.
Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
.../Linalg/Transforms/Vectorization.cpp | 97 ++++++++++++++-----
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 27 ++++--
.../convolution-with-patterns-flatten.mlir | 13 ++-
.../convolution-with-patterns.mlir | 51 +++++-----
.../Linalg/vectorization/convolution.mlir | 9 +-
5 files changed, 129 insertions(+), 68 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c7d5dff74c5a9..109bfe1936914 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -2070,7 +2071,7 @@ vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
return failure();
}
- if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
+ if (!isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
return failure();
}
@@ -2431,10 +2432,11 @@ static LogicalResult vectorizeLinalgOpPrecondition(
if (isElementwise(linalgOp))
return success();
- // TODO: isaConvolutionOpInterface that can also infer from generic
- // features. But we will still need stride/dilation attributes that will be
- // annoying to reverse-engineer...
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+ // Check for convolution ops - both named ops implementing
+ // ConvolutionOpInterface and generic ops that semantically match convolution
+ // patterns.
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
+ isaConvolutionOpInterface(linalgOp))
return vectorizeConvOpPrecondition(linalgOp);
// TODO: the common vector shape is equal to the static loop sizes only when
@@ -2640,12 +2642,12 @@ vectorizeScalableVectorPrecondition(Operation *op,
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
- return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
- isa<linalg::BatchMatmulOp>(op) ||
- isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
- isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
- isa<linalg::BatchMmt4DOp>(op) ||
- hasReductionIterator(linalgOp));
+ return success(
+ isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
+ isa<linalg::BatchMatmulOp>(op) ||
+ isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(linalgOp) ||
+ isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+ isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp));
}
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2736,7 +2738,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
// TODO: isaConvolutionOpInterface that can also infer from
// generic features. Will require stride/dilation attributes
// inference.
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
+ isaConvolutionOpInterface(linalgOp)) {
FailureOr<Operation *> convOr = vectorizeConvolution(
rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
flatten1DDepthwiseConv);
@@ -3482,6 +3485,42 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
bindShapeDims<0>(shapedType, vals...);
}
+/// Helper to extract strides and dilations for 1D convolution/pooling ops.
+/// Returns true if the op is a recognized 1D conv/pool op and extracts the
+/// stride and dilation values. For unrecognized ops, returns false.
+static bool extract1DConvPoolStrideDilation(LinalgOp op, int &strideW,
+ int &dilationW) {
+#define EXTRACT_1D_CONV_POOL_STRIDE_DILATION(ConvOpTy) \
+ if (std::optional<DilationsAndStrides> convParams = \
+ matchConvolutionOpOfType<ConvOpTy>(op)) { \
+ strideW = static_cast<int>(convParams->strides.front()); \
+ dilationW = static_cast<int>(convParams->dilations.front()); \
+ return true; \
+ }
+
+ // 1D Convolution ops
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNwcWcfOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNcwFcwOp);
+ // Depthwise 1D Convolution ops
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNcwCwOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcmOp);
+ // 1D Pooling ops (NWC layout)
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcSumOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxUnsignedOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinUnsignedOp);
+ // 1D Pooling ops (NCW layout)
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwSumOp);
+ EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwMaxOp);
+
+#undef EXTRACT_1D_CONV_POOL_STRIDE_DILATION
+
+ return false;
+}
+
namespace {
/// Generate a vector implementation for either:
/// ```
@@ -3537,14 +3576,19 @@ struct Conv1DGenerator
auto maybeKind = getCombinerOpKind(reduceOp);
reductionKind = maybeKind.value();
- // The ConvolutionOpInterface gives us guarantees of existence for
- // strides/dilations. However, we do not need to rely on those, we can
- // simply use them if present, otherwise use the default and let the generic
- // conv. matcher in the ConvGenerator succeed or fail.
- auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
- auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
- strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
- dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+ // Try to extract strides/dilations from named 1D conv/pool ops using
+ // matchConvolutionOpOfType. This works for both named ops and generic ops
+ // that match their semantics. For unrecognized generic ops, fall back to
+ // checking attributes directly (which may not exist for generic ops).
+ if (!extract1DConvPoolStrideDilation(linalgOp, strideW, dilationW)) {
+ // Fallback: check for stride/dilation attributes directly.
+ // For generic ops without these attributes, default to 1.
+ auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
+ auto dilations =
+ linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+ strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
+ dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+ }
}
/// Generate a vector implementation for:
@@ -4267,13 +4311,14 @@ static FailureOr<Operation *> vectorizeConvolution(
if (!inputVecSizes.empty()) {
// Only use the input vector size corresponding to the channel dim. Other
// vector dims will be inferred from the Ops.
- assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
- isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
+ assert((isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op) ||
+ isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op)) &&
"Not a 1D depthwise conv!");
- size_t chDimIdx =
- TypeSwitch<Operation *, size_t>(op)
- .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
- .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
+ size_t chDimIdx = 0;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(op))
+ chDimIdx = 2;
+ else if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(op))
+ chDimIdx = 1;
vecChDimSize = inputVecSizes[chDimIdx];
vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index daf02442bb21a..9b96fde6106fc 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -325,23 +325,29 @@ static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp,
/// where, %input_scalar can have optional upcast operation.
static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
bool containsZeroPointOffset = false) {
- Operation *addOp = yieldVal.getDefiningOp();
- if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
- return false;
+ bool isOrOp = false;
+ Operation *addOrOp = yieldVal.getDefiningOp();
+ if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOrOp)) {
+ if (!isa_and_present<arith::OrIOp>(addOrOp))
+ return false;
+ isOrOp = true;
+ }
- Operation *mulOp = addOp->getOperand(1).getDefiningOp();
- if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
+ Operation *mulAndOp = addOrOp->getOperand(1).getDefiningOp();
+ if (!isOrOp && !isa_and_present<arith::MulIOp, arith::MulFOp>(mulAndOp))
+ return false;
+ if (isOrOp && !isa_and_present<arith::AndIOp>(mulAndOp))
return false;
if (containsZeroPointOffset) {
- return bodyMatcherForZeroPointOffsets(addOp, mulOp, body);
+ return bodyMatcherForZeroPointOffsets(addOrOp, mulAndOp, body);
}
BlockArgument lhsBlockArg =
- getBlockArgumentWithOptionalCastOps(mulOp->getOperand(0));
+ getBlockArgumentWithOptionalCastOps(mulAndOp->getOperand(0));
BlockArgument rhsBlockArg =
- getBlockArgumentWithOptionalCastOps(mulOp->getOperand(1));
+ getBlockArgumentWithOptionalCastOps(mulAndOp->getOperand(1));
BlockArgument outBlockArg =
- getBlockArgumentWithOptionalCastOps(addOp->getOperand(0));
+ getBlockArgumentWithOptionalCastOps(addOrOp->getOperand(0));
if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
@@ -387,7 +393,8 @@ static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
}
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+ return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp, arith::OrIOp>(
+ yieldVal, body);
}
static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex,
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
index c47824a18cf56..3bb077946b257 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns-flatten.mlir
@@ -1,4 +1,7 @@
// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+// Test the same patterns on generic convolution ops by first generalizing the
+// named ops. This avoids duplicating lit tests for linalg.generic conv ops.
+// RUN: mlir-opt --split-input-file --linalg-generalize-named-ops --transform-interpreter %s | FileCheck %s
///----------------------------------------------------------------------------------------
/// Tests for vectorizing depthwise convolutions (with patterns) with the
@@ -19,7 +22,7 @@ func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x3xi8>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -52,7 +55,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[SC_ADDI:.*]] = vector.shape_cast %[[ADDI]] : vector<1x24xi8> to vector<1x8x3xi8>
// CHECK: vector.transfer_write %[[SC_ADDI]], %[[OUTPUT]][%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]]
-//------
+//-----
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3x5x4xf32>,
%filter: memref<2x4xf32>,
@@ -106,7 +109,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dillation_2(%input: memref<3
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -170,7 +173,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref_dilation_2(%input: memref<3x5
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -302,7 +305,7 @@ func.func @depthwise_conv1d_nwc_wc_3x9x4xi8_tensor_stride_2(%input: tensor<3x9x4
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
transform.yield
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
index cea60842f4606..f8781ff5452d9 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution-with-patterns.mlir
@@ -1,4 +1,7 @@
// RUN: mlir-opt -transform-interpreter -split-input-file %s | FileCheck %s
+// Test the same patterns on generic convolution ops by first generalizing the
+// named ops. This avoids duplicating lit tests for linalg.generic conv ops.
+// RUN: mlir-opt --linalg-generalize-named-ops --transform-interpreter --split-input-file %s | FileCheck %s
func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf32>, %output: memref<4x2x8xf32>) {
linalg.conv_1d_nwc_wcf
@@ -63,7 +66,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -102,7 +105,7 @@ func.func @conv1d_nwc_4x2x8_memref_i1(%input: memref<4x6x3xi1>, %filter: memref<
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -179,7 +182,7 @@ func.func @conv1d_nwc_4x2x8_i8i8i32_memref(%input: memref<4x6x3xi8>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -265,7 +268,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -323,7 +326,7 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -403,7 +406,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -443,7 +446,7 @@ func.func @conv1d_ncw_4x8x2_memref_i1(%input: memref<4x3x6xi1>, %filter: memref<
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -537,7 +540,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -604,7 +607,7 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_ncw_fcw", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -666,7 +669,7 @@ func.func @conv1d_8_tensor(%input: tensor<11xf32>, %filter: tensor<4xf32>, %outp
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -716,7 +719,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -770,7 +773,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %fi
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -804,7 +807,7 @@ func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter:
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -838,7 +841,7 @@ func.func @conv_1d_nwc_wcf_mixed_int_fp_memref(%input: memref<1x2x3xi8>, %filter
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.conv_1d_nwc_wcf", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -873,7 +876,7 @@ func.func @pooling_nwc_sum_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -908,7 +911,7 @@ func.func @pooling_nwc_max_memref_1_2_1_3(%input: memref<4x4x3xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_max"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_max", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -949,7 +952,7 @@ func.func @pooling_nwc_sum_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %fil
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -990,7 +993,7 @@ func.func @pooling_nwc_max_i8i8i32_memref_1_2_1_3(%input: memref<4x4x3xi8>, %fil
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_max"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_max", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -1029,7 +1032,7 @@ func.func @pooling_nwc_sum_memref_2_2_2_3(%input: memref<4x6x3xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -1068,7 +1071,7 @@ func.func @pooling_ncw_sum_memref_1_2_1_3(%input: memref<4x3x4xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_ncw_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_ncw_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -1099,7 +1102,7 @@ func.func @pooling_nwc_sum_mixed_type_memref_1_2_1_1(%input: memref<1x2x3xf16>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -1131,7 +1134,7 @@ func.func @pooling_nwc_sum_memref_2_2_2_1(%input: memref<4x4x3xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_nwc_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_nwc_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -1173,7 +1176,7 @@ func.func @pooling_ncw_sum_memref_2_2_2_3(%input: memref<4x3x6xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_ncw_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_ncw_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
@@ -1207,7 +1210,7 @@ func.func @pooling_ncw_sum_memref_2_3_2_1(%input: memref<4x2x5xf32>, %filter: me
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.pooling_ncw_sum"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.pooling_ncw_sum", "linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
transform.yield
diff --git a/mlir/test/Dialect/Linalg/vectorization/convolution.mlir b/mlir/test/Dialect/Linalg/vectorization/convolution.mlir
index 5c321d40f6c60..4f01e77039158 100644
--- a/mlir/test/Dialect/Linalg/vectorization/convolution.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/convolution.mlir
@@ -1,4 +1,7 @@
// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
+// Test the same patterns on generic convolution ops by first generalizing the
+// named ops. This avoids duplicating lit tests for linalg.generic conv ops.
+// RUN: mlir-opt --split-input-file --linalg-generalize-named-ops --transform-interpreter -cse %s | FileCheck %s
///----------------------------------------------------------------------------------------
/// ATM, all tests in this file require masking. As the support for masking is
@@ -23,7 +26,7 @@ func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor(%input: tensor<1x8x?xi8>,
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [1, 8, 4, 1] : !transform.any_op
transform.yield
}
@@ -84,7 +87,7 @@ func.func @depthwise_conv1d_nwc_wc_1x8x3xi8_tensor_scalable(
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [1, 8, [4], 1] : !transform.any_op
transform.yield
}
@@ -145,7 +148,7 @@ func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc", "linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [3, 2, [4], 2] : !transform.any_op
transform.yield
}
>From 83572d55c3af7cf8b6c481ba8c4be98d26b543fa Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Mon, 19 Jan 2026 07:24:16 +0000
Subject: [PATCH 2/2] Review comment v1.0
---
.../Linalg/Transforms/Vectorization.cpp | 82 ++++++++++---------
1 file changed, 42 insertions(+), 40 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 109bfe1936914..dbe7456c1daf4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2432,11 +2432,8 @@ static LogicalResult vectorizeLinalgOpPrecondition(
if (isElementwise(linalgOp))
return success();
- // Check for convolution ops - both named ops implementing
- // ConvolutionOpInterface and generic ops that semantically match convolution
- // patterns.
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
- isaConvolutionOpInterface(linalgOp))
+ // Check for both named as well as generic convolution ops.
+ if (isaConvolutionOpInterface(linalgOp))
return vectorizeConvOpPrecondition(linalgOp);
// TODO: the common vector shape is equal to the static loop sizes only when
@@ -2735,11 +2732,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
auto vectorizeResult =
TypeSwitch<Operation *, LogicalResult>(op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
- // TODO: isaConvolutionOpInterface that can also infer from
- // generic features. Will require stride/dilation attributes
- // inference.
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation()) ||
- isaConvolutionOpInterface(linalgOp)) {
+ // Check for both named as well as generic convolution ops.
+ if (isaConvolutionOpInterface(linalgOp)) {
FailureOr<Operation *> convOr = vectorizeConvolution(
rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
flatten1DDepthwiseConv);
@@ -3498,21 +3492,22 @@ static bool extract1DConvPoolStrideDilation(LinalgOp op, int &strideW,
return true; \
}
- // 1D Convolution ops
+ // 1D Convolution ops.
EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DOp);
EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNwcWcfOp);
EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::Conv1DNcwFcwOp);
- // Depthwise 1D Convolution ops
+ // Depthwise 1D Convolution ops.
+ // Note: Only NWC layout without channel multiplier is supported.
+ // DepthwiseConv1DNcwCwOp (NCW) and DepthwiseConv1DNwcWcmOp (with multiplier)
+ // are not supported.
EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcOp);
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNcwCwOp);
- EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::DepthwiseConv1DNwcWcmOp);
- // 1D Pooling ops (NWC layout)
+ // 1D Pooling ops (NWC layout).
EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcSumOp);
EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxOp);
EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMaxUnsignedOp);
EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinOp);
EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNwcMinUnsignedOp);
- // 1D Pooling ops (NCW layout)
+ // 1D Pooling ops (NCW layout).
EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwSumOp);
EXTRACT_1D_CONV_POOL_STRIDE_DILATION(linalg::PoolingNcwMaxOp);
@@ -3558,8 +3553,26 @@ namespace {
/// kw is unrolled, w is unrolled iff dilationW > 1.
struct Conv1DGenerator
: public StructuredGenerator<LinalgOp, utils::IteratorType> {
- Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
- : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
+ /// Factory method to create a Conv1DGenerator. Returns failure if the
+ /// operation doesn't have valid strides/dilations.
+ static FailureOr<Conv1DGenerator> create(RewriterBase &rewriter,
+ LinalgOp linalgOp) {
+ int strideW, dilationW;
+
+ // Try to extract strides/dilations from named 1D conv/pool ops using
+ // matchConvolutionOpOfType. This works for both named ops and generic ops
+ // that match their semantics.
+ if (!extract1DConvPoolStrideDilation(linalgOp, strideW, dilationW))
+ return failure();
+
+ return Conv1DGenerator(rewriter, linalgOp, strideW, dilationW);
+ }
+
+private:
+ Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
+ int dilationW)
+ : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
+ strideW(strideW), dilationW(dilationW) {
lhsShaped = linalgOp.getDpsInputOperand(0)->get();
rhsShaped = linalgOp.getDpsInputOperand(1)->get();
@@ -3575,22 +3588,9 @@ struct Conv1DGenerator
auto maybeKind = getCombinerOpKind(reduceOp);
reductionKind = maybeKind.value();
-
- // Try to extract strides/dilations from named 1D conv/pool ops using
- // matchConvolutionOpOfType. This works for both named ops and generic ops
- // that match their semantics. For unrecognized generic ops, fall back to
- // checking attributes directly (which may not exist for generic ops).
- if (!extract1DConvPoolStrideDilation(linalgOp, strideW, dilationW)) {
- // Fallback: check for stride/dilation attributes directly.
- // For generic ops without these attributes, default to 1.
- auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
- auto dilations =
- linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
- strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
- dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
- }
}
+public:
/// Generate a vector implementation for:
/// ```
/// Op def: ( w, kw )
@@ -4286,20 +4286,22 @@ struct Conv1DGenerator
static FailureOr<Operation *> vectorizeConvolution(
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
- Conv1DGenerator conv1dGen(rewriter, op);
- auto res = conv1dGen.generateNonChanneledConv();
+ FailureOr<Conv1DGenerator> conv1dGen = Conv1DGenerator::create(rewriter, op);
+ if (failed(conv1dGen))
+ return failure();
+ auto res = conv1dGen->generateNonChanneledConv();
if (succeeded(res))
return res;
- res = conv1dGen.generateNwcConv();
+ res = conv1dGen->generateNwcConv();
if (succeeded(res))
return res;
- res = conv1dGen.generateNcwConv();
+ res = conv1dGen->generateNcwConv();
if (succeeded(res))
return res;
- res = conv1dGen.generateNwcPooling();
+ res = conv1dGen->generateNwcPooling();
if (succeeded(res))
return res;
- res = conv1dGen.generateNcwPooling();
+ res = conv1dGen->generateNcwPooling();
if (succeeded(res))
return res;
@@ -4323,8 +4325,8 @@ static FailureOr<Operation *> vectorizeConvolution(
vecChDimSize = inputVecSizes[chDimIdx];
vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
}
- return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
- flatten1DDepthwiseConv);
+ return conv1dGen->generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
+ flatten1DDepthwiseConv);
}
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
More information about the Mlir-commits
mailing list