[Mlir-commits] [mlir] [Linalg] Update Conv Decomposition patterns to work with generic convolution ops as well (PR #174196)
Abhishek Varma
llvmlistbot at llvm.org
Tue Jan 6 01:16:21 PST 2026
https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/174196
>From 8ca77c9228105aa10039464c17e11403882660e0 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 1 Jan 2026 07:36:31 +0000
Subject: [PATCH 1/2] [Linalg] Update Conv Decomposition to work with generic
conv ops
-- This commit updates Conv Decomposition to work with both named as
well as generic convolution ops.
-- This required an update to the `isaConvolutionOfType` API to also
populate dilations/strides info for named convolution ops and since
now a generic LinalgOp is being used as the root op in the pattern
above the assert of the op implementing a ConvolutionOpInterface has
been replaced with an early exit if.
Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
.../Dialect/Linalg/Transforms/Transforms.h | 27 +-
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 8 +-
.../Dialect/Linalg/Transforms/Transforms.cpp | 107 +--
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 686 ++++++++++++++----
.../Linalg/transform-op-decompose.mlir | 238 ++++--
5 files changed, 770 insertions(+), 296 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6678d693719bf..32067358438d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1641,16 +1641,16 @@ FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
//===----------------------------------------------------------------------===//
/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
-/// convolution ops.
+/// convolution ops. Works with both named ops and equivalent generic ops.
template <typename Conv2DOp, typename Conv1DOp>
struct DownscaleSizeOneWindowed2DConvolution final
- : public OpRewritePattern<Conv2DOp> {
- using OpRewritePattern<Conv2DOp>::OpRewritePattern;
+ : public OpInterfaceRewritePattern<LinalgOp> {
+ using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
- FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+ FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const;
- LogicalResult matchAndRewrite(Conv2DOp convOp,
+ LogicalResult matchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(convOp, rewriter);
}
@@ -1664,29 +1664,28 @@ extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
/// dimensions into 1-D depthwise convolution ops.
struct DownscaleDepthwiseConv2DNhwcHwcOp final
- : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
+ : public OpInterfaceRewritePattern<LinalgOp> {
DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context,
PatternBenefit benefit = 1)
- : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit) {}
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
FailureOr<DepthwiseConv1DNwcWcOp>
- returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
- PatternRewriter &rewriter) const;
+ returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const;
- LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
+ LogicalResult matchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(convOp, rewriter);
}
};
-struct DownscaleConv2DOp final : public OpRewritePattern<Conv2DOp> {
+struct DownscaleConv2DOp final : public OpInterfaceRewritePattern<LinalgOp> {
DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1)
- : OpRewritePattern<Conv2DOp>(context, benefit) {}
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
- FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+ FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const;
- LogicalResult matchAndRewrite(Conv2DOp convOp,
+ LogicalResult matchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(convOp, rewriter);
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 9da01f30b52d2..16d557a6ed7fa 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -108,10 +108,12 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
/// Given a linalg `op` this function returns true if it is a convolution op of
/// type `ConvOpTy` and populates `dilations` and `strides` with values inferred
-/// from the indexing maps.
+/// from the indexing maps. If `dilations` or `strides` is nullptr, the
+/// corresponding values are not populated.
template <typename ConvOpTy>
-bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides);
+bool isaConvolutionOpOfType(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 96cc378f6c21a..7972408318b95 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -32,6 +32,7 @@
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/raw_ostream.h"
+#include <type_traits>
#include <utility>
#define DEBUG_TYPE "linalg-transforms"
@@ -1406,13 +1407,18 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
template <typename Conv2DOp, typename Conv1DOp>
FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
- returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
+ returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const {
+ // Check if this LinalgOp is of the expected Conv2DOp type (named or generic).
+ SmallVector<int64_t> dilations, strides;
+ if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+ return failure();
+
if (convOp.hasPureBufferSemantics())
return failure(); // To be implemented.
- Value input = convOp.getInputs().front();
- Value kernel = convOp.getInputs().back();
- Value output = convOp.getOutputs().front();
+ Value input = convOp.getDpsInputs().front();
+ Value kernel = convOp.getDpsInputs().back();
+ Value output = convOp.getDpsInits().front();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
@@ -1421,38 +1427,33 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
- // Get domain indices based on conv2D layout.
- auto [khIndex, kwIndex, ohIndex, owIndex] =
- TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>(
- convOp)
- .Case([&](linalg::Conv2DNhwcHwcfOp op) {
- return std::make_tuple(0, 1, 1, 2);
- })
- .Case([&](linalg::Conv2DNchwFchwOp op) {
- return std::make_tuple(2, 3, 2, 3);
- })
- .Case([&](linalg::PoolingNhwcSumOp op) {
- return std::make_tuple(0, 1, 1, 2);
- })
- .Case([&](linalg::PoolingNchwSumOp op) {
- return std::make_tuple(0, 1, 2, 3);
- })
- .Case([&](linalg::PoolingNhwcMaxOp op) {
- return std::make_tuple(0, 1, 1, 2);
- })
- .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
- return std::make_tuple(0, 1, 1, 2);
- })
- .Case([&](linalg::PoolingNhwcMinOp op) {
- return std::make_tuple(0, 1, 1, 2);
- })
- .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
- return std::make_tuple(0, 1, 1, 2);
- })
- .Case([&](linalg::PoolingNchwMaxOp op) {
- return std::make_tuple(0, 1, 2, 3);
- })
- .DefaultUnreachable("unexpected conv2d/pool2d operation.");
+ // Get domain indices based on Conv2DOp type. These are known at compile time.
+ int64_t khIndex, kwIndex, ohIndex, owIndex;
+ if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
+ std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
+ std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
+ std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
+ std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
+ std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
+ // NHWC layout: kernel [H, W, ...], output [N, H, W, C]
+ khIndex = 0;
+ kwIndex = 1;
+ ohIndex = 1;
+ owIndex = 2;
+ } else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
+ // NCHW_FCHW layout: kernel [..., H, W], output [N, C, H, W]
+ khIndex = 2;
+ kwIndex = 3;
+ ohIndex = 2;
+ owIndex = 3;
+ } else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
+ std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
+ // NCHW pooling layout: kernel [H, W], output [N, C, H, W]
+ khIndex = 0;
+ kwIndex = 1;
+ ohIndex = 2;
+ owIndex = 3;
+ }
// Only handle the case where at least one of the window dimensions is
// of size 1. Other cases can rely on tiling to reduce to such cases.
@@ -1484,13 +1485,9 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
// Rank-reduce strides and dilations too.
// TODO: dropDim 1-liner helper.
- auto strides =
- llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
strides.erase(strides.begin() + (removeH ? 0 : 1));
auto stridesAttr = rewriter.getI64VectorAttr(strides);
- auto dilations =
- llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
dilations.erase(dilations.begin() + (removeH ? 0 : 1));
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
@@ -1527,13 +1524,19 @@ template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
FailureOr<DepthwiseConv1DNwcWcOp>
DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
- DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
+ LinalgOp convOp, PatternRewriter &rewriter) const {
+ // Check if this LinalgOp is a DepthwiseConv2DNhwcHwcOp (named or generic).
+ SmallVector<int64_t> dilations, strides;
+ if (!linalg::isaConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(
+ convOp, &dilations, &strides))
+ return failure();
+
if (convOp.hasPureBufferSemantics())
return failure(); // To be implemented.
- Value input = convOp.getInputs().front();
- Value kernel = convOp.getInputs().back();
- Value output = convOp.getOutputs().front();
+ Value input = convOp.getDpsInputs().front();
+ Value kernel = convOp.getDpsInputs().back();
+ Value output = convOp.getDpsInits().front();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
@@ -1572,12 +1575,9 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
// Rank-reduce strides and dilations too.
// TODO: dropDim 1-liner helper.
- auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
strides.erase(strides.begin() + (removeH ? 0 : 1));
auto stridesAttr = rewriter.getI64VectorAttr(strides);
- auto dilations =
- llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
dilations.erase(dilations.begin() + (removeH ? 0 : 1));
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
@@ -1594,14 +1594,19 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
}
FailureOr<Conv1DOp>
-DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp,
+DownscaleConv2DOp::returningMatchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const {
+ // Check if this LinalgOp is a Conv2DOp (named or generic).
+ SmallVector<int64_t> dilations, strides;
+ if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+ return failure();
+
if (convOp.hasPureBufferSemantics())
return failure(); // To be implemented.
- Value input = convOp.getInputs().front();
- Value kernel = convOp.getInputs().back();
- Value output = convOp.getOutputs().front();
+ Value input = convOp.getDpsInputs().front();
+ Value kernel = convOp.getDpsInputs().back();
+ Value output = convOp.getDpsInits().front();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 2718124251c18..1cdd01567c4e7 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -373,11 +373,8 @@ static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
body);
}
-// max_unsigned ops should not allow float data type.
-// TODO(#164800): Retire OPDSL logic.
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
- body);
+ return bodyMatcherForPoolOps<arith::MaxUIOp>(yieldVal, body);
}
static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
@@ -385,11 +382,8 @@ static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
body);
}
-// min_unsigned ops should not allow float data type.
-// TODO(#164800): Retire OPDSL logic.
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
- body);
+ return bodyMatcherForPoolOps<arith::MinUIOp>(yieldVal, body);
}
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
@@ -601,11 +595,20 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv1DOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (isa<linalg::Conv1DOp>(op)) {
+ // Conv1DOp has no strides/dilations attributes, default to 1.
+ *dilations = SmallVector<int64_t>(1, 1);
+ *strides = SmallVector<int64_t>(1, 1);
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
AffineExpr W = m.dim(0);
@@ -622,11 +625,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv1DNwcWcfOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv1DNwcWcfOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
AffineExpr N = m.dim(0);
@@ -646,11 +657,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv1DNcwFcwOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv1DNcwFcwOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
AffineExpr N = m.dim(0);
@@ -670,11 +689,20 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (isa<linalg::Conv2DOp>(op)) {
+ // Conv2DOp has no strides/dilations attributes, default to 1.
+ *dilations = SmallVector<int64_t>(2, 1);
+ *strides = SmallVector<int64_t>(2, 1);
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr H = m.dim(0);
@@ -694,11 +722,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNhwcHwcfOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -721,11 +757,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfQOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -750,11 +794,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNhwcFhwcOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -777,11 +829,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcQOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -806,11 +866,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNchwFchwOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv2DNchwFchwOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -833,11 +901,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNchwFchwQOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv2DNchwFchwQOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -862,11 +938,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNgchwFgchwOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv2DNgchwFgchwOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -891,11 +975,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNgchwGfchwOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -920,11 +1012,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwQOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -951,11 +1051,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -980,11 +1088,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcQOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1011,11 +1127,20 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv3DOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (isa<linalg::Conv3DOp>(op)) {
+ // Conv3DOp has no strides/dilations attributes, default to 1.
+ *dilations = SmallVector<int64_t>(3, 1);
+ *strides = SmallVector<int64_t>(3, 1);
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
AffineExpr D = m.dim(0);
@@ -1039,11 +1164,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv3DNdhwcDhwcfOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1070,11 +1203,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfQOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfQOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1103,11 +1244,19 @@ template <>
bool isaConvolutionOpOfType<linalg::Conv3DNcdhwFcdhwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::Conv3DNcdhwFcdhwOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp = dyn_cast<linalg::Conv3DNcdhwFcdhwOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1134,11 +1283,20 @@ template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp =
+ dyn_cast<linalg::DepthwiseConv1DNcwCwOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1157,11 +1315,20 @@ template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp =
+ dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1180,11 +1347,20 @@ template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp =
+ dyn_cast<linalg::DepthwiseConv1DNwcWcmOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1204,11 +1380,20 @@ template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp =
+ dyn_cast<linalg::DepthwiseConv2DNchwChwOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1230,11 +1415,20 @@ template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp =
+ dyn_cast<linalg::DepthwiseConv2DNhwcHwcOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1256,11 +1450,20 @@ template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp =
+ dyn_cast<linalg::DepthwiseConv2DNhwcHwcQOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1284,11 +1487,20 @@ template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp =
+ dyn_cast<linalg::DepthwiseConv2DNhwcHwcmOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1311,11 +1523,20 @@ template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp =
+ dyn_cast<linalg::DepthwiseConv2DNhwcHwcmQOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1340,11 +1561,20 @@ template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp =
+ dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1370,11 +1600,20 @@ template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNcdhwCdhwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp =
+ dyn_cast<linalg::DepthwiseConv3DNcdhwCdhwOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1400,11 +1639,20 @@ template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto convOp =
+ dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
AffineExpr N = m.dim(0);
@@ -1431,11 +1679,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNhwcMaxOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNhwcMaxOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
PoolingType::MaxSigned);
@@ -1458,11 +1714,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNhwcMinOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNhwcMinOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
PoolingType::MinSigned);
@@ -1485,11 +1749,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNhwcSumOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNhwcSumOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
PoolingType::Sum);
@@ -1512,11 +1784,20 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp =
+ dyn_cast<linalg::PoolingNhwcMaxUnsignedOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
PoolingType::MaxUnsigned);
@@ -1539,11 +1820,20 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp =
+ dyn_cast<linalg::PoolingNhwcMinUnsignedOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
PoolingType::MinUnsigned);
@@ -1566,11 +1856,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNchwSumOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNchwSumOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNchwSumOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
PoolingType::Sum);
@@ -1593,11 +1891,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNchwMaxOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNchwMaxOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNchwMaxOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
PoolingType::MaxSigned);
@@ -1620,11 +1926,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNwcSumOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNwcSumOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNwcSumOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
PoolingType::Sum);
@@ -1644,11 +1958,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNcwSumOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNcwSumOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNcwSumOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
PoolingType::Sum);
@@ -1668,11 +1990,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNwcMaxOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNwcMaxOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNwcMaxOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
PoolingType::MaxSigned);
@@ -1692,11 +2022,20 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNwcMaxUnsignedOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNwcMaxUnsignedOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp =
+ dyn_cast<linalg::PoolingNwcMaxUnsignedOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
PoolingType::MaxUnsigned);
@@ -1716,11 +2055,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNcwMaxOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNcwMaxOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNcwMaxOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
PoolingType::MaxSigned);
@@ -1740,11 +2087,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNwcMinOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNwcMinOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNwcMinOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
PoolingType::MinSigned);
@@ -1764,11 +2119,20 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNwcMinUnsignedOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNwcMinUnsignedOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp =
+ dyn_cast<linalg::PoolingNwcMinUnsignedOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
PoolingType::MinUnsigned);
@@ -1788,11 +2152,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNdhwcSumOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNdhwcSumOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNdhwcSumOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
PoolingType::Sum);
@@ -1819,11 +2191,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNdhwcMaxOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNdhwcMaxOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNdhwcMaxOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
PoolingType::MaxSigned);
@@ -1850,11 +2230,19 @@ template <>
bool isaConvolutionOpOfType<linalg::PoolingNdhwcMinOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> *strides) {
- if (isa<linalg::PoolingNdhwcMinOp>(op))
+ SmallVector<int64_t> localDilations, localStrides;
+ if (!dilations)
+ dilations = &localDilations;
+ if (!strides)
+ strides = &localStrides;
+ if (auto poolOp = dyn_cast<linalg::PoolingNdhwcMinOp>(op.getOperation())) {
+ *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
return true;
+ }
- assert(isaConvolutionOpInterface(op) &&
- "expected op to implement ConvolutionOpInterface");
+ if (!isaConvolutionOpInterface(op))
+ return false;
ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
PoolingType::MinSigned);
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 60a4c555fa19a..7798cb76e4fb9 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -3,113 +3,168 @@
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map_nhwc_hwcf_input = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+#map_nhwc_hwcf_filter = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+#map_nhwc_hwcf_output = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+
// CHECK-LABEL: @conv_2d_nhwc_hwcf
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
- // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
- // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
- // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
- // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_nwc_wcf
- // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ // CHECK: tensor.extract_slice %[[ARG0]]
+ // CHECK: tensor.extract_slice %[[ARG1]]
+ // CHECK: tensor.extract_slice %[[ARG2]]
+ // Both named and generic ops should decompose to conv_1d_nwc_wcf
+ // CHECK-COUNT-2: linalg.conv_1d_nwc_wcf
%0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>)
outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
- // CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ // Generic op version with same semantics.
+ %1 = linalg.generic {indexing_maps = [#map_nhwc_hwcf_input, #map_nhwc_hwcf_filter, #map_nhwc_hwcf_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.mulf %in, %in_0 : f32
+ %3 = arith.addf %out, %2 : f32
+ linalg.yield %3 : f32
+ } -> tensor<?x1x?x?xf32>
+ return %1 : tensor<?x1x?x?xf32>
}
+#map_nchw_fchw_input = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+#map_nchw_fchw_filter = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+#map_nchw_fchw_output = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+
// CHECK-LABEL: @conv_2d_nchw_fchw
// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
- // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
- // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
- // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
- // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_ncw_fcw
- // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ // CHECK: tensor.extract_slice %[[ARG0]]
+ // CHECK: tensor.extract_slice %[[ARG1]]
+ // CHECK: tensor.extract_slice %[[ARG2]]
+ // Both named and generic ops should decompose to conv_1d_ncw_fcw
+ // CHECK-COUNT-2: linalg.conv_1d_ncw_fcw
%0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>)
outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
- // CHECK: return %[[RES]]
- return %0 : tensor<?x?x1x?xf32>
+ // Generic op version with same semantics.
+ %1 = linalg.generic {indexing_maps = [#map_nchw_fchw_input, #map_nchw_fchw_filter, #map_nchw_fchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.mulf %in, %in_0 : f32
+ %3 = arith.addf %out, %2 : f32
+ linalg.yield %3 : f32
+ } -> tensor<?x?x1x?xf32>
+ return %1 : tensor<?x?x1x?xf32>
}
+#map_depthwise_nhwc_hwc_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d3)>
+#map_depthwise_nhwc_hwc_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+#map_depthwise_nhwc_hwc_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
// CHECK: %[[RES:.+]] = tensor.empty
%init = tensor.empty() : tensor<1x1x56x96xf32>
- // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
- // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
- // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]]
- // CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc
- // CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]]
- // CHECK-SAME: outs(%[[SLICERES]]
- // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]]
+ // CHECK: tensor.extract_slice %[[ARG0]]
+ // CHECK: tensor.extract_slice %[[ARG1]]
+ // Both named and generic ops should decompose to depthwise_conv_1d_nwc_wc
+ // CHECK-COUNT-2: linalg.depthwise_conv_1d_nwc_wc
%0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
outs(%init: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
- // CHECK: %[[INSERTED]]
- return %0: tensor<1x1x56x96xf32>
+ // Generic op version with same semantics (strides = 2).
+ %1 = linalg.generic {indexing_maps = [#map_depthwise_nhwc_hwc_input, #map_depthwise_nhwc_hwc_filter, #map_depthwise_nhwc_hwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<1x1x113x96xf32>, tensor<1x3x96xf32>) outs(%0 : tensor<1x1x56x96xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.mulf %in, %in_0 : f32
+ %3 = arith.addf %out, %2 : f32
+ linalg.yield %3 : f32
+ } -> tensor<1x1x56x96xf32>
+ return %1: tensor<1x1x56x96xf32>
}
+#map_conv_2d_input = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
+#map_conv_2d_filter = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map_conv_2d_output = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
// CHECK-LABEL: @conv_2d
// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>,
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<1x?xf32>)
func.func @conv_2d(%input: tensor<1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<1x?xf32>) -> tensor<1x?xf32> {
- // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
- // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
- // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
- // CHECK: %[[SLICERES:.+]] = linalg.conv_1d
- // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ // CHECK: tensor.extract_slice %[[ARG0]]
+ // CHECK: tensor.extract_slice %[[ARG1]]
+ // CHECK: tensor.extract_slice %[[ARG2]]
+ // Both named and generic ops should decompose to conv_1d
+ // CHECK-COUNT-2: linalg.conv_1d
%0 = linalg.conv_2d
ins (%input, %filter: tensor<1x?xf32>, tensor<1x?xf32>)
outs (%init: tensor<1x?xf32>) -> tensor<1x?xf32>
- // CHECK: return %[[RES]]
- return %0 : tensor<1x?xf32>
+ // Generic op version with same semantics.
+ %1 = linalg.generic {indexing_maps = [#map_conv_2d_input, #map_conv_2d_filter, #map_conv_2d_output], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<1x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.mulf %in, %in_0 : f32
+ %3 = arith.addf %out, %2 : f32
+ linalg.yield %3 : f32
+ } -> tensor<1x?xf32>
+ return %1 : tensor<1x?xf32>
}
+#map_pooling_nhwc_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+#map_pooling_nhwc_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+#map_pooling_nhwc_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
// CHECK-LABEL: @pooling_nhwc_sum
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
func.func @pooling_nhwc_sum(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
- // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
- // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
- // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
- // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_sum
- // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ // CHECK: tensor.extract_slice %[[ARG0]]
+ // CHECK: tensor.extract_slice %[[ARG1]]
+ // CHECK: tensor.extract_slice %[[ARG2]]
+ // Both named and generic ops should decompose to pooling_nwc_sum
+ // CHECK-COUNT-2: linalg.pooling_nwc_sum
%0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
- // CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ // Generic op version with same semantics.
+ %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.addf %out, %in : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x1x?x?xf32>
+ return %1 : tensor<?x1x?x?xf32>
}
+#map_pooling_nchw_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
+#map_pooling_nchw_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+#map_pooling_nchw_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
// CHECK-LABEL: @pooling_nchw_sum
// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
- // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
- // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
- // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
- // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_sum
- // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ // CHECK: tensor.extract_slice %[[ARG0]]
+ // CHECK: tensor.extract_slice %[[ARG1]]
+ // CHECK: tensor.extract_slice %[[ARG2]]
+ // Both named and generic ops should decompose to pooling_ncw_sum
+ // CHECK-COUNT-2: linalg.pooling_ncw_sum
%0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
- // CHECK: return %[[RES]]
- return %0 : tensor<?x?x1x?xf32>
+ // Generic op version with same semantics.
+ %1 = linalg.generic {indexing_maps = [#map_pooling_nchw_input, #map_pooling_nchw_filter, #map_pooling_nchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.addf %out, %in : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?x1x?xf32>
+ return %1 : tensor<?x?x1x?xf32>
}
// CHECK-LABEL: @pooling_nhwc_max
@@ -117,17 +172,22 @@ func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
- // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
- // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
- // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
- // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max
- // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ // CHECK: tensor.extract_slice %[[ARG0]]
+ // CHECK: tensor.extract_slice %[[ARG1]]
+ // CHECK: tensor.extract_slice %[[ARG2]]
+ // Both named and generic ops should decompose to pooling_nwc_max
+ // CHECK-COUNT-2: linalg.pooling_nwc_max
%0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
- // CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ // Generic op version with same semantics.
+ %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.maximumf %out, %in : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x1x?x?xf32>
+ return %1 : tensor<?x1x?x?xf32>
}
// CHECK-LABEL: @pooling_nhwc_max_unsigned
@@ -135,17 +195,22 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
- // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
- // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
- // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
- // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned
- // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ // CHECK: tensor.extract_slice %[[ARG0]]
+ // CHECK: tensor.extract_slice %[[ARG1]]
+ // CHECK: tensor.extract_slice %[[ARG2]]
+ // Both named and generic ops should decompose to pooling_nwc_max_unsigned
+ // CHECK-COUNT-2: linalg.pooling_nwc_max_unsigned
%0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
- // CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xi32>
+ // Generic op version with same semantics.
+ %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xi32>, tensor<1x?xi32>) outs(%0 : tensor<?x1x?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %2 = arith.maxui %out, %in : i32
+ linalg.yield %2 : i32
+ } -> tensor<?x1x?x?xi32>
+ return %1 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nhwc_min
@@ -153,17 +218,22 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tenso
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
- // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
- // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
- // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
- // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min
- // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ // CHECK: tensor.extract_slice %[[ARG0]]
+ // CHECK: tensor.extract_slice %[[ARG1]]
+ // CHECK: tensor.extract_slice %[[ARG2]]
+ // Both named and generic ops should decompose to pooling_nwc_min
+ // CHECK-COUNT-2: linalg.pooling_nwc_min
%0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
- // CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ // Generic op version with same semantics.
+ %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.minimumf %out, %in : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x1x?x?xf32>
+ return %1 : tensor<?x1x?x?xf32>
}
// CHECK-LABEL: @pooling_nhwc_min_unsigned
@@ -171,17 +241,22 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
- // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
- // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
- // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
- // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned
- // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ // CHECK: tensor.extract_slice %[[ARG0]]
+ // CHECK: tensor.extract_slice %[[ARG1]]
+ // CHECK: tensor.extract_slice %[[ARG2]]
+ // Both named and generic ops should decompose to pooling_nwc_min_unsigned
+ // CHECK-COUNT-2: linalg.pooling_nwc_min_unsigned
%0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
- // CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xi32>
+ // Generic op version with same semantics.
+ %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xi32>, tensor<1x?xi32>) outs(%0 : tensor<?x1x?x?xi32>) {
+ ^bb0(%in: i32, %in_0: i32, %out: i32):
+ %2 = arith.minui %out, %in : i32
+ linalg.yield %2 : i32
+ } -> tensor<?x1x?x?xi32>
+ return %1 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nchw_max
@@ -189,17 +264,22 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tenso
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
- // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
- // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
- // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
- // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_max
- // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ // CHECK: tensor.extract_slice %[[ARG0]]
+ // CHECK: tensor.extract_slice %[[ARG1]]
+ // CHECK: tensor.extract_slice %[[ARG2]]
+ // Both named and generic ops should decompose to pooling_ncw_max
+ // CHECK-COUNT-2: linalg.pooling_ncw_max
%0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
- // CHECK: return %[[RES]]
- return %0 : tensor<?x?x1x?xf32>
+ // Generic op version with same semantics.
+ %1 = linalg.generic {indexing_maps = [#map_pooling_nchw_input, #map_pooling_nchw_filter, #map_pooling_nchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %2 = arith.maximumf %out, %in : f32
+ linalg.yield %2 : f32
+ } -> tensor<?x?x1x?xf32>
+ return %1 : tensor<?x?x1x?xf32>
}
func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
>From 10c93d8a300d497e5730cee7d1f547917a75df5f Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 6 Jan 2026 08:10:41 +0000
Subject: [PATCH 2/2] Review comment by Hanhan v1.0 : Update API + different
RUN line
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 17 +-
.../Dialect/Linalg/Transforms/Specialize.cpp | 8 +-
.../Dialect/Linalg/Transforms/Transforms.cpp | 20 +-
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 1827 ++++++++---------
.../Linalg/transform-op-decompose.mlir | 239 +--
5 files changed, 996 insertions(+), 1115 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 16d557a6ed7fa..db0cf474d6254 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -106,14 +106,17 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
// Convolution matcher utility
//===----------------------------------------------------------------------===//
-/// Given a linalg `op` this function returns true if it is a convolution op of
-/// type `ConvOpTy` and populates `dilations` and `strides` with values inferred
-/// from the indexing maps. If `dilations` or `strides` is nullptr, the
-/// corresponding values are not populated.
+/// A struct containing dilations and strides inferred from convolution ops.
+struct DilationsAndStrides {
+ SmallVector<int64_t> dilations;
+ SmallVector<int64_t> strides;
+};
+
+/// Given a linalg `op` this function returns DilationsAndStrides if it is a
+/// convolution op of type `ConvOpTy`, otherwise returns std::nullopt. The
+/// dilations and strides are inferred from the indexing maps.
template <typename ConvOpTy>
-bool isaConvolutionOpOfType(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
+std::optional<DilationsAndStrides> isaConvolutionOpOfType(LinalgOp op);
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 0c7b998ffcab9..7db401b253abe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -267,11 +267,11 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
/// Converts linalg.generic to named linalg.*conv/pooling* where possible.
static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
GenericOp genericOp) {
- SmallVector<int64_t> dilations, strides;
#define CONV_OP_SPECIALIZER(ConvOpTy) \
- if (isaConvolutionOpOfType<ConvOpTy>(genericOp, &dilations, &strides)) \
- return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \
- strides); \
+ if (std::optional<DilationsAndStrides> convParams = \
+ isaConvolutionOpOfType<ConvOpTy>(genericOp)) \
+ return specializeToConvOp<ConvOpTy>( \
+ rewriter, genericOp, convParams->dilations, convParams->strides); \
// -----------------------------
// Convolution ops.
// -----------------------------
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7972408318b95..c1deed891a68c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1409,9 +1409,12 @@ template <typename Conv2DOp, typename Conv1DOp>
FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const {
// Check if this LinalgOp is of the expected Conv2DOp type (named or generic).
- SmallVector<int64_t> dilations, strides;
- if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+ std::optional<DilationsAndStrides> convParams =
+ linalg::isaConvolutionOpOfType<Conv2DOp>(convOp);
+ if (!convParams)
return failure();
+ SmallVector<int64_t> dilations = std::move(convParams->dilations);
+ SmallVector<int64_t> strides = std::move(convParams->strides);
if (convOp.hasPureBufferSemantics())
return failure(); // To be implemented.
@@ -1526,10 +1529,12 @@ FailureOr<DepthwiseConv1DNwcWcOp>
DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
LinalgOp convOp, PatternRewriter &rewriter) const {
// Check if this LinalgOp is a DepthwiseConv2DNhwcHwcOp (named or generic).
- SmallVector<int64_t> dilations, strides;
- if (!linalg::isaConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(
- convOp, &dilations, &strides))
+ std::optional<DilationsAndStrides> convParams =
+ linalg::isaConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(convOp);
+ if (!convParams)
return failure();
+ SmallVector<int64_t> dilations = std::move(convParams->dilations);
+ SmallVector<int64_t> strides = std::move(convParams->strides);
if (convOp.hasPureBufferSemantics())
return failure(); // To be implemented.
@@ -1597,8 +1602,9 @@ FailureOr<Conv1DOp>
DownscaleConv2DOp::returningMatchAndRewrite(LinalgOp convOp,
PatternRewriter &rewriter) const {
// Check if this LinalgOp is a Conv2DOp (named or generic).
- SmallVector<int64_t> dilations, strides;
- if (!linalg::isaConvolutionOpOfType<Conv2DOp>(convOp, &dilations, &strides))
+ std::optional<DilationsAndStrides> convParams =
+ linalg::isaConvolutionOpOfType<Conv2DOp>(convOp);
+ if (!convParams)
return failure();
if (convOp.hasPureBufferSemantics())
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 1cdd01567c4e7..1a99e3d64f3f8 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -592,151 +592,142 @@ class ConvMatcherBuilder {
//===----------------------------------------------------------------------===//
template <>
-bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
- SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (isa<linalg::Conv1DOp>(op)) {
// Conv1DOp has no strides/dilations attributes, default to 1.
- *dilations = SmallVector<int64_t>(1, 1);
- *strides = SmallVector<int64_t>(1, 1);
- return true;
+ result.dilations = SmallVector<int64_t>(1, 1);
+ result.strides = SmallVector<int64_t>(1, 1);
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides);
AffineExpr W = m.dim(0);
AffineExpr w = m.dim(1);
- return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
- .matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
- /*filterMap=*/{w},
- /*outputMap=*/{W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
+ /*filterMap=*/{w},
+ /*outputMap=*/{W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv1DNwcWcfOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr W = m.dim(1);
AffineExpr F = m.dim(2);
AffineExpr w = m.dim(3);
AffineExpr c = m.dim(4);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
- /*filterMap=*/{w, c, F},
- /*outputMap=*/{N, W, F}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
+ /*filterMap=*/{w, c, F},
+ /*outputMap=*/{N, W, F}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv1DNcwFcwOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr F = m.dim(1);
AffineExpr W = m.dim(2);
AffineExpr c = m.dim(3);
AffineExpr w = m.dim(4);
- return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
- .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
- /*filterMap=*/{F, c, w},
- /*outputMap=*/{N, F, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
+ /*filterMap=*/{F, c, w},
+ /*outputMap=*/{N, F, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
- SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (isa<linalg::Conv2DOp>(op)) {
// Conv2DOp has no strides/dilations attributes, default to 1.
- *dilations = SmallVector<int64_t>(2, 1);
- *strides = SmallVector<int64_t>(2, 1);
- return true;
+ result.dilations = SmallVector<int64_t>(2, 1);
+ result.strides = SmallVector<int64_t>(2, 1);
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr H = m.dim(0);
AffineExpr W = m.dim(1);
AffineExpr h = m.dim(2);
AffineExpr w = m.dim(3);
- return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
- .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
- .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{h, w},
- /*outputMap=*/{H, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+ .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
+ .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{H, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -745,33 +736,33 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
AffineExpr w = m.dim(5);
AffineExpr c = m.dim(6);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
- /*filterMap=*/{h, w, c, F},
- /*outputMap=*/{N, H, W, F}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+ /*filterMap=*/{h, w, c, F},
+ /*outputMap=*/{N, H, W, F}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfQOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -780,35 +771,35 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
AffineExpr w = m.dim(5);
AffineExpr c = m.dim(6);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
- /*filterMap=*/{h, w, c, F},
- /*scalarMap=*/{},
- /*scalarMap=*/{},
- /*outputMap=*/{N, H, W, F}})
- .matchBody(/*containsZeroPointOffset=*/true);
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+ /*filterMap=*/{h, w, c, F},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, H, W, F}})
+ .matchBody(/*containsZeroPointOffset=*/true))
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -817,33 +808,33 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
AffineExpr w = m.dim(5);
AffineExpr c = m.dim(6);
- return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
- /*filterMap=*/{F, h, w, c},
- /*outputMap=*/{N, H, W, F}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+ /*filterMap=*/{F, h, w, c},
+ /*outputMap=*/{N, H, W, F}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcQOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -852,35 +843,35 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
AffineExpr w = m.dim(5);
AffineExpr c = m.dim(6);
- return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
- /*filterMap=*/{F, h, w, c},
- /*scalarMap=*/{},
- /*scalarMap=*/{},
- /*outputMap=*/{N, H, W, F}})
- .matchBody(/*containsZeroPointOffset=*/true);
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
+ /*filterMap=*/{F, h, w, c},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, H, W, F}})
+ .matchBody(/*containsZeroPointOffset=*/true))
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv2DNchwFchwOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr F = m.dim(1);
AffineExpr H = m.dim(2);
@@ -889,33 +880,33 @@ bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
AffineExpr h = m.dim(5);
AffineExpr w = m.dim(6);
- return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
- .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{F, c, h, w},
- /*outputMap=*/{N, F, H, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+ .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{F, c, h, w},
+ /*outputMap=*/{N, F, H, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv2DNchwFchwQOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr F = m.dim(1);
AffineExpr H = m.dim(2);
@@ -924,35 +915,35 @@ bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
AffineExpr h = m.dim(5);
AffineExpr w = m.dim(6);
- return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
- .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{F, c, h, w},
- /*scalarMap=*/{},
- /*scalarMap=*/{},
- /*outputMap=*/{N, F, H, W}})
- .matchBody(/*containsZeroPointOffset=*/true);
+ if (m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+ .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{F, c, h, w},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, F, H, W}})
+ .matchBody(/*containsZeroPointOffset=*/true))
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv2DNgchwFgchwOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr G = m.dim(1);
AffineExpr F = m.dim(2);
@@ -962,34 +953,33 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
AffineExpr h = m.dim(6);
AffineExpr w = m.dim(7);
- return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
- .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
- .matchMaps(
- {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{F, G, c, h, w},
- /*outputMap=*/{N, G, F, H, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
+ .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{F, G, c, h, w},
+ /*outputMap=*/{N, G, F, H, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr G = m.dim(1);
AffineExpr F = m.dim(2);
@@ -999,34 +989,33 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
AffineExpr h = m.dim(6);
AffineExpr w = m.dim(7);
- return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
- .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
- .matchMaps(
- {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{G, F, c, h, w},
- /*outputMap=*/{N, G, F, H, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
+ .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{G, F, c, h, w},
+ /*outputMap=*/{N, G, F, H, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwQOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr G = m.dim(1);
AffineExpr F = m.dim(2);
@@ -1036,36 +1025,35 @@ bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
AffineExpr h = m.dim(6);
AffineExpr w = m.dim(7);
- return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
- .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
- .matchMaps(
- {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{G, F, c, h, w},
- /*scalarMap=*/{},
- /*scalarMap=*/{},
- /*outputMap=*/{N, G, F, H, W}})
- .matchBody(/*containsZeroPointOffset=*/true);
+ if (m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
+ .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{G, F, c, h, w},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, G, F, H, W}})
+ .matchBody(/*containsZeroPointOffset=*/true))
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1075,34 +1063,33 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
AffineExpr w = m.dim(6);
AffineExpr c = m.dim(7);
- return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
- .matchMaps(
- {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
- /*filterMap=*/{G, F, h, w, c},
- /*outputMap=*/{N, H, W, G, F}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
+ /*filterMap=*/{G, F, h, w, c},
+ /*outputMap=*/{N, H, W, G, F}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcQOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1112,37 +1099,35 @@ bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
AffineExpr w = m.dim(6);
AffineExpr c = m.dim(7);
- return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
- .matchMaps(
- {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
- /*filterMap=*/{G, F, h, w, c},
- /*scalarMap=*/{},
- /*scalarMap=*/{},
- /*outputMap=*/{N, H, W, G, F}})
- .matchBody(/*containsZeroPointOffset=*/true);
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
+ /*filterMap=*/{G, F, h, w, c},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, H, W, G, F}})
+ .matchBody(/*containsZeroPointOffset=*/true))
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
- SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (isa<linalg::Conv3DOp>(op)) {
// Conv3DOp has no strides/dilations attributes, default to 1.
- *dilations = SmallVector<int64_t>(3, 1);
- *strides = SmallVector<int64_t>(3, 1);
- return true;
+ result.dilations = SmallVector<int64_t>(3, 1);
+ result.strides = SmallVector<int64_t>(3, 1);
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
+ &result.strides);
AffineExpr D = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1150,35 +1135,34 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
AffineExpr h = m.dim(4);
AffineExpr w = m.dim(5);
- return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
- .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
- .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
- .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2)},
- /*filterMap=*/{d, h, w},
- /*outputMap=*/{D, H, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
+ .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
+ .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2)},
+ /*filterMap=*/{d, h, w},
+ /*outputMap=*/{D, H, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr D = m.dim(1);
AffineExpr H = m.dim(2);
@@ -1189,35 +1173,34 @@ bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfOp>(
AffineExpr w = m.dim(7);
AffineExpr c = m.dim(8);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
- .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2), c},
- /*filterMap=*/{d, h, w, c, F},
- /*outputMap=*/{N, D, H, W, F}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2), c},
+ /*filterMap=*/{d, h, w, c, F},
+ /*outputMap=*/{N, D, H, W, F}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfQOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfQOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfQOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr D = m.dim(1);
AffineExpr H = m.dim(2);
@@ -1228,37 +1211,36 @@ bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfQOp>(
AffineExpr w = m.dim(7);
AffineExpr c = m.dim(8);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
- .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2), c},
- /*filterMap=*/{d, h, w, c, F},
- /*scalarMap=*/{},
- /*scalarMap=*/{},
- /*outputMap=*/{N, D, H, W, F}})
- .matchBody(/*containsZeroPointOffset=*/true);
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2), c},
+ /*filterMap=*/{d, h, w, c, F},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, D, H, W, F}})
+ .matchBody(/*containsZeroPointOffset=*/true))
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::Conv3DNcdhwFcdhwOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::Conv3DNcdhwFcdhwOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp = dyn_cast<linalg::Conv3DNcdhwFcdhwOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr F = m.dim(1);
AffineExpr D = m.dim(2);
@@ -1269,133 +1251,129 @@ bool isaConvolutionOpOfType<linalg::Conv3DNcdhwFcdhwOp>(
AffineExpr h = m.dim(7);
AffineExpr w = m.dim(8);
- return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
- .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
- .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/2)
- .matchMaps({/*inputMap=*/{N, c, m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2)},
- /*filterMap=*/{F, c, d, h, w},
- /*outputMap=*/{N, F, D, H, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
+ .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
+ .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{N, c, m.strided(D, d, 0),
+ m.strided(H, h, 1), m.strided(W, w, 2)},
+ /*filterMap=*/{F, c, d, h, w},
+ /*outputMap=*/{N, F, D, H, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp =
dyn_cast<linalg::DepthwiseConv1DNcwCwOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr W = m.dim(1);
AffineExpr C = m.dim(2);
AffineExpr w = m.dim(3);
- return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
- .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
- /*filterMap=*/{C, w},
- /*outputMap=*/{N, C, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+ /*filterMap=*/{C, w},
+ /*outputMap=*/{N, C, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp =
dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr W = m.dim(1);
AffineExpr C = m.dim(2);
AffineExpr w = m.dim(3);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
- /*filterMap=*/{w, C},
- /*outputMap=*/{N, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w, C},
+ /*outputMap=*/{N, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp =
dyn_cast<linalg::DepthwiseConv1DNwcWcmOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr W = m.dim(1);
AffineExpr C = m.dim(2);
AffineExpr CM = m.dim(3);
AffineExpr w = m.dim(4);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
- /*filterMap=*/{w, C, CM},
- /*outputMap=*/{N, W, C, CM}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w, C, CM},
+ /*outputMap=*/{N, W, C, CM}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp =
dyn_cast<linalg::DepthwiseConv2DNchwChwOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1403,34 +1381,34 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
AffineExpr h = m.dim(4);
AffineExpr w = m.dim(5);
- return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
- .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{C, h, w},
- /*outputMap=*/{N, C, H, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{C, h, w},
+ /*outputMap=*/{N, C, H, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp =
dyn_cast<linalg::DepthwiseConv2DNhwcHwcOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1438,34 +1416,34 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
AffineExpr h = m.dim(4);
AffineExpr w = m.dim(5);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w, C},
- /*outputMap=*/{N, H, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w, C},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp =
dyn_cast<linalg::DepthwiseConv2DNhwcHwcQOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1473,36 +1451,36 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
AffineExpr h = m.dim(4);
AffineExpr w = m.dim(5);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w, C},
- /*scalarMap=*/{},
- /*scalarMap=*/{},
- /*outputMap=*/{N, H, W, C}})
- .matchBody(/*containsZeroPointOffset=*/true);
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w, C},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody(/*containsZeroPointOffset=*/true))
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp =
dyn_cast<linalg::DepthwiseConv2DNhwcHwcmOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1511,34 +1489,34 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
AffineExpr h = m.dim(5);
AffineExpr w = m.dim(6);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w, C, CM},
- /*outputMap=*/{N, H, W, C, CM}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w, C, CM},
+ /*outputMap=*/{N, H, W, C, CM}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp =
dyn_cast<linalg::DepthwiseConv2DNhwcHwcmQOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1547,36 +1525,36 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
AffineExpr h = m.dim(5);
AffineExpr w = m.dim(6);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w, C, CM},
- /*scalarMap=*/{},
- /*scalarMap=*/{},
- /*outputMap=*/{N, H, W, C, CM}})
- .matchBody(/*containsZeroPointOffset=*/true);
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w, C, CM},
+ /*scalarMap=*/{},
+ /*scalarMap=*/{},
+ /*outputMap=*/{N, H, W, C, CM}})
+ .matchBody(/*containsZeroPointOffset=*/true))
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp =
dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr D = m.dim(1);
AffineExpr H = m.dim(2);
@@ -1586,36 +1564,35 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcOp>(
AffineExpr w = m.dim(6);
AffineExpr C = m.dim(7);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
- .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2), C},
- /*filterMap=*/{d, h, w, C},
- /*outputMap=*/{N, D, H, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2), C},
+ /*filterMap=*/{d, h, w, C},
+ /*outputMap=*/{N, D, H, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNcdhwCdhwOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::DepthwiseConv3DNcdhwCdhwOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp =
dyn_cast<linalg::DepthwiseConv3DNcdhwCdhwOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr D = m.dim(1);
AffineExpr H = m.dim(2);
@@ -1625,36 +1602,35 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNcdhwCdhwOp>(
AffineExpr w = m.dim(6);
AffineExpr C = m.dim(7);
- return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
- .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
- .matchStride(/*iDim=*/4, /*fDim=*/3, /*oDim=*/4, /*idx=*/2)
- .matchMaps({/*inputMap=*/{N, C, m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2)},
- /*filterMap=*/{C, d, h, w},
- /*outputMap=*/{N, C, D, H, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
+ .matchStride(/*iDim=*/4, /*fDim=*/3, /*oDim=*/4, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{N, C, m.strided(D, d, 0),
+ m.strided(H, h, 1), m.strided(W, w, 2)},
+ /*filterMap=*/{C, d, h, w},
+ /*outputMap=*/{N, C, D, H, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto convOp =
dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op.getOperation())) {
- *dilations = llvm::to_vector(convOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(convOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
+ &result.strides);
AffineExpr N = m.dim(0);
AffineExpr D = m.dim(1);
AffineExpr H = m.dim(2);
@@ -1665,36 +1641,34 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
AffineExpr w = m.dim(7);
AffineExpr C = m.dim(8);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
- .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2), C},
- /*filterMap=*/{d, h, w, C, CM},
- /*outputMap=*/{N, D, H, W, C, CM}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2), C},
+ /*filterMap=*/{d, h, w, C, CM},
+ /*outputMap=*/{N, D, H, W, C, CM}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNhwcMaxOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MaxSigned);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides, PoolingType::MaxSigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1702,34 +1676,33 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
AffineExpr h = m.dim(4);
AffineExpr w = m.dim(5);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNhwcMinOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MinSigned);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides, PoolingType::MinSigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1737,34 +1710,33 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
AffineExpr h = m.dim(4);
AffineExpr w = m.dim(5);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNhwcSumOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::Sum);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides, PoolingType::Sum);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1772,35 +1744,34 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
AffineExpr h = m.dim(4);
AffineExpr w = m.dim(5);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp =
dyn_cast<linalg::PoolingNhwcMaxUnsignedOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MaxUnsigned);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides, PoolingType::MaxUnsigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1808,35 +1779,34 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
AffineExpr h = m.dim(4);
AffineExpr w = m.dim(5);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp =
dyn_cast<linalg::PoolingNhwcMinUnsignedOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MinUnsigned);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides, PoolingType::MinUnsigned);
AffineExpr N = m.dim(0);
AffineExpr H = m.dim(1);
AffineExpr W = m.dim(2);
@@ -1844,34 +1814,33 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
AffineExpr h = m.dim(4);
AffineExpr w = m.dim(5);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, H, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, H, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNchwSumOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNchwSumOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNchwSumOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::Sum);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides, PoolingType::Sum);
AffineExpr N = m.dim(0);
AffineExpr C = m.dim(1);
AffineExpr H = m.dim(2);
@@ -1879,34 +1848,33 @@ bool isaConvolutionOpOfType<linalg::PoolingNchwSumOp>(
AffineExpr h = m.dim(4);
AffineExpr w = m.dim(5);
- return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
- .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, C, H, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
+ .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, C, H, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNchwMaxOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNchwMaxOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNchwMaxOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
- PoolingType::MaxSigned);
+ ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
+ &result.strides, PoolingType::MaxSigned);
AffineExpr N = m.dim(0);
AffineExpr C = m.dim(1);
AffineExpr H = m.dim(2);
@@ -1914,260 +1882,245 @@ bool isaConvolutionOpOfType<linalg::PoolingNchwMaxOp>(
AffineExpr h = m.dim(4);
AffineExpr w = m.dim(5);
- return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
- .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1)
- .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
- /*filterMap=*/{h, w},
- /*outputMap=*/{N, C, H, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
+ .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1)
+ .matchMaps(
+ {/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
+ /*filterMap=*/{h, w},
+ /*outputMap=*/{N, C, H, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNwcSumOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNwcSumOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNwcSumOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
- PoolingType::Sum);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides, PoolingType::Sum);
AffineExpr N = m.dim(0);
AffineExpr W = m.dim(1);
AffineExpr C = m.dim(2);
AffineExpr w = m.dim(3);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
- /*filterMap=*/{w},
- /*outputMap=*/{N, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w},
+ /*outputMap=*/{N, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNcwSumOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNcwSumOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNcwSumOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
- PoolingType::Sum);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides, PoolingType::Sum);
AffineExpr N = m.dim(0);
AffineExpr C = m.dim(1);
AffineExpr W = m.dim(2);
AffineExpr w = m.dim(3);
- return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
- .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
- /*filterMap=*/{w},
- /*outputMap=*/{N, C, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+ /*filterMap=*/{w},
+ /*outputMap=*/{N, C, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNwcMaxOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNwcMaxOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNwcMaxOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
- PoolingType::MaxSigned);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides, PoolingType::MaxSigned);
AffineExpr N = m.dim(0);
AffineExpr W = m.dim(1);
AffineExpr C = m.dim(2);
AffineExpr w = m.dim(3);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
- /*filterMap=*/{w},
- /*outputMap=*/{N, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w},
+ /*outputMap=*/{N, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNwcMaxUnsignedOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNwcMaxUnsignedOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp =
dyn_cast<linalg::PoolingNwcMaxUnsignedOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
- PoolingType::MaxUnsigned);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides, PoolingType::MaxUnsigned);
AffineExpr N = m.dim(0);
AffineExpr W = m.dim(1);
AffineExpr C = m.dim(2);
AffineExpr w = m.dim(3);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
- /*filterMap=*/{w},
- /*outputMap=*/{N, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w},
+ /*outputMap=*/{N, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNcwMaxOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNcwMaxOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNcwMaxOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
- PoolingType::MaxSigned);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides, PoolingType::MaxSigned);
AffineExpr N = m.dim(0);
AffineExpr C = m.dim(1);
AffineExpr W = m.dim(2);
AffineExpr w = m.dim(3);
- return m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
- .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
- /*filterMap=*/{w},
- /*outputMap=*/{N, C, W}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
+ /*filterMap=*/{w},
+ /*outputMap=*/{N, C, W}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNwcMinOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNwcMinOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNwcMinOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
- PoolingType::MinSigned);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides, PoolingType::MinSigned);
AffineExpr N = m.dim(0);
AffineExpr W = m.dim(1);
AffineExpr C = m.dim(2);
AffineExpr w = m.dim(3);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
- /*filterMap=*/{w},
- /*outputMap=*/{N, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w},
+ /*outputMap=*/{N, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNwcMinUnsignedOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNwcMinUnsignedOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp =
dyn_cast<linalg::PoolingNwcMinUnsignedOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides,
- PoolingType::MinUnsigned);
+ ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
+ &result.strides, PoolingType::MinUnsigned);
AffineExpr N = m.dim(0);
AffineExpr W = m.dim(1);
AffineExpr C = m.dim(2);
AffineExpr w = m.dim(3);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
- /*filterMap=*/{w},
- /*outputMap=*/{N, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
+ /*filterMap=*/{w},
+ /*outputMap=*/{N, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNdhwcSumOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNdhwcSumOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNdhwcSumOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
- PoolingType::Sum);
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
+ &result.strides, PoolingType::Sum);
AffineExpr N = m.dim(0);
AffineExpr D = m.dim(1);
AffineExpr H = m.dim(2);
@@ -2177,36 +2130,34 @@ bool isaConvolutionOpOfType<linalg::PoolingNdhwcSumOp>(
AffineExpr h = m.dim(6);
AffineExpr w = m.dim(7);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
- .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2), C},
- /*filterMap=*/{d, h, w},
- /*outputMap=*/{N, D, H, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2), C},
+ /*filterMap=*/{d, h, w},
+ /*outputMap=*/{N, D, H, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNdhwcMaxOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNdhwcMaxOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNdhwcMaxOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
- PoolingType::MaxSigned);
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
+ &result.strides, PoolingType::MaxSigned);
AffineExpr N = m.dim(0);
AffineExpr D = m.dim(1);
AffineExpr H = m.dim(2);
@@ -2216,36 +2167,34 @@ bool isaConvolutionOpOfType<linalg::PoolingNdhwcMaxOp>(
AffineExpr h = m.dim(6);
AffineExpr w = m.dim(7);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
- .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2), C},
- /*filterMap=*/{d, h, w},
- /*outputMap=*/{N, D, H, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2), C},
+ /*filterMap=*/{d, h, w},
+ /*outputMap=*/{N, D, H, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
template <>
-bool isaConvolutionOpOfType<linalg::PoolingNdhwcMinOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- SmallVector<int64_t> localDilations, localStrides;
- if (!dilations)
- dilations = &localDilations;
- if (!strides)
- strides = &localStrides;
+std::optional<DilationsAndStrides>
+isaConvolutionOpOfType<linalg::PoolingNdhwcMinOp>(LinalgOp op) {
+ DilationsAndStrides result;
if (auto poolOp = dyn_cast<linalg::PoolingNdhwcMinOp>(op.getOperation())) {
- *dilations = llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
- *strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
- return true;
+ result.dilations =
+ llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
+ result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
+ return result;
}
if (!isaConvolutionOpInterface(op))
- return false;
+ return std::nullopt;
- ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides,
- PoolingType::MinSigned);
+ ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
+ &result.strides, PoolingType::MinSigned);
AffineExpr N = m.dim(0);
AffineExpr D = m.dim(1);
AffineExpr H = m.dim(2);
@@ -2255,14 +2204,16 @@ bool isaConvolutionOpOfType<linalg::PoolingNdhwcMinOp>(
AffineExpr h = m.dim(6);
AffineExpr w = m.dim(7);
- return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
- .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
- .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
- .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
- m.strided(W, w, 2), C},
- /*filterMap=*/{d, h, w},
- /*outputMap=*/{N, D, H, W, C}})
- .matchBody();
+ if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
+ .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
+ .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
+ .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
+ m.strided(W, w, 2), C},
+ /*filterMap=*/{d, h, w},
+ /*outputMap=*/{N, D, H, W, C}})
+ .matchBody())
+ return result;
+ return std::nullopt;
}
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 7798cb76e4fb9..9c9aaf8c20b8d 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -1,170 +1,116 @@
// RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --linalg-generalize-named-ops --transform-interpreter --split-input-file %s | FileCheck %s
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
-#map_nhwc_hwcf_input = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
-#map_nhwc_hwcf_filter = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
-#map_nhwc_hwcf_output = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-
// CHECK-LABEL: @conv_2d_nhwc_hwcf
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
- // CHECK: tensor.extract_slice %[[ARG0]]
- // CHECK: tensor.extract_slice %[[ARG1]]
- // CHECK: tensor.extract_slice %[[ARG2]]
- // Both named and generic ops should decompose to conv_1d_nwc_wcf
- // CHECK-COUNT-2: linalg.conv_1d_nwc_wcf
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_nwc_wcf
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>)
outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
- // Generic op version with same semantics.
- %1 = linalg.generic {indexing_maps = [#map_nhwc_hwcf_input, #map_nhwc_hwcf_filter, #map_nhwc_hwcf_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.mulf %in, %in_0 : f32
- %3 = arith.addf %out, %2 : f32
- linalg.yield %3 : f32
- } -> tensor<?x1x?x?xf32>
- return %1 : tensor<?x1x?x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x1x?x?xf32>
}
-#map_nchw_fchw_input = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
-#map_nchw_fchw_filter = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
-#map_nchw_fchw_output = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
-
// CHECK-LABEL: @conv_2d_nchw_fchw
// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
- // CHECK: tensor.extract_slice %[[ARG0]]
- // CHECK: tensor.extract_slice %[[ARG1]]
- // CHECK: tensor.extract_slice %[[ARG2]]
- // Both named and generic ops should decompose to conv_1d_ncw_fcw
- // CHECK-COUNT-2: linalg.conv_1d_ncw_fcw
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_ncw_fcw
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>)
outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
- // Generic op version with same semantics.
- %1 = linalg.generic {indexing_maps = [#map_nchw_fchw_input, #map_nchw_fchw_filter, #map_nchw_fchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.mulf %in, %in_0 : f32
- %3 = arith.addf %out, %2 : f32
- linalg.yield %3 : f32
- } -> tensor<?x?x1x?xf32>
- return %1 : tensor<?x?x1x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x?x1x?xf32>
}
-#map_depthwise_nhwc_hwc_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 2 + d5, d3)>
-#map_depthwise_nhwc_hwc_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
-#map_depthwise_nhwc_hwc_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-
// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
// CHECK: %[[RES:.+]] = tensor.empty
%init = tensor.empty() : tensor<1x1x56x96xf32>
- // CHECK: tensor.extract_slice %[[ARG0]]
- // CHECK: tensor.extract_slice %[[ARG1]]
- // Both named and generic ops should decompose to depthwise_conv_1d_nwc_wc
- // CHECK-COUNT-2: linalg.depthwise_conv_1d_nwc_wc
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]]
+ // CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc
+ // CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]]
+ // CHECK-SAME: outs(%[[SLICERES]]
+ // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]]
%0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
outs(%init: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
- // Generic op version with same semantics (strides = 2).
- %1 = linalg.generic {indexing_maps = [#map_depthwise_nhwc_hwc_input, #map_depthwise_nhwc_hwc_filter, #map_depthwise_nhwc_hwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<1x1x113x96xf32>, tensor<1x3x96xf32>) outs(%0 : tensor<1x1x56x96xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.mulf %in, %in_0 : f32
- %3 = arith.addf %out, %2 : f32
- linalg.yield %3 : f32
- } -> tensor<1x1x56x96xf32>
- return %1: tensor<1x1x56x96xf32>
+ // CHECK: %[[INSERTED]]
+ return %0: tensor<1x1x56x96xf32>
}
-#map_conv_2d_input = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
-#map_conv_2d_filter = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
-#map_conv_2d_output = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
-
// CHECK-LABEL: @conv_2d
// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>,
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<1x?xf32>)
func.func @conv_2d(%input: tensor<1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<1x?xf32>) -> tensor<1x?xf32> {
- // CHECK: tensor.extract_slice %[[ARG0]]
- // CHECK: tensor.extract_slice %[[ARG1]]
- // CHECK: tensor.extract_slice %[[ARG2]]
- // Both named and generic ops should decompose to conv_1d
- // CHECK-COUNT-2: linalg.conv_1d
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.conv_1d
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.conv_2d
ins (%input, %filter: tensor<1x?xf32>, tensor<1x?xf32>)
outs (%init: tensor<1x?xf32>) -> tensor<1x?xf32>
- // Generic op version with same semantics.
- %1 = linalg.generic {indexing_maps = [#map_conv_2d_input, #map_conv_2d_filter, #map_conv_2d_output], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<1x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.mulf %in, %in_0 : f32
- %3 = arith.addf %out, %2 : f32
- linalg.yield %3 : f32
- } -> tensor<1x?xf32>
- return %1 : tensor<1x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<1x?xf32>
}
-#map_pooling_nhwc_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
-#map_pooling_nhwc_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-#map_pooling_nhwc_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-
// CHECK-LABEL: @pooling_nhwc_sum
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
func.func @pooling_nhwc_sum(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
- // CHECK: tensor.extract_slice %[[ARG0]]
- // CHECK: tensor.extract_slice %[[ARG1]]
- // CHECK: tensor.extract_slice %[[ARG2]]
- // Both named and generic ops should decompose to pooling_nwc_sum
- // CHECK-COUNT-2: linalg.pooling_nwc_sum
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_sum
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
- // Generic op version with same semantics.
- %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.addf %out, %in : f32
- linalg.yield %2 : f32
- } -> tensor<?x1x?x?xf32>
- return %1 : tensor<?x1x?x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x1x?x?xf32>
}
-#map_pooling_nchw_input = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
-#map_pooling_nchw_filter = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
-#map_pooling_nchw_output = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
-
// CHECK-LABEL: @pooling_nchw_sum
// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
- // CHECK: tensor.extract_slice %[[ARG0]]
- // CHECK: tensor.extract_slice %[[ARG1]]
- // CHECK: tensor.extract_slice %[[ARG2]]
- // Both named and generic ops should decompose to pooling_ncw_sum
- // CHECK-COUNT-2: linalg.pooling_ncw_sum
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_sum
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
- // Generic op version with same semantics.
- %1 = linalg.generic {indexing_maps = [#map_pooling_nchw_input, #map_pooling_nchw_filter, #map_pooling_nchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.addf %out, %in : f32
- linalg.yield %2 : f32
- } -> tensor<?x?x1x?xf32>
- return %1 : tensor<?x?x1x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x?x1x?xf32>
}
// CHECK-LABEL: @pooling_nhwc_max
@@ -172,22 +118,17 @@ func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
- // CHECK: tensor.extract_slice %[[ARG0]]
- // CHECK: tensor.extract_slice %[[ARG1]]
- // CHECK: tensor.extract_slice %[[ARG2]]
- // Both named and generic ops should decompose to pooling_nwc_max
- // CHECK-COUNT-2: linalg.pooling_nwc_max
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
- // Generic op version with same semantics.
- %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.maximumf %out, %in : f32
- linalg.yield %2 : f32
- } -> tensor<?x1x?x?xf32>
- return %1 : tensor<?x1x?x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x1x?x?xf32>
}
// CHECK-LABEL: @pooling_nhwc_max_unsigned
@@ -195,22 +136,17 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
- // CHECK: tensor.extract_slice %[[ARG0]]
- // CHECK: tensor.extract_slice %[[ARG1]]
- // CHECK: tensor.extract_slice %[[ARG2]]
- // Both named and generic ops should decompose to pooling_nwc_max_unsigned
- // CHECK-COUNT-2: linalg.pooling_nwc_max_unsigned
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
- // Generic op version with same semantics.
- %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xi32>, tensor<1x?xi32>) outs(%0 : tensor<?x1x?x?xi32>) {
- ^bb0(%in: i32, %in_0: i32, %out: i32):
- %2 = arith.maxui %out, %in : i32
- linalg.yield %2 : i32
- } -> tensor<?x1x?x?xi32>
- return %1 : tensor<?x1x?x?xi32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nhwc_min
@@ -218,22 +154,17 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tenso
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
- // CHECK: tensor.extract_slice %[[ARG0]]
- // CHECK: tensor.extract_slice %[[ARG1]]
- // CHECK: tensor.extract_slice %[[ARG2]]
- // Both named and generic ops should decompose to pooling_nwc_min
- // CHECK-COUNT-2: linalg.pooling_nwc_min
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
- // Generic op version with same semantics.
- %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x1x?x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.minimumf %out, %in : f32
- linalg.yield %2 : f32
- } -> tensor<?x1x?x?xf32>
- return %1 : tensor<?x1x?x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x1x?x?xf32>
}
// CHECK-LABEL: @pooling_nhwc_min_unsigned
@@ -241,22 +172,17 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
- // CHECK: tensor.extract_slice %[[ARG0]]
- // CHECK: tensor.extract_slice %[[ARG1]]
- // CHECK: tensor.extract_slice %[[ARG2]]
- // Both named and generic ops should decompose to pooling_nwc_min_unsigned
- // CHECK-COUNT-2: linalg.pooling_nwc_min_unsigned
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
- // Generic op version with same semantics.
- %1 = linalg.generic {indexing_maps = [#map_pooling_nhwc_input, #map_pooling_nhwc_filter, #map_pooling_nhwc_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x1x?x?xi32>, tensor<1x?xi32>) outs(%0 : tensor<?x1x?x?xi32>) {
- ^bb0(%in: i32, %in_0: i32, %out: i32):
- %2 = arith.minui %out, %in : i32
- linalg.yield %2 : i32
- } -> tensor<?x1x?x?xi32>
- return %1 : tensor<?x1x?x?xi32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nchw_max
@@ -264,22 +190,17 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tenso
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
- // CHECK: tensor.extract_slice %[[ARG0]]
- // CHECK: tensor.extract_slice %[[ARG1]]
- // CHECK: tensor.extract_slice %[[ARG2]]
- // Both named and generic ops should decompose to pooling_ncw_max
- // CHECK-COUNT-2: linalg.pooling_ncw_max
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_max
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
- // Generic op version with same semantics.
- %1 = linalg.generic {indexing_maps = [#map_pooling_nchw_input, #map_pooling_nchw_filter, #map_pooling_nchw_output], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%input, %filter : tensor<?x?x1x?xf32>, tensor<1x?xf32>) outs(%0 : tensor<?x?x1x?xf32>) {
- ^bb0(%in: f32, %in_0: f32, %out: f32):
- %2 = arith.maximumf %out, %in : f32
- linalg.yield %2 : f32
- } -> tensor<?x?x1x?xf32>
- return %1 : tensor<?x?x1x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x?x1x?xf32>
}
func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
More information about the Mlir-commits
mailing list