[Mlir-commits] [mlir] [Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops (PR #163724)
Abhishek Varma
llvmlistbot at llvm.org
Thu Oct 23 04:14:38 PDT 2025
================
@@ -237,6 +237,145 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
+/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
+template <typename ConvOpTy>
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+ ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+ SmallVector<Value> inputs = genericOp.getDpsInputs();
+ ValueRange outputs = genericOp.getDpsInits();
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+ ? TypeRange(ValueRange(outputs))
+ : TypeRange{};
+ LinalgOp namedOp;
+ if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+ inputs, outputs);
+ } else {
+ Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+ Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+ genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+ }
+ return namedOp;
+}
+
+/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
+ rewriter, genericOp, dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
+ rewriter, genericOp, dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
+// improve the search speed, the convolution ops have been segregated based on
+// the rank of iterator types array.
+static FailureOr<LinalgOp>
+inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
+ SmallVector<utils::IteratorType> iteratorTypes =
+ genericOp.getIteratorTypesArray();
+ unsigned totalIterators = iteratorTypes.size();
+ switch (totalIterators) {
+ case 2:
+ return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
+ case 4:
+ return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
+ case 5:
+ return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
+ case 6:
+ return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
+ case 7:
+ return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
+ case 8:
+ return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
+ case 9:
+ return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
----------------
Abhishek-Varma wrote:
I've removed `inferAndSpecializeBasedOnRankNConvIteratorType` (N = 2, 4, 5, ...) and checking/forming convolution type within just one function `inferAndSpecializeToConvolutionOp` now in the latest push to the patch.
https://github.com/llvm/llvm-project/pull/163724
More information about the Mlir-commits
mailing list