[Mlir-commits] [mlir] [mlir][linalg] Use inferConvolutionDims for generic convolution downscaling (PR #180586)

Han-Chung Wang llvmlistbot at llvm.org
Tue Feb 24 16:56:30 PST 2026


================
@@ -1422,289 +1422,234 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   return success();
 }
 
-// The following are patterns for downscaling convolution ops with size-1
-// window dimensions.
+//===----------------------------------------------------------------------===//
+// Generic DownscaleSizeOneWindowedConvolution
+//===----------------------------------------------------------------------===//
 //
-// Note that we'd eventually want to write such transformations in a generic
-// way, e.g., converting to linalg.generic, removing the size-1 dimensions,
-// and then turning back to named ops. But for now it's fine to have a few
-// patterns matching special ops to get started.
-
-template <typename Conv2DOp, typename Conv1DOp>
-FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
-    returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const {
-  // Check if this LinalgOp is of the expected Conv2DOp type (named or generic).
-  std::optional<DilationsAndStrides> convParams =
-      matchConvolutionOpOfType<Conv2DOp>(convOp);
-  if (!convParams)
-    return failure();
-  SmallVector<int64_t> dilations = std::move(convParams->dilations);
-  SmallVector<int64_t> strides = std::move(convParams->strides);
-
-  if (convOp.hasPureBufferSemantics())
-    return failure(); // To be implemented.
-
-  Value input = convOp.getDpsInputs().front();
-  Value kernel = convOp.getDpsInputs().back();
-  Value output = convOp.getDpsInits().front();
-
-  auto inputType = dyn_cast<RankedTensorType>(input.getType());
-  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
-  auto outputType = dyn_cast<RankedTensorType>(output.getType());
-
-  auto kernelShape = kernelType.getShape();
-  auto outputShape = outputType.getShape();
-
-  // Get domain indices based on Conv2DOp type. These are known at compile time.
-  int64_t khIndex, kwIndex, ohIndex, owIndex;
-  if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
-                std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
-                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
-                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
-                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
-                std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
-    // NHWC layout: kernel [H, W, ...], output [N, H, W, C]
-    khIndex = 0;
-    kwIndex = 1;
-    ohIndex = 1;
-    owIndex = 2;
-  } else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
-    // NCHW_FCHW layout: kernel [..., H, W], output [N, C, H, W]
-    khIndex = 2;
-    kwIndex = 3;
-    ohIndex = 2;
-    owIndex = 3;
-  } else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
-                       std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
-    // NCHW pooling layout: kernel [H, W], output [N, C, H, W]
-    khIndex = 0;
-    kwIndex = 1;
-    ohIndex = 2;
-    owIndex = 3;
+/// Returns the indices of affine map results that reference any of the given
+/// dimensions.
+static SmallVector<unsigned>
+getResultIndicesReferencingDims(AffineMap map, ArrayRef<unsigned> dims) {
+  SmallVector<unsigned> resultIndices;
+  for (unsigned dim : dims) {
+    for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
+      AffineExpr expr = map.getResult(i);
+      if (expr.isFunctionOfDim(dim)) {
+        resultIndices.push_back(i);
+        break;
+      }
+    }
   }
+  return resultIndices;
+}
 
-  // Only handle the case where at least one of the window dimensions is
-  // of size 1. Other cases can rely on tiling to reduce to such cases.
-  int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
-  int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
-  bool removeH = (khSize == 1 && ohSize == 1);
-  bool removeW = (kwSize == 1 && owSize == 1);
-  if (!removeH && !removeW)
-    return failure();
+/// Helper to create a rank-reducing extract_slice that removes specific
+/// dimensions from a tensor.
+static Value createRankReducingExtractSlice(RewriterBase &rewriter,
+                                            Location loc, Value tensor,
+                                            ArrayRef<unsigned> dimsToRemove) {
+  auto tensorType = cast<RankedTensorType>(tensor.getType());
+  int64_t rank = tensorType.getRank();
+
+  // Compute new shape by removing the specified dimensions.
+  SmallVector<int64_t> newShape;
+  for (int64_t i = 0; i < rank; ++i) {
+    if (!llvm::is_contained(dimsToRemove, i))
+      newShape.push_back(tensorType.getDimSize(i));
+  }
 
-  // Get new shapes and types for all operands by removing the size-1
-  // dimension.
-  using RTTBuilder = RankedTensorType::Builder;
-  RankedTensorType newInputType =
-      RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
-  RankedTensorType newKernelType =
-      RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
-  RankedTensorType newOutputType =
-      RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
-
-  // Rank-reduce operands.
-  Location loc = convOp.getLoc();
-  Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, input, newInputType);
-  Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, kernel, newKernelType);
-  Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, output, newOutputType);
-
-  // Rank-reduce strides and dilations too.
-  // TODO: dropDim 1-liner helper.
-  strides.erase(strides.begin() + (removeH ? 0 : 1));
-  auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
-  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
-  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
-
-  auto conv1DOp = Conv1DOp::create(
-      rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
-      ValueRange{newOutput}, stridesAttr, dilationsAttr);
-
-  // Insert back.
-  Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
-      rewriter, loc, conv1DOp.getResult(0), output);
-  rewriter.replaceOp(convOp, inserted);
-
-  return conv1DOp;
+  auto newType = RankedTensorType::get(newShape, tensorType.getElementType());
+  return tensor::createCanonicalRankReducingExtractSliceOp(rewriter, loc,
+                                                           tensor, newType);
 }
 
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
-                                                              Conv1DNwcWcfOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
-                                                              Conv1DNcwFcwOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
-                                                              PoolingNwcSumOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
-                                                              PoolingNcwSumOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
-                                                              PoolingNwcMaxOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<
-    PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
-                                                              PoolingNwcMinOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<
-    PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
-                                                              PoolingNcwMaxOp>;
-
-FailureOr<DepthwiseConv1DNwcWcOp>
-DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
-    LinalgOp convOp, PatternRewriter &rewriter) const {
-  // Check if this LinalgOp is a DepthwiseConv2DNhwcHwcOp (named or generic).
-  std::optional<DilationsAndStrides> convParams =
-      matchConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(convOp);
-  if (!convParams)
+/// Drops specified dimensions from an AffineExpr and compresses remaining
+/// dimension indices. Returns std::nullopt if the expression only references
+/// the dropped dimensions.
+static std::optional<AffineExpr>
+dropDimsAndCompress(AffineExpr expr, ArrayRef<unsigned> dimsToDrop,
+                    unsigned newNumDims, MLIRContext *ctx) {
+  // Check if expr only references dimensions to be dropped.
+  bool onlyReferencesDroppedDims = true;
+  for (unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
+    if (expr.isFunctionOfDim(d) && !llvm::is_contained(dimsToDrop, d)) {
+      onlyReferencesDroppedDims = false;
+      break;
+    }
+  }
+  if (onlyReferencesDroppedDims && llvm::any_of(dimsToDrop, [&](unsigned d) {
+        return expr.isFunctionOfDim(d);
+      }))
+    return std::nullopt;
+
+  // Replace dimensions: compute new index for each old dimension.
+  // Dropped dimensions get mapped to constant 0, others get compressed.
+  SmallVector<AffineExpr> dimReplacements;
+  unsigned newDimIdx = 0;
+  for (unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
+    if (llvm::is_contained(dimsToDrop, d)) {
+      dimReplacements.push_back(getAffineConstantExpr(0, ctx));
+    } else {
+      dimReplacements.push_back(getAffineDimExpr(newDimIdx++, ctx));
+    }
+  }
+
+  return expr.replaceDims(dimReplacements);
+}
+
+FailureOr<LinalgOp>
+linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
+                                            LinalgOp op) {
+  auto maybeDims = inferConvolutionDims(op);
+  if (failed(maybeDims))
     return failure();
-  SmallVector<int64_t> dilations = std::move(convParams->dilations);
-  SmallVector<int64_t> strides = std::move(convParams->strides);
-
-  if (convOp.hasPureBufferSemantics())
-    return failure(); // To be implemented.
-
-  Value input = convOp.getDpsInputs().front();
-  Value kernel = convOp.getDpsInputs().back();
-  Value output = convOp.getDpsInits().front();
-
-  auto inputType = dyn_cast<RankedTensorType>(input.getType());
-  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
-  auto outputType = dyn_cast<RankedTensorType>(output.getType());
-
-  auto kernelShape = kernelType.getShape();
-  auto outputShape = outputType.getShape();
-
-  // Only handle the case where at least one of the window dimensions is
-  // of size 1. Other cases can rely on tiling to reduce to such cases.
-  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
-  int64_t ohSize = outputShape[1], owSize = outputShape[2];
-  bool removeH = (khSize == 1 && ohSize == 1);
-  bool removeW = (kwSize == 1 && owSize == 1);
-  if (!removeH && !removeW)
+
+  // Currently supports only 2D convolutions.
+  if (maybeDims->outputImage.size() != 2 || maybeDims->filterLoop.size() != 2)
     return failure();
 
-  // Get new shapes and types for all operands by removing the size-1
-  // dimension.
-  using RTTBuilder = RankedTensorType::Builder;
-  RankedTensorType newInputType =
-      RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
-  RankedTensorType newKernelType =
-      RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
-  RankedTensorType newOutputType =
-      RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
-
-  // Rank-reduce operands.
-  Location loc = convOp.getLoc();
-  Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, input, newInputType);
-  Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, kernel, newKernelType);
-  Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, output, newOutputType);
-
-  // Rank-reduce strides and dilations too.
-  // TODO: dropDim 1-liner helper.
-  strides.erase(strides.begin() + (removeH ? 0 : 1));
-  auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
-  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
-  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
-
-  auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
-      rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
-      ValueRange{newOutput}, stridesAttr, dilationsAttr);
-
-  // Insert back.
-  Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
-      rewriter, loc, conv1DOp.getResult(0), output);
-  rewriter.replaceOp(convOp, inserted);
-
-  return conv1DOp;
-}
+  if (op.hasPureBufferSemantics())
+    return failure();
 
-FailureOr<Conv1DOp>
-DownscaleConv2DOp::returningMatchAndRewrite(LinalgOp convOp,
-                                            PatternRewriter &rewriter) const {
-  // Check if this LinalgOp is a Conv2DOp (named or generic).
-  std::optional<DilationsAndStrides> convParams =
-      matchConvolutionOpOfType<Conv2DOp>(convOp);
-  if (!convParams)
+  // Get loop domain indices for spatial dimensions.
+  unsigned outSpatial0 = maybeDims->outputImage[0];
+  unsigned outSpatial1 = maybeDims->outputImage[1];
+  unsigned filterSpatial0 = maybeDims->filterLoop[0];
+  unsigned filterSpatial1 = maybeDims->filterLoop[1];
+
+  // Get sizes from loop bounds.
+  SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
+  int64_t outSize0 = loopRanges[outSpatial0];
+  int64_t outSize1 = loopRanges[outSpatial1];
+  int64_t filterSize0 = loopRanges[filterSpatial0];
+  int64_t filterSize1 = loopRanges[filterSpatial1];
+
+  // Check if we can downscale by removing a spatial dimension.
+  bool canRemoveSpatial0 = (filterSize0 == 1 && outSize0 == 1);
+  bool canRemoveSpatial1 = (filterSize1 == 1 && outSize1 == 1);
+  if (!canRemoveSpatial0 && !canRemoveSpatial1)
     return failure();
 
-  if (convOp.hasPureBufferSemantics())
-    return failure(); // To be implemented.
+  // Prioritize dropping the leading spatial dimension if both are removable.
+  bool removeSpatial0 = canRemoveSpatial0;
 
-  Value input = convOp.getDpsInputs().front();
-  Value kernel = convOp.getDpsInputs().back();
-  Value output = convOp.getDpsInits().front();
+  // Determine which loop dims to remove (output spatial + corresponding filter)
+  SmallVector<unsigned> loopDimsToRemove;
+  if (removeSpatial0) {
+    loopDimsToRemove.push_back(outSpatial0);
+    loopDimsToRemove.push_back(filterSpatial0);
+  } else {
+    loopDimsToRemove.push_back(outSpatial1);
+    loopDimsToRemove.push_back(filterSpatial1);
+  }
+  // Sort for correct index compression when removing dimensions from affine
+  // maps.
+  llvm::sort(loopDimsToRemove);
 
-  auto inputType = dyn_cast<RankedTensorType>(input.getType());
-  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
-  auto outputType = dyn_cast<RankedTensorType>(output.getType());
+  // Create new indexing maps with dimensions removed.
+  SmallVector<AffineMap> newMaps;
+  MLIRContext *ctx = op.getContext();
+  unsigned numDims = op.getNumLoops();
+  unsigned newNumDims = numDims - loopDimsToRemove.size();
+
+  for (AffineMap map : op.getIndexingMapsArray()) {
+    // Remove the loop dimensions from the map.
+    SmallVector<AffineExpr> newResults;
+    for (AffineExpr expr : map.getResults()) {
+      auto newExpr =
+          dropDimsAndCompress(expr, loopDimsToRemove, newNumDims, ctx);
+      if (newExpr)
+        newResults.push_back(*newExpr);
+    }
+    newMaps.push_back(AffineMap::get(newNumDims, 0, newResults, ctx));
+  }
 
-  auto kernelShape = kernelType.getShape();
-  auto outputShape = outputType.getShape();
+  // Create new iterator types.
+  SmallVector<utils::IteratorType> newIterTypes;
+  auto iterTypes = op.getIteratorTypesArray();
+  for (unsigned idx = 0; idx < iterTypes.size(); ++idx) {
+    if (!llvm::is_contained(loopDimsToRemove, idx))
+      newIterTypes.push_back(iterTypes[idx]);
+  }
 
-  // Only handle the case where at least one of the window dimensions is
-  // of size 1. Other cases can rely on tiling to reduce to such cases.
-  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
-  int64_t ohSize = outputShape[0], owSize = outputShape[1];
-  bool removeH = (khSize == 1 && ohSize == 1);
-  bool removeW = (kwSize == 1 && owSize == 1);
-  if (!removeH && !removeW)
-    return failure();
+  // Rank-reduce operands using extract_slice.
+  Location loc = op.getLoc();
+  SmallVector<Value> newInputs;
+  for (OpOperand *input : op.getDpsInputOperands()) {
+    AffineMap map = op.getMatchingIndexingMap(input);
+    SmallVector<unsigned> tensorDimsToRemove =
+        getResultIndicesReferencingDims(map, loopDimsToRemove);
+    Value reduced = createRankReducingExtractSlice(rewriter, loc, input->get(),
+                                                   tensorDimsToRemove);
+    newInputs.push_back(reduced);
+  }
 
-  // Get new shapes and types for all operands by removing the size-1
-  // dimension.
-  using RTTBuilder = RankedTensorType::Builder;
-  RankedTensorType newInputType =
-      RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
-  RankedTensorType newKernelType =
-      RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
-  RankedTensorType newOutputType =
-      RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
-
-  // Rank-reduce operands.
-  Location loc = convOp.getLoc();
-  Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, input, newInputType);
-  Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, kernel, newKernelType);
-  Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
-      rewriter, loc, output, newOutputType);
-
-  auto conv1DOp =
-      Conv1DOp::create(rewriter, loc, newOutputType,
-                       ValueRange{newInput, newKernel}, ValueRange{newOutput});
-
-  // Insert back.
-  Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
-      rewriter, loc, conv1DOp.getResult(0), output);
-  rewriter.replaceOp(convOp, inserted);
-
-  return conv1DOp;
+  SmallVector<Value> newOutputs;
+  Value originalOutput;
+  SmallVector<OpOperand *> initOperands =
+      llvm::to_vector(llvm::make_pointer_range(op.getDpsInitsMutable()));
+  for (OpOperand *output : initOperands) {
+    originalOutput = output->get();
+    AffineMap map = op.getMatchingIndexingMap(output);
+    SmallVector<unsigned> tensorDimsToRemove =
+        getResultIndicesReferencingDims(map, loopDimsToRemove);
+    Value reduced = createRankReducingExtractSlice(rewriter, loc, output->get(),
+                                                   tensorDimsToRemove);
+    newOutputs.push_back(reduced);
+  }
+
+  // Create new linalg.generic with reduced dimensions.
+  auto newOp = linalg::GenericOp::create(
+      rewriter, loc, TypeRange{newOutputs[0].getType()}, newInputs, newOutputs,
+      newMaps, newIterTypes,
+      [&](OpBuilder &b, Location nestedLoc, ValueRange args) {
+        IRMapping mapping;
+        for (auto [oldArg, newArg] :
+             llvm::zip(op.getBlock()->getArguments(), args))
+          mapping.map(oldArg, newArg);
+        for (Operation &bodyOp : op.getBlock()->without_terminator())
+          b.clone(bodyOp, mapping);
+        auto yield = cast<linalg::YieldOp>(op.getBlock()->getTerminator());
+        linalg::YieldOp::create(b, nestedLoc,
+                                mapping.lookup(yield.getOperand(0)));
+      });
----------------
hanhanW wrote:

I think you can create the generic op without the body builder and use region cloning method like `inlineRegionBefore`. E.g.,

```cpp
  auto newOp = linalg::GenericOp::create(
      rewriter, loc, TypeRange{newOutputs[0].getType()}, newInputs, newOutputs,
      newMaps, newIterTypes);
  rewriter.inlineRegionBefore(op->getRegion(0), newOp.getRegion(),
                              newOp.getRegion().begin());
```

https://github.com/llvm/llvm-project/pull/180586


More information about the Mlir-commits mailing list