[Mlir-commits] [mlir] 31f888e - [mlir][linalg][python] Add attribute support to the OpDSL.

Tobias Gysi llvmlistbot at llvm.org
Thu Jun 24 02:41:41 PDT 2021


Author: Tobias Gysi
Date: 2021-06-24T09:40:32Z
New Revision: 31f888ea9af452ae312c270e569d9fbe23c57c9f

URL: https://github.com/llvm/llvm-project/commit/31f888ea9af452ae312c270e569d9fbe23c57c9f
DIFF: https://github.com/llvm/llvm-project/commit/31f888ea9af452ae312c270e569d9fbe23c57c9f.diff

LOG: [mlir][linalg][python] Add attribute support to the OpDSL.

Extend the OpDSL with index attributes. After tensors and scalars, index attributes are the third operand type. An index attribute represents a compile-time constant that is limited to index expressions. A use cases are the strides and dilations defined by convolution and pooling operations.

The patch only updates the OpDSL. The C++ yaml codegen is updated by a followup patch.

Differential Revision: https://reviews.llvm.org/D104711

Added: 
    

Modified: 
    mlir/include/mlir-c/AffineMap.h
    mlir/lib/Bindings/Python/IRAffine.cpp
    mlir/lib/CAPI/IR/AffineMap.cpp
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/python/dialects/linalg/opdsl/arguments.py
    mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
    mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
    mlir/test/python/dialects/linalg/opsrun.py
    mlir/test/python/ir/affine_map.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h
index e35b7cc6b51d5..7359b969127c7 100644
--- a/mlir/include/mlir-c/AffineMap.h
+++ b/mlir/include/mlir-c/AffineMap.h
@@ -169,6 +169,13 @@ mlirAffineMapGetMajorSubMap(MlirAffineMap affineMap, intptr_t numResults);
 MLIR_CAPI_EXPORTED MlirAffineMap
 mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, intptr_t numResults);
 
+/// Apply AffineExpr::replace(`map`) to each of the results and return a new
+/// new AffineMap with the new results and the specified number of dims and
+/// symbols.
+MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapReplace(
+    MlirAffineMap affineMap, MlirAffineExpr expression,
+    MlirAffineExpr replacement, intptr_t numResultDims, intptr_t numResultSyms);
+
 /// Returns the simplified affine map resulting from dropping the symbols that
 /// do not appear in any of the individual maps in `affineMaps`.
 /// Asserts that all maps in `affineMaps` are normalized to the same number of

diff  --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 5d3b790b35d0e..0a2a5666a9e47 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -654,6 +654,14 @@ void mlir::python::populateIRAffine(py::module &m) {
                  mlirAffineMapGetMinorSubMap(self, nResults);
              return PyAffineMap(self.getContext(), affineMap);
            })
+      .def("replace",
+           [](PyAffineMap &self, PyAffineExpr &expression,
+              PyAffineExpr &replacement, intptr_t numResultDims,
+              intptr_t numResultSyms) {
+             MlirAffineMap affineMap = mlirAffineMapReplace(
+                 self, expression, replacement, numResultDims, numResultSyms);
+             return PyAffineMap(self.getContext(), affineMap);
+           })
       .def_property_readonly(
           "is_permutation",
           [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })

diff  --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp
index e0c07afc3b75e..85557bc576f61 100644
--- a/mlir/lib/CAPI/IR/AffineMap.cpp
+++ b/mlir/lib/CAPI/IR/AffineMap.cpp
@@ -138,6 +138,15 @@ MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap,
   return wrap(unwrap(affineMap).getMinorSubMap(numResults));
 }
 
+MlirAffineMap mlirAffineMapReplace(MlirAffineMap affineMap,
+                                   MlirAffineExpr expression,
+                                   MlirAffineExpr replacement,
+                                   intptr_t numResultDims,
+                                   intptr_t numResultSyms) {
+  return wrap(unwrap(affineMap).replace(unwrap(expression), unwrap(replacement),
+                                        numResultDims, numResultSyms));
+}
+
 void mlirAffineMapCompressUnusedSymbols(
     MlirAffineMap *affineMaps, intptr_t size, void *result,
     void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)) {

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index fe067d6947138..2b2f57248c515 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -9,6 +9,7 @@
 """
 
 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
+from enum import Enum
 
 from mlir import ir as _ir
 
@@ -133,18 +134,31 @@ def __repr__(self):
     return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]"
 
 
+class OperandKind(Enum):
+  InputTensor = 0
+  Scalar = 1
+  OutputTensor = 2
+  Attribute = 3
+
+
 class OperandDef:
-  """Definition of a Tensor or Scalar operand passed to an operation."""
+  """Definition of an operand passed to an operation.
+
+  Keep the meta information of Tensor, Scalar, and Attribute operands and
+  provide the shared registration functionality.
+  """
 
-  def __init__(self, type_var: TypeVar, shape: Sequence[AffineExprDef],
-               scalar: bool, output: bool):
+  def __init__(self,
+               kind: OperandKind,
+               type_var: TypeVar,
+               size_exprs: Optional[Sequence[AffineExprDef]] = None):
     if not isinstance(type_var, TypeVar):
-      raise ValueError(f"OperandDef requires a TypeVar. Got: {repr(type_var)}")
+      raise ValueError(
+          f"OperandDef requires a TypeVar but got {repr(type_var)}")
     self.owner = None  # type: Optional["LinalgOpDef"]
     self.type_var = type_var
-    self.shape = shape
-    self.scalar = scalar
-    self.output = output
+    self.size_exprs = size_exprs
+    self.kind = kind
     self.name = None  # type: Optional[str]
     self.registered_index = -1  # type: int
 
@@ -159,10 +173,8 @@ def __hash__(self):
     return hash(id(self))
 
   def __repr__(self):
-    output = "OUTPUT " if self.output else ""
-    scalar = "SCALAR " if self.scalar else ""
-    return (f"{self.name}:OperandDef({output}{scalar}"
-            f"{repr(self.type_var)}, shape={self.shape})")
+    return (f"{self.name}:OperandDef(kind={self.kind.name}, "
+            f"type={repr(self.type_var)}, size_exprs={self.size_exprs})")
 
 
 class TensorDef:
@@ -170,14 +182,17 @@ class TensorDef:
 
   Tensor operands are indexed using the associated indexing_map when forwarded
   to the body of the structured op. A unique name identifies the tensor operands
-  and an index determines their position in the operation's parameter list.
+  and an index determines their position in the operation's parameter list. A
+  tensor definition takes type, a shape, and an optional flag to mark output
+  tensors.
   """
 
   def __init__(self,
                type_var: TypeVar,
                *shape: AffineExprDef,
                output: bool = False):
-    self.operand_def = OperandDef(type_var, shape, False, output)
+    kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
+    self.operand_def = OperandDef(kind, type_var, size_exprs=shape)
 
   def __getitem__(self, dims) -> TensorUse:
     assert self.operand_def.owner, "TensorDef is not attached to an op"
@@ -221,7 +236,7 @@ class ScalarDef(TensorExpression):
   """
 
   def __init__(self, type_var: TypeVar):
-    self.operand_def = OperandDef(type_var, (), True, False)
+    self.operand_def = OperandDef(OperandKind.Scalar, type_var)
 
   @property
   def scalar_name(self) -> str:
@@ -233,6 +248,22 @@ def to_scalar_expression(self) -> ScalarExpression:
     return ScalarArg(self.scalar_name).expr()
 
 
+class AttributeDef:
+  """Index Attribute definition.
+
+  Index attributes provide a way to define and set symbols that can be used in
+  indexing expressions. Every attribute specifies a tuple of symbols that at
+  compile-time are replaced by integer values.
+  """
+  yaml_tag = "!LinalgAttributeDef"
+
+  def __init__(self, *sizes: SymbolDef):
+    if any(not isinstance(size, SymbolDef) for size in sizes):
+      raise ValueError(f"AttributeDef requires sizes of type SymbolDef but got "
+                       f"{type(sizes)}")
+    self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes)
+
+
 class Comprehension:
   """Represents a single comprehension."""
 
@@ -303,7 +334,7 @@ class ReduceFnType:
   def __init__(self, operator: PrimFnType, *reduce_dims: DimDef):
     """Initializes the ReduceFn with a primitive function and dims."""
     if not isinstance(operator, PrimFnType):
-      raise ValueError(f"Reduce expected a Prim operator. Got: {operator}")
+      raise ValueError(f"Reduce expected a Prim operator but got {operator}")
     self.operator = operator
     self.reduce_dims = tuple(reduce_dims)
 
@@ -353,7 +384,7 @@ def __init__(self, value: Any):
         self.value = str(
             _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
       else:
-        raise ValueError(f"const requires int or float. Got: {type(value)}")
+        raise ValueError(f"const requires int or float but got {type(value)}")
 
   def to_scalar_expression(self) -> ScalarExpression:
     return ScalarConst(self.value).expr()
@@ -475,21 +506,22 @@ def __init__(self,
     self.comprehensions = list()  # type: List[Comprehension]
     self._affine_state = AffineBuildState()
 
-  @property
-  def outputs(self) -> Sequence[OperandDef]:
-    return [
-        operand for operand in self.registered_operands.values()
-        if operand.output
-    ]
-
   def add_operand(self, name: str, operand: OperandDef):
     """Registers an operand."""
     if name in self.registered_operands:
       raise ValueError(f"The operand {name} is already registered "
                        f"to {self.registered_operands['name']}")
-    if not operand.output and self.outputs:
-      raise ValueError(f"The operand {name} is an input registered after "
-                       f"the output {self.outputs[-1]}")
+    # Ensure output tensors are registered after input tensors and scalars and
+    # attributes are registered after all other operand types.
+    registered_kinds = [
+        operand.kind.value for operand in self.registered_operands.values()
+    ]
+    if registered_kinds:
+      maximum = max(registered_kinds)
+      if maximum > operand.kind.value and maximum > OperandKind.Scalar.value:
+        raise ValueError(
+            f"The operand {name} of kind {operand.kind.name} is registered "
+            f"after an operand of kind {OperandKind(maximum).name}")
     operand.attach(len(self.registered_operands), name, self)
     self.registered_operands[name] = operand
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index 6dd86334b95a5..773bd876397f9 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -45,9 +45,11 @@ class OperandDefConfig(YAMLObject):
 
   def __init__(self,
                operand_def: OperandDef,
-               shape_map: Optional[_ir.AffineMap] = None):
+               shape_map: Optional[_ir.AffineMap] = None,
+               attribute_map: Optional[_ir.AffineMap] = None):
     self.operand_def = operand_def
     self.shape_map = shape_map  # type: Optional[_ir.AffineMap]
+    self.attribute_map = attribute_map  # type: Optional[_ir.AffineMap]
     self.indexing_map = None  # type: Optional[_ir.AffineMap]
 
   @property
@@ -60,21 +62,25 @@ def type_var(self) -> TypeVar:
 
   @property
   def usage(self) -> str:
-    if self.operand_def.output:
-      return "output"
-    return "input"
+    if self.operand_def.kind == OperandKind.Attribute:
+      return "IndexAttribute"
+    if self.operand_def.kind == OperandKind.OutputTensor:
+      return "OutputOperand"
+    return "InputOperand"
 
   def to_yaml_custom_dict(self):
-    self_dict = dict(name=self.name)
-    self_dict["usage"] = self.usage
-    if not self.operand_def.scalar:
-      self_dict["shape"] = _serialize_affine_map(self.shape_map)
-    self_dict["type_var"] = self.type_var.name
+    self_dict = dict(
+        name=self.name, usage=self.usage, type_var=self.type_var.name)
+    if self.shape_map:
+      self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
+    if self.attribute_map:
+      self_dict["attribute_map"] = _serialize_affine_map(self.attribute_map)
     return self_dict
 
   def __repr__(self):
     return (f"OperandDefConfig({self.operand_def}, "
-            f"shape_map={self.shape_map}, indexing_map={self.indexing_map})")
+            f"shape_map={self.shape_map}, attribute_map={self.attribute_map}, "
+            f"indexing_map={self.indexing_map})")
 
 
 class LinalgIndexingMapsConfig(YAMLObject):
@@ -109,6 +115,7 @@ class LinalgStructuredOpConfig(YAMLObject):
 
   def __init__(self,
                comprehension: Comprehension,
+               registered_operands: Sequence[OperandDef],
                context: Optional[_ir.Context] = None):
     self.context = context if context is not None else _ir.Context()
     self.affine_state = AffineBuildState()
@@ -131,22 +138,33 @@ def __init__(self,
       read_use.collect_scalar_uses(collected_scalar_uses)
       read_use.collect_indices(collected_indices)
 
-    # Need to add all definitions before uses, so process twice.
+    # Collect all attribute definitions
+    collected_attr_defs = list()
+    for operand in registered_operands:
+      if operand.kind == OperandKind.Attribute:
+        collected_attr_defs.append(operand)
+
+    # Add all definitions before uses, so process twice.
     for use in collected_tensor_uses:
       self.add_operand(use.operand_def)
     for use in collected_scalar_uses:
       self.add_operand(use.operand_def)
+    for definition in collected_attr_defs:
+      self.add_operand(definition)
     for use in collected_tensor_uses:
       self.add_tensor_use(use)
 
-    # Now normalize all defs and uses indexing maps now that full count of
-    # dims and symbols are known.
+    # Normalize all shape and indexing maps now that full count of dims and
+    # symbols are known.
     for cuse in self.uses.values():
       cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
-    for cdef in self.operands.values():
-      if not cdef.operand_def.scalar:
-        cdef.shape_map = self._normalize_affine_map(
-            cdef.shape_map, with_dims=False)
+    for operand_config in self.operands.values():
+      if operand_config.shape_map:
+        operand_config.shape_map = self._normalize_affine_map(
+            operand_config.shape_map, with_dims=False)
+      if operand_config.attribute_map:
+        operand_config.attribute_map = self._normalize_affine_map(
+            operand_config.attribute_map, with_dims=False)
 
     # Now for each write use, propagate the indexing maps from the use to the
     # tensor, ensuring that there are not conflicts.
@@ -174,12 +192,16 @@ def __init__(self,
 
     # Set the indexing map of all scalar uses to the empty map.
     for operand_config in self.operands.values():
-      if operand_config.operand_def.scalar:
-        operand_config.indexing_map = self._create_empty_affine_map()
+      if operand_config.operand_def.kind == OperandKind.Scalar:
+        operand_config.indexing_map = self._get_scalar_map()
 
-    # Sanity check that all defs have an indexing map.
-    assert all(d.indexing_map for d in self.operands.values()), (
-        f"Missing indexing map on OperandConfigDef: {self.operands}")
+    # Check all registered tensor and scalar operands have an indexing map.
+    for operand in registered_operands:
+      if operand.kind == OperandKind.Attribute:
+        continue
+      if not (operand in self.operands and self.operands[operand].indexing_map):
+        raise ValueError(f"Failed to compute an indexing map for operand "
+                         f"{operand.name}")
 
     # Collect reduction dims and ensure all the same.
     all_reduction_dims = set(comprehension.all_reduction_dims)
@@ -189,7 +211,7 @@ def __init__(self,
           f"dims. Got: {all_reduction_dims}")
     self.reduction_dims = next(iter(all_reduction_dims))
 
-    # Check the index dimension exists and resolve
+    # Check the index dimension exists and resolve.
     for index in collected_indices:
       if index.dim_def.dimname not in self.affine_state.all_dims:
         raise ValueError(
@@ -221,7 +243,7 @@ def ordered_dims(self) -> Sequence[Tuple[str, int]]:
 
   @property
   def indexing_maps(self) -> Sequence[_ir.AffineMap]:
-    return [d.indexing_map for d in self.ordered_operands]
+    return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
 
   @property
   def iterator_types(self) -> Sequence[str]:
@@ -237,20 +259,24 @@ def get_type(symbolic_name, position):
   def add_operand(self, operand_def: OperandDef):
     if operand_def in self.operands:
       return
-    if operand_def.scalar:
+    if operand_def.kind == OperandKind.Scalar:
       self.operands[operand_def] = OperandDefConfig(operand_def)
       return
     with self.context:
       local_state = AffineBuildState(
           global_state=self.affine_state, allow_new_dims=False)
       exprs = []
-      for expr in operand_def.shape:
+      for expr in operand_def.size_exprs:
         exprs.append(expr.build(state=local_state))
       assert local_state.local_dim_count == 0
-      shape_map = _ir.AffineMap.get(
+      affine_map = _ir.AffineMap.get(
           dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
-      def_config = OperandDefConfig(operand_def, shape_map)
-      self.operands[operand_def] = def_config
+      if operand_def.kind == OperandKind.Attribute:
+        self.operands[operand_def] = OperandDefConfig(
+            operand_def, attribute_map=affine_map)
+      else:
+        self.operands[operand_def] = OperandDefConfig(
+            operand_def, shape_map=affine_map)
 
   def add_tensor_use(self, tensor_use: TensorUse):
     if tensor_use in self.uses:
@@ -261,7 +287,6 @@ def add_tensor_use(self, tensor_use: TensorUse):
       exprs = []
       for expr in tensor_use.indices:
         exprs.append(expr.build(state=local_state))
-      assert local_state.local_symbol_count == 0
       indexing_map = _ir.AffineMap.get(
           dim_count=local_state.dim_count,
           symbol_count=local_state.symbol_count,
@@ -270,8 +295,8 @@ def add_tensor_use(self, tensor_use: TensorUse):
       use_config = TensorUseConfig(tensor_use, indexing_map)
       self.uses[tensor_use] = use_config
 
-  def _create_empty_affine_map(self) -> _ir.AffineMap:
-    """Create an affine map with an empty range."""
+  def _get_scalar_map(self) -> _ir.AffineMap:
+    """Create an empty affine map used to index a scalar."""
     with self.context:
       return _ir.AffineMap.get(
           dim_count=self.affine_state.dim_count,
@@ -345,8 +370,9 @@ def from_linalg_op_def(
     return [
         LinalgOpConfig(
             tc_op_def.metadata,
-            structured_op=LinalgStructuredOpConfig(tc_op_def.comprehensions[0],
-                                                   context)),
+            structured_op=LinalgStructuredOpConfig(
+                tc_op_def.comprehensions[0],
+                tc_op_def.registered_operands.values(), context)),
     ]
 
   def __repr__(self):

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 191b1b34fd836..6dbda1bb7ecbe 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -44,15 +44,20 @@ def __init__(self, op_name: str, model: LinalgOpDef):
     self.op_name = op_name
     self.model = model
 
-  def __call__(self, *args, emit_generic: bool = False, **kwargs):
+  def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs):
     """Emits the corresponding op definition as IR.
 
     Most arguments are passed through to the underlying emitter. The following
-    are interpreted here:
+    keyword argument is interpreted here:
       emit_generic: Emits a generic form as appropriate (default True). If
         False, a named form is emitted (which must have been built in to the
         compiler).
     """
+    emit_generic = kwargs.pop("emit_generic", False)
+    if not isinstance(emit_generic, bool):
+      raise ValueError(f"The named argument 'emit_generic' needs to be "
+                       f" of type bool but got {type(emit_generic)}")
+
     op_configs = LinalgOpConfig.from_linalg_op_def(
         self.model, context=ir.Context.current)
 
@@ -70,12 +75,16 @@ def __call__(self, *args, emit_generic: bool = False, **kwargs):
     op_config = op_configs[0]
     if op_config.structured_op:
       if emit_generic:
-        return emit_generic_structured_op(op_config.structured_op, *args,
-                                          **kwargs)
+        return emit_generic_structured_op(
+            op_config.structured_op, *ins, outs=outs, **kwargs)
       else:
-        return emit_named_structured_op(op_config.structured_op, self.op_name,
-                                        self.model.metadata.cpp_class_name,
-                                        *args, **kwargs)
+        return emit_named_structured_op(
+            op_config.structured_op,
+            self.op_name,
+            self.model.metadata.cpp_class_name,
+            *ins,
+            outs=outs,
+            **kwargs)
 
     raise NotImplementedError(
         f"Emission of linalg op type not supported: {op_config}")
@@ -104,14 +113,12 @@ def linalg_structured_op(dsl_func=None,
   sig = inspect.signature(dsl_func)
   for param_name, param in sig.parameters.items():
     param_default = param.default
-    if isinstance(param_default, TensorDef):
-      tc_model.add_operand(param_name, param_default.operand_def)
-    elif isinstance(param_default, ScalarDef):
+    if isinstance(param_default, (TensorDef, ScalarDef, AttributeDef)):
       tc_model.add_operand(param_name, param_default.operand_def)
     else:
       raise ValueError(f"@tc_def_op function parameters must be defaulted as "
-                       f"TensorDef(...) or ScalarDef(...): Found {param_name}"
-                       f": {param_default}")
+                       f"TensorDef(...), ScalarDef(...), or AttributeDef(...): "
+                       f"Found {param_name}: {param_default}")
     dsl_func_args.append(param_default)
 
   # Invoke the DSL func to finish populating the model.

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 2b8b910507cec..f6fb0cc7d0d0e 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -13,6 +13,7 @@
 
 from .scalar_expr import *
 from .config import *
+import numpy as np
 
 __all__ = [
     "emit_generic_structured_op",
@@ -29,12 +30,14 @@ def isa(cls: Type, ty: Type):
 
 
 def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
-                                 *ins: Value, outs: Sequence[Value]):
+                                 *ins: Value, outs: Sequence[Value],
+                                 **attrs: Sequence[int]):
   all_arg_defs = op_config.ordered_operands
-  in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"]
-  out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"]
+  in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"]
+  out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"]
+  attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"]
 
-  # Verify outs and captures are sequences.
+  # Verify outs is a sequence.
   if not isinstance(outs, Sequence):
     raise ValueError(f"Expected named argument outs to have type Sequence "
                      f"but got {type(outs)}")
@@ -47,6 +50,40 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
     raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
                      f"{len(outs)} for {op_config}")
 
+  # Compute a replacement list for all attribute symbols.
+  expressions = []  # type: Sequence[AffineExpr]
+  replacements = []  # type: Sequence[AffineExpr]
+  for attr in attr_arg_defs:
+    if attr.name not in attrs:
+      raise ValueError(f"Expected named argument for the attribute {attr.name}")
+    attribute_values = attrs.get(attr.name)
+    if not all(isinstance(value, int) for value in attribute_values):
+      raise ValueError(f"Attribute {attr.name} needs to be of type "
+                       f"Sequence[int] but got {type(attribute_values)}")
+    results = attr.attribute_map.results  # type: AffineExprList
+    if len(attribute_values) != len(results):
+      raise ValueError(f"Attribute {attr.name} has length {len(results)} "
+                       f"but got {len(attribute_values)} values")
+    for expr, value in zip(results, attribute_values):
+      expressions.append(expr)
+      replacements.append(AffineConstantExpr.get(value))
+
+  # Replace all index attribute symbols by their value.
+  # TODO: Add support for shape symbols.
+  indexing_maps = []  # type: Sequence[AffineMap]
+  for curr in op_config.indexing_maps:
+    for expression, replacement in zip(expressions, replacements):
+      curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols)
+    indexing_maps.append(curr)
+
+  # TODO: Linalg verification does not currently allow symbols.
+  # Compress them for now and verify none are left.
+  indexing_maps = AffineMap.compress_unused_symbols(indexing_maps,
+                                                    Context.current)
+  if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps):
+    raise ValueError(f"Expected indexing_maps to use no symbols after "
+                     f"replacement and compression but got {indexing_maps}")
+
   outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
                                            out_arg_defs, outs)
 
@@ -67,27 +104,28 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
 
   # Emit the generic op.
   # TODO: Support emission of pure memref form.
-  indexing_maps_attr = ArrayAttr.get([
-      AffineMapAttr.get(am)
-      # TODO: linalg verification does not currently allow symbols.
-      # Compress them for now.
-      for am in AffineMap.compress_unused_symbols(op_config.indexing_maps,
-                                                  Context.current)
-  ])
+  indexing_maps_attr = ArrayAttr.get(
+      [AffineMapAttr.get(am) for am in indexing_maps])
   iterator_types_attr = ArrayAttr.get(
       [StringAttr.get(s) for s in op_config.iterator_types])
 
+  # Compute a dictionary storing all index attributes.
+  index_attributes = {}  # type: Dict[str, DenseElementAttr]
+  for attr in attr_arg_defs:
+    attribute_values = attrs.get(attr.name)
+    array = np.array(attribute_values, dtype=np.int64)
+    index_attributes[attr.name] = DenseElementsAttr.get(array)
+
   return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
           type_mapping, indexing_maps_attr, iterator_types_attr,
-          block_arg_types)
+          index_attributes, block_arg_types)
 
 
-def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
-                               *ins: Value,
-                               outs: Sequence[Value] = ()):
+def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
+                               outs: Sequence[Value], **attrs: Sequence[int]):
   all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
-  indexing_maps_attr, iterator_types_attr, block_arg_types = \
-     prepare_common_structured_op(op_config, *ins, outs = outs)
+  indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
+     prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
 
   generic_op = linalg.GenericOp(
       result_tensors=result_types,
@@ -114,14 +152,12 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
     return generic_op.results
 
 
-def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
-                             op_name: str,
-                             op_class_name: str,
-                             *ins: Value,
-                             outs: Sequence[Value] = ()):
+def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
+                             op_class_name: str, *ins: Value,
+                             outs: Sequence[Value], **attrs: Sequence[int]):
   all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
-  indexing_maps_attr, iterator_types_attr, block_arg_types = \
-     prepare_common_structured_op(op_config, *ins, outs = outs)
+  indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
+     prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
 
   # If we get here, there must exist a builtin class `op_class_name`.
   ctx = Context.current
@@ -141,6 +177,10 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
       "linalg.memoized_indexing_maps"] = indexing_maps_attr
   # iterator_types are hardcoded in C++ both in the yaml and non-yaml path.
 
+  # Additionally set all named attributes.
+  for name, value in index_attributes.items():
+    named_op.operation.attributes[name] = value
+
   if len(result_types) == 1:
     return named_op.result
   else:
@@ -304,7 +344,7 @@ def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type,
                       block_arg_types: Sequence[Type]):
   element_or_self_type = operand_type
   # Get the element type for tensor operands and the type itself for scalars.
-  if operand_config.operand_def.shape:
+  if operand_config.shape_map:
     try:
       element_or_self_type = ShapedType(operand_type).element_type
     except Exception as e:

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c6586824a840e..fe8bfc501ebcb 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -74,6 +74,19 @@ def dot(
   C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
 
 
+ at linalg_structured_op
+def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
+    I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, S.C),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  """A depth-wise 2-D convolution operation."""
+  O[D.n, D.oh, D.ow, D.c] += cast(
+      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+           D.c]) * cast(U, K[D.kh, D.kw, D.c])
+
+
 @linalg_structured_op
 def fill_rng_2d(
     min=ScalarDef(F64),

diff  --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
index f9a0b019034b3..6c94bec316293 100644
--- a/mlir/test/python/dialects/linalg/opdsl/arguments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -7,17 +7,17 @@
 # CHECK-LABEL: matmul
 # CHECK: args:
 # CHECK:     name: A
-# CHECK:     usage: input
-# CHECK:     shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
+# CHECK:     usage: InputOperand
 # CHECK:     type_var: T
+# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
 # CHECK:     name: B
-# CHECK:     usage: input
-# CHECK:     shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
+# CHECK:     usage: InputOperand
 # CHECK:     type_var: T
+# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
 # CHECK:     name: C
-# CHECK:     usage: output
-# CHECK:     shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+# CHECK:     usage: OutputOperand
 # CHECK:     type_var: U
+# CHECK:     shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
 @linalg_structured_op
 def matmul(
     A=TensorDef(T, S.M, S.K),
@@ -30,9 +30,32 @@ def matmul(
 # CHECK-LABEL: fill
 # CHECK: args:
 # CHECK:     name: value
-# CHECK:     usage: input
-# CHECK-NOT: shape:
+# CHECK:     usage: InputOperand
+# CHECK-NOT: shape_map:
 # CHECK:     type_var: T
 @linalg_structured_op
 def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
   O[D.m, D.n] = value
+
+
+# CHECK: ---
+# CHECK-LABEL: strided_copy
+# CHECK: args:
+# CHECK:     name: I
+# CHECK:     usage: InputOperand
+# CHECK:     type_var: T
+# CHECK:     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
+# CHECK:     name: O
+# CHECK:     usage: OutputOperand
+# CHECK:     type_var: T
+# CHECK:     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
+# CHECK:     name: strides
+# CHECK:     usage: IndexAttribute
+# CHECK:     type_var: I64
+# CHECK:     attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
+ at linalg_structured_op
+def strided_copy(
+    I=TensorDef(T, S.W, S.H),
+    O=TensorDef(T, S.OH, S.OW, output=True),
+    strides=AttributeDef(S.S0, S.S1)):
+  O[D.oh, D.ow] = I[D.h * S.S0, D.w * S.S1]

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index 6b12dc1167730..0ed32fe4fb293 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -7,6 +7,9 @@
 
 from mlir.dialects.linalg.opdsl.lang import *
 
+T1 = TV.T1
+T2 = TV.T2
+
 
 @linalg_structured_op
 def matmul_mono(
@@ -18,12 +21,24 @@ def matmul_mono(
 
 @linalg_structured_op
 def matmul_poly(
-    A=TensorDef(TV.T1, S.M, S.K),
-    B=TensorDef(TV.T2, S.K, S.N),
+    A=TensorDef(T1, S.M, S.K),
+    B=TensorDef(T2, S.K, S.N),
     C=TensorDef(U, S.M, S.N, output=True)):
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
 
+ at linalg_structured_op
+def conv_poly(
+    I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
+    K=TensorDef(T2, S.KH, S.KW, S.C),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  O[D.n, D.oh, D.ow, D.c] += cast(
+      U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+           D.c]) * cast(U, K[D.kh, D.kw, D.c])
+
+
 @linalg_structured_op
 def fill_rng(
     min=ScalarDef(F64),
@@ -57,6 +72,10 @@ def fill_rng(
     # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
     # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
 
+    # CHECK: #[[$MAPI:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 4 + d5 * 2, d3)>
+    # CHECK: #[[$MAPK:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+    # CHECK: #[[$MAPO:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
     # CHECK-LABEL: func @test_matmul_mono
     # CHECK-SAME:  %[[A:.+]]: tensor<4x16xf32>
     # CHECK-SAME: %[[B:.+]]: tensor<16x8xf32>
@@ -161,17 +180,11 @@ def test_f64f64f32_matmul(lhs, rhs, init_result):
     # CHECK-LABEL: @test_fill_rng
     # CHECK:      ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %{{.*}}
     # CHECK-DAG:    %[[IDX0:.+]] = linalg.index 0 : index
-    # CHECK-DAG:    %[[IDX1:.+]] = linalg.index 1 : index
     # CHECK-DAG:    %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
-    # CHECK-DAG:    %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32
     # CHECK-DAG:    %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
     # CHECK-DAG:    %[[CST0:.+]] = constant 1103515245 : i64
     # CHECK-DAG:    %[[CST0_CAST:.+]] = trunci %[[CST0]] : i64 to i32
-    # CHECK-DAG:    %[[CST1:.+]] = constant 12345 : i64
-    # CHECK-DAG:    %[[CST1_CAST:.+]] = trunci %[[CST1]] : i64 to i32
-    # CHECK-DAG:    %[[RND1:.+]] = muli %[[RND0]], %[[CST0_CAST]] : i32
-    # CHECK-DAG:    %[[RND2:.+]] = addi %[[RND1]], %[[CST1_CAST]] : i32
-    # Skip random number computation for the second index.
+    # Skip the remaining random number computation and match the scaling logic.
     # CHECK-DAG:    %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
     # CHECK-DAG:    %[[CST3:.+]] = constant 2.3283063999999999E-10 : f64
     # CHECK-DAG:    %[[FACT:.+]] = mulf %[[DIFF]], %[[CST3]] : f64
@@ -183,5 +196,24 @@ def test_f64f64f32_matmul(lhs, rhs, init_result):
     def test_fill_rng(min, max, seed, init_result):
       return fill_rng(min, max, seed, outs=[init_result])
 
+    # CHECK-LABEL: @test_f32i32_conv
+    # CHECK: linalg.generic
+    # CHECK-SAME: indexing_maps = [#[[$MAPI]], #[[$MAPK]], #[[$MAPO]]]
+    # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]
+    # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32)
+    # CHECK-NEXT:   %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
+    # CHECK-NEXT:   %[[FILTER_CAST:.+]] = fptosi %[[FILTER:.+]] : f32 to i32
+    # CHECK-NEXT:   %[[PROD:.+]] = muli %[[IN_CAST]], %[[FILTER_CAST]] : i32
+    # CHECK-NEXT:   %[[SUM:.+]] = addi %[[OUT]], %[[PROD]] : i32
+    # CHECK-NEXT:   linalg.yield %[[SUM]] : i32
+    # CHECK-NEXT: -> tensor<2x4xi32>
+    @builtin.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2, 1),
+                                                                 f32),
+        RankedTensorType.get((2, 4), i32))
+    def test_f32i32_conv(input, filter, init_result):
+      return conv_poly(
+          input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+
 
 print(module)

diff  --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
index 61453da13f49a..3132c90046df7 100644
--- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
+++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
@@ -7,9 +7,9 @@
 # dims auto discovered emits the right shape, indexing maps and iterator types.
 # CHECK: ---
 # CHECK-LABEL: matmul
-# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
-# CHECK: shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
-# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
 # CHECK: static_indexing_maps:
 # CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
 # CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
@@ -19,9 +19,10 @@
 # CHECK-NEXT: - parallel
 # CHECK-NEXT: - reduction
 @linalg_structured_op
-def matmul(A=TensorDef(T, S.M, S.K),
-           B=TensorDef(T, S.K, S.N),
-           C=TensorDef(U, S.M, S.N, output=True)):
+def matmul(
+    A=TensorDef(T, S.M, S.K),
+    B=TensorDef(T, S.K, S.N),
+    C=TensorDef(U, S.M, S.N, output=True)):
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
 
@@ -29,9 +30,9 @@ def matmul(A=TensorDef(T, S.M, S.K),
 # correctly.
 # CHECK: ---
 # CHECK-LABEL: dot
-# CHECK: shape: affine_map<()[s0] -> (s0)>
-# CHECK: shape: affine_map<()[s0] -> (s0)>
-# CHECK: shape: affine_map<()[s0] -> ()>
+# CHECK: shape_map: affine_map<()[s0] -> (s0)>
+# CHECK: shape_map: affine_map<()[s0] -> (s0)>
+# CHECK: shape_map: affine_map<()[s0] -> ()>
 # CHECK: static_indexing_maps:
 # CHECK-NEXT: - affine_map<(d0)[s0] -> (d0)>
 # CHECK-NEXT: - affine_map<(d0)[s0] -> (d0)>

diff  --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py
index ab96c048c1311..14217014fcd98 100644
--- a/mlir/test/python/dialects/linalg/opsrun.py
+++ b/mlir/test/python/dialects/linalg/opsrun.py
@@ -58,6 +58,30 @@ def log(*args):
 }
 """
 
+conv_boiler = """
+func @main() -> i32 attributes {llvm.emit_c_interface} {
+  %v0 = constant 0 : i32
+  %v1 = constant 1.0 : f64
+  %v2 = constant 2.0 : f64
+
+  %input = memref.alloc() : memref<1x4x16x1xf64>
+  %filter = memref.alloc() : memref<2x2x1xf64>
+  %output = memref.alloc() : memref<1x2x4x1xi32>
+  linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64>
+  linalg.fill(%v2, %filter) : f64, memref<2x2x1xf64>
+  linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
+
+  call @conv_on_buffers(%input, %filter, %output) :
+    (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> ()
+
+  %c0 = constant 0 : index
+  %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
+
+  // TODO: FFI-based solution to allow testing and printing with python code.
+  return %0 : i32
+}
+"""
+
 
 def transform(module, boilerplate):
   import mlir.conversions
@@ -69,8 +93,9 @@ def transform(module, boilerplate):
   mod = Module.parse(
       str(module.operation.regions[0].blocks[0].operations[0].operation) +
       boilerplate)
-  pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," +
-                         "convert-vector-to-llvm," + "convert-std-to-llvm")
+  pm = PassManager.parse("func(convert-linalg-to-loops, lower-affine, " +
+                         "convert-scf-to-std), convert-vector-to-llvm," +
+                         "convert-std-to-llvm")
   pm.run(mod)
   return mod
 
@@ -183,3 +208,38 @@ def fill_on_buffers(min, max, seed, out):
 
 
 test_fill_generic()
+
+
+def test_conv_generic():
+  with Context() as ctx, Location.unknown():
+    module = Module.create()
+    f64 = F64Type.get()
+    i32 = IntegerType.get_signless(32)
+    with InsertionPoint(module.body):
+
+      @builtin.FuncOp.from_py_func(
+          MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2, 1), f64),
+          MemRefType.get((1, 2, 4, 1), i32))
+      def conv_on_buffers(input, filter, output):
+        linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly(
+            input,
+            filter,
+            outs=[output],
+            strides=[2, 4],
+            dilations=[1, 2],
+            emit_generic=True)
+
+    execution_engine = ExecutionEngine(transform(module, conv_boiler))
+
+    # TODO: FFI-based solution to allow testing and printing with python code.
+    # Prepare arguments: one result i32.
+    # Arguments must be passed as pointers.
+    c_int_p = ctypes.c_int * 1
+    res = c_int_p(-1)
+    execution_engine.invoke("main", res)
+
+    log("RESULT: ", res[0])
+    # CHECK: RESULT: 8
+
+
+test_conv_generic()

diff  --git a/mlir/test/python/ir/affine_map.py b/mlir/test/python/ir/affine_map.py
index d7bc098ffdc5a..da5d230f42cde 100644
--- a/mlir/test/python/ir/affine_map.py
+++ b/mlir/test/python/ir/affine_map.py
@@ -3,6 +3,7 @@
 import gc
 from mlir.ir import *
 
+
 def run(f):
   print("\nTEST:", f.__name__)
   f()
@@ -21,6 +22,7 @@ def testAffineMapCapsule():
   assert am2 == am1
   assert am2.context is ctx
 
+
 run(testAffineMapCapsule)
 
 
@@ -97,6 +99,7 @@ def testAffineMapGet():
       # CHECK: number of results out of bounds
       print(e)
 
+
 run(testAffineMapGet)
 
 
@@ -117,6 +120,7 @@ def testAffineMapDerive():
     map34 = map5.get_minor_submap(2)
     print(map34)
 
+
 run(testAffineMapDerive)
 
 
@@ -142,6 +146,7 @@ def testAffineMapProperties():
     # CHECK: False
     print(map3.is_projected_permutation)
 
+
 run(testAffineMapProperties)
 
 
@@ -175,23 +180,22 @@ def testAffineMapExprs():
       print(expr)
     assert list(map3.results) == [d2, d0, d1]
 
+
 run(testAffineMapExprs)
 
+
 # CHECK-LABEL: TEST: testCompressUnusedSymbols
 def testCompressUnusedSymbols():
   with Context() as ctx:
-    d0, d1, d2 = (
-      AffineDimExpr.get(0), 
-      AffineDimExpr.get(1), 
-      AffineDimExpr.get(2))
-    s0, s1, s2 = (
-      AffineSymbolExpr.get(0), 
-      AffineSymbolExpr.get(1), 
-      AffineSymbolExpr.get(2))
+    d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
+                  AffineDimExpr.get(2))
+    s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
+                  AffineSymbolExpr.get(2))
     maps = [
         AffineMap.get(3, 3, [d2, d0, d1]),
         AffineMap.get(3, 3, [d2, d0 + s2, d1]),
-        AffineMap.get(3, 3, [d1, d2, d0])]
+        AffineMap.get(3, 3, [d1, d2, d0])
+    ]
 
     compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
 
@@ -207,3 +211,29 @@ def testCompressUnusedSymbols():
 
 
 run(testCompressUnusedSymbols)
+
+
+# CHECK-LABEL: TEST: testReplace
+def testReplace():
+  with Context() as ctx:
+    d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
+                  AffineDimExpr.get(2))
+    s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
+                  AffineSymbolExpr.get(2))
+    map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])
+
+    replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
+    replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
+    replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)
+
+    # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42)
+    print(replace0)
+
+    # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0)
+    print(replace1)
+
+    # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0)
+    print(replace3)
+
+
+run(testReplace)


        


More information about the Mlir-commits mailing list