[Mlir-commits] [mlir] b828506 - [mlir][Linalg] Add a DownscaleDepthwiseConv2DNhwcHwcOp decomposition pattern.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Nov 15 12:48:20 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-15T20:48:16Z
New Revision: b828506ecacee709893adc9afbc36ae24dbcc068

URL: https://github.com/llvm/llvm-project/commit/b828506ecacee709893adc9afbc36ae24dbcc068
DIFF: https://github.com/llvm/llvm-project/commit/b828506ecacee709893adc9afbc36ae24dbcc068.diff

LOG: [mlir][Linalg] Add a DownscaleDepthwiseConv2DNhwcHwcOp decomposition pattern.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D113907

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/Dialect/Linalg/decompose-convolution.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index f15337f0f70e..a4b4d78e1073 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -904,10 +904,83 @@ struct DownscaleSizeOneWindowed2DConvolution final
   };
 };
 
+/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
+/// dimensions into 1-D depthwise convolution ops.
+struct DownscaleDepthwiseConv2DNhwcHwcOp final
+    : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
+                                PatternRewriter &rewriter) const override {
+    auto linalgOp = cast<linalg::LinalgOp>(*convOp);
+    if (linalgOp.hasBufferSemantics())
+      return failure(); // To be implemented
+
+    Value input = convOp.inputs().front();
+    Value kernel = convOp.inputs().back();
+    Value output = convOp.outputs().front();
+
+    auto inputType = input.getType().dyn_cast<RankedTensorType>();
+    auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
+    auto outputType = output.getType().dyn_cast<RankedTensorType>();
+
+    auto kernelShape = kernelType.getShape();
+    auto outputShape = outputType.getShape();
+
+    // Only handle the case where at least one of the window dimensions is
+    // of size 1. Other cases can rely on tiling to reduce to such cases.
+    int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
+    int64_t ohSize = outputShape[1], owSize = outputShape[2];
+    bool removeH = (khSize == 1 && ohSize == 1);
+    bool removeW = (kwSize == 1 && owSize == 1);
+    if (!removeH && !removeW)
+      return failure();
+
+    // Get new shapes and types for all operands by removing the size-1
+    // dimension.
+    using RTTBuilder = RankedTensorType::Builder;
+    auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+    auto newKernelType = RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
+    auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
+
+    // Rank-reduce operands.
+    Location loc = convOp.getLoc();
+    Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
+        rewriter, loc, input, newInputType);
+    Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
+        rewriter, loc, kernel, newKernelType);
+    Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
+        rewriter, loc, output, newOutputType);
+
+    // Rank-reduce strides and dilations too.
+    // TODO: dropDim 1-liner helper.
+    auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
+    strides.erase(strides.begin() + (removeH ? 0 : 1));
+    auto stridesAttr = rewriter.getI64VectorAttr(strides);
+
+    auto dilations =
+        llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
+    dilations.erase(dilations.begin() + (removeH ? 0 : 1));
+    auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
+
+    auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
+        loc, newOutputType, ValueRange{newInput, newKernel},
+        ValueRange{newOutput}, stridesAttr, dilationsAttr);
+
+    // Insert back.
+    Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
+        rewriter, loc, conv1DOp.getResult(0), output);
+    rewriter.replaceOp(convOp, inserted);
+
+    return success();
+  };
+};
+
 } // namespace
 
 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
                                                   PatternBenefit benefit) {
-  patterns.add<DownscaleSizeOneWindowed2DConvolution>(patterns.getContext(),
-                                                      benefit);
+  patterns.add<DownscaleSizeOneWindowed2DConvolution,
+               DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(),
+                                                  benefit);
 }

diff  --git a/mlir/test/Dialect/Linalg/decompose-convolution.mlir b/mlir/test/Dialect/Linalg/decompose-convolution.mlir
index 381ed1bd6080..2eb79ec8935f 100644
--- a/mlir/test/Dialect/Linalg/decompose-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-convolution.mlir
@@ -68,3 +68,27 @@ func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x3x5x3xf32>, %filter: tensor<2x
     outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32>
   return %0 : tensor<4x1x2x8xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_tensor
+func @depthwise_conv_2d_nhwc_hwc_tensor(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>, %out: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32> {
+  //     CHECK: linalg.depthwise_conv_1d_nwc_wc
+  %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+         ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
+         outs(%out: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
+  return %0: tensor<1x1x56x96xf32>
+}
+
+// -----
+
+// Do not convert convolution ops whose window dimensions are not ones.
+
+// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_tensor
+func @depthwise_conv_2d_nhwc_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>, %out: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> {
+  //     CHECK: linalg.depthwise_conv_2d_nhwc_hwc
+  %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+         ins(%input, %filter: tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
+         outs(%out: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
+  return %0: tensor<1x56x56x96xf32>
+}


        


More information about the Mlir-commits mailing list