[Mlir-commits] [mlir] [Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops (PR #163724)

Abhishek Varma llvmlistbot at llvm.org
Thu Oct 23 04:14:38 PDT 2025


================
@@ -237,6 +237,145 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
   return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
 }
 
+/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
+template <typename ConvOpTy>
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+                   ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+  SmallVector<Value> inputs = genericOp.getDpsInputs();
+  ValueRange outputs = genericOp.getDpsInits();
+  SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+  SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+                                      ? TypeRange(ValueRange(outputs))
+                                      : TypeRange{};
+  LinalgOp namedOp;
+  if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+                std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+                std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+                                                    inputs, outputs);
+  } else {
+    Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+    Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+    namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+        genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+  }
+  return namedOp;
+}
+
+/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
+        rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
+        rewriter, genericOp, dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
+                                                       &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
+                                                        dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
+                                                       &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
+                                                        dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
+                                                       &strides))
+    return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
+                                                        dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
+        rewriter, genericOp, dilations, strides);
+  if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
+        rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
+                                                GenericOp genericOp) {
+  SmallVector<int64_t> dilations, strides;
+  if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+          genericOp, &dilations, &strides))
+    return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+        rewriter, genericOp, dilations, strides);
+  return failure();
+}
+
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
+// improve the search speed, the convolution ops have been segregated based on
+// the rank of iterator types array.
+static FailureOr<LinalgOp>
+inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
+  SmallVector<utils::IteratorType> iteratorTypes =
+      genericOp.getIteratorTypesArray();
+  unsigned totalIterators = iteratorTypes.size();
+  switch (totalIterators) {
+  case 2:
+    return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
+  case 4:
+    return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
+  case 5:
+    return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
+  case 6:
+    return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
+  case 7:
+    return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
+  case 8:
+    return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
+  case 9:
+    return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
----------------
Abhishek-Varma wrote:

I've removed `inferAndSpecializeBasedOnRankNConvIteratorType` (N = 2, 4, 5, ...) and checking/forming convolution type within just one function `inferAndSpecializeToConvolutionOp` now in the latest push to the patch.

https://github.com/llvm/llvm-project/pull/163724


More information about the Mlir-commits mailing list