[Mlir-commits] [mlir] 78dc1e4 - [mlir][linalg][python] Add shape-only tensor support to OpDSL.
Tobias Gysi
llvmlistbot at llvm.org
Thu Jun 24 07:11:53 PDT 2021
Author: Tobias Gysi
Date: 2021-06-24T14:11:15Z
New Revision: 78dc1e497807e6b857fde7f78f4bc9cb5a4f8939
URL: https://github.com/llvm/llvm-project/commit/78dc1e497807e6b857fde7f78f4bc9cb5a4f8939
DIFF: https://github.com/llvm/llvm-project/commit/78dc1e497807e6b857fde7f78f4bc9cb5a4f8939.diff
LOG: [mlir][linalg][python] Add shape-only tensor support to OpDSL.
Add an index_dim annotation to specify the shape to loop mapping of shape-only tensors. A shape-only tensor serves is not accessed withing the body of the operation but is required to span the iteration space of certain operations such as pooling.
Differential Revision: https://reviews.llvm.org/D104767
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
mlir/test/python/dialects/linalg/opsrun.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 58872da9b1dab..82e4d01c4a72c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -309,7 +309,11 @@ structured_op: !LinalgStructuredOpConfig
metadata: !LinalgOpMetadata
name: depthwise_conv_2d_input_nhwc_filter_hwc_poly
cpp_class_name: DepthwiseConv2DInputNhwcFilterHwcPolyOp
- doc: A depth-wise 2-D convolution operation.
+ doc: |-
+ Performs depth-wise 2-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
@@ -317,13 +321,13 @@ structured_op: !LinalgStructuredOpConfig
usage: InputOperand
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
- (s0, s6, s7, s3)>
+ (s0, s4, s5, s3)>
- !LinalgOperandDefConfig
name: K
usage: InputOperand
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
- (s4, s5, s3)>
+ (s6, s7, s3)>
- !LinalgOperandDefConfig
name: O
usage: OutputOperand
@@ -383,6 +387,77 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: K
--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: pooling_nhwc_sum_poly
+ cpp_class_name: PoolingNhwcSumPolyOp
+ doc: |-
+ Performs sum pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ usage: InputOperand
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s4, s5, s3)>
+ - !LinalgOperandDefConfig
+ name: K
+ usage: InputOperand
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s10, s11)>
+ - !LinalgOperandDefConfig
+ name: O
+ usage: OutputOperand
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s1, s2, s3)>
+ - !LinalgOperandDefConfig
+ name: strides
+ usage: IndexAttribute
+ type_var: I64
+ attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s6, s7)>
+ - !LinalgOperandDefConfig
+ name: dilations
+ usage: IndexAttribute
+ type_var: I64
+ attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s8, s9)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+ s10, s11] -> (d2, d3 * s6 + d0 * s8, d4 * s7 + d1 * s9, d5)>
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+ s10, s11] -> (d0, d1)>
+ - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+ s10, s11] -> (d2, d3, d4, d5)>
+ iterator_types:
+ - reduction
+ - reduction
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: O
+ - !ScalarExpression
+ symbolic_cast:
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: fill_rng_2d
cpp_class_name: FillRng2DOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 2b2f57248c515..e89885e975d65 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -151,13 +151,15 @@ class OperandDef:
def __init__(self,
kind: OperandKind,
type_var: TypeVar,
- size_exprs: Optional[Sequence[AffineExprDef]] = None):
+ size_exprs: Optional[Sequence[AffineExprDef]] = None,
+ index_dims: Optional[Sequence[DimDef]] = None):
if not isinstance(type_var, TypeVar):
raise ValueError(
f"OperandDef requires a TypeVar but got {repr(type_var)}")
self.owner = None # type: Optional["LinalgOpDef"]
self.type_var = type_var
self.size_exprs = size_exprs
+ self.index_dims = index_dims
self.kind = kind
self.name = None # type: Optional[str]
self.registered_index = -1 # type: int
@@ -174,7 +176,8 @@ def __hash__(self):
def __repr__(self):
return (f"{self.name}:OperandDef(kind={self.kind.name}, "
- f"type={repr(self.type_var)}, size_exprs={self.size_exprs})")
+ f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), "
+ f"index_dims={self.index_dims})")
class TensorDef:
@@ -184,15 +187,25 @@ class TensorDef:
to the body of the structured op. A unique name identifies the tensor operands
and an index determines their position in the operation's parameter list. A
tensor definition takes type, a shape, and an optional flag to mark output
- tensors.
+ tensors. Additionally, a tuple of index dimensions may be used to map the
+ tensor to the loop dimensions of the operation. This mapping is needed to
+ compute the indexing map of shape-only tensors that have no uses.
"""
def __init__(self,
type_var: TypeVar,
*shape: AffineExprDef,
+ index_dims: Optional[Sequence[DimDef]] = None,
output: bool = False):
+ if index_dims and len(shape) != len(index_dims):
+ raise ValueError(f"Expected the shape rank {len(shape)} to match the "
+ f"number of index_dims {len(index_dims)}")
+ if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
+ raise ValueError(f"TensorDef requires index dims of type DimDef but "
+ f"got {type(index_dims)}")
kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
- self.operand_def = OperandDef(kind, type_var, size_exprs=shape)
+ self.operand_def = OperandDef(
+ kind, type_var, size_exprs=shape, index_dims=index_dims)
def __getitem__(self, dims) -> TensorUse:
assert self.operand_def.owner, "TensorDef is not attached to an op"
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index 773bd876397f9..78e6f1d6a3083 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -138,12 +138,18 @@ def __init__(self,
read_use.collect_scalar_uses(collected_scalar_uses)
read_use.collect_indices(collected_indices)
- # Collect all attribute definitions
+ # Collect all attribute definitions.
collected_attr_defs = list()
for operand in registered_operands:
if operand.kind == OperandKind.Attribute:
collected_attr_defs.append(operand)
+ # Collect all tensors with manual indexing annotation.
+ collected_index_defs = list()
+ for operand in registered_operands:
+ if operand.index_dims:
+ collected_index_defs.append(operand)
+
# Add all definitions before uses, so process twice.
for use in collected_tensor_uses:
self.add_operand(use.operand_def)
@@ -151,6 +157,10 @@ def __init__(self,
self.add_operand(use.operand_def)
for definition in collected_attr_defs:
self.add_operand(definition)
+ for definition in collected_index_defs:
+ if definition not in self.operands:
+ self.add_operand(definition)
+ self.add_indexed_operand(definition)
for use in collected_tensor_uses:
self.add_tensor_use(use)
@@ -158,6 +168,9 @@ def __init__(self,
# symbols are known.
for cuse in self.uses.values():
cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
+ for definition in collected_index_defs:
+ self.operands[definition].indexing_map = self._normalize_affine_map(
+ self.operands[definition].indexing_map)
for operand_config in self.operands.values():
if operand_config.shape_map:
operand_config.shape_map = self._normalize_affine_map(
@@ -278,6 +291,18 @@ def add_operand(self, operand_def: OperandDef):
self.operands[operand_def] = OperandDefConfig(
operand_def, shape_map=affine_map)
+ def add_indexed_operand(self, operand_def: OperandDef):
+ with self.context:
+ local_state = AffineBuildState(
+ global_state=self.affine_state, allow_new_symbols=False)
+ exprs = []
+ for expr in operand_def.index_dims:
+ exprs.append(expr.build(state=local_state))
+ self.operands[operand_def].indexing_map = _ir.AffineMap.get(
+ dim_count=local_state.dim_count,
+ symbol_count=local_state.symbol_count,
+ exprs=exprs)
+
def add_tensor_use(self, tensor_use: TensorUse):
if tensor_use in self.uses:
return
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index fe8bfc501ebcb..253fca4b41690 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -81,12 +81,32 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
- """A depth-wise 2-D convolution operation."""
+ """Performs depth-wise 2-D convolution.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
O[D.n, D.oh, D.ow, D.c] += cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
D.c]) * cast(U, K[D.kh, D.kw, D.c])
+ at linalg_structured_op
+def pooling_nhwc_sum_poly(
+ I=TensorDef(T1, S.N, S.H, S.W, S.C),
+ K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+ strides=AttributeDef(S.SH, S.SW),
+ dilations=AttributeDef(S.DH, S.DW)):
+ """Performs sum pooling.
+
+ Numeric casting is performed on the input operand, promoting it to the same
+ data type as the accumulator/output.
+ """
+ O[D.n, D.oh, D.ow, D.c] += cast(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
+
+
@linalg_structured_op
def fill_rng_2d(
min=ScalarDef(F64),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index b40ab139c3e73..723859c913c04 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -60,6 +60,34 @@ func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32(%input : tenso
// -----
+func @generalize_pooling_nhwc_sum_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
+ %0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+ ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
+ return %0: tensor<1x2x4x1xf32>
+}
+
+// CHECK-LABEL: @generalize_pooling_nhwc_sum_poly_f32
+// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
+// CHECK-NEXT: %[[ADD:.+]] = addf %[[OUT_ARG]], %[[IN_ARG]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<1x2x4x1xf32>
+
+// -----
+
+func @generalize_pooling_nhwc_sum_poly_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
+ %0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
+ ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
+ return %0: tensor<1x2x4x1xi32>
+}
+
+// CHECK-LABEL: @generalize_pooling_nhwc_sum_poly_i32
+// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
+// CHECK-NEXT: %[[ADD:.+]] = addi %[[OUT_ARG]], %[[IN_ARG]] : i32
+// CHECK-NEXT: linalg.yield %[[ADD]] : i32
+// CHECK-NEXT: -> tensor<1x2x4x1xi32>
+
+// -----
+
func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
%0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
return %0: tensor<16x32xf32>
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index 0ed32fe4fb293..cbe88dd043f73 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -40,7 +40,18 @@ def conv_poly(
@linalg_structured_op
-def fill_rng(
+def pooling_poly(
+ I=TensorDef(T1, S.N, S.H, S.W, S.C),
+ K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+ strides=AttributeDef(S.SH, S.SW),
+ dilations=AttributeDef(S.DH, S.DW)):
+ O[D.n, D.oh, D.ow, D.c] += cast(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
+
+
+ at linalg_structured_op
+def fill_rng_poly(
min=ScalarDef(F64),
max=ScalarDef(F64),
seed=ScalarDef(I32),
@@ -65,16 +76,22 @@ def fill_rng(
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
- # Note that these all have the same indexing maps. We verify the first and
- # then do more permutation tests on casting and body generation
- # behavior.
- # CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
- # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
- # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+ # Multiplication indexing maps. We verify only the indexing maps of the
+ # first multiplication and then do additional tests on casting and body
+ # generation behavior.
+ # CHECK: #[[$MUL_MAP_A:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+ # CHECK: #[[$MUL_MAP_B:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+ # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+ # Convolution indexing maps.
+ # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 4 + d5 * 2, d3)>
+ # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+ # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
- # CHECK: #[[$MAPI:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 4 + d5 * 2, d3)>
- # CHECK: #[[$MAPK:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
- # CHECK: #[[$MAPO:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+ # Pooling indexing maps.
+ # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3 * 2 + d0, d4 * 4 + d1 * 2, d5)>
+ # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>
+ # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
# CHECK-LABEL: func @test_matmul_mono
# CHECK-SAME: %[[A:.+]]: tensor<4x16xf32>
@@ -82,7 +99,7 @@ def fill_rng(
# CHECK: %[[INITC:.+]] = linalg.init_tensor [4, 8] : tensor<4x8xf32>
# CHECK: linalg.generic
- # CHECK-SAME: indexing_maps = [#[[$MAPA]], #[[$MAPB]], #[[$MAPC]]]
+ # CHECK-SAME: indexing_maps = [#[[$MUL_MAP_A]], #[[$MUL_MAP_B]], #[[$MUL_MAP_C]]]
# CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
# CHECK-SAME: ins(%[[A]], %[[B]]
# CHECK-SAME: outs(%[[INITC]]
@@ -177,28 +194,9 @@ def test_f16f16f32_matmul(lhs, rhs, init_result):
def test_f64f64f32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
- # CHECK-LABEL: @test_fill_rng
- # CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %{{.*}}
- # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
- # CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
- # CHECK-DAG: %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
- # CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i64
- # CHECK-DAG: %[[CST0_CAST:.+]] = trunci %[[CST0]] : i64 to i32
- # Skip the remaining random number computation and match the scaling logic.
- # CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
- # CHECK-DAG: %[[CST3:.+]] = constant 2.3283063999999999E-10 : f64
- # CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST3]] : f64
- # CHECK-DAG: %[[RND4:.+]] = mulf %{{.+}}, %[[FACT]] : f64
- # CHECK-DAG: %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64
- # CHECK-DAG: %{{.*}} = fptosi %[[RND5]] : f64 to i32
- @builtin.FuncOp.from_py_func(f64, f64, i32,
- RankedTensorType.get((4, 16), i32))
- def test_fill_rng(min, max, seed, init_result):
- return fill_rng(min, max, seed, outs=[init_result])
-
# CHECK-LABEL: @test_f32i32_conv
# CHECK: linalg.generic
- # CHECK-SAME: indexing_maps = [#[[$MAPI]], #[[$MAPK]], #[[$MAPO]]]
+ # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]]
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32)
# CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
@@ -215,5 +213,40 @@ def test_f32i32_conv(input, filter, init_result):
return conv_poly(
input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+ # CHECK-LABEL: @test_f32i32_pooling
+ # CHECK: linalg.generic
+ # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
+ # CHECK-SAME: iterator_types = ["reduction", "reduction", "parallel", "parallel", "parallel", "parallel"]
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
+ # CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
+ # CHECK-NEXT: %[[SUM:.+]] = addi %[[OUT]], %[[IN_CAST]] : i32
+ # CHECK-NEXT: linalg.yield %[[SUM]] : i32
+ # CHECK-NEXT: -> tensor<2x4xi32>
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
+ RankedTensorType.get((2, 4), i32))
+ def test_f32i32_pooling(input, shape, init_result):
+ return pooling_poly(
+ input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+
+ # CHECK-LABEL: @test_i32_fill_rng
+ # CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %{{.*}}
+ # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
+ # CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
+ # CHECK-DAG: %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
+ # CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i64
+ # CHECK-DAG: %[[CST0_CAST:.+]] = trunci %[[CST0]] : i64 to i32
+ # Skip the remaining random number computation and match the scaling logic.
+ # CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
+ # CHECK-DAG: %[[CST3:.+]] = constant 2.3283063999999999E-10 : f64
+ # CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST3]] : f64
+ # CHECK-DAG: %[[RND4:.+]] = mulf %{{.+}}, %[[FACT]] : f64
+ # CHECK-DAG: %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64
+ # CHECK-DAG: %{{.*}} = fptosi %[[RND5]] : f64 to i32
+ @builtin.FuncOp.from_py_func(f64, f64, i32,
+ RankedTensorType.get((4, 16), i32))
+ def test_i32_fill_rng(min, max, seed, init_result):
+ return fill_rng_poly(min, max, seed, outs=[init_result])
+
print(module)
diff --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
index 3132c90046df7..2933852f97cfe 100644
--- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
+++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
@@ -42,3 +42,23 @@ def matmul(
@linalg_structured_op
def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
+
+# Verifies that the index_dims of shape-only operands translate to correct
+# indexing maps.
+# CHECK: ---
+# CHECK-LABEL: pool
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0)>
+# CHECK: static_indexing_maps:
+# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d1 * 2 + d0)>
+# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0)>
+# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d1)>
+# CHECK: iterator_types:
+# CHECK-NEXT: - reduction
+# CHECK-NEXT: - parallel
+ at linalg_structured_op
+def pool(I=TensorDef(T, S.I),
+ K=TensorDef(T, S.K, index_dims=[D.k]),
+ O=TensorDef(U, S.O, output=True)):
+ O[D.o] += cast(U, I[D.o * 2 + D.k])
diff --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py
index e315a5fe9889e..08b13a5352984 100644
--- a/mlir/test/python/dialects/linalg/opsrun.py
+++ b/mlir/test/python/dialects/linalg/opsrun.py
@@ -82,6 +82,28 @@ def log(*args):
}
"""
+pooling_boiler = """
+func @main() -> i32 attributes {llvm.emit_c_interface} {
+ %v0 = constant 0 : i32
+ %v1 = constant 1.0 : f64
+
+ %input = memref.alloc() : memref<1x4x16x1xf64>
+ %shape = memref.alloc() : memref<2x2xf64>
+ %output = memref.alloc() : memref<1x2x4x1xi32>
+ linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64>
+ linalg.fill(%v1, %shape) : f64, memref<2x2xf64>
+ linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
+
+ call @pooling_on_buffers(%input, %shape, %output) :
+ (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
+
+ %c0 = constant 0 : index
+ %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
+
+ // TODO: FFI-based solution to allow testing and printing with python code.
+ return %0 : i32
+}
+"""
def transform(module, boilerplate):
import mlir.conversions
@@ -273,3 +295,72 @@ def conv_on_buffers(input, filter, output):
test_conv_generic()
+
+
+def test_pooling_builtin():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f64 = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ @builtin.FuncOp.from_py_func(
+ MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
+ MemRefType.get((1, 2, 4, 1), i32))
+ def pooling_on_buffers(input, shape, output):
+ linalg.pooling_nhwc_sum_poly(
+ input,
+ shape,
+ outs=[output],
+ strides=[2, 4],
+ dilations=[1, 2])
+
+ execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # CHECK: RESULT: 4
+
+
+test_pooling_builtin()
+
+
+def test_pooling_generic():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f64 = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ @builtin.FuncOp.from_py_func(
+ MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
+ MemRefType.get((1, 2, 4, 1), i32))
+ def pooling_on_buffers(input, shape, output):
+ linalg.pooling_nhwc_sum_poly(
+ input,
+ shape,
+ outs=[output],
+ strides=[2, 4],
+ dilations=[1, 2],
+ emit_generic=True)
+
+ execution_engine = ExecutionEngine(transform(module, pooling_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # CHECK: RESULT: 4
+
+
+test_pooling_generic()
More information about the Mlir-commits
mailing list