[Mlir-commits] [mlir] dfd5b85 - [mlir][linalg] Use inferConvolutionDims for generic convolution downscaling (#180586)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 16 00:55:54 PDT 2026
Author: Abhishek Varma
Date: 2026-03-16T13:25:48+05:30
New Revision: dfd5b85d0bb1f641c38a5344134f9796cbb5b407
URL: https://github.com/llvm/llvm-project/commit/dfd5b85d0bb1f641c38a5344134f9796cbb5b407
DIFF: https://github.com/llvm/llvm-project/commit/dfd5b85d0bb1f641c38a5344134f9796cbb5b407.diff
LOG: [mlir][linalg] Use inferConvolutionDims for generic convolution downscaling (#180586)
The goal of this PR is to implement a generic, structure-aware
convolution downscaling transformation that works for any
convolution-like operation regardless of its specific layout or naming,
rather than relying on pattern-matching against specific named
operations.
Each pattern we currently have, have hardcoded dimension indices
specific to its layout (e.g., NHWC vs NCHW).
This approach :-
1. Requires maintaining many similar patterns.
2. Is brittle when new layouts are introduced.
3. Cannot handle batchless versions of the conv variants.
This PR thus creates a single downscaleSizeOneWindowedConvolution
function that uses `inferConvolutionDims` to semantically understand the
convolution structure (batch dims, output image dims, filter loop dims,
etc.) rather than hardcoding indices.
It works with any layout - NHWC, NCHW, or any other - because it reasons
about the meaning of dimensions, not their positions.
If the input to the downscaling pattern is a named op -> the output will
be a named op. Else it'd be a generic op input/output.
And for this reason we now remove the second RUN line as the infra tests
both named as well as generic ops.
Signed-off-by: Abhishek Varma <abhvarma at amd.com>
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-op-decompose.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 12e228bcaeefa..cd842fb1c5392 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -310,9 +310,13 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
- Decomposes named complex operations, such as higher-dimensional
- (depthwise) convolutions, into combinations of lower-dimensional equivalents
- when possible.
+ Decomposes higher-dimensional convolution ops into lower-dimensional
+ equivalents when possible. This operates on both named ops and equivalent
+ `linalg.generic` ops that have convolution-like structure (as determined
+ by `inferConvolutionDims`).
+
+ The transformation always attempts to specialize the result back to a named
+ op when possible.
#### Return modes
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index dcb7f1f212207..486ef75b76859 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1640,63 +1640,22 @@ decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
FailureOr<linalg::GenericOp> deduplicateOperandsAndRemoveDeadResults(
RewriterBase &rewriter, linalg::GenericOp genericOp, bool removeOutputs);
+/// Rewrite convolution/pooling/depthwise ops with size-1 window dimensions
+/// into lower-dimensional ops. Uses `inferConvolutionDims` to work with any
+/// layout and handles both named ops and equivalent linalg.generic ops
+/// uniformly. The result is specialized back to a named op if the input was a
+/// named op.
+/// TODO: Support n-D to (n-1)-D downscaling. Currently it only support 2D->1D
+/// downscaling.
+FailureOr<LinalgOp> downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
+ LinalgOp op);
+
//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
// functional-stye API call.
//===----------------------------------------------------------------------===//
-/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
-/// convolution ops. Works with both named ops and equivalent generic ops.
-template <typename Conv2DOp, typename Conv1DOp>
-struct DownscaleSizeOneWindowed2DConvolution final
- : public OpInterfaceRewritePattern<LinalgOp> {
- using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
-
- FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
- PatternRewriter &rewriter) const;
-
- LogicalResult matchAndRewrite(LinalgOp convOp,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(convOp, rewriter);
- }
-};
-
-extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
- Conv1DNwcWcfOp>;
-extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
- Conv1DNcwFcwOp>;
-
-/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
-/// dimensions into 1-D depthwise convolution ops.
-struct DownscaleDepthwiseConv2DNhwcHwcOp final
- : public OpInterfaceRewritePattern<LinalgOp> {
- DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context,
- PatternBenefit benefit = 1)
- : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
-
- FailureOr<DepthwiseConv1DNwcWcOp>
- returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const;
-
- LogicalResult matchAndRewrite(LinalgOp convOp,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(convOp, rewriter);
- }
-};
-
-struct DownscaleConv2DOp final : public OpInterfaceRewritePattern<LinalgOp> {
- DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1)
- : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
-
- FailureOr<Conv1DOp> returningMatchAndRewrite(LinalgOp convOp,
- PatternRewriter &rewriter) const;
-
- LogicalResult matchAndRewrite(LinalgOp convOp,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(convOp, rewriter);
- }
-};
-
///
/// Linalg generalization pattern.
///
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d84408c024e25..d751488d186ad 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -505,32 +505,12 @@ transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
-#define DOWNSCALE(trans) \
- { \
- FailureOr<LinalgOp> res = tryApply<trans>(target); \
- if (succeeded(res)) { \
- results.push_back(*res); \
- return DiagnosedSilenceableFailure::success(); \
- } \
- }
-
-#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
-#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
-
- DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
- DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
- DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
- DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
- DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
- DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
- DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
- DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
- DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
- DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp)
- DOWNSCALE(DownscaleConv2DOp)
-#undef DOWNSCALE_NORMAL
-#undef DOWNSCALE_CALL
-#undef DOWNSCALE
+ FailureOr<linalg::LinalgOp> res =
+ downscaleSizeOneWindowedConvolution(rewriter, target);
+ if (succeeded(res)) {
+ results.push_back(*res);
+ return DiagnosedSilenceableFailure::success();
+ }
return emitDefaultSilenceableFailure(target);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 2b4986aeac14f..260e36fb47f04 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1424,289 +1424,213 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
return success();
}
-// The following are patterns for downscaling convolution ops with size-1
-// window dimensions.
+//===----------------------------------------------------------------------===//
+// Generic DownscaleSizeOneWindowedConvolution
+//===----------------------------------------------------------------------===//
//
-// Note that we'd eventually want to write such transformations in a generic
-// way, e.g., converting to linalg.generic, removing the size-1 dimensions,
-// and then turning back to named ops. But for now it's fine to have a few
-// patterns matching special ops to get started.
-
-template <typename Conv2DOp, typename Conv1DOp>
-FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
- returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const {
- // Check if this LinalgOp is of the expected Conv2DOp type (named or generic).
- std::optional<DilationsAndStrides> convParams =
- matchConvolutionOpOfType<Conv2DOp>(convOp);
- if (!convParams)
- return failure();
- SmallVector<int64_t> dilations = std::move(convParams->dilations);
- SmallVector<int64_t> strides = std::move(convParams->strides);
-
- if (convOp.hasPureBufferSemantics())
- return failure(); // To be implemented.
-
- Value input = convOp.getDpsInputs().front();
- Value kernel = convOp.getDpsInputs().back();
- Value output = convOp.getDpsInits().front();
-
- auto inputType = dyn_cast<RankedTensorType>(input.getType());
- auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
- auto outputType = dyn_cast<RankedTensorType>(output.getType());
-
- auto kernelShape = kernelType.getShape();
- auto outputShape = outputType.getShape();
-
- // Get domain indices based on Conv2DOp type. These are known at compile time.
- int64_t khIndex, kwIndex, ohIndex, owIndex;
- if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
- std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
- std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
- std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
- std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
- std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
- // NHWC layout: kernel [H, W, ...], output [N, H, W, C]
- khIndex = 0;
- kwIndex = 1;
- ohIndex = 1;
- owIndex = 2;
- } else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
- // NCHW_FCHW layout: kernel [..., H, W], output [N, C, H, W]
- khIndex = 2;
- kwIndex = 3;
- ohIndex = 2;
- owIndex = 3;
- } else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
- std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
- // NCHW pooling layout: kernel [H, W], output [N, C, H, W]
- khIndex = 0;
- kwIndex = 1;
- ohIndex = 2;
- owIndex = 3;
+/// Returns the indices of affine map results that reference any of the given
+/// dimensions.
+static SmallVector<unsigned>
+getResultIndicesReferencingDims(AffineMap map, ArrayRef<unsigned> dims) {
+ SmallVector<unsigned> resultIndices;
+ for (unsigned dim : dims) {
+ for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
+ AffineExpr expr = map.getResult(i);
+ if (expr.isFunctionOfDim(dim)) {
+ resultIndices.push_back(i);
+ break;
+ }
+ }
}
+ return resultIndices;
+}
- // Only handle the case where at least one of the window dimensions is
- // of size 1. Other cases can rely on tiling to reduce to such cases.
- int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
- int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
- bool removeH = (khSize == 1 && ohSize == 1);
- bool removeW = (kwSize == 1 && owSize == 1);
- if (!removeH && !removeW)
- return failure();
+/// Helper to create a rank-reducing extract_slice that removes specific
+/// dimensions from a tensor.
+static Value createRankReducingExtractSlice(RewriterBase &rewriter,
+ Location loc, Value tensor,
+ ArrayRef<unsigned> dimsToRemove) {
+ auto tensorType = cast<RankedTensorType>(tensor.getType());
+ int64_t rank = tensorType.getRank();
+
+ // Compute new shape by removing the specified dimensions.
+ SmallVector<int64_t> newShape;
+ for (int64_t i = 0; i < rank; ++i) {
+ if (!llvm::is_contained(dimsToRemove, i))
+ newShape.push_back(tensorType.getDimSize(i));
+ }
- // Get new shapes and types for all operands by removing the size-1
- // dimension.
- using RTTBuilder = RankedTensorType::Builder;
- RankedTensorType newInputType =
- RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
- RankedTensorType newKernelType =
- RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
- RankedTensorType newOutputType =
- RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
-
- // Rank-reduce operands.
- Location loc = convOp.getLoc();
- Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, input, newInputType);
- Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, kernel, newKernelType);
- Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, output, newOutputType);
-
- // Rank-reduce strides and dilations too.
- // TODO: dropDim 1-liner helper.
- strides.erase(strides.begin() + (removeH ? 0 : 1));
- auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
- dilations.erase(dilations.begin() + (removeH ? 0 : 1));
- auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
-
- auto conv1DOp = Conv1DOp::create(
- rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
- ValueRange{newOutput}, stridesAttr, dilationsAttr);
-
- // Insert back.
- Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
- rewriter, loc, conv1DOp.getResult(0), output);
- rewriter.replaceOp(convOp, inserted);
-
- return conv1DOp;
+ auto newType = RankedTensorType::get(newShape, tensorType.getElementType());
+ return tensor::createCanonicalRankReducingExtractSliceOp(rewriter, loc,
+ tensor, newType);
}
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
- Conv1DNwcWcfOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
- Conv1DNcwFcwOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
- PoolingNwcSumOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
- PoolingNcwSumOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
- PoolingNwcMaxOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<
- PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
- PoolingNwcMinOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<
- PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
-template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
- PoolingNcwMaxOp>;
-
-FailureOr<DepthwiseConv1DNwcWcOp>
-DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
- LinalgOp convOp, PatternRewriter &rewriter) const {
- // Check if this LinalgOp is a DepthwiseConv2DNhwcHwcOp (named or generic).
- std::optional<DilationsAndStrides> convParams =
- matchConvolutionOpOfType<DepthwiseConv2DNhwcHwcOp>(convOp);
- if (!convParams)
+/// Drops specified dimensions from an AffineExpr and compresses remaining
+/// dimension indices. Returns std::nullopt if the expression only references
+/// the dropped dimensions.
+static std::optional<AffineExpr>
+dropDimsAndCompress(AffineExpr expr, ArrayRef<unsigned> dimsToDrop,
+ unsigned newNumDims, MLIRContext *ctx) {
+ // Check if expr only references dimensions to be dropped.
+ bool onlyReferencesDroppedDims = true;
+ for (unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
+ if (expr.isFunctionOfDim(d) && !llvm::is_contained(dimsToDrop, d)) {
+ onlyReferencesDroppedDims = false;
+ break;
+ }
+ }
+ if (onlyReferencesDroppedDims && llvm::any_of(dimsToDrop, [&](unsigned d) {
+ return expr.isFunctionOfDim(d);
+ }))
+ return std::nullopt;
+
+ // Replace dimensions: compute new index for each old dimension.
+ // Dropped dimensions get mapped to constant 0, others get compressed.
+ SmallVector<AffineExpr> dimReplacements;
+ unsigned newDimIdx = 0;
+ for (unsigned d = 0; d < newNumDims + dimsToDrop.size(); ++d) {
+ if (llvm::is_contained(dimsToDrop, d)) {
+ dimReplacements.push_back(getAffineConstantExpr(0, ctx));
+ } else {
+ dimReplacements.push_back(getAffineDimExpr(newDimIdx++, ctx));
+ }
+ }
+
+ return expr.replaceDims(dimReplacements);
+}
+
+FailureOr<LinalgOp>
+linalg::downscaleSizeOneWindowedConvolution(RewriterBase &rewriter,
+ LinalgOp op) {
+ auto maybeDims = inferConvolutionDims(op);
+ if (failed(maybeDims))
return failure();
- SmallVector<int64_t> dilations = std::move(convParams->dilations);
- SmallVector<int64_t> strides = std::move(convParams->strides);
-
- if (convOp.hasPureBufferSemantics())
- return failure(); // To be implemented.
-
- Value input = convOp.getDpsInputs().front();
- Value kernel = convOp.getDpsInputs().back();
- Value output = convOp.getDpsInits().front();
-
- auto inputType = dyn_cast<RankedTensorType>(input.getType());
- auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
- auto outputType = dyn_cast<RankedTensorType>(output.getType());
-
- auto kernelShape = kernelType.getShape();
- auto outputShape = outputType.getShape();
-
- // Only handle the case where at least one of the window dimensions is
- // of size 1. Other cases can rely on tiling to reduce to such cases.
- int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
- int64_t ohSize = outputShape[1], owSize = outputShape[2];
- bool removeH = (khSize == 1 && ohSize == 1);
- bool removeW = (kwSize == 1 && owSize == 1);
- if (!removeH && !removeW)
+
+ // Currently supports only 2D convolutions.
+ if (maybeDims->outputImage.size() != 2 || maybeDims->filterLoop.size() != 2)
return failure();
- // Get new shapes and types for all operands by removing the size-1
- // dimension.
- using RTTBuilder = RankedTensorType::Builder;
- RankedTensorType newInputType =
- RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
- RankedTensorType newKernelType =
- RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
- RankedTensorType newOutputType =
- RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
-
- // Rank-reduce operands.
- Location loc = convOp.getLoc();
- Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, input, newInputType);
- Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, kernel, newKernelType);
- Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, output, newOutputType);
-
- // Rank-reduce strides and dilations too.
- // TODO: dropDim 1-liner helper.
- strides.erase(strides.begin() + (removeH ? 0 : 1));
- auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
- dilations.erase(dilations.begin() + (removeH ? 0 : 1));
- auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
-
- auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
- rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
- ValueRange{newOutput}, stridesAttr, dilationsAttr);
-
- // Insert back.
- Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
- rewriter, loc, conv1DOp.getResult(0), output);
- rewriter.replaceOp(convOp, inserted);
-
- return conv1DOp;
-}
+ if (op.hasPureBufferSemantics())
+ return failure();
-FailureOr<Conv1DOp>
-DownscaleConv2DOp::returningMatchAndRewrite(LinalgOp convOp,
- PatternRewriter &rewriter) const {
- // Check if this LinalgOp is a Conv2DOp (named or generic).
- std::optional<DilationsAndStrides> convParams =
- matchConvolutionOpOfType<Conv2DOp>(convOp);
- if (!convParams)
+ // Get loop domain indices for spatial dimensions.
+ unsigned outSpatial0 = maybeDims->outputImage[0];
+ unsigned outSpatial1 = maybeDims->outputImage[1];
+ unsigned filterSpatial0 = maybeDims->filterLoop[0];
+ unsigned filterSpatial1 = maybeDims->filterLoop[1];
+
+ // Get sizes from loop bounds.
+ SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
+ int64_t outSize0 = loopRanges[outSpatial0];
+ int64_t outSize1 = loopRanges[outSpatial1];
+ int64_t filterSize0 = loopRanges[filterSpatial0];
+ int64_t filterSize1 = loopRanges[filterSpatial1];
+
+ // Check if we can downscale by removing a spatial dimension.
+ bool canRemoveSpatial0 = (filterSize0 == 1 && outSize0 == 1);
+ bool canRemoveSpatial1 = (filterSize1 == 1 && outSize1 == 1);
+ if (!canRemoveSpatial0 && !canRemoveSpatial1)
return failure();
- if (convOp.hasPureBufferSemantics())
- return failure(); // To be implemented.
+ // Determine which loop dims to remove (output spatial + corresponding filter)
+ // and sort for correct index compression when removing dimensions from affine
+ // maps.
+ SmallVector<unsigned> loopDimsToRemove;
+ if (canRemoveSpatial0) {
+ loopDimsToRemove.push_back(outSpatial0);
+ loopDimsToRemove.push_back(filterSpatial0);
+ } else {
+ loopDimsToRemove.push_back(outSpatial1);
+ loopDimsToRemove.push_back(filterSpatial1);
+ }
+ llvm::sort(loopDimsToRemove);
- Value input = convOp.getDpsInputs().front();
- Value kernel = convOp.getDpsInputs().back();
- Value output = convOp.getDpsInits().front();
+ // Create new indexing maps with dimensions removed.
+ SmallVector<AffineMap> newMaps;
+ MLIRContext *ctx = op.getContext();
+ unsigned numDims = op.getNumLoops();
+ unsigned newNumDims = numDims - loopDimsToRemove.size();
+ for (AffineMap map : op.getIndexingMapsArray()) {
+ SmallVector<AffineExpr> newResults;
+ for (AffineExpr expr : map.getResults()) {
+ auto newExpr =
+ dropDimsAndCompress(expr, loopDimsToRemove, newNumDims, ctx);
+ if (newExpr)
+ newResults.push_back(*newExpr);
+ }
+ newMaps.push_back(AffineMap::get(newNumDims, 0, newResults, ctx));
+ }
- auto inputType = dyn_cast<RankedTensorType>(input.getType());
- auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
- auto outputType = dyn_cast<RankedTensorType>(output.getType());
+ // Create new iterator types.
+ SmallVector<utils::IteratorType> newIterTypes;
+ auto iterTypes = op.getIteratorTypesArray();
+ for (unsigned idx = 0; idx < iterTypes.size(); ++idx) {
+ if (!llvm::is_contained(loopDimsToRemove, idx))
+ newIterTypes.push_back(iterTypes[idx]);
+ }
- auto kernelShape = kernelType.getShape();
- auto outputShape = outputType.getShape();
+ // Rank-reduce operands using extract_slice.
+ Location loc = op.getLoc();
+ SmallVector<Value> newInputs;
+ for (OpOperand *input : op.getDpsInputOperands()) {
+ AffineMap map = op.getMatchingIndexingMap(input);
+ SmallVector<unsigned> tensorDimsToRemove =
+ getResultIndicesReferencingDims(map, loopDimsToRemove);
+ Value reduced = createRankReducingExtractSlice(rewriter, loc, input->get(),
+ tensorDimsToRemove);
+ newInputs.push_back(reduced);
+ }
- // Only handle the case where at least one of the window dimensions is
- // of size 1. Other cases can rely on tiling to reduce to such cases.
- int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
- int64_t ohSize = outputShape[0], owSize = outputShape[1];
- bool removeH = (khSize == 1 && ohSize == 1);
- bool removeW = (kwSize == 1 && owSize == 1);
- if (!removeH && !removeW)
- return failure();
+ OpOperand &output = *op.getDpsInitsMutable().begin();
+ AffineMap outputMap = op.getMatchingIndexingMap(&output);
+ SmallVector<unsigned> outputDimsToRemove =
+ getResultIndicesReferencingDims(outputMap, loopDimsToRemove);
+ Value newOutput = createRankReducingExtractSlice(rewriter, loc, output.get(),
+ outputDimsToRemove);
+
+ // Create new linalg.generic with reduced dimensions.
+ auto newOp =
+ linalg::GenericOp::create(rewriter, loc, TypeRange{newOutput.getType()},
+ newInputs, newOutput, newMaps, newIterTypes);
+ rewriter.inlineRegionBefore(op->getRegion(0), newOp.getRegion(),
+ newOp.getRegion().begin());
+
+ // Try to specialize the generic back to a named op only if the input was
+ // already a specialized (named) op.
+ LinalgOp resultOp = newOp;
+ if (!isa<GenericOp>(op)) {
+ FailureOr<LinalgOp> specializedOp = specializeGenericOp(rewriter, newOp);
+ if (succeeded(specializedOp))
+ resultOp = *specializedOp;
+ }
- // Get new shapes and types for all operands by removing the size-1
- // dimension.
- using RTTBuilder = RankedTensorType::Builder;
- RankedTensorType newInputType =
- RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
- RankedTensorType newKernelType =
- RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
- RankedTensorType newOutputType =
- RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
-
- // Rank-reduce operands.
- Location loc = convOp.getLoc();
- Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, input, newInputType);
- Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, kernel, newKernelType);
- Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, output, newOutputType);
-
- auto conv1DOp =
- Conv1DOp::create(rewriter, loc, newOutputType,
- ValueRange{newInput, newKernel}, ValueRange{newOutput});
-
- // Insert back.
- Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
- rewriter, loc, conv1DOp.getResult(0), output);
- rewriter.replaceOp(convOp, inserted);
-
- return conv1DOp;
+ // Insert result back into original shape.
+ Value result = tensor::createCanonicalRankReducingInsertSliceOp(
+ rewriter, loc, resultOp->getResult(0), output.get());
+
+ rewriter.replaceOp(op, result);
+ return resultOp;
}
+namespace {
+/// Pattern wrapper around `downscaleSizeOneWindowedConvolution`.
+struct DownscaleSizeOneWindowedConvolution final
+ : public OpInterfaceRewritePattern<LinalgOp> {
+ DownscaleSizeOneWindowedConvolution(MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit) {}
+
+ LogicalResult matchAndRewrite(LinalgOp op,
+ PatternRewriter &rewriter) const override {
+ return linalg::downscaleSizeOneWindowedConvolution(rewriter, op);
+ }
+};
+} // namespace
+
void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
- patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
- Conv1DNwcWcfOp>,
- DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
- Conv1DNcwFcwOp>,
- DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>(
- patterns.getContext(), benefit);
- patterns.add<
- DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
- DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
- DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>,
- DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
- PoolingNwcMaxUnsignedOp>,
- DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>,
- DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
- PoolingNwcMinUnsignedOp>,
- DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
- patterns.getContext(), benefit);
+ patterns.add<DownscaleSizeOneWindowedConvolution>(patterns.getContext(),
+ benefit);
}
void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) {
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 6b03885069a37..3897f8502bb04 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -1,11 +1,26 @@
// RUN: mlir-opt --transform-interpreter --split-input-file %s | FileCheck %s
-// Test the same patterns on generic convolution ops by first generalizing the
-// named ops. This avoids duplicating lit tests for linalg.generic conv ops.
-// RUN: mlir-opt --linalg-generalize-named-ops --transform-interpreter --split-input-file %s | FileCheck %s
+
+// Expected indexing maps for batchless conv_1d_nwc_wcf.
+// CHECK-DAG: #[[$CONV_I:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d3)>
+// CHECK-DAG: #[[$CONV_F:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>
+// CHECK-DAG: #[[$CONV_O:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
+// Expected indexing maps for batchless depthwise_conv_1d_wc_wcf.
+// CHECK-DAG: #[[$DW_I:.+]] = affine_map<(d0, d1, d2) -> (d0 + d2, d1)>
+// CHECK-DAG: #[[$DW_F:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+
+// Expected indexing maps for batchless pooling_cw_min.
+// CHECK-DAG: #[[$POOL_I:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+// CHECK-DAG: #[[$POOL_F:.+]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// Expected indexing maps for 1D conv (cross-conv after downscale from generic).
+// CHECK-DAG: #[[$CROSS_1D_I:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-DAG: #[[$CROSS_1D_F:.+]] = affine_map<(d0, d1) -> (d1)>
+// CHECK-DAG: #[[$CROSS_1D_O:.+]] = affine_map<(d0, d1) -> (d0)>
+
// CHECK-LABEL: @conv_2d_nhwc_hwcf
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
@@ -42,10 +57,11 @@ func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x
return %0 : tensor<?x?x1x?xf32>
}
-// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
+// Depthwise conv with height=1 (downscales height dimension)
+// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc_height
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
-func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
+func.func @depthwise_conv_2d_nhwc_hwc_height(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
// CHECK: %[[RES:.+]] = tensor.empty
%init = tensor.empty() : tensor<1x1x56x96xf32>
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
@@ -62,6 +78,27 @@ func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: t
return %0: tensor<1x1x56x96xf32>
}
+// Depthwise conv with width=1 (downscales width dimension)
+// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc_width
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x113x1x96xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<3x1x96xf32>
+func.func @depthwise_conv_2d_nhwc_hwc_width(%input: tensor<1x113x1x96xf32>, %filter: tensor<3x1x96xf32>) -> tensor<1x56x1x96xf32> {
+ // CHECK: %[[RES:.+]] = tensor.empty
+ %init = tensor.empty() : tensor<1x56x1x96xf32>
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]]
+ // CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc
+ // CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]]
+ // CHECK-SAME: outs(%[[SLICERES]]
+ // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]]
+ %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+ ins(%input, %filter: tensor<1x113x1x96xf32>, tensor<3x1x96xf32>)
+ outs(%init: tensor<1x56x1x96xf32>) -> tensor<1x56x1x96xf32>
+ // CHECK: %[[INSERTED]]
+ return %0: tensor<1x56x1x96xf32>
+}
+
// CHECK-LABEL: @conv_2d
// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<1x?xf32>,
// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
@@ -205,6 +242,125 @@ func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32
return %0 : tensor<?x?x1x?xf32>
}
+#map_conv_i = affine_map<(oh, ow, f, kh, kw, c) -> (oh + kh, ow + kw, c)>
+#map_conv_f = affine_map<(oh, ow, f, kh, kw, c) -> (kh, kw, c, f)>
+#map_conv_o = affine_map<(oh, ow, f, kh, kw, c) -> (oh, ow, f)>
+
+// CHECK-LABEL: @batchless_conv_2d_hwc_hwcf
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x14x8xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x8x16xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<1x12x16xf32>
+func.func @batchless_conv_2d_hwc_hwcf(%input: tensor<1x14x8xf32>, %filter: tensor<1x3x8x16xf32>, %output: tensor<1x12x16xf32>) -> tensor<1x12x16xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$CONV_I]], #[[$CONV_F]], #[[$CONV_O]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction"]
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.generic {
+ indexing_maps = [#map_conv_i, #map_conv_f, #map_conv_o],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
+ } ins(%input, %filter : tensor<1x14x8xf32>, tensor<1x3x8x16xf32>)
+ outs(%output : tensor<1x12x16xf32>) {
+ ^bb0(%in: f32, %fil: f32, %out: f32):
+ %mul = arith.mulf %in, %fil : f32
+ %add = arith.addf %out, %mul : f32
+ linalg.yield %add : f32
+ } -> tensor<1x12x16xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<1x12x16xf32>
+}
+
+#map_dw_i = affine_map<(oh, ow, c, kh, kw) -> (oh + kh, ow + kw, c)>
+#map_dw_f = affine_map<(oh, ow, c, kh, kw) -> (kh, kw, c)>
+#map_dw_o = affine_map<(oh, ow, c, kh, kw) -> (oh, ow, c)>
+
+// CHECK-LABEL: @batchless_depthwise_conv_2d_hwc_hwc
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x14x8xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x8xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<1x12x8xf32>
+func.func @batchless_depthwise_conv_2d_hwc_hwc(%input: tensor<1x14x8xf32>, %filter: tensor<1x3x8xf32>, %output: tensor<1x12x8xf32>) -> tensor<1x12x8xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$DW_I]], #[[$DW_F]], #[[$MAP1]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.generic {
+ indexing_maps = [#map_dw_i, #map_dw_f, #map_dw_o],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
+ } ins(%input, %filter : tensor<1x14x8xf32>, tensor<1x3x8xf32>)
+ outs(%output : tensor<1x12x8xf32>) {
+ ^bb0(%in: f32, %fil: f32, %out: f32):
+ %mul = arith.mulf %in, %fil : f32
+ %add = arith.addf %out, %mul : f32
+ linalg.yield %add : f32
+ } -> tensor<1x12x8xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<1x12x8xf32>
+}
+
+#map_pool_i = affine_map<(c, oh, ow, kh, kw) -> (c, oh + kh, ow + kw)>
+#map_pool_f = affine_map<(c, oh, ow, kh, kw) -> (kh, kw)>
+#map_pool_o = affine_map<(c, oh, ow, kh, kw) -> (c, oh, ow)>
+
+// CHECK-LABEL: @batchless_pooling_chw_min
+// CHECK-SAME: %[[ARG0:.+]]: tensor<8x1x14xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<8x1x12xf32>
+func.func @batchless_pooling_chw_min(%input: tensor<8x1x14xf32>, %filter: tensor<1x3xf32>, %output: tensor<8x1x12xf32>) -> tensor<8x1x12xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$POOL_I]], #[[$POOL_F]], #[[$MAP1]]]
+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.generic {
+ indexing_maps = [#map_pool_i, #map_pool_f, #map_pool_o],
+ iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
+ } ins(%input, %filter : tensor<8x1x14xf32>, tensor<1x3xf32>)
+ outs(%output : tensor<8x1x12xf32>) {
+ ^bb0(%in: f32, %fil: f32, %out: f32):
+ %min = arith.minimumf %out, %in : f32
+ linalg.yield %min : f32
+ } -> tensor<8x1x12xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<8x1x12xf32>
+}
+
+#map_cross_i = affine_map<(d0, d1, d2, d3) -> (d0 + d3, d1 + d2)>
+#map_cross_f = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
+#map_cross_o = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
+
+// CHECK-LABEL: @cross_conv_nonstandard_loop_order
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x15xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<3x1xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<1x12xf32>
+func.func @cross_conv_nonstandard_loop_order(%input: tensor<1x15xf32>, %filter: tensor<3x1xf32>, %output: tensor<1x12xf32>) -> tensor<1x12xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.generic
+ // CHECK-SAME: indexing_maps = [#[[$CROSS_1D_I]], #[[$CROSS_1D_F]], #[[$CROSS_1D_O]]]
+ // CHECK-SAME: iterator_types = ["parallel", "reduction"]
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ // CHECK: return %[[RES]]
+ %0 = linalg.generic {
+ indexing_maps = [#map_cross_i, #map_cross_f, #map_cross_o],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]
+ } ins(%input, %filter : tensor<1x15xf32>, tensor<3x1xf32>)
+ outs(%output : tensor<1x12xf32>) {
+ ^bb0(%in: f32, %fil: f32, %out: f32):
+ %mul = arith.mulf %in, %fil : f32
+ %add = arith.addf %out, %mul : f32
+ linalg.yield %add : f32
+ } -> tensor<1x12xf32>
+ return %0 : tensor<1x12xf32>
+}
+
func.func @softmax(%arg0: tensor<2x16x32xf32>, %dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
%1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%dst: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
return %1 : tensor<2x16x32xf32>
More information about the Mlir-commits
mailing list