[Mlir-commits] [mlir] 78dc1e4 - [mlir][linalg][python] Add shape-only tensor support to OpDSL.

Tobias Gysi llvmlistbot at llvm.org
Thu Jun 24 07:11:53 PDT 2021


Author: Tobias Gysi
Date: 2021-06-24T14:11:15Z
New Revision: 78dc1e497807e6b857fde7f78f4bc9cb5a4f8939

URL: https://github.com/llvm/llvm-project/commit/78dc1e497807e6b857fde7f78f4bc9cb5a4f8939
DIFF: https://github.com/llvm/llvm-project/commit/78dc1e497807e6b857fde7f78f4bc9cb5a4f8939.diff

LOG: [mlir][linalg][python] Add shape-only tensor support to OpDSL.

Add an index_dim annotation to specify the shape to loop mapping of shape-only tensors. A shape-only tensor serves is not accessed withing the body of the operation but is required to span the iteration space of certain operations such as pooling.

Differential Revision: https://reviews.llvm.org/D104767

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
    mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
    mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
    mlir/test/python/dialects/linalg/opsrun.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 58872da9b1dab..82e4d01c4a72c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -309,7 +309,11 @@ structured_op: !LinalgStructuredOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv_2d_input_nhwc_filter_hwc_poly
   cpp_class_name: DepthwiseConv2DInputNhwcFilterHwcPolyOp
-  doc: A depth-wise 2-D convolution operation.
+  doc: |-
+    Performs depth-wise 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
 structured_op: !LinalgStructuredOpConfig
   args:
   - !LinalgOperandDefConfig
@@ -317,13 +321,13 @@ structured_op: !LinalgStructuredOpConfig
     usage: InputOperand
     type_var: T1
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
-      (s0, s6, s7, s3)>
+      (s0, s4, s5, s3)>
   - !LinalgOperandDefConfig
     name: K
     usage: InputOperand
     type_var: T2
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
-      (s4, s5, s3)>
+      (s6, s7, s3)>
   - !LinalgOperandDefConfig
     name: O
     usage: OutputOperand
@@ -383,6 +387,77 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: K
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: pooling_nhwc_sum_poly
+  cpp_class_name: PoolingNhwcSumPolyOp
+  doc: |-
+    Performs sum pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s4, s5, s3)>
+  - !LinalgOperandDefConfig
+    name: K
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s10, s11)>
+  - !LinalgOperandDefConfig
+    name: O
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s2, s3)>
+  - !LinalgOperandDefConfig
+    name: strides
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s6, s7)>
+  - !LinalgOperandDefConfig
+    name: dilations
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s8, s9)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+      s10, s11] -> (d2, d3 * s6 + d0 * s8, d4 * s7 + d1 * s9, d5)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+      s10, s11] -> (d0, d1)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+      s10, s11] -> (d2, d3, d4, d5)>
+  iterator_types:
+  - reduction
+  - reduction
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          symbolic_cast:
+            type_var: U
+            operands:
+            - !ScalarExpression
+              scalar_arg: I
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: fill_rng_2d
   cpp_class_name: FillRng2DOp

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 2b2f57248c515..e89885e975d65 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -151,13 +151,15 @@ class OperandDef:
   def __init__(self,
                kind: OperandKind,
                type_var: TypeVar,
-               size_exprs: Optional[Sequence[AffineExprDef]] = None):
+               size_exprs: Optional[Sequence[AffineExprDef]] = None,
+               index_dims: Optional[Sequence[DimDef]] = None):
     if not isinstance(type_var, TypeVar):
       raise ValueError(
           f"OperandDef requires a TypeVar but got {repr(type_var)}")
     self.owner = None  # type: Optional["LinalgOpDef"]
     self.type_var = type_var
     self.size_exprs = size_exprs
+    self.index_dims = index_dims
     self.kind = kind
     self.name = None  # type: Optional[str]
     self.registered_index = -1  # type: int
@@ -174,7 +176,8 @@ def __hash__(self):
 
   def __repr__(self):
     return (f"{self.name}:OperandDef(kind={self.kind.name}, "
-            f"type={repr(self.type_var)}, size_exprs={self.size_exprs})")
+            f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), "
+            f"index_dims={self.index_dims})")
 
 
 class TensorDef:
@@ -184,15 +187,25 @@ class TensorDef:
   to the body of the structured op. A unique name identifies the tensor operands
   and an index determines their position in the operation's parameter list. A
   tensor definition takes type, a shape, and an optional flag to mark output
-  tensors.
+  tensors. Additionally, a tuple of index dimensions may be used to map the
+  tensor to the loop dimensions of the operation. This mapping is needed to
+  compute the indexing map of shape-only tensors that have no uses.
   """
 
   def __init__(self,
                type_var: TypeVar,
                *shape: AffineExprDef,
+               index_dims: Optional[Sequence[DimDef]] = None,
                output: bool = False):
+    if index_dims and len(shape) != len(index_dims):
+      raise ValueError(f"Expected the shape rank {len(shape)} to match the "
+                       f"number of index_dims {len(index_dims)}")
+    if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
+      raise ValueError(f"TensorDef requires index dims of type DimDef but "
+                       f"got {type(index_dims)}")
     kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
-    self.operand_def = OperandDef(kind, type_var, size_exprs=shape)
+    self.operand_def = OperandDef(
+        kind, type_var, size_exprs=shape, index_dims=index_dims)
 
   def __getitem__(self, dims) -> TensorUse:
     assert self.operand_def.owner, "TensorDef is not attached to an op"

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index 773bd876397f9..78e6f1d6a3083 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -138,12 +138,18 @@ def __init__(self,
       read_use.collect_scalar_uses(collected_scalar_uses)
       read_use.collect_indices(collected_indices)
 
-    # Collect all attribute definitions
+    # Collect all attribute definitions.
     collected_attr_defs = list()
     for operand in registered_operands:
       if operand.kind == OperandKind.Attribute:
         collected_attr_defs.append(operand)
 
+    # Collect all tensors with manual indexing annotation.
+    collected_index_defs = list()
+    for operand in registered_operands:
+      if operand.index_dims:
+        collected_index_defs.append(operand)
+
     # Add all definitions before uses, so process twice.
     for use in collected_tensor_uses:
       self.add_operand(use.operand_def)
@@ -151,6 +157,10 @@ def __init__(self,
       self.add_operand(use.operand_def)
     for definition in collected_attr_defs:
       self.add_operand(definition)
+    for definition in collected_index_defs:
+      if definition not in self.operands:
+        self.add_operand(definition)
+      self.add_indexed_operand(definition)
     for use in collected_tensor_uses:
       self.add_tensor_use(use)
 
@@ -158,6 +168,9 @@ def __init__(self,
     # symbols are known.
     for cuse in self.uses.values():
       cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
+    for definition in collected_index_defs:
+      self.operands[definition].indexing_map = self._normalize_affine_map(
+          self.operands[definition].indexing_map)
     for operand_config in self.operands.values():
       if operand_config.shape_map:
         operand_config.shape_map = self._normalize_affine_map(
@@ -278,6 +291,18 @@ def add_operand(self, operand_def: OperandDef):
         self.operands[operand_def] = OperandDefConfig(
             operand_def, shape_map=affine_map)
 
+  def add_indexed_operand(self, operand_def: OperandDef):
+    with self.context:
+      local_state = AffineBuildState(
+          global_state=self.affine_state, allow_new_symbols=False)
+      exprs = []
+      for expr in operand_def.index_dims:
+        exprs.append(expr.build(state=local_state))
+      self.operands[operand_def].indexing_map = _ir.AffineMap.get(
+          dim_count=local_state.dim_count,
+          symbol_count=local_state.symbol_count,
+          exprs=exprs)
+
   def add_tensor_use(self, tensor_use: TensorUse):
     if tensor_use in self.uses:
       return

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index fe8bfc501ebcb..253fca4b41690 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -81,12 +81,32 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
     strides=AttributeDef(S.SH, S.SW),
     dilations=AttributeDef(S.DH, S.DW)):
-  """A depth-wise 2-D convolution operation."""
+  """Performs depth-wise 2-D convolution.
+
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output.
+  """
   O[D.n, D.oh, D.ow, D.c] += cast(
       U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
            D.c]) * cast(U, K[D.kh, D.kw, D.c])
 
 
+ at linalg_structured_op
+def pooling_nhwc_sum_poly(
+    I=TensorDef(T1, S.N, S.H, S.W, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  """Performs sum pooling.
+
+  Numeric casting is performed on the input operand, promoting it to the same
+  data type as the accumulator/output.
+  """
+  O[D.n, D.oh, D.ow, D.c] += cast(
+      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
+
+
 @linalg_structured_op
 def fill_rng_2d(
     min=ScalarDef(F64),

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index b40ab139c3e73..723859c913c04 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -60,6 +60,34 @@ func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32(%input : tenso
 
 // -----
 
+func @generalize_pooling_nhwc_sum_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
+  %0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+    ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
+  return %0: tensor<1x2x4x1xf32>
+}
+
+// CHECK-LABEL: @generalize_pooling_nhwc_sum_poly_f32
+// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
+// CHECK-NEXT:   %[[ADD:.+]] = addf %[[OUT_ARG]], %[[IN_ARG]] : f32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<1x2x4x1xf32>
+
+// -----
+
+func @generalize_pooling_nhwc_sum_poly_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
+  %0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+    ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
+  return %0: tensor<1x2x4x1xi32>
+}
+
+// CHECK-LABEL: @generalize_pooling_nhwc_sum_poly_i32
+// CHECK:      ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
+// CHECK-NEXT:   %[[ADD:.+]] = addi %[[OUT_ARG]], %[[IN_ARG]] : i32
+// CHECK-NEXT:   linalg.yield %[[ADD]] : i32
+// CHECK-NEXT: -> tensor<1x2x4x1xi32>
+
+// -----
+
 func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
   %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
   return %0: tensor<16x32xf32>

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index 0ed32fe4fb293..cbe88dd043f73 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -40,7 +40,18 @@ def conv_poly(
 
 
 @linalg_structured_op
-def fill_rng(
+def pooling_poly(
+    I=TensorDef(T1, S.N, S.H, S.W, S.C),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  O[D.n, D.oh, D.ow, D.c] += cast(
+      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
+
+
+ at linalg_structured_op
+def fill_rng_poly(
     min=ScalarDef(F64),
     max=ScalarDef(F64),
     seed=ScalarDef(I32),
@@ -65,16 +76,22 @@ def fill_rng(
   i32 = IntegerType.get_signless(32)
   with InsertionPoint(module.body):
 
-    # Note that these all have the same indexing maps. We verify the first and
-    # then do more permutation tests on casting and body generation
-    # behavior.
-    # CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
-    # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
-    # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+    # Multiplication indexing maps. We verify only the indexing maps of the
+    # first multiplication and then do additional tests on casting and body
+    # generation behavior.
+    # CHECK: #[[$MUL_MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+    # CHECK: #[[$MUL_MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+    # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+    # Convolution indexing maps.
+    # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 4 + d5 * 2, d3)>
+    # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+    # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
 
-    # CHECK: #[[$MAPI:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 4 + d5 * 2, d3)>
-    # CHECK: #[[$MAPK:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
-    # CHECK: #[[$MAPO:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+    # Pooling indexing maps.
+    # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3 * 2 + d0, d4 * 4 + d1 * 2, d5)>
+    # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>
+    # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
 
     # CHECK-LABEL: func @test_matmul_mono
     # CHECK-SAME:  %[[A:.+]]: tensor<4x16xf32>
@@ -82,7 +99,7 @@ def fill_rng(
 
     # CHECK: %[[INITC:.+]] = linalg.init_tensor [4, 8] : tensor<4x8xf32>
     # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$MAPA]], #[[$MAPB]], #[[$MAPC]]]
+    # CHECK-SAME: indexing_maps = [#[[$MUL_MAP_A]], #[[$MUL_MAP_B]], #[[$MUL_MAP_C]]]
     # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
     # CHECK-SAME: ins(%[[A]], %[[B]]
     # CHECK-SAME: outs(%[[INITC]]
@@ -177,28 +194,9 @@ def test_f16f16f32_matmul(lhs, rhs, init_result):
     def test_f64f64f32_matmul(lhs, rhs, init_result):
       return matmul_poly(lhs, rhs, outs=[init_result])
 
-    # CHECK-LABEL: @test_fill_rng
-    # CHECK:      ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %{{.*}}
-    # CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
-    # CHECK-DAG:    %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
-    # CHECK-DAG:    %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
-    # CHECK-DAG:    %[[CST0:.+]] = constant 1103515245 : i64
-    # CHECK-DAG:    %[[CST0_CAST:.+]] = trunci %[[CST0]] : i64 to i32
-    # Skip the remaining random number computation and match the scaling logic.
-    # CHECK-DAG:    %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
-    # CHECK-DAG:    %[[CST3:.+]] = constant 2.3283063999999999E-10 : f64
-    # CHECK-DAG:    %[[FACT:.+]] = mulf %[[DIFF]], %[[CST3]] : f64
-    # CHECK-DAG:    %[[RND4:.+]] = mulf %{{.+}}, %[[FACT]] : f64
-    # CHECK-DAG:    %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64
-    # CHECK-DAG:    %{{.*}} = fptosi %[[RND5]] : f64 to i32
-    @builtin.FuncOp.from_py_func(f64, f64, i32,
-                                 RankedTensorType.get((4, 16), i32))
-    def test_fill_rng(min, max, seed, init_result):
-      return fill_rng(min, max, seed, outs=[init_result])
-
     # CHECK-LABEL: @test_f32i32_conv
     # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$MAPI]], #[[$MAPK]], #[[$MAPO]]]
+    # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]]
     # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]
     # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32)
     # CHECK-NEXT:   %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
@@ -215,5 +213,40 @@ def test_f32i32_conv(input, filter, init_result):
       return conv_poly(
           input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2])
 
+    # CHECK-LABEL: @test_f32i32_pooling
+    # CHECK: linalg.generic
+    # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
+    # CHECK-SAME: iterator_types = ["reduction", "reduction", "parallel", "parallel", "parallel", "parallel"]
+    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
+    # CHECK-NEXT:   %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
+    # CHECK-NEXT:   %[[SUM:.+]] = addi %[[OUT]], %[[IN_CAST]] : i32
+    # CHECK-NEXT:   linalg.yield %[[SUM]] : i32
+    # CHECK-NEXT: -> tensor<2x4xi32>
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
+        RankedTensorType.get((2, 4), i32))
+    def test_f32i32_pooling(input, shape, init_result):
+      return pooling_poly(
+          input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+
+    # CHECK-LABEL: @test_i32_fill_rng
+    # CHECK:      ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %{{.*}}
+    # CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
+    # CHECK-DAG:    %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
+    # CHECK-DAG:    %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
+    # CHECK-DAG:    %[[CST0:.+]] = constant 1103515245 : i64
+    # CHECK-DAG:    %[[CST0_CAST:.+]] = trunci %[[CST0]] : i64 to i32
+    # Skip the remaining random number computation and match the scaling logic.
+    # CHECK-DAG:    %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
+    # CHECK-DAG:    %[[CST3:.+]] = constant 2.3283063999999999E-10 : f64
+    # CHECK-DAG:    %[[FACT:.+]] = mulf %[[DIFF]], %[[CST3]] : f64
+    # CHECK-DAG:    %[[RND4:.+]] = mulf %{{.+}}, %[[FACT]] : f64
+    # CHECK-DAG:    %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64
+    # CHECK-DAG:    %{{.*}} = fptosi %[[RND5]] : f64 to i32
+    @builtin.FuncOp.from_py_func(f64, f64, i32,
+                                 RankedTensorType.get((4, 16), i32))
+    def test_i32_fill_rng(min, max, seed, init_result):
+      return fill_rng_poly(min, max, seed, outs=[init_result])
+
 
 print(module)

diff  --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
index 3132c90046df7..2933852f97cfe 100644
--- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
+++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
@@ -42,3 +42,23 @@ def matmul(
 @linalg_structured_op
 def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
   C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
+
+# Verifies that the index_dims of shape-only operands translate to correct
+# indexing maps.
+# CHECK: ---
+# CHECK-LABEL: pool
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0)>
+# CHECK: static_indexing_maps:
+# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d1 * 2 + d0)>
+# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0)>
+# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d1)>
+# CHECK: iterator_types:
+# CHECK-NEXT: - reduction
+# CHECK-NEXT: - parallel
+ at linalg_structured_op
+def pool(I=TensorDef(T, S.I),
+         K=TensorDef(T, S.K, index_dims=[D.k]),
+         O=TensorDef(U, S.O, output=True)):
+  O[D.o] += cast(U, I[D.o * 2 + D.k])

diff  --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py
index e315a5fe9889e..08b13a5352984 100644
--- a/mlir/test/python/dialects/linalg/opsrun.py
+++ b/mlir/test/python/dialects/linalg/opsrun.py
@@ -82,6 +82,28 @@ def log(*args):
 }
 """
 
+pooling_boiler = """
+func @main() -> i32 attributes {llvm.emit_c_interface} {
+  %v0 = constant 0 : i32
+  %v1 = constant 1.0 : f64
+
+  %input = memref.alloc() : memref<1x4x16x1xf64>
+  %shape = memref.alloc() : memref<2x2xf64>
+  %output = memref.alloc() : memref<1x2x4x1xi32>
+  linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64>
+  linalg.fill(%v1, %shape) : f64, memref<2x2xf64>
+  linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
+
+  call @pooling_on_buffers(%input, %shape, %output) :
+    (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
+
+  %c0 = constant 0 : index
+  %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
+
+  // TODO: FFI-based solution to allow testing and printing with python code.
+  return %0 : i32
+}
+"""
 
 def transform(module, boilerplate):
   import mlir.conversions
@@ -273,3 +295,72 @@ def conv_on_buffers(input, filter, output):
 
 
 test_conv_generic()
+
+
+def test_pooling_builtin():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f64 = F64Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
+          MemRefType.get((1, 2, 4, 1), i32))
+      def pooling_on_buffers(input, shape, output):
+        linalg.pooling_nhwc_sum_poly(
+            input,
+            shape,
+            outs=[output],
+            strides=[2, 4],
+            dilations=[1, 2])
+
+    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result i32.
+    # Arguments must be passed as pointers.
+    c_int_p = ctypes.c_int * 1
+    res = c_int_p(-1)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # CHECK: RESULT: 4
+
+
+test_pooling_builtin()
+
+
+def test_pooling_generic():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f64 = F64Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
+          MemRefType.get((1, 2, 4, 1), i32))
+      def pooling_on_buffers(input, shape, output):
+        linalg.pooling_nhwc_sum_poly(
+            input,
+            shape,
+            outs=[output],
+            strides=[2, 4],
+            dilations=[1, 2],
+            emit_generic=True)
+
+    execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result i32.
+    # Arguments must be passed as pointers.
+    c_int_p = ctypes.c_int * 1
+    res = c_int_p(-1)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # CHECK: RESULT: 4
+
+
+test_pooling_generic()


        


More information about the Mlir-commits mailing list