[Mlir-commits] [mlir] f239026 - [mlir][linalg][python] Add min operation in OpDSL.

Tobias Gysi llvmlistbot at llvm.org
Fri Jul 2 09:31:36 PDT 2021


Author: Tobias Gysi
Date: 2021-07-02T16:27:30Z
New Revision: f239026f89b24e4eeaf16f171f95da53e28f36f0

URL: https://github.com/llvm/llvm-project/commit/f239026f89b24e4eeaf16f171f95da53e28f36f0
DIFF: https://github.com/llvm/llvm-project/commit/f239026f89b24e4eeaf16f171f95da53e28f36f0.diff

LOG: [mlir][linalg][python] Add min operation in OpDSL.

Add the min operation to OpDSL and introduce a min pooling operation to test the implementation. The patch is a sibling of the max operation patch https://reviews.llvm.org/D105203 and the min operation is again lowered to a compare and select pair.

Differential Revision: https://reviews.llvm.org/D105345

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
    mlir/test/python/integration/dialects/linalg/opsrun.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 39045a212ce11..1e4277ecd7bdf 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -664,6 +664,77 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_arg: I
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: pooling_nhwc_min_poly
+  cpp_class_name: PoolingNhwcMinPolyOp
+  doc: |-
+    Performs min pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s2, s3)>
+  - !LinalgOperandDefConfig
+    name: K
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s4, s5)>
+  - !LinalgOperandDefConfig
+    name: O
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s6, s7, s3)>
+  - !LinalgOperandDefConfig
+    name: strides
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s8, s9)>
+  - !LinalgOperandDefConfig
+    name: dilations
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s10, s11)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+      s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+      s10, s11] -> (d3, d4)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+      s10, s11] -> (d0, d1, d2, d5)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - parallel
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: min
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          symbolic_cast:
+            type_var: U
+            operands:
+            - !ScalarExpression
+              scalar_arg: I
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: fill_rng_2d
   cpp_class_name: FillRng2DOp

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9b729b9db5d10..18c55f4019cab 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -275,17 +275,18 @@ class RegionBuilderHelper {
   }
 
   Value applyfn__max(Value lhs, Value rhs) {
-    OpBuilder builder = getBuilder();
-    if (isFloatingPoint(lhs)) {
-      Value condition =
-          builder.create<CmpFOp>(lhs.getLoc(), CmpFPredicate::OGT, lhs, rhs);
-      return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
-    }
-    if (isInteger(lhs)) {
-      Value condition =
-          builder.create<CmpIOp>(lhs.getLoc(), CmpIPredicate::sgt, lhs, rhs);
-      return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
-    }
+    if (isFloatingPoint(lhs))
+      return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OGT);
+    if (isInteger(lhs))
+      return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::sgt);
+    llvm_unreachable("unsupported non numeric type");
+  }
+
+  Value applyfn__min(Value lhs, Value rhs) {
+    if (isFloatingPoint(lhs))
+      return emitCmpFAndSelect(lhs, rhs, CmpFPredicate::OLT);
+    if (isInteger(lhs))
+      return emitCmpIAndSelect(lhs, rhs, CmpIPredicate::slt);
     llvm_unreachable("unsupported non numeric type");
   }
 
@@ -322,6 +323,17 @@ class RegionBuilderHelper {
   MLIRContext *context;
   Block █
 
+  Value emitCmpFAndSelect(Value lhs, Value rhs, CmpFPredicate predicate) {
+    OpBuilder builder = getBuilder();
+    Value condition = builder.create<CmpFOp>(lhs.getLoc(), predicate, lhs, rhs);
+    return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
+  }
+  Value emitCmpIAndSelect(Value lhs, Value rhs, CmpIPredicate predicate) {
+    OpBuilder builder = getBuilder();
+    Value condition = builder.create<CmpIOp>(lhs.getLoc(), predicate, lhs, rhs);
+    return builder.create<SelectOp>(lhs.getLoc(), condition, lhs, rhs);
+  }
+
   bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
   bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 1f9230de397a2..66d7510b68abf 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -339,6 +339,7 @@ class PrimFn:
   log = PrimFnType("log")
   mul = PrimFnType("mul")
   max = PrimFnType("max")
+  min = PrimFnType("min")
   sub = PrimFnType("sub")
 
 
@@ -364,6 +365,7 @@ class ReduceFn:
   add = PrimFn.add.reduce
   mul = PrimFn.mul.reduce
   max = PrimFn.max.reduce
+  min = PrimFn.min.reduce
 
 
 class PrimApply(TensorExpression):

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 9489dec522716..61d2260587116 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -308,17 +308,23 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
     raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
 
   def _eval_max(self, lhs: Value, rhs: Value) -> Value:
-    i1 = IntegerType.get_signless(1)
     if _is_floating_point_type(lhs.type):
       ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
-      cond = std.CmpFOp(i1, ogt_attr, lhs, rhs).result
-      return std.SelectOp(lhs.type, cond, lhs, rhs).result
+      return _emit_cmpf_and_select(lhs, rhs, ogt_attr)
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
-      cond = std.CmpIOp(i1, sgt_attr, lhs, rhs).result
-      return std.SelectOp(lhs.type, cond, lhs, rhs).result
+      return _emit_cmpi_and_select(lhs, rhs, sgt_attr)
     raise NotImplementedError("Unsupported 'max' operand: {lhs}")
 
+  def _eval_min(self, lhs: Value, rhs: Value) -> Value:
+    if _is_floating_point_type(lhs.type):
+      olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4)
+      return _emit_cmpf_and_select(lhs, rhs, olt_attr)
+    if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+      slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2)
+      return _emit_cmpi_and_select(lhs, rhs, slt_attr)
+    raise NotImplementedError("Unsupported 'min' operand: {lhs}")
+
 
 def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
                            in_arg_defs: Sequence[OperandDefConfig],
@@ -397,3 +403,13 @@ def _get_floating_point_width(t: Type) -> int:
   if BF16Type.isinstance(t):
     return 16
   raise NotImplementedError(f"Unhandled floating point type switch {t}")
+
+
+def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
+  cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result
+  return std.SelectOp(lhs.type, cond, lhs, rhs).result
+
+
+def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value:
+  cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result
+  return std.SelectOp(lhs.type, cond, lhs, rhs).result

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 04c950e0a44db..a37e1944c1f75 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -166,6 +166,24 @@ def pooling_nhwc_max_poly(
                 D.c]))
 
 
+ at linalg_structured_op
+def pooling_nhwc_min_poly(
+    I=TensorDef(T1, S.N, S.H, S.W, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  """Performs min pooling.
+
+  Numeric casting is performed on the input operand, promoting it to the same
+  data type as the accumulator/output.
+  """
+  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
+      cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+                D.c]))
+
+
 @linalg_structured_op
 def fill_rng_2d(
     min=ScalarDef(F64),

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 4a1cb8dbcfa58..0e1c6a62a7b10 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -90,6 +90,36 @@ func @generalize_pooling_nhwc_max_poly_i32(%input : tensor<1x4x16x1xi32>, %shape
 
 // -----
 
+func @generalize_pooling_nhwc_min_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
+  %0 = linalg.pooling_nhwc_min_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+    ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
+  return %0: tensor<1x2x4x1xf32>
+}
+
+// CHECK-LABEL: @generalize_pooling_nhwc_min_poly_f32
+// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
+// CHECK-NEXT:   %[[COND:.+]] = cmpf olt, %[[OUT_ARG]], %[[IN_ARG]] : f32
+// CHECK-NEXT:   %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32
+// CHECK-NEXT:   linalg.yield %[[MAX]] : f32
+// CHECK-NEXT: -> tensor<1x2x4x1xf32>
+
+// -----
+
+func @generalize_pooling_nhwc_min_poly_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
+  %0 = linalg.pooling_nhwc_min_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+    ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
+  return %0: tensor<1x2x4x1xi32>
+}
+
+// CHECK-LABEL: @generalize_pooling_nhwc_min_poly_i32
+// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
+// CHECK-NEXT:   %[[COND:.+]] = cmpi slt, %[[OUT_ARG]], %[[IN_ARG]] : i32
+// CHECK-NEXT:   %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32
+// CHECK-NEXT:   linalg.yield %[[MAX]] : i32
+// CHECK-NEXT: -> tensor<1x2x4x1xi32>
+
+// -----
+
 func @generalize_pooling_nhwc_sum_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
   %0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
     ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index 12f6c560cfecc..44ac4e8e8c5b4 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -43,7 +43,7 @@ def conv_poly(
 
 
 @linalg_structured_op
-def pooling_poly(
+def pooling_max_poly(
     I=TensorDef(T1, S.N, S.H, S.W, S.C),
     K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
@@ -55,6 +55,19 @@ def pooling_poly(
                 D.c]))
 
 
+ at linalg_structured_op
+def pooling_min_poly(
+    I=TensorDef(T1, S.N, S.H, S.W, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
+      cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+                D.c]))
+
+
 @linalg_structured_op
 def fill_rng_poly(
     min=ScalarDef(F64),
@@ -216,7 +229,7 @@ def test_f32i32_conv(input, filter, init_result):
       return conv_poly(
           input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2])
 
-    # CHECK-LABEL: @test_f32i32_pooling
+    # CHECK-LABEL: @test_f32i32_max_pooling
     # CHECK: linalg.generic
     # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
     # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
@@ -229,11 +242,11 @@ def test_f32i32_conv(input, filter, init_result):
     @builtin.FuncOp.from_py_func(
         RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
         RankedTensorType.get((2, 4), i32))
-    def test_f32i32_pooling(input, shape, init_result):
-      return pooling_poly(
+    def test_f32i32_max_pooling(input, shape, init_result):
+      return pooling_max_poly(
           input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
 
-    # CHECK-LABEL: @test_f32f32_pooling
+    # CHECK-LABEL: @test_f32f32_max_pooling
     # CHECK: linalg.generic
     # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
     # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
@@ -245,8 +258,26 @@ def test_f32i32_pooling(input, shape, init_result):
     @builtin.FuncOp.from_py_func(
         RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
         RankedTensorType.get((2, 4), f32))
-    def test_f32f32_pooling(input, shape, init_result):
-      return pooling_poly(
+    def test_f32f32_max_pooling(input, shape, init_result):
+      return pooling_max_poly(
+          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+
+    # CHECK-LABEL: @test_f32i32_min_pooling
+    # CHECK:   = cmpi slt,
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
+        RankedTensorType.get((2, 4), i32))
+    def test_f32i32_min_pooling(input, shape, init_result):
+      return pooling_min_poly(
+          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+
+    # CHECK-LABEL: @test_f32f32_min_pooling
+    # CHECK:   = cmpf olt,
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
+        RankedTensorType.get((2, 4), f32))
+    def test_f32f32_min_pooling(input, shape, init_result):
+      return pooling_min_poly(
           input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
 
     # CHECK-LABEL: @test_i32_fill_rng

diff  --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index c6d26d1c6b858..8ec4b6c44da20 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -86,6 +86,8 @@ def log(*args):
 func @main() -> i32 attributes {llvm.emit_c_interface} {
   %v0 = constant 0 : i32
   %v42 = constant 42.0 : f64
+  %v77 = constant 77.0 : f64
+  %v-13 = constant -13.0 : f64
   %v1 = constant 1.0 : f64
 
   %input = memref.alloc() : memref<1x4x16x1xf64>
@@ -96,7 +98,11 @@ def log(*args):
   linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
 
   %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
   memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
+  memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64>
+  memref.store %v-13, %input[%c0, %c0, %c2, %c0] : memref<1x4x16x1xf64>
 
   call @pooling_on_buffers(%input, %shape, %output) :
     (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
@@ -301,7 +307,7 @@ def conv_on_buffers(input, filter, output):
 test_conv_generic()
 
 
-def test_pooling_builtin():
+def test_max_pooling_builtin():
   with Context() as ctx, Location.unknown():
     module = Module.create()
     f64 = F64Type.get()
@@ -325,13 +331,14 @@ def pooling_on_buffers(input, shape, output):
     execution_engine.invoke("main", res)
 
     log("RESULT: ", res[0])
+    # 77 is not selected due to the dilation 2 in the second dimension.
     # CHECK: RESULT: 42
 
 
-test_pooling_builtin()
+test_max_pooling_builtin()
 
 
-def test_pooling_generic():
+def test_max_pooling_generic():
   with Context() as ctx, Location.unknown():
     module = Module.create()
     f64 = F64Type.get()
@@ -360,7 +367,73 @@ def pooling_on_buffers(input, shape, output):
     execution_engine.invoke("main", res)
 
     log("RESULT: ", res[0])
+    # 77 is not selected due to the dilation 2 in the second dimension.
     # CHECK: RESULT: 42
 
 
-test_pooling_generic()
+test_max_pooling_generic()
+
+
+def test_min_pooling_builtin():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f64 = F64Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
+          MemRefType.get((1, 2, 4, 1), i32))
+      def pooling_on_buffers(input, shape, output):
+        linalg.pooling_nhwc_min_poly(
+            input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
+
+    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result i32.
+    # Arguments must be passed as pointers.
+    c_int_p = ctypes.c_int * 1
+    res = c_int_p(-1)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # CHECK: RESULT: -13
+
+
+test_min_pooling_builtin()
+
+
+def test_min_pooling_generic():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f64 = F64Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
+          MemRefType.get((1, 2, 4, 1), i32))
+      def pooling_on_buffers(input, shape, output):
+        linalg.pooling_nhwc_min_poly(
+            input,
+            shape,
+            outs=[output],
+            strides=[2, 4],
+            dilations=[1, 2],
+            emit_generic=True)
+
+    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result i32.
+    # Arguments must be passed as pointers.
+    c_int_p = ctypes.c_int * 1
+    res = c_int_p(-1)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # CHECK: RESULT: -13
+
+
+test_min_pooling_generic()


        


More information about the Mlir-commits mailing list