[Mlir-commits] [mlir] [Linalg] Add matchers to infer which Convolution Op a given linalg.generic is (PR #163374)
Abhishek Varma
llvmlistbot at llvm.org
Wed Oct 15 01:30:31 PDT 2025
https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/163374
>From a964bb437e612c57ce741157977502ed1f7815fb Mon Sep 17 00:00:00 2001
From: Abhishek Varma <avarma094 at gmail.com>
Date: Mon, 22 Sep 2025 12:16:48 +0000
Subject: [PATCH 01/18] [WIP] Generic to named Conv op support
Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 158 ++++++++++++++++++
1 file changed, 158 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 40fc0d68e358f..4e9572ee7cb04 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,159 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
+static bool matchingIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
+ArrayRef<mlir::utils::IteratorType> expectedIteratorTypes) {
+ if (iteratorTypes.size() != expectedIteratorTypes.size()) return false;
+ for (auto [orig, expected] : llvm::zip_equal(iteratorTypes, expectedIteratorTypes)) {
+ if (orig != expected) return false;
+ }
+ return true;
+}
+
+static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
+ uint32_t mapIndex, uint32_t dimIndex) {
+ auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+ // uint32_t nResults = affineMap.getNumResults();
+ // llvm::outs()<<affineMap<<"\n";
+ // llvm::outs()<<"Total result = "<<affineMap.getNumResults()<<"\n";
+ // llvm::outs()<<"N = "<<nResults<<", dimIndex = "<<dimIndex<<"\n";
+ // llvm::outs().flush();
+ return affineMap.getResult(dimIndex);
+}
+
+static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
+ SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+ SmallVector<utils::IteratorType> expectedIteratorTypes = {
+ utils::IteratorType::parallel, utils::IteratorType::reduction
+ };
+
+ if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
+ return "linalg.conv_1d";
+ return "";
+}
+
+static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
+ ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+ if (indexingMaps.size() != 3) return "";
+ SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+ // Conv 1D
+ // depthwise_conv_1d_ncw_cw
+ // depthwise_conv_1d_nwc_wc
+ // ["parallel", "parallel", "parallel", "reduction"]
+ SmallVector<utils::IteratorType> expectedIteratorTypes = {
+ utils::IteratorType::parallel, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::reduction
+ };
+ // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2;
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
+ if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))
+ return "linalg.depthwise_conv_1d_ncw_cw";
+ else if (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2))
+ return "linalg.depthwise_conv_1d_nwc_wc";
+ }
+
+ //
+ expectedIteratorTypes[2] = utils::IteratorType::reduction;
+ if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
+ return "linalg.conv_2d";
+ }
+ return "";
+}
+
+static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
+ ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+ if (indexingMaps.size() != 3) return "";
+ SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+ // "parallel", "parallel", "parallel", "reduction", "reduction"]
+ SmallVector<utils::IteratorType> expectedIteratorTypes = {
+ utils::IteratorType::parallel, utils::IteratorType::parallel,
+ utils::IteratorType::parallel, utils::IteratorType::parallel,
+ utils::IteratorType::reduction
+ };
+ if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
+ return "linalg.depthwise_conv_1d_nwc_wcm";
+
+ expectedIteratorTypes[3] = utils::IteratorType::reduction;
+ // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2;
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
+ if (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))
+ return "linalg.conv_1d_nwc_wcf";
+ else if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))
+ return "linalg.conv_1d_ncw_fcw";
+ }
+ return "";
+}
+
+static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
+ SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+ SmallVector<utils::IteratorType> expectedIteratorTypes = {
+ utils::IteratorType::parallel, utils::IteratorType::reduction
+ };
+ if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
+ return "linalg.conv_1d";
+ return "";
+}
+
+static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
+ SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+ SmallVector<utils::IteratorType> expectedIteratorTypes = {
+ utils::IteratorType::parallel, utils::IteratorType::reduction
+ };
+ if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
+ return "linalg.conv_1d";
+ return "";
+}
+
+static std::string inferConvolutionKind(GenericOp genericOp) {
+ SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+ unsigned totalIterators = iteratorTypes.size();
+ switch(totalIterators) {
+ case 2:
+ return inferBasedOnRank2ConvIteratorTypes(genericOp);
+ case 4:
+ return inferBasedOnRank4ConvIteratorTypes(genericOp);
+ case 5:
+ return inferBasedOnRank5ConvIteratorTypes(genericOp);
+ case 7:
+ return inferBasedOnRank7ConvIteratorTypes(genericOp);
+ case 8:
+ return inferBasedOnRank8ConvIteratorTypes(genericOp);
+ }
+ return "";
+}
+
+// Converts linalg.generic to named linalg.*conv* where possible.
+static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ std::string convKind = inferConvolutionKind(genericOp);
+ if (convKind == "") return failure();
+ SmallVector<Value> inputs = genericOp.getDpsInputs();
+ ValueRange outputs = genericOp.getDpsInits();
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+ ? TypeRange(ValueRange(outputs))
+ : TypeRange{};
+ LinalgOp namedOp;
+ if (convKind == "linalg.conv_1d") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_1d_nwc_wcf") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DNwcWcfOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_1d_ncw_fcw") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DNcwFcwOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.depthwise_conv_1d_ncw_cw") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNcwCwOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.depthwise_conv_1d_nwc_wc") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_2d") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DOp>(genericOp, resultTypes, inputs, outputs);
+ }
+ return namedOp;
+
+ return failure();
+}
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -316,6 +469,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
+
+ // Convolution - e.g. *conv*
+ if (isaConvolutionOpInterface(genericOp)) {
+ return specializeLinalgConvolutions(rewriter, genericOp);
+ }
return failure();
}
>From 89b7190e79b16954b254acca9b5f4d8f6f7c9eb6 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <avarma094 at gmail.com>
Date: Mon, 29 Sep 2025 17:27:30 +0000
Subject: [PATCH 02/18] Matching indexing maps
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 257 +++++++++++++-----
1 file changed, 187 insertions(+), 70 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4e9572ee7cb04..84b080fd53535 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,33 +237,17 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
-static bool matchingIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes,
-ArrayRef<mlir::utils::IteratorType> expectedIteratorTypes) {
- if (iteratorTypes.size() != expectedIteratorTypes.size()) return false;
- for (auto [orig, expected] : llvm::zip_equal(iteratorTypes, expectedIteratorTypes)) {
- if (orig != expected) return false;
- }
- return true;
-}
-
static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
uint32_t mapIndex, uint32_t dimIndex) {
auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
- // uint32_t nResults = affineMap.getNumResults();
- // llvm::outs()<<affineMap<<"\n";
- // llvm::outs()<<"Total result = "<<affineMap.getNumResults()<<"\n";
- // llvm::outs()<<"N = "<<nResults<<", dimIndex = "<<dimIndex<<"\n";
- // llvm::outs().flush();
return affineMap.getResult(dimIndex);
}
static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
- SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
- SmallVector<utils::IteratorType> expectedIteratorTypes = {
- utils::IteratorType::parallel, utils::IteratorType::reduction
- };
-
- if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
+ ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+ if (indexingMaps.size() != 3) return "";
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ if (getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0)))
return "linalg.conv_1d";
return "";
}
@@ -271,74 +255,187 @@ static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
if (indexingMaps.size() != 3) return "";
- SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
- // Conv 1D
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
// depthwise_conv_1d_ncw_cw
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))))
+ return "linalg.depthwise_conv_1d_ncw_cw";
// depthwise_conv_1d_nwc_wc
- // ["parallel", "parallel", "parallel", "reduction"]
- SmallVector<utils::IteratorType> expectedIteratorTypes = {
- utils::IteratorType::parallel, utils::IteratorType::parallel,
- utils::IteratorType::parallel, utils::IteratorType::reduction
- };
- // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2;
- unsigned iIndex = 0, fIndex = 1, oIndex = 2;
- if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
- if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))
- return "linalg.depthwise_conv_1d_ncw_cw";
- else if (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2))
- return "linalg.depthwise_conv_1d_nwc_wc";
- }
-
- //
- expectedIteratorTypes[2] = utils::IteratorType::reduction;
- if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))))
+ return "linalg.depthwise_conv_1d_nwc_wc";
+ // conv_2d
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))))
return "linalg.conv_2d";
- }
return "";
}
static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
if (indexingMaps.size() != 3) return "";
- SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
- // "parallel", "parallel", "parallel", "reduction", "reduction"]
- SmallVector<utils::IteratorType> expectedIteratorTypes = {
- utils::IteratorType::parallel, utils::IteratorType::parallel,
- utils::IteratorType::parallel, utils::IteratorType::parallel,
- utils::IteratorType::reduction
- };
- if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
- return "linalg.depthwise_conv_1d_nwc_wcm";
-
- expectedIteratorTypes[3] = utils::IteratorType::reduction;
- // inputMapIndex = 0, filterMapIndex = 1, outputMapIndex = 2;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
- if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes)) {
- if (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))
- return "linalg.conv_1d_nwc_wcf";
- else if (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1))
- return "linalg.conv_1d_ncw_fcw";
- }
+ // depthwise_conv_1d_nwc_wcm
+ // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) &&
+ (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 3)))
+ return "linalg.depthwise_conv_1d_nwc_wcm";
+ // conv_1d_nwc_wcf
+ // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)))
+ return "linalg.conv_1d_nwc_wcf";
+ // conv_1d_ncw_fcw
+ // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+ return "linalg.conv_1d_ncw_fcw";
return "";
}
static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
- SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
- SmallVector<utils::IteratorType> expectedIteratorTypes = {
- utils::IteratorType::parallel, utils::IteratorType::reduction
- };
- if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
- return "linalg.conv_1d";
+ ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+ if (indexingMaps.size() < 3) return "";
+ unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+ // conv_2d_nhwc_fhwc
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ if (indexingMaps.size() == 3 &&
+ (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
+ (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3)))
+ return "linalg.conv_2d_nhwc_fhwc";
+ // conv_2d_nhwc_hwcf
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
+ (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
+ return "linalg.conv_2d_nhwc_hwcf";
+ // conv_2d_nchw_fchw
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ if (indexingMaps.size() == 3 &&
+ (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+ return "linalg.conv_2d_nchw_fchw";
+ // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ if (indexingMaps.size() == 5 &&
+ (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
+ (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
+ (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3)))
+ return "linalg.conv_2d_nhwc_fhwc_q";
+ // conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps)
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ llvm::outs()<<"Indexing map size = "<<indexingMaps.size()<<"\n";
+ llvm::outs()<<"(indexingMaps[2] == indexingMaps[3]) == "<<(indexingMaps[2] == indexingMaps[3])<<"\n";
+ llvm::outs()<<"cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() = "<<cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults()<<"\n";
+ if (indexingMaps.size() == 5 &&
+ (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
+ (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+ return "linalg.conv_2d_nchw_fchw_q";
return "";
}
static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
- SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
- SmallVector<utils::IteratorType> expectedIteratorTypes = {
- utils::IteratorType::parallel, utils::IteratorType::reduction
- };
- if (matchingIteratorTypes(iteratorTypes, expectedIteratorTypes))
- return "linalg.conv_1d";
+ ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+ if (indexingMaps.size() < 3) return "";
+ unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+ // conv_2d_ngchw_fgchw
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+ (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+ return "linalg.conv_2d_ngchw_fgchw";
+ // conv_2d_ngchw_gfchw
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ if (indexingMaps.size() == 3 &&
+ (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+ (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)))
+ return "linalg.conv_2d_ngchw_gfchw";
+ // conv_2d_ngchw_gfchw_q
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ if (indexingMaps.size() == 5 &&
+ (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
+ (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+ (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)))
+ return "linalg.conv_2d_ngchw_gfchw_q";
+ // conv_2d_nhwgc_gfhwc
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) &&
+ (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4)))
+ return "linalg.conv_2d_nhwgc_gfhwc";
return "";
}
@@ -382,8 +479,28 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNcwCwOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.depthwise_conv_1d_nwc_wc") {
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.depthwise_conv_1d_nwc_wcm") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcmOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.conv_2d") {
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_2d_nhwc_fhwc") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcFhwcOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_2d_nhwc_hwcf") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_2d_nchw_fchw") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNchwFchwOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_2d_nhwc_fhwc_q") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcFhwcQOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_2d_nchw_fchw_q") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNchwFchwQOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_2d_ngchw_fgchw") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwFgchwOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_2d_ngchw_gfchw") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_2d_ngchw_gfchw_q") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwQOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcOp>(genericOp, resultTypes, inputs, outputs);
}
return namedOp;
>From dac92f142afdfd6a7a616574e1b0983513a37ba1 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 3 Oct 2025 06:56:34 -0500
Subject: [PATCH 03/18] Conv complete -> start Pool op now
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 146 +++++++++++++++++-
1 file changed, 142 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 84b080fd53535..6603967b991ab 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -316,6 +316,39 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
return "";
}
+static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
+ ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+ if (indexingMaps.size() < 3) return "";
+ unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+ // depthwise_conv_2d_nchw_chw
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))))
+ return "linalg.depthwise_conv_2d_nchw_chw";
+ // depthwise_conv_2d_nhwc_hwc
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
+ return "linalg.depthwise_conv_2d_nhwc_hwc";
+ // conv_3d
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))))
+ return "linalg.conv_3d";
+ return "";
+}
+
static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
if (indexingMaps.size() < 3) return "";
@@ -370,9 +403,6 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- llvm::outs()<<"Indexing map size = "<<indexingMaps.size()<<"\n";
- llvm::outs()<<"(indexingMaps[2] == indexingMaps[3]) == "<<(indexingMaps[2] == indexingMaps[3])<<"\n";
- llvm::outs()<<"cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() = "<<cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults()<<"\n";
if (indexingMaps.size() == 5 &&
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
(getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
@@ -381,6 +411,30 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
(getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
return "linalg.conv_2d_nchw_fchw_q";
+ // depthwise_conv_2d_nhwc_hwcm
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+ if (indexingMaps.size() == 3 &&
+ (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
+ (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
+ return "linalg.depthwise_conv_2d_nhwc_hwcm";
+ // depthwise_conv_2d_nhwc_hwcm_q
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+ if (indexingMaps.size() == 5 &&
+ (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
+ (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
+ (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
+ return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
return "";
}
@@ -397,7 +451,7 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
(getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
(getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
- (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+ (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 2)))
return "linalg.conv_2d_ngchw_fgchw";
// conv_2d_ngchw_gfchw
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
@@ -436,6 +490,66 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
(getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) &&
(getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4)))
return "linalg.conv_2d_nhwgc_gfhwc";
+ // depthwise_conv_3d_ncdhw_cdhw
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 4))))
+ return "linalg.depthwise_conv_3d_ncdhw_cdhw";
+ // depthwise_conv_3d_ndhwc_dhwc
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
+ return "linalg.depthwise_conv_3d_ndhwc_dhwc";
+ return "";
+}
+
+static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
+ ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+ if (indexingMaps.size() < 3) return "";
+ unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+ // conv_3d_ncdhw_fcdhw
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+ (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+ return "linalg.conv_3d_ncdhw_fcdhw";
+ // conv_3d_ndhwc_dhwcf
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
+ (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
+ return "linalg.conv_3d_ndhwc_dhwcf";
+ // depthwise_conv_3d_ndhwc_dhwcm
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) &&
+ (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5)))
+ return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
return "";
}
@@ -449,10 +563,14 @@ static std::string inferConvolutionKind(GenericOp genericOp) {
return inferBasedOnRank4ConvIteratorTypes(genericOp);
case 5:
return inferBasedOnRank5ConvIteratorTypes(genericOp);
+ case 6:
+ return inferBasedOnRank6ConvIteratorTypes(genericOp);
case 7:
return inferBasedOnRank7ConvIteratorTypes(genericOp);
case 8:
return inferBasedOnRank8ConvIteratorTypes(genericOp);
+ case 9:
+ return inferBasedOnRank9ConvIteratorTypes(genericOp);
}
return "";
}
@@ -501,6 +619,26 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwQOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") {
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.depthwise_conv_2d_nchw_chw") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNchwChwOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm_q") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_3d") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_3d_ncdhw_fcdhw") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNcdhwFcdhwOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwcm") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.depthwise_conv_3d_ncdhw_cdhw") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwc") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(genericOp, resultTypes, inputs, outputs);
}
return namedOp;
>From 789fb856517377966b58ba9b1c4fdea1d1b12324 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 7 Oct 2025 09:04:37 -0500
Subject: [PATCH 04/18] Add pooling ops to the mix - has few issues but we can
shift to considering dilations/strides now
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 167 ++++++++++++++++++
1 file changed, 167 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 6603967b991ab..2efa410e4b855 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,39 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+ Operation *defOp = yieldVal.getDefiningOp();
+ // if (!defOp) return false;
+ if (!(isa_and_present<OpTypes>(defOp) || ...)) return false;
+
+ BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+ BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
+ if (!lhsArg || !rhsArg) return false;
+ return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
uint32_t mapIndex, uint32_t dimIndex) {
auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
@@ -279,6 +312,39 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))))
return "linalg.conv_2d";
+
+ Block *body = genericOp.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ // pooling_ncw_max
+ // pooling_ncw_sum
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
+ // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+ // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2)))) {
+ if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
+ return "linalg.pooling_ncw_max";
+ if (bodyMatcherForSumPoolOps(yieldVal, body))
+ return "linalg.pooling_ncw_sum";
+ }
+ // pooling_nwc_max
+ // pooling_nwc_min
+ // pooling_nwc_sum
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+ // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+ // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) {
+ if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
+ return "linalg.pooling_nwc_max";
+ if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
+ return "linalg.pooling_nwc_min";
+ if (bodyMatcherForSumPoolOps(yieldVal, body))
+ return "linalg.pooling_nwc_sum";
+ }
return "";
}
@@ -346,6 +412,55 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))))
return "linalg.conv_3d";
+
+ Block *body = genericOp.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ // pooling_nchw_max
+ // pooling_nchw_sum
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 3)))) {
+ if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
+ return "linalg.pooling_nchw_max";
+ if (bodyMatcherForSumPoolOps(yieldVal, body))
+ return "linalg.pooling_nchw_sum";
+ }
+ // pooling_nhwc_max
+ // pooling_nhwc_min
+ // pooling_nhwc_sum
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
+ if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
+ return "linalg.pooling_nhwc_max";
+ if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
+ return "linalg.pooling_nhwc_min";
+ if (bodyMatcherForSumPoolOps(yieldVal, body))
+ return "linalg.pooling_nhwc_sum";
+ }
+ // pooling_nhwc_max_unsigned
+ // pooling_nhwc_min_unsigned
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
+ if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body))
+ return "linalg.pooling_nhwc_max_unsigned";
+ if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body))
+ return "linalg.pooling_nhwc_max_unsigned";
+ }
return "";
}
@@ -510,6 +625,28 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
(getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
return "linalg.depthwise_conv_3d_ndhwc_dhwc";
+
+ Block *body = genericOp.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ // pooling_ndhwc_max
+ // pooling_ndhwc_min
+ // pooling_ndhwc_sum
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+ // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) {
+ if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
+ return "linalg.pooling_ndhwc_max";
+ if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
+ return "linalg.pooling_ndhwc_min";
+ if (bodyMatcherForSumPoolOps(yieldVal, body))
+ return "linalg.pooling_ndhwc_sum";
+ }
return "";
}
@@ -639,6 +776,36 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwc") {
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_nchw_max") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNchwMaxOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_nchw_sum") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNchwSumOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_nhwc_max") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_nhwc_min") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMinOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_nhwc_sum") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcSumOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_nhwc_max_unsigned") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_nhwc_min_unsigned") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMinUnsignedOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_ncw_max") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNcwMaxOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_ncw_sum") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNcwSumOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_nwc_max") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcMaxOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_nwc_min") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcMinOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_nwc_sum") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcSumOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_ndhwc_max") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcMaxOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_ndhwc_min") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcMinOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.pooling_ndhwc_sum") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcSumOp>(genericOp, resultTypes, inputs, outputs);
}
return namedOp;
>From 8e65bc6fc1a75de626a62ad72784087b57f2cd65 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 8 Oct 2025 03:42:40 -0500
Subject: [PATCH 05/18] Concisely v1.0
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 199 ++++++++++++------
1 file changed, 130 insertions(+), 69 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 2efa410e4b855..01a5c3bebd146 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -273,14 +273,75 @@ static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
uint32_t mapIndex, uint32_t dimIndex) {
auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
- return affineMap.getResult(dimIndex);
+ if (dimIndex < affineMap.getNumResults())
+ return affineMap.getResult(dimIndex);
+ return nullptr;
+}
+
+// Check if `expr` is either:
+// - a dimension expr alone (implying *1), or
+// - a multiplication of dimension expr by constant.
+bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) {
+ if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
+ dim = dExpr;
+ constantValue = 1;
+ return true;
+ }
+
+ auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+ return false;
+
+ AffineExpr lhs = mulExpr.getLHS();
+ AffineExpr rhs = mulExpr.getRHS();
+
+ if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
+ if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
+ dim = dExpr;
+ constantValue = cst.getValue();
+ return true;
+ }
+ }
+ if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
+ if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
+ dim = dExpr;
+ constantValue = cst.getValue();
+ return true;
+ }
+ }
+ return false;
+}
+
+bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) {
+ unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+ AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
+ auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+ if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+ return false;
+
+ AffineExpr dim0, dim1;
+ // TODO(Abhishek-Varma): Use this information in specialize.cpp.
+ int64_t c0, c1;
+
+ if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
+ isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
+ // Pattern matched with dims and constants extracted.
+ AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
+ AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
+ return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr));
+ }
+ return false;
+}
+
+bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
+ return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
}
static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
if (indexingMaps.size() != 3) return "";
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
- if (getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0)))
+ if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0))
return "linalg.conv_1d";
return "";
}
@@ -295,7 +356,7 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))))
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2))
return "linalg.depthwise_conv_1d_ncw_cw";
// depthwise_conv_1d_nwc_wc
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
@@ -303,14 +364,14 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))))
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1))
return "linalg.depthwise_conv_1d_nwc_wc";
// conv_2d
// #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))))
+ if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1))
return "linalg.conv_2d";
Block *body = genericOp.getBlock();
@@ -323,7 +384,7 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
// #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2)))) {
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2)) {
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
return "linalg.pooling_ncw_max";
if (bodyMatcherForSumPoolOps(yieldVal, body))
@@ -336,7 +397,7 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
// #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
(getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) {
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
return "linalg.pooling_nwc_max";
@@ -357,7 +418,7 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
(getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) &&
(getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 3)))
return "linalg.depthwise_conv_1d_nwc_wcm";
@@ -366,7 +427,7 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
(getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
(getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)))
return "linalg.conv_1d_nwc_wcf";
@@ -376,7 +437,7 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
(getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
return "linalg.conv_1d_ncw_fcw";
return "";
@@ -392,25 +453,25 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))))
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3))
return "linalg.depthwise_conv_2d_nchw_chw";
// depthwise_conv_2d_nhwc_hwc
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
return "linalg.depthwise_conv_2d_nhwc_hwc";
// conv_3d
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))))
+ if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2))
return "linalg.conv_3d";
Block *body = genericOp.getBlock();
@@ -423,8 +484,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 3)))) {
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3)) {
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
return "linalg.pooling_nchw_max";
if (bodyMatcherForSumPoolOps(yieldVal, body))
@@ -437,8 +498,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
return "linalg.pooling_nhwc_max";
@@ -453,13 +514,13 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body))
return "linalg.pooling_nhwc_max_unsigned";
if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body))
- return "linalg.pooling_nhwc_max_unsigned";
+ return "linalg.pooling_nhwc_min_unsigned";
}
return "";
}
@@ -474,8 +535,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
if (indexingMaps.size() == 3 &&
(getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
(getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3)))
return "linalg.conv_2d_nhwc_fhwc";
@@ -484,8 +545,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
(getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
return "linalg.conv_2d_nhwc_hwcf";
@@ -496,8 +557,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
if (indexingMaps.size() == 3 &&
(getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
(getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
return "linalg.conv_2d_nchw_fchw";
// conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
@@ -508,8 +569,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
if (indexingMaps.size() == 5 &&
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
(getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
(getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3)))
return "linalg.conv_2d_nhwc_fhwc_q";
@@ -522,8 +583,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
(getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
(getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
return "linalg.conv_2d_nchw_fchw_q";
// depthwise_conv_2d_nhwc_hwcm
@@ -532,8 +593,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
if (indexingMaps.size() == 3 &&
(getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
(getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
return "linalg.depthwise_conv_2d_nhwc_hwcm";
@@ -545,8 +606,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
if (indexingMaps.size() == 5 &&
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
(getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
(getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
@@ -564,8 +625,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
(getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
- (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
(getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 2)))
return "linalg.conv_2d_ngchw_fgchw";
// conv_2d_ngchw_gfchw
@@ -576,8 +637,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
(getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
(getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
- (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
(getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)))
return "linalg.conv_2d_ngchw_gfchw";
// conv_2d_ngchw_gfchw_q
@@ -590,8 +651,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
(getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
(getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
- (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
(getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)))
return "linalg.conv_2d_ngchw_gfchw_q";
// conv_2d_nhwgc_gfhwc
@@ -599,8 +660,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
(getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) &&
(getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4)))
@@ -611,18 +672,18 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
- (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 4))))
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4))
return "linalg.depthwise_conv_3d_ncdhw_cdhw";
// depthwise_conv_3d_ndhwc_dhwc
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
(getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
return "linalg.depthwise_conv_3d_ndhwc_dhwc";
@@ -636,9 +697,9 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
// #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
(getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) {
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
return "linalg.pooling_ndhwc_max";
@@ -660,9 +721,9 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
- (getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
(getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
return "linalg.conv_3d_ncdhw_fcdhw";
// conv_3d_ndhwc_dhwcf
@@ -670,22 +731,22 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
- (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
- (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
+ (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
return "linalg.conv_3d_ndhwc_dhwcf";
// depthwise_conv_3d_ndhwc_dhwcm
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
- (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) &&
- (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5)))
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) &&
+ (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5)))
return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
return "";
}
>From aae7e048da5f9fa34e7eb1a0de7559e823e23866 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 8 Oct 2025 06:23:27 -0500
Subject: [PATCH 06/18] Concise v2.0
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 4 +
.../Dialect/Linalg/Transforms/Specialize.cpp | 178 ++++++++++--------
2 files changed, 100 insertions(+), 82 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 48978eb7663d5..46f9f10789e6c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -114,6 +114,10 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
// Fusion / Tiling utilities
//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+// Fusion / Tiling utilities
+//===----------------------------------------------------------------------===//
+
/// The type of loops to be generated during tiling.
enum class LinalgTilingLoopType {
Loops = 0,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 01a5c3bebd146..a741cd126dd3b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -354,16 +354,18 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2))
return "linalg.depthwise_conv_1d_ncw_cw";
// depthwise_conv_1d_nwc_wc
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1))
return "linalg.depthwise_conv_1d_nwc_wc";
// conv_2d
@@ -382,8 +384,8 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
// #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2)) {
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
return "linalg.pooling_ncw_max";
@@ -396,9 +398,9 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
// #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) {
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2)) {
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
return "linalg.pooling_nwc_max";
if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
@@ -417,28 +419,29 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)) &&
- (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 3)))
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3))
return "linalg.depthwise_conv_1d_nwc_wcm";
// conv_1d_nwc_wcf
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
- (getAffineMapDim(indexingMaps, fIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2)))
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2))
return "linalg.conv_1d_nwc_wcf";
// conv_1d_ncw_fcw
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
return "linalg.conv_1d_ncw_fcw";
return "";
}
@@ -451,8 +454,9 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3))
return "linalg.depthwise_conv_2d_nchw_chw";
@@ -460,10 +464,11 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3))
return "linalg.depthwise_conv_2d_nhwc_hwc";
// conv_3d
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
@@ -482,8 +487,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3)) {
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
@@ -497,10 +502,10 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) {
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
return "linalg.pooling_nhwc_max";
if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
@@ -513,10 +518,10 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) {
if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body))
return "linalg.pooling_nhwc_max_unsigned";
if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body))
@@ -534,32 +539,32 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
if (indexingMaps.size() == 3 &&
- (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
- (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3)))
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3))
return "linalg.conv_2d_nhwc_fhwc";
// conv_2d_nhwc_hwcf
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
- (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3))
return "linalg.conv_2d_nhwc_hwcf";
// conv_2d_nchw_fchw
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
if (indexingMaps.size() == 3 &&
- (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
return "linalg.conv_2d_nchw_fchw";
// conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
@@ -568,11 +573,11 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
if (indexingMaps.size() == 5 &&
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
- (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
- (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 3)))
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3))
return "linalg.conv_2d_nhwc_fhwc_q";
// conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps)
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
@@ -581,22 +586,23 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
if (indexingMaps.size() == 5 &&
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
- (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
return "linalg.conv_2d_nchw_fchw_q";
// depthwise_conv_2d_nhwc_hwcm
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
if (indexingMaps.size() == 3 &&
- (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
- (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4))
return "linalg.depthwise_conv_2d_nhwc_hwcm";
// depthwise_conv_2d_nhwc_hwcm_q
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
@@ -605,11 +611,12 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
if (indexingMaps.size() == 5 &&
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
- (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
- (getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4))
return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
return "";
}
@@ -622,24 +629,26 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
- (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 2)))
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2))
return "linalg.conv_2d_ngchw_fgchw";
// conv_2d_ngchw_gfchw
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
if (indexingMaps.size() == 3 &&
- (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
- (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)))
+ matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2))
return "linalg.conv_2d_ngchw_gfchw";
// conv_2d_ngchw_gfchw_q
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
@@ -648,30 +657,33 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
if (indexingMaps.size() == 5 &&
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
- (getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
- (getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
- (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 2)))
+ matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2))
return "linalg.conv_2d_ngchw_gfchw_q";
// conv_2d_nhwgc_gfhwc
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
- (getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
- (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) &&
- (getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4)))
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4))
return "linalg.conv_2d_nhwgc_gfhwc";
// depthwise_conv_3d_ncdhw_cdhw
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4))
@@ -680,11 +692,12 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
- (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4))
return "linalg.depthwise_conv_3d_ndhwc_dhwc";
Block *body = genericOp.getBlock();
@@ -696,11 +709,11 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
// #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
- (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) {
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)) {
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
return "linalg.pooling_ndhwc_max";
if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
@@ -719,34 +732,35 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
- (getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
- (getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
return "linalg.conv_3d_ncdhw_fcdhw";
// conv_3d_ndhwc_dhwcf
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
- (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
- (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4))
return "linalg.conv_3d_ndhwc_dhwcf";
// depthwise_conv_3d_ndhwc_dhwcm
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
- if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
+ if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
- (getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) &&
- (getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5)))
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5))
return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
return "";
}
>From bafdb41a8e539d77adcf6253b1631d22c69b2d45 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 8 Oct 2025 07:08:18 -0500
Subject: [PATCH 07/18] Start pulling out into separate APIs
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 9 +-
.../Dialect/Linalg/Transforms/Specialize.cpp | 34 +--
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 219 ++++++++++++++++++
3 files changed, 233 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 46f9f10789e6c..222b66ca51708 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -111,9 +111,16 @@ std::optional<SmallVector<ReassociationIndices>>
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
//===----------------------------------------------------------------------===//
-// Fusion / Tiling utilities
+// Convolution matcher utilities
//===----------------------------------------------------------------------===//
+bool isaConv1DOp(LinalgOp op);
+bool isaConv1DNwcWcfOp(LinalgOp op);
+bool isaConv1DNcwFcwOp(LinalgOp op);
+bool isaDepthwiseConv1DNcwCwOp(LinalgOp op);
+bool isaDepthwiseConv1DNwcWcOp(LinalgOp op);
+bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op);
+
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index a741cd126dd3b..8b51b8f13ce0d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -338,35 +338,24 @@ bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned a
}
static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
- ArrayAttr indexingMaps = genericOp.getIndexingMaps();
- if (indexingMaps.size() != 3) return "";
- unsigned iIndex = 0, fIndex = 1, oIndex = 2;
- if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0))
- return "linalg.conv_1d";
+ if (isaConv1DOp(genericOp)) return "linalg.conv_1d";
return "";
}
static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
if (indexingMaps.size() != 3) return "";
- unsigned iIndex = 0, fIndex = 1, oIndex = 2;
// depthwise_conv_1d_ncw_cw
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2))
+ if (isaDepthwiseConv1DNcwCwOp(genericOp))
return "linalg.depthwise_conv_1d_ncw_cw";
// depthwise_conv_1d_nwc_wc
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1))
+ if (isaDepthwiseConv1DNwcWcOp(genericOp))
return "linalg.depthwise_conv_1d_nwc_wc";
// conv_2d
// #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
@@ -414,34 +403,23 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
if (indexingMaps.size() != 3) return "";
- unsigned iIndex = 0, fIndex = 1, oIndex = 2;
// depthwise_conv_1d_nwc_wcm
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3))
+ if (isaDepthwiseConv1DNwcWcmOp(genericOp))
return "linalg.depthwise_conv_1d_nwc_wcm";
// conv_1d_nwc_wcf
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2))
+ if (isaConv1DNwcWcfOp(genericOp))
return "linalg.conv_1d_nwc_wcf";
// conv_1d_ncw_fcw
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
+ if (isaConv1DNcwFcwOp(genericOp))
return "linalg.conv_1d_ncw_fcw";
return "";
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 3593b5348d268..12f88caf08fc7 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,6 +240,225 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
return iteratorType == utils::IteratorType::reduction;
}
+// -------------------------------
+// ---------- CONV ---------------
+// -------------------------------
+
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+ Operation *defOp = yieldVal.getDefiningOp();
+ // if (!defOp) return false;
+ if (!(isa_and_present<OpTypes>(defOp) || ...)) return false;
+
+ BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+ BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
+ if (!lhsArg || !rhsArg) return false;
+ return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
+static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
+ uint32_t mapIndex, uint32_t dimIndex) {
+ auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+ if (dimIndex < affineMap.getNumResults())
+ return affineMap.getResult(dimIndex);
+ return nullptr;
+}
+
+// Check if `expr` is either:
+// - a dimension expr alone (implying *1), or
+// - a multiplication of dimension expr by constant.
+static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) {
+ if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
+ dim = dExpr;
+ constantValue = 1;
+ return true;
+ }
+
+ auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+ return false;
+
+ AffineExpr lhs = mulExpr.getLHS();
+ AffineExpr rhs = mulExpr.getRHS();
+
+ if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
+ if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
+ dim = dExpr;
+ constantValue = cst.getValue();
+ return true;
+ }
+ }
+ if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
+ if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
+ dim = dExpr;
+ constantValue = cst.getValue();
+ return true;
+ }
+ }
+ return false;
+}
+
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) {
+ unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+ AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
+ auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+ if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+ return false;
+
+ AffineExpr dim0, dim1;
+ // TODO(Abhishek-Varma): Use this information in specialize.cpp.
+ int64_t c0, c1;
+
+ if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
+ isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
+ // Pattern matched with dims and constants extracted.
+ AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
+ AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
+ return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr));
+ }
+ return false;
+}
+
+static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
+ return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
+}
+
+static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t> expectedSizes) {
+ if (indexingMaps.size() != expectedSizes.size()) return false;
+
+ for (auto [indexingMap, expectedSize] : llvm::zip_equal(indexingMaps, expectedSizes)) {
+ auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
+ if (affineMap.getNumResults() != expectedSize) return false;
+ }
+ return true;
+}
+
+bool isaConv1DOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {1,1,1})) return false;
+
+ // #map = affine_map<(d0, d1) -> (d0 + d1)>
+ // #map1 = affine_map<(d0, d1) -> (d1)>
+ // #map2 = affine_map<(d0, d1) -> (d0)>
+ return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0);
+}
+
+bool isaConv1DNwcWcfOp(LinalgOp op) {
+ if (isa<linalg::Conv1DNwcWcfOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2));
+}
+
+bool isaConv1DNcwFcwOp(LinalgOp op) {
+ if (isa<linalg::Conv1DNcwFcwOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+}
+
+bool isaDepthwiseConv1DNcwCwOp(LinalgOp op) {
+ if (isa<linalg::DepthwiseConv1DNcwCwOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2));
+}
+
+bool isaDepthwiseConv1DNwcWcOp(LinalgOp op) {
+ if (isa<linalg::DepthwiseConv1DNwcWcOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1));
+}
+
+bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) {
+ if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,4})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3));
+}
+
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold,
ValueRange typeDynDims) {
>From a08247c4804f504535a061c77fe782e5b75991f4 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 8 Oct 2025 07:47:06 -0500
Subject: [PATCH 08/18] Some more APIs
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 14 +
.../Dialect/Linalg/Transforms/Specialize.cpp | 105 +------
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 277 ++++++++++++++++++
3 files changed, 306 insertions(+), 90 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 222b66ca51708..b4955625b6dec 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -120,6 +120,20 @@ bool isaConv1DNcwFcwOp(LinalgOp op);
bool isaDepthwiseConv1DNcwCwOp(LinalgOp op);
bool isaDepthwiseConv1DNwcWcOp(LinalgOp op);
bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op);
+bool isaConv2DOp(LinalgOp op);
+bool isaConv2DNhwcFhwcOp(LinalgOp op);
+bool isaConv2DNhwcHwcfOp(LinalgOp op);
+bool isaConv2DNchwFchwOp(LinalgOp op);
+bool isaConv2DNhwcFhwcQOp(LinalgOp op);
+bool isaConv2DNchwFchwQOp(LinalgOp op);
+bool isaConv2DNgchwFgchwOp(LinalgOp op);
+bool isaConv2DNgchwGfchwOp(LinalgOp op);
+bool isaConv2DNgchwGfchwQOp(LinalgOp op);
+bool isaConv2DNhwgcGfhwcOp(LinalgOp op);
+bool isaDepthwiseConv2DNchwChwOp(LinalgOp op);
+bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op);
+bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op);
+bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op);
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 8b51b8f13ce0d..968370c05615a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -361,10 +361,10 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
// #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
- if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1))
+ if (isaConv2DOp(genericOp))
return "linalg.conv_2d";
+ unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
Block *body = genericOp.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
@@ -432,21 +432,13 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3))
+ if (isaDepthwiseConv2DNchwChwOp(genericOp))
return "linalg.depthwise_conv_2d_nchw_chw";
// depthwise_conv_2d_nhwc_hwc
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3))
+ if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
return "linalg.depthwise_conv_2d_nhwc_hwc";
// conv_3d
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
@@ -511,90 +503,50 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
if (indexingMaps.size() < 3) return "";
- unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
// conv_2d_nhwc_fhwc
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- if (indexingMaps.size() == 3 &&
- matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3))
+ if (isaConv2DNhwcFhwcOp(genericOp))
return "linalg.conv_2d_nhwc_fhwc";
// conv_2d_nhwc_hwcf
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3))
+ if (isaConv2DNhwcHwcfOp(genericOp))
return "linalg.conv_2d_nhwc_hwcf";
// conv_2d_nchw_fchw
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- if (indexingMaps.size() == 3 &&
- matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
+ if (isaConv2DNchwFchwOp(genericOp))
return "linalg.conv_2d_nchw_fchw";
// conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- if (indexingMaps.size() == 5 &&
- (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3))
+ if (isaConv2DNhwcFhwcQOp(genericOp))
return "linalg.conv_2d_nhwc_fhwc_q";
// conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps)
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- if (indexingMaps.size() == 5 &&
- (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
+ if (isaConv2DNchwFchwQOp(genericOp))
return "linalg.conv_2d_nchw_fchw_q";
// depthwise_conv_2d_nhwc_hwcm
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
- if (indexingMaps.size() == 3 &&
- matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4))
+ if (isaDepthwiseConv2DNhwcHwcmOp(genericOp))
return "linalg.depthwise_conv_2d_nhwc_hwcm";
// depthwise_conv_2d_nhwc_hwcm_q
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
- if (indexingMaps.size() == 5 &&
- (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4))
+ if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp))
return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
return "";
}
@@ -607,53 +559,26 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2))
+ if (isaConv2DNgchwFgchwOp(genericOp))
return "linalg.conv_2d_ngchw_fgchw";
// conv_2d_ngchw_gfchw
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- if (indexingMaps.size() == 3 &&
- matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2))
+ if (isaConv2DNgchwGfchwOp(genericOp))
return "linalg.conv_2d_ngchw_gfchw";
// conv_2d_ngchw_gfchw_q
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- if (indexingMaps.size() == 5 &&
- (indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2))
+ if (isaConv2DNgchwGfchwQOp(genericOp))
return "linalg.conv_2d_ngchw_gfchw_q";
// conv_2d_nhwgc_gfhwc
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4))
+ if (isaConv2DNhwgcGfhwcOp(genericOp))
return "linalg.conv_2d_nhwgc_gfhwc";
// depthwise_conv_3d_ncdhw_cdhw
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 12f88caf08fc7..c5bb184c726f8 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -459,6 +459,283 @@ bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) {
matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3));
}
+bool isaConv2DOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {2,2,2})) return false;
+
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+ return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1));
+}
+
+bool isaConv2DNhwcFhwcOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
+}
+
+bool isaConv2DNhwcHwcfOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
+}
+
+bool isaConv2DNchwFchwOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+}
+
+bool isaConv2DNhwcFhwcQOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
+}
+
+bool isaConv2DNchwFchwQOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+}
+
+bool isaConv2DNgchwFgchwOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2));
+}
+
+bool isaConv2DNgchwGfchwOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ return (indexingMaps.size() == 3 &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
+}
+
+bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
+}
+
+bool isaConv2DNhwgcGfhwcOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
+}
+
+bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3));
+}
+
+bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
+}
+
+bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,5})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
+}
+
+bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,5})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
+}
+
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold,
ValueRange typeDynDims) {
>From 053d912f321b72cde08df10ef5cc9690248247cc Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 8 Oct 2025 07:59:55 -0500
Subject: [PATCH 09/18] Clean a bit
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 97 +------------------
1 file changed, 5 insertions(+), 92 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 968370c05615a..ea94b49946545 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -343,27 +343,14 @@ static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
}
static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
- ArrayAttr indexingMaps = genericOp.getIndexingMaps();
- if (indexingMaps.size() != 3) return "";
- // depthwise_conv_1d_ncw_cw
- // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
- // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
if (isaDepthwiseConv1DNcwCwOp(genericOp))
return "linalg.depthwise_conv_1d_ncw_cw";
- // depthwise_conv_1d_nwc_wc
- // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
- // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
if (isaDepthwiseConv1DNwcWcOp(genericOp))
return "linalg.depthwise_conv_1d_nwc_wc";
- // conv_2d
- // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
- // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
if (isaConv2DOp(genericOp))
return "linalg.conv_2d";
+ ArrayAttr indexingMaps = genericOp.getIndexingMaps();
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
Block *body = genericOp.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
@@ -401,45 +388,24 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
}
static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
- ArrayAttr indexingMaps = genericOp.getIndexingMaps();
- if (indexingMaps.size() != 3) return "";
- // depthwise_conv_1d_nwc_wcm
- // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
if (isaDepthwiseConv1DNwcWcmOp(genericOp))
return "linalg.depthwise_conv_1d_nwc_wcm";
- // conv_1d_nwc_wcf
- // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
if (isaConv1DNwcWcfOp(genericOp))
return "linalg.conv_1d_nwc_wcf";
- // conv_1d_ncw_fcw
- // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
if (isaConv1DNcwFcwOp(genericOp))
return "linalg.conv_1d_ncw_fcw";
return "";
}
static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
- ArrayAttr indexingMaps = genericOp.getIndexingMaps();
- if (indexingMaps.size() < 3) return "";
- unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
- // depthwise_conv_2d_nchw_chw
- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
if (isaDepthwiseConv2DNchwChwOp(genericOp))
return "linalg.depthwise_conv_2d_nchw_chw";
- // depthwise_conv_2d_nhwc_hwc
- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
return "linalg.depthwise_conv_2d_nhwc_hwc";
+
+ ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+ if (indexingMaps.size() < 3) return "";
+ unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
// conv_3d
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
@@ -501,83 +467,30 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
}
static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
- ArrayAttr indexingMaps = genericOp.getIndexingMaps();
- if (indexingMaps.size() < 3) return "";
- // conv_2d_nhwc_fhwc
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
if (isaConv2DNhwcFhwcOp(genericOp))
return "linalg.conv_2d_nhwc_fhwc";
- // conv_2d_nhwc_hwcf
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
if (isaConv2DNhwcHwcfOp(genericOp))
return "linalg.conv_2d_nhwc_hwcf";
- // conv_2d_nchw_fchw
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
if (isaConv2DNchwFchwOp(genericOp))
return "linalg.conv_2d_nchw_fchw";
- // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
if (isaConv2DNhwcFhwcQOp(genericOp))
return "linalg.conv_2d_nhwc_fhwc_q";
- // conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps)
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
if (isaConv2DNchwFchwQOp(genericOp))
return "linalg.conv_2d_nchw_fchw_q";
- // depthwise_conv_2d_nhwc_hwcm
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
if (isaDepthwiseConv2DNhwcHwcmOp(genericOp))
return "linalg.depthwise_conv_2d_nhwc_hwcm";
- // depthwise_conv_2d_nhwc_hwcm_q
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp))
return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
return "";
}
static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
- ArrayAttr indexingMaps = genericOp.getIndexingMaps();
- if (indexingMaps.size() < 3) return "";
- unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
- // conv_2d_ngchw_fgchw
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
if (isaConv2DNgchwFgchwOp(genericOp))
return "linalg.conv_2d_ngchw_fgchw";
- // conv_2d_ngchw_gfchw
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
if (isaConv2DNgchwGfchwOp(genericOp))
return "linalg.conv_2d_ngchw_gfchw";
- // conv_2d_ngchw_gfchw_q
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
if (isaConv2DNgchwGfchwQOp(genericOp))
return "linalg.conv_2d_ngchw_gfchw_q";
- // conv_2d_nhwgc_gfhwc
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
if (isaConv2DNhwgcGfhwcOp(genericOp))
return "linalg.conv_2d_nhwgc_gfhwc";
// depthwise_conv_3d_ncdhw_cdhw
>From 87b91ee5602ff9b031f98a980f6283d517190092 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 9 Oct 2025 02:36:54 -0500
Subject: [PATCH 10/18] Add 3D APIs
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 6 +
.../Dialect/Linalg/Transforms/Specialize.cpp | 70 ++---------
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 117 ++++++++++++++++++
3 files changed, 132 insertions(+), 61 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index b4955625b6dec..ad5e0818b90f5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -134,6 +134,12 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op);
bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op);
bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op);
bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op);
+bool isaConv3DOp(LinalgOp op);
+bool isaConv3DNcdhwFcdhwOp(LinalgOp op);
+bool isaConv3DNdhwcDhwcfOp(LinalgOp op);
+bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op);
+bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op);
+bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op);
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index ea94b49946545..6ecc6a024bed8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -406,13 +406,7 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
if (indexingMaps.size() < 3) return "";
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
- // conv_3d
- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
- if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2))
+ if (isaConv3DOp(genericOp))
return "linalg.conv_3d";
Block *body = genericOp.getBlock();
@@ -493,29 +487,14 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
return "linalg.conv_2d_ngchw_gfchw_q";
if (isaConv2DNhwgcGfhwcOp(genericOp))
return "linalg.conv_2d_nhwgc_gfhwc";
- // depthwise_conv_3d_ncdhw_cdhw
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4))
+ if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp))
return "linalg.depthwise_conv_3d_ncdhw_cdhw";
- // depthwise_conv_3d_ndhwc_dhwc
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4))
+ if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp))
return "linalg.depthwise_conv_3d_ndhwc_dhwc";
+ ArrayAttr indexingMaps = genericOp.getIndexingMaps();
+ if (indexingMaps.size() < 3) return "";
+ unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
Block *body = genericOp.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
@@ -541,42 +520,11 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
}
static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
- ArrayAttr indexingMaps = genericOp.getIndexingMaps();
- if (indexingMaps.size() < 3) return "";
- unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
- // conv_3d_ncdhw_fcdhw
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
+ if (isaConv3DNcdhwFcdhwOp(genericOp))
return "linalg.conv_3d_ncdhw_fcdhw";
- // conv_3d_ndhwc_dhwcf
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4))
+ if (isaConv3DNdhwcDhwcfOp(genericOp))
return "linalg.conv_3d_ndhwc_dhwcf";
- // depthwise_conv_3d_ndhwc_dhwcm
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5))
+ if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp))
return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
return "";
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index c5bb184c726f8..b3e79e8c1a409 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -736,6 +736,123 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
}
+bool isaConv3DOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
+
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
+ return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2));
+}
+
+bool isaConv3DNcdhwFcdhwOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+}
+
+bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
+}
+
+bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,6})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
+}
+
+bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4));
+}
+
+bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) {
+ if (isa<linalg::Conv1DOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4));
+}
+
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold,
ValueRange typeDynDims) {
>From e5acca4e6d5c76cfd00688d487c16f6082ecd3a7 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 9 Oct 2025 02:43:28 -0500
Subject: [PATCH 11/18] Fix the NamedOp versions
---
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 40 ++++++++++++-------------
1 file changed, 20 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index b3e79e8c1a409..2d6d51d858853 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -460,7 +460,7 @@ bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) {
}
bool isaConv2DOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv2DOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -475,7 +475,7 @@ bool isaConv2DOp(LinalgOp op) {
}
bool isaConv2DNhwcFhwcOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv2DNhwcFhwcOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -494,7 +494,7 @@ bool isaConv2DNhwcFhwcOp(LinalgOp op) {
}
bool isaConv2DNhwcHwcfOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv2DNhwcHwcfOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -513,7 +513,7 @@ bool isaConv2DNhwcHwcfOp(LinalgOp op) {
}
bool isaConv2DNchwFchwOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv2DNchwFchwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -532,7 +532,7 @@ bool isaConv2DNchwFchwOp(LinalgOp op) {
}
bool isaConv2DNhwcFhwcQOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv2DNhwcFhwcQOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -552,7 +552,7 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op) {
}
bool isaConv2DNchwFchwQOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv2DNchwFchwQOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -572,7 +572,7 @@ bool isaConv2DNchwFchwQOp(LinalgOp op) {
}
bool isaConv2DNgchwFgchwOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv2DNgchwFgchwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -593,7 +593,7 @@ bool isaConv2DNgchwFgchwOp(LinalgOp op) {
}
bool isaConv2DNgchwGfchwOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv2DNgchwGfchwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -615,7 +615,7 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op) {
}
bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv2DNgchwGfchwQOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -637,7 +637,7 @@ bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
}
bool isaConv2DNhwgcGfhwcOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv2DNhwgcGfhwcOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -658,7 +658,7 @@ bool isaConv2DNhwgcGfhwcOp(LinalgOp op) {
}
bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::DepthwiseConv2DNchwChwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -677,7 +677,7 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) {
}
bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -696,7 +696,7 @@ bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) {
}
bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -716,7 +716,7 @@ bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) {
}
bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -737,7 +737,7 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
}
bool isaConv3DOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv3DOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -753,7 +753,7 @@ bool isaConv3DOp(LinalgOp op) {
}
bool isaConv3DNcdhwFcdhwOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv3DNcdhwFcdhwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -773,7 +773,7 @@ bool isaConv3DNcdhwFcdhwOp(LinalgOp op) {
}
bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv3DNdhwcDhwcfOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -793,7 +793,7 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
}
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -814,7 +814,7 @@ bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
}
bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -834,7 +834,7 @@ bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) {
}
bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
>From 535f7e95288274f104b5c6fda01012016613d3b3 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 9 Oct 2025 03:22:39 -0500
Subject: [PATCH 12/18] Pooling ops'
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 15 +
.../Dialect/Linalg/Transforms/Specialize.cpp | 243 ++-----------
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 328 ++++++++++++++++++
3 files changed, 373 insertions(+), 213 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index ad5e0818b90f5..1a1b70d3eb979 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -140,6 +140,21 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op);
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op);
bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op);
bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op);
+bool isaPoolingNchwMaxOp(LinalgOp op);
+bool isaPoolingNchwSumOp(LinalgOp op);
+bool isaPoolingNhwcMaxOp(LinalgOp op);
+bool isaPoolingNhwcMinOp(LinalgOp op);
+bool isaPoolingNhwcSumOp(LinalgOp op);
+bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op);
+bool isaPoolingNhwcMinUnsignedOp(LinalgOp op);
+bool isaPoolingNcwMaxOp(LinalgOp op);
+bool isaPoolingNcwSumOp(LinalgOp op);
+bool isaPoolingNwcMaxOp(LinalgOp op);
+bool isaPoolingNwcMinOp(LinalgOp op);
+bool isaPoolingNwcSumOp(LinalgOp op);
+bool isaPoolingNdhwcMaxOp(LinalgOp op);
+bool isaPoolingNdhwcMinOp(LinalgOp op);
+bool isaPoolingNdhwcSumOp(LinalgOp op);
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 6ecc6a024bed8..aef3a1480d289 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,106 +237,6 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
-/// Utility to match block body for linalg.pool* ops.
-template <typename... OpTypes>
-static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
- Operation *defOp = yieldVal.getDefiningOp();
- // if (!defOp) return false;
- if (!(isa_and_present<OpTypes>(defOp) || ...)) return false;
-
- BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
- BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
- if (!lhsArg || !rhsArg) return false;
- return true;
-}
-
-static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
-}
-
-static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
-}
-
-static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
-}
-
-static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
-}
-
-static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
-}
-
-static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
- uint32_t mapIndex, uint32_t dimIndex) {
- auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
- if (dimIndex < affineMap.getNumResults())
- return affineMap.getResult(dimIndex);
- return nullptr;
-}
-
-// Check if `expr` is either:
-// - a dimension expr alone (implying *1), or
-// - a multiplication of dimension expr by constant.
-bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) {
- if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
- dim = dExpr;
- constantValue = 1;
- return true;
- }
-
- auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
- if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
- return false;
-
- AffineExpr lhs = mulExpr.getLHS();
- AffineExpr rhs = mulExpr.getRHS();
-
- if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
- if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
- dim = dExpr;
- constantValue = cst.getValue();
- return true;
- }
- }
- if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
- if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
- dim = dExpr;
- constantValue = cst.getValue();
- return true;
- }
- }
- return false;
-}
-
-bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) {
- unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
- AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
- auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
- if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
- return false;
-
- AffineExpr dim0, dim1;
- // TODO(Abhishek-Varma): Use this information in specialize.cpp.
- int64_t c0, c1;
-
- if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
- isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
- // Pattern matched with dims and constants extracted.
- AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
- AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
- return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr));
- }
- return false;
-}
-
-bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
- return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
-}
-
static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
if (isaConv1DOp(genericOp)) return "linalg.conv_1d";
return "";
@@ -349,41 +249,16 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
return "linalg.depthwise_conv_1d_nwc_wc";
if (isaConv2DOp(genericOp))
return "linalg.conv_2d";
-
- ArrayAttr indexingMaps = genericOp.getIndexingMaps();
- unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
- Block *body = genericOp.getBlock();
- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
- Value yieldVal = yieldOp.getOperand(0);
- // pooling_ncw_max
- // pooling_ncw_sum
- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
- // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
- // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2)) {
- if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
- return "linalg.pooling_ncw_max";
- if (bodyMatcherForSumPoolOps(yieldVal, body))
- return "linalg.pooling_ncw_sum";
- }
- // pooling_nwc_max
- // pooling_nwc_min
- // pooling_nwc_sum
- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
- // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
- // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2)) {
- if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
- return "linalg.pooling_nwc_max";
- if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
- return "linalg.pooling_nwc_min";
- if (bodyMatcherForSumPoolOps(yieldVal, body))
- return "linalg.pooling_nwc_sum";
- }
+ if (isaPoolingNcwMaxOp(genericOp))
+ return "linalg.pooling_ncw_max";
+ if (isaPoolingNcwSumOp(genericOp))
+ return "linalg.pooling_ncw_sum";
+ if (isaPoolingNwcMaxOp(genericOp))
+ return "linalg.pooling_nwc_max";
+ if (isaPoolingNwcMinOp(genericOp))
+ return "linalg.pooling_nwc_min";
+ if (isaPoolingNwcSumOp(genericOp))
+ return "linalg.pooling_nwc_sum";
return "";
}
@@ -402,61 +277,22 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
return "linalg.depthwise_conv_2d_nchw_chw";
if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
return "linalg.depthwise_conv_2d_nhwc_hwc";
-
- ArrayAttr indexingMaps = genericOp.getIndexingMaps();
- if (indexingMaps.size() < 3) return "";
- unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
if (isaConv3DOp(genericOp))
return "linalg.conv_3d";
-
- Block *body = genericOp.getBlock();
- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
- Value yieldVal = yieldOp.getOperand(0);
- // pooling_nchw_max
- // pooling_nchw_sum
- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3)) {
- if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
- return "linalg.pooling_nchw_max";
- if (bodyMatcherForSumPoolOps(yieldVal, body))
- return "linalg.pooling_nchw_sum";
- }
- // pooling_nhwc_max
- // pooling_nhwc_min
- // pooling_nhwc_sum
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
- // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) {
- if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
- return "linalg.pooling_nhwc_max";
- if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
- return "linalg.pooling_nhwc_min";
- if (bodyMatcherForSumPoolOps(yieldVal, body))
- return "linalg.pooling_nhwc_sum";
- }
- // pooling_nhwc_max_unsigned
- // pooling_nhwc_min_unsigned
- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) {
- if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body))
- return "linalg.pooling_nhwc_max_unsigned";
- if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body))
- return "linalg.pooling_nhwc_min_unsigned";
- }
+ if (isaPoolingNchwMaxOp(genericOp))
+ return "linalg.pooling_nchw_max";
+ if (isaPoolingNchwSumOp(genericOp))
+ return "linalg.pooling_nchw_sum";
+ if (isaPoolingNhwcMaxOp(genericOp))
+ return "linalg.pooling_nhwc_max";
+ if (isaPoolingNhwcMinOp(genericOp))
+ return "linalg.pooling_nhwc_min";
+ if (isaPoolingNhwcSumOp(genericOp))
+ return "linalg.pooling_nhwc_sum";
+ if (isaPoolingNhwcMaxUnsignedOp(genericOp))
+ return "linalg.pooling_nhwc_max_unsigned";
+ if (isaPoolingNhwcMinUnsignedOp(genericOp))
+ return "linalg.pooling_nhwc_min_unsigned";
return "";
}
@@ -491,31 +327,12 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
return "linalg.depthwise_conv_3d_ncdhw_cdhw";
if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp))
return "linalg.depthwise_conv_3d_ndhwc_dhwc";
-
- ArrayAttr indexingMaps = genericOp.getIndexingMaps();
- if (indexingMaps.size() < 3) return "";
- unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
- Block *body = genericOp.getBlock();
- auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
- Value yieldVal = yieldOp.getOperand(0);
- // pooling_ndhwc_max
- // pooling_ndhwc_min
- // pooling_ndhwc_sum
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
- // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)) {
- if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
- return "linalg.pooling_ndhwc_max";
- if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
- return "linalg.pooling_ndhwc_min";
- if (bodyMatcherForSumPoolOps(yieldVal, body))
- return "linalg.pooling_ndhwc_sum";
- }
+ if (isaPoolingNdhwcMaxOp(genericOp))
+ return "linalg.pooling_ndhwc_max";
+ if (isaPoolingNdhwcMinOp(genericOp))
+ return "linalg.pooling_ndhwc_min";
+ if (isaPoolingNdhwcSumOp(genericOp))
+ return "linalg.pooling_ndhwc_sum";
return "";
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 2d6d51d858853..127e61d7db050 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -853,6 +853,334 @@ bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) {
matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4));
}
+bool isaPoolingNchwMaxOp(LinalgOp op) {
+ if (isa<linalg::PoolingNchwMaxOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3) &&
+ bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNchwSumOp(LinalgOp op) {
+ if (isa<linalg::PoolingNchwSumOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3) &&
+ bodyMatcherForSumPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNhwcMaxOp(LinalgOp op) {
+ if (isa<linalg::PoolingNhwcMaxOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNhwcMinOp(LinalgOp op) {
+ if (isa<linalg::PoolingNhwcMinOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForMinSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNhwcSumOp(LinalgOp op) {
+ if (isa<linalg::PoolingNhwcSumOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForSumPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op) {
+ if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForMaxUnsignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNhwcMinUnsignedOp(LinalgOp op) {
+ if (isa<linalg::PoolingNhwcMinUnsignedOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForMinUnsignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNcwMaxOp(LinalgOp op) {
+ if (isa<linalg::PoolingNcwMaxOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
+ // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+ // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+ bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNcwSumOp(LinalgOp op) {
+ if (isa<linalg::PoolingNcwSumOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
+ // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+ // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+ bodyMatcherForSumPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNwcMaxOp(LinalgOp op) {
+ if (isa<linalg::PoolingNwcMaxOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+ // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+ // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNwcMinOp(LinalgOp op) {
+ if (isa<linalg::PoolingNwcMinOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+ // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+ // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ bodyMatcherForMinSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNwcSumOp(LinalgOp op) {
+ if (isa<linalg::PoolingNwcSumOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+ // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
+ // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ bodyMatcherForSumPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNdhwcMaxOp(LinalgOp op) {
+ if (isa<linalg::PoolingNdhwcMaxOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+ // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+ bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNdhwcMinOp(LinalgOp op) {
+ if (isa<linalg::PoolingNdhwcMinOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+ // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+ bodyMatcherForMinSignedPoolOps(yieldVal, body));
+}
+
+bool isaPoolingNdhwcSumOp(LinalgOp op) {
+ if (isa<linalg::PoolingNdhwcSumOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+ // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+ bodyMatcherForSumPoolOps(yieldVal, body));
+}
+
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold,
ValueRange typeDynDims) {
>From b06ba750c040a5b2506271ccf96d60fceb02cdef Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 9 Oct 2025 03:30:56 -0500
Subject: [PATCH 13/18] Updated maps
---
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 66 ++++++++++++-------------
1 file changed, 33 insertions(+), 33 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 127e61d7db050..e847f2cf3aef2 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -909,9 +909,9 @@ bool isaPoolingNhwcMaxOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
- // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
@@ -931,9 +931,9 @@ bool isaPoolingNhwcMinOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
- // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
@@ -953,9 +953,9 @@ bool isaPoolingNhwcSumOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
- // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
@@ -1019,9 +1019,9 @@ bool isaPoolingNcwMaxOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
- // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
- // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
@@ -1040,9 +1040,9 @@ bool isaPoolingNcwSumOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
- // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
- // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
@@ -1061,9 +1061,9 @@ bool isaPoolingNwcMaxOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
- // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
- // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
@@ -1082,9 +1082,9 @@ bool isaPoolingNwcMinOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
- // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
- // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
@@ -1103,9 +1103,9 @@ bool isaPoolingNwcSumOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
- // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
- // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
@@ -1124,9 +1124,9 @@ bool isaPoolingNdhwcMaxOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
- // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
@@ -1147,9 +1147,9 @@ bool isaPoolingNdhwcMinOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
- // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
@@ -1170,9 +1170,9 @@ bool isaPoolingNdhwcSumOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
- // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
>From 5aeb3716ff7ddb376bb4a9db962431dba1b55b56 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 9 Oct 2025 04:42:10 -0500
Subject: [PATCH 14/18] Missing ops
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 4 +
.../Dialect/Linalg/Transforms/Specialize.cpp | 16 ++++
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 83 +++++++++++++++++++
3 files changed, 103 insertions(+)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 1a1b70d3eb979..2f7868bd55182 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -128,15 +128,19 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op);
bool isaConv2DNchwFchwQOp(LinalgOp op);
bool isaConv2DNgchwFgchwOp(LinalgOp op);
bool isaConv2DNgchwGfchwOp(LinalgOp op);
+bool isaConv2DNhwcHwcfQOp(LinalgOp op);
+bool isaConv2DNhwgcGfhwcQOp(LinalgOp op);
bool isaConv2DNgchwGfchwQOp(LinalgOp op);
bool isaConv2DNhwgcGfhwcOp(LinalgOp op);
bool isaDepthwiseConv2DNchwChwOp(LinalgOp op);
bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op);
bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op);
+bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op);
bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op);
bool isaConv3DOp(LinalgOp op);
bool isaConv3DNcdhwFcdhwOp(LinalgOp op);
bool isaConv3DNdhwcDhwcfOp(LinalgOp op);
+bool isaConv3DNdhwcDhwcfQOp(LinalgOp op);
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op);
bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op);
bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index aef3a1480d289..031cb3b919b96 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -277,6 +277,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
return "linalg.depthwise_conv_2d_nchw_chw";
if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
return "linalg.depthwise_conv_2d_nhwc_hwc";
+ if (isaDepthwiseConv2DNhwcHwcQOp(genericOp))
+ return "linalg.depthwise_conv_2d_nhwc_hwc_q";
if (isaConv3DOp(genericOp))
return "linalg.conv_3d";
if (isaPoolingNchwMaxOp(genericOp))
@@ -307,6 +309,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
return "linalg.conv_2d_nhwc_fhwc_q";
if (isaConv2DNchwFchwQOp(genericOp))
return "linalg.conv_2d_nchw_fchw_q";
+ if (isaConv2DNhwcHwcfQOp(genericOp))
+ return "linalg.conv_2d_nhwc_hwcf_q";
if (isaDepthwiseConv2DNhwcHwcmOp(genericOp))
return "linalg.depthwise_conv_2d_nhwc_hwcm";
if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp))
@@ -323,6 +327,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
return "linalg.conv_2d_ngchw_gfchw_q";
if (isaConv2DNhwgcGfhwcOp(genericOp))
return "linalg.conv_2d_nhwgc_gfhwc";
+ if (isaConv2DNhwgcGfhwcQOp(genericOp))
+ return "linalg.conv_2d_nhwgc_gfhwc_q";
if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp))
return "linalg.depthwise_conv_3d_ncdhw_cdhw";
if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp))
@@ -341,6 +347,8 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
return "linalg.conv_3d_ncdhw_fcdhw";
if (isaConv3DNdhwcDhwcfOp(genericOp))
return "linalg.conv_3d_ndhwc_dhwcf";
+ if (isaConv3DNdhwcDhwcfQOp(genericOp))
+ return "linalg.conv_3d_ndhwc_dhwcf_q";
if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp))
return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
return "";
@@ -412,6 +420,10 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwQOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") {
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_2d_nhwc_hwcf_q") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfQOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc_q") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcQOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.depthwise_conv_2d_nchw_chw") {
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNchwChwOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc") {
@@ -420,12 +432,16 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm_q") {
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc_q") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcQOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.conv_3d") {
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.conv_3d_ncdhw_fcdhw") {
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNcdhwFcdhwOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.conv_3d_ndhwc_dhwcf") {
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfOp>(genericOp, resultTypes, inputs, outputs);
+ } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf_q") {
+ namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfQOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwcm") {
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(genericOp, resultTypes, inputs, outputs);
} else if (convKind == "linalg.depthwise_conv_3d_ncdhw_cdhw") {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index e847f2cf3aef2..b239f62a7049d 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -614,6 +614,48 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op) {
matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
}
+bool isaConv2DNhwcHwcfQOp(LinalgOp op) {
+ if (isa<linalg::Conv2DNhwcHwcfQOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
+}
+
+bool isaConv2DNhwgcGfhwcQOp(LinalgOp op) {
+ if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
+}
+
bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
if (isa<linalg::Conv2DNgchwGfchwQOp>(op)) return true;
@@ -736,6 +778,26 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
}
+bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op) {
+ if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,0,0,4})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
+}
+
bool isaConv3DOp(LinalgOp op) {
if (isa<linalg::Conv3DOp>(op)) return true;
@@ -792,6 +854,27 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
}
+bool isaConv3DNdhwcDhwcfQOp(LinalgOp op) {
+ if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op)) return true;
+
+ if (!isaConvolutionOpInterface(op)) return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
+ return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
+}
+
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) return true;
>From f1b8e80ff65c4db082b634eadc3976c35dc4ccac Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 14 Oct 2025 04:37:28 -0500
Subject: [PATCH 15/18] Make use of dilations/strides info
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 84 +--
.../Dialect/Linalg/Transforms/Specialize.cpp | 355 +++++-------
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 546 ++++++++++++------
3 files changed, 548 insertions(+), 437 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 2f7868bd55182..44ebc101d7c37 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -115,50 +115,50 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
//===----------------------------------------------------------------------===//
bool isaConv1DOp(LinalgOp op);
-bool isaConv1DNwcWcfOp(LinalgOp op);
-bool isaConv1DNcwFcwOp(LinalgOp op);
-bool isaDepthwiseConv1DNcwCwOp(LinalgOp op);
-bool isaDepthwiseConv1DNwcWcOp(LinalgOp op);
-bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op);
+bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
bool isaConv2DOp(LinalgOp op);
-bool isaConv2DNhwcFhwcOp(LinalgOp op);
-bool isaConv2DNhwcHwcfOp(LinalgOp op);
-bool isaConv2DNchwFchwOp(LinalgOp op);
-bool isaConv2DNhwcFhwcQOp(LinalgOp op);
-bool isaConv2DNchwFchwQOp(LinalgOp op);
-bool isaConv2DNgchwFgchwOp(LinalgOp op);
-bool isaConv2DNgchwGfchwOp(LinalgOp op);
-bool isaConv2DNhwcHwcfQOp(LinalgOp op);
-bool isaConv2DNhwgcGfhwcQOp(LinalgOp op);
-bool isaConv2DNgchwGfchwQOp(LinalgOp op);
-bool isaConv2DNhwgcGfhwcOp(LinalgOp op);
-bool isaDepthwiseConv2DNchwChwOp(LinalgOp op);
-bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op);
-bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op);
-bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op);
-bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op);
+bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
bool isaConv3DOp(LinalgOp op);
-bool isaConv3DNcdhwFcdhwOp(LinalgOp op);
-bool isaConv3DNdhwcDhwcfOp(LinalgOp op);
-bool isaConv3DNdhwcDhwcfQOp(LinalgOp op);
-bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op);
-bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op);
-bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op);
-bool isaPoolingNchwMaxOp(LinalgOp op);
-bool isaPoolingNchwSumOp(LinalgOp op);
-bool isaPoolingNhwcMaxOp(LinalgOp op);
-bool isaPoolingNhwcMinOp(LinalgOp op);
-bool isaPoolingNhwcSumOp(LinalgOp op);
-bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op);
-bool isaPoolingNhwcMinUnsignedOp(LinalgOp op);
-bool isaPoolingNcwMaxOp(LinalgOp op);
-bool isaPoolingNcwSumOp(LinalgOp op);
-bool isaPoolingNwcMaxOp(LinalgOp op);
-bool isaPoolingNwcMinOp(LinalgOp op);
-bool isaPoolingNwcSumOp(LinalgOp op);
-bool isaPoolingNdhwcMaxOp(LinalgOp op);
-bool isaPoolingNdhwcMinOp(LinalgOp op);
-bool isaPoolingNdhwcSumOp(LinalgOp op);
+bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 031cb3b919b96..94dfbcc15d055 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,250 +237,169 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
-static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
- if (isaConv1DOp(genericOp)) return "linalg.conv_1d";
- return "";
+template <typename ConvOpTy>
+static FailureOr<LinalgOp> specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+ SmallVector<Value> inputs = genericOp.getDpsInputs();
+ ValueRange outputs = genericOp.getDpsInits();
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+ ? TypeRange(ValueRange(outputs))
+ : TypeRange{};
+ LinalgOp namedOp;
+ if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> || std::is_same_v<ConvOpTy, linalg::Conv2DOp> || std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes, inputs, outputs);
+ } else {
+ Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+ Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+ }
+ return namedOp;
}
-static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
- if (isaDepthwiseConv1DNcwCwOp(genericOp))
- return "linalg.depthwise_conv_1d_ncw_cw";
- if (isaDepthwiseConv1DNwcWcOp(genericOp))
- return "linalg.depthwise_conv_1d_nwc_wc";
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConv1DOp(genericOp)) return specializeToConvOp<linalg::Conv1DOp>(rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaDepthwiseConv1DNcwCwOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv1DNcwCwOp>(rewriter, genericOp, dilations, strides);
+ if (isaDepthwiseConv1DNwcWcOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(rewriter, genericOp, dilations, strides);
if (isaConv2DOp(genericOp))
- return "linalg.conv_2d";
- if (isaPoolingNcwMaxOp(genericOp))
- return "linalg.pooling_ncw_max";
- if (isaPoolingNcwSumOp(genericOp))
- return "linalg.pooling_ncw_sum";
- if (isaPoolingNwcMaxOp(genericOp))
- return "linalg.pooling_nwc_max";
- if (isaPoolingNwcMinOp(genericOp))
- return "linalg.pooling_nwc_min";
- if (isaPoolingNwcSumOp(genericOp))
- return "linalg.pooling_nwc_sum";
- return "";
+ return specializeToConvOp<linalg::Conv2DOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNcwMaxOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNcwMaxOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNcwSumOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNcwSumOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNwcMaxOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNwcMaxOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNwcMinOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNwcMinOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNwcSumOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNwcSumOp>(rewriter, genericOp, dilations, strides);
+ return failure();
}
-static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
- if (isaDepthwiseConv1DNwcWcmOp(genericOp))
- return "linalg.depthwise_conv_1d_nwc_wcm";
- if (isaConv1DNwcWcfOp(genericOp))
- return "linalg.conv_1d_nwc_wcf";
- if (isaConv1DNcwFcwOp(genericOp))
- return "linalg.conv_1d_ncw_fcw";
- return "";
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaDepthwiseConv1DNwcWcmOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv1DNwcWcmOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv1DNwcWcfOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv1DNwcWcfOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv1DNcwFcwOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv1DNcwFcwOp>(rewriter, genericOp, dilations, strides);
+ return failure();
}
-static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
- if (isaDepthwiseConv2DNchwChwOp(genericOp))
- return "linalg.depthwise_conv_2d_nchw_chw";
- if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
- return "linalg.depthwise_conv_2d_nhwc_hwc";
- if (isaDepthwiseConv2DNhwcHwcQOp(genericOp))
- return "linalg.depthwise_conv_2d_nhwc_hwc_q";
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaDepthwiseConv2DNchwChwOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(rewriter, genericOp, dilations, strides);
+ if (isaDepthwiseConv2DNhwcHwcOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcOp>(rewriter, genericOp, dilations, strides);
+ if (isaDepthwiseConv2DNhwcHwcQOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcQOp>(rewriter, genericOp, dilations, strides);
if (isaConv3DOp(genericOp))
- return "linalg.conv_3d";
- if (isaPoolingNchwMaxOp(genericOp))
- return "linalg.pooling_nchw_max";
- if (isaPoolingNchwSumOp(genericOp))
- return "linalg.pooling_nchw_sum";
- if (isaPoolingNhwcMaxOp(genericOp))
- return "linalg.pooling_nhwc_max";
- if (isaPoolingNhwcMinOp(genericOp))
- return "linalg.pooling_nhwc_min";
- if (isaPoolingNhwcSumOp(genericOp))
- return "linalg.pooling_nhwc_sum";
- if (isaPoolingNhwcMaxUnsignedOp(genericOp))
- return "linalg.pooling_nhwc_max_unsigned";
- if (isaPoolingNhwcMinUnsignedOp(genericOp))
- return "linalg.pooling_nhwc_min_unsigned";
- return "";
+ return specializeToConvOp<linalg::Conv3DOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNchwMaxOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNchwMaxOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNchwSumOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNchwSumOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNhwcMaxOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNhwcMinOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNhwcSumOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNhwcMaxUnsignedOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNhwcMinUnsignedOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(rewriter, genericOp, dilations, strides);
+ return failure();
}
-static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
- if (isaConv2DNhwcFhwcOp(genericOp))
- return "linalg.conv_2d_nhwc_fhwc";
- if (isaConv2DNhwcHwcfOp(genericOp))
- return "linalg.conv_2d_nhwc_hwcf";
- if (isaConv2DNchwFchwOp(genericOp))
- return "linalg.conv_2d_nchw_fchw";
- if (isaConv2DNhwcFhwcQOp(genericOp))
- return "linalg.conv_2d_nhwc_fhwc_q";
- if (isaConv2DNchwFchwQOp(genericOp))
- return "linalg.conv_2d_nchw_fchw_q";
- if (isaConv2DNhwcHwcfQOp(genericOp))
- return "linalg.conv_2d_nhwc_hwcf_q";
- if (isaDepthwiseConv2DNhwcHwcmOp(genericOp))
- return "linalg.depthwise_conv_2d_nhwc_hwcm";
- if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp))
- return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
- return "";
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConv2DNhwcFhwcOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv2DNhwcFhwcOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv2DNhwcHwcfOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv2DNhwcHwcfOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv2DNchwFchwOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv2DNchwFchwOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv2DNhwcFhwcQOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv2DNhwcFhwcQOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv2DNchwFchwQOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv2DNchwFchwQOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv2DNhwcHwcfQOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv2DNhwcHwcfQOp>(rewriter, genericOp, dilations, strides);
+ if (isaDepthwiseConv2DNhwcHwcmOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmOp>(rewriter, genericOp, dilations, strides);
+ if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(rewriter, genericOp, dilations, strides);
+ return failure();
}
-static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
- if (isaConv2DNgchwFgchwOp(genericOp))
- return "linalg.conv_2d_ngchw_fgchw";
- if (isaConv2DNgchwGfchwOp(genericOp))
- return "linalg.conv_2d_ngchw_gfchw";
- if (isaConv2DNgchwGfchwQOp(genericOp))
- return "linalg.conv_2d_ngchw_gfchw_q";
- if (isaConv2DNhwgcGfhwcOp(genericOp))
- return "linalg.conv_2d_nhwgc_gfhwc";
- if (isaConv2DNhwgcGfhwcQOp(genericOp))
- return "linalg.conv_2d_nhwgc_gfhwc_q";
- if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp))
- return "linalg.depthwise_conv_3d_ncdhw_cdhw";
- if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp))
- return "linalg.depthwise_conv_3d_ndhwc_dhwc";
- if (isaPoolingNdhwcMaxOp(genericOp))
- return "linalg.pooling_ndhwc_max";
- if (isaPoolingNdhwcMinOp(genericOp))
- return "linalg.pooling_ndhwc_min";
- if (isaPoolingNdhwcSumOp(genericOp))
- return "linalg.pooling_ndhwc_sum";
- return "";
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConv2DNgchwFgchwOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv2DNgchwFgchwOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv2DNgchwGfchwOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv2DNgchwGfchwOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv2DNgchwGfchwQOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv2DNgchwGfchwQOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv2DNhwgcGfhwcOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv2DNhwgcGfhwcOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv2DNhwgcGfhwcQOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv2DNhwgcGfhwcQOp>(rewriter, genericOp, dilations, strides);
+ if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(rewriter, genericOp, dilations, strides);
+ if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNdhwcMaxOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNdhwcMaxOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNdhwcMinOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNdhwcMinOp>(rewriter, genericOp, dilations, strides);
+ if (isaPoolingNdhwcSumOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNdhwcSumOp>(rewriter, genericOp, dilations, strides);
+ return failure();
}
-static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
- if (isaConv3DNcdhwFcdhwOp(genericOp))
- return "linalg.conv_3d_ncdhw_fcdhw";
- if (isaConv3DNdhwcDhwcfOp(genericOp))
- return "linalg.conv_3d_ndhwc_dhwcf";
- if (isaConv3DNdhwcDhwcfQOp(genericOp))
- return "linalg.conv_3d_ndhwc_dhwcf_q";
- if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp))
- return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
- return "";
+static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConv3DNcdhwFcdhwOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv3DNcdhwFcdhwOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv3DNdhwcDhwcfOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv3DNdhwcDhwcfOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv3DNdhwcDhwcfQOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::Conv3DNdhwcDhwcfQOp>(rewriter, genericOp, dilations, strides);
+ if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(rewriter, genericOp, dilations, strides);
+ return failure();
}
-static std::string inferConvolutionKind(GenericOp genericOp) {
+// Converts linalg.generic to named linalg.*conv* where possible.
+static FailureOr<LinalgOp> inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
unsigned totalIterators = iteratorTypes.size();
switch(totalIterators) {
case 2:
- return inferBasedOnRank2ConvIteratorTypes(genericOp);
+ return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
case 4:
- return inferBasedOnRank4ConvIteratorTypes(genericOp);
+ return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
case 5:
- return inferBasedOnRank5ConvIteratorTypes(genericOp);
+ return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
case 6:
- return inferBasedOnRank6ConvIteratorTypes(genericOp);
+ return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
case 7:
- return inferBasedOnRank7ConvIteratorTypes(genericOp);
+ return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
case 8:
- return inferBasedOnRank8ConvIteratorTypes(genericOp);
+ return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
case 9:
- return inferBasedOnRank9ConvIteratorTypes(genericOp);
- }
- return "";
-}
-
-// Converts linalg.generic to named linalg.*conv* where possible.
-static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
- GenericOp genericOp) {
- std::string convKind = inferConvolutionKind(genericOp);
- if (convKind == "") return failure();
- SmallVector<Value> inputs = genericOp.getDpsInputs();
- ValueRange outputs = genericOp.getDpsInits();
- SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
- SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
- ? TypeRange(ValueRange(outputs))
- : TypeRange{};
- LinalgOp namedOp;
- if (convKind == "linalg.conv_1d") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_1d_nwc_wcf") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DNwcWcfOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_1d_ncw_fcw") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv1DNcwFcwOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.depthwise_conv_1d_ncw_cw") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNcwCwOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.depthwise_conv_1d_nwc_wc") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.depthwise_conv_1d_nwc_wcm") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcmOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_2d") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_2d_nhwc_fhwc") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcFhwcOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_2d_nhwc_hwcf") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_2d_nchw_fchw") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNchwFchwOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_2d_nhwc_fhwc_q") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcFhwcQOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_2d_nchw_fchw_q") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNchwFchwQOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_2d_ngchw_fgchw") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwFgchwOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_2d_ngchw_gfchw") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_2d_ngchw_gfchw_q") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwQOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_2d_nhwc_hwcf_q") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfQOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_2d_nhwgc_gfhwc_q") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcQOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.depthwise_conv_2d_nchw_chw") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNchwChwOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm_q") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc_q") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcQOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_3d") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_3d_ncdhw_fcdhw") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNcdhwFcdhwOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.conv_3d_ndhwc_dhwcf_q") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfQOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwcm") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.depthwise_conv_3d_ncdhw_cdhw") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwc") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_nchw_max") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNchwMaxOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_nchw_sum") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNchwSumOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_nhwc_max") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_nhwc_min") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMinOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_nhwc_sum") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcSumOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_nhwc_max_unsigned") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_nhwc_min_unsigned") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMinUnsignedOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_ncw_max") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNcwMaxOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_ncw_sum") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNcwSumOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_nwc_max") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcMaxOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_nwc_min") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcMinOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_nwc_sum") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcSumOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_ndhwc_max") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcMaxOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_ndhwc_min") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcMinOp>(genericOp, resultTypes, inputs, outputs);
- } else if (convKind == "linalg.pooling_ndhwc_sum") {
- namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcSumOp>(genericOp, resultTypes, inputs, outputs);
+ return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
}
- return namedOp;
-
return failure();
}
@@ -566,7 +485,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
// Convolution - e.g. *conv*
if (isaConvolutionOpInterface(genericOp)) {
- return specializeLinalgConvolutions(rewriter, genericOp);
+ return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
}
return failure();
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index b239f62a7049d..548f43f83b0ed 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -319,7 +319,8 @@ static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_
return false;
}
-static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) {
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim,
+ int64_t& dilation, int64_t& stride) {
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
@@ -327,7 +328,6 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un
return false;
AffineExpr dim0, dim1;
- // TODO(Abhishek-Varma): Use this information in specialize.cpp.
int64_t c0, c1;
if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
@@ -335,7 +335,15 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un
// Pattern matched with dims and constants extracted.
AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
- return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr));
+ if (dim0 == fExpr && dim1 == oExpr) {
+ dilation = c0;
+ stride = c1;
+ return true;
+ } else if (dim1 == fExpr && dim0 == oExpr) {
+ dilation = c1;
+ stride = c0;
+ return true;
+ }
}
return false;
}
@@ -354,6 +362,16 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t>
return true;
}
+static bool updateConvDilationsAndStrides(SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides, ArrayRef<int64_t> tempDilations, ArrayRef<int64_t> tempStrides) {
+ if (!(dilations && strides))
+ return true;
+ for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) {
+ dilations->push_back(dilation);
+ strides->push_back(stride);
+ }
+ return true;
+}
+
bool isaConv1DOp(LinalgOp op) {
if (isa<linalg::Conv1DOp>(op)) return true;
@@ -365,10 +383,12 @@ bool isaConv1DOp(LinalgOp op) {
// #map = affine_map<(d0, d1) -> (d0 + d1)>
// #map1 = affine_map<(d0, d1) -> (d1)>
// #map2 = affine_map<(d0, d1) -> (d0)>
- return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0);
+ SmallVector<int64_t> tempDilations(1,1);
+ SmallVector<int64_t> tempStrides(1,1);
+ return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]);
}
-bool isaConv1DNwcWcfOp(LinalgOp op) {
+bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv1DNwcWcfOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -377,16 +397,20 @@ bool isaConv1DNwcWcfOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1,1);
+ SmallVector<int64_t> tempStrides(1,1);
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv1DNcwFcwOp(LinalgOp op) {
+bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv1DNcwFcwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -395,16 +419,20 @@ bool isaConv1DNcwFcwOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1,1);
+ SmallVector<int64_t> tempStrides(1,1);
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaDepthwiseConv1DNcwCwOp(LinalgOp op) {
+bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::DepthwiseConv1DNcwCwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -413,16 +441,21 @@ bool isaDepthwiseConv1DNcwCwOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1,1);
+ SmallVector<int64_t> tempStrides(1,1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2));
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaDepthwiseConv1DNwcWcOp(LinalgOp op) {
+// -------------------
+bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::DepthwiseConv1DNwcWcOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -431,16 +464,20 @@ bool isaDepthwiseConv1DNwcWcOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1,1);
+ SmallVector<int64_t> tempStrides(1,1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1));
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) {
+bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -449,14 +486,18 @@ bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,4})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1,1);
+ SmallVector<int64_t> tempStrides(1,1);
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
bool isaConv2DOp(LinalgOp op) {
@@ -467,14 +508,16 @@ bool isaConv2DOp(LinalgOp op) {
ArrayAttr indexingMaps = op.getIndexingMaps();
if (!verifyConvIndexingMapSizes(indexingMaps, {2,2,2})) return false;
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
- return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1));
+ return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[1], tempStrides[1]));
}
-bool isaConv2DNhwcFhwcOp(LinalgOp op) {
+bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv2DNhwcFhwcOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -483,17 +526,21 @@ bool isaConv2DNhwcFhwcOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv2DNhwcHwcfOp(LinalgOp op) {
+bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv2DNhwcHwcfOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -502,17 +549,21 @@ bool isaConv2DNhwcHwcfOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv2DNchwFchwOp(LinalgOp op) {
+bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv2DNchwFchwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -521,17 +572,21 @@ bool isaConv2DNchwFchwOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv2DNhwcFhwcQOp(LinalgOp op) {
+bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv2DNhwcFhwcQOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -540,18 +595,22 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv2DNchwFchwQOp(LinalgOp op) {
+bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv2DNchwFchwQOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -560,18 +619,22 @@ bool isaConv2DNchwFchwQOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv2DNgchwFgchwOp(LinalgOp op) {
+bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv2DNgchwFgchwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -580,19 +643,23 @@ bool isaConv2DNgchwFgchwOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv2DNgchwGfchwOp(LinalgOp op) {
+bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv2DNgchwGfchwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -601,20 +668,23 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- return (indexingMaps.size() == 3 &&
- matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv2DNhwcHwcfQOp(LinalgOp op) {
+bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv2DNhwcHwcfQOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -623,18 +693,22 @@ bool isaConv2DNhwcHwcfQOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv2DNhwgcGfhwcQOp(LinalgOp op) {
+bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -643,20 +717,24 @@ bool isaConv2DNhwgcGfhwcQOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
+bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv2DNgchwGfchwQOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -665,20 +743,24 @@ bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv2DNhwgcGfhwcOp(LinalgOp op) {
+bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv2DNhwgcGfhwcOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -687,19 +769,23 @@ bool isaConv2DNhwgcGfhwcOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) {
+bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::DepthwiseConv2DNchwChwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -708,17 +794,21 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3));
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], tempStrides[1]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) {
+bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -727,17 +817,21 @@ bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) {
+bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -746,18 +840,22 @@ bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,5})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
+bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -766,19 +864,23 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,5})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op) {
+bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -787,15 +889,19 @@ bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,0,0,4})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
bool isaConv3DOp(LinalgOp op) {
@@ -806,15 +912,17 @@ bool isaConv3DOp(LinalgOp op) {
ArrayAttr indexingMaps = op.getIndexingMaps();
if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
+ SmallVector<int64_t> tempDilations(3,1);
+ SmallVector<int64_t> tempStrides(3,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
- return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2));
+ return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[1], tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[2], tempStrides[2]));
}
-bool isaConv3DNcdhwFcdhwOp(LinalgOp op) {
+bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv3DNcdhwFcdhwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -823,18 +931,22 @@ bool isaConv3DNcdhwFcdhwOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(3,1);
+ SmallVector<int64_t> tempStrides(3,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[2], tempStrides[2]) &&
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
+bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv3DNdhwcDhwcfOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -843,18 +955,22 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(3,1);
+ SmallVector<int64_t> tempStrides(3,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaConv3DNdhwcDhwcfQOp(LinalgOp op) {
+bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -863,19 +979,23 @@ bool isaConv3DNdhwcDhwcfQOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
+
+ SmallVector<int64_t> tempDilations(3,1);
+ SmallVector<int64_t> tempStrides(3,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
+bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -884,19 +1004,23 @@ bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,6})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(3,1);
+ SmallVector<int64_t> tempStrides(3,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) {
+bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -905,18 +1029,22 @@ bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(3,1);
+ SmallVector<int64_t> tempStrides(3,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4));
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4, tempDilations[2], tempStrides[2]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) {
+bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -925,18 +1053,22 @@ bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) {
if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(3,1);
+ SmallVector<int64_t> tempStrides(3,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNchwMaxOp(LinalgOp op) {
+bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNchwMaxOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -948,17 +1080,21 @@ bool isaPoolingNchwMaxOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNchwSumOp(LinalgOp op) {
+bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNchwSumOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -970,17 +1106,21 @@ bool isaPoolingNchwSumOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
bodyMatcherForSumPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNhwcMaxOp(LinalgOp op) {
+bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNhwcMaxOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -992,17 +1132,21 @@ bool isaPoolingNhwcMaxOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNhwcMinOp(LinalgOp op) {
+bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNhwcMinOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -1014,17 +1158,21 @@ bool isaPoolingNhwcMinOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
bodyMatcherForMinSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNhwcSumOp(LinalgOp op) {
+bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNhwcSumOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -1036,17 +1184,21 @@ bool isaPoolingNhwcSumOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
bodyMatcherForSumPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op) {
+bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -1058,17 +1210,21 @@ bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
bodyMatcherForMaxUnsignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNhwcMinUnsignedOp(LinalgOp op) {
+bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNhwcMinUnsignedOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -1080,17 +1236,21 @@ bool isaPoolingNhwcMinUnsignedOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2,1);
+ SmallVector<int64_t> tempStrides(2,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
bodyMatcherForMinUnsignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNcwMaxOp(LinalgOp op) {
+bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNcwMaxOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -1102,16 +1262,20 @@ bool isaPoolingNcwMaxOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1,1);
+ SmallVector<int64_t> tempStrides(1,1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNcwSumOp(LinalgOp op) {
+bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNcwSumOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -1123,16 +1287,20 @@ bool isaPoolingNcwSumOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1,1);
+ SmallVector<int64_t> tempStrides(1,1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
bodyMatcherForSumPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNwcMaxOp(LinalgOp op) {
+bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNwcMaxOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -1144,16 +1312,20 @@ bool isaPoolingNwcMaxOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1,1);
+ SmallVector<int64_t> tempStrides(1,1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNwcMinOp(LinalgOp op) {
+bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNwcMinOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -1165,16 +1337,20 @@ bool isaPoolingNwcMinOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1,1);
+ SmallVector<int64_t> tempStrides(1,1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
bodyMatcherForMinSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNwcSumOp(LinalgOp op) {
+bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNwcSumOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -1186,16 +1362,20 @@ bool isaPoolingNwcSumOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1,1);
+ SmallVector<int64_t> tempStrides(1,1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
bodyMatcherForSumPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNdhwcMaxOp(LinalgOp op) {
+bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNdhwcMaxOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -1207,18 +1387,22 @@ bool isaPoolingNdhwcMaxOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(3,1);
+ SmallVector<int64_t> tempStrides(3,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNdhwcMinOp(LinalgOp op) {
+bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNdhwcMinOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -1230,18 +1414,22 @@ bool isaPoolingNdhwcMinOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(3,1);
+ SmallVector<int64_t> tempStrides(3,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
bodyMatcherForMinSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
-bool isaPoolingNdhwcSumOp(LinalgOp op) {
+bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
if (isa<linalg::PoolingNdhwcSumOp>(op)) return true;
if (!isaConvolutionOpInterface(op)) return false;
@@ -1253,15 +1441,19 @@ bool isaPoolingNdhwcSumOp(LinalgOp op) {
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(3,1);
+ SmallVector<int64_t> tempStrides(3,1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
+ bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
bodyMatcherForSumPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
}
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
>From 1a9417d3623074ba40ee95c7ef94c5c14b678d87 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 14 Oct 2025 06:19:23 -0500
Subject: [PATCH 16/18] Add lit test and clean up
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 5 +-
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 9 +
...oundtrip-linalg-convolution-named-ops.mlir | 615 ++++++++++++++++++
3 files changed, 627 insertions(+), 2 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 94dfbcc15d055..12eb17ef0a435 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,7 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
+/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` with `dilations` and `strides`.
template <typename ConvOpTy>
static FailureOr<LinalgOp> specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
SmallVector<Value> inputs = genericOp.getDpsInputs();
@@ -380,7 +381,7 @@ static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank9ConvIteratorTypes(Rewri
return failure();
}
-// Converts linalg.generic to named linalg.*conv* where possible.
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To improve the search speed, the convolution ops have been segregated based on the rank of iterator types array.
static FailureOr<LinalgOp> inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
unsigned totalIterators = iteratorTypes.size();
@@ -483,7 +484,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
return specializeLinalgContractions(rewriter, genericOp);
}
- // Convolution - e.g. *conv*
+ // Convolution - e.g. *conv/pooling*
if (isaConvolutionOpInterface(genericOp)) {
return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 548f43f83b0ed..0d4e8aa5e6382 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -319,6 +319,11 @@ static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_
return false;
}
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+/// indexingMaps[0].getResult(iDim) ==
+/// indexingMaps[1].getResult(fDim) * <CST_1> +
+/// indexingMaps[n-1].getResult(oDim) * <CST_2>
+/// where, CST_1 and CST_2 can be any constant.
static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim,
int64_t& dilation, int64_t& stride) {
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
@@ -348,10 +353,13 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un
return false;
}
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+/// indexingMaps[aIndex].getResult(aDim) == indexingMaps[bIndex].getResult(bDim)
static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
}
+/// Give an array of AffineMaps, verify each map to be of the corresponding `expectedSize`.
static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t> expectedSizes) {
if (indexingMaps.size() != expectedSizes.size()) return false;
@@ -362,6 +370,7 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t>
return true;
}
+/// Utility to update `dilations` and `strides` by copy the corresponding data from `tempDilations` and `tempStrides`.
static bool updateConvDilationsAndStrides(SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides, ArrayRef<int64_t> tempDilations, ArrayRef<int64_t> tempStrides) {
if (!(dilations && strides))
return true;
diff --git a/mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir b/mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir
new file mode 100644
index 0000000000000..8cd57044e613f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/roundtrip-linalg-convolution-named-ops.mlir
@@ -0,0 +1,615 @@
+// The following test examples of linalg convolution named ops lowered to linalg.generic and then
+// lifted back up to named op.
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+func.func @conv_1d_nwc_wcf(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
+ linalg.conv_1d_nwc_wcf {dilations = dense<3> : tensor<1xi64>,
+ strides = dense<2> : tensor<1xi64>}
+ ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs (%output: memref<?x?x?xf32>)
+ return
+}
+// CHECK: @conv_1d_nwc_wcf
+// CHECK: linalg.conv_1d_nwc_wcf
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_1d_ncw_fcw(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?xf32>) {
+ linalg.conv_1d_ncw_fcw {dilations = dense<3> : tensor<1xi64>,
+ strides = dense<2> : tensor<1xi64>}
+ ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs (%output: memref<?x?x?xf32>)
+ return
+}
+// CHECK: @conv_1d_ncw_fcw
+// CHECK: linalg.conv_1d_ncw_fcw
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_1d(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
+ linalg.conv_1d ins(%in, %filter : memref<?xf32>, memref<?xf32>)
+ outs(%out : memref<?xf32>)
+ return
+}
+// CHECK: @conv_1d
+// CHECK: linalg.conv_1d
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_1d_ncw_cw(%input: memref<?x?x?xf32>, %filter: memref<?x?xf32>, %output: memref<?x?x?xf32>) {
+ linalg.depthwise_conv_1d_ncw_cw {dilations = dense<3> : tensor<1xi64>,
+ strides = dense<2> : tensor<1xi64>}
+ ins (%input, %filter: memref<?x?x?xf32>, memref<?x?xf32>)
+ outs (%output: memref<?x?x?xf32>)
+ return
+}
+// CHECK: @depthwise_conv_1d_ncw_cw
+// CHECK: linalg.depthwise_conv_1d_ncw_cw
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_1d_nwc_wc(%input: memref<?x?x?xf32>, %filter: memref<?x?xf32>, %output: memref<?x?x?xf32>) {
+ linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>,
+ strides = dense<2> : tensor<1xi64>}
+ ins (%input, %filter: memref<?x?x?xf32>, memref<?x?xf32>)
+ outs (%output: memref<?x?x?xf32>)
+ return
+}
+// CHECK: @depthwise_conv_1d_nwc_wc
+// CHECK: linalg.depthwise_conv_1d_nwc_wc
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_1d_nwc_wcm(%input: memref<?x?x?xf32>, %filter: memref<?x?x?xf32>, %output: memref<?x?x?x?xf32>) {
+ linalg.depthwise_conv_1d_nwc_wcm {dilations = dense<1> : tensor<1xi64>,
+ strides = dense<1> : tensor<1xi64>}
+ ins (%input, %filter: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs (%output: memref<?x?x?x?xf32>)
+ return
+}
+// CHECK: @depthwise_conv_1d_nwc_wcm
+// CHECK: linalg.depthwise_conv_1d_nwc_wcm
+// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_2d(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
+ linalg.conv_2d ins(%in, %filter : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%out: memref<?x?xf32>)
+ return
+}
+// CHECK: @conv_2d
+// CHECK: linalg.conv_2d
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_2d_nchw_fchw(%arg0 : tensor<?x?x?x?xf32>,
+ %arg1 : tensor<?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>) ->
+ (tensor<4x8x12x16xf32>, tensor<?x?x?x?xf32>) {
+ %0 = linalg.conv_2d_nchw_fchw {dilations = dense<[2,4]> : tensor<2xi64>, strides = dense<[3,5]> : tensor<2xi64>}
+ ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ %1 = tensor.cast %0 : tensor<?x?x?x?xf32> to tensor<4x8x12x16xf32>
+ return %1, %0 : tensor<4x8x12x16xf32>, tensor<?x?x?x?xf32>
+}
+// CHECK: @conv_2d_nchw_fchw
+// CHECK: linalg.conv_2d_nchw_fchw
+// CHECK-SAME: dilations = dense<[2, 4]> : tensor<2xi64>, strides = dense<[3, 5]> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_2d_nchw_fchw_q(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter, %inputzp, %filterzp: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, i32, i32)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @conv_2d_nchw_fchw_q
+// CHECK: linalg.conv_2d_nchw_fchw_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_2d_ngchw_fgchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @conv_2d_ngchw_fgchw
+// CHECK: linalg.conv_2d_ngchw_fgchw
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_2d_ngchw_gfchw(%input: memref<?x?x?x?x?xi8>, %filter: memref<?x?x?x?x?xi8>, %output: memref<?x?x?x?x?xi32>) {
+ linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: memref<?x?x?x?x?xi8>, memref<?x?x?x?x?xi8>)
+ outs (%output: memref<?x?x?x?x?xi32>)
+ return
+}
+// CHECK: @conv_2d_ngchw_gfchw
+// CHECK: linalg.conv_2d_ngchw_gfchw
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_2d_ngchw_gfchw_q(%input: memref<?x?x?x?x?xi8>, %filter: memref<?x?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: memref<?x?x?x?x?xi32>) {
+ linalg.conv_2d_ngchw_gfchw_q {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter, %inputzp, %filterzp: memref<?x?x?x?x?xi8>, memref<?x?x?x?x?xi8>, i32, i32)
+ outs (%output: memref<?x?x?x?x?xi32>)
+ return
+}
+// CHECK: @conv_2d_ngchw_gfchw_q
+// CHECK: linalg.conv_2d_ngchw_gfchw_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwc_hwcf_q(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %output: memref<?x?x?x?xf32>) {
+ linalg.conv_2d_nhwc_hwcf_q {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } ins(%input, %filter, %inputzp, %filterzp : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, i32, i32) outs(%output : memref<?x?x?x?xf32>)
+ return
+}
+// CHECK: @conv_2d_nhwc_hwcf_q
+// CHECK: linalg.conv_2d_nhwc_hwcf_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwgc_gfhwc_q(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %output: memref<?x?x?x?x?xf32>) {
+ linalg.conv_2d_nhwgc_gfhwc_q {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } ins(%input, %filter, %inputzp, %filterzp : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>, i32, i32) outs(%output : memref<?x?x?x?x?xf32>)
+ return
+}
+// CHECK: @conv_2d_nhwgc_gfhwc_q
+// CHECK: linalg.conv_2d_nhwgc_gfhwc_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwc_q(%input: tensor<?x?x?x?xi8>, %filter: tensor<?x?x?xi8>, %inputzp: i32, %filterzp: i32, %output: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>{
+ %res = linalg.depthwise_conv_2d_nhwc_hwc_q {
+ dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>
+ } ins(%input, %filter, %inputzp, %filterzp : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%output : tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+ return %res : tensor<?x?x?x?xi32>
+}
+// CHECK: @depthwise_conv_2d_nhwc_hwc_q
+// CHECK: linalg.depthwise_conv_2d_nhwc_hwc_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwc_fhwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @conv_2d_nhwc_fhwc
+// CHECK: linalg.conv_2d_nhwc_fhwc
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwc_fhwc_q(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter, %inputzp, %filterzp: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, i32, i32)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @conv_2d_nhwc_fhwc_q
+// CHECK: linalg.conv_2d_nhwc_fhwc_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwc_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @conv_2d_nhwc_hwcf
+// CHECK: linalg.conv_2d_nhwc_hwcf
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_2d_nhwgc_gfhwc(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
+ linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+ outs (%output: memref<?x?x?x?x?xf32>)
+ return
+}
+// CHECK: @conv_2d_nhwgc_gfhwc
+// CHECK: linalg.conv_2d_nhwgc_gfhwc
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_2d_nchw_chw
+// CHECK: linalg.depthwise_conv_2d_nchw_chw
+// CHECK-SAME: dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_2d_nhwc_hwc
+// CHECK: linalg.depthwise_conv_2d_nhwc_hwc
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwcm(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_2d_nhwc_hwcm
+// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nhwc_hwcm_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x?xi8>, %arg2: tensor<?x?x?x?x?xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x?xi32> {
+ %0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+ return %0 : tensor<?x?x?x?x?xi32>
+}
+// CHECK: @depthwise_conv_2d_nhwc_hwcm_q
+// CHECK: linalg.depthwise_conv_2d_nhwc_hwcm_q
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_3d(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
+ linalg.conv_3d ins(%in, %filter : memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%out : memref<?x?x?xf32>)
+ return
+}
+// CHECK: @conv_3d
+// CHECK: linalg.conv_3d
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_3d_ncdhw_fcdhw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.conv_3d_ncdhw_fcdhw {dilations = dense<1> : tensor<3xi64>,
+ strides = dense<1> : tensor<3xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @conv_3d_ncdhw_fcdhw
+// CHECK: linalg.conv_3d_ncdhw_fcdhw
+// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>,
+ strides = dense<1> : tensor<3xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @conv_3d_ndhwc_dhwcf
+// CHECK: linalg.conv_3d_ndhwc_dhwcf
+// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @conv_3d_ndhwc_dhwcf_q(%input: tensor<?x?x?x?x?xi8>, %filter: tensor<?x?x?x?x?xi8>, %inputzp: i32, %filterzp: i32, %init: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32> {
+ %0 = linalg.conv_3d_ndhwc_dhwcf_q {dilations = dense<1> : tensor<3xi64>,
+ strides = dense<1> : tensor<3xi64>}
+ ins(%input, %filter, %inputzp, %filterzp : tensor<?x?x?x?x?xi8>, tensor<?x?x?x?x?xi8>, i32, i32)
+ outs (%init: tensor<?x?x?x?x?xi32>) -> tensor<?x?x?x?x?xi32>
+ return %0 : tensor<?x?x?x?x?xi32>
+}
+// CHECK: @conv_3d_ndhwc_dhwcf_q
+// CHECK: linalg.conv_3d_ndhwc_dhwcf_q
+// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_3d_ncdhw_cdhw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_3d_ncdhw_cdhw {dilations = dense<1> : tensor<3xi64>,
+ strides = dense<1> : tensor<3xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_3d_ncdhw_cdhw
+// CHECK: linalg.depthwise_conv_3d_ncdhw_cdhw
+// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_3d_ndhwc_dhwc(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_3d_ndhwc_dhwc {dilations = dense<1> : tensor<3xi64>,
+ strides = dense<1> : tensor<3xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_3d_ndhwc_dhwc
+// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwc
+// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>,
+ strides = dense<1> : tensor<3xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_3d_ndhwc_dhwcm
+// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwcm
+// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nchw_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nchw_max
+// CHECK: linalg.pooling_nchw_max
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nchw_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nchw_sum
+// CHECK: linalg.pooling_nchw_sum
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_ncw_max(%input: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>, %filter: tensor<?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.pooling_ncw_max {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+ ins(%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+ outs(%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK: @pooling_ncw_max
+// CHECK: linalg.pooling_ncw_max
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_ncw_sum(%input: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>, %filter: tensor<?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.pooling_ncw_sum {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+ ins(%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+ outs(%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK: @pooling_ncw_sum
+// CHECK: linalg.pooling_ncw_sum
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_ndhwc_max(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.pooling_ndhwc_max {dilations = dense<1> : tensor<3xi64>,
+ strides = dense<1> : tensor<3xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @pooling_ndhwc_max
+// CHECK: linalg.pooling_ndhwc_max
+// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_ndhwc_min(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.pooling_ndhwc_min {dilations = dense<1> : tensor<3xi64>,
+ strides = dense<1> : tensor<3xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @pooling_ndhwc_min
+// CHECK: linalg.pooling_ndhwc_min
+// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_ndhwc_sum(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+ %0 = linalg.pooling_ndhwc_sum {dilations = dense<1> : tensor<3xi64>,
+ strides = dense<1> : tensor<3xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?xf32>
+}
+// CHECK: @pooling_ndhwc_sum
+// CHECK: linalg.pooling_ndhwc_sum
+// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nhwc_max
+// CHECK: linalg.pooling_nhwc_max
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_min(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nhwc_min
+// CHECK: linalg.pooling_nhwc_min
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nhwc_sum
+// CHECK: linalg.pooling_nhwc_sum
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+ %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xi32>, tensor<?x?xi32>)
+ outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+ return %0 : tensor<?x?x?x?xi32>
+}
+// CHECK: @pooling_nhwc_max_unsigned
+// CHECK: linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nhwc_min_unsigned
+// CHECK: linalg.pooling_nhwc_min
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nwc_max(%input: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>, %filter: tensor<?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.pooling_nwc_max {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ ins(%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+ outs(%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK: @pooling_nwc_max
+// CHECK: linalg.pooling_nwc_max
+// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nwc_min(%input: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>, %filter: tensor<?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.pooling_nwc_min {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+ ins(%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+ outs(%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK: @pooling_nwc_min
+// CHECK: linalg.pooling_nwc_min
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nwc_sum(%input: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>, %filter: tensor<?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.pooling_nwc_sum {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
+ ins(%input, %filter: tensor<?x?x?xf32>, tensor<?xf32>)
+ outs(%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK: @pooling_nwc_sum
+// CHECK: linalg.pooling_nwc_sum
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+// CHECK-NOT: linalg.generic
>From bba3921a36b934ee96f4bc4794e47dd329e9995b Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 14 Oct 2025 07:22:14 -0500
Subject: [PATCH 17/18] Format code
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 153 +-
.../Dialect/Linalg/Transforms/Specialize.cpp | 221 +-
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 2066 ++++++++++-------
3 files changed, 1527 insertions(+), 913 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 44ebc101d7c37..0f39098ca9946 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -115,50 +115,119 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
//===----------------------------------------------------------------------===//
bool isaConv1DOp(LinalgOp op);
-bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv1DNcwCwOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
bool isaConv2DOp(LinalgOp op);
-bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNhwcFhwcQOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNchwFchwQOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNgchwFgchwOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNgchwGfchwOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNhwcHwcfQOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNhwgcGfhwcQOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNgchwGfchwQOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv2DNhwgcGfhwcOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
bool isaConv3DOp(LinalgOp op);
-bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
-bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations = nullptr, SmallVector<int64_t>* strides = nullptr);
+bool isaConv3DNcdhwFcdhwOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv3DNdhwcDhwcfOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaConv3DNdhwcDhwcfQOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNdhwcMaxOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNdhwcMinOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+bool isaPoolingNdhwcSumOp(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 12eb17ef0a435..e08705b90e7b0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,9 +237,12 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
-/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` with `dilations` and `strides`.
+/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
template <typename ConvOpTy>
-static FailureOr<LinalgOp> specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+ ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
SmallVector<Value> inputs = genericOp.getDpsInputs();
ValueRange outputs = genericOp.getDpsInits();
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
@@ -247,159 +250,227 @@ static FailureOr<LinalgOp> specializeToConvOp(RewriterBase &rewriter, GenericOp
? TypeRange(ValueRange(outputs))
: TypeRange{};
LinalgOp namedOp;
- if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> || std::is_same_v<ConvOpTy, linalg::Conv2DOp> || std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
- namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes, inputs, outputs);
+ if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+ inputs, outputs);
} else {
Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
- namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+ genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
}
return namedOp;
}
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
- if (isaConv1DOp(genericOp)) return specializeToConvOp<linalg::Conv1DOp>(rewriter, genericOp, dilations, strides);
+ if (isaConv1DOp(genericOp))
+ return specializeToConvOp<linalg::Conv1DOp>(rewriter, genericOp, dilations,
+ strides);
return failure();
}
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
if (isaDepthwiseConv1DNcwCwOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::DepthwiseConv1DNcwCwOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::DepthwiseConv1DNcwCwOp>(
+ rewriter, genericOp, dilations, strides);
if (isaDepthwiseConv1DNwcWcOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
+ rewriter, genericOp, dilations, strides);
if (isaConv2DOp(genericOp))
- return specializeToConvOp<linalg::Conv2DOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv2DOp>(rewriter, genericOp, dilations,
+ strides);
if (isaPoolingNcwMaxOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNcwMaxOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNcwMaxOp>(rewriter, genericOp,
+ dilations, strides);
if (isaPoolingNcwSumOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNcwSumOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNcwSumOp>(rewriter, genericOp,
+ dilations, strides);
if (isaPoolingNwcMaxOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNwcMaxOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNwcMaxOp>(rewriter, genericOp,
+ dilations, strides);
if (isaPoolingNwcMinOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNwcMinOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNwcMinOp>(rewriter, genericOp,
+ dilations, strides);
if (isaPoolingNwcSumOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNwcSumOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNwcSumOp>(rewriter, genericOp,
+ dilations, strides);
return failure();
}
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
if (isaDepthwiseConv1DNwcWcmOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::DepthwiseConv1DNwcWcmOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::DepthwiseConv1DNwcWcmOp>(
+ rewriter, genericOp, dilations, strides);
if (isaConv1DNwcWcfOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv1DNwcWcfOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv1DNwcWcfOp>(rewriter, genericOp,
+ dilations, strides);
if (isaConv1DNcwFcwOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv1DNcwFcwOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv1DNcwFcwOp>(rewriter, genericOp,
+ dilations, strides);
return failure();
}
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
if (isaDepthwiseConv2DNchwChwOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
+ rewriter, genericOp, dilations, strides);
if (isaDepthwiseConv2DNhwcHwcOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcOp>(
+ rewriter, genericOp, dilations, strides);
if (isaDepthwiseConv2DNhwcHwcQOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcQOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcQOp>(
+ rewriter, genericOp, dilations, strides);
if (isaConv3DOp(genericOp))
- return specializeToConvOp<linalg::Conv3DOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv3DOp>(rewriter, genericOp, dilations,
+ strides);
if (isaPoolingNchwMaxOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNchwMaxOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNchwMaxOp>(rewriter, genericOp,
+ dilations, strides);
if (isaPoolingNchwSumOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNchwSumOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNchwSumOp>(rewriter, genericOp,
+ dilations, strides);
if (isaPoolingNhwcMaxOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
+ dilations, strides);
if (isaPoolingNhwcMinOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
+ dilations, strides);
if (isaPoolingNhwcSumOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
+ dilations, strides);
if (isaPoolingNhwcMaxUnsignedOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
+ rewriter, genericOp, dilations, strides);
if (isaPoolingNhwcMinUnsignedOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
+ rewriter, genericOp, dilations, strides);
return failure();
}
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
if (isaConv2DNhwcFhwcOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv2DNhwcFhwcOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv2DNhwcFhwcOp>(rewriter, genericOp,
+ dilations, strides);
if (isaConv2DNhwcHwcfOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv2DNhwcHwcfOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv2DNhwcHwcfOp>(rewriter, genericOp,
+ dilations, strides);
if (isaConv2DNchwFchwOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv2DNchwFchwOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv2DNchwFchwOp>(rewriter, genericOp,
+ dilations, strides);
if (isaConv2DNhwcFhwcQOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv2DNhwcFhwcQOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv2DNhwcFhwcQOp>(rewriter, genericOp,
+ dilations, strides);
if (isaConv2DNchwFchwQOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv2DNchwFchwQOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv2DNchwFchwQOp>(rewriter, genericOp,
+ dilations, strides);
if (isaConv2DNhwcHwcfQOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv2DNhwcHwcfQOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv2DNhwcHwcfQOp>(rewriter, genericOp,
+ dilations, strides);
if (isaDepthwiseConv2DNhwcHwcmOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmOp>(
+ rewriter, genericOp, dilations, strides);
if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(
+ rewriter, genericOp, dilations, strides);
return failure();
}
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
if (isaConv2DNgchwFgchwOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv2DNgchwFgchwOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv2DNgchwFgchwOp>(rewriter, genericOp,
+ dilations, strides);
if (isaConv2DNgchwGfchwOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv2DNgchwGfchwOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv2DNgchwGfchwOp>(rewriter, genericOp,
+ dilations, strides);
if (isaConv2DNgchwGfchwQOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv2DNgchwGfchwQOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv2DNgchwGfchwQOp>(rewriter, genericOp,
+ dilations, strides);
if (isaConv2DNhwgcGfhwcOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv2DNhwgcGfhwcOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv2DNhwgcGfhwcOp>(rewriter, genericOp,
+ dilations, strides);
if (isaConv2DNhwgcGfhwcQOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv2DNhwgcGfhwcQOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv2DNhwgcGfhwcQOp>(rewriter, genericOp,
+ dilations, strides);
if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(
+ rewriter, genericOp, dilations, strides);
if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(
+ rewriter, genericOp, dilations, strides);
if (isaPoolingNdhwcMaxOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNdhwcMaxOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNdhwcMaxOp>(rewriter, genericOp,
+ dilations, strides);
if (isaPoolingNdhwcMinOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNdhwcMinOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNdhwcMinOp>(rewriter, genericOp,
+ dilations, strides);
if (isaPoolingNdhwcSumOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::PoolingNdhwcSumOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::PoolingNdhwcSumOp>(rewriter, genericOp,
+ dilations, strides);
return failure();
}
-static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
if (isaConv3DNcdhwFcdhwOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv3DNcdhwFcdhwOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv3DNcdhwFcdhwOp>(rewriter, genericOp,
+ dilations, strides);
if (isaConv3DNdhwcDhwcfOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv3DNdhwcDhwcfOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv3DNdhwcDhwcfOp>(rewriter, genericOp,
+ dilations, strides);
if (isaConv3DNdhwcDhwcfQOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::Conv3DNdhwcDhwcfQOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::Conv3DNdhwcDhwcfQOp>(rewriter, genericOp,
+ dilations, strides);
if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(rewriter, genericOp, dilations, strides);
+ return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ rewriter, genericOp, dilations, strides);
return failure();
}
-// Converts linalg.generic to named linalg.*conv/pooling* where possible. To improve the search speed, the convolution ops have been segregated based on the rank of iterator types array.
-static FailureOr<LinalgOp> inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
- SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
+// improve the search speed, the convolution ops have been segregated based on
+// the rank of iterator types array.
+static FailureOr<LinalgOp>
+inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
+ SmallVector<utils::IteratorType> iteratorTypes =
+ genericOp.getIteratorTypesArray();
unsigned totalIterators = iteratorTypes.size();
- switch(totalIterators) {
- case 2:
- return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
- case 4:
- return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
- case 5:
- return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
- case 6:
- return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
- case 7:
- return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
- case 8:
- return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
- case 9:
- return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
+ switch (totalIterators) {
+ case 2:
+ return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
+ case 4:
+ return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
+ case 5:
+ return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
+ case 6:
+ return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
+ case 7:
+ return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
+ case 8:
+ return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
+ case 9:
+ return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
}
return failure();
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 0d4e8aa5e6382..8ea5e7a10e17e 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -249,28 +249,34 @@ template <typename... OpTypes>
static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
Operation *defOp = yieldVal.getDefiningOp();
// if (!defOp) return false;
- if (!(isa_and_present<OpTypes>(defOp) || ...)) return false;
+ if (!(isa_and_present<OpTypes>(defOp) || ...))
+ return false;
- BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
- BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
- if (!lhsArg || !rhsArg) return false;
+ BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+ BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
+ if (!lhsArg || !rhsArg)
+ return false;
return true;
}
static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal,
+ body);
}
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
+ body);
}
static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal,
+ body);
}
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
- return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
+ body);
}
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
@@ -288,7 +294,8 @@ static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
// Check if `expr` is either:
// - a dimension expr alone (implying *1), or
// - a multiplication of dimension expr by constant.
-static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) {
+static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim,
+ int64_t &constantValue) {
if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
dim = dExpr;
constantValue = 1;
@@ -320,12 +327,13 @@ static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_
}
/// Given an array of AffineMaps `indexingMaps` verify the following :-
-/// indexingMaps[0].getResult(iDim) ==
+/// indexingMaps[0].getResult(iDim) ==
/// indexingMaps[1].getResult(fDim) * <CST_1> +
/// indexingMaps[n-1].getResult(oDim) * <CST_2>
/// where, CST_1 and CST_2 can be any constant.
-static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim,
- int64_t& dilation, int64_t& stride) {
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
+ unsigned fDim, unsigned oDim,
+ int64_t &dilation, int64_t &stride) {
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
@@ -354,24 +362,37 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un
}
/// Given an array of AffineMaps `indexingMaps` verify the following :-
-/// indexingMaps[aIndex].getResult(aDim) == indexingMaps[bIndex].getResult(bDim)
-static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
- return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
-}
-
-/// Give an array of AffineMaps, verify each map to be of the corresponding `expectedSize`.
-static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t> expectedSizes) {
- if (indexingMaps.size() != expectedSizes.size()) return false;
+/// indexingMaps[aIndex].getResult(aDim) ==
+/// indexingMaps[bIndex].getResult(bDim)
+static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex,
+ unsigned aDim, unsigned bIndex,
+ unsigned bDim) {
+ return getAffineMapDim(indexingMaps, aIndex, aDim) ==
+ getAffineMapDim(indexingMaps, bIndex, bDim);
+}
+
+/// Give an array of AffineMaps, verify each map to be of the corresponding
+/// `expectedSize`.
+static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
+ ArrayRef<int64_t> expectedSizes) {
+ if (indexingMaps.size() != expectedSizes.size())
+ return false;
- for (auto [indexingMap, expectedSize] : llvm::zip_equal(indexingMaps, expectedSizes)) {
+ for (auto [indexingMap, expectedSize] :
+ llvm::zip_equal(indexingMaps, expectedSizes)) {
auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
- if (affineMap.getNumResults() != expectedSize) return false;
+ if (affineMap.getNumResults() != expectedSize)
+ return false;
}
return true;
}
-/// Utility to update `dilations` and `strides` by copy the corresponding data from `tempDilations` and `tempStrides`.
-static bool updateConvDilationsAndStrides(SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides, ArrayRef<int64_t> tempDilations, ArrayRef<int64_t> tempStrides) {
+/// Utility to update `dilations` and `strides` by copy the corresponding data
+/// from `tempDilations` and `tempStrides`.
+static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides,
+ ArrayRef<int64_t> tempDilations,
+ ArrayRef<int64_t> tempStrides) {
if (!(dilations && strides))
return true;
for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) {
@@ -382,1087 +403,1540 @@ static bool updateConvDilationsAndStrides(SmallVector<int64_t>* dilations, Small
}
bool isaConv1DOp(LinalgOp op) {
- if (isa<linalg::Conv1DOp>(op)) return true;
+ if (isa<linalg::Conv1DOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {1,1,1})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {1, 1, 1}))
+ return false;
+
// #map = affine_map<(d0, d1) -> (d0 + d1)>
// #map1 = affine_map<(d0, d1) -> (d1)>
// #map2 = affine_map<(d0, d1) -> (d0)>
- SmallVector<int64_t> tempDilations(1,1);
- SmallVector<int64_t> tempStrides(1,1);
- return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]);
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
+ return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
+ /*oDim=*/0, tempDilations[0],
+ tempStrides[0]);
}
-bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv1DNwcWcfOp>(op)) return true;
+bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv1DNwcWcfOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 3}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(1,1);
- SmallVector<int64_t> tempStrides(1,1);
+
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv1DNcwFcwOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv1DNcwFcwOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 3}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(1,1);
- SmallVector<int64_t> tempStrides(1,1);
+
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::DepthwiseConv1DNcwCwOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(1,1);
- SmallVector<int64_t> tempStrides(1,1);
+
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
}
// -------------------
-bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::DepthwiseConv1DNwcWcOp>(op)) return true;
+bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(1,1);
- SmallVector<int64_t> tempStrides(1,1);
+
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 4}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(1,1);
- SmallVector<int64_t> tempStrides(1,1);
+
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
}
bool isaConv2DOp(LinalgOp op) {
- if (isa<linalg::Conv2DOp>(op)) return true;
+ if (isa<linalg::Conv2DOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {2,2,2})) return false;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
+ if (!verifyConvIndexingMapSizes(indexingMaps, {2, 2, 2}))
+ return false;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
- return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[1], tempStrides[1]));
+ return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
+ /*oDim=*/0, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+ /*oDim=*/1, tempDilations[1],
+ tempStrides[1]));
}
-bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv2DNhwcFhwcOp>(op)) return true;
+bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNhwcFhwcOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 4}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv2DNhwcHwcfOp>(op)) return true;
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5,
+ // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNhwcHwcfOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 4}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv2DNchwFchwOp>(op)) return true;
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5,
+ // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNchwFchwOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 4}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv2DNhwcFhwcQOp>(op)) return true;
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 +
+ // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 4}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5,
+ // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv2DNchwFchwQOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNchwFchwQOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 4}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 +
+ // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv2DNgchwFgchwOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNgchwFgchwOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv2DNgchwGfchwOp>(op)) return true;
-
- if (!isaConvolutionOpInterface(op)) return false;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6,
+ // d4 + d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1,
+ // d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0,
+ // d1, d2, d3, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+ /*oDim=*/4, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNgchwGfchwOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv2DNhwcHwcfQOp>(op)) return true;
-
- if (!isaConvolutionOpInterface(op)) return false;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6,
+ // d4 + d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2,
+ // d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0,
+ // d1, d2, d3, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+ /*oDim=*/4, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 4}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5,
+ // d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 0, 0, 5}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv2DNgchwGfchwQOp>(op)) return true;
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 +
+ // d6, d3, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3,
+ // d4, d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 0, 0, 5}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv2DNhwgcGfhwcOp>(op)) return true;
-
- if (!isaConvolutionOpInterface(op)) return false;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6,
+ // d4 + d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2,
+ // d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3,
+ // d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+ /*oDim=*/4, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::DepthwiseConv2DNchwChwOp>(op)) return true;
-
- if (!isaConvolutionOpInterface(op)) return false;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 +
+ // d6, d3, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3,
+ // d4, d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) ->
+ // (d0, d1, d2, d3, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], tempStrides[1]));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[1],
+ tempStrides[1]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 5}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op)) return true;
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6,
+ // d3)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 4, 0, 0, 5}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6,
+ // d3)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,0,0,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 0, 0, 4}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
}
bool isaConv3DOp(LinalgOp op) {
- if (isa<linalg::Conv3DOp>(op)) return true;
+ if (isa<linalg::Conv3DOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
-
- SmallVector<int64_t> tempDilations(3,1);
- SmallVector<int64_t> tempStrides(3,1);
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 3, 3}))
+ return false;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
- return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1, tempDilations[1], tempStrides[1]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[2], tempStrides[2]));
-}
-
-bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv3DNcdhwFcdhwOp>(op)) return true;
+ return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
+ /*oDim=*/0, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
+ /*oDim=*/1, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, tempDilations[2],
+ tempStrides[2]));
+}
+
+bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv3DNcdhwFcdhwOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(3,1);
- SmallVector<int64_t> tempStrides(3,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4, tempDilations[2], tempStrides[2]) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv3DNdhwcDhwcfOp>(op)) return true;
-
- if (!isaConvolutionOpInterface(op)) return false;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6,
+ // d3 + d7, d4 + d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d1, d5, d6, d7, d8)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6,
+ // d7, d8) -> (d0, d1, d2, d3, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3,
+ /*oDim=*/3, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4,
+ /*oDim=*/4, tempDilations[2],
+ tempStrides[2]) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv3DNdhwcDhwcfOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 5}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(3,1);
- SmallVector<int64_t> tempStrides(3,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op)) return true;
-
- if (!isaConvolutionOpInterface(op)) return false;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2
+ // + d6, d3 + d7, d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d5, d6, d7, d8, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6,
+ // d7, d8) -> (d0, d1, d2, d3, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[2],
+ tempStrides[2]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 0, 0, 5}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
-
- SmallVector<int64_t> tempDilations(3,1);
- SmallVector<int64_t> tempStrides(3,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> ()>
- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) return true;
-
- if (!isaConvolutionOpInterface(op)) return false;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2
+ // + d6, d3 + d7, d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d5, d6, d7, d8, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6,
+ // d7, d8) -> ()> #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) ->
+ // (d0, d1, d2, d3, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[2],
+ tempStrides[2]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,6})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(3,1);
- SmallVector<int64_t> tempStrides(3,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
- matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op)) return true;
-
- if (!isaConvolutionOpInterface(op)) return false;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2
+ // + d6, d3 + d7, d8)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d5, d6, d7, d8, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6,
+ // d7, d8) -> (d0, d1, d2, d3, d8, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[2],
+ tempStrides[2]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 4, 5}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(3,1);
- SmallVector<int64_t> tempStrides(3,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4, tempDilations[2], tempStrides[2]));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op)) return true;
-
- if (!isaConvolutionOpInterface(op)) return false;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2
+ // + d5, d3 + d6)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7,
+ // d4, d5, d6)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0,
+ // d7, d1, d2, d3)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3,
+ /*oDim=*/4, tempDilations[2],
+ tempStrides[2]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 4, 5}))
+ return false;
+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(3,1);
- SmallVector<int64_t> tempStrides(3,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNchwMaxOp>(op)) return true;
-
- if (!isaConvolutionOpInterface(op)) return false;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 +
+ // d5, d3 + d6, d7)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) ->
+ // (d4, d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) ->
+ // (d0, d1, d2, d3, d7)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[2],
+ tempStrides[2]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNchwMaxOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
- bodyMatcherForMaxSignedPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNchwSumOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1,
+ /*oDim=*/3, tempDilations[1],
+ tempStrides[1]) &&
+ bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNchwSumOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3, tempDilations[1], tempStrides[1]) &&
- bodyMatcherForSumPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNhwcMaxOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1,
+ /*oDim=*/3, tempDilations[1],
+ tempStrides[1]) &&
+ bodyMatcherForSumPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMaxOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
- bodyMatcherForMaxSignedPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNhwcMinOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMinOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
- bodyMatcherForMinSignedPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNhwcSumOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForMinSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcSumOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
- bodyMatcherForSumPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForSumPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
- bodyMatcherForMaxUnsignedPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNhwcMinUnsignedOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForMaxUnsignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {4,2,4})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(2,1);
- SmallVector<int64_t> tempStrides(2,1);
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
- bodyMatcherForMinUnsignedPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNcwMaxOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForMinUnsignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNcwMaxOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(1,1);
- SmallVector<int64_t> tempStrides(1,1);
+
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
- bodyMatcherForMaxSignedPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNcwSumOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNcwSumOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(1,1);
- SmallVector<int64_t> tempStrides(1,1);
+
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2, tempDilations[0], tempStrides[0]) &&
- bodyMatcherForSumPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNwcMaxOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ bodyMatcherForSumPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNwcMaxOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- SmallVector<int64_t> tempDilations(1,1);
- SmallVector<int64_t> tempStrides(1,1);
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
- bodyMatcherForMaxSignedPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNwcMinOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNwcMinOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- SmallVector<int64_t> tempDilations(1,1);
- SmallVector<int64_t> tempStrides(1,1);
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
- bodyMatcherForMinSignedPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNwcSumOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ bodyMatcherForMinSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNwcSumOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {3,1,3})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 1, 3}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
- SmallVector<int64_t> tempDilations(1,1);
- SmallVector<int64_t> tempStrides(1,1);
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3)>
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
- bodyMatcherForSumPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNdhwcMaxOp>(op)) return true;
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ bodyMatcherForSumPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNdhwcMaxOp>(op))
+ return true;
- if (!isaConvolutionOpInterface(op)) return false;
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 3, 5}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(3,1);
- SmallVector<int64_t> tempStrides(3,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
- bodyMatcherForMaxSignedPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNdhwcMinOp>(op)) return true;
-
- if (!isaConvolutionOpInterface(op)) return false;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 +
+ // d6, d3 + d7, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) ->
+ // (d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0,
+ // d1, d2, d3, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[2],
+ tempStrides[2]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+ bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNdhwcMinOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 3, 5}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(3,1);
- SmallVector<int64_t> tempStrides(3,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
- bodyMatcherForMinSignedPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
-}
-
-bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides) {
- if (isa<linalg::PoolingNdhwcSumOp>(op)) return true;
-
- if (!isaConvolutionOpInterface(op)) return false;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 +
+ // d6, d3 + d7, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) ->
+ // (d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0,
+ // d1, d2, d3, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[2],
+ tempStrides[2]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+ bodyMatcherForMinSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNdhwcSumOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
ArrayAttr indexingMaps = op.getIndexingMaps();
- if (!verifyConvIndexingMapSizes(indexingMaps, {5,3,5})) return false;
-
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 3, 5}))
+ return false;
+
Block *body = op.getBlock();
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
Value yieldVal = yieldOp.getOperand(0);
unsigned iIndex = 0, oIndex = 2;
-
- SmallVector<int64_t> tempDilations(3,1);
- SmallVector<int64_t> tempStrides(3,1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
- bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) &&
- matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) &&
- matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
- bodyMatcherForSumPoolOps(yieldVal, body));
- return returnVal && updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides);
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 +
+ // d6, d3 + d7, d4)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) ->
+ // (d5, d6, d7)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0,
+ // d1, d2, d3, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[2],
+ tempStrides[2]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+ bodyMatcherForSumPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
}
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
>From 3852dc4ffeac76056c4e31ac74397ade5f3dc228 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 15 Oct 2025 03:26:58 -0500
Subject: [PATCH 18/18] Export just a single API
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 116 +----
.../Dialect/Linalg/Transforms/Specialize.cpp | 132 +++--
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 452 ++++++++++++++----
3 files changed, 450 insertions(+), 250 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 0f39098ca9946..771d753a8bddb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -111,123 +111,13 @@ std::optional<SmallVector<ReassociationIndices>>
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
//===----------------------------------------------------------------------===//
-// Convolution matcher utilities
+// Convolution matcher utility
//===----------------------------------------------------------------------===//
-bool isaConv1DOp(LinalgOp op);
-bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv1DNcwCwOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DOp(LinalgOp op);
-bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNhwcFhwcQOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNchwFchwQOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNgchwFgchwOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNgchwGfchwOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNhwcHwcfQOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNhwgcGfhwcQOp(LinalgOp op,
+template <typename ConvOpTy>
+bool isaConvolutionOpOfType(LinalgOp op,
SmallVector<int64_t> *dilations = nullptr,
SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNgchwGfchwQOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv2DNhwgcGfhwcOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv3DOp(LinalgOp op);
-bool isaConv3DNcdhwFcdhwOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv3DNdhwcDhwcfOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaConv3DNdhwcDhwcfQOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNdhwcMaxOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNdhwcMinOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
-bool isaPoolingNdhwcSumOp(LinalgOp op,
- SmallVector<int64_t> *dilations = nullptr,
- SmallVector<int64_t> *strides = nullptr);
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index e08705b90e7b0..929904fa2c510 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -268,7 +268,7 @@ static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
- if (isaConv1DOp(genericOp))
+ if (isaConvolutionOpOfType<linalg::Conv1DOp>(genericOp, &dilations, &strides))
return specializeToConvOp<linalg::Conv1DOp>(rewriter, genericOp, dilations,
strides);
return failure();
@@ -278,28 +278,35 @@ static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
- if (isaDepthwiseConv1DNcwCwOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv1DNcwCwOp>(
rewriter, genericOp, dilations, strides);
- if (isaDepthwiseConv1DNwcWcOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
rewriter, genericOp, dilations, strides);
- if (isaConv2DOp(genericOp))
+ if (isaConvolutionOpOfType<linalg::Conv2DOp>(genericOp, &dilations, &strides))
return specializeToConvOp<linalg::Conv2DOp>(rewriter, genericOp, dilations,
strides);
- if (isaPoolingNcwMaxOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNcwMaxOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNcwMaxOp>(rewriter, genericOp,
dilations, strides);
- if (isaPoolingNcwSumOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNcwSumOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNcwSumOp>(rewriter, genericOp,
dilations, strides);
- if (isaPoolingNwcMaxOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNwcMaxOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNwcMaxOp>(rewriter, genericOp,
dilations, strides);
- if (isaPoolingNwcMinOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNwcMinOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNwcMinOp>(rewriter, genericOp,
dilations, strides);
- if (isaPoolingNwcSumOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNwcSumOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNwcSumOp>(rewriter, genericOp,
dilations, strides);
return failure();
@@ -309,13 +316,16 @@ static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
- if (isaDepthwiseConv1DNwcWcmOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv1DNwcWcmOp>(
rewriter, genericOp, dilations, strides);
- if (isaConv1DNwcWcfOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv1DNwcWcfOp>(rewriter, genericOp,
dilations, strides);
- if (isaConv1DNcwFcwOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv1DNcwFcwOp>(rewriter, genericOp,
dilations, strides);
return failure();
@@ -325,37 +335,47 @@ static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
- if (isaDepthwiseConv2DNchwChwOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
rewriter, genericOp, dilations, strides);
- if (isaDepthwiseConv2DNhwcHwcOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcOp>(
rewriter, genericOp, dilations, strides);
- if (isaDepthwiseConv2DNhwcHwcQOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcQOp>(
rewriter, genericOp, dilations, strides);
- if (isaConv3DOp(genericOp))
+ if (isaConvolutionOpOfType<linalg::Conv3DOp>(genericOp, &dilations, &strides))
return specializeToConvOp<linalg::Conv3DOp>(rewriter, genericOp, dilations,
strides);
- if (isaPoolingNchwMaxOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNchwMaxOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNchwMaxOp>(rewriter, genericOp,
dilations, strides);
- if (isaPoolingNchwSumOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNchwSumOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNchwSumOp>(rewriter, genericOp,
dilations, strides);
- if (isaPoolingNhwcMaxOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
dilations, strides);
- if (isaPoolingNhwcMinOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
dilations, strides);
- if (isaPoolingNhwcSumOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
dilations, strides);
- if (isaPoolingNhwcMaxUnsignedOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
rewriter, genericOp, dilations, strides);
- if (isaPoolingNhwcMinUnsignedOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
rewriter, genericOp, dilations, strides);
return failure();
@@ -365,28 +385,36 @@ static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
- if (isaConv2DNhwcFhwcOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv2DNhwcFhwcOp>(rewriter, genericOp,
dilations, strides);
- if (isaConv2DNhwcHwcfOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv2DNhwcHwcfOp>(rewriter, genericOp,
dilations, strides);
- if (isaConv2DNchwFchwOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv2DNchwFchwOp>(rewriter, genericOp,
dilations, strides);
- if (isaConv2DNhwcFhwcQOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv2DNhwcFhwcQOp>(rewriter, genericOp,
dilations, strides);
- if (isaConv2DNchwFchwQOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv2DNchwFchwQOp>(rewriter, genericOp,
dilations, strides);
- if (isaConv2DNhwcHwcfQOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv2DNhwcHwcfQOp>(rewriter, genericOp,
dilations, strides);
- if (isaDepthwiseConv2DNhwcHwcmOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmOp>(
rewriter, genericOp, dilations, strides);
- if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(
rewriter, genericOp, dilations, strides);
return failure();
@@ -396,34 +424,44 @@ static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
- if (isaConv2DNgchwFgchwOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv2DNgchwFgchwOp>(rewriter, genericOp,
dilations, strides);
- if (isaConv2DNgchwGfchwOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv2DNgchwGfchwOp>(rewriter, genericOp,
dilations, strides);
- if (isaConv2DNgchwGfchwQOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv2DNgchwGfchwQOp>(rewriter, genericOp,
dilations, strides);
- if (isaConv2DNhwgcGfhwcOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv2DNhwgcGfhwcOp>(rewriter, genericOp,
dilations, strides);
- if (isaConv2DNhwgcGfhwcQOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv2DNhwgcGfhwcQOp>(rewriter, genericOp,
dilations, strides);
- if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNcdhwCdhwOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(
rewriter, genericOp, dilations, strides);
- if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(
rewriter, genericOp, dilations, strides);
- if (isaPoolingNdhwcMaxOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNdhwcMaxOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNdhwcMaxOp>(rewriter, genericOp,
dilations, strides);
- if (isaPoolingNdhwcMinOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNdhwcMinOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNdhwcMinOp>(rewriter, genericOp,
dilations, strides);
- if (isaPoolingNdhwcSumOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::PoolingNdhwcSumOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::PoolingNdhwcSumOp>(rewriter, genericOp,
dilations, strides);
return failure();
@@ -433,16 +471,20 @@ static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
- if (isaConv3DNcdhwFcdhwOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv3DNcdhwFcdhwOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv3DNcdhwFcdhwOp>(rewriter, genericOp,
dilations, strides);
- if (isaConv3DNdhwcDhwcfOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv3DNdhwcDhwcfOp>(rewriter, genericOp,
dilations, strides);
- if (isaConv3DNdhwcDhwcfQOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfQOp>(genericOp, &dilations,
+ &strides))
return specializeToConvOp<linalg::Conv3DNdhwcDhwcfQOp>(rewriter, genericOp,
dilations, strides);
- if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp, &dilations, &strides))
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
rewriter, genericOp, dilations, strides);
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 8ea5e7a10e17e..13235d99887a7 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,15 +240,14 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
return iteratorType == utils::IteratorType::reduction;
}
-// -------------------------------
-// ---------- CONV ---------------
-// -------------------------------
+//===----------------------------------------------------------------------===//
+// Convolution matcher utilities
+//===----------------------------------------------------------------------===//
/// Utility to match block body for linalg.pool* ops.
template <typename... OpTypes>
static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
Operation *defOp = yieldVal.getDefiningOp();
- // if (!defOp) return false;
if (!(isa_and_present<OpTypes>(defOp) || ...))
return false;
@@ -402,7 +401,7 @@ static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
return true;
}
-bool isaConv1DOp(LinalgOp op) {
+static bool isaConv1DOp(LinalgOp op) {
if (isa<linalg::Conv1DOp>(op))
return true;
@@ -423,8 +422,8 @@ bool isaConv1DOp(LinalgOp op) {
tempStrides[0]);
}
-bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv1DNwcWcfOp>(op))
return true;
@@ -453,8 +452,8 @@ bool isaConv1DNwcWcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv1DNcwFcwOp>(op))
return true;
@@ -483,8 +482,9 @@ bool isaConv1DNcwFcwOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv1DNcwCwOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
return true;
@@ -514,8 +514,9 @@ bool isaDepthwiseConv1DNcwCwOp(LinalgOp op, SmallVector<int64_t> *dilations,
}
// -------------------
-bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
return true;
@@ -544,8 +545,9 @@ bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
return true;
@@ -575,7 +577,7 @@ bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv2DOp(LinalgOp op) {
+static bool isaConv2DOp(LinalgOp op) {
if (isa<linalg::Conv2DOp>(op))
return true;
@@ -599,8 +601,8 @@ bool isaConv2DOp(LinalgOp op) {
tempStrides[1]));
}
-bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNhwcFhwcOp>(op))
return true;
@@ -632,8 +634,8 @@ bool isaConv2DNhwcFhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNhwcHwcfOp>(op))
return true;
@@ -665,8 +667,8 @@ bool isaConv2DNhwcHwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNchwFchwOp>(op))
return true;
@@ -698,8 +700,8 @@ bool isaConv2DNchwFchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
return true;
@@ -732,8 +734,8 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNchwFchwQOp>(op))
return true;
@@ -766,8 +768,8 @@ bool isaConv2DNchwFchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNgchwFgchwOp>(op))
return true;
@@ -802,8 +804,8 @@ bool isaConv2DNgchwFgchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNgchwGfchwOp>(op))
return true;
@@ -838,8 +840,8 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
return true;
@@ -872,8 +874,8 @@ bool isaConv2DNhwcHwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op))
return true;
@@ -908,8 +910,8 @@ bool isaConv2DNhwgcGfhwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
return true;
@@ -945,8 +947,8 @@ bool isaConv2DNgchwGfchwQOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
return true;
@@ -981,8 +983,9 @@ bool isaConv2DNhwgcGfhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
return true;
@@ -1014,8 +1017,9 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op))
return true;
@@ -1047,8 +1051,9 @@ bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
return true;
@@ -1081,8 +1086,9 @@ bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op))
return true;
@@ -1116,8 +1122,9 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
return true;
@@ -1150,7 +1157,7 @@ bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv3DOp(LinalgOp op) {
+static bool isaConv3DOp(LinalgOp op) {
if (isa<linalg::Conv3DOp>(op))
return true;
@@ -1177,8 +1184,8 @@ bool isaConv3DOp(LinalgOp op) {
tempStrides[2]));
}
-bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv3DNcdhwFcdhwOp>(op))
return true;
@@ -1214,8 +1221,8 @@ bool isaConv3DNcdhwFcdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv3DNdhwcDhwcfOp>(op))
return true;
@@ -1251,8 +1258,8 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op))
return true;
@@ -1289,9 +1296,9 @@ bool isaConv3DNdhwcDhwcfQOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
- SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
return true;
@@ -1328,8 +1335,9 @@ bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
tempDilations, tempStrides);
}
-bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op))
return true;
@@ -1365,8 +1373,9 @@ bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op))
return true;
@@ -1402,8 +1411,8 @@ bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNchwMaxOp>(op))
return true;
@@ -1438,8 +1447,8 @@ bool isaPoolingNchwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNchwSumOp>(op))
return true;
@@ -1474,8 +1483,8 @@ bool isaPoolingNchwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNhwcMaxOp>(op))
return true;
@@ -1510,8 +1519,8 @@ bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNhwcMinOp>(op))
return true;
@@ -1546,8 +1555,8 @@ bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNhwcSumOp>(op))
return true;
@@ -1582,8 +1591,9 @@ bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
return true;
@@ -1618,8 +1628,9 @@ bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
return true;
@@ -1654,8 +1665,8 @@ bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNcwMaxOp>(op))
return true;
@@ -1687,8 +1698,8 @@ bool isaPoolingNcwMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNcwSumOp>(op))
return true;
@@ -1720,8 +1731,8 @@ bool isaPoolingNcwSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNwcMaxOp>(op))
return true;
@@ -1753,8 +1764,8 @@ bool isaPoolingNwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNwcMinOp>(op))
return true;
@@ -1786,8 +1797,8 @@ bool isaPoolingNwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNwcSumOp>(op))
return true;
@@ -1819,8 +1830,8 @@ bool isaPoolingNwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNdhwcMaxOp>(op))
return true;
@@ -1859,8 +1870,8 @@ bool isaPoolingNdhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNdhwcMinOp>(op))
return true;
@@ -1899,8 +1910,8 @@ bool isaPoolingNdhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+static bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNdhwcSumOp>(op))
return true;
@@ -1939,6 +1950,263 @@ bool isaPoolingNdhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
+template <typename ConvOpTy>
+bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp>) {
+ return isaConv1DOp(op);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DNwcWcfOp>) {
+ return isaConv1DNwcWcfOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DNcwFcwOp>) {
+ return isaConv1DNcwFcwOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv1DNcwCwOp>) {
+ return isaDepthwiseConv1DNcwCwOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv1DNwcWcOp>) {
+ return isaDepthwiseConv1DNwcWcOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv1DNwcWcmOp>) {
+ return isaDepthwiseConv1DNwcWcmOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DOp>) {
+ return isaConv2DOp(op);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNhwcFhwcOp>) {
+ return isaConv2DNhwcFhwcOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNhwcHwcfOp>) {
+ return isaConv2DNhwcHwcfOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNchwFchwOp>) {
+ return isaConv2DNchwFchwOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNhwcFhwcQOp>) {
+ return isaConv2DNhwcFhwcQOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNchwFchwQOp>) {
+ return isaConv2DNchwFchwQOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNgchwFgchwOp>) {
+ return isaConv2DNgchwFgchwOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNgchwGfchwOp>) {
+ return isaConv2DNgchwGfchwOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNhwcHwcfQOp>) {
+ return isaConv2DNhwcHwcfQOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNhwgcGfhwcQOp>) {
+ return isaConv2DNhwgcGfhwcQOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNgchwGfchwQOp>) {
+ return isaConv2DNgchwGfchwQOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv2DNhwgcGfhwcOp>) {
+ return isaConv2DNhwgcGfhwcOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv2DNchwChwOp>) {
+ return isaDepthwiseConv2DNchwChwOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv2DNhwcHwcOp>) {
+ return isaDepthwiseConv2DNhwcHwcOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv2DNhwcHwcmOp>) {
+ return isaDepthwiseConv2DNhwcHwcmOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv2DNhwcHwcQOp>) {
+ return isaDepthwiseConv2DNhwcHwcQOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv2DNhwcHwcmQOp>) {
+ return isaDepthwiseConv2DNhwcHwcmQOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+ return isaConv3DOp(op);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv3DNcdhwFcdhwOp>) {
+ return isaConv3DNcdhwFcdhwOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv3DNdhwcDhwcfOp>) {
+ return isaConv3DNdhwcDhwcfOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::Conv3DNdhwcDhwcfQOp>) {
+ return isaConv3DNdhwcDhwcfQOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv3DNdhwcDhwcmOp>) {
+ return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv3DNcdhwCdhwOp>) {
+ return isaDepthwiseConv3DNcdhwCdhwOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv3DNdhwcDhwcOp>) {
+ return isaDepthwiseConv3DNdhwcDhwcOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNchwMaxOp>) {
+ return isaPoolingNchwMaxOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNchwSumOp>) {
+ return isaPoolingNchwSumOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcMaxOp>) {
+ return isaPoolingNhwcMaxOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcMinOp>) {
+ return isaPoolingNhwcMinOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcSumOp>) {
+ return isaPoolingNhwcSumOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::PoolingNhwcMaxUnsignedOp>) {
+ return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::PoolingNhwcMinUnsignedOp>) {
+ return isaPoolingNhwcMinUnsignedOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNcwMaxOp>) {
+ return isaPoolingNcwMaxOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNcwSumOp>) {
+ return isaPoolingNcwSumOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNwcMaxOp>) {
+ return isaPoolingNwcMaxOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNwcMinOp>) {
+ return isaPoolingNwcMinOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNwcSumOp>) {
+ return isaPoolingNwcSumOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNdhwcMaxOp>) {
+ return isaPoolingNdhwcMaxOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNdhwcMinOp>) {
+ return isaPoolingNdhwcMinOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNdhwcSumOp>) {
+ return isaPoolingNdhwcSumOp(op, dilations, strides);
+ } else {
+ return false;
+ }
+}
+
+template bool
+isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNhwcFhwcQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNchwFchwQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNgchwFgchwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNhwcHwcfQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNgchwGfchwQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv2DNhwgcGfhwcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNhwcHwcmQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv3DNcdhwFcdhwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::Conv3DNdhwcDhwcfQOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNcdhwCdhwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNchwMaxOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNchwSumOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::PoolingNcwMaxOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::PoolingNcwSumOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::PoolingNwcMaxOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::PoolingNwcMinOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool
+isaConvolutionOpOfType<linalg::PoolingNwcSumOp>(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNdhwcMaxOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNdhwcMinOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNdhwcSumOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold,
ValueRange typeDynDims) {
More information about the Mlir-commits
mailing list