[Mlir-commits] [mlir] 4361bd9 - [mlir][linalg][python] Explicit shape and dimension order in OpDSL.

Tobias Gysi llvmlistbot at llvm.org
Wed Jun 30 02:17:20 PDT 2021


Author: Tobias Gysi
Date: 2021-06-30T08:59:39Z
New Revision: 4361bd9b7b38c73b69f9a37e52d0b72989e84947

URL: https://github.com/llvm/llvm-project/commit/4361bd9b7b38c73b69f9a37e52d0b72989e84947
DIFF: https://github.com/llvm/llvm-project/commit/4361bd9b7b38c73b69f9a37e52d0b72989e84947.diff

LOG: [mlir][linalg][python] Explicit shape and dimension order in OpDSL.

Extend the OpDSL syntax with an optional `domain` function to specify an explicit dimension order. The extension is needed to provide more control over the dimension order instead of deducing it implicitly depending on the formulation of the tensor comprehension. Additionally, the patch also ensures the symbols are ordered according to the operand definitions of the operation.

Differential Revision: https://reviews.llvm.org/D105117

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/python/dialects/linalg/opdsl/arguments.py
    mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
    mlir/test/python/dialects/linalg/opdsl/interfaces.py
    mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 82e4d01c4a72c..e536b44fe6fb2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -1,3 +1,4 @@
+
 --- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matmul
@@ -15,17 +16,17 @@ structured_op: !LinalgStructuredOpConfig
     name: A
     usage: InputOperand
     type_var: T1
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
   - !LinalgOperandDefConfig
     name: B
     usage: InputOperand
     type_var: T2
-    shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
+    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
   - !LinalgOperandDefConfig
     name: C
     usage: OutputOperand
     type_var: U
-    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
@@ -77,17 +78,17 @@ structured_op: !LinalgStructuredOpConfig
     name: A
     usage: InputOperand
     type_var: T1
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
   - !LinalgOperandDefConfig
     name: B
     usage: InputOperand
     type_var: T2
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
   - !LinalgOperandDefConfig
     name: C
     usage: OutputOperand
     type_var: U
-    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
@@ -201,17 +202,17 @@ structured_op: !LinalgStructuredOpConfig
     name: y
     usage: InputOperand
     type_var: T1
-    shape_map: affine_map<()[s0, s1] -> (s1)>
+    shape_map: affine_map<()[s0, s1] -> (s0)>
   - !LinalgOperandDefConfig
     name: A
     usage: InputOperand
     type_var: T2
-    shape_map: affine_map<()[s0, s1] -> (s1, s0)>
+    shape_map: affine_map<()[s0, s1] -> (s0, s1)>
   - !LinalgOperandDefConfig
     name: x
     usage: OutputOperand
     type_var: U
-    shape_map: affine_map<()[s0, s1] -> (s0)>
+    shape_map: affine_map<()[s0, s1] -> (s1)>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1)[s0, s1] -> (d1)>
@@ -321,19 +322,19 @@ structured_op: !LinalgStructuredOpConfig
     usage: InputOperand
     type_var: T1
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
-      (s0, s4, s5, s3)>
+      (s0, s1, s2, s3)>
   - !LinalgOperandDefConfig
     name: K
     usage: InputOperand
     type_var: T2
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
-      (s6, s7, s3)>
+      (s4, s5, s3)>
   - !LinalgOperandDefConfig
     name: O
     usage: OutputOperand
     type_var: U
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
-      (s0, s1, s2, s3)>
+      (s0, s6, s7, s3)>
   - !LinalgOperandDefConfig
     name: strides
     usage: IndexAttribute
@@ -349,18 +350,18 @@ structured_op: !LinalgStructuredOpConfig
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
-      s10, s11] -> (d0, d1 * s8 + d4 * s10, d2 * s9 + d5 * s11, d3)>
+      s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)>
     - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
-      s10, s11] -> (d4, d5, d3)>
+      s10, s11] -> (d3, d4, d5)>
     - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
-      s10, s11] -> (d0, d1, d2, d3)>
+      s10, s11] -> (d0, d1, d2, d5)>
   iterator_types:
   - parallel
   - parallel
   - parallel
-  - parallel
   - reduction
   - reduction
+  - parallel
   assignments:
   - !ScalarAssign
     arg: O
@@ -402,45 +403,45 @@ structured_op: !LinalgStructuredOpConfig
     usage: InputOperand
     type_var: T1
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
-      (s0, s4, s5, s3)>
+      (s0, s1, s2, s3)>
   - !LinalgOperandDefConfig
     name: K
     usage: InputOperand
     type_var: T2
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
-      (s10, s11)>
+      (s4, s5)>
   - !LinalgOperandDefConfig
     name: O
     usage: OutputOperand
     type_var: U
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
-      (s0, s1, s2, s3)>
+      (s0, s6, s7, s3)>
   - !LinalgOperandDefConfig
     name: strides
     usage: IndexAttribute
     type_var: I64
     attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
-      -> (s6, s7)>
+      -> (s8, s9)>
   - !LinalgOperandDefConfig
     name: dilations
     usage: IndexAttribute
     type_var: I64
     attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
-      -> (s8, s9)>
+      -> (s10, s11)>
   indexing_maps: !LinalgIndexingMapsConfig
     static_indexing_maps:
     - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
-      s10, s11] -> (d2, d3 * s6 + d0 * s8, d4 * s7 + d1 * s9, d5)>
+      s10, s11] -> (d0, d1 * s8 + d3 * s10, d2 * s9 + d4 * s11, d5)>
     - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
-      s10, s11] -> (d0, d1)>
+      s10, s11] -> (d3, d4)>
     - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
-      s10, s11] -> (d2, d3, d4, d5)>
+      s10, s11] -> (d0, d1, d2, d5)>
   iterator_types:
-  - reduction
-  - reduction
   - parallel
   - parallel
   - parallel
+  - reduction
+  - reduction
   - parallel
   assignments:
   - !ScalarAssign

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index e89885e975d65..1f9230de397a2 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -32,13 +32,13 @@ def visit_tensor_exprs(self, callback):
     """Visits all tensor expression reachable by the expression."""
     callback(self)
 
-  def _get_all_dim_defs(self) -> Set[DimDef]:
-    """Recursively gets all DimDef affine expressions that are referenced."""
+  def collect_dim_uses(self, uses: Set["DimDef"]):
+    """Collects all DimDefs reachable through this expression."""
     results = set()
 
     def visit_dim_def(dim_def):
       if isinstance(dim_def, DimDef):
-        results.add(dim_def)
+        uses.add(dim_def)
 
     def visit_affine_exprs(expr):
       if isinstance(expr, TensorUse):
@@ -49,7 +49,6 @@ def visit_affine_exprs(expr):
           ind.visit_affine_exprs(visit_dim_def)
 
     self.visit_tensor_exprs(visit_affine_exprs)
-    return results
 
   def collect_tensor_uses(self, uses: Set["TensorUse"]):
     """Collects all TensorUses reachable through this expression."""
@@ -126,8 +125,10 @@ def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
     reduced into. Any indices referenced on the rhs and not in self are
     considered reduction dims and will be ordered as encountered on the rhs.
     """
-    rhs_dims = rhs._get_all_dim_defs()
-    lhs_dims = self._get_all_dim_defs()
+    rhs_dims = set()
+    lhs_dims = set()
+    rhs.collect_dim_uses(rhs_dims)
+    self.collect_dim_uses(lhs_dims)
     return rhs_dims - lhs_dims
 
   def __repr__(self):
@@ -202,7 +203,7 @@ def __init__(self,
                        f"number of index_dims {len(index_dims)}")
     if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
       raise ValueError(f"TensorDef requires index dims of type DimDef but "
-                       f"got {type(index_dims)}")
+                       f"got {index_dims}")
     kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
     self.operand_def = OperandDef(
         kind, type_var, size_exprs=shape, index_dims=index_dims)
@@ -273,7 +274,7 @@ class AttributeDef:
   def __init__(self, *sizes: SymbolDef):
     if any(not isinstance(size, SymbolDef) for size in sizes):
       raise ValueError(f"AttributeDef requires sizes of type SymbolDef but got "
-                       f"{type(sizes)}")
+                       f"{sizes}")
     self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes)
 
 
@@ -516,6 +517,7 @@ def __init__(self,
     self.metadata = OpMetadataDef(
         name=name, cpp_class_name=cpp_class_name, doc=doc)
     self.registered_operands = dict()  # type: Dict[str, OperandDef]
+    self.domain = list()  # type: List[DimDef]
     self.comprehensions = list()  # type: List[Comprehension]
     self._affine_state = AffineBuildState()
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index 78e6f1d6a3083..f6d5248ea00fb 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -115,6 +115,7 @@ class LinalgStructuredOpConfig(YAMLObject):
 
   def __init__(self,
                comprehension: Comprehension,
+               domain: Sequence[DimDef],
                registered_operands: Sequence[OperandDef],
                context: Optional[_ir.Context] = None):
     self.context = context if context is not None else _ir.Context()
@@ -123,10 +124,11 @@ def __init__(self,
     self.operands = dict()  # type: Dict[OperandDef, OperandDefConfig]
     self.uses = dict()  # type: Dict[TensorUse, TensorUseConfig]
 
-    # Compute the ordered set of writes and collect the tensor, capture, and
-    # index uses.
+    # Compute the ordered set of writes and collect the tensor, capture, dims,
+    # and index uses.
     collected_tensor_uses = set()
     collected_scalar_uses = set()
+    collected_dim_uses = set()
     collected_indices = set()
     for write_use, read_use in zip(comprehension.definitions,
                                    comprehension.values):
@@ -136,8 +138,28 @@ def __init__(self,
       collected_tensor_uses.add(write_use)
       read_use.collect_tensor_uses(collected_tensor_uses)
       read_use.collect_scalar_uses(collected_scalar_uses)
+      read_use.collect_dim_uses(collected_dim_uses)
+      write_use.collect_dim_uses(collected_dim_uses)
       read_use.collect_indices(collected_indices)
 
+    # Set domain to the sorted list of uses if no domain annotation is given.
+    if not domain:
+      domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname)
+
+    # Verify the domain dimensions match the used dimensions.
+    if (len(domain) != len(collected_dim_uses) or
+        any(dim not in collected_dim_uses for dim in domain)):
+      raise ValueError(f"Expected the annotated domain dimensions {domain} to "
+                       f"match the set of dimension used by the tensor "
+                       f"comprehension {collected_dim_uses}")
+
+    # Instantiate the dimensions in the given order.
+    with self.context:
+      local_state = AffineBuildState(
+          global_state=self.affine_state, allow_new_symbols=False)
+      for dim in domain:
+        dim.build(state=local_state)
+
     # Collect all attribute definitions.
     collected_attr_defs = list()
     for operand in registered_operands:
@@ -148,18 +170,32 @@ def __init__(self,
     collected_index_defs = list()
     for operand in registered_operands:
       if operand.index_dims:
+        if any(dim not in collected_dim_uses for dim in operand.index_dims):
+          raise ValueError(f"Expected all index dims {operand.index_dims} of "
+                           f"operand {operand.name} to have uses.")
         collected_index_defs.append(operand)
 
-    # Add all definitions before uses, so process twice.
+    # Collect the operand definitions of all tensor/scalar uses, attributes, and
+    # shape-only tensors.
+    all_operand_defs = list()
     for use in collected_tensor_uses:
-      self.add_operand(use.operand_def)
+      all_operand_defs.append(use.operand_def)
     for use in collected_scalar_uses:
-      self.add_operand(use.operand_def)
+      all_operand_defs.append(use.operand_def)
     for definition in collected_attr_defs:
-      self.add_operand(definition)
+      all_operand_defs.append(definition)
+    for definition in collected_index_defs:
+      all_operand_defs.append(definition)
+
+    # Add all operands in registration order to ensure the symbols are
+    # registered in the order they appear.
+    all_operand_defs = sorted(
+        all_operand_defs, key=lambda operand_def: operand_def.registered_index)
+    for operand_def in all_operand_defs:
+      self.add_operand(operand_def)
+
+    # Add all shape-only tensor index_dim annotations and all tensor uses.
     for definition in collected_index_defs:
-      if definition not in self.operands:
-        self.add_operand(definition)
       self.add_indexed_operand(definition)
     for use in collected_tensor_uses:
       self.add_tensor_use(use)
@@ -396,7 +432,7 @@ def from_linalg_op_def(
         LinalgOpConfig(
             tc_op_def.metadata,
             structured_op=LinalgStructuredOpConfig(
-                tc_op_def.comprehensions[0],
+                tc_op_def.comprehensions[0], tc_op_def.domain,
                 tc_op_def.registered_operands.values(), context)),
     ]
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 6dbda1bb7ecbe..1b42b57670448 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -132,3 +132,11 @@ def linalg_structured_op(dsl_func=None,
 
 def implements(*interfaces: OpInterfaceDef):
   current_op_def().metadata.implements.extend(interfaces)
+
+
+def domain(*dimensions: DimDef):
+  if current_op_def().domain:
+    raise ValueError(f"Expected only one set of domain dimensions per operator")
+  if any(not isinstance(dim, DimDef) for dim in dimensions):
+    raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
+  current_op_def().domain.extend(dimensions)

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 253fca4b41690..5867109279aa4 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -16,6 +16,7 @@ def matmul(
   Numeric casting is performed on the operands to the inner multiply, promoting
   them to the same data type as the accumulator/output.
   """
+  domain(D.m, D.n, D.k)
   implements(ContractionOpInterface)
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
@@ -30,6 +31,7 @@ def batch_matmul(
   Numeric casting is performed on the operands to the inner multiply, promoting
   them to the same data type as the accumulator/output.
   """
+  domain(D.b, D.m, D.n, D.k)
   implements(ContractionOpInterface)
   C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n])
 
@@ -44,6 +46,7 @@ def matvec(
   Numeric casting is performed on the operands to the inner multiply, promoting
   them to the same data type as the accumulator/output.
   """
+  domain(D.m, D.n)
   implements(ContractionOpInterface)
   x[D.m] += cast(U, A[D.m, D.n]) * cast(U, y[D.n])
 
@@ -58,6 +61,7 @@ def vecmat(
   Numeric casting is performed on the operands to the inner multiply, promoting
   them to the same data type as the accumulator/output.
   """
+  domain(D.n, D.m)
   implements(ContractionOpInterface)
   x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n])
 
@@ -86,6 +90,7 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
   Numeric casting is performed on the operands to the inner multiply, promoting
   them to the same data type as the accumulator/output.
   """
+  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
   O[D.n, D.oh, D.ow, D.c] += cast(
       U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
            D.c]) * cast(U, K[D.kh, D.kw, D.c])
@@ -103,6 +108,7 @@ def pooling_nhwc_sum_poly(
   Numeric casting is performed on the input operand, promoting it to the same
   data type as the accumulator/output.
   """
+  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
   O[D.n, D.oh, D.ow, D.c] += cast(
       U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
 
@@ -123,6 +129,7 @@ def fill_rng_2d(
   element seed the random number generation. The min and max operands limit
   the range of the generated random numbers.
   """
+  domain(D.m, D.n)
   multiplier = cast(I32, const(1103515245))
   increment = cast(I32, const(12345))
   rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment

diff  --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
index a70e3cdeca99b..572c811d93a4b 100644
--- a/mlir/test/python/dialects/linalg/opdsl/arguments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -9,15 +9,15 @@
 # CHECK:     name: A
 # CHECK:     usage: InputOperand
 # CHECK:     type_var: T
-# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
 # CHECK:     name: B
 # CHECK:     usage: InputOperand
 # CHECK:     type_var: T
-# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
+# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
 # CHECK:     name: C
 # CHECK:     usage: OutputOperand
 # CHECK:     type_var: U
-# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
 @linalg_structured_op
 def matmul(
     A=TensorDef(T, S.M, S.K),
@@ -44,11 +44,11 @@ def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
 # CHECK:     name: I
 # CHECK:     usage: InputOperand
 # CHECK:     type_var: T
-# CHECK:     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
+# CHECK:     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
 # CHECK:     name: O
 # CHECK:     usage: OutputOperand
 # CHECK:     type_var: T
-# CHECK:     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
+# CHECK:     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
 # CHECK:     name: strides
 # CHECK:     usage: IndexAttribute
 # CHECK:     type_var: I64
@@ -58,4 +58,4 @@ def strided_copy(
     I=TensorDef(T, S.IH, S.IW),
     O=TensorDef(T, S.OH, S.OW, output=True),
     strides=AttributeDef(S.SH, S.SW)):
-  O[D.oh, D.ow] = I[D.h * S.SH, D.w * S.SW]
+  O[D.oh, D.ow] = I[D.oh * S.SH, D.ow * S.SW]

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index cbe88dd043f73..f7db532dced5c 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -16,6 +16,7 @@ def matmul_mono(
     A=TensorDef(T, S.M, S.K),
     B=TensorDef(T, S.K, S.N),
     C=TensorDef(T, S.M, S.N, output=True)):
+  domain(D.m, D.n, D.k)
   C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n]
 
 
@@ -24,6 +25,7 @@ def matmul_poly(
     A=TensorDef(T1, S.M, S.K),
     B=TensorDef(T2, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True)):
+  domain(D.m, D.n, D.k)
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
 
@@ -34,6 +36,7 @@ def conv_poly(
     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
     strides=AttributeDef(S.SH, S.SW),
     dilations=AttributeDef(S.DH, S.DW)):
+  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
   O[D.n, D.oh, D.ow, D.c] += cast(
       U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
            D.c]) * cast(U, K[D.kh, D.kw, D.c])
@@ -46,6 +49,7 @@ def pooling_poly(
     O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
     strides=AttributeDef(S.SH, S.SW),
     dilations=AttributeDef(S.DH, S.DW)):
+  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
   O[D.n, D.oh, D.ow, D.c] += cast(
       U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
 
@@ -84,14 +88,12 @@ def fill_rng_poly(
     # CHECK: #[[$MUL_MAP_C:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
     # Convolution indexing maps.
-    # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 4 + d5 * 2, d3)>
-    # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
-    # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+    # CHECK: #[[$CONV_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d3, d2 * 4 + d4 * 2, d5)>
+    # CHECK: #[[$CONV_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+    # CHECK: #[[$CONV_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>
 
     # Pooling indexing maps.
-    # CHECK: #[[$POOL_MAP_I:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3 * 2 + d0, d4 * 4 + d1 * 2, d5)>
-    # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>
-    # CHECK: #[[$POOL_MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
+    # CHECK: #[[$POOL_MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>
 
     # CHECK-LABEL: func @test_matmul_mono
     # CHECK-SAME:  %[[A:.+]]: tensor<4x16xf32>
@@ -197,7 +199,7 @@ def test_f64f64f32_matmul(lhs, rhs, init_result):
     # CHECK-LABEL: @test_f32i32_conv
     # CHECK: linalg.generic
     # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$CONV_MAP_K]], #[[$CONV_MAP_O]]]
-    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]
+    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
     # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32)
     # CHECK-NEXT:   %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
     # CHECK-NEXT:   %[[FILTER_CAST:.+]] = fptosi %[[FILTER:.+]] : f32 to i32
@@ -215,8 +217,8 @@ def test_f32i32_conv(input, filter, init_result):
 
     # CHECK-LABEL: @test_f32i32_pooling
     # CHECK: linalg.generic
-    # CHECK-SAME: indexing_maps = [#[[$POOL_MAP_I]], #[[$POOL_MAP_K]], #[[$POOL_MAP_O]]]
-    # CHECK-SAME: iterator_types = ["reduction", "reduction", "parallel", "parallel", "parallel", "parallel"]
+    # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
+    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
     # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
     # CHECK-NEXT:   %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
     # CHECK-NEXT:   %[[SUM:.+]] = addi %[[OUT]], %[[IN_CAST]] : i32

diff  --git a/mlir/test/python/dialects/linalg/opdsl/interfaces.py b/mlir/test/python/dialects/linalg/opdsl/interfaces.py
index 46689a07bbbb9..6d75bfcbeefd4 100644
--- a/mlir/test/python/dialects/linalg/opdsl/interfaces.py
+++ b/mlir/test/python/dialects/linalg/opdsl/interfaces.py
@@ -2,13 +2,15 @@
 
 from mlir.dialects.linalg.opdsl.lang import *
 
+
 # CHECK: ---
 # CHECK-LABEL: matmul
 # CHECK:      implements:
 # CHECK-NEXT: - LinalgContractionOpInterface
 @linalg_structured_op
-def matmul(A=TensorDef(T, S.M, S.K),
-           B=TensorDef(T, S.K, S.N),
-           C=TensorDef(U, S.M, S.N, output=True)):
+def matmul(
+    A=TensorDef(T, S.M, S.K),
+    B=TensorDef(T, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True)):
   implements(ContractionOpInterface)
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])

diff  --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
index 2933852f97cfe..fbb82f79f6d9e 100644
--- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
+++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
@@ -7,9 +7,9 @@
 # dims auto discovered emits the right shape, indexing maps and iterator types.
 # CHECK: ---
 # CHECK-LABEL: matmul
-# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
-# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
 # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
 # CHECK: static_indexing_maps:
 # CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
 # CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
@@ -23,6 +23,7 @@ def matmul(
     A=TensorDef(T, S.M, S.K),
     B=TensorDef(T, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True)):
+  domain(D.m, D.n, D.k)
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
 
@@ -43,22 +44,25 @@ def matmul(
 def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)):
   C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
 
+
 # Verifies that the index_dims of shape-only operands translate to correct
 # indexing maps.
 # CHECK: ---
 # CHECK-LABEL: pool
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0)>
 # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s1)>
 # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2)>
-# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0)>
 # CHECK: static_indexing_maps:
-# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d1 * 2 + d0)>
-# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0)>
+# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0 * 2 + d1)>
 # CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d1)>
+# CHECK-NEXT: - affine_map<(d0, d1)[s0, s1, s2] -> (d0)>
 # CHECK: iterator_types:
-# CHECK-NEXT: - reduction
 # CHECK-NEXT: - parallel
+# CHECK-NEXT: - reduction
 @linalg_structured_op
-def pool(I=TensorDef(T, S.I),
-         K=TensorDef(T, S.K, index_dims=[D.k]),
-         O=TensorDef(U, S.O, output=True)):
+def pool(
+    I=TensorDef(T, S.I),
+    K=TensorDef(T, S.K, index_dims=[D.k]),
+    O=TensorDef(U, S.O, output=True)):
+  domain(D.o, D.k)
   O[D.o] += cast(U, I[D.o * 2 + D.k])


        


More information about the Mlir-commits mailing list