[Mlir-commits] [mlir] [MLIR] Refactor to create vectorization convOp precondition check (PR #130181)
Zhuoran Yin
llvmlistbot at llvm.org
Tue Mar 11 10:19:17 PDT 2025
https://github.com/jerryyin updated https://github.com/llvm/llvm-project/pull/130181
>From a33b2116cf8bea9287053bb57d277add19f77cb8 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Thu, 6 Mar 2025 21:40:25 +0000
Subject: [PATCH 1/7] Blocking conv2d from vectorization pass
---
.../Linalg/Transforms/Vectorization.cpp | 20 +++++++++++++++----
.../Linalg/vectorization-unsupported.mlir | 19 ++++++++++++++++++
2 files changed, 35 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..319dd4b2043c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1990,8 +1990,18 @@ static LogicalResult vectorizeLinalgOpPrecondition(
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. But we will still need stride/dilation attributes that will be
// annoying to reverse-engineer...
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+ // Check if it is 2d+ convolution. If it is, return failure because we don't
+ // support it. To use this pass on a 2d+ convolution, it should have already
+ // been decomposed to 1d convolution via
+ // DecomposeConvolutionToLowerDimOpsPass.
+ if (linalgOp.getNumParallelLoops() >= 4) {
+ LDBG("precondition failed: Regular 2d+ convolutions not supported.\n");
+ return failure();
+ }
return success();
+ }
+
// TODO: the common vector shape is equal to the static loop sizes only when
// all indexing maps are projected permutations. For convs and stencils the
// logic will need to evolve.
@@ -3929,9 +3939,11 @@ static FailureOr<Operation *> vectorizeConvolution(
if (!inputVecSizes.empty()) {
// Only use the input vector size corresponding to the channel dim. Other
// vector dims will be inferred from the Ops.
- assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
- isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
- "Not a 1D depthwise conv!");
+ if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
+ !isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
+ return rewriter.notifyMatchFailure(
+ op, "Unexpected convolution: expected 1D depthwise conv");
+ }
size_t chDimIdx =
TypeSwitch<Operation *, size_t>(op)
.Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 8f3b199145ce0..88d9e98c02bca 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -112,6 +112,25 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @conv2d(%3: tensor<1x64x58x58xf32>, %4: tensor<64x64x3x3xf32>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %5 = tensor.empty() : tensor<1x64x56x56xf32>
+ %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+ // expected-error @+1 {{Attempted to vectorize, but failed}}
+ %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%6 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
%pad = arith.constant 0.000000e+00 : f32
// expected-error @+1 {{Attempted to vectorize, but failed}}
>From 80474f73a516dc8df0df10555dd9d59cda86121c Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 7 Mar 2025 14:56:44 +0000
Subject: [PATCH 2/7] Refactor Conv1DGenerator
---
.../Linalg/Transforms/Vectorization.cpp | 26 +++++++++----------
1 file changed, 13 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 319dd4b2043c3..d57d4214a78ea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -3135,10 +3135,8 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
/// kw is unrolled, w is unrolled iff dilationW > 1.
struct Conv1DGenerator
: public StructuredGenerator<LinalgOp, utils::IteratorType> {
- Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
- int dilationW)
- : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
- strideW(strideW), dilationW(dilationW) {
+ Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
+ : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
// Determine whether `linalgOp` can be generated with this generator
if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
return;
@@ -3185,6 +3183,16 @@ struct Conv1DGenerator
return;
break;
}
+
+ // The ConvolutionOpInterface gives us guarantees of existence for
+ // strides/dilations. However, we do not need to rely on those, we can
+ // simply use them if present, otherwise use the default and let the generic
+ // conv. matcher in the ConvGenerator succeed or fail.
+ auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
+ auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
+ strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
+ dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
+
// The op is now known to be valid.
valid = true;
}
@@ -3906,15 +3914,7 @@ struct Conv1DGenerator
static FailureOr<Operation *> vectorizeConvolution(
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
- // The ConvolutionOpInterface gives us guarantees of existence for
- // strides/dilations. However, we do not need to rely on those, we can
- // simply use them if present, otherwise use the default and let the generic
- // conv. matcher in the ConvGenerator succeed or fail.
- auto strides = op->getAttrOfType<DenseIntElementsAttr>("strides");
- auto dilations = op->getAttrOfType<DenseIntElementsAttr>("dilations");
- auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
- auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
- Conv1DGenerator e(rewriter, op, stride, dilation);
+ Conv1DGenerator e(rewriter, op);
auto res = e.generateNonChanneledConv();
if (succeeded(res))
return res;
>From 00c3a3380530944e5c6c306c97fddf05d967afdc Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 7 Mar 2025 16:39:06 +0000
Subject: [PATCH 3/7] Forward declare Conv1DGenerator for validaty
---
.../Linalg/Transforms/Vectorization.cpp | 26 ++++++++++++++++---
1 file changed, 22 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d57d4214a78ea..358b30b7b9712 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -52,6 +52,12 @@ using namespace mlir::linalg;
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+// Forward declaration of Conv1DGenerator and its validator
+namespace {
+struct Conv1DGenerator;
+bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp);
+} // namespace
+
/// Try to vectorize `convOp` as a convolution.
static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
@@ -1991,14 +1997,17 @@ static LogicalResult vectorizeLinalgOpPrecondition(
// features. But we will still need stride/dilation attributes that will be
// annoying to reverse-engineer...
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
- // Check if it is 2d+ convolution. If it is, return failure because we don't
+ // Create a dummy rewriter first, a rewriter is not required for
+ // validation
+ IRRewriter dummyBuilder(linalgOp.getContext());
+ // Check if we can successfully construct a 1d convolution generator.
+ // For example, if it is 2d+ convolution, return failure because we don't
// support it. To use this pass on a 2d+ convolution, it should have already
// been decomposed to 1d convolution via
// DecomposeConvolutionToLowerDimOpsPass.
- if (linalgOp.getNumParallelLoops() >= 4) {
- LDBG("precondition failed: Regular 2d+ convolutions not supported.\n");
+ if (!validateConv1DGenerator(dummyBuilder, linalgOp))
return failure();
- }
+
return success();
}
@@ -3197,6 +3206,8 @@ struct Conv1DGenerator
valid = true;
}
+ bool isValid() { return valid; }
+
/// Generate a vector implementation for:
/// ```
/// Op def: ( w, kw )
@@ -3907,6 +3918,13 @@ struct Conv1DGenerator
}
}
};
+
+// Helper function to construct Conv1DGenerator
+bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) {
+ Conv1DGenerator conv1dGen(rewriter, linalgOp);
+ return conv1dGen.isValid();
+}
+
} // namespace
/// Helper function to vectorize a LinalgOp with convolution semantics.
>From 74a898618bd0227e56413812411ac8bed809e7c9 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 10 Mar 2025 18:34:45 +0000
Subject: [PATCH 4/7] Addressing review feedbacks
---
.../Linalg/Transforms/Vectorization.cpp | 70 ++++++++-----------
.../Linalg/vectorization-unsupported.mlir | 40 +++++++++--
2 files changed, 65 insertions(+), 45 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 358b30b7b9712..6e5907d72e97e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -52,12 +52,6 @@ using namespace mlir::linalg;
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
-// Forward declaration of Conv1DGenerator and its validator
-namespace {
-struct Conv1DGenerator;
-bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp);
-} // namespace
-
/// Try to vectorize `convOp` as a convolution.
static FailureOr<Operation *>
vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
@@ -1945,6 +1939,22 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
return success();
}
+static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
+ // We only support 1D convolutions, reject all other cases.
+ if (isa<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcFhwcOp,
+ linalg::Conv2DNchwFchwOp>(convOp)) {
+ LDBG("2D convolutions are not supported\n");
+ return failure();
+ }
+
+ if (isa<linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNcdhwFcdhwOp>(convOp)) {
+ LDBG("3D convolutions are not supported\n");
+ return failure();
+ }
+
+ return success();
+}
+
static LogicalResult vectorizeLinalgOpPrecondition(
LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -1996,20 +2006,8 @@ static LogicalResult vectorizeLinalgOpPrecondition(
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. But we will still need stride/dilation attributes that will be
// annoying to reverse-engineer...
- if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
- // Create a dummy rewriter first, a rewriter is not required for
- // validation
- IRRewriter dummyBuilder(linalgOp.getContext());
- // Check if we can successfully construct a 1d convolution generator.
- // For example, if it is 2d+ convolution, return failure because we don't
- // support it. To use this pass on a 2d+ convolution, it should have already
- // been decomposed to 1d convolution via
- // DecomposeConvolutionToLowerDimOpsPass.
- if (!validateConv1DGenerator(dummyBuilder, linalgOp))
- return failure();
-
- return success();
- }
+ if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+ return vectorizeConvOpPrecondition(linalgOp);
// TODO: the common vector shape is equal to the static loop sizes only when
// all indexing maps are projected permutations. For convs and stencils the
@@ -3918,13 +3916,6 @@ struct Conv1DGenerator
}
}
};
-
-// Helper function to construct Conv1DGenerator
-bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) {
- Conv1DGenerator conv1dGen(rewriter, linalgOp);
- return conv1dGen.isValid();
-}
-
} // namespace
/// Helper function to vectorize a LinalgOp with convolution semantics.
@@ -3932,20 +3923,21 @@ bool validateConv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) {
static FailureOr<Operation *> vectorizeConvolution(
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
- Conv1DGenerator e(rewriter, op);
- auto res = e.generateNonChanneledConv();
+ Conv1DGenerator conv1dGen(rewriter, op);
+ assert(conv1dGen.isValid() && "Conv1DGenerator failed");
+ auto res = conv1dGen.generateNonChanneledConv();
if (succeeded(res))
return res;
- res = e.generateNwcConv();
+ res = conv1dGen.generateNwcConv();
if (succeeded(res))
return res;
- res = e.generateNcwConv();
+ res = conv1dGen.generateNcwConv();
if (succeeded(res))
return res;
- res = e.generateNwcPooling();
+ res = conv1dGen.generateNwcPooling();
if (succeeded(res))
return res;
- res = e.generateNcwPooling();
+ res = conv1dGen.generateNcwPooling();
if (succeeded(res))
return res;
@@ -3957,11 +3949,9 @@ static FailureOr<Operation *> vectorizeConvolution(
if (!inputVecSizes.empty()) {
// Only use the input vector size corresponding to the channel dim. Other
// vector dims will be inferred from the Ops.
- if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
- !isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
- return rewriter.notifyMatchFailure(
- op, "Unexpected convolution: expected 1D depthwise conv");
- }
+ assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
+ isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
+ "Not a 1D depthwise conv!");
size_t chDimIdx =
TypeSwitch<Operation *, size_t>(op)
.Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
@@ -3970,8 +3960,8 @@ static FailureOr<Operation *> vectorizeConvolution(
vecChDimSize = inputVecSizes[chDimIdx];
vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
}
- return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
- flatten1DDepthwiseConv);
+ return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
+ flatten1DDepthwiseConv);
}
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 88d9e98c02bca..2d1f0191eb798 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -112,12 +112,9 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @conv2d(%3: tensor<1x64x58x58xf32>, %4: tensor<64x64x3x3xf32>) {
- %cst = arith.constant 0.000000e+00 : f32
- %5 = tensor.empty() : tensor<1x64x56x56xf32>
- %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+func.func @conv2d_nchw_fchw(%input: tensor<1x5x8x8xf32>, %filter: tensor<4x5x3x3xf32>, %output: tensor<1x4x6x6xf32>) {
// expected-error @+1 {{Attempted to vectorize, but failed}}
- %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%6 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+ linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x5x8x8xf32>, tensor<4x5x3x3xf32>) outs(%output : tensor<1x4x6x6xf32>) -> tensor<1x4x6x6xf32>
return
}
@@ -131,6 +128,39 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @conv2d_nhwc_fhwc(%input: tensor<1x8x8x5xf32>, %filter: tensor<4x3x3x5xf32>, %output: tensor<1x6x6x4xf32>) {
+ // expected-error @+1 {{Attempted to vectorize, but failed}}
+ linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%input, %filter : tensor<1x8x8x5xf32>, tensor<4x3x3x5xf32>) outs(%output : tensor<1x6x6x4xf32>) -> tensor<1x6x6x4xf32>
+ return
+}
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @conv3d_ncdhw_fcdhw(%input: tensor<1x5x8x8x8xf32>, %filter: tensor<4x5x3x3x3xf32>, %output: tensor<1x4x6x6x6xf32>) {
+ // expected-error @+1 {{Attempted to vectorize, but failed}}
+ linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : vector<3xi64>, strides = dense<1> : vector<3xi64>} ins(%input, %filter : tensor<1x5x8x8x8xf32>, tensor<4x5x3x3x3xf32>) outs(%output : tensor<1x4x6x6x6xf32>) -> tensor<1x4x6x6x6xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.conv_3d_ncdhw_fcdhw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
%pad = arith.constant 0.000000e+00 : f32
// expected-error @+1 {{Attempted to vectorize, but failed}}
>From 2ef5555f82feb9f8dbaa1cac04cfc25013c0d16f Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Mon, 10 Mar 2025 21:51:35 +0000
Subject: [PATCH 5/7] Peel away conv1dgen validator into precondition check
---
.../Linalg/Transforms/Vectorization.cpp | 226 ++++++++++--------
1 file changed, 121 insertions(+), 105 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6e5907d72e97e..5c830693e8349 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1939,19 +1939,124 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
return success();
}
+namespace {
+bool isCastOfBlockArgument(Operation *op) {
+ return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
+ isa<BlockArgument>(op->getOperand(0));
+}
+
+// Returns true iff it is a valid conv/pooling op.
+// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
+// + yield) and rhs is not used) then it is the body of a pooling
+// If conv, check for single `mul` predecessor. The `mul` operands must be
+// block arguments or extension of block arguments.
+// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
+// must be block arguments or extension of block arguments.
+enum OperKind { Conv, Pool };
+bool getOperKind(Operation *reduceOp, OperKind &oper) {
+ int numBlockArguments =
+ llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
+
+ switch (numBlockArguments) {
+ case 1: {
+ // Will be convolution if feeder is a MulOp.
+ // A strength reduced version of MulOp for i1 type is AndOp which is also
+ // supported. Otherwise, it can be pooling. This strength reduction logic
+ // is in `buildBinaryFn` helper in the Linalg dialect.
+ auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
+ llvm::IsaPred<BlockArgument>);
+ Operation *feedOp = (*feedValIt).getDefiningOp();
+ // llvm::outs() << "feedOp: " << *feedOp << "\n";
+ if (isCastOfBlockArgument(feedOp)) {
+ oper = Pool;
+ } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+ (isa<arith::AndIOp>(feedOp) &&
+ feedOp->getResultTypes()[0].isInteger(1))) &&
+ llvm::all_of(feedOp->getOperands(), [](Value v) {
+ if (isa<BlockArgument>(v))
+ return true;
+ if (Operation *op = v.getDefiningOp())
+ return isCastOfBlockArgument(op);
+ return false;
+ }))) {
+ return false;
+ }
+ return true;
+ }
+ case 2:
+ // Must be pooling
+ oper = Pool;
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool isSupportedPoolKind(vector::CombiningKind kind) {
+ switch (kind) {
+ case vector::CombiningKind::ADD:
+ case vector::CombiningKind::MAXNUMF:
+ case vector::CombiningKind::MAXIMUMF:
+ case vector::CombiningKind::MAXSI:
+ case vector::CombiningKind::MAXUI:
+ case vector::CombiningKind::MINNUMF:
+ case vector::CombiningKind::MINIMUMF:
+ case vector::CombiningKind::MINSI:
+ case vector::CombiningKind::MINUI:
+ return true;
+ default:
+ return false;
+ }
+}
+} // namespace
+
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
- // We only support 1D convolutions, reject all other cases.
- if (isa<linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcFhwcOp,
- linalg::Conv2DNchwFchwOp>(convOp)) {
- LDBG("2D convolutions are not supported\n");
+ if (convOp.getNumDpsInputs() != 2 || convOp.getNumDpsInits() != 1)
+ return failure();
+
+ auto lhsShaped = convOp.getDpsInputOperand(0)->get();
+ auto rhsShaped = convOp.getDpsInputOperand(1)->get();
+ auto resShaped = convOp.getDpsInitOperand(0)->get();
+ auto lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
+ auto rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
+ auto resShapedType = dyn_cast<ShapedType>(resShaped.getType());
+ if (!lhsShapedType || !rhsShapedType || !resShapedType)
+ return failure();
+ // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
+ // (non-channeled convolution -> LHS and RHS both have single dimensions).
+ if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
+ (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
return failure();
- }
- if (isa<linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNcdhwFcdhwOp>(convOp)) {
- LDBG("3D convolutions are not supported\n");
+ Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
+ if (!reduceOp)
+ return failure();
+
+ OperKind oper = Conv;
+ if (!getOperKind(reduceOp, oper))
+ return failure();
+ auto maybeKind = getCombinerOpKind(reduceOp);
+ // Typically convolution will have a `Add` CombiningKind but for i1 type it
+ // can get strength reduced to `OR` which is also supported. This strength
+ // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
+ if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
+ *maybeKind != vector::CombiningKind::OR) &&
+ (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
return failure();
}
+ auto rhsRank = rhsShapedType.getRank();
+ switch (oper) {
+ case Conv:
+ if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
+ return failure();
+ break;
+ case Pool:
+ if (rhsRank != 1)
+ return failure();
+ break;
+ }
+
return success();
}
@@ -3084,28 +3189,6 @@ static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
}
namespace {
-bool isCastOfBlockArgument(Operation *op) {
- return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
- isa<BlockArgument>(op->getOperand(0));
-}
-
-bool isSupportedPoolKind(vector::CombiningKind kind) {
- switch (kind) {
- case vector::CombiningKind::ADD:
- case vector::CombiningKind::MAXNUMF:
- case vector::CombiningKind::MAXIMUMF:
- case vector::CombiningKind::MAXSI:
- case vector::CombiningKind::MAXUI:
- case vector::CombiningKind::MINNUMF:
- case vector::CombiningKind::MINIMUMF:
- case vector::CombiningKind::MINSI:
- case vector::CombiningKind::MINUI:
- return true;
- default:
- return false;
- }
-}
-
/// Generate a vector implementation for either:
/// ```
/// Op def: ( w, kw )
@@ -3144,53 +3227,22 @@ struct Conv1DGenerator
: public StructuredGenerator<LinalgOp, utils::IteratorType> {
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
: StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
- // Determine whether `linalgOp` can be generated with this generator
- if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
- return;
+
lhsShaped = linalgOp.getDpsInputOperand(0)->get();
rhsShaped = linalgOp.getDpsInputOperand(1)->get();
resShaped = linalgOp.getDpsInitOperand(0)->get();
lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
resShapedType = dyn_cast<ShapedType>(resShaped.getType());
- if (!lhsShapedType || !rhsShapedType || !resShapedType)
- return;
- // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
- // (non-channeled convolution -> LHS and RHS both have single dimensions).
- if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
- (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
- return;
Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
- if (!reduceOp)
- return;
redOp = reduceOp->getName().getIdentifier();
- if (!setOperKind(reduceOp))
- return;
+ setOperKind(reduceOp);
+
auto maybeKind = getCombinerOpKind(reduceOp);
- // Typically convolution will have a `Add` CombiningKind but for i1 type it
- // can get strength reduced to `OR` which is also supported. This strength
- // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
- if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
- *maybeKind != vector::CombiningKind::OR) &&
- (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
- return;
- }
reductionKind = maybeKind.value();
- auto rhsRank = rhsShapedType.getRank();
- switch (oper) {
- case Conv:
- if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
- return;
- break;
- case Pool:
- if (rhsRank != 1)
- return;
- break;
- }
-
// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can
// simply use them if present, otherwise use the default and let the generic
@@ -3199,13 +3251,8 @@ struct Conv1DGenerator
auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
-
- // The op is now known to be valid.
- valid = true;
}
- bool isValid() { return valid; }
-
/// Generate a vector implementation for:
/// ```
/// Op def: ( w, kw )
@@ -3225,9 +3272,6 @@ struct Conv1DGenerator
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
- if (!valid)
- return rewriter.notifyMatchFailure(op, "unvectorizable 1-D conv/pool");
-
int64_t nSize, wSize, cSize, kwSize, fSize;
SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
@@ -3510,9 +3554,6 @@ struct Conv1DGenerator
FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
bool channelDimScalableFlag,
bool flatten) {
- if (!valid)
- return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
-
bool scalableChDim = false;
bool useMasking = false;
int64_t nSize, wSize, cSize, kwSize;
@@ -3857,8 +3898,6 @@ struct Conv1DGenerator
}
private:
- enum OperKind { Conv, Pool };
- bool valid = false;
OperKind oper = Conv;
StringAttr redOp;
StringAttr poolExtOp;
@@ -3869,18 +3908,10 @@ struct Conv1DGenerator
vector::CombiningKind reductionKind;
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
- // Returns true iff it is a valid conv/pooling op.
- // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
- // + yield) and rhs is not used) then it is the body of a pooling
- // If conv, check for single `mul` predecessor. The `mul` operands must be
- // block arguments or extension of block arguments.
- // Otherwise, check for one or zero `ext` predecessor. The `ext` operands
- // must be block arguments or extension of block arguments.
- bool setOperKind(Operation *reduceOp) {
+ void setOperKind(Operation *reduceOp) {
int numBlockArguments =
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
- switch (numBlockArguments) {
- case 1: {
+ if (numBlockArguments == 1) {
// Will be convolution if feeder is a MulOp.
// A strength reduced version of MulOp for i1 type is AndOp which is also
// supported. Otherwise, it can be pooling. This strength reduction logic
@@ -3892,27 +3923,13 @@ struct Conv1DGenerator
oper = Pool;
isPoolExt = true;
poolExtOp = feedOp->getName().getIdentifier();
- } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
- (isa<arith::AndIOp>(feedOp) &&
- feedOp->getResultTypes()[0].isInteger(1))) &&
- llvm::all_of(feedOp->getOperands(), [](Value v) {
- if (isa<BlockArgument>(v))
- return true;
- if (Operation *op = v.getDefiningOp())
- return isCastOfBlockArgument(op);
- return false;
- }))) {
- return false;
+ } else {
+ oper = Conv;
}
- return true;
- }
- case 2:
- // Must be pooling
+ } else {
+ // Pooling.
oper = Pool;
isPoolExt = false;
- return true;
- default:
- return false;
}
}
};
@@ -3924,7 +3941,6 @@ static FailureOr<Operation *> vectorizeConvolution(
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
Conv1DGenerator conv1dGen(rewriter, op);
- assert(conv1dGen.isValid() && "Conv1DGenerator failed");
auto res = conv1dGen.generateNonChanneledConv();
if (succeeded(res))
return res;
>From 13f5183d6ca64ba6a09985f6eaeb13881c237bf4 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 11 Mar 2025 14:35:49 +0000
Subject: [PATCH 6/7] Addressing review feedbacks
---
.../Linalg/Transforms/Vectorization.cpp | 112 +++++++++---------
1 file changed, 57 insertions(+), 55 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5c830693e8349..df39e6daf271d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1940,7 +1940,10 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
}
namespace {
-bool isCastOfBlockArgument(Operation *op) {
+enum class ConvOperationKind { Conv, Pool };
+} // namespace
+
+static bool isCastOfBlockArgument(Operation *op) {
return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
isa<BlockArgument>(op->getOperand(0));
}
@@ -1952,8 +1955,8 @@ bool isCastOfBlockArgument(Operation *op) {
// block arguments or extension of block arguments.
// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
// must be block arguments or extension of block arguments.
-enum OperKind { Conv, Pool };
-bool getOperKind(Operation *reduceOp, OperKind &oper) {
+static std::optional<ConvOperationKind>
+getConvOperationKind(Operation *reduceOp) {
int numBlockArguments =
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
@@ -1966,33 +1969,34 @@ bool getOperKind(Operation *reduceOp, OperKind &oper) {
auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
llvm::IsaPred<BlockArgument>);
Operation *feedOp = (*feedValIt).getDefiningOp();
- // llvm::outs() << "feedOp: " << *feedOp << "\n";
if (isCastOfBlockArgument(feedOp)) {
- oper = Pool;
- } else if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
- (isa<arith::AndIOp>(feedOp) &&
- feedOp->getResultTypes()[0].isInteger(1))) &&
- llvm::all_of(feedOp->getOperands(), [](Value v) {
- if (isa<BlockArgument>(v))
- return true;
- if (Operation *op = v.getDefiningOp())
- return isCastOfBlockArgument(op);
- return false;
- }))) {
- return false;
+ return ConvOperationKind::Pool;
}
- return true;
+
+ if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
+ (isa<arith::AndIOp>(feedOp) &&
+ feedOp->getResultTypes()[0].isInteger(1))) &&
+ llvm::all_of(feedOp->getOperands(), [](Value v) {
+ if (isa<BlockArgument>(v))
+ return true;
+ if (Operation *op = v.getDefiningOp())
+ return isCastOfBlockArgument(op);
+ return false;
+ }))) {
+ return std::nullopt;
+ }
+
+ return ConvOperationKind::Conv;
}
case 2:
// Must be pooling
- oper = Pool;
- return true;
+ return ConvOperationKind::Pool;
default:
- return false;
+ return std::nullopt;
}
}
-bool isSupportedPoolKind(vector::CombiningKind kind) {
+static bool isSupportedPoolKind(vector::CombiningKind kind) {
switch (kind) {
case vector::CombiningKind::ADD:
case vector::CombiningKind::MAXNUMF:
@@ -2008,7 +2012,6 @@ bool isSupportedPoolKind(vector::CombiningKind kind) {
return false;
}
}
-} // namespace
static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
if (convOp.getNumDpsInputs() != 2 || convOp.getNumDpsInits() != 1)
@@ -2032,29 +2035,28 @@ static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
if (!reduceOp)
return failure();
- OperKind oper = Conv;
- if (!getOperKind(reduceOp, oper))
+ auto maybeOper = getConvOperationKind(reduceOp);
+ if (!maybeOper.has_value())
return failure();
+
auto maybeKind = getCombinerOpKind(reduceOp);
// Typically convolution will have a `Add` CombiningKind but for i1 type it
// can get strength reduced to `OR` which is also supported. This strength
// reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
*maybeKind != vector::CombiningKind::OR) &&
- (oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
+ (*maybeOper != ConvOperationKind::Pool ||
+ !isSupportedPoolKind(*maybeKind)))) {
return failure();
}
auto rhsRank = rhsShapedType.getRank();
- switch (oper) {
- case Conv:
- if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
- return failure();
- break;
- case Pool:
+ if (*maybeOper == ConvOperationKind::Pool) {
if (rhsRank != 1)
return failure();
- break;
+ } else {
+ if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
+ return failure();
}
return success();
@@ -3238,7 +3240,7 @@ struct Conv1DGenerator
Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
redOp = reduceOp->getName().getIdentifier();
- setOperKind(reduceOp);
+ setConvOperationKind(reduceOp);
auto maybeKind = getCombinerOpKind(reduceOp);
reductionKind = maybeKind.value();
@@ -3293,11 +3295,11 @@ struct Conv1DGenerator
// out{n, w, f}
bindShapeDims(resShapedType, nSize, wSize, fSize);
switch (oper) {
- case Conv:
+ case ConvOperationKind::Conv:
// kernel{kw, c, f}
bindShapeDims(rhsShapedType, kwSize, cSize);
break;
- case Pool:
+ case ConvOperationKind::Pool:
// kernel{kw}
bindShapeDims(rhsShapedType, kwSize);
cSize = fSize;
@@ -3311,10 +3313,10 @@ struct Conv1DGenerator
1,
cSize};
switch (oper) {
- case Conv:
+ case ConvOperationKind::Conv:
rhsShape = {kwSize, cSize, fSize};
break;
- case Pool:
+ case ConvOperationKind::Pool:
rhsShape = {kwSize};
break;
}
@@ -3324,11 +3326,11 @@ struct Conv1DGenerator
// out{n, f, w}
bindShapeDims(resShapedType, nSize, fSize, wSize);
switch (oper) {
- case Conv:
+ case ConvOperationKind::Conv:
// kernel{f, c, kw}
bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
break;
- case Pool:
+ case ConvOperationKind::Pool:
// kernel{kw}
bindShapeDims(rhsShapedType, kwSize);
cSize = fSize;
@@ -3341,10 +3343,10 @@ struct Conv1DGenerator
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
1};
switch (oper) {
- case Conv:
+ case ConvOperationKind::Conv:
rhsShape = {fSize, cSize, kwSize};
break;
- case Pool:
+ case ConvOperationKind::Pool:
rhsShape = {kwSize};
break;
}
@@ -3376,7 +3378,7 @@ struct Conv1DGenerator
lhsPadding);
// This is needed only for Conv.
Value rhs = nullptr;
- if (oper == Conv)
+ if (oper == ConvOperationKind::Conv)
rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
rhsPadding);
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
@@ -3399,7 +3401,7 @@ struct Conv1DGenerator
static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
// This is needed only for Conv.
- if (oper == Conv)
+ if (oper == ConvOperationKind::Conv)
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs);
// nfw -> nwf
static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
@@ -3417,7 +3419,7 @@ struct Conv1DGenerator
kwSize, strideW, dilationW, wSizeStep,
isSingleChanneled);
// Do not do for pooling.
- if (oper == Conv)
+ if (oper == ConvOperationKind::Conv)
rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
wSizeStep, isSingleChanneled);
@@ -3432,7 +3434,7 @@ struct Conv1DGenerator
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
switch (oper) {
- case Conv:
+ case ConvOperationKind::Conv:
if (isSingleChanneled) {
resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
lhsVals[linearIndex(kw, w)],
@@ -3443,7 +3445,7 @@ struct Conv1DGenerator
rhsVals[kw], resVals[w]);
}
break;
- case Pool:
+ case ConvOperationKind::Pool:
resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
resVals[w]);
break;
@@ -3898,7 +3900,7 @@ struct Conv1DGenerator
}
private:
- OperKind oper = Conv;
+ ConvOperationKind oper = ConvOperationKind::Conv;
StringAttr redOp;
StringAttr poolExtOp;
bool isPoolExt = false;
@@ -3908,7 +3910,7 @@ struct Conv1DGenerator
vector::CombiningKind reductionKind;
// Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
- void setOperKind(Operation *reduceOp) {
+ void setConvOperationKind(Operation *reduceOp) {
int numBlockArguments =
llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
if (numBlockArguments == 1) {
@@ -3920,17 +3922,17 @@ struct Conv1DGenerator
llvm::IsaPred<BlockArgument>);
Operation *feedOp = (*feedValIt).getDefiningOp();
if (isCastOfBlockArgument(feedOp)) {
- oper = Pool;
+ oper = ConvOperationKind::Pool;
isPoolExt = true;
poolExtOp = feedOp->getName().getIdentifier();
- } else {
- oper = Conv;
+ return;
}
- } else {
- // Pooling.
- oper = Pool;
- isPoolExt = false;
+ oper = ConvOperationKind::Conv;
+ return;
}
+ // numBlockArugments == 2 and this is a pooling op.
+ oper = ConvOperationKind::Pool;
+ isPoolExt = false;
}
};
} // namespace
>From 2b5d1dc903ec25affb72e02817b1fc237df0b426 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Tue, 11 Mar 2025 17:18:59 +0000
Subject: [PATCH 7/7] Amending comments of getConvOperationKind()
---
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index df39e6daf271d..94a55355d51e6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1948,7 +1948,10 @@ static bool isCastOfBlockArgument(Operation *op) {
isa<BlockArgument>(op->getOperand(0));
}
-// Returns true iff it is a valid conv/pooling op.
+// Returns the ConvOperationKind of the op using reduceOp of the generic
+// payload. If it is neither a convolution nor a pooling, it returns
+// std::nullopt.
+//
// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
// + yield) and rhs is not used) then it is the body of a pooling
// If conv, check for single `mul` predecessor. The `mul` operands must be
More information about the Mlir-commits
mailing list