[Mlir-commits] [mlir] 1a151fd - [mlir][linalg] Downscale 2D pooling with unit dimensions for height to 1D pooling
Murali Vijayaraghavan
llvmlistbot at llvm.org
Mon Dec 19 14:35:36 PST 2022
Author: Murali Vijayaraghavan
Date: 2022-12-19T22:34:43Z
New Revision: 1a151fdc011dc422d740a954c346827613961350
URL: https://github.com/llvm/llvm-project/commit/1a151fdc011dc422d740a954c346827613961350
DIFF: https://github.com/llvm/llvm-project/commit/1a151fdc011dc422d740a954c346827613961350.diff
LOG: [mlir][linalg] Downscale 2D pooling with unit dimensions for height to 1D pooling
Differential Revision: https://reviews.llvm.org/D140187
Added:
Modified:
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/transform-op-decompose.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5f1a4ecab35d8..347c53085aa58 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -67,26 +67,31 @@ DiagnosedSilenceableFailure
transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
- FailureOr<LinalgOp> windowedNhwc =
- tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
- Conv1DNwcWcfOp>>(target);
- if (succeeded(windowedNhwc)) {
- results.push_back(*windowedNhwc);
- return DiagnosedSilenceableFailure::success();
- }
- FailureOr<LinalgOp> windowedNchw =
- tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
- Conv1DNcwFcwOp>>(target);
- if (succeeded(windowedNchw)) {
- results.push_back(*windowedNchw);
- return DiagnosedSilenceableFailure::success();
- }
- FailureOr<LinalgOp> depthwise =
- tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
- if (succeeded(depthwise)) {
- results.push_back(*depthwise);
- return DiagnosedSilenceableFailure::success();
- }
+#define DOWNSCALE(trans) \
+ { \
+ FailureOr<LinalgOp> res = tryApply<trans>(target); \
+ if (succeeded(res)) { \
+ results.push_back(*res); \
+ return DiagnosedSilenceableFailure::success(); \
+ } \
+ }
+
+#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
+#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
+
+ DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
+ DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
+ DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
+ DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
+ DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
+ DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
+ DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
+ DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
+ DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
+ DOWNSCALE(DownscaleDepthwiseConv2DNhwcHwcOp)
+#undef DOWNSCALE_NORMAL
+#undef DOWNSCALE_CALL
+#undef DOWNSCALE
results.assign(1, nullptr);
return emitDefaultSilenceableFailure(target);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index b8c6115d04474..77ea7aff8eae6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -613,23 +613,39 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
auto outputShape = outputType.getShape();
// Get domain indices based on conv2D layout.
- int khIndex, kwIndex, ohIndex, owIndex;
-
- TypeSwitch<Operation *>(convOp)
+ auto [khIndex, kwIndex, ohIndex, owIndex] =
+ TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t,
+ int64_t>>(convOp)
.Case([&](linalg::Conv2DNhwcHwcfOp op) {
- khIndex = 0;
- kwIndex = 1;
- ohIndex = 1;
- owIndex = 2;
+ return std::make_tuple(0, 1, 1, 2);
})
.Case([&](linalg::Conv2DNchwFchwOp op) {
- khIndex = 2;
- kwIndex = 3;
- ohIndex = 2;
- owIndex = 3;
+ return std::make_tuple(2, 3, 2, 3);
+ })
+ .Case([&](linalg::PoolingNhwcSumOp op) {
+ return std::make_tuple(0, 1, 1, 2);
+ })
+ .Case([&](linalg::PoolingNchwSumOp op) {
+ return std::make_tuple(0, 1, 2, 3);
+ })
+ .Case([&](linalg::PoolingNhwcMaxOp op) {
+ return std::make_tuple(0, 1, 1, 2);
+ })
+ .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
+ return std::make_tuple(0, 1, 1, 2);
+ })
+ .Case([&](linalg::PoolingNhwcMinOp op) {
+ return std::make_tuple(0, 1, 1, 2);
+ })
+ .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
+ return std::make_tuple(0, 1, 1, 2);
+ })
+ .Case([&](linalg::PoolingNchwMaxOp op) {
+ return std::make_tuple(0, 1, 2, 3);
})
.Default([&](Operation *op) {
- llvm_unreachable("unexpected conv2d operation.");
+ llvm_unreachable("unexpected conv2d/pool2d operation.");
+ return std::make_tuple(0, 0, 0, 0);
});
// Only handle the case where at least one of the window dimensions is
@@ -688,6 +704,20 @@ template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
Conv1DNwcWcfOp>;
template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
Conv1DNcwFcwOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
+ PoolingNwcSumOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
+ PoolingNcwSumOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
+ PoolingNwcMaxOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<
+ PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
+ PoolingNwcMinOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<
+ PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
+template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
+ PoolingNcwMaxOp>;
FailureOr<DepthwiseConv1DNwcWcOp>
DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
@@ -765,4 +795,15 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
Conv1DNcwFcwOp>,
DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(),
benefit);
+ patterns.add<
+ DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>,
+ DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>,
+ DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>,
+ DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
+ PoolingNwcMaxUnsignedOp>,
+ DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>,
+ DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
+ PoolingNwcMinUnsignedOp>,
+ DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>(
+ patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 81ee39d0c557d..2c873b263bdbe 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -56,6 +56,132 @@ func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: t
return %0: tensor<1x1x56x96xf32>
}
+// CHECK-LABEL: @pooling_nhwc_sum
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
+func.func @pooling_nhwc_sum(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_sum
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
+ outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x1x?x?xf32>
+}
+
+// CHECK-LABEL: @pooling_nchw_sum
+// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
+// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
+// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
+func.func @pooling_nchw_sum(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_sum
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.pooling_nchw_sum {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
+ outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x?x1x?xf32>
+}
+
+// CHECK-LABEL: @pooling_nhwc_max
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
+func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
+ outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x1x?x?xf32>
+}
+
+// CHECK-LABEL: @pooling_nhwc_max_unsigned
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_max_unsigned
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
+ outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x1x?x?xf32>
+}
+
+// CHECK-LABEL: @pooling_nhwc_min
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
+func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
+ outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x1x?x?xf32>
+}
+
+// CHECK-LABEL: @pooling_nhwc_min_unsigned
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_nwc_min_unsigned
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
+ outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x1x?x?xf32>
+}
+
+// CHECK-LABEL: @pooling_nchw_max
+// CHECK-SAME: (%[[ARG0:[0-9a-z]+]]: tensor<?x?x1x?xf32>,
+// CHECK-SAME: %[[ARG1:[0-9a-z]+]]: tensor<1x?xf32>,
+// CHECK-SAME: %[[ARG2:[0-9a-z]+]]: tensor<?x?x1x?xf32>)
+func.func @pooling_nchw_max(%input: tensor<?x?x1x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32> {
+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
+ // CHECK: %[[SLICERES:.+]] = linalg.pooling_ncw_max
+ // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
+ %0 = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x1x?xf32>, tensor<1x?xf32>)
+ outs (%init: tensor<?x?x1x?xf32>) -> tensor<?x?x1x?xf32>
+ // CHECK: return %[[RES]]
+ return %0 : tensor<?x?x1x?xf32>
+}
+
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match interface{LinalgOp} in %arg1
More information about the Mlir-commits
mailing list