[Mlir-commits] [mlir] [MLIR] Allowing unsupported conv2d op to fail gracefully vectorization pass (PR #130181)
Zhuoran Yin
llvmlistbot at llvm.org
Thu Mar 6 13:48:12 PST 2025
https://github.com/jerryyin created https://github.com/llvm/llvm-project/pull/130181
In corner situations, the vectorization pass may face to lower a conv2d op and assert in a completely irrelevant location in vectorizeConvolution() subroutine.
This PR rejects the conv2d op early and make the asserted routine to return failure as a defensive workaround.
>From a33b2116cf8bea9287053bb57d277add19f77cb8 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 6 Mar 2025 21:40:25 +0000
Subject: [PATCH] Blocking conv2d from vectorization pass
---
.../Linalg/Transforms/Vectorization.cpp | 20 +++++++++++++++----
.../Linalg/vectorization-unsupported.mlir | 19 ++++++++++++++++++
2 files changed, 35 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..319dd4b2043c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1990,8 +1990,18 @@ static LogicalResult vectorizeLinalgOpPrecondition(
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. But we will still need stride/dilation attributes that will be
// annoying to reverse-engineer...
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+ // Check if it is 2d+ convolution. If it is, return failure because we don't
+ // support it. To use this pass on a 2d+ convolution, it should have already
+ // been decomposed to 1d convolution via
+ // DecomposeConvolutionToLowerDimOpsPass.
+ if (linalgOp.getNumParallelLoops() >= 4) {
+ LDBG("precondition failed: Regular 2d+ convolutions not supported.\n");
+ return failure();
+ }
return success();
+ }
+
// TODO: the common vector shape is equal to the static loop sizes only when
// all indexing maps are projected permutations. For convs and stencils the
// logic will need to evolve.
@@ -3929,9 +3939,11 @@ static FailureOr<Operation *> vectorizeConvolution(
if (!inputVecSizes.empty()) {
// Only use the input vector size corresponding to the channel dim. Other
// vector dims will be inferred from the Ops.
- assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
- isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
- "Not a 1D depthwise conv!");
+ if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
+ !isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
+ return rewriter.notifyMatchFailure(
+ op, "Unexpected convolution: expected 1D depthwise conv");
+ }
size_t chDimIdx =
TypeSwitch<Operation *, size_t>(op)
.Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 8f3b199145ce0..88d9e98c02bca 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -112,6 +112,25 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @conv2d(%3: tensor<1x64x58x58xf32>, %4: tensor<64x64x3x3xf32>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %5 = tensor.empty() : tensor<1x64x56x56xf32>
+ %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+ // expected-error @+1 {{Attempted to vectorize, but failed}}
+ %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%6 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
%pad = arith.constant 0.000000e+00 : f32
// expected-error @+1 {{Attempted to vectorize, but failed}}
More information about the Mlir-commits
mailing list