[Mlir-commits] [mlir] ceeb5b0 - [mlir][tosa] Add tosa.max_pool2d lowering to linalg int max pooling additions
Rob Suderman
llvmlistbot at llvm.org
Thu Apr 8 18:19:02 PDT 2021
Author: Rob Suderman
Date: 2021-04-08T18:17:16-07:00
New Revision: ceeb5b0f87a3f564026603c70d970eb7c1b6872e
URL: https://github.com/llvm/llvm-project/commit/ceeb5b0f87a3f564026603c70d970eb7c1b6872e
DIFF: https://github.com/llvm/llvm-project/commit/ceeb5b0f87a3f564026603c70d970eb7c1b6872e.diff
LOG: [mlir][tosa] Add tosa.max_pool2d lowering to linalg int max pooling additions
Lowerings tosa.max_pool2d to linalg equivalent operations. Includes
adding max pooling operations for linalg, with corresponding tests.
Differential Revision: https://reviews.llvm.org/D99824
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Dialect/Linalg/generalize-named-ops.mlir
mlir/test/Dialect/Linalg/named-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index ae29a1fe6fd91..e0fc25992a491 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -352,6 +352,51 @@ def pooling_nhwc_sum
ow * strides[1] + kw * dilations[1], c));
}
+ods_def<PoolingNHWCMaxI8Op>:
+def pooling_nhwc_i8_max
+ (I: i8(N, H, W, C), K: i8(KH, KW))
+ -> (O: i8(N, OH, OW, C))
+ attr(strides: 2xi64, dilations: 2xi64)
+{
+ O(n, oh, ow, c) =
+ std_select<kh, kw>(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0],
+ ow * strides[1] + kw * dilations[1], c),
+ O(n, oh, ow, c)),
+ I(n, oh * strides[0] + kh * dilations[0],
+ ow * strides[1] + kw * dilations[1], c),
+ O(n, oh, ow, c));
+}
+
+ods_def<PoolingNHWCMaxI16Op>:
+def pooling_nhwc_i16_max
+ (I: i16(N, H, W, C), K: i16(KH, KW))
+ -> (O: i16(N, OH, OW, C))
+ attr(strides: 2xi64, dilations: 2xi64)
+{
+ O(n, oh, ow, c) =
+ std_select<kh, kw>(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0],
+ ow * strides[1] + kw * dilations[1], c),
+ O(n, oh, ow, c)),
+ I(n, oh * strides[0] + kh * dilations[0],
+ ow * strides[1] + kw * dilations[1], c),
+ O(n, oh, ow, c));
+}
+
+ods_def<PoolingNHWCMaxI32Op>:
+def pooling_nhwc_i32_max
+ (I: i32(N, H, W, C), K: i32(KH, KW))
+ -> (O: i32(N, OH, OW, C))
+ attr(strides: 2xi64, dilations: 2xi64)
+{
+ O(n, oh, ow, c) =
+ std_select<kh, kw>(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0],
+ ow * strides[1] + kw * dilations[1], c),
+ O(n, oh, ow, c)),
+ I(n, oh * strides[0] + kh * dilations[0],
+ ow * strides[1] + kw * dilations[1], c),
+ O(n, oh, ow, c));
+}
+
ods_def<PoolingNHWCMaxFOp>:
def pooling_nhwc_max
(I: f32(N, H, W, C), K: f32(KH, KW))
diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
index d100cf7d9bc2e..17d461a708ca4 100644
--- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
+++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
@@ -59,6 +59,15 @@ struct CmpFValueBuilder : public ValueBuilder<CmpFOp> {
using std_cmpf_ogt = CmpFValueBuilder<CmpFPredicate::OGT>;
using std_cmpf_olt = CmpFValueBuilder<CmpFPredicate::OLT>;
+template <CmpIPredicate Predicate>
+struct CmpIValueBuilder : public ValueBuilder<CmpIOp> {
+ using ValueBuilder<CmpIOp>::ValueBuilder;
+ template <typename... Args>
+ CmpIValueBuilder(Args... args) : ValueBuilder<CmpIOp>(Predicate, args...) {}
+};
+
+using std_cmpi_sgt = CmpIValueBuilder<CmpIPredicate::sgt>;
+
/// Branches into `block` with `operands`.
BranchOp std_br(Block *block, ValueRange operands);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a6271f7097563..ef317aa3e9b46 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1230,6 +1230,22 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
"Pad converter requires static shaped input / padding values.");
}
+ Attribute constantAttr;
+ if (elementTy.isa<FloatType>())
+ constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
+ else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
+ constantAttr = rewriter.getIntegerAttr(elementTy, 0);
+ else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
+ auto value = padOp.quantization_info().getValue().input_zp().getValue();
+ constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
+ }
+
+ if (!constantAttr) {
+ return rewriter.notifyMatchFailure(
+ padOp,
+ "tosa.pad to linalg lowering encountered an unknown element type");
+ }
+
Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
Value highIndex =
rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
@@ -1256,22 +1272,6 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
highValues.push_back(highVal);
}
- Attribute constantAttr;
- if (elementTy.isa<FloatType>())
- constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
- else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
- constantAttr = rewriter.getIntegerAttr(elementTy, 0);
- else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
- auto value = padOp.quantization_info().getValue().input_zp().getValue();
- constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
- }
-
- if (!constantAttr) {
- return rewriter.notifyMatchFailure(
- padOp,
- "tosa.pad to linalg lowering encountered an unknown element type");
- }
-
Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
@@ -1523,6 +1523,128 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
}
};
+class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
+public:
+ using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ Value input = op.input();
+ ShapedType inputTy = input.getType().cast<ShapedType>();
+ Type inElementTy = inputTy.getElementType();
+
+ ShapedType resultTy = op.getType().cast<ShapedType>();
+ Type outElementTy = inputTy.getElementType();
+ int64_t rank = inputTy.getRank();
+
+ if (!inputTy.hasStaticShape())
+ return failure();
+
+ // Determine what the initial value needs to be for the max pool op.
+ Attribute initialAttr;
+ if (outElementTy.isF32())
+ initialAttr = rewriter.getFloatAttr(
+ outElementTy,
+ APFloat::getLargest(
+ outElementTy.cast<FloatType>().getFloatSemantics(), true));
+
+ if (outElementTy.isa<IntegerType>())
+ initialAttr = rewriter.getIntegerAttr(
+ outElementTy,
+ APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth()));
+
+ if (!initialAttr)
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported initial value for tosa.maxpool_2d op");
+
+ Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
+
+ SmallVector<int64_t> kernel, stride, pad;
+ getValuesFromIntArrayAttribute(op.kernel(), kernel);
+ getValuesFromIntArrayAttribute(op.stride(), stride);
+ getValuesFromIntArrayAttribute(op.pad(), pad);
+
+ Attribute strideAttr = rewriter.getI64VectorAttr(stride);
+ Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
+
+ // If non-zero padding we need to pad the input
+ if (llvm::any_of(pad, [](int64_t v) { return v != 0; })) {
+ SmallVector<int64_t, 4> paddedShape;
+ for (int64_t i = 0; i < rank; i++)
+ paddedShape.push_back(inputTy.getDimSize(i));
+
+ paddedShape[1] += pad[0] + pad[1];
+ paddedShape[2] += pad[2] + pad[3];
+
+ OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
+ OpFoldResult heightLowPadIndex = rewriter.getIndexAttr(pad[0]);
+ OpFoldResult heightHighPadIndex = rewriter.getIndexAttr(pad[1]);
+ OpFoldResult widthLowPadIndex = rewriter.getIndexAttr(pad[2]);
+ OpFoldResult widthHighPadIndex = rewriter.getIndexAttr(pad[3]);
+
+ SmallVector<OpFoldResult, 4> lowIndices = {zeroIndex, heightLowPadIndex,
+ widthLowPadIndex, zeroIndex};
+ SmallVector<OpFoldResult, 4> highIndices = {zeroIndex, heightHighPadIndex,
+ widthHighPadIndex, zeroIndex};
+
+ input = linalg::PadTensorOp::createPadScalarOp(
+ RankedTensorType::get(paddedShape, inElementTy), input,
+ initialValue, lowIndices, highIndices, loc, rewriter)
+ .result();
+ }
+
+ Value initTensor = rewriter.create<linalg::InitTensorOp>(
+ loc, resultTy.getShape(), resultTy.getElementType());
+
+ Value filledInitTensor =
+ rewriter.create<linalg::FillOp>(loc, initTensor, initialValue).result();
+
+ Value fakeWindowDims =
+ rewriter.create<linalg::InitTensorOp>(loc, kernel, outElementTy);
+
+ auto createOp = [&](auto *typePtr) -> linalg::LinalgOp {
+ return cast<linalg::LinalgOp>(
+ rewriter
+ .create<std::remove_pointer_t<decltype(typePtr)>>(
+ loc, ArrayRef<Type>{resultTy},
+ ValueRange{input, fakeWindowDims}, filledInitTensor,
+ dilationAttr, strideAttr)
+ .getOperation());
+ };
+
+ if (inElementTy.isF32()) {
+ linalg::LinalgOp poolingOp =
+ createOp(static_cast<linalg::PoolingNHWCMaxFOp *>(nullptr));
+ rewriter.replaceOp(op, poolingOp->getResult(0));
+ return success();
+ }
+
+ if (inElementTy.isInteger(8)) {
+ linalg::LinalgOp poolingOp =
+ createOp(static_cast<linalg::PoolingNHWCMaxI8Op *>(nullptr));
+ rewriter.replaceOp(op, poolingOp->getResult(0));
+ return success();
+ }
+
+ if (inElementTy.isInteger(16)) {
+ linalg::LinalgOp poolingOp =
+ createOp(static_cast<linalg::PoolingNHWCMaxI16Op *>(nullptr));
+ rewriter.replaceOp(op, poolingOp->getResult(0));
+ return success();
+ }
+
+ if (inElementTy.isInteger(32)) {
+ linalg::LinalgOp poolingOp =
+ createOp(static_cast<linalg::PoolingNHWCMaxI32Op *>(nullptr));
+ rewriter.replaceOp(op, poolingOp->getResult(0));
+ return success();
+ }
+
+ return failure();
+ }
+};
+
} // namespace
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
@@ -1579,6 +1701,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
TileConverter,
TransposeConverter,
MatMulConverter,
+ MaxPool2dConverter,
FullyConnectedConverter>(patterns->getContext());
// clang-format on
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 5d77c932bf121..b33c18f46ddf3 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -873,3 +873,53 @@ func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () {
%0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi16>, tensor<513xi16>) -> (tensor<6xi32>)
return
}
+
+// -----
+
+// CHECK-LABEL: @max_pool
+func @max_pool(%arg0: tensor<1x6x34x62xf32>) -> () {
+ // CHECK-DAG: [[CONST:%.+]] = constant -3.40282347E+38
+ // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 32, 62]
+ // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
+ // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3]
+ // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x32x62xf32>)
+ %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>)
+ return
+}
+
+// CHECK-LABEL: @max_pool_padded
+func @max_pool_padded(%arg0: tensor<1x6x34x62xf32>) -> () {
+ // CHECK-DAG: [[CONST:%.+]] = constant -3.40282347E+38 : f32
+ // CHECK-DAG: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 0, 0, 0] high[0, 0, 1, 0]
+ // CHECK-DAG: linalg.yield [[CONST]]
+ // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 33, 62]
+ // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
+ // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3]
+ // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x6x35x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x33x62xf32>)
+ %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 1], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x4x33x62xf32>)
+ return
+}
+
+// CHECK-LABEL: @max_pool_i8
+func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () {
+ // CHECK: constant -128
+ // CHECK: linalg.pooling_nhwc_i8_max
+ %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi8>) -> (tensor<1x4x32x62xi8>)
+ return
+}
+
+// CHECK-LABEL: @max_pool_i16
+func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () {
+ // CHECK: constant -32768
+ // CHECK: linalg.pooling_nhwc_i16_max
+ %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi16>) -> (tensor<1x4x32x62xi16>)
+ return
+}
+
+// CHECK-LABEL: @max_pool_i32
+func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
+ // CHECK: constant -2147483648
+ // CHECK: linalg.pooling_nhwc_i32_max
+ %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>)
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 7bd6dcb1404b9..e7b8e3aad9d95 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -340,6 +340,84 @@ func @pooling_nhwc_max(%input: memref<?x?x?x?xf32>, %fake: memref<2x3xf32>, %ini
// -----
+func @pooling_nhwc_i8_max(%input: memref<?x?x?x?xi8>, %fake: memref<2x3xi8>, %init: memref<?x?x?x?xi8>) {
+ linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>}
+ ins(%input, %fake: memref<?x?x?x?xi8>, memref<2x3xi8>)
+ outs(%init: memref<?x?x?x?xi8>)
+ return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
+// CHECK: func @pooling_nhwc_i8_max
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xi8>, memref<2x3xi8>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xi8>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8)
+// CHECK-NEXT: %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i8
+// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i8
+// CHECK-NEXT: linalg.yield %[[RES]] : i8
+
+// -----
+
+func @pooling_nhwc_i16_max(%input: memref<?x?x?x?xi16>, %fake: memref<2x3xi16>, %init: memref<?x?x?x?xi16>) {
+ linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>}
+ ins(%input, %fake: memref<?x?x?x?xi16>, memref<2x3xi16>)
+ outs(%init: memref<?x?x?x?xi16>)
+ return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
+// CHECK: func @pooling_nhwc_i16_max
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xi16>, memref<2x3xi16>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xi16>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i16, %[[BBARG1:.+]]: i16, %[[BBARG2:.+]]: i16)
+// CHECK-NEXT: %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i16
+// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i16
+// CHECK-NEXT: linalg.yield %[[RES]] : i16
+
+// -----
+
+func @pooling_nhwc_i32_max(%input: memref<?x?x?x?xi32>, %fake: memref<2x3xi32>, %init: memref<?x?x?x?xi32>) {
+ linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>}
+ ins(%input, %fake: memref<?x?x?x?xi32>, memref<2x3xi32>)
+ outs(%init: memref<?x?x?x?xi32>)
+ return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
+// CHECK: func @pooling_nhwc_i32_max
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xi32>, memref<2x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xi32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: i32, %[[BBARG2:.+]]: i32)
+// CHECK-NEXT: %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i32
+// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i32
+// CHECK-NEXT: linalg.yield %[[RES]] : i32
+
+// -----
+
func @pooling_nhwc_min(%input: memref<?x?x?x?xf32>, %fake: memref<2x3xf32>, %init: memref<?x?x?x?xf32>) {
linalg.pooling_nhwc_min {dilations = dense<3> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
ins(%input, %fake: memref<?x?x?x?xf32>, memref<2x3xf32>)
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 4e49afb891a7e..c5a623aa15f42 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -344,6 +344,109 @@ func @pooling_nhwc_max(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %out
return
}
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_i8_max_tensor
+// CHECK: %{{.+}} = linalg.pooling_nhwc_i8_max
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME: strides = dense<1> : tensor<2xi64>
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi8>, tensor<3x3xi8>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8>
+func @pooling_nhwc_i8_max_tensor(%input: tensor<1x4x4x1xi8>) -> tensor<1x2x2x1xi8> {
+ %fake = linalg.init_tensor [3, 3] : tensor<3x3xi8>
+ %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi8>
+ %cst = constant 0 : i8
+ %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xi8>, i8 -> tensor<1x2x2x1xi8>
+ %res = linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi8>, tensor<3x3xi8>)
+ outs(%fill: tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8>
+ return %res : tensor<1x2x2x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_i8_max
+// CHECK: linalg.pooling_nhwc_i8_max
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME: strides = dense<1> : tensor<2xi64>
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi8>, memref<3x3xi8>)
+// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xi8>)
+func @pooling_nhwc_i8_max(%input: memref<1x4x4x1xi8>, %fake: memref<3x3xi8>, %output: memref<1x2x2x1xi8>) {
+ linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: memref<1x4x4x1xi8>, memref<3x3xi8>)
+ outs(%output: memref<1x2x2x1xi8>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_i16_max_tensor
+// CHECK: %{{.+}} = linalg.pooling_nhwc_i16_max
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME: strides = dense<1> : tensor<2xi64>
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi16>, tensor<3x3xi16>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16>
+func @pooling_nhwc_i16_max_tensor(%input: tensor<1x4x4x1xi16>) -> tensor<1x2x2x1xi16> {
+ %fake = linalg.init_tensor [3, 3] : tensor<3x3xi16>
+ %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi16>
+ %cst = constant 0 : i16
+ %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xi16>, i16 -> tensor<1x2x2x1xi16>
+ %res = linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi16>, tensor<3x3xi16>)
+ outs(%fill: tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16>
+ return %res : tensor<1x2x2x1xi16>
+}
+
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_i16_max
+// CHECK: linalg.pooling_nhwc_i16_max
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME: strides = dense<1> : tensor<2xi64>
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi16>, memref<3x3xi16>)
+// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xi16>)
+func @pooling_nhwc_i16_max(%input: memref<1x4x4x1xi16>, %fake: memref<3x3xi16>, %output: memref<1x2x2x1xi16>) {
+ linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: memref<1x4x4x1xi16>, memref<3x3xi16>)
+ outs(%output: memref<1x2x2x1xi16>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_i32_max_tensor
+// CHECK: %{{.+}} = linalg.pooling_nhwc_i32_max
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME: strides = dense<1> : tensor<2xi64>
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func @pooling_nhwc_i32_max_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+ %fake = linalg.init_tensor [3, 3] : tensor<3x3xi32>
+ %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi32>
+ %cst = constant 0 : i32
+ %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xi32>, i32 -> tensor<1x2x2x1xi32>
+ %res = linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+ outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ return %res : tensor<1x2x2x1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_i32_max
+// CHECK: linalg.pooling_nhwc_i32_max
+// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME: strides = dense<1> : tensor<2xi64>
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi32>, memref<3x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xi32>)
+func @pooling_nhwc_i32_max(%input: memref<1x4x4x1xi32>, %fake: memref<3x3xi32>, %output: memref<1x2x2x1xi32>) {
+ linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: memref<1x4x4x1xi32>, memref<3x3xi32>)
+ outs(%output: memref<1x2x2x1xi32>)
+ return
+}
+
+
// -----
// CHECK-LABEL: func @pooling_nhwc_min_tensor
More information about the Mlir-commits
mailing list