[Mlir-commits] [mlir] [MLIR] Allowing unsupported conv2d op to fail gracefully from vectorization (PR #130181)

Zhuoran Yin llvmlistbot at llvm.org
Fri Mar 7 08:39:25 PST 2025


https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/130181

>From a33b2116cf8bea9287053bb57d277add19f77cb8 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 6 Mar 2025 21:40:25 +0000
Subject: [PATCH 1/3] Blocking conv2d from vectorization pass

---
 .../Linalg/Transforms/Vectorization.cpp       | 20 +++++++++++++++----
 .../Linalg/vectorization-unsupported.mlir     | 19 ++++++++++++++++++
 2 files changed, 35 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..319dd4b2043c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1990,8 +1990,18 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   // TODO: isaConvolutionOpInterface that can also infer from generic
   // features. But we will still need stride/dilation attributes that will be
   // annoying to reverse-engineer...
-  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+    // Check if it is 2d+ convolution. If it is, return failure because we don't
+    // support it. To use this pass on a 2d+ convolution, it should have already
+    // been decomposed to 1d convolution via
+    // DecomposeConvolutionToLowerDimOpsPass.
+    if (linalgOp.getNumParallelLoops() >= 4) {
+      LDBG("precondition failed: Regular 2d+ convolutions not supported.\n");
+      return failure();
+    }
     return success();
+  }
+
   // TODO: the common vector shape is equal to the static loop sizes only when
   // all indexing maps are projected permutations. For convs and stencils the
   // logic will need to evolve.
@@ -3929,9 +3939,11 @@ static FailureOr<Operation *> vectorizeConvolution(
   if (!inputVecSizes.empty()) {
     // Only use the input vector size corresponding to the channel dim. Other
     // vector dims will be inferred from the Ops.
-    assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
-            isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
-           "Not a 1D depthwise conv!");
+    if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
+        !isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
+      return rewriter.notifyMatchFailure(
+          op, "Unexpected convolution: expected 1D depthwise conv");
+    }
     size_t chDimIdx =
         TypeSwitch<Operation *, size_t>(op)
             .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 8f3b199145ce0..88d9e98c02bca 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -112,6 +112,25 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @conv2d(%3: tensor<1x64x58x58xf32>, %4:  tensor<64x64x3x3xf32>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %5 = tensor.empty() : tensor<1x64x56x56xf32>
+  %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%6 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
   %pad = arith.constant 0.000000e+00 : f32
   // expected-error @+1 {{Attempted to vectorize, but failed}}

>From 80474f73a516dc8df0df10555dd9d59cda86121c Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 7 Mar 2025 14:56:44 +0000
Subject: [PATCH 2/3] Refactor Conv1DGenerator

---
 .../Linalg/Transforms/Vectorization.cpp       | 26 +++++++++----------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 319dd4b2043c3..d57d4214a78ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3135,10 +3135,8 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
 /// kw is unrolled, w is unrolled iff dilationW > 1.
 struct Conv1DGenerator
     : public StructuredGenerator<LinalgOp, utils::IteratorType> {
-  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
-                  int dilationW)
-      : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
-        strideW(strideW), dilationW(dilationW) {
+  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
+      : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
     // Determine whether `linalgOp` can be generated with this generator
     if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
       return;
@@ -3185,6 +3183,16 @@ struct Conv1DGenerator
         return;
       break;
     }
+
+    // The ConvolutionOpInterface gives us guarantees of existence for
+    // strides/dilations. However, we do not need to rely on those, we can
+    // simply use them if present, otherwise use the default and let the generic
+    // conv. matcher in the ConvGenerator succeed or fail.
+    auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
+    auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+    strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
+    dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+
     // The op is now known to be valid.
     valid = true;
   }
@@ -3906,15 +3914,7 @@ struct Conv1DGenerator
 static FailureOr<Operation *> vectorizeConvolution(
     RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
     ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
-  // The ConvolutionOpInterface gives us guarantees of existence for
-  // strides/dilations. However, we do not need to rely on those, we can
-  // simply use them if present, otherwise use the default and let the generic
-  // conv. matcher in the ConvGenerator succeed or fail.
-  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
-  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
-  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
-  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
-  Conv1DGenerator e(rewriter, op, stride, dilation);
+  Conv1DGenerator e(rewriter, op);
   auto res = e.generateNonChanneledConv();
   if (succeeded(res))
     return res;

>From 00c3a3380530944e5c6c306c97fddf05d967afdc Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 7 Mar 2025 16:39:06 +0000
Subject: [PATCH 3/3] Forward declare Conv1DGenerator for validaty

---
 .../Linalg/Transforms/Vectorization.cpp       | 26 ++++++++++++++++---
 1 file changed, 22 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d57d4214a78ea..358b30b7b9712 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -52,6 +52,12 @@ using namespace mlir::linalg;
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
+// Forward declaration of Conv1DGenerator and its validator
+namespace {
+struct Conv1DGenerator;
+bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp);
+} // namespace
+
 /// Try to vectorize `convOp` as a convolution.
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
@@ -1991,14 +1997,17 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   // features. But we will still need stride/dilation attributes that will be
   // annoying to reverse-engineer...
   if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
-    // Check if it is 2d+ convolution. If it is, return failure because we don't
+    // Create a dummy rewriter first, a rewriter is not required for
+    // validation
+    IRRewriter dummyBuilder(linalgOp.getContext());
+    // Check if we can successfully construct a 1d convolution generator.
+    // For example, if it is 2d+ convolution, return failure because we don't
     // support it. To use this pass on a 2d+ convolution, it should have already
     // been decomposed to 1d convolution via
     // DecomposeConvolutionToLowerDimOpsPass.
-    if (linalgOp.getNumParallelLoops() >= 4) {
-      LDBG("precondition failed: Regular 2d+ convolutions not supported.\n");
+    if (!validateConv1DGenerator(dummyBuilder, linalgOp))
       return failure();
-    }
+
     return success();
   }
 
@@ -3197,6 +3206,8 @@ struct Conv1DGenerator
     valid = true;
   }
 
+  bool isValid() { return valid; }
+
   /// Generate a vector implementation for:
   /// ```
   ///   Op def: (     w,     kw  )
@@ -3907,6 +3918,13 @@ struct Conv1DGenerator
     }
   }
 };
+
+// Helper function to construct Conv1DGenerator
+bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) {
+  Conv1DGenerator conv1dGen(rewriter, linalgOp);
+  return conv1dGen.isValid();
+}
+
 } // namespace
 
 /// Helper function to vectorize a LinalgOp with convolution semantics.



More information about the Mlir-commits mailing list