[Mlir-commits] [mlir] [MLIR] Refactor to create vectorization convOp precondition check (PR #130181)

Zhuoran Yin llvmlistbot at llvm.org
Tue Mar 11 10:19:17 PDT 2025


https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/130181

>From a33b2116cf8bea9287053bb57d277add19f77cb8 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 6 Mar 2025 21:40:25 +0000
Subject: [PATCH 1/7] Blocking conv2d from vectorization pass

---
 .../Linalg/Transforms/Vectorization.cpp       | 20 +++++++++++++++----
 .../Linalg/vectorization-unsupported.mlir     | 19 ++++++++++++++++++
 2 files changed, 35 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..319dd4b2043c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1990,8 +1990,18 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   // TODO: isaConvolutionOpInterface that can also infer from generic
   // features. But we will still need stride/dilation attributes that will be
   // annoying to reverse-engineer...
-  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+    // Check if it is 2d+ convolution. If it is, return failure because we don't
+    // support it. To use this pass on a 2d+ convolution, it should have already
+    // been decomposed to 1d convolution via
+    // DecomposeConvolutionToLowerDimOpsPass.
+    if (linalgOp.getNumParallelLoops() >= 4) {
+      LDBG("precondition failed: Regular 2d+ convolutions not supported.\n");
+      return failure();
+    }
     return success();
+  }
+
   // TODO: the common vector shape is equal to the static loop sizes only when
   // all indexing maps are projected permutations. For convs and stencils the
   // logic will need to evolve.
@@ -3929,9 +3939,11 @@ static FailureOr<Operation *> vectorizeConvolution(
   if (!inputVecSizes.empty()) {
     // Only use the input vector size corresponding to the channel dim. Other
     // vector dims will be inferred from the Ops.
-    assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
-            isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
-           "Not a 1D depthwise conv!");
+    if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
+        !isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
+      return rewriter.notifyMatchFailure(
+          op, "Unexpected convolution: expected 1D depthwise conv");
+    }
     size_t chDimIdx =
         TypeSwitch<Operation *, size_t>(op)
             .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 8f3b199145ce0..88d9e98c02bca 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -112,6 +112,25 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @conv2d(%3: tensor<1x64x58x58xf32>, %4:  tensor<64x64x3x3xf32>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %5 = tensor.empty() : tensor<1x64x56x56xf32>
+  %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%6 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
   %pad = arith.constant 0.000000e+00 : f32
   // expected-error @+1 {{Attempted to vectorize, but failed}}

>From 80474f73a516dc8df0df10555dd9d59cda86121c Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 7 Mar 2025 14:56:44 +0000
Subject: [PATCH 2/7] Refactor Conv1DGenerator

---
 .../Linalg/Transforms/Vectorization.cpp       | 26 +++++++++----------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 319dd4b2043c3..d57d4214a78ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3135,10 +3135,8 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
 /// kw is unrolled, w is unrolled iff dilationW > 1.
 struct Conv1DGenerator
     : public StructuredGenerator<LinalgOp, utils::IteratorType> {
-  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
-                  int dilationW)
-      : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
-        strideW(strideW), dilationW(dilationW) {
+  Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
+      : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
     // Determine whether `linalgOp` can be generated with this generator
     if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
       return;
@@ -3185,6 +3183,16 @@ struct Conv1DGenerator
         return;
       break;
     }
+
+    // The ConvolutionOpInterface gives us guarantees of existence for
+    // strides/dilations. However, we do not need to rely on those, we can
+    // simply use them if present, otherwise use the default and let the generic
+    // conv. matcher in the ConvGenerator succeed or fail.
+    auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
+    auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+    strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
+    dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+
     // The op is now known to be valid.
     valid = true;
   }
@@ -3906,15 +3914,7 @@ struct Conv1DGenerator
 static FailureOr<Operation *> vectorizeConvolution(
     RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
     ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
-  // The ConvolutionOpInterface gives us guarantees of existence for
-  // strides/dilations. However, we do not need to rely on those, we can
-  // simply use them if present, otherwise use the default and let the generic
-  // conv. matcher in the ConvGenerator succeed or fail.
-  auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
-  auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
-  auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
-  auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
-  Conv1DGenerator e(rewriter, op, stride, dilation);
+  Conv1DGenerator e(rewriter, op);
   auto res = e.generateNonChanneledConv();
   if (succeeded(res))
     return res;

>From 00c3a3380530944e5c6c306c97fddf05d967afdc Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 7 Mar 2025 16:39:06 +0000
Subject: [PATCH 3/7] Forward declare Conv1DGenerator for validaty

---
 .../Linalg/Transforms/Vectorization.cpp       | 26 ++++++++++++++++---
 1 file changed, 22 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d57d4214a78ea..358b30b7b9712 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -52,6 +52,12 @@ using namespace mlir::linalg;
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
+// Forward declaration of Conv1DGenerator and its validator
+namespace {
+struct Conv1DGenerator;
+bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp);
+} // namespace
+
 /// Try to vectorize `convOp` as a convolution.
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
@@ -1991,14 +1997,17 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   // features. But we will still need stride/dilation attributes that will be
   // annoying to reverse-engineer...
   if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
-    // Check if it is 2d+ convolution. If it is, return failure because we don't
+    // Create a dummy rewriter first, a rewriter is not required for
+    // validation
+    IRRewriter dummyBuilder(linalgOp.getContext());
+    // Check if we can successfully construct a 1d convolution generator.
+    // For example, if it is 2d+ convolution, return failure because we don't
     // support it. To use this pass on a 2d+ convolution, it should have already
     // been decomposed to 1d convolution via
     // DecomposeConvolutionToLowerDimOpsPass.
-    if (linalgOp.getNumParallelLoops() >= 4) {
-      LDBG("precondition failed: Regular 2d+ convolutions not supported.\n");
+    if (!validateConv1DGenerator(dummyBuilder, linalgOp))
       return failure();
-    }
+
     return success();
   }
 
@@ -3197,6 +3206,8 @@ struct Conv1DGenerator
     valid = true;
   }
 
+  bool isValid() { return valid; }
+
   /// Generate a vector implementation for:
   /// ```
   ///   Op def: (     w,     kw  )
@@ -3907,6 +3918,13 @@ struct Conv1DGenerator
     }
   }
 };
+
+// Helper function to construct Conv1DGenerator
+bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) {
+  Conv1DGenerator conv1dGen(rewriter, linalgOp);
+  return conv1dGen.isValid();
+}
+
 } // namespace
 
 /// Helper function to vectorize a LinalgOp with convolution semantics.

>From 74a898618bd0227e56413812411ac8bed809e7c9 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 10 Mar 2025 18:34:45 +0000
Subject: [PATCH 4/7] Addressing review feedbacks

---
 .../Linalg/Transforms/Vectorization.cpp       | 70 ++++++++-----------
 .../Linalg/vectorization-unsupported.mlir     | 40 +++++++++--
 2 files changed, 65 insertions(+), 45 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 358b30b7b9712..6e5907d72e97e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -52,12 +52,6 @@ using namespace mlir::linalg;
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
-// Forward declaration of Conv1DGenerator and its validator
-namespace {
-struct Conv1DGenerator;
-bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp);
-} // namespace
-
 /// Try to vectorize `convOp` as a convolution.
 static FailureOr<Operation *>
 vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
@@ -1945,6 +1939,22 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
   return success();
 }
 
+static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
+  // We only support 1D convolutions, reject all other cases.
+  if (isa<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcFhwcOp,
+          linalg::Conv2DNchwFchwOp>(convOp)) {
+    LDBG("2D convolutions are not supported\n");
+    return failure();
+  }
+
+  if (isa<linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNcdhwFcdhwOp>(convOp)) {
+    LDBG("3D convolutions are not supported\n");
+    return failure();
+  }
+
+  return success();
+}
+
 static LogicalResult vectorizeLinalgOpPrecondition(
     LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
     bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -1996,20 +2006,8 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   // TODO: isaConvolutionOpInterface that can also infer from generic
   // features. But we will still need stride/dilation attributes that will be
   // annoying to reverse-engineer...
-  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
-    // Create a dummy rewriter first, a rewriter is not required for
-    // validation
-    IRRewriter dummyBuilder(linalgOp.getContext());
-    // Check if we can successfully construct a 1d convolution generator.
-    // For example, if it is 2d+ convolution, return failure because we don't
-    // support it. To use this pass on a 2d+ convolution, it should have already
-    // been decomposed to 1d convolution via
-    // DecomposeConvolutionToLowerDimOpsPass.
-    if (!validateConv1DGenerator(dummyBuilder, linalgOp))
-      return failure();
-
-    return success();
-  }
+  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+    return vectorizeConvOpPrecondition(linalgOp);
 
   // TODO: the common vector shape is equal to the static loop sizes only when
   // all indexing maps are projected permutations. For convs and stencils the
@@ -3918,13 +3916,6 @@ struct Conv1DGenerator
     }
   }
 };
-
-// Helper function to construct Conv1DGenerator
-bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) {
-  Conv1DGenerator conv1dGen(rewriter, linalgOp);
-  return conv1dGen.isValid();
-}
-
 } // namespace
 
 /// Helper function to vectorize a LinalgOp with convolution semantics.
@@ -3932,20 +3923,21 @@ bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) {
 static FailureOr<Operation *> vectorizeConvolution(
     RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
     ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
-  Conv1DGenerator e(rewriter, op);
-  auto res = e.generateNonChanneledConv();
+  Conv1DGenerator conv1dGen(rewriter, op);
+  assert(conv1dGen.isValid() && "Conv1DGenerator failed");
+  auto res = conv1dGen.generateNonChanneledConv();
   if (succeeded(res))
     return res;
-  res = e.generateNwcConv();
+  res = conv1dGen.generateNwcConv();
   if (succeeded(res))
     return res;
-  res = e.generateNcwConv();
+  res = conv1dGen.generateNcwConv();
   if (succeeded(res))
     return res;
-  res = e.generateNwcPooling();
+  res = conv1dGen.generateNwcPooling();
   if (succeeded(res))
     return res;
-  res = e.generateNcwPooling();
+  res = conv1dGen.generateNcwPooling();
   if (succeeded(res))
     return res;
 
@@ -3957,11 +3949,9 @@ static FailureOr<Operation *> vectorizeConvolution(
   if (!inputVecSizes.empty()) {
     // Only use the input vector size corresponding to the channel dim. Other
     // vector dims will be inferred from the Ops.
-    if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
-        !isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
-      return rewriter.notifyMatchFailure(
-          op, "Unexpected convolution: expected 1D depthwise conv");
-    }
+    assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
+            isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
+           "Not a 1D depthwise conv!");
     size_t chDimIdx =
         TypeSwitch<Operation *, size_t>(op)
             .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
@@ -3970,8 +3960,8 @@ static FailureOr<Operation *> vectorizeConvolution(
     vecChDimSize = inputVecSizes[chDimIdx];
     vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
   }
-  return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
-                               flatten1DDepthwiseConv);
+  return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
+                                       flatten1DDepthwiseConv);
 }
 
 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 88d9e98c02bca..2d1f0191eb798 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -112,12 +112,9 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @conv2d(%3: tensor<1x64x58x58xf32>, %4:  tensor<64x64x3x3xf32>) {
-  %cst = arith.constant 0.000000e+00 : f32
-  %5 = tensor.empty() : tensor<1x64x56x56xf32>
-  %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+func.func @conv2d_nchw_fchw(%input: tensor<1x5x8x8xf32>, %filter:  tensor<4x5x3x3xf32>, %output: tensor<1x4x6x6xf32>) {
   // expected-error @+1 {{Attempted to vectorize, but failed}}
-  %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%6 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+  linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x5x8x8xf32>, tensor<4x5x3x3xf32>) outs(%output : tensor<1x4x6x6xf32>) -> tensor<1x4x6x6xf32>
   return
 }
 
@@ -131,6 +128,39 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @conv2d_nhwc_fhwc(%input: tensor<1x8x8x5xf32>, %filter: tensor<4x3x3x5xf32>, %output: tensor<1x6x6x4xf32>) {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x8x8x5xf32>, tensor<4x3x3x5xf32>) outs(%output : tensor<1x6x6x4xf32>) -> tensor<1x6x6x4xf32>
+  return
+}
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @conv3d_ncdhw_fcdhw(%input: tensor<1x5x8x8x8xf32>, %filter: tensor<4x5x3x3x3xf32>, %output: tensor<1x4x6x6x6xf32>) {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : vector<3xi64>, strides = dense<1> : vector<3xi64>} ins(%input, %filter : tensor<1x5x8x8x8xf32>, tensor<4x5x3x3x3xf32>) outs(%output : tensor<1x4x6x6x6xf32>) -> tensor<1x4x6x6x6xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_3d_ncdhw_fcdhw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
   %pad = arith.constant 0.000000e+00 : f32
   // expected-error @+1 {{Attempted to vectorize, but failed}}

>From 2ef5555f82feb9f8dbaa1cac04cfc25013c0d16f Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 10 Mar 2025 21:51:35 +0000
Subject: [PATCH 5/7] Peel away conv1dgen validator into precondition check

---
 .../Linalg/Transforms/Vectorization.cpp       | 226 ++++++++++--------
 1 file changed, 121 insertions(+), 105 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6e5907d72e97e..5c830693e8349 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1939,19 +1939,124 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
   return success();
 }
 
+namespace {
+bool isCastOfBlockArgument(Operation *op) {
+  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
+         isa<BlockArgument>(op->getOperand(0));
+}
+
+// Returns true iff it is a valid conv/pooling op.
+// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
+// + yield) and rhs is not used) then it is the body of a pooling
+// If conv, check for single `mul` predecessor. The `mul` operands must be
+// block arguments or extension of block arguments.
+// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
+// must be block arguments or extension of block arguments.
+enum OperKind { Conv, Pool };
+bool getOperKind(Operation *reduceOp, OperKind &oper) {
+  int numBlockArguments =
+      llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
+
+  switch (numBlockArguments) {
+  case 1: {
+    // Will be convolution if feeder is a MulOp.
+    // A strength reduced version of MulOp for i1 type is AndOp which is also
+    // supported. Otherwise, it can be pooling. This strength reduction logic
+    // is in `buildBinaryFn` helper in the Linalg dialect.
+    auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
+                                       llvm::IsaPred<BlockArgument>);
+    Operation *feedOp = (*feedValIt).getDefiningOp();
+    // llvm::outs() << "feedOp: " << *feedOp << "\n";
+    if (isCastOfBlockArgument(feedOp)) {
+      oper = Pool;
+    } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+                  (isa<arith::AndIOp>(feedOp) &&
+                   feedOp->getResultTypes()[0].isInteger(1))) &&
+                 llvm::all_of(feedOp->getOperands(), [](Value v) {
+                   if (isa<BlockArgument>(v))
+                     return true;
+                   if (Operation *op = v.getDefiningOp())
+                     return isCastOfBlockArgument(op);
+                   return false;
+                 }))) {
+      return false;
+    }
+    return true;
+  }
+  case 2:
+    // Must be pooling
+    oper = Pool;
+    return true;
+  default:
+    return false;
+  }
+}
+
+bool isSupportedPoolKind(vector::CombiningKind kind) {
+  switch (kind) {
+  case vector::CombiningKind::ADD:
+  case vector::CombiningKind::MAXNUMF:
+  case vector::CombiningKind::MAXIMUMF:
+  case vector::CombiningKind::MAXSI:
+  case vector::CombiningKind::MAXUI:
+  case vector::CombiningKind::MINNUMF:
+  case vector::CombiningKind::MINIMUMF:
+  case vector::CombiningKind::MINSI:
+  case vector::CombiningKind::MINUI:
+    return true;
+  default:
+    return false;
+  }
+}
+} // namespace
+
 static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
-  // We only support 1D convolutions, reject all other cases.
-  if (isa<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcFhwcOp,
-          linalg::Conv2DNchwFchwOp>(convOp)) {
-    LDBG("2D convolutions are not supported\n");
+  if (convOp.getNumDpsInputs() != 2 || convOp.getNumDpsInits() != 1)
+    return failure();
+
+  auto lhsShaped = convOp.getDpsInputOperand(0)->get();
+  auto rhsShaped = convOp.getDpsInputOperand(1)->get();
+  auto resShaped = convOp.getDpsInitOperand(0)->get();
+  auto lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
+  auto rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
+  auto resShapedType = dyn_cast<ShapedType>(resShaped.getType());
+  if (!lhsShapedType || !rhsShapedType || !resShapedType)
+    return failure();
+  // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
+  // (non-channeled convolution -> LHS and RHS both have single dimensions).
+  if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
+      (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
     return failure();
-  }
 
-  if (isa<linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNcdhwFcdhwOp>(convOp)) {
-    LDBG("3D convolutions are not supported\n");
+  Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
+  if (!reduceOp)
+    return failure();
+
+  OperKind oper = Conv;
+  if (!getOperKind(reduceOp, oper))
+    return failure();
+  auto maybeKind = getCombinerOpKind(reduceOp);
+  // Typically convolution will have a `Add` CombiningKind but for i1 type it
+  // can get strength reduced to `OR` which is also supported. This strength
+  // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
+  if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
+                      *maybeKind != vector::CombiningKind::OR) &&
+                     (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
     return failure();
   }
 
+  auto rhsRank = rhsShapedType.getRank();
+  switch (oper) {
+  case Conv:
+    if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
+      return failure();
+    break;
+  case Pool:
+    if (rhsRank != 1)
+      return failure();
+    break;
+  }
+
   return success();
 }
 
@@ -3084,28 +3189,6 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
 }
 
 namespace {
-bool isCastOfBlockArgument(Operation *op) {
-  return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
-         isa<BlockArgument>(op->getOperand(0));
-}
-
-bool isSupportedPoolKind(vector::CombiningKind kind) {
-  switch (kind) {
-  case vector::CombiningKind::ADD:
-  case vector::CombiningKind::MAXNUMF:
-  case vector::CombiningKind::MAXIMUMF:
-  case vector::CombiningKind::MAXSI:
-  case vector::CombiningKind::MAXUI:
-  case vector::CombiningKind::MINNUMF:
-  case vector::CombiningKind::MINIMUMF:
-  case vector::CombiningKind::MINSI:
-  case vector::CombiningKind::MINUI:
-    return true;
-  default:
-    return false;
-  }
-}
-
 /// Generate a vector implementation for either:
 /// ```
 ///   Op def: (     w,     kw  )
@@ -3144,53 +3227,22 @@ struct Conv1DGenerator
     : public StructuredGenerator<LinalgOp, utils::IteratorType> {
   Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
       : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
-    // Determine whether `linalgOp` can be generated with this generator
-    if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
-      return;
+
     lhsShaped = linalgOp.getDpsInputOperand(0)->get();
     rhsShaped = linalgOp.getDpsInputOperand(1)->get();
     resShaped = linalgOp.getDpsInitOperand(0)->get();
     lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
     rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
     resShapedType = dyn_cast<ShapedType>(resShaped.getType());
-    if (!lhsShapedType || !rhsShapedType || !resShapedType)
-      return;
-    // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
-    // (non-channeled convolution -> LHS and RHS both have single dimensions).
-    if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
-        (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
-      return;
 
     Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
-    if (!reduceOp)
-      return;
     redOp = reduceOp->getName().getIdentifier();
 
-    if (!setOperKind(reduceOp))
-      return;
+    setOperKind(reduceOp);
+
     auto maybeKind = getCombinerOpKind(reduceOp);
-    // Typically convolution will have a `Add` CombiningKind but for i1 type it
-    // can get strength reduced to `OR` which is also supported. This strength
-    // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
-    if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
-                        *maybeKind != vector::CombiningKind::OR) &&
-                       (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
-      return;
-    }
     reductionKind = maybeKind.value();
 
-    auto rhsRank = rhsShapedType.getRank();
-    switch (oper) {
-    case Conv:
-      if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
-        return;
-      break;
-    case Pool:
-      if (rhsRank != 1)
-        return;
-      break;
-    }
-
     // The ConvolutionOpInterface gives us guarantees of existence for
     // strides/dilations. However, we do not need to rely on those, we can
     // simply use them if present, otherwise use the default and let the generic
@@ -3199,13 +3251,8 @@ struct Conv1DGenerator
     auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
     strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
     dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
-
-    // The op is now known to be valid.
-    valid = true;
   }
 
-  bool isValid() { return valid; }
-
   /// Generate a vector implementation for:
   /// ```
   ///   Op def: (     w,     kw  )
@@ -3225,9 +3272,6 @@ struct Conv1DGenerator
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
   FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
-    if (!valid)
-      return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
-
     int64_t nSize, wSize, cSize, kwSize, fSize;
     SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
     bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
@@ -3510,9 +3554,6 @@ struct Conv1DGenerator
   FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
                                        bool channelDimScalableFlag,
                                        bool flatten) {
-    if (!valid)
-      return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
-
     bool scalableChDim = false;
     bool useMasking = false;
     int64_t nSize, wSize, cSize, kwSize;
@@ -3857,8 +3898,6 @@ struct Conv1DGenerator
   }
 
 private:
-  enum OperKind { Conv, Pool };
-  bool valid = false;
   OperKind oper = Conv;
   StringAttr redOp;
   StringAttr poolExtOp;
@@ -3869,18 +3908,10 @@ struct Conv1DGenerator
   vector::CombiningKind reductionKind;
 
   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
-  // Returns true iff it is a valid conv/pooling op.
-  // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
-  // + yield) and rhs is not used) then it is the body of a pooling
-  // If conv, check for single `mul` predecessor. The `mul` operands must be
-  // block arguments or extension of block arguments.
-  // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
-  // must be block arguments or extension of block arguments.
-  bool setOperKind(Operation *reduceOp) {
+  void setOperKind(Operation *reduceOp) {
     int numBlockArguments =
         llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
-    switch (numBlockArguments) {
-    case 1: {
+    if (numBlockArguments == 1) {
       // Will be convolution if feeder is a MulOp.
       // A strength reduced version of MulOp for i1 type is AndOp which is also
       // supported. Otherwise, it can be pooling. This strength reduction logic
@@ -3892,27 +3923,13 @@ struct Conv1DGenerator
         oper = Pool;
         isPoolExt = true;
         poolExtOp = feedOp->getName().getIdentifier();
-      } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
-                    (isa<arith::AndIOp>(feedOp) &&
-                     feedOp->getResultTypes()[0].isInteger(1))) &&
-                   llvm::all_of(feedOp->getOperands(), [](Value v) {
-                     if (isa<BlockArgument>(v))
-                       return true;
-                     if (Operation *op = v.getDefiningOp())
-                       return isCastOfBlockArgument(op);
-                     return false;
-                   }))) {
-        return false;
+      } else {
+        oper = Conv;
       }
-      return true;
-    }
-    case 2:
-      // Must be pooling
+    } else {
+      // Pooling.
       oper = Pool;
       isPoolExt = false;
-      return true;
-    default:
-      return false;
     }
   }
 };
@@ -3924,7 +3941,6 @@ static FailureOr<Operation *> vectorizeConvolution(
     RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
     ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
   Conv1DGenerator conv1dGen(rewriter, op);
-  assert(conv1dGen.isValid() && "Conv1DGenerator failed");
   auto res = conv1dGen.generateNonChanneledConv();
   if (succeeded(res))
     return res;

>From 13f5183d6ca64ba6a09985f6eaeb13881c237bf4 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 11 Mar 2025 14:35:49 +0000
Subject: [PATCH 6/7] Addressing review feedbacks

---
 .../Linalg/Transforms/Vectorization.cpp       | 112 +++++++++---------
 1 file changed, 57 insertions(+), 55 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5c830693e8349..df39e6daf271d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1940,7 +1940,10 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
 }
 
 namespace {
-bool isCastOfBlockArgument(Operation *op) {
+enum class ConvOperationKind { Conv, Pool };
+} // namespace
+
+static bool isCastOfBlockArgument(Operation *op) {
   return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
          isa<BlockArgument>(op->getOperand(0));
 }
@@ -1952,8 +1955,8 @@ bool isCastOfBlockArgument(Operation *op) {
 // block arguments or extension of block arguments.
 // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
 // must be block arguments or extension of block arguments.
-enum OperKind { Conv, Pool };
-bool getOperKind(Operation *reduceOp, OperKind &oper) {
+static std::optional<ConvOperationKind>
+getConvOperationKind(Operation *reduceOp) {
   int numBlockArguments =
       llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
 
@@ -1966,33 +1969,34 @@ bool getOperKind(Operation *reduceOp, OperKind &oper) {
     auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
                                        llvm::IsaPred<BlockArgument>);
     Operation *feedOp = (*feedValIt).getDefiningOp();
-    // llvm::outs() << "feedOp: " << *feedOp << "\n";
     if (isCastOfBlockArgument(feedOp)) {
-      oper = Pool;
-    } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
-                  (isa<arith::AndIOp>(feedOp) &&
-                   feedOp->getResultTypes()[0].isInteger(1))) &&
-                 llvm::all_of(feedOp->getOperands(), [](Value v) {
-                   if (isa<BlockArgument>(v))
-                     return true;
-                   if (Operation *op = v.getDefiningOp())
-                     return isCastOfBlockArgument(op);
-                   return false;
-                 }))) {
-      return false;
+      return ConvOperationKind::Pool;
     }
-    return true;
+
+    if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+           (isa<arith::AndIOp>(feedOp) &&
+            feedOp->getResultTypes()[0].isInteger(1))) &&
+          llvm::all_of(feedOp->getOperands(), [](Value v) {
+            if (isa<BlockArgument>(v))
+              return true;
+            if (Operation *op = v.getDefiningOp())
+              return isCastOfBlockArgument(op);
+            return false;
+          }))) {
+      return std::nullopt;
+    }
+
+    return ConvOperationKind::Conv;
   }
   case 2:
     // Must be pooling
-    oper = Pool;
-    return true;
+    return ConvOperationKind::Pool;
   default:
-    return false;
+    return std::nullopt;
   }
 }
 
-bool isSupportedPoolKind(vector::CombiningKind kind) {
+static bool isSupportedPoolKind(vector::CombiningKind kind) {
   switch (kind) {
   case vector::CombiningKind::ADD:
   case vector::CombiningKind::MAXNUMF:
@@ -2008,7 +2012,6 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
     return false;
   }
 }
-} // namespace
 
 static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
   if (convOp.getNumDpsInputs() != 2 || convOp.getNumDpsInits() != 1)
@@ -2032,29 +2035,28 @@ static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
   if (!reduceOp)
     return failure();
 
-  OperKind oper = Conv;
-  if (!getOperKind(reduceOp, oper))
+  auto maybeOper = getConvOperationKind(reduceOp);
+  if (!maybeOper.has_value())
     return failure();
+
   auto maybeKind = getCombinerOpKind(reduceOp);
   // Typically convolution will have a `Add` CombiningKind but for i1 type it
   // can get strength reduced to `OR` which is also supported. This strength
   // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
   if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
                       *maybeKind != vector::CombiningKind::OR) &&
-                     (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
+                     (*maybeOper != ConvOperationKind::Pool ||
+                      !isSupportedPoolKind(*maybeKind)))) {
     return failure();
   }
 
   auto rhsRank = rhsShapedType.getRank();
-  switch (oper) {
-  case Conv:
-    if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
-      return failure();
-    break;
-  case Pool:
+  if (*maybeOper == ConvOperationKind::Pool) {
     if (rhsRank != 1)
       return failure();
-    break;
+  } else {
+    if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
+      return failure();
   }
 
   return success();
@@ -3238,7 +3240,7 @@ struct Conv1DGenerator
     Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
     redOp = reduceOp->getName().getIdentifier();
 
-    setOperKind(reduceOp);
+    setConvOperationKind(reduceOp);
 
     auto maybeKind = getCombinerOpKind(reduceOp);
     reductionKind = maybeKind.value();
@@ -3293,11 +3295,11 @@ struct Conv1DGenerator
       // out{n, w, f}
       bindShapeDims(resShapedType, nSize, wSize, fSize);
       switch (oper) {
-      case Conv:
+      case ConvOperationKind::Conv:
         // kernel{kw, c, f}
         bindShapeDims(rhsShapedType, kwSize, cSize);
         break;
-      case Pool:
+      case ConvOperationKind::Pool:
         // kernel{kw}
         bindShapeDims(rhsShapedType, kwSize);
         cSize = fSize;
@@ -3311,10 +3313,10 @@ struct Conv1DGenerator
                       1,
                   cSize};
       switch (oper) {
-      case Conv:
+      case ConvOperationKind::Conv:
         rhsShape = {kwSize, cSize, fSize};
         break;
-      case Pool:
+      case ConvOperationKind::Pool:
         rhsShape = {kwSize};
         break;
       }
@@ -3324,11 +3326,11 @@ struct Conv1DGenerator
       // out{n, f, w}
       bindShapeDims(resShapedType, nSize, fSize, wSize);
       switch (oper) {
-      case Conv:
+      case ConvOperationKind::Conv:
         // kernel{f, c, kw}
         bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
         break;
-      case Pool:
+      case ConvOperationKind::Pool:
         // kernel{kw}
         bindShapeDims(rhsShapedType, kwSize);
         cSize = fSize;
@@ -3341,10 +3343,10 @@ struct Conv1DGenerator
                   ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
                       1};
       switch (oper) {
-      case Conv:
+      case ConvOperationKind::Conv:
         rhsShape = {fSize, cSize, kwSize};
         break;
-      case Pool:
+      case ConvOperationKind::Pool:
         rhsShape = {kwSize};
         break;
       }
@@ -3376,7 +3378,7 @@ struct Conv1DGenerator
                                                         lhsPadding);
     // This is needed only for Conv.
     Value rhs = nullptr;
-    if (oper == Conv)
+    if (oper == ConvOperationKind::Conv)
       rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
                                                     rhsPadding);
     Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
@@ -3399,7 +3401,7 @@ struct Conv1DGenerator
       static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
 
       // This is needed only for Conv.
-      if (oper == Conv)
+      if (oper == ConvOperationKind::Conv)
         rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
       // nfw -> nwf
       static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
@@ -3417,7 +3419,7 @@ struct Conv1DGenerator
                                      kwSize, strideW, dilationW, wSizeStep,
                                      isSingleChanneled);
     // Do not do for pooling.
-    if (oper == Conv)
+    if (oper == ConvOperationKind::Conv)
       rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
     resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
                                       wSizeStep, isSingleChanneled);
@@ -3432,7 +3434,7 @@ struct Conv1DGenerator
     for (int64_t kw = 0; kw < kwSize; ++kw) {
       for (int64_t w = 0; w < wSize; w += wSizeStep) {
         switch (oper) {
-        case Conv:
+        case ConvOperationKind::Conv:
           if (isSingleChanneled) {
             resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
                                                    lhsVals[linearIndex(kw, w)],
@@ -3443,7 +3445,7 @@ struct Conv1DGenerator
                                                   rhsVals[kw], resVals[w]);
           }
           break;
-        case Pool:
+        case ConvOperationKind::Pool:
           resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
                                    resVals[w]);
           break;
@@ -3898,7 +3900,7 @@ struct Conv1DGenerator
   }
 
 private:
-  OperKind oper = Conv;
+  ConvOperationKind oper = ConvOperationKind::Conv;
   StringAttr redOp;
   StringAttr poolExtOp;
   bool isPoolExt = false;
@@ -3908,7 +3910,7 @@ struct Conv1DGenerator
   vector::CombiningKind reductionKind;
 
   // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
-  void setOperKind(Operation *reduceOp) {
+  void setConvOperationKind(Operation *reduceOp) {
     int numBlockArguments =
         llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
     if (numBlockArguments == 1) {
@@ -3920,17 +3922,17 @@ struct Conv1DGenerator
                                          llvm::IsaPred<BlockArgument>);
       Operation *feedOp = (*feedValIt).getDefiningOp();
       if (isCastOfBlockArgument(feedOp)) {
-        oper = Pool;
+        oper = ConvOperationKind::Pool;
         isPoolExt = true;
         poolExtOp = feedOp->getName().getIdentifier();
-      } else {
-        oper = Conv;
+        return;
       }
-    } else {
-      // Pooling.
-      oper = Pool;
-      isPoolExt = false;
+      oper = ConvOperationKind::Conv;
+      return;
     }
+    // numBlockArugments == 2 and this is a pooling op.
+    oper = ConvOperationKind::Pool;
+    isPoolExt = false;
   }
 };
 } // namespace

>From 2b5d1dc903ec25affb72e02817b1fc237df0b426 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 11 Mar 2025 17:18:59 +0000
Subject: [PATCH 7/7] Amending comments of getConvOperationKind()

---
 mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index df39e6daf271d..94a55355d51e6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1948,7 +1948,10 @@ static bool isCastOfBlockArgument(Operation *op) {
          isa<BlockArgument>(op->getOperand(0));
 }
 
-// Returns true iff it is a valid conv/pooling op.
+// Returns the ConvOperationKind of the op using reduceOp of the generic
+// payload. If it is neither a convolution nor a pooling, it returns
+// std::nullopt.
+//
 // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
 // + yield) and rhs is not used) then it is the body of a pooling
 // If conv, check for single `mul` predecessor. The `mul` operands must be



More information about the Mlir-commits mailing list