[Mlir-commits] [mlir] [mlir][linalg] Use inferConvolutionDims for generic convolution downscaling (PR #180586)
Abhishek Varma
llvmlistbot at llvm.org
Wed Feb 11 22:53:39 PST 2026
================
@@ -1422,289 +1422,221 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
return success();
}
-// The following are patterns for downscaling convolution ops with size-1
-// window dimensions.
+//===----------------------------------------------------------------------===//
+// Generic DownscaleSizeOneWindowedConvolution
+//===----------------------------------------------------------------------===//
//
-// Note that we'd eventually want to write such transformations in a generic
-// way, e.g., converting to linalg.generic, removing the size-1 dimensions,
-// and then turning back to named ops. But for now it's fine to have a few
-// patterns matching special ops to get started.
-
-template <typename Conv2DOp, typename Conv1DOp>
-FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
- returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const {
- // Check if this LinalgOp is of the expected Conv2DOp type (named or generic).
- std::optional<DilationsAndStrides> convParams =
- matchConvolutionOpOfType<Conv2DOp>(convOp);
- if (!convParams)
- return failure();
- SmallVector<int64_t> dilations = std::move(convParams->dilations);
- SmallVector<int64_t> strides = std::move(convParams->strides);
-
- if (convOp.hasPureBufferSemantics())
- return failure(); // To be implemented.
-
- Value input = convOp.getDpsInputs().front();
- Value kernel = convOp.getDpsInputs().back();
- Value output = convOp.getDpsInits().front();
-
- auto inputType = dyn_cast<RankedTensorType>(input.getType());
- auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
- auto outputType = dyn_cast<RankedTensorType>(output.getType());
-
- auto kernelShape = kernelType.getShape();
- auto outputShape = outputType.getShape();
-
- // Get domain indices based on Conv2DOp type. These are known at compile time.
- int64_t khIndex, kwIndex, ohIndex, owIndex;
- if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
- std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
- std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
- std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
- std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
- std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
- // NHWC layout: kernel [H, W, ...], output [N, H, W, C]
- khIndex = 0;
- kwIndex = 1;
- ohIndex = 1;
- owIndex = 2;
- } else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
- // NCHW_FCHW layout: kernel [..., H, W], output [N, C, H, W]
- khIndex = 2;
- kwIndex = 3;
- ohIndex = 2;
- owIndex = 3;
- } else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
- std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
- // NCHW pooling layout: kernel [H, W], output [N, C, H, W]
- khIndex = 0;
- kwIndex = 1;
- ohIndex = 2;
- owIndex = 3;
+// This pattern rewrites 2-D convolution/pooling/depthwise ops with size-1
+// window dimensions into lower-dimensional ops. It uses inferConvolutionDims
+// to work with any layout and handles both named ops and equivalent
+// linalg.generic ops uniformly.
+//
+/// Returns the indices of affine map results that reference any of the given
+/// dimensions.
+static SmallVector<unsigned>
+getResultIndicesReferencingDims(AffineMap map, ArrayRef<unsigned> dims) {
+ SmallVector<unsigned> resultIndices;
+ for (unsigned dim : dims) {
+ for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
+ AffineExpr expr = map.getResult(i);
+ if (expr.isFunctionOfDim(dim)) {
+ resultIndices.push_back(i);
+ break;
+ }
+ }
}
+ return resultIndices;
+}
- // Only handle the case where at least one of the window dimensions is
- // of size 1. Other cases can rely on tiling to reduce to such cases.
- int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
- int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
- bool removeH = (khSize == 1 && ohSize == 1);
- bool removeW = (kwSize == 1 && owSize == 1);
- if (!removeH && !removeW)
- return failure();
+/// Helper to create a rank-reducing extract_slice that removes specific
+/// dimensions from a tensor.
+static Value createRankReducingExtractSlice(RewriterBase &rewriter,
+ Location loc, Value tensor,
+ ArrayRef<unsigned> dimsToRemove) {
+ auto tensorType = cast<RankedTensorType>(tensor.getType());
+ int64_t rank = tensorType.getRank();
+
+ // Compute new shape by removing the specified dimensions.
+ SmallVector<int64_t> newShape;
+ for (int64_t i = 0; i < rank; ++i) {
+ if (!llvm::is_contained(dimsToRemove, i))
+ newShape.push_back(tensorType.getDimSize(i));
+ }
- // Get new shapes and types for all operands by removing the size-1
- // dimension.
- using RTTBuilder = RankedTensorType::Builder;
- RankedTensorType newInputType =
- RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
- RankedTensorType newKernelType =
- RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
- RankedTensorType newOutputType =
- RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
-
- // Rank-reduce operands.
- Location loc = convOp.getLoc();
- Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, input, newInputType);
- Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, kernel, newKernelType);
- Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, output, newOutputType);
-
- // Rank-reduce strides and dilations too.
- // TODO: dropDim 1-liner helper.
- strides.erase(strides.begin() + (removeH ? 0 : 1));
- auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
- dilations.erase(dilations.begin() + (removeH ? 0 : 1));
- auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
-
- auto conv1DOp = Conv1DOp::create(
- rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
- ValueRange{newOutput}, stridesAttr, dilationsAttr);
-
- // Insert back.
- Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
- rewriter, loc, conv1DOp.getResult(0), output);
- rewriter.replaceOp(convOp, inserted);
-
- return conv1DOp;
+ auto newType = RankedTensorType::get(newShape, tensorType.getElementType());
+ return tensor::createCanonicalRankReducingExtractSliceOp(rewriter, loc,
+ tensor, newType);
}
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
- Conv1DNwcWcfOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
- Conv1DNcwFcwOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
- PoolingNwcSumOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
- PoolingNcwSumOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
- PoolingNwcMaxOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<
- PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
- PoolingNwcMinOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<
- PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
- PoolingNcwMaxOp>;
-
-FailureOr<DepthwiseConv1DNwcWcOp>
-DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
- LinalgOp convOp, PatternRewriter &rewriter) const {
- // Check if this LinalgOp is a DepthwiseConv2DNhwcHwcOp (named or generic).
- std::optional<DilationsAndStrides> convParams =
- matchConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(convOp);
- if (!convParams)
+/// Drops specified dimensions from an AffineExpr and compresses remaining
+/// dimension indices. Returns std::nullopt if the expression only references
+/// the dropped dimensions.
+static std::optional<AffineExpr>
+dropDimsAndCompress(AffineExpr expr, ArrayRef<unsigned> dimsToDrop,
+ unsigned newNumDims, MLIRContext *ctx) {
+ // Check if expr only references dimensions to be dropped.
+ bool onlyReferencesDroppedDims = true;
+ for (unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
+ if (expr.isFunctionOfDim(d) && !llvm::is_contained(dimsToDrop, d)) {
+ onlyReferencesDroppedDims = false;
+ break;
+ }
+ }
+ if (onlyReferencesDroppedDims && llvm::any_of(dimsToDrop, [&](unsigned d) {
+ return expr.isFunctionOfDim(d);
+ }))
+ return std::nullopt;
+
+ // Replace dimensions: compute new index for each old dimension.
+ // Dropped dimensions get mapped to constant 0, others get compressed.
+ SmallVector<AffineExpr> dimReplacements;
+ unsigned newDimIdx = 0;
+ for (unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
+ if (llvm::is_contained(dimsToDrop, d)) {
+ dimReplacements.push_back(getAffineConstantExpr(0, ctx));
+ } else {
+ dimReplacements.push_back(getAffineDimExpr(newDimIdx++, ctx));
+ }
+ }
+
+ return expr.replaceDims(dimReplacements);
+}
+
+FailureOr<LinalgOp>
+linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
+ LinalgOp op) {
+ auto maybeDims = inferConvolutionDims(op);
+ if (failed(maybeDims))
return failure();
- SmallVector<int64_t> dilations = std::move(convParams->dilations);
- SmallVector<int64_t> strides = std::move(convParams->strides);
-
- if (convOp.hasPureBufferSemantics())
- return failure(); // To be implemented.
-
- Value input = convOp.getDpsInputs().front();
- Value kernel = convOp.getDpsInputs().back();
- Value output = convOp.getDpsInits().front();
-
- auto inputType = dyn_cast<RankedTensorType>(input.getType());
- auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
- auto outputType = dyn_cast<RankedTensorType>(output.getType());
-
- auto kernelShape = kernelType.getShape();
- auto outputShape = outputType.getShape();
-
- // Only handle the case where at least one of the window dimensions is
- // of size 1. Other cases can rely on tiling to reduce to such cases.
- int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
- int64_t ohSize = outputShape[1], owSize = outputShape[2];
- bool removeH = (khSize == 1 && ohSize == 1);
- bool removeW = (kwSize == 1 && owSize == 1);
- if (!removeH && !removeW)
+
+ // Must be 2D Conv.
+ if (maybeDims->outputImage.size() != 2 || maybeDims->filterLoop.size() != 2)
return failure();
- // Get new shapes and types for all operands by removing the size-1
- // dimension.
- using RTTBuilder = RankedTensorType::Builder;
- RankedTensorType newInputType =
- RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
- RankedTensorType newKernelType =
- RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
- RankedTensorType newOutputType =
- RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
-
- // Rank-reduce operands.
- Location loc = convOp.getLoc();
- Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, input, newInputType);
- Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, kernel, newKernelType);
- Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, output, newOutputType);
-
- // Rank-reduce strides and dilations too.
- // TODO: dropDim 1-liner helper.
- strides.erase(strides.begin() + (removeH ? 0 : 1));
- auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
- dilations.erase(dilations.begin() + (removeH ? 0 : 1));
- auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
-
- auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
- rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
- ValueRange{newOutput}, stridesAttr, dilationsAttr);
-
- // Insert back.
- Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
- rewriter, loc, conv1DOp.getResult(0), output);
- rewriter.replaceOp(convOp, inserted);
-
- return conv1DOp;
-}
+ if (op.hasPureBufferSemantics())
+ return failure();
-FailureOr<Conv1DOp>
-DownscaleConv2DOp::returningMatchAndRewrite(LinalgOp convOp,
- PatternRewriter &rewriter) const {
- // Check if this LinalgOp is a Conv2DOp (named or generic).
- std::optional<DilationsAndStrides> convParams =
- matchConvolutionOpOfType<Conv2DOp>(convOp);
- if (!convParams)
+ // Get loop domain indices.
+ unsigned ohLoopIdx = maybeDims->outputImage[0];
+ unsigned owLoopIdx = maybeDims->outputImage[1];
+ unsigned khLoopIdx = maybeDims->filterLoop[0];
+ unsigned kwLoopIdx = maybeDims->filterLoop[1];
+
+ // Get sizes from loop bounds.
+ SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
+ int64_t ohSize = loopRanges[ohLoopIdx];
+ int64_t owSize = loopRanges[owLoopIdx];
+ int64_t khSize = loopRanges[khLoopIdx];
+ int64_t kwSize = loopRanges[kwLoopIdx];
+
+ // Check if we can downscale.
+ bool canRemoveH = (khSize == 1 && ohSize == 1);
+ bool canRemoveW = (kwSize == 1 && owSize == 1);
+ if (!canRemoveH && !canRemoveW)
return failure();
- if (convOp.hasPureBufferSemantics())
- return failure(); // To be implemented.
+ // Prefer removing H if both are possible.
+ bool removeH = canRemoveH;
----------------
Abhishek-Varma wrote:
I've rebased and added a lit test for cross_conv
https://github.com/llvm/llvm-project/pull/180586
More information about the Mlir-commits
mailing list