[Mlir-commits] [mlir] b828506 - [mlir][Linalg] Add a DownscaleDepthwiseConv2DNhwcHwcOp decomposition pattern.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Nov 15 12:48:20 PST 2021
Author: Nicolas Vasilache
Date: 2021-11-15T20:48:16Z
New Revision: b828506ecacee709893adc9afbc36ae24dbcc068
URL: https://github.com/llvm/llvm-project/commit/b828506ecacee709893adc9afbc36ae24dbcc068
DIFF: https://github.com/llvm/llvm-project/commit/b828506ecacee709893adc9afbc36ae24dbcc068.diff
LOG: [mlir][Linalg] Add a DownscaleDepthwiseConv2DNhwcHwcOp decomposition pattern.
Reviewed By: gysit
Differential Revision: https://reviews.llvm.org/D113907
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/decompose-convolution.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index f15337f0f70e..a4b4d78e1073 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -904,10 +904,83 @@ struct DownscaleSizeOneWindowed2DConvolution final
};
};
+/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
+/// dimensions into 1-D depthwise convolution ops.
+struct DownscaleDepthwiseConv2DNhwcHwcOp final
+ : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
+ PatternRewriter &rewriter) const override {
+ auto linalgOp = cast<linalg::LinalgOp>(*convOp);
+ if (linalgOp.hasBufferSemantics())
+ return failure(); // To be implemented
+
+ Value input = convOp.inputs().front();
+ Value kernel = convOp.inputs().back();
+ Value output = convOp.outputs().front();
+
+ auto inputType = input.getType().dyn_cast<RankedTensorType>();
+ auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
+ auto outputType = output.getType().dyn_cast<RankedTensorType>();
+
+ auto kernelShape = kernelType.getShape();
+ auto outputShape = outputType.getShape();
+
+ // Only handle the case where at least one of the window dimensions is
+ // of size 1. Other cases can rely on tiling to reduce to such cases.
+ int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
+ int64_t ohSize = outputShape[1], owSize = outputShape[2];
+ bool removeH = (khSize == 1 && ohSize == 1);
+ bool removeW = (kwSize == 1 && owSize == 1);
+ if (!removeH && !removeW)
+ return failure();
+
+ // Get new shapes and types for all operands by removing the size-1
+ // dimension.
+ using RTTBuilder = RankedTensorType::Builder;
+ auto newInputType = RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+ auto newKernelType = RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
+ auto newOutputType = RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
+
+ // Rank-reduce operands.
+ Location loc = convOp.getLoc();
+ Value newInput = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, input, newInputType);
+ Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, kernel, newKernelType);
+ Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp(
+ rewriter, loc, output, newOutputType);
+
+ // Rank-reduce strides and dilations too.
+ // TODO: dropDim 1-liner helper.
+ auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
+ strides.erase(strides.begin() + (removeH ? 0 : 1));
+ auto stridesAttr = rewriter.getI64VectorAttr(strides);
+
+ auto dilations =
+ llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
+ dilations.erase(dilations.begin() + (removeH ? 0 : 1));
+ auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
+
+ auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
+ loc, newOutputType, ValueRange{newInput, newKernel},
+ ValueRange{newOutput}, stridesAttr, dilationsAttr);
+
+ // Insert back.
+ Value inserted = tensor::createCanonicalRankReducingInsertSliceOp(
+ rewriter, loc, conv1DOp.getResult(0), output);
+ rewriter.replaceOp(convOp, inserted);
+
+ return success();
+ };
+};
+
} // namespace
void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
- patterns.add<DownscaleSizeOneWindowed2DConvolution>(patterns.getContext(),
- benefit);
+ patterns.add<DownscaleSizeOneWindowed2DConvolution,
+ DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(),
+ benefit);
}
diff --git a/mlir/test/Dialect/Linalg/decompose-convolution.mlir b/mlir/test/Dialect/Linalg/decompose-convolution.mlir
index 381ed1bd6080..2eb79ec8935f 100644
--- a/mlir/test/Dialect/Linalg/decompose-convolution.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-convolution.mlir
@@ -68,3 +68,27 @@ func @conv2d_nhwc_4x1x2x8_tensor(%input: tensor<4x3x5x3xf32>, %filter: tensor<2x
outs(%init : tensor<4x1x2x8xf32>) -> tensor<4x1x2x8xf32>
return %0 : tensor<4x1x2x8xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_tensor
+func @depthwise_conv_2d_nhwc_hwc_tensor(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>, %out: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32> {
+ // CHECK: linalg.depthwise_conv_1d_nwc_wc
+ %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+ ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>)
+ outs(%out: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32>
+ return %0: tensor<1x1x56x96xf32>
+}
+
+// -----
+
+// Do not convert convolution ops whose window dimensions are not ones.
+
+// CHECK-LABEL: func @depthwise_conv_2d_nhwc_hwc_tensor
+func @depthwise_conv_2d_nhwc_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>, %out: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32> {
+ // CHECK: linalg.depthwise_conv_2d_nhwc_hwc
+ %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+ ins(%input, %filter: tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
+ outs(%out: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
+ return %0: tensor<1x56x56x96xf32>
+}
More information about the Mlir-commits
mailing list