[Mlir-commits] [mlir] d50571a - [mlir][OpDSL] Add default value to index attributes.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 14 04:38:17 PST 2022
Author: gysit
Date: 2022-02-14T12:14:12Z
New Revision: d50571ab07e1fc1761cf2a884459fe4892ec75f1
URL: https://github.com/llvm/llvm-project/commit/d50571ab07e1fc1761cf2a884459fe4892ec75f1
DIFF: https://github.com/llvm/llvm-project/commit/d50571ab07e1fc1761cf2a884459fe4892ec75f1.diff
LOG: [mlir][OpDSL] Add default value to index attributes.
Index attributes had no default value, which means the attribute values had to be set on the operation. This revision adds a default parameter to `IndexAttrDef`. After the change, every index attribute has to define a default value. For example, we may define the following strides attribute:
```
```
When using the operation the default stride is used if the strides attribute is not set. The mechanism is implemented using `DefaultValuedAttr`.
Additionally, the revision uses the naming index attribute instead of attribute more consistently, which is a preparation for follow up revisions that will introduce function attributes.
Depends On D119125
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D119126
Added:
Modified:
mlir/docs/Dialects/Linalg/OpDSL.md
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/Dialect/Linalg/named-ops.mlir
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
mlir/test/python/dialects/linalg/opdsl/arguments.py
mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
mlir/test/python/integration/dialects/linalg/opsrun.py
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index deec3eae0fd2b..4f703b7504bc1 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -105,7 +105,7 @@ appear in the parameter list of the operation:
copy_and_scale(val, in_tensor, outs=[out_tensor])
```
-## Attributes
+## Index Attributes
Attributes are compile-time constant parameters only accessible in index
expressions. They can be used to parameterize the access pattern of a structured
@@ -118,7 +118,7 @@ The following example demonstrates the use of attributes:
@linalg_structured_op
def strided_copy(I=TensorDef(T, S.IH, S.IW),
O=TensorDef(T, S.OH, S.OW, output=True),
- strides=IndexAttrDef(S.SH, S.SW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1])):
"""Copy a subset of the input tensor elements to the output tensor"""
O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
```
@@ -129,11 +129,12 @@ the symbols `S.SH` and `S.SW`, which are used to index the input tensor `I`.
When instantiating the operation, the attribute is set using a named argument:
```python
-strided_copy(in_tensor, outs=[out_tensor], strides=[1,2])
+strided_copy(in_tensor, outs=[out_tensor], strides=[1, 2])
```
The `strides` vector elements substitute the symbols `S.SH` and `S.SW` in the
-index expressions of the operation instance.
+index expressions of the operation instance. If no strides are provided the
+`default` vector elements are used instead.
Attributes are currently limited to integer vectors and only accessible in index
expressions. An operation may have multiple attributes all of them placed at the
@@ -157,8 +158,8 @@ def pooling_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(U,
I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
```
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 69a4cc407b9d7..e24efad36b577 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -15,17 +15,17 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- !LinalgOperandDefConfig
name: B
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
- !LinalgOperandDefConfig
name: C
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -79,17 +79,17 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- !LinalgOperandDefConfig
name: B
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
- !LinalgOperandDefConfig
name: C
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -143,25 +143,25 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
- !LinalgOperandDefConfig
name: B
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
- !LinalgOperandDefConfig
name: AZp
- usage: InputOperand
+ usage: Input
type_var: I32
- !LinalgOperandDefConfig
name: BZp
- usage: InputOperand
+ usage: Input
type_var: I32
- !LinalgOperandDefConfig
name: C
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -244,17 +244,17 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: lhs
- usage: InputOperand
+ usage: Input
type_var: LhsType
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2, s3)>
- !LinalgOperandDefConfig
name: rhs
- usage: InputOperand
+ usage: Input
type_var: RhsType
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s5, s3)>
- !LinalgOperandDefConfig
name: accum
- usage: OutputOperand
+ usage: Output
type_var: AccumType
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -314,17 +314,17 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- !LinalgOperandDefConfig
name: B
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- !LinalgOperandDefConfig
name: C
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -379,25 +379,25 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
- !LinalgOperandDefConfig
name: B
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
- !LinalgOperandDefConfig
name: AZp
- usage: InputOperand
+ usage: Input
type_var: I32
- !LinalgOperandDefConfig
name: BZp
- usage: InputOperand
+ usage: Input
type_var: I32
- !LinalgOperandDefConfig
name: C
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -476,17 +476,17 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
- !LinalgOperandDefConfig
name: y
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1] -> (s1)>
- !LinalgOperandDefConfig
name: x
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1] -> (s0)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -539,17 +539,17 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: y
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1] -> (s0)>
- !LinalgOperandDefConfig
name: A
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
- !LinalgOperandDefConfig
name: x
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1] -> (s1)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -602,17 +602,17 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1, s2)>
- !LinalgOperandDefConfig
name: B
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
- !LinalgOperandDefConfig
name: C
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -666,17 +666,17 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: A
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0] -> (s0)>
- !LinalgOperandDefConfig
name: B
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0] -> (s0)>
- !LinalgOperandDefConfig
name: C
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0] -> ()>
indexing_maps: !LinalgIndexingMapsConfig
@@ -728,17 +728,17 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1] -> (s0 + s1)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1] -> (s1)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1] -> (s0)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -791,17 +791,17 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0 + s1, s2 + s3)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3] -> (s1, s3)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -856,17 +856,17 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0 + s1, s2 + s3, s4 + s5)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s1, s3, s5)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s2, s4)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -924,30 +924,32 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0, s1 * s2 + s3 * s4,
s5)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s3, s5, s6)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0, s1, s6)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s2)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s2)>
+ default_vals:
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s4)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s4)>
+ default_vals:
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4)[s0, s1, s2, s3, s4, s5, s6] -> (d0, d1 * s2
@@ -1006,34 +1008,38 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3,
s7, s9, s10)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
s1, s5, s10)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s2,
- s6)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s2, s6)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s4,
- s8)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s4, s8)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
@@ -1097,42 +1103,46 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3,
s7, s9, s10)>
- !LinalgOperandDefConfig
name: IZp
- usage: InputOperand
+ usage: Input
type_var: I32
- !LinalgOperandDefConfig
name: KZp
- usage: InputOperand
+ usage: Input
type_var: I32
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
s1, s5, s10)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s2,
- s6)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s2, s6)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s4,
- s8)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s4, s8)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
@@ -1221,34 +1231,38 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
s1, s2 * s3 + s4 * s5, s6 * s7 + s8 * s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10,
s1, s4, s8)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
s10, s2, s6)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3,
- s7)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s3, s7)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s5,
- s9)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s5, s9)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
@@ -1307,35 +1321,41 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
s13, s14] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12,
s13)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
s13, s14] -> (s3, s7, s11, s13, s14)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
s13, s14] -> (s0, s1, s5, s9, s14)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
s12, s13, s14] -> (s2, s6, s10)>
+ default_vals:
+ - 1
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
s12, s13, s14] -> (s4, s8, s12)>
+ default_vals:
+ - 1
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
@@ -1398,29 +1418,31 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1 * s2 + s3 * s4, s5)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s3, s5)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s5)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2)>
+ default_vals:
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4)>
+ default_vals:
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4, s5] -> (d0, d1 * s2 + d3 * s4,
@@ -1475,31 +1497,37 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 *
s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7, s9)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5,
s9)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2,
+ s6)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4,
+ s8)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
@@ -1557,39 +1585,45 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 *
s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7, s9)>
- !LinalgOperandDefConfig
name: IZp
- usage: InputOperand
+ usage: Input
type_var: I32
- !LinalgOperandDefConfig
name: KZp
- usage: InputOperand
+ usage: Input
type_var: I32
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5,
s9)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2,
+ s6)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4,
+ s8)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
@@ -1673,34 +1707,38 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3,
s7, s9, s10)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
s1, s5, s9, s10)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s2,
- s6)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s2, s6)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s4,
- s8)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s4, s8)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
@@ -1759,42 +1797,46 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s3,
s7, s9, s10)>
- !LinalgOperandDefConfig
name: IZp
- usage: InputOperand
+ usage: Input
type_var: I32
- !LinalgOperandDefConfig
name: KZp
- usage: InputOperand
+ usage: Input
type_var: I32
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
s1, s5, s9, s10)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s2,
- s6)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s2, s6)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s4,
- s8)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
+ (s4, s8)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
@@ -1879,31 +1921,37 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 *
s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5,
s9)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2,
+ s6)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4,
+ s8)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
@@ -1950,31 +1998,37 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 *
s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5,
s9)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2,
+ s6)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4,
+ s8)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
@@ -2021,31 +2075,37 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 *
s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5,
s9)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2,
+ s6)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4,
+ s8)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
@@ -2092,31 +2152,37 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2
* s3 + s4 * s5, s6 * s7 + s8 * s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2,
s6)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3,
+ s7)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s5, s9)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s5,
+ s9)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
@@ -2163,31 +2229,37 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 *
s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5,
s9)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2,
+ s6)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4,
+ s8)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
@@ -2234,31 +2306,37 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1 *
s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3, s7)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s5,
s9)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2, s6)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2,
+ s6)>
+ default_vals:
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4, s8)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4,
+ s8)>
+ default_vals:
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
@@ -2305,34 +2383,40 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
s13] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12, s13)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
s13] -> (s3, s7, s11)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
s13] -> (s0, s1, s5, s9, s13)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
s12, s13] -> (s2, s6, s10)>
+ default_vals:
+ - 1
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
s12, s13] -> (s4, s8, s12)>
+ default_vals:
+ - 1
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
@@ -2382,34 +2466,40 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
s13] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12, s13)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
s13] -> (s3, s7, s11)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
s13] -> (s0, s1, s5, s9, s13)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
s12, s13] -> (s2, s6, s10)>
+ default_vals:
+ - 1
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
s12, s13] -> (s4, s8, s12)>
+ default_vals:
+ - 1
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
@@ -2459,34 +2549,40 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
s13] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12, s13)>
- !LinalgOperandDefConfig
name: K
- usage: InputOperand
+ usage: Input
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
s13] -> (s3, s7, s11)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
s13] -> (s0, s1, s5, s9, s13)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
s12, s13] -> (s2, s6, s10)>
+ default_vals:
+ - 1
+ - 1
+ - 1
- !LinalgOperandDefConfig
name: dilations
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
s12, s13] -> (s4, s8, s12)>
+ default_vals:
+ - 1
+ - 1
+ - 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
@@ -2535,11 +2631,11 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: value
- usage: InputOperand
+ usage: Input
type_var: T1
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
@@ -2575,19 +2671,19 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: min
- usage: InputOperand
+ usage: Input
type_var: F64
- !LinalgOperandDefConfig
name: max
- usage: InputOperand
+ usage: Input
type_var: F64
- !LinalgOperandDefConfig
name: seed
- usage: InputOperand
+ usage: Input
type_var: I32
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: T
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -2733,12 +2829,12 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
indexing_maps: !LinalgIndexingMapsConfig
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index ea25d85aa7428..4513236b8703f 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -135,7 +135,7 @@ class OperandKind(Enum):
InputTensor = 0
Scalar = 1
OutputTensor = 2
- Attribute = 3
+ IndexAttr = 3
class OperandDef:
@@ -147,16 +147,18 @@ class OperandDef:
def __init__(self,
kind: OperandKind,
- type_var: TypeVar,
+ type_var: Optional[TypeVar] = None,
size_exprs: Optional[Sequence[AffineExprDef]] = None,
- index_dims: Optional[Sequence[DimDef]] = None):
- if not isinstance(type_var, TypeVar):
+ index_dims: Optional[Sequence[DimDef]] = None,
+ default_vals : Optional[Sequence[int]] = None):
+ if type_var and not isinstance(type_var, TypeVar):
raise ValueError(
f"OperandDef requires a TypeVar but got {repr(type_var)}")
self.owner = None # type: Optional["LinalgOpDef"]
self.type_var = type_var
self.size_exprs = size_exprs
self.index_dims = index_dims
+ self.default_vals = default_vals
self.kind = kind
self.name = None # type: Optional[str]
self.registered_index = -1 # type: int
@@ -174,7 +176,7 @@ def __hash__(self):
def __repr__(self):
return (f"{self.name}:OperandDef(kind={self.kind.name}, "
f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), "
- f"index_dims={self.index_dims})")
+ f"index_dims={self.index_dims}, default_vals={self.default_vals})")
class TensorDef:
@@ -202,7 +204,7 @@ def __init__(self,
f"got {index_dims}")
kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
self.operand_def = OperandDef(
- kind, type_var, size_exprs=shape, index_dims=index_dims)
+ kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
def __getitem__(self, dims) -> TensorUse:
assert self.operand_def.owner, "TensorDef is not attached to an op"
@@ -246,7 +248,7 @@ class ScalarDef(TensorExpression):
"""
def __init__(self, type_var: TypeVar):
- self.operand_def = OperandDef(OperandKind.Scalar, type_var)
+ self.operand_def = OperandDef(OperandKind.Scalar, type_var=type_var)
@property
def scalar_name(self) -> str:
@@ -259,18 +261,25 @@ def to_scalar_expression(self) -> ScalarExpression:
class IndexAttrDef:
- """Index Attribute definition.
+ """Index attribute definition.
Index attributes provide a way to define and set symbols that can be used in
indexing expressions. Every attribute specifies a tuple of symbols that at
- compile-time are replaced by integer values.
+ compile-time are replaced by integer values as well as their default values.
"""
- def __init__(self, *sizes: SymbolDef):
+ def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
if any(not isinstance(size, SymbolDef) for size in sizes):
- raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef but got "
- f"{sizes}")
- self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes)
+ raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef "
+ f"but got {sizes}")
+ if any(not isinstance(default_val, int) for default_val in default):
+ raise ValueError(f"IndexAttrDef requires default values of type int "
+ f"but got {default}")
+ if len(sizes) != len(default):
+ raise ValueError(f"IndexAttrDef expects {len(sizes)} default values "
+ f"but got {len(default)}")
+ self.operand_def = OperandDef(
+ OperandKind.IndexAttr, size_exprs=sizes, default_vals=default)
class Comprehension:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index 59a10998e102c..21741252f4996 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -45,10 +45,10 @@ class OperandDefConfig(YAMLObject):
def __init__(self,
operand_def: OperandDef,
shape_map: Optional[_ir.AffineMap] = None,
- attribute_map: Optional[_ir.AffineMap] = None):
+ index_attr_map: Optional[_ir.AffineMap] = None):
self.operand_def = operand_def
self.shape_map = shape_map # type: Optional[_ir.AffineMap]
- self.attribute_map = attribute_map # type: Optional[_ir.AffineMap]
+ self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap]
self.indexing_map = None # type: Optional[_ir.AffineMap]
@property
@@ -61,24 +61,28 @@ def type_var(self) -> TypeVar:
@property
def usage(self) -> str:
- if self.operand_def.kind == OperandKind.Attribute:
- return "IndexAttribute"
+ if self.operand_def.kind == OperandKind.IndexAttr:
+ return "IndexAttr"
if self.operand_def.kind == OperandKind.OutputTensor:
- return "OutputOperand"
- return "InputOperand"
+ return "Output"
+ return "Input"
def to_yaml_custom_dict(self):
- self_dict = dict(
- name=self.name, usage=self.usage, type_var=self.type_var.name)
+ self_dict = dict(name=self.name, usage=self.usage)
+ if self.type_var:
+ self_dict["type_var"] = self.type_var.name
if self.shape_map:
self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
- if self.attribute_map:
- self_dict["attribute_map"] = _serialize_affine_map(self.attribute_map)
+ if self.index_attr_map:
+ self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map)
+ if self.operand_def.default_vals:
+ self_dict["default_vals"] = self.operand_def.default_vals
return self_dict
def __repr__(self):
return (f"OperandDefConfig({self.operand_def}, "
- f"shape_map={self.shape_map}, attribute_map={self.attribute_map}, "
+ f"shape_map={self.shape_map}, "
+ f"index_attr_map={self.index_attr_map}, "
f"indexing_map={self.indexing_map})")
@@ -162,7 +166,7 @@ def __init__(self,
# Collect all attribute definitions.
collected_attr_defs = list()
for operand in registered_operands:
- if operand.kind == OperandKind.Attribute:
+ if operand.kind == OperandKind.IndexAttr:
collected_attr_defs.append(operand)
# Collect all tensors with manual indexing annotation.
@@ -210,9 +214,9 @@ def __init__(self,
if operand_config.shape_map:
operand_config.shape_map = self._normalize_affine_map(
operand_config.shape_map, with_dims=False)
- if operand_config.attribute_map:
- operand_config.attribute_map = self._normalize_affine_map(
- operand_config.attribute_map, with_dims=False)
+ if operand_config.index_attr_map:
+ operand_config.index_attr_map = self._normalize_affine_map(
+ operand_config.index_attr_map, with_dims=False)
# Now for each write use, propagate the indexing maps from the use to the
# tensor, ensuring that there are not conflicts.
@@ -245,7 +249,7 @@ def __init__(self,
# Check all registered tensor and scalar operands have an indexing map.
for operand in registered_operands:
- if operand.kind == OperandKind.Attribute:
+ if operand.kind == OperandKind.IndexAttr:
continue
if not (operand in self.operands and self.operands[operand].indexing_map):
raise ValueError(f"Failed to compute an indexing map for operand "
@@ -319,9 +323,9 @@ def add_operand(self, operand_def: OperandDef):
assert local_state.local_dim_count == 0
affine_map = _ir.AffineMap.get(
dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
- if operand_def.kind == OperandKind.Attribute:
+ if operand_def.kind == OperandKind.IndexAttr:
self.operands[operand_def] = OperandDefConfig(
- operand_def, attribute_map=affine_map)
+ operand_def, index_attr_map=affine_map)
else:
self.operands[operand_def] = OperandDefConfig(
operand_def, shape_map=affine_map)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index e4695f0c92a27..3d3b1889ba992 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -39,15 +39,14 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
*ins: Value, outs: ValueList,
**attrs: Sequence[int]):
all_arg_defs = op_config.ordered_operands
- in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"]
- out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"]
- attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"]
+ in_arg_defs = [d for d in all_arg_defs if d.usage == "Input"]
+ out_arg_defs = [d for d in all_arg_defs if d.usage == "Output"]
+ index_attr_arg_defs = [d for d in all_arg_defs if d.usage == "IndexAttr"]
# Verify outs is a sequence or a list of results.
if not isinstance(outs, (Sequence, OpResultList)):
- raise ValueError(
- f"Expected named argument outs to have type Sequence or OpResultLis but got {type(outs)}"
- )
+ raise ValueError(f"Expected named argument outs to have type Sequence or "
+ f"OpResultLis but got {type(outs)}")
# Arity validation.
if len(ins) != len(in_arg_defs):
@@ -60,18 +59,19 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
# Compute a replacement list for all attribute symbols.
expressions = [] # type: Sequence[AffineExpr]
replacements = [] # type: Sequence[AffineExpr]
- for attr in attr_arg_defs:
- if attr.name not in attrs:
- raise ValueError(f"Expected named argument for the attribute {attr.name}")
- attribute_values = attrs.get(attr.name)
- if not all(isinstance(value, int) for value in attribute_values):
- raise ValueError(f"Attribute {attr.name} needs to be of type "
- f"Sequence[int] but got {type(attribute_values)}")
- results = attr.attribute_map.results # type: AffineExprList
- if len(attribute_values) != len(results):
- raise ValueError(f"Attribute {attr.name} has length {len(results)} "
- f"but got {len(attribute_values)} values")
- for expr, value in zip(results, attribute_values):
+ for index_attr in index_attr_arg_defs:
+ index_attr_vals = index_attr.operand_def.default_vals
+ if index_attr.name in attrs:
+ index_attr_vals = attrs.get(index_attr.name)
+ assert index_attr_vals, "Index attribute has no value"
+ if not all(isinstance(value, int) for value in index_attr_vals):
+ raise ValueError(f"Attribute {index_attr.name} needs to be of type "
+ f"Sequence[int] but got {type(index_attr_vals)}")
+ results = index_attr.index_attr_map.results # type: AffineExprList
+ if len(index_attr_vals) != len(results):
+ raise ValueError(f"Attribute {index_attr.name} has length {len(results)} "
+ f"but got {len(index_attr_vals)} values")
+ for expr, value in zip(results, index_attr_vals):
expressions.append(expr)
replacements.append(AffineConstantExpr.get(value))
@@ -116,22 +116,24 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
iterator_types_attr = ArrayAttr.get(
[StringAttr.get(s) for s in op_config.iterator_types])
- # Compute a dictionary storing all index attributes.
- index_attributes = {} # type: Dict[str, DenseElementAttr]
- for attr in attr_arg_defs:
- attribute_values = attrs.get(attr.name)
- array = np.array(attribute_values, dtype=np.int64)
- index_attributes[attr.name] = DenseElementsAttr.get(array)
+ # Compute the index attributes used when emitting a named structured op.
+ index_attrs = {} # type: Dict[str, DenseElementAttr]
+ for index_attr in index_attr_arg_defs:
+ index_attr_vals = attrs.get(index_attr.name)
+ # Only forward attributes set to a non-default value.
+ if index_attr_vals:
+ array = np.array(index_attr_vals, dtype=np.int64)
+ index_attrs[index_attr.name] = DenseElementsAttr.get(array)
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
type_mapping, indexing_maps_attr, iterator_types_attr,
- index_attributes, block_arg_types)
+ index_attrs, block_arg_types)
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
outs: ValueList, **attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
- indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
+ indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
# An operation that accesses only scalars and scalar/rank zero tensors is
@@ -182,7 +184,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
op_class_name: str, *ins: Value, outs: ValueList,
**attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
- indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
+ indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \
prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
# If we get here, there must exist a builtin class `op_class_name`.
@@ -195,7 +197,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
# Set the index attributes used to compute the indexing maps.
named_op = getattr(linalg, op_class_name)(ins, outs, result_types)
- for name, value in index_attributes.items():
+ for name, value in index_attrs.items():
named_op.operation.attributes[name] = value
linalg.fill_builtin_region(named_op.operation)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 80a8fb6ccf091..25bd0c3ab32b4 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -224,8 +224,8 @@ def conv_1d_nwc_wcf(
I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KW, S.C, S.F),
O=TensorDef(U, S.N, S.OW, S.F, output=True),
- strides=IndexAttrDef(S.SW),
- dilations=IndexAttrDef(S.DW)):
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1])):
"""Performs 1-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@@ -244,8 +244,8 @@ def conv_2d_nhwc_hwcf(
S.C),
K=TensorDef(T2, S.KH, S.KW, S.C, S.F),
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs 2-D convolution.
Layout:
@@ -270,8 +270,8 @@ def conv_2d_nhwc_hwcf_q(
IZp=ScalarDef(I32),
KZp=ScalarDef(I32),
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs 2-D convolution with zero point offsets.
Layout:
@@ -297,8 +297,8 @@ def conv_2d_nchw_fchw(
S.OW * S.SW + S.KW * S.DW),
K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs 2-D convolution.
Layout:
@@ -321,8 +321,8 @@ def conv_3d_ndhwc_dhwcf(
S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F),
O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True),
- strides=IndexAttrDef(S.SD, S.SH, S.SW),
- dilations=IndexAttrDef(S.DD, S.DH, S.DW)):
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
"""Performs 3-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@@ -341,8 +341,8 @@ def depthwise_conv_1d_nwc_wc(
I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC),
K=TensorDef(T2, S.KW, S.IC),
O=TensorDef(U, S.N, S.OW, S.IC, output=True),
- strides=IndexAttrDef(S.SW),
- dilations=IndexAttrDef(S.DW)):
+ strides=IndexAttrDef(S.SW, default=[1]),
+ dilations=IndexAttrDef(S.DW, default=[1])):
"""Performs depth-wise 1-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@@ -362,8 +362,8 @@ def depthwise_conv_2d_nhwc_hwc(
S.IC),
K=TensorDef(T2, S.KH, S.KW, S.IC),
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs depth-wise 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@@ -385,8 +385,8 @@ def depthwise_conv_2d_nhwc_hwc_q(
IZp=ScalarDef(I32),
KZp=ScalarDef(I32),
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs depth-wise 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@@ -407,8 +407,8 @@ def depthwise_conv_2d_nhwc_hwcm(
S.IC),
K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM),
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs depth-wise 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@@ -429,8 +429,8 @@ def depthwise_conv_2d_nhwc_hwcm_q(
IZp=ScalarDef(I32),
KZp=ScalarDef(I32),
O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs depth-wise 2-D convolution.
Numeric casting is performed on the operands to the inner multiply, promoting
@@ -451,8 +451,8 @@ def pooling_nhwc_sum(
S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs sum pooling.
Numeric casting is performed on the input operand, promoting it to the same
@@ -470,8 +470,8 @@ def pooling_nhwc_max(
S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs max pooling.
Numeric casting is performed on the input operand, promoting it to the same
@@ -490,8 +490,8 @@ def pooling_nhwc_max_unsigned(
S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs unsigned max pooling.
Numeric casting is performed on the input operand, promoting it to the same
@@ -510,8 +510,8 @@ def pooling_nchw_max(
S.OW * S.SW + S.KW * S.DW),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs max pooling.
Numeric casting is performed on the input operand, promoting it to the same
@@ -531,8 +531,8 @@ def pooling_nhwc_min(
S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs min pooling.
Numeric casting is performed on the input operand, promoting it to the same
@@ -551,8 +551,8 @@ def pooling_nhwc_min_unsigned(
S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
"""Performs unsigned min pooling.
Numeric casting is performed on the input operand, promoting it to the same
@@ -571,8 +571,8 @@ def pooling_ndhwc_sum(
S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SD, S.SH, S.SW),
- dilations=IndexAttrDef(S.DD, S.DH, S.DW)):
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
"""Performs 3D sum pooling.
Numeric casting is performed on the input operand, promoting it to the same
@@ -591,8 +591,8 @@ def pooling_ndhwc_max(
S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SD, S.SH, S.SW),
- dilations=IndexAttrDef(S.DD, S.DH, S.DW)):
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
"""Performs 3D max pooling.
Numeric casting is performed on the input operand, promoting it to the same
@@ -612,8 +612,8 @@ def pooling_ndhwc_min(
S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]),
O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SD, S.SH, S.SW),
- dilations=IndexAttrDef(S.DD, S.DH, S.DW)):
+ strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+ dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])):
"""Performs 3D min pooling.
Numeric casting is performed on the input operand, promoting it to the same
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index e9b9656bba493..8de70c2ce8758 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -97,19 +97,12 @@ func @depthwise_conv_2d_nhwc_hwcm_memref_dilated(%input: memref<2x8x9x2xf32>, %f
// -----
-func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
- // expected-error @+1 {{missing indexing map required attribute 'strides'}}
- linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>}
- ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
- outs(%output: memref<1x56x56x96xf32>)
- return
-}
-
-// -----
-
-func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
- // expected-error @+1 {{missing indexing map required attribute 'dilations'}}
- linalg.depthwise_conv_2d_nhwc_hwc {strides = dense<1> : vector<2xi64>}
+// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_default_attributes
+func @depthwise_conv_2d_input_nhwc_filter_default_attributes(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
+ // CHECK: linalg.depthwise_conv_2d_nhwc_hwc
+ // CHECK-NOT: strides =
+ // CHECK-NOT: dilations =
+ linalg.depthwise_conv_2d_nhwc_hwc
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
@@ -118,7 +111,7 @@ func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113
// -----
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
- // expected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}}
+ // expected-error @+1 {{incorrect element type for index attribute 'strides'}}
linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
@@ -128,7 +121,7 @@ func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memr
// -----
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
- // expected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}}
+ // expected-error @+1 {{incorrect shape for index attribute 'strides'}}
linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> }
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
@@ -566,7 +559,7 @@ func @conv_interface_wrong_input_indexing_map(
%arg0 : tensor<?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
// expected-error @+1 {{unexpected input index map for convolutions}}
%0 = "linalg.conv_2d_nhwc_hwcf"(%arg0, %arg1, %arg2) ({
- ^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
+ ^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
%1 = "arith.mulf"(%arg3, %arg4) : (f32, f32) -> f32
%2 = "arith.addf"(%arg5, %1) : (f32, f32) -> f32
"linalg.yield"(%2) : (f32) -> ()
@@ -583,7 +576,7 @@ func @conv_interface_wrong_num_operands(
%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?x?xf32>, %arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
// expected-error @+1 {{expected output/filter indexing maps to be projected permutations}}
%0 = "linalg.conv_2d_nhwc_hwcf"(%arg0, %arg1, %arg2) ({
- ^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
+ ^bb0(%arg3: f32, %arg4: f32, %arg5 : f32):
%1 = "arith.mulf"(%arg3, %arg4) : (f32, f32) -> f32
%2 = "arith.addf"(%arg5, %1) : (f32, f32) -> f32
"linalg.yield"(%2) : (f32) -> ()
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
index ee36510aaf004..347923825cee3 100644
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml
@@ -21,7 +21,7 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: T
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
indexing_maps: !LinalgIndexingMapsConfig
@@ -95,7 +95,7 @@ structured_op: !LinalgStructuredOpConfig
# @linalg_structured_op
# def test2(I=TensorDef(T, S.M, S.N),
# O=TensorDef(T, S.M, S.N, output=True),
-# strides=IndexAttrDef(S.SM, S.SN)):
+# strides=IndexAttrDef(S.SM, S.SN, default=[1, 2])):
# """Title.
# Detailed description.
@@ -114,19 +114,21 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
- usage: InputOperand
+ usage: Input
type_var: T
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: T
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
- !LinalgOperandDefConfig
name: strides
- usage: IndexAttribute
- type_var: I64
- attribute_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
+ usage: IndexAttr
+ index_attr_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
+ default_vals:
+ - 1
+ - 2
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)>
@@ -145,7 +147,8 @@ structured_op: !LinalgStructuredOpConfig
# ODS: let arguments =
# ODS-NEXT: Variadic<AnyType>:$inputs,
# ODS-NEXT: Variadic<AnyShaped>:$outputs,
-# ODS-NEXT: RankedI64ElementsAttr<[2]>:$strides
+# ODS-NEXT: DefaultValuedAttr<RankedI64ElementsAttr<[2]>
+# ODS-SAME: "{ static_cast<int64_t>(1), static_cast<int64_t>(2) }">:$strides
# ODS: "Attribute":$strides
# ODS: $_state.addAttribute("strides", strides);
@@ -169,8 +172,8 @@ structured_op: !LinalgStructuredOpConfig
# IMPL: Test2Op::hasDynamicIndexingMaps() { return true; }
# IMPL: Test2Op::verifyIndexingMapRequiredAttributes()
# IMPL: auto attr = op->getAttrOfType<DenseElementsAttr>("strides")
-# IMPL: "missing indexing map required attribute 'strides'"
-
+# IMPL: "incorrect element type for index attribute 'strides'"
+# IMPL: "incorrect shape for index attribute 'strides'"
# IMPL: void Test2Op::regionBuilder(ImplicitLocOpBuilder &b, Block &block)
# IMPL-NEXT: assert(2 > 0 && block.getNumArguments() == 2 &&
@@ -197,11 +200,11 @@ structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: value
- usage: InputOperand
+ usage: Input
type_var: T1
- !LinalgOperandDefConfig
name: O
- usage: OutputOperand
+ usage: Output
type_var: U
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
index 053637582038f..ab8f0d010b683 100644
--- a/mlir/test/python/dialects/linalg/opdsl/arguments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -7,15 +7,15 @@
# CHECK-LABEL: matmul
# CHECK: args:
# CHECK: name: A
-# CHECK: usage: InputOperand
+# CHECK: usage: Input
# CHECK: type_var: T
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
# CHECK: name: B
-# CHECK: usage: InputOperand
+# CHECK: usage: Input
# CHECK: type_var: T
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
# CHECK: name: C
-# CHECK: usage: OutputOperand
+# CHECK: usage: Output
# CHECK: type_var: U
# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
@linalg_structured_op
@@ -30,7 +30,7 @@ def matmul(
# CHECK-LABEL: fill
# CHECK: args:
# CHECK: name: value
-# CHECK: usage: InputOperand
+# CHECK: usage: Input
# CHECK-NOT: shape_map:
# CHECK: type_var: T
@linalg_structured_op
@@ -42,20 +42,22 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
# CHECK-LABEL: strided_copy
# CHECK: args:
# CHECK: name: I
-# CHECK: usage: InputOperand
+# CHECK: usage: Input
# CHECK: type_var: T
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
# CHECK: name: O
-# CHECK: usage: OutputOperand
+# CHECK: usage: Output
# CHECK: type_var: T
# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
# CHECK: name: strides
-# CHECK: usage: IndexAttribute
-# CHECK: type_var: I64
-# CHECK: attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
+# CHECK: usage: IndexAttr
+# CHECK: index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
+# CHECK: default_vals:
+# CHECK: - 1
+# CHECK: - 2
@linalg_structured_op
def strided_copy(
I=TensorDef(T, S.IH, S.IW),
O=TensorDef(T, S.OH, S.OW, output=True),
- strides=IndexAttrDef(S.SH, S.SW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 2])):
O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
index 9736fa529eeb1..e15424cfac4e7 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py
@@ -16,8 +16,8 @@ def conv_poly(
I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
K=TensorDef(T2, S.KH, S.KW, S.C),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
@@ -51,8 +51,9 @@ def conv_poly(
RankedTensorType.get((2, 2, 1), f32),
RankedTensorType.get((1, 2, 4, 1), i32))
def test_f32i32_conv(input, filter, init_result):
+ # Use default dilations and set non-default strides.
return conv_poly(
- input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+ input, filter, outs=[init_result], strides=[2, 4])
print(module)
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
index cf10c9d3f1a2a..35ec8540cb4b2 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
@@ -16,8 +16,8 @@ def pooling_max_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
TypeFn.cast(
@@ -29,8 +29,8 @@ def pooling_max_unsigned_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
TypeFn.cast_unsigned(
@@ -42,8 +42,8 @@ def pooling_min_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
TypeFn.cast(
@@ -55,8 +55,8 @@ def pooling_min_unsigned_poly(
I=TensorDef(T1, S.N, S.H, S.W, S.C),
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
- strides=IndexAttrDef(S.SH, S.SW),
- dilations=IndexAttrDef(S.DH, S.DW)):
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
TypeFn.cast_unsigned(
diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 5be00f8df3334..644952ef3239d 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -132,7 +132,7 @@ def log(*args):
%c2 = arith.constant 2 : index
memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64>
- memref.store %v-13, %input[%c0, %c0, %c2, %c0] : memref<1x4x16x1xf64>
+ memref.store %v-13, %input[%c0, %c1, %c0, %c0] : memref<1x4x16x1xf64>
call @pooling_on_buffers(%input, %shape, %output) :
(memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
@@ -421,9 +421,13 @@ def test_min_pooling_builtin():
@builtin.FuncOp.from_py_func(
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
MemRefType.get((1, 2, 4, 1), i32))
+ # Set the strides and use the default dilations.
def pooling_on_buffers(input, shape, output):
linalg.pooling_nhwc_min(
- input, shape, outs=[output], strides=[2, 4], dilations=[1, 2])
+ input,
+ shape,
+ outs=[output],
+ strides=[2, 4])
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
@@ -451,13 +455,13 @@ def test_min_pooling_generic():
@builtin.FuncOp.from_py_func(
MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2), f64),
MemRefType.get((1, 2, 4, 1), i32))
+ # Set the strides and use the default dilations.
def pooling_on_buffers(input, shape, output):
linalg.pooling_nhwc_min(
input,
shape,
outs=[output],
strides=[2, 4],
- dilations=[1, 2],
emit_generic=True)
execution_engine = ExecutionEngine(transform(module, pooling_boiler))
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index d5d8ba6e0db12..f1fac9f578612 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -61,14 +61,15 @@ struct SerializedAffineMap {
AffineMap affineMap() { return affineMapAttr.getValue(); }
};
-enum class LinalgOperandDefUsage { input, output, attribute };
+enum class LinalgOperandDefUsage { Input, Output, IndexAttr };
struct LinalgOperandDef {
std::string name;
LinalgOperandDefUsage usage;
- std::string typeVar;
+ Optional<std::string> typeVar;
Optional<SerializedAffineMap> shapeMap;
- Optional<SerializedAffineMap> attributeMap;
+ Optional<SerializedAffineMap> indexAttrMap;
+ Optional<SmallVector<int64_t>> defaultVals;
};
enum class LinalgIteratorTypeDef {
@@ -175,18 +176,21 @@ struct MappingTraits<LinalgStructuredOpConfig> {
/// the argument. Only tensor arguments have a `shape_map`. Each shape must
/// be normalized over the same list of symbols and have no dimension
/// inputs.
-/// - `attribute_map`: An optional AffineMap from all op symbols to the
-/// attribute symbols. During op creation these symbols are replaced by the
-/// corresponding `name` attribute values. Only attribute arguments have
-/// an `attribute_map`.
+/// - `index_attr_map`: An optional AffineMap from all op symbols to the
+/// index attribute symbols. During op creation these symbols are replaced
+/// by the corresponding `name` index attribue values. Only index attribute
+/// arguments have an `index_attr_map`.
+/// - `default_vals`: An optional default initialization for index attribute
+/// arguments.
template <>
struct MappingTraits<LinalgOperandDef> {
static void mapping(IO &io, LinalgOperandDef &info) {
io.mapRequired("name", info.name);
io.mapRequired("usage", info.usage);
- io.mapRequired("type_var", info.typeVar);
+ io.mapOptional("type_var", info.typeVar);
io.mapOptional("shape_map", info.shapeMap);
- io.mapOptional("attribute_map", info.attributeMap);
+ io.mapOptional("index_attr_map", info.indexAttrMap);
+ io.mapOptional("default_vals", info.defaultVals);
}
};
@@ -194,9 +198,9 @@ struct MappingTraits<LinalgOperandDef> {
template <>
struct ScalarEnumerationTraits<LinalgOperandDefUsage> {
static void enumeration(IO &io, LinalgOperandDefUsage &value) {
- io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input);
- io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output);
- io.enumCase(value, "IndexAttribute", LinalgOperandDefUsage::attribute);
+ io.enumCase(value, "Input", LinalgOperandDefUsage::Input);
+ io.enumCase(value, "Output", LinalgOperandDefUsage::Output);
+ io.enumCase(value, "IndexAttr", LinalgOperandDefUsage::IndexAttr);
}
};
@@ -395,7 +399,10 @@ findTypeValue(StringRef typeVar, SmallVectorImpl<LinalgOperandDef> &args) {
// Search all argument types.
for (const auto &it : llvm::enumerate(args)) {
- if (it.value().typeVar == typeVar)
+ if (it.value().usage != LinalgOperandDefUsage::Input &&
+ it.value().usage != LinalgOperandDefUsage::Output)
+ continue;
+ if (it.value().typeVar.getValue() == typeVar)
return llvm::formatv("block.getArgument({0}).getType()", it.index())
.str();
}
@@ -674,20 +681,32 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
// Assemble the attribute specific logic required for the op definition.
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
- return arg.usage == LinalgOperandDefUsage::attribute;
+ return arg.usage == LinalgOperandDefUsage::IndexAttr;
})) {
SmallVector<std::string> attrDefs;
SmallVector<std::string> attrParams;
SmallVector<std::string> attrStmts;
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
- if (arg.usage != LinalgOperandDefUsage::attribute)
+ if (arg.usage != LinalgOperandDefUsage::IndexAttr)
continue;
- assert(arg.attributeMap.hasValue() && arg.typeVar == "I64");
- static const char defFmt[] = "RankedI64ElementsAttr<[{0}]>:${1}";
+ assert(arg.indexAttrMap.hasValue());
+ assert(arg.defaultVals.hasValue());
+ size_t size = arg.indexAttrMap->affineMap().getNumResults();
+ assert(arg.defaultVals.getValue().size() == size);
+ static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>";
+ static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}";
static const char paramFmt[] = "\"Attribute\":${0}";
static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});";
- attrDefs.push_back(llvm::formatv(
- defFmt, arg.attributeMap->affineMap().getNumResults(), arg.name));
+ std::string defaultVals;
+ llvm::raw_string_ostream ss(defaultVals);
+ ss << "{ ";
+ llvm::interleave(
+ arg.defaultVals.getValue(), ss,
+ [&](int64_t val) { ss << "static_cast<int64_t>(" << val << ")"; },
+ ", ");
+ ss << " }";
+ attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size),
+ ss.str(), arg.name));
attrParams.push_back(llvm::formatv(paramFmt, arg.name));
attrStmts.push_back(llvm::formatv(stmtFmt, arg.name));
}
@@ -725,7 +744,7 @@ generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
// Compute the number of scalar and tensor arguments.
int64_t numOfArgs =
llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
- return arg.usage != LinalgOperandDefUsage::attribute;
+ return arg.usage != LinalgOperandDefUsage::IndexAttr;
});
// An operation that accesses only scalars and scalar/rank zero tensors is
@@ -796,11 +815,11 @@ exprs.push_back(getAffineConstantExpr(cst{1}, context));
)FMT";
// Update all symbol bindings mapped to an attribute.
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
- if (arg.usage != LinalgOperandDefUsage::attribute)
+ if (arg.usage != LinalgOperandDefUsage::IndexAttr)
continue;
- assert(arg.attributeMap.hasValue());
+ assert(arg.indexAttrMap.hasValue());
for (auto &en :
- llvm::enumerate(arg.attributeMap->affineMap().getResults())) {
+ llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) {
if (auto symbol = en.value().dyn_cast<AffineSymbolExpr>()) {
symbolBindings[symbol.getPosition()] =
llvm::formatv(structuredOpAccessAttrFormat, arg.name,
@@ -889,31 +908,26 @@ std::string {0}::getLibraryCallName() {{
// hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes()
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
- return arg.usage == LinalgOperandDefUsage::attribute;
+ return arg.usage == LinalgOperandDefUsage::IndexAttr;
})) {
std::vector<std::string> attrVerifications;
for (LinalgOperandDef &arg : opConfig.structuredOp->args) {
- if (arg.usage != LinalgOperandDefUsage::attribute)
+ if (arg.usage != LinalgOperandDefUsage::IndexAttr)
continue;
- assert(arg.attributeMap.hasValue() && arg.typeVar == "I64");
+ assert(arg.indexAttrMap.hasValue());
// Verify index attribute. Paramters:
// {0}: Attribute name
// {1}: Attribute size
static const char attrFmt[] = R"FMT(
if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
if (!attr.getType().getElementType().isInteger(64))
- return op->emitError(
- "incorrect element type for indexing map required attribute '{0}'");
+ return op->emitError("incorrect element type for index attribute '{0}'");
if (attr.getType().getShape() != ArrayRef<int64_t>{{ {1} })
- return op->emitError(
- "incorrect shape for indexing map required attribute '{0}'");
-} else {
- return op->emitError(
- "missing indexing map required attribute '{0}'");
+ return op->emitError("incorrect shape for index attribute '{0}'");
}
)FMT";
attrVerifications.push_back(llvm::formatv(
- attrFmt, arg.name, arg.attributeMap->affineMap().getNumResults()));
+ attrFmt, arg.name, arg.indexAttrMap->affineMap().getNumResults()));
}
// Generates the verifyIndexingMapRequiredAttributes method. Parameters:
@@ -953,7 +967,7 @@ void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {{
int localCounter = 0;
SmallVector<std::string> stmts;
for (LinalgOperandDef &arg : args) {
- if (arg.usage != LinalgOperandDefUsage::output)
+ if (arg.usage != LinalgOperandDefUsage::Output)
continue;
// Find the assignment that correlates with the argument.
More information about the Mlir-commits
mailing list