[Mlir-commits] [mlir] ceeb5b0 - [mlir][tosa] Add tosa.max_pool2d lowering to linalg int max pooling additions

Rob Suderman llvmlistbot at llvm.org
Thu Apr 8 18:19:02 PDT 2021


Author: Rob Suderman
Date: 2021-04-08T18:17:16-07:00
New Revision: ceeb5b0f87a3f564026603c70d970eb7c1b6872e

URL: https://github.com/llvm/llvm-project/commit/ceeb5b0f87a3f564026603c70d970eb7c1b6872e
DIFF: https://github.com/llvm/llvm-project/commit/ceeb5b0f87a3f564026603c70d970eb7c1b6872e.diff

LOG: [mlir][tosa] Add tosa.max_pool2d lowering to linalg int max pooling additions

Lowerings tosa.max_pool2d to linalg equivalent operations. Includes
adding max pooling operations for linalg, with corresponding tests.

Differential Revision: https://reviews.llvm.org/D99824

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
    mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir
    mlir/test/Dialect/Linalg/named-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index ae29a1fe6fd91..e0fc25992a491 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -352,6 +352,51 @@ def pooling_nhwc_sum
                                        ow * strides[1] + kw * dilations[1], c));
 }
 
+ods_def<PoolingNHWCMaxI8Op>:
+def pooling_nhwc_i8_max
+    (I: i8(N, H, W, C), K: i8(KH, KW))
+  -> (O: i8(N, OH, OW, C))
+  attr(strides: 2xi64, dilations: 2xi64)
+{
+  O(n, oh, ow, c) =
+      std_select<kh, kw>(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0],
+                                        ow * strides[1] + kw * dilations[1], c),
+                                      O(n, oh, ow, c)),
+                         I(n, oh * strides[0] + kh * dilations[0],
+                           ow * strides[1] + kw * dilations[1], c),
+                         O(n, oh, ow, c));
+}
+
+ods_def<PoolingNHWCMaxI16Op>:
+def pooling_nhwc_i16_max
+    (I: i16(N, H, W, C), K: i16(KH, KW))
+  -> (O: i16(N, OH, OW, C))
+  attr(strides: 2xi64, dilations: 2xi64)
+{
+  O(n, oh, ow, c) =
+      std_select<kh, kw>(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0],
+                                        ow * strides[1] + kw * dilations[1], c),
+                                      O(n, oh, ow, c)),
+                         I(n, oh * strides[0] + kh * dilations[0],
+                           ow * strides[1] + kw * dilations[1], c),
+                         O(n, oh, ow, c));
+}
+
+ods_def<PoolingNHWCMaxI32Op>:
+def pooling_nhwc_i32_max
+    (I: i32(N, H, W, C), K: i32(KH, KW))
+  -> (O: i32(N, OH, OW, C))
+  attr(strides: 2xi64, dilations: 2xi64)
+{
+  O(n, oh, ow, c) =
+      std_select<kh, kw>(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0],
+                                        ow * strides[1] + kw * dilations[1], c),
+                                      O(n, oh, ow, c)),
+                         I(n, oh * strides[0] + kh * dilations[0],
+                           ow * strides[1] + kw * dilations[1], c),
+                         O(n, oh, ow, c));
+}
+
 ods_def<PoolingNHWCMaxFOp>:
 def pooling_nhwc_max
     (I: f32(N, H, W, C), K: f32(KH, KW))

diff  --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
index d100cf7d9bc2e..17d461a708ca4 100644
--- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
+++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
@@ -59,6 +59,15 @@ struct CmpFValueBuilder : public ValueBuilder<CmpFOp> {
 using std_cmpf_ogt = CmpFValueBuilder<CmpFPredicate::OGT>;
 using std_cmpf_olt = CmpFValueBuilder<CmpFPredicate::OLT>;
 
+template <CmpIPredicate Predicate>
+struct CmpIValueBuilder : public ValueBuilder<CmpIOp> {
+  using ValueBuilder<CmpIOp>::ValueBuilder;
+  template <typename... Args>
+  CmpIValueBuilder(Args... args) : ValueBuilder<CmpIOp>(Predicate, args...) {}
+};
+
+using std_cmpi_sgt = CmpIValueBuilder<CmpIPredicate::sgt>;
+
 /// Branches into `block` with `operands`.
 BranchOp std_br(Block *block, ValueRange operands);
 

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a6271f7097563..ef317aa3e9b46 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1230,6 +1230,22 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
           "Pad converter requires static shaped input / padding values.");
     }
 
+    Attribute constantAttr;
+    if (elementTy.isa<FloatType>())
+      constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
+    else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
+      constantAttr = rewriter.getIntegerAttr(elementTy, 0);
+    else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
+      auto value = padOp.quantization_info().getValue().input_zp().getValue();
+      constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
+    }
+
+    if (!constantAttr) {
+      return rewriter.notifyMatchFailure(
+          padOp,
+          "tosa.pad to linalg lowering encountered an unknown element type");
+    }
+
     Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
     Value highIndex =
         rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
@@ -1256,22 +1272,6 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
       highValues.push_back(highVal);
     }
 
-    Attribute constantAttr;
-    if (elementTy.isa<FloatType>())
-      constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
-    else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
-      constantAttr = rewriter.getIntegerAttr(elementTy, 0);
-    else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
-      auto value = padOp.quantization_info().getValue().input_zp().getValue();
-      constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
-    }
-
-    if (!constantAttr) {
-      return rewriter.notifyMatchFailure(
-          padOp,
-          "tosa.pad to linalg lowering encountered an unknown element type");
-    }
-
     Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
 
     auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
@@ -1523,6 +1523,128 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
   }
 };
 
+class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
+public:
+  using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
+                                PatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    Value input = op.input();
+    ShapedType inputTy = input.getType().cast<ShapedType>();
+    Type inElementTy = inputTy.getElementType();
+
+    ShapedType resultTy = op.getType().cast<ShapedType>();
+    Type outElementTy = inputTy.getElementType();
+    int64_t rank = inputTy.getRank();
+
+    if (!inputTy.hasStaticShape())
+      return failure();
+
+    // Determine what the initial value needs to be for the max pool op.
+    Attribute initialAttr;
+    if (outElementTy.isF32())
+      initialAttr = rewriter.getFloatAttr(
+          outElementTy,
+          APFloat::getLargest(
+              outElementTy.cast<FloatType>().getFloatSemantics(), true));
+
+    if (outElementTy.isa<IntegerType>())
+      initialAttr = rewriter.getIntegerAttr(
+          outElementTy,
+          APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth()));
+
+    if (!initialAttr)
+      return rewriter.notifyMatchFailure(
+          op, "Unsupported initial value for tosa.maxpool_2d op");
+
+    Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
+
+    SmallVector<int64_t> kernel, stride, pad;
+    getValuesFromIntArrayAttribute(op.kernel(), kernel);
+    getValuesFromIntArrayAttribute(op.stride(), stride);
+    getValuesFromIntArrayAttribute(op.pad(), pad);
+
+    Attribute strideAttr = rewriter.getI64VectorAttr(stride);
+    Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
+
+    // If non-zero padding we need to pad the input
+    if (llvm::any_of(pad, [](int64_t v) { return v != 0; })) {
+      SmallVector<int64_t, 4> paddedShape;
+      for (int64_t i = 0; i < rank; i++)
+        paddedShape.push_back(inputTy.getDimSize(i));
+
+      paddedShape[1] += pad[0] + pad[1];
+      paddedShape[2] += pad[2] + pad[3];
+
+      OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
+      OpFoldResult heightLowPadIndex = rewriter.getIndexAttr(pad[0]);
+      OpFoldResult heightHighPadIndex = rewriter.getIndexAttr(pad[1]);
+      OpFoldResult widthLowPadIndex = rewriter.getIndexAttr(pad[2]);
+      OpFoldResult widthHighPadIndex = rewriter.getIndexAttr(pad[3]);
+
+      SmallVector<OpFoldResult, 4> lowIndices = {zeroIndex, heightLowPadIndex,
+                                                 widthLowPadIndex, zeroIndex};
+      SmallVector<OpFoldResult, 4> highIndices = {zeroIndex, heightHighPadIndex,
+                                                  widthHighPadIndex, zeroIndex};
+
+      input = linalg::PadTensorOp::createPadScalarOp(
+                  RankedTensorType::get(paddedShape, inElementTy), input,
+                  initialValue, lowIndices, highIndices, loc, rewriter)
+                  .result();
+    }
+
+    Value initTensor = rewriter.create<linalg::InitTensorOp>(
+        loc, resultTy.getShape(), resultTy.getElementType());
+
+    Value filledInitTensor =
+        rewriter.create<linalg::FillOp>(loc, initTensor, initialValue).result();
+
+    Value fakeWindowDims =
+        rewriter.create<linalg::InitTensorOp>(loc, kernel, outElementTy);
+
+    auto createOp = [&](auto *typePtr) -> linalg::LinalgOp {
+      return cast<linalg::LinalgOp>(
+          rewriter
+              .create<std::remove_pointer_t<decltype(typePtr)>>(
+                  loc, ArrayRef<Type>{resultTy},
+                  ValueRange{input, fakeWindowDims}, filledInitTensor,
+                  dilationAttr, strideAttr)
+              .getOperation());
+    };
+
+    if (inElementTy.isF32()) {
+      linalg::LinalgOp poolingOp =
+          createOp(static_cast<linalg::PoolingNHWCMaxFOp *>(nullptr));
+      rewriter.replaceOp(op, poolingOp->getResult(0));
+      return success();
+    }
+
+    if (inElementTy.isInteger(8)) {
+      linalg::LinalgOp poolingOp =
+          createOp(static_cast<linalg::PoolingNHWCMaxI8Op *>(nullptr));
+      rewriter.replaceOp(op, poolingOp->getResult(0));
+      return success();
+    }
+
+    if (inElementTy.isInteger(16)) {
+      linalg::LinalgOp poolingOp =
+          createOp(static_cast<linalg::PoolingNHWCMaxI16Op *>(nullptr));
+      rewriter.replaceOp(op, poolingOp->getResult(0));
+      return success();
+    }
+
+    if (inElementTy.isInteger(32)) {
+      linalg::LinalgOp poolingOp =
+          createOp(static_cast<linalg::PoolingNHWCMaxI32Op *>(nullptr));
+      rewriter.replaceOp(op, poolingOp->getResult(0));
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
@@ -1579,6 +1701,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
       TileConverter,
       TransposeConverter,
       MatMulConverter,
+      MaxPool2dConverter,
       FullyConnectedConverter>(patterns->getContext());
       // clang-format on
 }

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 5d77c932bf121..b33c18f46ddf3 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -873,3 +873,53 @@ func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () {
   %0 = "tosa.table"(%arg0, %arg1)  : (tensor<6xi16>, tensor<513xi16>)  -> (tensor<6xi32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: @max_pool
+func @max_pool(%arg0: tensor<1x6x34x62xf32>) -> () {
+  // CHECK-DAG: [[CONST:%.+]] = constant -3.40282347E+38
+  // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 32, 62]
+  // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
+  // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3]
+  // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x32x62xf32>)
+  %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>)  -> (tensor<1x4x32x62xf32>)
+  return
+}
+
+// CHECK-LABEL: @max_pool_padded
+func @max_pool_padded(%arg0: tensor<1x6x34x62xf32>) -> () {
+  // CHECK-DAG: [[CONST:%.+]] = constant -3.40282347E+38 : f32
+  // CHECK-DAG: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 0, 0, 0] high[0, 0, 1, 0]
+  // CHECK-DAG:   linalg.yield [[CONST]]
+  // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 33, 62]
+  // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
+  // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3]
+  // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x6x35x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x33x62xf32>)
+  %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 1], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>)  -> (tensor<1x4x33x62xf32>)
+  return
+}
+
+// CHECK-LABEL: @max_pool_i8
+func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () {
+  // CHECK: constant -128
+  // CHECK: linalg.pooling_nhwc_i8_max
+  %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi8>)  -> (tensor<1x4x32x62xi8>)
+  return
+}
+
+// CHECK-LABEL: @max_pool_i16
+func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () {
+  // CHECK: constant -32768
+  // CHECK: linalg.pooling_nhwc_i16_max
+  %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi16>)  -> (tensor<1x4x32x62xi16>)
+  return
+}
+
+// CHECK-LABEL: @max_pool_i32
+func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
+  // CHECK: constant -2147483648
+  // CHECK: linalg.pooling_nhwc_i32_max
+  %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi32>)  -> (tensor<1x4x32x62xi32>)
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index 7bd6dcb1404b9..e7b8e3aad9d95 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -340,6 +340,84 @@ func @pooling_nhwc_max(%input: memref<?x?x?x?xf32>, %fake: memref<2x3xf32>, %ini
 
 // -----
 
+func @pooling_nhwc_i8_max(%input: memref<?x?x?x?xi8>, %fake: memref<2x3xi8>, %init: memref<?x?x?x?xi8>) {
+  linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>}
+    ins(%input, %fake: memref<?x?x?x?xi8>, memref<2x3xi8>)
+    outs(%init: memref<?x?x?x?xi8>)
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
+// CHECK: func @pooling_nhwc_i8_max
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xi8>, memref<2x3xi8>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xi8>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8)
+// CHECK-NEXT:      %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i8
+// CHECK-NEXT:      %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i8
+// CHECK-NEXT:      linalg.yield %[[RES]] : i8
+
+// -----
+
+func @pooling_nhwc_i16_max(%input: memref<?x?x?x?xi16>, %fake: memref<2x3xi16>, %init: memref<?x?x?x?xi16>) {
+  linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>}
+    ins(%input, %fake: memref<?x?x?x?xi16>, memref<2x3xi16>)
+    outs(%init: memref<?x?x?x?xi16>)
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
+// CHECK: func @pooling_nhwc_i16_max
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xi16>, memref<2x3xi16>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xi16>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i16, %[[BBARG1:.+]]: i16, %[[BBARG2:.+]]: i16)
+// CHECK-NEXT:      %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i16
+// CHECK-NEXT:      %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i16
+// CHECK-NEXT:      linalg.yield %[[RES]] : i16
+
+// -----
+
+func @pooling_nhwc_i32_max(%input: memref<?x?x?x?xi32>, %fake: memref<2x3xi32>, %init: memref<?x?x?x?xi32>) {
+  linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>}
+    ins(%input, %fake: memref<?x?x?x?xi32>, memref<2x3xi32>)
+    outs(%init: memref<?x?x?x?xi32>)
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
+// CHECK: func @pooling_nhwc_i32_max
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xi32>, memref<2x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xi32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: i32, %[[BBARG2:.+]]: i32)
+// CHECK-NEXT:      %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i32
+// CHECK-NEXT:      %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i32
+// CHECK-NEXT:      linalg.yield %[[RES]] : i32
+
+// -----
+
 func @pooling_nhwc_min(%input: memref<?x?x?x?xf32>, %fake: memref<2x3xf32>, %init: memref<?x?x?x?xf32>) {
   linalg.pooling_nhwc_min {dilations = dense<3> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
     ins(%input, %fake: memref<?x?x?x?xf32>, memref<2x3xf32>)

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 4e49afb891a7e..c5a623aa15f42 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -344,6 +344,109 @@ func @pooling_nhwc_max(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %out
   return
 }
 
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_i8_max_tensor
+// CHECK:         %{{.+}} = linalg.pooling_nhwc_i8_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi8>, tensor<3x3xi8>)
+// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8>
+func @pooling_nhwc_i8_max_tensor(%input: tensor<1x4x4x1xi8>) -> tensor<1x2x2x1xi8> {
+  %fake = linalg.init_tensor [3, 3] : tensor<3x3xi8>
+  %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi8>
+  %cst = constant 0 : i8
+  %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xi8>, i8 -> tensor<1x2x2x1xi8>
+  %res = linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: tensor<1x4x4x1xi8>, tensor<3x3xi8>)
+    outs(%fill: tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8>
+  return %res : tensor<1x2x2x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_i8_max
+// CHECK:         linalg.pooling_nhwc_i8_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi8>, memref<3x3xi8>)
+// CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x1xi8>)
+func @pooling_nhwc_i8_max(%input: memref<1x4x4x1xi8>, %fake: memref<3x3xi8>, %output: memref<1x2x2x1xi8>) {
+  linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: memref<1x4x4x1xi8>, memref<3x3xi8>)
+    outs(%output: memref<1x2x2x1xi8>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_i16_max_tensor
+// CHECK:         %{{.+}} = linalg.pooling_nhwc_i16_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi16>, tensor<3x3xi16>)
+// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16>
+func @pooling_nhwc_i16_max_tensor(%input: tensor<1x4x4x1xi16>) -> tensor<1x2x2x1xi16> {
+  %fake = linalg.init_tensor [3, 3] : tensor<3x3xi16>
+  %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi16>
+  %cst = constant 0 : i16
+  %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xi16>, i16 -> tensor<1x2x2x1xi16>
+  %res = linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: tensor<1x4x4x1xi16>, tensor<3x3xi16>)
+    outs(%fill: tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16>
+  return %res : tensor<1x2x2x1xi16>
+}
+
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_i16_max
+// CHECK:         linalg.pooling_nhwc_i16_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi16>, memref<3x3xi16>)
+// CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x1xi16>)
+func @pooling_nhwc_i16_max(%input: memref<1x4x4x1xi16>, %fake: memref<3x3xi16>, %output: memref<1x2x2x1xi16>) {
+  linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: memref<1x4x4x1xi16>, memref<3x3xi16>)
+    outs(%output: memref<1x2x2x1xi16>)
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_i32_max_tensor
+// CHECK:         %{{.+}} = linalg.pooling_nhwc_i32_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func @pooling_nhwc_i32_max_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+  %fake = linalg.init_tensor [3, 3] : tensor<3x3xi32>
+  %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi32>
+  %cst = constant 0 : i32
+  %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xi32>, i32 -> tensor<1x2x2x1xi32>
+  %res = linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+    outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  return %res : tensor<1x2x2x1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_i32_max
+// CHECK:         linalg.pooling_nhwc_i32_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi32>, memref<3x3xi32>)
+// CHECK-SAME:      outs(%{{.+}} : memref<1x2x2x1xi32>)
+func @pooling_nhwc_i32_max(%input: memref<1x4x4x1xi32>, %fake: memref<3x3xi32>, %output: memref<1x2x2x1xi32>) {
+  linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: memref<1x4x4x1xi32>, memref<3x3xi32>)
+    outs(%output: memref<1x2x2x1xi32>)
+  return
+}
+
+
 // -----
 
 // CHECK-LABEL: func @pooling_nhwc_min_tensor


        


More information about the Mlir-commits mailing list