[Mlir-commits] [mlir] [Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops (PR #163724)
Abhishek Varma
llvmlistbot at llvm.org
Sun Nov 9 23:26:16 PST 2025
================
@@ -237,6 +237,78 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
+/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
+template <typename ConvOpTy>
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+ ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+ SmallVector<Value> inputs = genericOp.getDpsInputs();
+ ValueRange outputs = genericOp.getDpsInits();
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+ ? TypeRange(ValueRange(outputs))
+ : TypeRange{};
+ LinalgOp namedOp;
+ if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+ inputs, outputs);
+ } else {
+ Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+ Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+ genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+ }
+ return namedOp;
+}
+
+// Converts linalg.generic to named linalg.*conv/pooling* where possible.
+static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ // -----------------------------
+ // Depthwise Convolution ops.
+ // -----------------------------
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
+ rewriter, genericOp, dilations, strides);
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
+ rewriter, genericOp, dilations, strides);
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ rewriter, genericOp, dilations, strides);
+ // -----------------------------
+ // Pooling ops.
+ // -----------------------------
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
+ rewriter, genericOp, dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
----------------
Abhishek-Varma wrote:
Wow. This was super helpful! Thanks Hanhan! Updated.
https://github.com/llvm/llvm-project/pull/163724
More information about the Mlir-commits
mailing list