[Mlir-commits] [mlir] ce2e198 - [mlir] add decompose and generalize to structured transform ops
Alex Zinenko
llvmlistbot at llvm.org
Thu Jun 2 06:25:45 PDT 2022
Author: Alex Zinenko
Date: 2022-06-02T15:25:18+02:00
New Revision: ce2e198bc2546f24a64fbeff62bf1489bcc53c27
URL: https://github.com/llvm/llvm-project/commit/ce2e198bc2546f24a64fbeff62bf1489bcc53c27
DIFF: https://github.com/llvm/llvm-project/commit/ce2e198bc2546f24a64fbeff62bf1489bcc53c27.diff
LOG: [mlir] add decompose and generalize to structured transform ops
These ops complement the tiling/padding transformations by transforming
higher-level named structured operations such as depthwise convolutions into
lower-level and/or generic equivalents that are better handled by some
downstream transformations.
Differential Revision: https://reviews.llvm.org/D126698
Added:
mlir/test/Dialect/Linalg/transform-op-decompose.mlir
mlir/test/Dialect/Linalg/transform-op-generalize.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/python/mlir/dialects/_structured_transform_ops_ext.py
mlir/test/python/dialects/transform_structured_ext.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 387521a0e7245..205b0987ff98a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -16,6 +16,49 @@ include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
+def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let description = [{
+ Decomposes named complex operations, such as higher-dimensional
+ (depthwise) convolutions, into combinations of lower-dimensional equivalents
+ when possible. The operand handle must point to a list of such operations.
+ The returning handle points to the main produced computational operation,
+ such as the lower-dimensional convolution.
+ }];
+
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+ ::mlir::linalg::LinalgOp target);
+ }];
+}
+
+def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let description = [{
+ Transforms a named structued operation into the generic form with the
+ explicit attached region. The operand handle must point to a list of
+ structured operations, it is consumed by the transformation and is not
+ expected to be used afterwards. The resulting handle points to the list
+ of equivalent generic operations, in the same order as the original named
+ operations.
+ }];
+
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$transformed);
+ let assemblyFormat = "$target attr-dict";
+
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+ ::mlir::linalg::LinalgOp target);
+ }];
+}
+
def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6d058e03fafed..3db28c32f740c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -708,6 +708,56 @@ struct LinalgPaddingPattern : public OpInterfaceRewritePattern<LinalgOp> {
LinalgPaddingOptions options;
};
+/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
+/// convolution ops.
+struct DownscaleSizeOneWindowed2DConvolution final
+ : public OpRewritePattern<Conv2DNhwcHwcfOp> {
+ DownscaleSizeOneWindowed2DConvolution(
+ MLIRContext *context,
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
+ filter(std::move(f)) {}
+
+ FailureOr<Conv1DNwcWcfOp>
+ returningMatchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+ PatternRewriter &rewriter) const;
+
+ LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(convOp, rewriter);
+ }
+
+private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ LinalgTransformationFilter filter;
+};
+
+/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
+/// dimensions into 1-D depthwise convolution ops.
+struct DownscaleDepthwiseConv2DNhwcHwcOp final
+ : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
+ DownscaleDepthwiseConv2DNhwcHwcOp(
+ MLIRContext *context,
+ LinalgTransformationFilter f = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
+ filter(std::move(f)) {}
+
+ FailureOr<DepthwiseConv1DNwcWcOp>
+ returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
+ PatternRewriter &rewriter) const;
+
+ LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
+ PatternRewriter &rewriter) const override {
+ return returningMatchAndRewrite(convOp, rewriter);
+ }
+
+private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ LinalgTransformationFilter filter;
+};
+
struct LinalgFusionOptions {
/// List of operands indices to use for fusion.
llvm::SmallSet<unsigned, 1> indicesToFuse = {};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f80ba4fc286f7..b081e241a848d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -50,6 +50,68 @@ class SimpleRewriter : public PatternRewriter {
};
} // namespace
+/// Attempts to apply the pattern specified as template argument to the given
+/// operation. The pattern is expected to have a `returningMatchAndRewrite`
+/// function that returns the "main" result or failure. Returns failure if the
+/// pattern failed to apply. Extra arguments are forwarded to the pattern
+/// constructor.
+template <typename PatternTy, typename... Args>
+static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
+ // Check if the given operation has the type expected by the pattern.
+ using OpTy = typename llvm::function_traits<
+ decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
+ auto op = dyn_cast<OpTy>(operation);
+ if (!op)
+ return failure();
+
+ // Apply the pattern directly to the op.
+ PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
+ SimpleRewriter rewriter(operation->getContext());
+ rewriter.setInsertionPoint(operation);
+ auto result = pattern.returningMatchAndRewrite(op, rewriter);
+ if (failed(result))
+ return failure();
+ return cast<LinalgOp>(result->getOperation());
+}
+
+//===----------------------------------------------------------------------===//
+// DecomposeOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
+ FailureOr<LinalgOp> windowed =
+ tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
+ if (succeeded(windowed))
+ return windowed;
+
+ FailureOr<LinalgOp> depthwise =
+ tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
+ if (succeeded(depthwise))
+ return depthwise;
+
+ InFlightDiagnostic diag = emitError() << "failed to apply";
+ diag.attachNote(target.getLoc()) << "attempted to apply to this op";
+ return diag;
+}
+
+//===----------------------------------------------------------------------===//
+// GeneralizeOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
+ // Exit early if no transformation is needed.
+ if (isa<GenericOp>(target))
+ return target;
+
+ FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
+ if (succeeded(generic))
+ return generic;
+
+ InFlightDiagnostic diag = emitError() << "failed to apply";
+ diag.attachNote(target.getLoc()) << "attempted to apply to this op";
+ return diag;
+}
+
//===----------------------------------------------------------------------===//
// InterchangeOp
//===----------------------------------------------------------------------===//
@@ -70,15 +132,7 @@ FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) {
return diag;
}
- GenericOpInterchangePattern pattern(getContext(), interchangeVector);
- SimpleRewriter rewriter(getContext());
- rewriter.setInsertionPoint(target);
- FailureOr<GenericOp> result =
- pattern.returningMatchAndRewrite(genericTarget, rewriter);
- if (failed(result))
- return failure();
-
- return cast<LinalgOp>(result->getOperation());
+ return tryApply<GenericOpInterchangePattern>(target, interchangeVector);
}
LogicalResult transform::InterchangeOp::verify() {
@@ -147,18 +201,15 @@ FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) {
paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
paddingOptions.setTransposePaddings(transposePaddings);
- LinalgPaddingPattern pattern(getContext(), paddingOptions);
- SimpleRewriter rewriter(getContext());
- rewriter.setInsertionPoint(target);
- FailureOr<LinalgOp> patternResult =
- pattern.returningMatchAndRewrite(target, rewriter);
- if (failed(patternResult)) {
- InFlightDiagnostic diag = emitError()
- << "failed to apply pattern to target op";
- diag.attachNote(target.getLoc()) << "target op";
- return diag;
- }
- return patternResult;
+ FailureOr<LinalgOp> result =
+ tryApply<LinalgPaddingPattern>(target, paddingOptions);
+ if (succeeded(result))
+ return result;
+
+ InFlightDiagnostic diag = emitError()
+ << "failed to apply pattern to target op";
+ diag.attachNote(target.getLoc()) << "target op";
+ return diag;
}
LogicalResult transform::PadOp::verify() {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7fed6c0428fb2..6b347561a09e0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -945,7 +945,6 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
return success();
}
-namespace {
// The following are patterns for downscaling convolution ops with size-1
// window dimensions.
//
@@ -954,179 +953,145 @@ namespace {
// and then turning back to named ops. But for now it's fine to have a few
// patterns matching special ops to get started.
-/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
-/// convolution ops.
-struct DownscaleSizeOneWindowed2DConvolution final
- : public OpRewritePattern<Conv2DNhwcHwcfOp> {
- DownscaleSizeOneWindowed2DConvolution(
- MLIRContext *context,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
- filter(std::move(f)) {}
-
- LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
- PatternRewriter &rewriter) const override {
- if (failed(filter.checkAndNotify(rewriter, convOp)))
- return failure();
- if (convOp.hasBufferSemantics())
- return failure(); // To be implemented
-
- Value input = convOp.inputs().front();
- Value kernel = convOp.inputs().back();
- Value output = convOp.outputs().front();
-
- auto inputType = input.getType().dyn_cast<RankedTensorType>();
- auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
-
- auto kernelShape = kernelType.getShape();
- auto outputShape = outputType.getShape();
-
- // Only handle the case where at least one of the window dimensions is
- // of size 1. Other cases can rely on tiling to reduce to such cases.
- int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
- int64_t ohSize = outputShape[1], owSize = outputShape[2];
- bool removeH = (khSize == 1 && ohSize == 1);
- bool removeW = (kwSize == 1 && owSize == 1);
- if (!removeH && !removeW)
- return failure();
-
- // Get new shapes and types for all operands by removing the size-1
- // dimension.
- using RTTBuilder = RankedTensorType::Builder;
- RankedTensorType newInputType =
- RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
- RankedTensorType newKernelType =
- RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
- RankedTensorType newOutputType =
- RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
-
- // Rank-reduce operands.
- Location loc = convOp.getLoc();
- Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, input, newInputType);
- Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, kernel, newKernelType);
- Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, output, newOutputType);
-
- // Rank-reduce strides and dilations too.
- // TODO: dropDim 1-liner helper.
- auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
- strides.erase(strides.begin() + (removeH ? 0 : 1));
- auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
- auto dilations =
- llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
- dilations.erase(dilations.begin() + (removeH ? 0 : 1));
- auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
-
- auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
- loc, newOutputType, ValueRange{newInput, newKernel},
- ValueRange{newOutput}, stridesAttr, dilationsAttr);
-
- // Insert back.
- Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
- rewriter, loc, conv1DOp.getResult(0), output);
- rewriter.replaceOp(convOp, inserted);
-
- filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
- return success();
- };
-
-private:
- /// LinalgTransformMarker handles special attribute manipulations.
- LinalgTransformationFilter filter;
-};
-
-/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
-/// dimensions into 1-D depthwise convolution ops.
-struct DownscaleDepthwiseConv2DNhwcHwcOp final
- : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
- DownscaleDepthwiseConv2DNhwcHwcOp(
- MLIRContext *context,
- LinalgTransformationFilter f = LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
- filter(std::move(f)) {}
-
- LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
- PatternRewriter &rewriter) const override {
- if (failed(filter.checkAndNotify(rewriter, convOp)))
- return failure();
- if (convOp.hasBufferSemantics())
- return failure(); // To be implemented
-
- Value input = convOp.inputs().front();
- Value kernel = convOp.inputs().back();
- Value output = convOp.outputs().front();
-
- auto inputType = input.getType().dyn_cast<RankedTensorType>();
- auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
- auto outputType = output.getType().dyn_cast<RankedTensorType>();
-
- auto kernelShape = kernelType.getShape();
- auto outputShape = outputType.getShape();
-
- // Only handle the case where at least one of the window dimensions is
- // of size 1. Other cases can rely on tiling to reduce to such cases.
- int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
- int64_t ohSize = outputShape[1], owSize = outputShape[2];
- bool removeH = (khSize == 1 && ohSize == 1);
- bool removeW = (kwSize == 1 && owSize == 1);
- if (!removeH && !removeW)
- return failure();
+FailureOr<Conv1DNwcWcfOp>
+DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
+ linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const {
+ if (failed(filter.checkAndNotify(rewriter, convOp)))
+ return failure();
+ if (convOp.hasBufferSemantics())
+ return failure(); // To be implemented.
+
+ Value input = convOp.inputs().front();
+ Value kernel = convOp.inputs().back();
+ Value output = convOp.outputs().front();
+
+ auto inputType = input.getType().dyn_cast<RankedTensorType>();
+ auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
+ auto outputType = output.getType().dyn_cast<RankedTensorType>();
+
+ auto kernelShape = kernelType.getShape();
+ auto outputShape = outputType.getShape();
+
+ // Only handle the case where at least one of the window dimensions is
+ // of size 1. Other cases can rely on tiling to reduce to such cases.
+ int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
+ int64_t ohSize = outputShape[1], owSize = outputShape[2];
+ bool removeH = (khSize == 1 && ohSize == 1);
+ bool removeW = (kwSize == 1 && owSize == 1);
+ if (!removeH && !removeW)
+ return failure();
- // Get new shapes and types for all operands by removing the size-1
- // dimension.
- using RTTBuilder = RankedTensorType::Builder;
- RankedTensorType newInputType =
- RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
- RankedTensorType newKernelType =
- RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
- RankedTensorType newOutputType =
- RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
-
- // Rank-reduce operands.
- Location loc = convOp.getLoc();
- Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, input, newInputType);
- Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, kernel, newKernelType);
- Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
- rewriter, loc, output, newOutputType);
-
- // Rank-reduce strides and dilations too.
- // TODO: dropDim 1-liner helper.
- auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
- strides.erase(strides.begin() + (removeH ? 0 : 1));
- auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
- auto dilations =
- llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
- dilations.erase(dilations.begin() + (removeH ? 0 : 1));
- auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
-
- auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
- loc, newOutputType, ValueRange{newInput, newKernel},
- ValueRange{newOutput}, stridesAttr, dilationsAttr);
-
- // Insert back.
- Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
- rewriter, loc, conv1DOp.getResult(0), output);
- rewriter.replaceOp(convOp, inserted);
-
- filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
- return success();
- };
+ // Get new shapes and types for all operands by removing the size-1
+ // dimension.
+ using RTTBuilder = RankedTensorType::Builder;
+ RankedTensorType newInputType =
+ RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+ RankedTensorType newKernelType =
+ RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
+ RankedTensorType newOutputType =
+ RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
+
+ // Rank-reduce operands.
+ Location loc = convOp.getLoc();
+ Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, input, newInputType);
+ Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, kernel, newKernelType);
+ Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, output, newOutputType);
+
+ // Rank-reduce strides and dilations too.
+ // TODO: dropDim 1-liner helper.
+ auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
+ strides.erase(strides.begin() + (removeH ? 0 : 1));
+ auto stridesAttr = rewriter.getI64VectorAttr(strides);
+
+ auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
+ dilations.erase(dilations.begin() + (removeH ? 0 : 1));
+ auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
+
+ auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
+ loc, newOutputType, ValueRange{newInput, newKernel},
+ ValueRange{newOutput}, stridesAttr, dilationsAttr);
+
+ // Insert back.
+ Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
+ rewriter, loc, conv1DOp.getResult(0), output);
+ rewriter.replaceOp(convOp, inserted);
+
+ filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
+ return conv1DOp;
+}
-private:
- /// LinalgTransformMarker handles special attribute manipulations.
- LinalgTransformationFilter filter;
-};
+FailureOr<DepthwiseConv1DNwcWcOp>
+DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
+ DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
+ if (failed(filter.checkAndNotify(rewriter, convOp)))
+ return failure();
+ if (convOp.hasBufferSemantics())
+ return failure(); // To be implemented.
+
+ Value input = convOp.inputs().front();
+ Value kernel = convOp.inputs().back();
+ Value output = convOp.outputs().front();
+
+ auto inputType = input.getType().dyn_cast<RankedTensorType>();
+ auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
+ auto outputType = output.getType().dyn_cast<RankedTensorType>();
+
+ auto kernelShape = kernelType.getShape();
+ auto outputShape = outputType.getShape();
+
+ // Only handle the case where at least one of the window dimensions is
+ // of size 1. Other cases can rely on tiling to reduce to such cases.
+ int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
+ int64_t ohSize = outputShape[1], owSize = outputShape[2];
+ bool removeH = (khSize == 1 && ohSize == 1);
+ bool removeW = (kwSize == 1 && owSize == 1);
+ if (!removeH && !removeW)
+ return failure();
-} // namespace
+ // Get new shapes and types for all operands by removing the size-1
+ // dimension.
+ using RTTBuilder = RankedTensorType::Builder;
+ RankedTensorType newInputType =
+ RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+ RankedTensorType newKernelType =
+ RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
+ RankedTensorType newOutputType =
+ RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
+
+ // Rank-reduce operands.
+ Location loc = convOp.getLoc();
+ Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, input, newInputType);
+ Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, kernel, newKernelType);
+ Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, output, newOutputType);
+
+ // Rank-reduce strides and dilations too.
+ // TODO: dropDim 1-liner helper.
+ auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
+ strides.erase(strides.begin() + (removeH ? 0 : 1));
+ auto stridesAttr = rewriter.getI64VectorAttr(strides);
+
+ auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
+ dilations.erase(dilations.begin() + (removeH ? 0 : 1));
+ auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
+
+ auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
+ loc, newOutputType, ValueRange{newInput, newKernel},
+ ValueRange{newOutput}, stridesAttr, dilationsAttr);
+
+ // Insert back.
+ Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
+ rewriter, loc, conv1DOp.getResult(0), output);
+ rewriter.replaceOp(convOp, inserted);
+
+ filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
+ return conv1DOp;
+}
void linalg::populateDecomposeConvolutionPatterns(
RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 70e39be5289da..e5a2a473150cc 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -69,6 +69,28 @@ def _get_int_int_array_attr(
return ArrayAttr.get([_get_int_array_attr(value) for value in values])
+class DecomposeOp:
+ """Specialization for DecomposeOp class."""
+
+ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+ super().__init__(
+ pdl.OperationType.get(),
+ _get_op_result_or_value(target),
+ loc=loc,
+ ip=ip)
+
+
+class GeneralizeOp:
+ """Specialization for GeneralizeOp class."""
+
+ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+ super().__init__(
+ pdl.OperationType.get(),
+ _get_op_result_or_value(target),
+ loc=loc,
+ ip=ip)
+
+
class InterchangeOp:
"""Specialization for InterchangeOp class."""
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
new file mode 100644
index 0000000000000..e80c3b1078d6d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @conv_2d_nhwc_hwcf
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
+func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_nwc_wcf
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>)
+ outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x1x?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.conv_2d_nhwc_hwcf"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1 = transform.structured.decompose %0
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
+func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
+ // CHECK: %[[RES:.+]] = linalg.init_tensor
+ %init = linalg.init_tensor [1, 1, 56, 96] : tensor<1x1x56x96xf32>
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]]
+ // CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc
+ // CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]]
+ // CHECK-SAME: outs(%[[SLICERES]]
+ // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]]
+ %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+ ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
+ outs(%init: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
+ // CHECK: %[[INSERTED]]
+ return %0: tensor<1x1x56x96xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.depthwise_conv_2d_nhwc_hwc"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1 = transform.structured.decompose %0
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-generalize.mlir b/mlir/test/Dialect/Linalg/transform-op-generalize.mlir
new file mode 100644
index 0000000000000..1a20cf7502cab
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-generalize.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s
+
+// CHECK-LABEL: func.func @generalize_unary
+func.func @generalize_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+ // CHECK-NOT: linalg.elemwise_unary
+ // CHECK: linalg.generic
+ %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+ outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %0 = pdl.operation "linalg.elemwise_unary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1 = transform.structured.generalize %0
+ }
+}
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 463dec10d7bd5..a34b03fb9d0bc 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -16,6 +16,28 @@ def run(f):
return f
+ at run
+def testDecompose():
+ sequence = transform.SequenceOp()
+ with InsertionPoint(sequence.body):
+ structured.DecomposeOp(sequence.bodyTarget)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testDecompose
+ # CHECK: transform.sequence
+ # CHECK: transform.structured.decompose
+
+
+ at run
+def testGeneralize():
+ sequence = transform.SequenceOp()
+ with InsertionPoint(sequence.body):
+ structured.GeneralizeOp(sequence.bodyTarget)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testGeneralize
+ # CHECK: transform.sequence
+ # CHECK: transform.structured.generalize
+
+
@run
def testInterchange():
sequence = transform.SequenceOp()
More information about the Mlir-commits
mailing list