[Mlir-commits] [mlir] 8e484b5 - [mlir][linalg] Add decomposition from conv_2d_nchw
Hanhan Wang
llvmlistbot at llvm.org
Fri Sep 9 16:01:03 PDT 2022
Author: Stanley Winata
Date: 2022-09-09T16:00:37-07:00
New Revision: 8e484b522b7a2f72f21632bd30cc858269db6a1f
URL: https://github.com/llvm/llvm-project/commit/8e484b522b7a2f72f21632bd30cc858269db6a1f
DIFF: https://github.com/llvm/llvm-project/commit/8e484b522b7a2f72f21632bd30cc858269db6a1f.diff
LOG: [mlir][linalg] Add decomposition from conv_2d_nchw
Decompose conv_2d_nchw_fchw -> conv_1d_ncw_fcw
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D133551
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-op-decompose.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index abd0243c7119a..e6a1b07b21088 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -754,20 +754,19 @@ struct LinalgPaddingPattern : public OpInterfaceRewritePattern<LinalgOp> {
/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
/// convolution ops.
+template <typename Conv2DOp, typename Conv1DOp>
struct DownscaleSizeOneWindowed2DConvolution final
- : public OpRewritePattern<Conv2DNhwcHwcfOp> {
+ : public OpRewritePattern<Conv2DOp> {
DownscaleSizeOneWindowed2DConvolution(
MLIRContext *context,
LinalgTransformationFilter f = LinalgTransformationFilter(),
PatternBenefit benefit = 1)
- : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
- filter(std::move(f)) {}
+ : OpRewritePattern<Conv2DOp>(context, benefit), filter(std::move(f)) {}
- FailureOr<Conv1DNwcWcfOp>
- returningMatchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
- PatternRewriter &rewriter) const;
+ FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
+ PatternRewriter &rewriter) const;
- LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
+ LogicalResult matchAndRewrite(Conv2DOp convOp,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(convOp, rewriter);
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 80e4d313c1ad1..f4241f44ea0f6 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -76,10 +76,18 @@ DiagnosedSilenceableFailure
transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
- FailureOr<LinalgOp> windowed =
- tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
- if (succeeded(windowed)) {
- results.push_back(*windowed);
+ FailureOr<LinalgOp> windowedNhwc =
+ tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
+ Conv1DNwcWcfOp>>(target);
+ if (succeeded(windowedNhwc)) {
+ results.push_back(*windowedNhwc);
+ return DiagnosedSilenceableFailure(success());
+ }
+ FailureOr<LinalgOp> windowedNchw =
+ tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
+ Conv1DNcwFcwOp>>(target);
+ if (succeeded(windowedNchw)) {
+ results.push_back(*windowedNchw);
return DiagnosedSilenceableFailure(success());
}
FailureOr<LinalgOp> depthwise =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 2fcbe680259fb..1c4ceaaae034f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -828,9 +828,9 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
// and then turning back to named ops. But for now it's fine to have a few
// patterns matching special ops to get started.
-FailureOr<Conv1DNwcWcfOp>
-DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
- linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const {
+template <typename Conv2DOp, typename Conv1DOp>
+FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
+ returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
if (failed(filter.checkAndNotify(rewriter, convOp)))
return failure();
if (convOp.hasBufferSemantics())
@@ -847,10 +847,30 @@ DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
auto kernelShape = kernelType.getShape();
auto outputShape = outputType.getShape();
+ // Get domain indices based on conv2D layout.
+ int khIndex, kwIndex, ohIndex, owIndex;
+
+ TypeSwitch<Operation *>(convOp)
+ .Case([&](linalg::Conv2DNhwcHwcfOp op) {
+ khIndex = 0;
+ kwIndex = 1;
+ ohIndex = 1;
+ owIndex = 2;
+ })
+ .Case([&](linalg::Conv2DNchwFchwOp op) {
+ khIndex = 2;
+ kwIndex = 3;
+ ohIndex = 2;
+ owIndex = 3;
+ })
+ .Default([&](Operation *op) {
+ llvm_unreachable("unexpected conv2d operation.");
+ });
+
// Only handle the case where at least one of the window dimensions is
// of size 1. Other cases can rely on tiling to reduce to such cases.
- int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
- int64_t ohSize = outputShape[1], owSize = outputShape[2];
+ int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
+ int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
bool removeH = (khSize == 1 && ohSize == 1);
bool removeW = (kwSize == 1 && owSize == 1);
if (!removeH && !removeW)
@@ -860,11 +880,11 @@ DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
// dimension.
using RTTBuilder = RankedTensorType::Builder;
RankedTensorType newInputType =
- RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
+ RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
RankedTensorType newKernelType =
- RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
+ RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
RankedTensorType newOutputType =
- RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
+ RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
// Rank-reduce operands.
Location loc = convOp.getLoc();
@@ -877,16 +897,17 @@ DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite(
// Rank-reduce strides and dilations too.
// TODO: dropDim 1-liner helper.
- auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
+ auto strides =
+ llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
strides.erase(strides.begin() + (removeH ? 0 : 1));
auto stridesAttr = rewriter.getI64VectorAttr(strides);
auto dilations =
- llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
+ llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
dilations.erase(dilations.begin() + (removeH ? 0 : 1));
auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
- auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
+ auto conv1DOp = rewriter.create<Conv1DOp>(
loc, newOutputType, ValueRange{newInput, newKernel},
ValueRange{newOutput}, stridesAttr, dilationsAttr);
@@ -973,7 +994,10 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
void linalg::populateDecomposeConvolutionPatterns(
RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
PatternBenefit benefit) {
- patterns.add<DownscaleSizeOneWindowed2DConvolution,
+ patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
+ Conv1DNwcWcfOp>,
+ DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
+ Conv1DNcwFcwOp>,
DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
benefit);
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 988d706c01ff9..73482897583a9 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -18,6 +18,24 @@ func.func @conv_2d_nhwc_hwcf(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?x?x
return %0 : tensor<?x1x?x?xf32>
}
+// CHECK-LABEL: @conv_2d_nchw_fchw
+// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
+// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
+// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
+func.func @conv_2d_nchw_fchw(%input: tensor<?x?x1x?xf32>, %filter: tensor<?x?x1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_ncw_fcw
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<?x?x1x?xf32>)
+ outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x?x1x?xf32>
+}
+
// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32>
More information about the Mlir-commits
mailing list