[Mlir-commits] [mlir] ce2e198 - [mlir] add decompose and generalize to structured transform ops

Alex Zinenko llvmlistbot at llvm.org
Thu Jun 2 06:25:45 PDT 2022


Author: Alex Zinenko
Date: 2022-06-02T15:25:18+02:00
New Revision: ce2e198bc2546f24a64fbeff62bf1489bcc53c27

URL: https://github.com/llvm/llvm-project/commit/ce2e198bc2546f24a64fbeff62bf1489bcc53c27
DIFF: https://github.com/llvm/llvm-project/commit/ce2e198bc2546f24a64fbeff62bf1489bcc53c27.diff

LOG: [mlir] add decompose and generalize to structured transform ops

These ops complement the tiling/padding transformations by transforming
higher-level named structured operations such as depthwise convolutions into
lower-level and/or generic equivalents that are better handled by some
downstream transformations.

Differential Revision: https://reviews.llvm.org/D126698

Added: 
    mlir/test/Dialect/Linalg/transform-op-decompose.mlir
    mlir/test/Dialect/Linalg/transform-op-generalize.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/test/python/dialects/transform_structured_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 387521a0e7245..205b0987ff98a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -16,6 +16,49 @@ include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let description = [{
+    Decomposes named complex operations, such as higher-dimensional
+    (depthwise) convolutions, into combinations of lower-dimensional equivalents
+    when possible. The operand handle must point to a list of such operations.
+    The returning handle points to the main produced computational operation,
+    such as the lower-dimensional convolution.
+  }];
+
+  let arguments = (ins PDL_Operation:$target);
+  let results = (outs PDL_Operation:$transformed);
+  let assemblyFormat = "$target attr-dict";
+
+  let extraClassDeclaration = [{
+    ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+        ::mlir::linalg::LinalgOp target);
+  }];
+}
+
+def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let description = [{
+    Transforms a named structued operation into the generic form with the
+    explicit attached region. The operand handle must point to a list of
+    structured operations, it is consumed by the transformation and is not
+    expected to be used afterwards. The resulting handle points to the list
+    of equivalent generic operations, in the same order as the original named
+    operations.
+  }];
+
+  let arguments = (ins PDL_Operation:$target);
+  let results = (outs PDL_Operation:$transformed);
+  let assemblyFormat = "$target attr-dict";
+
+  let extraClassDeclaration = [{
+    ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
+        ::mlir::linalg::LinalgOp target);
+  }];
+}
+
 def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
     TransformOpInterface, TransformEachOpTrait]> {

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6d058e03fafed..3db28c32f740c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -708,6 +708,56 @@ struct LinalgPaddingPattern : public OpInterfaceRewritePattern<LinalgOp> {
   LinalgPaddingOptions options;
 };
 
+/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
+/// convolution ops.
+struct DownscaleSizeOneWindowed2DConvolution final
+    : public OpRewritePattern<Conv2DNhwcHwcfOp> {
+  DownscaleSizeOneWindowed2DConvolution(
+      MLIRContext *context,
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1)
+      : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
+        filter(std::move(f)) {}
+
+  FailureOr<Conv1DNwcWcfOp>
+  returningMatchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+                           PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(convOp, rewriter);
+  }
+
+private:
+  /// LinalgTransformMarker handles special attribute manipulations.
+  LinalgTransformationFilter filter;
+};
+
+/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
+/// dimensions into 1-D depthwise convolution ops.
+struct DownscaleDepthwiseConv2DNhwcHwcOp final
+    : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
+  DownscaleDepthwiseConv2DNhwcHwcOp(
+      MLIRContext *context,
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1)
+      : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
+        filter(std::move(f)) {}
+
+  FailureOr<DepthwiseConv1DNwcWcOp>
+  returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
+                           PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(convOp, rewriter);
+  }
+
+private:
+  /// LinalgTransformMarker handles special attribute manipulations.
+  LinalgTransformationFilter filter;
+};
+
 struct LinalgFusionOptions {
   /// List of operands indices to use for fusion.
   llvm::SmallSet<unsigned, 1> indicesToFuse = {};

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f80ba4fc286f7..b081e241a848d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -50,6 +50,68 @@ class SimpleRewriter : public PatternRewriter {
 };
 } // namespace
 
+/// Attempts to apply the pattern specified as template argument to the given
+/// operation. The pattern is expected to have a `returningMatchAndRewrite`
+/// function that returns the "main" result or failure. Returns failure if the
+/// pattern failed to apply. Extra arguments are forwarded to the pattern
+/// constructor.
+template <typename PatternTy, typename... Args>
+static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
+  // Check if the given operation has the type expected by the pattern.
+  using OpTy = typename llvm::function_traits<
+      decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
+  auto op = dyn_cast<OpTy>(operation);
+  if (!op)
+    return failure();
+
+  // Apply the pattern directly to the op.
+  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
+  SimpleRewriter rewriter(operation->getContext());
+  rewriter.setInsertionPoint(operation);
+  auto result = pattern.returningMatchAndRewrite(op, rewriter);
+  if (failed(result))
+    return failure();
+  return cast<LinalgOp>(result->getOperation());
+}
+
+//===----------------------------------------------------------------------===//
+// DecomposeOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
+  FailureOr<LinalgOp> windowed =
+      tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
+  if (succeeded(windowed))
+    return windowed;
+
+  FailureOr<LinalgOp> depthwise =
+      tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
+  if (succeeded(depthwise))
+    return depthwise;
+
+  InFlightDiagnostic diag = emitError() << "failed to apply";
+  diag.attachNote(target.getLoc()) << "attempted to apply to this op";
+  return diag;
+}
+
+//===----------------------------------------------------------------------===//
+// GeneralizeOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
+  // Exit early if no transformation is needed.
+  if (isa<GenericOp>(target))
+    return target;
+
+  FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
+  if (succeeded(generic))
+    return generic;
+
+  InFlightDiagnostic diag = emitError() << "failed to apply";
+  diag.attachNote(target.getLoc()) << "attempted to apply to this op";
+  return diag;
+}
+
 //===----------------------------------------------------------------------===//
 // InterchangeOp
 //===----------------------------------------------------------------------===//
@@ -70,15 +132,7 @@ FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) {
     return diag;
   }
 
-  GenericOpInterchangePattern pattern(getContext(), interchangeVector);
-  SimpleRewriter rewriter(getContext());
-  rewriter.setInsertionPoint(target);
-  FailureOr<GenericOp> result =
-      pattern.returningMatchAndRewrite(genericTarget, rewriter);
-  if (failed(result))
-    return failure();
-
-  return cast<LinalgOp>(result->getOperation());
+  return tryApply<GenericOpInterchangePattern>(target, interchangeVector);
 }
 
 LogicalResult transform::InterchangeOp::verify() {
@@ -147,18 +201,15 @@ FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) {
   paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings()));
   paddingOptions.setTransposePaddings(transposePaddings);
 
-  LinalgPaddingPattern pattern(getContext(), paddingOptions);
-  SimpleRewriter rewriter(getContext());
-  rewriter.setInsertionPoint(target);
-  FailureOr<LinalgOp> patternResult =
-      pattern.returningMatchAndRewrite(target, rewriter);
-  if (failed(patternResult)) {
-    InFlightDiagnostic diag = emitError()
-                              << "failed to apply pattern to target op";
-    diag.attachNote(target.getLoc()) << "target op";
-    return diag;
-  }
-  return patternResult;
+  FailureOr<LinalgOp> result =
+      tryApply<LinalgPaddingPattern>(target, paddingOptions);
+  if (succeeded(result))
+    return result;
+
+  InFlightDiagnostic diag = emitError()
+                            << "failed to apply pattern to target op";
+  diag.attachNote(target.getLoc()) << "target op";
+  return diag;
 }
 
 LogicalResult transform::PadOp::verify() {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7fed6c0428fb2..6b347561a09e0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -945,7 +945,6 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
   return success();
 }
 
-namespace {
 // The following are patterns for downscaling convolution ops with size-1
 // window dimensions.
 //
@@ -954,179 +953,145 @@ namespace {
 // and then turning back to named ops. But for now it's fine to have a few
 // patterns matching special ops to get started.
 
-/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
-/// convolution ops.
-struct DownscaleSizeOneWindowed2DConvolution final
-    : public OpRewritePattern<Conv2DNhwcHwcfOp> {
-  DownscaleSizeOneWindowed2DConvolution(
-      MLIRContext *context,
-      LinalgTransformationFilter f = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
-      : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
-        filter(std::move(f)) {}
-
-  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
-                                PatternRewriter &rewriter) const override {
-    if (failed(filter.checkAndNotify(rewriter, convOp)))
-      return failure();
-    if (convOp.hasBufferSemantics())
-      return failure(); // To be implemented
-
-    Value input = convOp.inputs().front();
-    Value kernel = convOp.inputs().back();
-    Value output = convOp.outputs().front();
-
-    auto inputType = input.getType().dyn_cast<RankedTensorType>();
-    auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
-    auto outputType = output.getType().dyn_cast<RankedTensorType>();
-
-    auto kernelShape = kernelType.getShape();
-    auto outputShape = outputType.getShape();
-
-    // Only handle the case where at least one of the window dimensions is
-    // of size 1. Other cases can rely on tiling to reduce to such cases.
-    int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
-    int64_t ohSize = outputShape[1], owSize = outputShape[2];
-    bool removeH = (khSize == 1 && ohSize == 1);
-    bool removeW = (kwSize == 1 && owSize == 1);
-    if (!removeH && !removeW)
-      return failure();
-
-    // Get new shapes and types for all operands by removing the size-1
-    // dimension.
-    using RTTBuilder = RankedTensorType::Builder;
-    RankedTensorType newInputType =
-        RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
-    RankedTensorType newKernelType =
-        RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
-    RankedTensorType newOutputType =
-        RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
-
-    // Rank-reduce operands.
-    Location loc = convOp.getLoc();
-    Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
-        rewriter, loc, input, newInputType);
-    Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
-        rewriter, loc, kernel, newKernelType);
-    Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
-        rewriter, loc, output, newOutputType);
-
-    // Rank-reduce strides and dilations too.
-    // TODO: dropDim 1-liner helper.
-    auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
-    strides.erase(strides.begin() + (removeH ? 0 : 1));
-    auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
-    auto dilations =
-        llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
-    dilations.erase(dilations.begin() + (removeH ? 0 : 1));
-    auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
-
-    auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
-        loc, newOutputType, ValueRange{newInput, newKernel},
-        ValueRange{newOutput}, stridesAttr, dilationsAttr);
-
-    // Insert back.
-    Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
-        rewriter, loc, conv1DOp.getResult(0), output);
-    rewriter.replaceOp(convOp, inserted);
-
-    filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
-    return success();
-  };
-
-private:
-  /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter filter;
-};
-
-/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
-/// dimensions into 1-D depthwise convolution ops.
-struct DownscaleDepthwiseConv2DNhwcHwcOp final
-    : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
-  DownscaleDepthwiseConv2DNhwcHwcOp(
-      MLIRContext *context,
-      LinalgTransformationFilter f = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
-      : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
-        filter(std::move(f)) {}
-
-  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
-                                PatternRewriter &rewriter) const override {
-    if (failed(filter.checkAndNotify(rewriter, convOp)))
-      return failure();
-    if (convOp.hasBufferSemantics())
-      return failure(); // To be implemented
-
-    Value input = convOp.inputs().front();
-    Value kernel = convOp.inputs().back();
-    Value output = convOp.outputs().front();
-
-    auto inputType = input.getType().dyn_cast<RankedTensorType>();
-    auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
-    auto outputType = output.getType().dyn_cast<RankedTensorType>();
-
-    auto kernelShape = kernelType.getShape();
-    auto outputShape = outputType.getShape();
-
-    // Only handle the case where at least one of the window dimensions is
-    // of size 1. Other cases can rely on tiling to reduce to such cases.
-    int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
-    int64_t ohSize = outputShape[1], owSize = outputShape[2];
-    bool removeH = (khSize == 1 && ohSize == 1);
-    bool removeW = (kwSize == 1 && owSize == 1);
-    if (!removeH && !removeW)
-      return failure();
+FailureOr<Conv1DNwcWcfOp>
+DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
+    linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const {
+  if (failed(filter.checkAndNotify(rewriter, convOp)))
+    return failure();
+  if (convOp.hasBufferSemantics())
+    return failure(); // To be implemented.
+
+  Value input = convOp.inputs().front();
+  Value kernel = convOp.inputs().back();
+  Value output = convOp.outputs().front();
+
+  auto inputType = input.getType().dyn_cast<RankedTensorType>();
+  auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
+  auto outputType = output.getType().dyn_cast<RankedTensorType>();
+
+  auto kernelShape = kernelType.getShape();
+  auto outputShape = outputType.getShape();
+
+  // Only handle the case where at least one of the window dimensions is
+  // of size 1. Other cases can rely on tiling to reduce to such cases.
+  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
+  int64_t ohSize = outputShape[1], owSize = outputShape[2];
+  bool removeH = (khSize == 1 && ohSize == 1);
+  bool removeW = (kwSize == 1 && owSize == 1);
+  if (!removeH && !removeW)
+    return failure();
 
-    // Get new shapes and types for all operands by removing the size-1
-    // dimension.
-    using RTTBuilder = RankedTensorType::Builder;
-    RankedTensorType newInputType =
-        RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
-    RankedTensorType newKernelType =
-        RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
-    RankedTensorType newOutputType =
-        RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
-
-    // Rank-reduce operands.
-    Location loc = convOp.getLoc();
-    Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
-        rewriter, loc, input, newInputType);
-    Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
-        rewriter, loc, kernel, newKernelType);
-    Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
-        rewriter, loc, output, newOutputType);
-
-    // Rank-reduce strides and dilations too.
-    // TODO: dropDim 1-liner helper.
-    auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
-    strides.erase(strides.begin() + (removeH ? 0 : 1));
-    auto stridesAttr = rewriter.getI64VectorAttr(strides);
-
-    auto dilations =
-        llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
-    dilations.erase(dilations.begin() + (removeH ? 0 : 1));
-    auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
-
-    auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
-        loc, newOutputType, ValueRange{newInput, newKernel},
-        ValueRange{newOutput}, stridesAttr, dilationsAttr);
-
-    // Insert back.
-    Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
-        rewriter, loc, conv1DOp.getResult(0), output);
-    rewriter.replaceOp(convOp, inserted);
-
-    filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
-    return success();
-  };
+  // Get new shapes and types for all operands by removing the size-1
+  // dimension.
+  using RTTBuilder = RankedTensorType::Builder;
+  RankedTensorType newInputType =
+      RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+  RankedTensorType newKernelType =
+      RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
+  RankedTensorType newOutputType =
+      RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
+
+  // Rank-reduce operands.
+  Location loc = convOp.getLoc();
+  Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, input, newInputType);
+  Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, kernel, newKernelType);
+  Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, output, newOutputType);
+
+  // Rank-reduce strides and dilations too.
+  // TODO: dropDim 1-liner helper.
+  auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
+  strides.erase(strides.begin() + (removeH ? 0 : 1));
+  auto stridesAttr = rewriter.getI64VectorAttr(strides);
+
+  auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
+  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
+  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
+
+  auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
+      loc, newOutputType, ValueRange{newInput, newKernel},
+      ValueRange{newOutput}, stridesAttr, dilationsAttr);
+
+  // Insert back.
+  Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
+      rewriter, loc, conv1DOp.getResult(0), output);
+  rewriter.replaceOp(convOp, inserted);
+
+  filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
+  return conv1DOp;
+}
 
-private:
-  /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter filter;
-};
+FailureOr<DepthwiseConv1DNwcWcOp>
+DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
+    DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
+  if (failed(filter.checkAndNotify(rewriter, convOp)))
+    return failure();
+  if (convOp.hasBufferSemantics())
+    return failure(); // To be implemented.
+
+  Value input = convOp.inputs().front();
+  Value kernel = convOp.inputs().back();
+  Value output = convOp.outputs().front();
+
+  auto inputType = input.getType().dyn_cast<RankedTensorType>();
+  auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
+  auto outputType = output.getType().dyn_cast<RankedTensorType>();
+
+  auto kernelShape = kernelType.getShape();
+  auto outputShape = outputType.getShape();
+
+  // Only handle the case where at least one of the window dimensions is
+  // of size 1. Other cases can rely on tiling to reduce to such cases.
+  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
+  int64_t ohSize = outputShape[1], owSize = outputShape[2];
+  bool removeH = (khSize == 1 && ohSize == 1);
+  bool removeW = (kwSize == 1 && owSize == 1);
+  if (!removeH && !removeW)
+    return failure();
 
-} // namespace
+  // Get new shapes and types for all operands by removing the size-1
+  // dimension.
+  using RTTBuilder = RankedTensorType::Builder;
+  RankedTensorType newInputType =
+      RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+  RankedTensorType newKernelType =
+      RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
+  RankedTensorType newOutputType =
+      RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
+
+  // Rank-reduce operands.
+  Location loc = convOp.getLoc();
+  Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, input, newInputType);
+  Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, kernel, newKernelType);
+  Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, output, newOutputType);
+
+  // Rank-reduce strides and dilations too.
+  // TODO: dropDim 1-liner helper.
+  auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
+  strides.erase(strides.begin() + (removeH ? 0 : 1));
+  auto stridesAttr = rewriter.getI64VectorAttr(strides);
+
+  auto dilations = llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
+  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
+  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
+
+  auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
+      loc, newOutputType, ValueRange{newInput, newKernel},
+      ValueRange{newOutput}, stridesAttr, dilationsAttr);
+
+  // Insert back.
+  Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
+      rewriter, loc, conv1DOp.getResult(0), output);
+  rewriter.replaceOp(convOp, inserted);
+
+  filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
+  return conv1DOp;
+}
 
 void linalg::populateDecomposeConvolutionPatterns(
     RewritePatternSet &patterns, const LinalgTransformationFilter &filter,

diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 70e39be5289da..e5a2a473150cc 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -69,6 +69,28 @@ def _get_int_int_array_attr(
   return ArrayAttr.get([_get_int_array_attr(value) for value in values])
 
 
+class DecomposeOp:
+  """Specialization for DecomposeOp class."""
+
+  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+    super().__init__(
+        pdl.OperationType.get(),
+        _get_op_result_or_value(target),
+        loc=loc,
+        ip=ip)
+
+
+class GeneralizeOp:
+  """Specialization for GeneralizeOp class."""
+
+  def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
+    super().__init__(
+        pdl.OperationType.get(),
+        _get_op_result_or_value(target),
+        loc=loc,
+        ip=ip)
+
+
 class InterchangeOp:
   """Specialization for InterchangeOp class."""
 

diff  --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
new file mode 100644
index 0000000000000..e80c3b1078d6d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @conv_2d_nhwc_hwcf
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
+func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+  // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_nwc_wcf
+  // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+  %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
+                                 strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?x?x?xf32>)
+    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+  // CHECK: return %[[RES]]
+  return %0 : tensor<?x1x?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.conv_2d_nhwc_hwcf"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    %1 = transform.structured.decompose %0
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
+func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> {
+  // CHECK: %[[RES:.+]] = linalg.init_tensor
+  %init = linalg.init_tensor [1, 1, 56, 96] : tensor<1x1x56x96xf32>
+  // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+  // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+  // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]]
+  // CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc
+  // CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]]
+  // CHECK-SAME: outs(%[[SLICERES]]
+  // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]]
+  %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+         ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
+         outs(%init: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
+  // CHECK: %[[INSERTED]]
+  return %0: tensor<1x1x56x96xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.depthwise_conv_2d_nhwc_hwc"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    %1 = transform.structured.decompose %0
+  }
+}

diff  --git a/mlir/test/Dialect/Linalg/transform-op-generalize.mlir b/mlir/test/Dialect/Linalg/transform-op-generalize.mlir
new file mode 100644
index 0000000000000..1a20cf7502cab
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-generalize.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s
+
+// CHECK-LABEL: func.func @generalize_unary
+func.func @generalize_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+
+  // CHECK-NOT:   linalg.elemwise_unary
+  //     CHECK:   linalg.generic
+  %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+                             outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.elemwise_unary"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    %1 = transform.structured.generalize %0
+  }
+}

diff  --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 463dec10d7bd5..a34b03fb9d0bc 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -16,6 +16,28 @@ def run(f):
   return f
 
 
+ at run
+def testDecompose():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    structured.DecomposeOp(sequence.bodyTarget)
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testDecompose
+  # CHECK: transform.sequence
+  # CHECK: transform.structured.decompose
+
+
+ at run
+def testGeneralize():
+  sequence = transform.SequenceOp()
+  with InsertionPoint(sequence.body):
+    structured.GeneralizeOp(sequence.bodyTarget)
+    transform.YieldOp()
+  # CHECK-LABEL: TEST: testGeneralize
+  # CHECK: transform.sequence
+  # CHECK: transform.structured.generalize
+
+
 @run
 def testInterchange():
   sequence = transform.SequenceOp()


        


More information about the Mlir-commits mailing list