[Mlir-commits] [mlir] [Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops (PR #163724)
Abhishek Varma
llvmlistbot at llvm.org
Fri Oct 24 01:00:03 PDT 2025
https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/163724
>From c6aea9193db7ad415f2fcde00e8bdcc3d98cfea4 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 16 Oct 2025 02:54:14 -0500
Subject: [PATCH 1/5] [Linalg] Add basic infra to add matchers for
linalg.*conv*/*pool* ops
-- This commit includes the basic infra/utilities to add matchers for
linalg.*conv*/*pool* ops - such that given a `linalg.generic` op it
identifies which linalg.*conv*/*pool* op it is.
-- It adds a few representative linalg.*conv*/*pool* ops to demo the
matchers' capability and does so as part of `linalg-specialize-generic-ops`
pass.
-- The goal is directed towards addressing the aim of
[[RFC] Op explosion in Linalg](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863)
iteratively for `*conv*/*pooling*` ops.
-- This is part-1 of a series of PRs aimed to add matchers for Convolution ops.
-- For further details, refer to https://github.com/llvm/llvm-project/pull/163374#pullrequestreview-3341048722
Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 9 +
.../Dialect/Linalg/Transforms/Specialize.cpp | 144 +++++
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 502 ++++++++++++++++++
.../convolution/roundtrip-convolution.mlir | 112 ++++
4 files changed, 767 insertions(+)
create mode 100644 mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 48978eb7663d5..771d753a8bddb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -110,6 +110,15 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
std::optional<SmallVector<ReassociationIndices>>
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
+//===----------------------------------------------------------------------===//
+// Convolution matcher utility
+//===----------------------------------------------------------------------===//
+
+template <typename ConvOpTy>
+bool isaConvolutionOpOfType(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 40fc0d68e358f..35861002e309e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,145 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
+/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
+template <typename ConvOpTy>
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+ ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+ SmallVector<Value> inputs = genericOp.getDpsInputs();
+ ValueRange outputs = genericOp.getDpsInits();
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+ ? TypeRange(ValueRange(outputs))
+ : TypeRange{};
+ LinalgOp namedOp;
+ if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+ inputs, outputs);
+ } else {
+ Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+ Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+ genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+ }
+ return namedOp;
+}
+
+/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
+ rewriter, genericOp, dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
+ rewriter, genericOp, dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
+// improve the search speed, the convolution ops have been segregated based on
+// the rank of iterator types array.
+static FailureOr<LinalgOp>
+inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
+ SmallVector<utils::IteratorType> iteratorTypes =
+ genericOp.getIteratorTypesArray();
+ unsigned totalIterators = iteratorTypes.size();
+ switch (totalIterators) {
+ case 2:
+ return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
+ case 4:
+ return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
+ case 5:
+ return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
+ case 6:
+ return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
+ case 7:
+ return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
+ case 8:
+ return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
+ case 9:
+ return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
+ }
+ return failure();
+}
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -316,6 +455,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
+
+ // Convolution - e.g. *conv/pooling*
+ if (isaConvolutionOpInterface(genericOp)) {
+ return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
+ }
return failure();
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 24d3722cf5426..c3c2819652129 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,6 +240,508 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
return iteratorType == utils::IteratorType::reduction;
}
+//===----------------------------------------------------------------------===//
+// Convolution matcher utilities
+//===----------------------------------------------------------------------===//
+
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+ Operation *defOp = yieldVal.getDefiningOp();
+ if (!(isa_and_present<OpTypes>(defOp) || ...))
+ return false;
+
+ BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+ BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
+ if (!lhsArg || !rhsArg)
+ return false;
+ return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
+static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
+ uint32_t mapIndex, uint32_t dimIndex) {
+ auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+ if (dimIndex < affineMap.getNumResults())
+ return affineMap.getResult(dimIndex);
+ return nullptr;
+}
+
+// Check if `expr` is either:
+// - a dimension expr alone (implying *1), or
+// - a multiplication of dimension expr by constant.
+static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim,
+ int64_t &constantValue) {
+ if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
+ dim = dExpr;
+ constantValue = 1;
+ return true;
+ }
+
+ auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+ return false;
+
+ AffineExpr lhs = mulExpr.getLHS();
+ AffineExpr rhs = mulExpr.getRHS();
+
+ if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
+ if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
+ dim = dExpr;
+ constantValue = cst.getValue();
+ return true;
+ }
+ }
+ if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
+ if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
+ dim = dExpr;
+ constantValue = cst.getValue();
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+/// indexingMaps[0].getResult(iDim) ==
+/// indexingMaps[1].getResult(fDim) * <CST_1> +
+/// indexingMaps[n-1].getResult(oDim) * <CST_2>
+/// where, CST_1 and CST_2 can be any constant.
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
+ unsigned fDim, unsigned oDim,
+ int64_t &dilation, int64_t &stride) {
+ unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+ AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
+ auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+ if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+ return false;
+
+ AffineExpr dim0, dim1;
+ int64_t c0, c1;
+
+ if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
+ isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
+ // Pattern matched with dims and constants extracted.
+ AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
+ AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
+ if (dim0 == fExpr && dim1 == oExpr) {
+ dilation = c0;
+ stride = c1;
+ return true;
+ } else if (dim1 == fExpr && dim0 == oExpr) {
+ dilation = c1;
+ stride = c0;
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+/// indexingMaps[aIndex].getResult(aDim) ==
+/// indexingMaps[bIndex].getResult(bDim)
+static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex,
+ unsigned aDim, unsigned bIndex,
+ unsigned bDim) {
+ return getAffineMapDim(indexingMaps, aIndex, aDim) ==
+ getAffineMapDim(indexingMaps, bIndex, bDim);
+}
+
+/// Give an array of AffineMaps, verify each map to be of the corresponding
+/// `expectedSize`.
+static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
+ ArrayRef<int64_t> expectedSizes) {
+ if (indexingMaps.size() != expectedSizes.size())
+ return false;
+
+ for (auto [indexingMap, expectedSize] :
+ llvm::zip_equal(indexingMaps, expectedSizes)) {
+ auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
+ if (affineMap.getNumResults() != expectedSize)
+ return false;
+ }
+ return true;
+}
+
+/// Utility to update `dilations` and `strides` by copy the corresponding data
+/// from `tempDilations` and `tempStrides`.
+static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides,
+ ArrayRef<int64_t> tempDilations,
+ ArrayRef<int64_t> tempStrides) {
+ if (!(dilations && strides))
+ return true;
+ for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) {
+ dilations->push_back(dilation);
+ strides->push_back(stride);
+ }
+ return true;
+}
+
+static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
+ return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
+ return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[1],
+ tempStrides[1]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6}))
+ return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d5, d6, d7, d8, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d0, d1, d2, d3, d8, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[2],
+ tempStrides[2]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMaxOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForMaxSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMinOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForMinSignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcSumOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForSumPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForMaxUnsignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
+ return false;
+
+ Block *body = op.getBlock();
+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
+ Value yieldVal = yieldOp.getOperand(0);
+ unsigned iIndex = 0, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
+ bodyMatcherForMinUnsignedPoolOps(yieldVal, body));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+template <typename ConvOpTy>
+bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if constexpr (std::is_same_v<ConvOpTy, linalg::DepthwiseConv1DNwcWcOp>) {
+ return isaDepthwiseConv1DNwcWcOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv2DNchwChwOp>) {
+ return isaDepthwiseConv2DNchwChwOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::DepthwiseConv3DNdhwcDhwcmOp>) {
+ return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcMaxOp>) {
+ return isaPoolingNhwcMaxOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcMinOp>) {
+ return isaPoolingNhwcMinOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcSumOp>) {
+ return isaPoolingNhwcSumOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::PoolingNhwcMaxUnsignedOp>) {
+ return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides);
+ } else if constexpr (std::is_same_v<ConvOpTy,
+ linalg::PoolingNhwcMinUnsignedOp>) {
+ return isaPoolingNhwcMinUnsignedOp(op, dilations, strides);
+ } else {
+ return false;
+ }
+}
+
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+template bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides);
+
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold,
ValueRange typeDynDims) {
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
new file mode 100644
index 0000000000000..5a18ca8519be3
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -0,0 +1,112 @@
+// The following test examples of linalg convolution named ops lowered to linalg.generic and then
+// lifted back up to named op.
+// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
+
+func.func @depthwise_conv_1d_nwc_wc(%input: memref<?x?x?xf32>, %filter: memref<?x?xf32>, %output: memref<?x?x?xf32>) {
+ linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>,
+ strides = dense<2> : tensor<1xi64>}
+ ins (%input, %filter: memref<?x?x?xf32>, memref<?x?xf32>)
+ outs (%output: memref<?x?x?xf32>)
+ return
+}
+// CHECK: @depthwise_conv_1d_nwc_wc
+// CHECK: linalg.depthwise_conv_1d_nwc_wc
+// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_2d_nchw_chw
+// CHECK: linalg.depthwise_conv_2d_nchw_chw
+// CHECK-SAME: dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> {
+ %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>,
+ strides = dense<1> : tensor<3xi64>}
+ ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+ outs (%init: tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?x?x?xf32>
+}
+// CHECK: @depthwise_conv_3d_ndhwc_dhwcm
+// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwcm
+// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nhwc_max
+// CHECK: linalg.pooling_nhwc_max
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_min(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nhwc_min
+// CHECK: linalg.pooling_nhwc_min
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_sum(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nhwc_sum
+// CHECK: linalg.pooling_nhwc_sum
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+ %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xi32>, tensor<?x?xi32>)
+ outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+ return %0 : tensor<?x?x?x?xi32>
+}
+// CHECK: @pooling_nhwc_max_unsigned
+// CHECK: linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+ %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xi32>, tensor<?x?xi32>)
+ outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
+ return %0 : tensor<?x?x?x?xi32>
+}
+// CHECK: @pooling_nhwc_min_unsigned
+// CHECK: linalg.pooling_nhwc_min_unsigned
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
>From cd1b88a9d7febdd1f933ac22254303f74643f1c2 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 17 Oct 2025 02:46:56 -0500
Subject: [PATCH 2/5] Review comment v1.0
---
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 143 +++++++++---------
.../convolution/roundtrip-convolution.mlir | 16 +-
2 files changed, 87 insertions(+), 72 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index c3c2819652129..4dfec7b361eab 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -418,9 +418,9 @@ static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
SmallVector<int64_t> tempDilations(1, 1);
SmallVector<int64_t> tempStrides(1, 1);
- // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
- // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ // #map = affine_map<(N, W, C, w) -> (N, W + w, C)>
+ // #map1 = affine_map<(N, W, C, w) -> (w, C)>
+ // #map2 = affine_map<(N, W, C, w) -> (N, W, C)>
bool returnVal =
(matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
@@ -449,9 +449,9 @@ static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
SmallVector<int64_t> tempDilations(2, 1);
SmallVector<int64_t> tempStrides(2, 1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
+ // #map = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
+ // #map1 = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
+ // #map2 = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
bool returnVal =
(matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
@@ -483,12 +483,12 @@ static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
SmallVector<int64_t> tempDilations(3, 1);
SmallVector<int64_t> tempStrides(3, 1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
- // -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
- // -> (d5, d6, d7, d8, d4)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
- // -> (d0, d1, d2, d3, d8, d4)>
+ // #map = affine_map<(N, D, H, W, CM, d, h, w, C)
+ // -> (N, D + d, H + h, W + w, C)>
+ // #map1 = affine_map<(N, D, H, W, CM, d, h, w, C)
+ // -> (d, h, w, C, CM)>
+ // #map2 = affine_map<(N, D, H, W, CM, d, h, w, C)
+ // -> (N, D, H, W, C, CM)>
bool returnVal =
(matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
@@ -526,9 +526,9 @@ static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> tempDilations(2, 1);
SmallVector<int64_t> tempStrides(2, 1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+ // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)>
+ // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
bool returnVal =
(matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
@@ -562,9 +562,9 @@ static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> tempDilations(2, 1);
SmallVector<int64_t> tempStrides(2, 1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+ // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)>
+ // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
bool returnVal =
(matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
@@ -598,9 +598,9 @@ static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
SmallVector<int64_t> tempDilations(2, 1);
SmallVector<int64_t> tempStrides(2, 1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+ // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)>
+ // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
bool returnVal =
(matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
@@ -635,9 +635,9 @@ static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
SmallVector<int64_t> tempDilations(2, 1);
SmallVector<int64_t> tempStrides(2, 1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+ // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)>
+ // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
bool returnVal =
(matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
@@ -672,9 +672,9 @@ static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
SmallVector<int64_t> tempDilations(2, 1);
SmallVector<int64_t> tempStrides(2, 1);
- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)>
+ // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)>
+ // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)>
bool returnVal =
(matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
@@ -689,58 +689,61 @@ static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
tempDilations, tempStrides);
}
-template <typename ConvOpTy>
-bool isaConvolutionOpOfType(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- if constexpr (std::is_same_v<ConvOpTy, linalg::DepthwiseConv1DNwcWcOp>) {
- return isaDepthwiseConv1DNwcWcOp(op, dilations, strides);
- } else if constexpr (std::is_same_v<ConvOpTy,
- linalg::DepthwiseConv2DNchwChwOp>) {
- return isaDepthwiseConv2DNchwChwOp(op, dilations, strides);
- } else if constexpr (std::is_same_v<ConvOpTy,
- linalg::DepthwiseConv3DNdhwcDhwcmOp>) {
- return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides);
- } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcMaxOp>) {
- return isaPoolingNhwcMaxOp(op, dilations, strides);
- } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcMinOp>) {
- return isaPoolingNhwcMinOp(op, dilations, strides);
- } else if constexpr (std::is_same_v<ConvOpTy, linalg::PoolingNhwcSumOp>) {
- return isaPoolingNhwcSumOp(op, dilations, strides);
- } else if constexpr (std::is_same_v<ConvOpTy,
- linalg::PoolingNhwcMaxUnsignedOp>) {
- return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides);
- } else if constexpr (std::is_same_v<ConvOpTy,
- linalg::PoolingNhwcMinUnsignedOp>) {
- return isaPoolingNhwcMinUnsignedOp(op, dilations, strides);
- } else {
- return false;
- }
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ return isaDepthwiseConv1DNwcWcOp(op, dilations, strides);
}
-template bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ SmallVector<int64_t> *strides) {
+ return isaDepthwiseConv2DNchwChwOp(op, dilations, strides);
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
+ SmallVector<int64_t> *strides) {
+ return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides);
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
+ SmallVector<int64_t> *strides) {
+ return isaPoolingNhwcMaxOp(op, dilations, strides);
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
+ SmallVector<int64_t> *strides) {
+ return isaPoolingNhwcMinOp(op, dilations, strides);
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ SmallVector<int64_t> *strides) {
+ return isaPoolingNhwcSumOp(op, dilations, strides);
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides);
-template bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+ SmallVector<int64_t> *strides) {
+ return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides);
+}
+
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides);
+ SmallVector<int64_t> *strides) {
+ return isaPoolingNhwcMinUnsignedOp(op, dilations, strides);
+}
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold,
diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
index 5a18ca8519be3..06c9a84049d81 100644
--- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
@@ -99,14 +99,26 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x?x?x?xi32>, %filter: tenso
// -----
-func.func @pooling_nhwc_min_unsigned(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
+func.func @pooling_nhwc_min_unsigned_integer(%input: tensor<?x?x?x?xi32>, %filter: tensor<?x?xi32>, %init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32> {
%0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
ins (%input, %filter: tensor<?x?x?x?xi32>, tensor<?x?xi32>)
outs (%init: tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
return %0 : tensor<?x?x?x?xi32>
}
-// CHECK: @pooling_nhwc_min_unsigned
+// CHECK: @pooling_nhwc_min_unsigned_integer
// CHECK: linalg.pooling_nhwc_min_unsigned
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
// CHECK-NOT: linalg.generic
+
+func.func @pooling_nhwc_min_unsigned_float(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %0 : tensor<?x?x?x?xf32>
+}
+// CHECK: @pooling_nhwc_min_unsigned_float
+// CHECK: linalg.pooling_nhwc_min
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>
+// CHECK-NOT: linalg.generic
>From 0e9946b47c518867dae394c5221fba7d812c4803 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 22 Oct 2025 03:28:15 -0500
Subject: [PATCH 3/5] Review comment Hanhan v1.0
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 40 -------
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 109 +++++-------------
2 files changed, 32 insertions(+), 117 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 35861002e309e..2bfa21d9062ee 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -264,14 +264,6 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
return namedOp;
}
-/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be
-/// added here incrementally in follow-up PRs.
-static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
- GenericOp genericOp) {
- return failure();
-}
-
static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
@@ -283,14 +275,6 @@ inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
return failure();
}
-/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be
-/// added here incrementally in follow-up PRs.
-static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
- GenericOp genericOp) {
- return failure();
-}
-
static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
@@ -322,22 +306,6 @@ inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
return failure();
}
-/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be
-/// added here incrementally in follow-up PRs.
-static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
- GenericOp genericOp) {
- return failure();
-}
-
-/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be
-/// added here incrementally in follow-up PRs.
-static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
- GenericOp genericOp) {
- return failure();
-}
-
static FailureOr<LinalgOp>
inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
GenericOp genericOp) {
@@ -358,18 +326,10 @@ inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
genericOp.getIteratorTypesArray();
unsigned totalIterators = iteratorTypes.size();
switch (totalIterators) {
- case 2:
- return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
case 4:
return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
- case 5:
- return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
case 6:
return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
- case 7:
- return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
- case 8:
- return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
case 9:
return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 4dfec7b361eab..23c7fb68a5534 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -401,9 +401,10 @@ static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
return true;
}
-static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
- SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
return true;
@@ -432,9 +433,10 @@ static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
tempDilations, tempStrides);
}
-static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
- SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
return true;
@@ -466,9 +468,10 @@ static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
tempDilations, tempStrides);
}
-static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
- SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
return true;
@@ -507,8 +510,10 @@ static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
tempDilations, tempStrides);
}
-static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNhwcMaxOp>(op))
return true;
@@ -543,8 +548,10 @@ static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNhwcMinOp>(op))
return true;
@@ -579,8 +586,10 @@ static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNhwcSumOp>(op))
return true;
@@ -615,9 +624,10 @@ static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector<int64_t> *dilations,
tempDilations, tempStrides);
}
-static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
- SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
return true;
@@ -652,9 +662,10 @@ static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op,
tempDilations, tempStrides);
}
-static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
- SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
+template <>
+bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+ LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
return true;
@@ -689,62 +700,6 @@ static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op,
tempDilations, tempStrides);
}
-template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- return isaDepthwiseConv1DNwcWcOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- return isaDepthwiseConv2DNchwChwOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- return isaPoolingNhwcMaxOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- return isaPoolingNhwcMinOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- return isaPoolingNhwcSumOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides);
-}
-
-template <>
-bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
- LinalgOp op, SmallVector<int64_t> *dilations,
- SmallVector<int64_t> *strides) {
- return isaPoolingNhwcMinUnsignedOp(op, dilations, strides);
-}
-
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
Value source, Value pad, bool nofold,
ValueRange typeDynDims) {
>From d44cc34ce67daccce72d930f6fea0982ce02a273 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 23 Oct 2025 06:04:50 -0500
Subject: [PATCH 4/5] Review comment Andrszej v2.0
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 54 ++++---------------
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 6 +++
2 files changed, 17 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 2bfa21d9062ee..ce3df6a485f92 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -264,25 +264,26 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
return namedOp;
}
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
+// improve the search speed, the convolution ops have been segregated based on
+// the rank of iterator types array.
static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
- GenericOp genericOp) {
+inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
+ // Depthwise Convolution ops.
if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
rewriter, genericOp, dilations, strides);
- return failure();
-}
-
-static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
- GenericOp genericOp) {
- SmallVector<int64_t> dilations, strides;
if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
rewriter, genericOp, dilations, strides);
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ rewriter, genericOp, dilations, strides);
+ // Pooling ops.
if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
&strides))
return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
@@ -306,36 +307,6 @@ inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
return failure();
}
-static FailureOr<LinalgOp>
-inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
- GenericOp genericOp) {
- SmallVector<int64_t> dilations, strides;
- if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
- genericOp, &dilations, &strides))
- return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
- rewriter, genericOp, dilations, strides);
- return failure();
-}
-
-// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
-// improve the search speed, the convolution ops have been segregated based on
-// the rank of iterator types array.
-static FailureOr<LinalgOp>
-inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
- SmallVector<utils::IteratorType> iteratorTypes =
- genericOp.getIteratorTypesArray();
- unsigned totalIterators = iteratorTypes.size();
- switch (totalIterators) {
- case 4:
- return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
- case 6:
- return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
- case 9:
- return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
- }
- return failure();
-}
-
} // namespace
//===----------------------------------------------------------------------===//
@@ -417,10 +388,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
// Convolution - e.g. *conv/pooling*
- if (isaConvolutionOpInterface(genericOp)) {
- return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
- }
- return failure();
+ return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
}
namespace {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 23c7fb68a5534..cd518fc38819e 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -263,6 +263,9 @@ static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
body);
}
+// max_unsigned ops should not allow float data type.
+// TODO: Retire OPDSL logic. Refer to :
+// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
body);
@@ -273,6 +276,9 @@ static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
body);
}
+// min_unsigned ops should not allow float data type.
+// TODO: Retire OPDSL logic. Refer to :
+// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
body);
>From 7b47d9e56db22366604e8608d099002cba5e9fd6 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Fri, 24 Oct 2025 02:58:28 -0500
Subject: [PATCH 5/5] Review comment Andrszej v3.0
---
.../Dialect/Linalg/Transforms/Specialize.cpp | 13 +++++--
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 39 ++++++++++---------
2 files changed, 31 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index ce3df6a485f92..c68f7bd88c1ae 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -267,10 +267,12 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
// improve the search speed, the convolution ops have been segregated based on
// the rank of iterator types array.
-static FailureOr<LinalgOp>
-inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
+static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
+ GenericOp genericOp) {
SmallVector<int64_t> dilations, strides;
+ // -----------------------------
// Depthwise Convolution ops.
+ //------------------------------
if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
@@ -283,7 +285,9 @@ inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
genericOp, &dilations, &strides))
return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
rewriter, genericOp, dilations, strides);
+ // -----------------------------
// Pooling ops.
+ //------------------------------
if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
&strides))
return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
@@ -388,7 +392,10 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
// Convolution - e.g. *conv/pooling*
- return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
+ if (isaConvolutionOpInterface(genericOp)) {
+ return specializeLinalgConvolutions(rewriter, genericOp);
+ }
+ return failure();
}
namespace {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index cd518fc38819e..c5c9e4b2f8387 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -265,7 +265,7 @@ static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
// max_unsigned ops should not allow float data type.
// TODO: Retire OPDSL logic. Refer to :
-// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337
+// https://github.com/llvm/llvm-project/issues/164800
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
body);
@@ -278,7 +278,7 @@ static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
// min_unsigned ops should not allow float data type.
// TODO: Retire OPDSL logic. Refer to :
-// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337
+// https://github.com/llvm/llvm-project/issues/164800
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
body);
@@ -407,6 +407,9 @@ static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
return true;
}
+// ---------------------------------------------
+// Matchers for specific convolution operation.
+//----------------------------------------------
template <>
bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
LinalgOp op, SmallVector<int64_t> *dilations,
@@ -414,8 +417,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
return true;
- if (!isaConvolutionOpInterface(op))
- return false;
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
ArrayAttr indexingMaps = op.getIndexingMaps();
if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
@@ -446,8 +449,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
return true;
- if (!isaConvolutionOpInterface(op))
- return false;
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
ArrayAttr indexingMaps = op.getIndexingMaps();
if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
@@ -481,8 +484,8 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
return true;
- if (!isaConvolutionOpInterface(op))
- return false;
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
ArrayAttr indexingMaps = op.getIndexingMaps();
if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6}))
@@ -523,8 +526,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
if (isa<linalg::PoolingNhwcMaxOp>(op))
return true;
- if (!isaConvolutionOpInterface(op))
- return false;
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
ArrayAttr indexingMaps = op.getIndexingMaps();
if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
@@ -561,8 +564,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
if (isa<linalg::PoolingNhwcMinOp>(op))
return true;
- if (!isaConvolutionOpInterface(op))
- return false;
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
ArrayAttr indexingMaps = op.getIndexingMaps();
if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
@@ -599,8 +602,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
if (isa<linalg::PoolingNhwcSumOp>(op))
return true;
- if (!isaConvolutionOpInterface(op))
- return false;
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
ArrayAttr indexingMaps = op.getIndexingMaps();
if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
@@ -637,8 +640,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
return true;
- if (!isaConvolutionOpInterface(op))
- return false;
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
ArrayAttr indexingMaps = op.getIndexingMaps();
if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
@@ -675,8 +678,8 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
return true;
- if (!isaConvolutionOpInterface(op))
- return false;
+ assert(isaConvolutionOpInterface(op) &&
+ "expected linalgOp to implement ConvolutionOpInterface");
ArrayAttr indexingMaps = op.getIndexingMaps();
if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
More information about the Mlir-commits
mailing list