[Mlir-commits] [mlir] 31f888e - [mlir][linalg][python] Add attribute support to the OpDSL.
Tobias Gysi
llvmlistbot at llvm.org
Thu Jun 24 02:41:41 PDT 2021
Author: Tobias Gysi
Date: 2021-06-24T09:40:32Z
New Revision: 31f888ea9af452ae312c270e569d9fbe23c57c9f
URL: https://github.com/llvm/llvm-project/commit/31f888ea9af452ae312c270e569d9fbe23c57c9f
DIFF: https://github.com/llvm/llvm-project/commit/31f888ea9af452ae312c270e569d9fbe23c57c9f.diff
LOG: [mlir][linalg][python] Add attribute support to the OpDSL.
Extend the OpDSL with index attributes. After tensors and scalars, index attributes are the third operand type. An index attribute represents a compile-time constant that is limited to index expressions. A use cases are the strides and dilations defined by convolution and pooling operations.
The patch only updates the OpDSL. The C++ yaml codegen is updated by a followup patch.
Differential Revision: https://reviews.llvm.org/D104711
Added:
Modified:
mlir/include/mlir-c/AffineMap.h
mlir/lib/Bindings/Python/IRAffine.cpp
mlir/lib/CAPI/IR/AffineMap.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/python/dialects/linalg/opdsl/arguments.py
mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
mlir/test/python/dialects/linalg/opsrun.py
mlir/test/python/ir/affine_map.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h
index e35b7cc6b51d5..7359b969127c7 100644
--- a/mlir/include/mlir-c/AffineMap.h
+++ b/mlir/include/mlir-c/AffineMap.h
@@ -169,6 +169,13 @@ mlirAffineMapGetMajorSubMap(MlirAffineMap affineMap, intptr_t numResults);
MLIR_CAPI_EXPORTED MlirAffineMap
mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, intptr_t numResults);
+/// Apply AffineExpr::replace(`map`) to each of the results and return a new
+/// new AffineMap with the new results and the specified number of dims and
+/// symbols.
+MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapReplace(
+ MlirAffineMap affineMap, MlirAffineExpr expression,
+ MlirAffineExpr replacement, intptr_t numResultDims, intptr_t numResultSyms);
+
/// Returns the simplified affine map resulting from dropping the symbols that
/// do not appear in any of the individual maps in `affineMaps`.
/// Asserts that all maps in `affineMaps` are normalized to the same number of
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index 5d3b790b35d0e..0a2a5666a9e47 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -654,6 +654,14 @@ void mlir::python::populateIRAffine(py::module &m) {
mlirAffineMapGetMinorSubMap(self, nResults);
return PyAffineMap(self.getContext(), affineMap);
})
+ .def("replace",
+ [](PyAffineMap &self, PyAffineExpr &expression,
+ PyAffineExpr &replacement, intptr_t numResultDims,
+ intptr_t numResultSyms) {
+ MlirAffineMap affineMap = mlirAffineMapReplace(
+ self, expression, replacement, numResultDims, numResultSyms);
+ return PyAffineMap(self.getContext(), affineMap);
+ })
.def_property_readonly(
"is_permutation",
[](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
diff --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp
index e0c07afc3b75e..85557bc576f61 100644
--- a/mlir/lib/CAPI/IR/AffineMap.cpp
+++ b/mlir/lib/CAPI/IR/AffineMap.cpp
@@ -138,6 +138,15 @@ MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap,
return wrap(unwrap(affineMap).getMinorSubMap(numResults));
}
+MlirAffineMap mlirAffineMapReplace(MlirAffineMap affineMap,
+ MlirAffineExpr expression,
+ MlirAffineExpr replacement,
+ intptr_t numResultDims,
+ intptr_t numResultSyms) {
+ return wrap(unwrap(affineMap).replace(unwrap(expression), unwrap(replacement),
+ numResultDims, numResultSyms));
+}
+
void mlirAffineMapCompressUnusedSymbols(
MlirAffineMap *affineMaps, intptr_t size, void *result,
void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)) {
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index fe067d6947138..2b2f57248c515 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -9,6 +9,7 @@
"""
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
+from enum import Enum
from mlir import ir as _ir
@@ -133,18 +134,31 @@ def __repr__(self):
return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]"
+class OperandKind(Enum):
+ InputTensor = 0
+ Scalar = 1
+ OutputTensor = 2
+ Attribute = 3
+
+
class OperandDef:
- """Definition of a Tensor or Scalar operand passed to an operation."""
+ """Definition of an operand passed to an operation.
+
+ Keep the meta information of Tensor, Scalar, and Attribute operands and
+ provide the shared registration functionality.
+ """
- def __init__(self, type_var: TypeVar, shape: Sequence[AffineExprDef],
- scalar: bool, output: bool):
+ def __init__(self,
+ kind: OperandKind,
+ type_var: TypeVar,
+ size_exprs: Optional[Sequence[AffineExprDef]] = None):
if not isinstance(type_var, TypeVar):
- raise ValueError(f"OperandDef requires a TypeVar. Got: {repr(type_var)}")
+ raise ValueError(
+ f"OperandDef requires a TypeVar but got {repr(type_var)}")
self.owner = None # type: Optional["LinalgOpDef"]
self.type_var = type_var
- self.shape = shape
- self.scalar = scalar
- self.output = output
+ self.size_exprs = size_exprs
+ self.kind = kind
self.name = None # type: Optional[str]
self.registered_index = -1 # type: int
@@ -159,10 +173,8 @@ def __hash__(self):
return hash(id(self))
def __repr__(self):
- output = "OUTPUT " if self.output else ""
- scalar = "SCALAR " if self.scalar else ""
- return (f"{self.name}:OperandDef({output}{scalar}"
- f"{repr(self.type_var)}, shape={self.shape})")
+ return (f"{self.name}:OperandDef(kind={self.kind.name}, "
+ f"type={repr(self.type_var)}, size_exprs={self.size_exprs})")
class TensorDef:
@@ -170,14 +182,17 @@ class TensorDef:
Tensor operands are indexed using the associated indexing_map when forwarded
to the body of the structured op. A unique name identifies the tensor operands
- and an index determines their position in the operation's parameter list.
+ and an index determines their position in the operation's parameter list. A
+ tensor definition takes type, a shape, and an optional flag to mark output
+ tensors.
"""
def __init__(self,
type_var: TypeVar,
*shape: AffineExprDef,
output: bool = False):
- self.operand_def = OperandDef(type_var, shape, False, output)
+ kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
+ self.operand_def = OperandDef(kind, type_var, size_exprs=shape)
def __getitem__(self, dims) -> TensorUse:
assert self.operand_def.owner, "TensorDef is not attached to an op"
@@ -221,7 +236,7 @@ class ScalarDef(TensorExpression):
"""
def __init__(self, type_var: TypeVar):
- self.operand_def = OperandDef(type_var, (), True, False)
+ self.operand_def = OperandDef(OperandKind.Scalar, type_var)
@property
def scalar_name(self) -> str:
@@ -233,6 +248,22 @@ def to_scalar_expression(self) -> ScalarExpression:
return ScalarArg(self.scalar_name).expr()
+class AttributeDef:
+ """Index Attribute definition.
+
+ Index attributes provide a way to define and set symbols that can be used in
+ indexing expressions. Every attribute specifies a tuple of symbols that at
+ compile-time are replaced by integer values.
+ """
+ yaml_tag = "!LinalgAttributeDef"
+
+ def __init__(self, *sizes: SymbolDef):
+ if any(not isinstance(size, SymbolDef) for size in sizes):
+ raise ValueError(f"AttributeDef requires sizes of type SymbolDef but got "
+ f"{type(sizes)}")
+ self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes)
+
+
class Comprehension:
"""Represents a single comprehension."""
@@ -303,7 +334,7 @@ class ReduceFnType:
def __init__(self, operator: PrimFnType, *reduce_dims: DimDef):
"""Initializes the ReduceFn with a primitive function and dims."""
if not isinstance(operator, PrimFnType):
- raise ValueError(f"Reduce expected a Prim operator. Got: {operator}")
+ raise ValueError(f"Reduce expected a Prim operator but got {operator}")
self.operator = operator
self.reduce_dims = tuple(reduce_dims)
@@ -353,7 +384,7 @@ def __init__(self, value: Any):
self.value = str(
_ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
else:
- raise ValueError(f"const requires int or float. Got: {type(value)}")
+ raise ValueError(f"const requires int or float but got {type(value)}")
def to_scalar_expression(self) -> ScalarExpression:
return ScalarConst(self.value).expr()
@@ -475,21 +506,22 @@ def __init__(self,
self.comprehensions = list() # type: List[Comprehension]
self._affine_state = AffineBuildState()
- @property
- def outputs(self) -> Sequence[OperandDef]:
- return [
- operand for operand in self.registered_operands.values()
- if operand.output
- ]
-
def add_operand(self, name: str, operand: OperandDef):
"""Registers an operand."""
if name in self.registered_operands:
raise ValueError(f"The operand {name} is already registered "
f"to {self.registered_operands['name']}")
- if not operand.output and self.outputs:
- raise ValueError(f"The operand {name} is an input registered after "
- f"the output {self.outputs[-1]}")
+ # Ensure output tensors are registered after input tensors and scalars and
+ # attributes are registered after all other operand types.
+ registered_kinds = [
+ operand.kind.value for operand in self.registered_operands.values()
+ ]
+ if registered_kinds:
+ maximum = max(registered_kinds)
+ if maximum > operand.kind.value and maximum > OperandKind.Scalar.value:
+ raise ValueError(
+ f"The operand {name} of kind {operand.kind.name} is registered "
+ f"after an operand of kind {OperandKind(maximum).name}")
operand.attach(len(self.registered_operands), name, self)
self.registered_operands[name] = operand
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
index 6dd86334b95a5..773bd876397f9 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -45,9 +45,11 @@ class OperandDefConfig(YAMLObject):
def __init__(self,
operand_def: OperandDef,
- shape_map: Optional[_ir.AffineMap] = None):
+ shape_map: Optional[_ir.AffineMap] = None,
+ attribute_map: Optional[_ir.AffineMap] = None):
self.operand_def = operand_def
self.shape_map = shape_map # type: Optional[_ir.AffineMap]
+ self.attribute_map = attribute_map # type: Optional[_ir.AffineMap]
self.indexing_map = None # type: Optional[_ir.AffineMap]
@property
@@ -60,21 +62,25 @@ def type_var(self) -> TypeVar:
@property
def usage(self) -> str:
- if self.operand_def.output:
- return "output"
- return "input"
+ if self.operand_def.kind == OperandKind.Attribute:
+ return "IndexAttribute"
+ if self.operand_def.kind == OperandKind.OutputTensor:
+ return "OutputOperand"
+ return "InputOperand"
def to_yaml_custom_dict(self):
- self_dict = dict(name=self.name)
- self_dict["usage"] = self.usage
- if not self.operand_def.scalar:
- self_dict["shape"] = _serialize_affine_map(self.shape_map)
- self_dict["type_var"] = self.type_var.name
+ self_dict = dict(
+ name=self.name, usage=self.usage, type_var=self.type_var.name)
+ if self.shape_map:
+ self_dict["shape_map"] = _serialize_affine_map(self.shape_map)
+ if self.attribute_map:
+ self_dict["attribute_map"] = _serialize_affine_map(self.attribute_map)
return self_dict
def __repr__(self):
return (f"OperandDefConfig({self.operand_def}, "
- f"shape_map={self.shape_map}, indexing_map={self.indexing_map})")
+ f"shape_map={self.shape_map}, attribute_map={self.attribute_map}, "
+ f"indexing_map={self.indexing_map})")
class LinalgIndexingMapsConfig(YAMLObject):
@@ -109,6 +115,7 @@ class LinalgStructuredOpConfig(YAMLObject):
def __init__(self,
comprehension: Comprehension,
+ registered_operands: Sequence[OperandDef],
context: Optional[_ir.Context] = None):
self.context = context if context is not None else _ir.Context()
self.affine_state = AffineBuildState()
@@ -131,22 +138,33 @@ def __init__(self,
read_use.collect_scalar_uses(collected_scalar_uses)
read_use.collect_indices(collected_indices)
- # Need to add all definitions before uses, so process twice.
+ # Collect all attribute definitions
+ collected_attr_defs = list()
+ for operand in registered_operands:
+ if operand.kind == OperandKind.Attribute:
+ collected_attr_defs.append(operand)
+
+ # Add all definitions before uses, so process twice.
for use in collected_tensor_uses:
self.add_operand(use.operand_def)
for use in collected_scalar_uses:
self.add_operand(use.operand_def)
+ for definition in collected_attr_defs:
+ self.add_operand(definition)
for use in collected_tensor_uses:
self.add_tensor_use(use)
- # Now normalize all defs and uses indexing maps now that full count of
- # dims and symbols are known.
+ # Normalize all shape and indexing maps now that full count of dims and
+ # symbols are known.
for cuse in self.uses.values():
cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
- for cdef in self.operands.values():
- if not cdef.operand_def.scalar:
- cdef.shape_map = self._normalize_affine_map(
- cdef.shape_map, with_dims=False)
+ for operand_config in self.operands.values():
+ if operand_config.shape_map:
+ operand_config.shape_map = self._normalize_affine_map(
+ operand_config.shape_map, with_dims=False)
+ if operand_config.attribute_map:
+ operand_config.attribute_map = self._normalize_affine_map(
+ operand_config.attribute_map, with_dims=False)
# Now for each write use, propagate the indexing maps from the use to the
# tensor, ensuring that there are not conflicts.
@@ -174,12 +192,16 @@ def __init__(self,
# Set the indexing map of all scalar uses to the empty map.
for operand_config in self.operands.values():
- if operand_config.operand_def.scalar:
- operand_config.indexing_map = self._create_empty_affine_map()
+ if operand_config.operand_def.kind == OperandKind.Scalar:
+ operand_config.indexing_map = self._get_scalar_map()
- # Sanity check that all defs have an indexing map.
- assert all(d.indexing_map for d in self.operands.values()), (
- f"Missing indexing map on OperandConfigDef: {self.operands}")
+ # Check all registered tensor and scalar operands have an indexing map.
+ for operand in registered_operands:
+ if operand.kind == OperandKind.Attribute:
+ continue
+ if not (operand in self.operands and self.operands[operand].indexing_map):
+ raise ValueError(f"Failed to compute an indexing map for operand "
+ f"{operand.name}")
# Collect reduction dims and ensure all the same.
all_reduction_dims = set(comprehension.all_reduction_dims)
@@ -189,7 +211,7 @@ def __init__(self,
f"dims. Got: {all_reduction_dims}")
self.reduction_dims = next(iter(all_reduction_dims))
- # Check the index dimension exists and resolve
+ # Check the index dimension exists and resolve.
for index in collected_indices:
if index.dim_def.dimname not in self.affine_state.all_dims:
raise ValueError(
@@ -221,7 +243,7 @@ def ordered_dims(self) -> Sequence[Tuple[str, int]]:
@property
def indexing_maps(self) -> Sequence[_ir.AffineMap]:
- return [d.indexing_map for d in self.ordered_operands]
+ return [o.indexing_map for o in self.ordered_operands if o.indexing_map]
@property
def iterator_types(self) -> Sequence[str]:
@@ -237,20 +259,24 @@ def get_type(symbolic_name, position):
def add_operand(self, operand_def: OperandDef):
if operand_def in self.operands:
return
- if operand_def.scalar:
+ if operand_def.kind == OperandKind.Scalar:
self.operands[operand_def] = OperandDefConfig(operand_def)
return
with self.context:
local_state = AffineBuildState(
global_state=self.affine_state, allow_new_dims=False)
exprs = []
- for expr in operand_def.shape:
+ for expr in operand_def.size_exprs:
exprs.append(expr.build(state=local_state))
assert local_state.local_dim_count == 0
- shape_map = _ir.AffineMap.get(
+ affine_map = _ir.AffineMap.get(
dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs)
- def_config = OperandDefConfig(operand_def, shape_map)
- self.operands[operand_def] = def_config
+ if operand_def.kind == OperandKind.Attribute:
+ self.operands[operand_def] = OperandDefConfig(
+ operand_def, attribute_map=affine_map)
+ else:
+ self.operands[operand_def] = OperandDefConfig(
+ operand_def, shape_map=affine_map)
def add_tensor_use(self, tensor_use: TensorUse):
if tensor_use in self.uses:
@@ -261,7 +287,6 @@ def add_tensor_use(self, tensor_use: TensorUse):
exprs = []
for expr in tensor_use.indices:
exprs.append(expr.build(state=local_state))
- assert local_state.local_symbol_count == 0
indexing_map = _ir.AffineMap.get(
dim_count=local_state.dim_count,
symbol_count=local_state.symbol_count,
@@ -270,8 +295,8 @@ def add_tensor_use(self, tensor_use: TensorUse):
use_config = TensorUseConfig(tensor_use, indexing_map)
self.uses[tensor_use] = use_config
- def _create_empty_affine_map(self) -> _ir.AffineMap:
- """Create an affine map with an empty range."""
+ def _get_scalar_map(self) -> _ir.AffineMap:
+ """Create an empty affine map used to index a scalar."""
with self.context:
return _ir.AffineMap.get(
dim_count=self.affine_state.dim_count,
@@ -345,8 +370,9 @@ def from_linalg_op_def(
return [
LinalgOpConfig(
tc_op_def.metadata,
- structured_op=LinalgStructuredOpConfig(tc_op_def.comprehensions[0],
- context)),
+ structured_op=LinalgStructuredOpConfig(
+ tc_op_def.comprehensions[0],
+ tc_op_def.registered_operands.values(), context)),
]
def __repr__(self):
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 191b1b34fd836..6dbda1bb7ecbe 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -44,15 +44,20 @@ def __init__(self, op_name: str, model: LinalgOpDef):
self.op_name = op_name
self.model = model
- def __call__(self, *args, emit_generic: bool = False, **kwargs):
+ def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs):
"""Emits the corresponding op definition as IR.
Most arguments are passed through to the underlying emitter. The following
- are interpreted here:
+ keyword argument is interpreted here:
emit_generic: Emits a generic form as appropriate (default True). If
False, a named form is emitted (which must have been built in to the
compiler).
"""
+ emit_generic = kwargs.pop("emit_generic", False)
+ if not isinstance(emit_generic, bool):
+ raise ValueError(f"The named argument 'emit_generic' needs to be "
+ f" of type bool but got {type(emit_generic)}")
+
op_configs = LinalgOpConfig.from_linalg_op_def(
self.model, context=ir.Context.current)
@@ -70,12 +75,16 @@ def __call__(self, *args, emit_generic: bool = False, **kwargs):
op_config = op_configs[0]
if op_config.structured_op:
if emit_generic:
- return emit_generic_structured_op(op_config.structured_op, *args,
- **kwargs)
+ return emit_generic_structured_op(
+ op_config.structured_op, *ins, outs=outs, **kwargs)
else:
- return emit_named_structured_op(op_config.structured_op, self.op_name,
- self.model.metadata.cpp_class_name,
- *args, **kwargs)
+ return emit_named_structured_op(
+ op_config.structured_op,
+ self.op_name,
+ self.model.metadata.cpp_class_name,
+ *ins,
+ outs=outs,
+ **kwargs)
raise NotImplementedError(
f"Emission of linalg op type not supported: {op_config}")
@@ -104,14 +113,12 @@ def linalg_structured_op(dsl_func=None,
sig = inspect.signature(dsl_func)
for param_name, param in sig.parameters.items():
param_default = param.default
- if isinstance(param_default, TensorDef):
- tc_model.add_operand(param_name, param_default.operand_def)
- elif isinstance(param_default, ScalarDef):
+ if isinstance(param_default, (TensorDef, ScalarDef, AttributeDef)):
tc_model.add_operand(param_name, param_default.operand_def)
else:
raise ValueError(f"@tc_def_op function parameters must be defaulted as "
- f"TensorDef(...) or ScalarDef(...): Found {param_name}"
- f": {param_default}")
+ f"TensorDef(...), ScalarDef(...), or AttributeDef(...): "
+ f"Found {param_name}: {param_default}")
dsl_func_args.append(param_default)
# Invoke the DSL func to finish populating the model.
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 2b8b910507cec..f6fb0cc7d0d0e 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -13,6 +13,7 @@
from .scalar_expr import *
from .config import *
+import numpy as np
__all__ = [
"emit_generic_structured_op",
@@ -29,12 +30,14 @@ def isa(cls: Type, ty: Type):
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
- *ins: Value, outs: Sequence[Value]):
+ *ins: Value, outs: Sequence[Value],
+ **attrs: Sequence[int]):
all_arg_defs = op_config.ordered_operands
- in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"]
- out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"]
+ in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"]
+ out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"]
+ attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"]
- # Verify outs and captures are sequences.
+ # Verify outs is a sequence.
if not isinstance(outs, Sequence):
raise ValueError(f"Expected named argument outs to have type Sequence "
f"but got {type(outs)}")
@@ -47,6 +50,40 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
f"{len(outs)} for {op_config}")
+ # Compute a replacement list for all attribute symbols.
+ expressions = [] # type: Sequence[AffineExpr]
+ replacements = [] # type: Sequence[AffineExpr]
+ for attr in attr_arg_defs:
+ if attr.name not in attrs:
+ raise ValueError(f"Expected named argument for the attribute {attr.name}")
+ attribute_values = attrs.get(attr.name)
+ if not all(isinstance(value, int) for value in attribute_values):
+ raise ValueError(f"Attribute {attr.name} needs to be of type "
+ f"Sequence[int] but got {type(attribute_values)}")
+ results = attr.attribute_map.results # type: AffineExprList
+ if len(attribute_values) != len(results):
+ raise ValueError(f"Attribute {attr.name} has length {len(results)} "
+ f"but got {len(attribute_values)} values")
+ for expr, value in zip(results, attribute_values):
+ expressions.append(expr)
+ replacements.append(AffineConstantExpr.get(value))
+
+ # Replace all index attribute symbols by their value.
+ # TODO: Add support for shape symbols.
+ indexing_maps = [] # type: Sequence[AffineMap]
+ for curr in op_config.indexing_maps:
+ for expression, replacement in zip(expressions, replacements):
+ curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols)
+ indexing_maps.append(curr)
+
+ # TODO: Linalg verification does not currently allow symbols.
+ # Compress them for now and verify none are left.
+ indexing_maps = AffineMap.compress_unused_symbols(indexing_maps,
+ Context.current)
+ if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps):
+ raise ValueError(f"Expected indexing_maps to use no symbols after "
+ f"replacement and compression but got {indexing_maps}")
+
outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
out_arg_defs, outs)
@@ -67,27 +104,28 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
# Emit the generic op.
# TODO: Support emission of pure memref form.
- indexing_maps_attr = ArrayAttr.get([
- AffineMapAttr.get(am)
- # TODO: linalg verification does not currently allow symbols.
- # Compress them for now.
- for am in AffineMap.compress_unused_symbols(op_config.indexing_maps,
- Context.current)
- ])
+ indexing_maps_attr = ArrayAttr.get(
+ [AffineMapAttr.get(am) for am in indexing_maps])
iterator_types_attr = ArrayAttr.get(
[StringAttr.get(s) for s in op_config.iterator_types])
+ # Compute a dictionary storing all index attributes.
+ index_attributes = {} # type: Dict[str, DenseElementAttr]
+ for attr in attr_arg_defs:
+ attribute_values = attrs.get(attr.name)
+ array = np.array(attribute_values, dtype=np.int64)
+ index_attributes[attr.name] = DenseElementsAttr.get(array)
+
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
type_mapping, indexing_maps_attr, iterator_types_attr,
- block_arg_types)
+ index_attributes, block_arg_types)
-def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
- *ins: Value,
- outs: Sequence[Value] = ()):
+def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
+ outs: Sequence[Value], **attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
- indexing_maps_attr, iterator_types_attr, block_arg_types = \
- prepare_common_structured_op(op_config, *ins, outs = outs)
+ indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
+ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
generic_op = linalg.GenericOp(
result_tensors=result_types,
@@ -114,14 +152,12 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
return generic_op.results
-def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
- op_name: str,
- op_class_name: str,
- *ins: Value,
- outs: Sequence[Value] = ()):
+def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str,
+ op_class_name: str, *ins: Value,
+ outs: Sequence[Value], **attrs: Sequence[int]):
all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
- indexing_maps_attr, iterator_types_attr, block_arg_types = \
- prepare_common_structured_op(op_config, *ins, outs = outs)
+ indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \
+ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs)
# If we get here, there must exist a builtin class `op_class_name`.
ctx = Context.current
@@ -141,6 +177,10 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
"linalg.memoized_indexing_maps"] = indexing_maps_attr
# iterator_types are hardcoded in C++ both in the yaml and non-yaml path.
+ # Additionally set all named attributes.
+ for name, value in index_attributes.items():
+ named_op.operation.attributes[name] = value
+
if len(result_types) == 1:
return named_op.result
else:
@@ -304,7 +344,7 @@ def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type,
block_arg_types: Sequence[Type]):
element_or_self_type = operand_type
# Get the element type for tensor operands and the type itself for scalars.
- if operand_config.operand_def.shape:
+ if operand_config.shape_map:
try:
element_or_self_type = ShapedType(operand_type).element_type
except Exception as e:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index c6586824a840e..fe8bfc501ebcb 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -74,6 +74,19 @@ def dot(
C[None] += cast(U, A[D.m]) * cast(U, B[D.m])
+ at linalg_structured_op
+def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
+ I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
+ K=TensorDef(T2, S.KH, S.KW, S.C),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+ strides=AttributeDef(S.SH, S.SW),
+ dilations=AttributeDef(S.DH, S.DW)):
+ """A depth-wise 2-D convolution operation."""
+ O[D.n, D.oh, D.ow, D.c] += cast(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+ D.c]) * cast(U, K[D.kh, D.kw, D.c])
+
+
@linalg_structured_op
def fill_rng_2d(
min=ScalarDef(F64),
diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py
index f9a0b019034b3..6c94bec316293 100644
--- a/mlir/test/python/dialects/linalg/opdsl/arguments.py
+++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py
@@ -7,17 +7,17 @@
# CHECK-LABEL: matmul
# CHECK: args:
# CHECK: name: A
-# CHECK: usage: input
-# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
+# CHECK: usage: InputOperand
# CHECK: type_var: T
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
# CHECK: name: B
-# CHECK: usage: input
-# CHECK: shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
+# CHECK: usage: InputOperand
# CHECK: type_var: T
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
# CHECK: name: C
-# CHECK: usage: output
-# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+# CHECK: usage: OutputOperand
# CHECK: type_var: U
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
@linalg_structured_op
def matmul(
A=TensorDef(T, S.M, S.K),
@@ -30,9 +30,32 @@ def matmul(
# CHECK-LABEL: fill
# CHECK: args:
# CHECK: name: value
-# CHECK: usage: input
-# CHECK-NOT: shape:
+# CHECK: usage: InputOperand
+# CHECK-NOT: shape_map:
# CHECK: type_var: T
@linalg_structured_op
def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)):
O[D.m, D.n] = value
+
+
+# CHECK: ---
+# CHECK-LABEL: strided_copy
+# CHECK: args:
+# CHECK: name: I
+# CHECK: usage: InputOperand
+# CHECK: type_var: T
+# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s2, s3)>
+# CHECK: name: O
+# CHECK: usage: OutputOperand
+# CHECK: type_var: T
+# CHECK: shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1)>
+# CHECK: name: strides
+# CHECK: usage: IndexAttribute
+# CHECK: type_var: I64
+# CHECK: attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s5)>
+ at linalg_structured_op
+def strided_copy(
+ I=TensorDef(T, S.W, S.H),
+ O=TensorDef(T, S.OH, S.OW, output=True),
+ strides=AttributeDef(S.S0, S.S1)):
+ O[D.oh, D.ow] = I[D.h * S.S0, D.w * S.S1]
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index 6b12dc1167730..0ed32fe4fb293 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -7,6 +7,9 @@
from mlir.dialects.linalg.opdsl.lang import *
+T1 = TV.T1
+T2 = TV.T2
+
@linalg_structured_op
def matmul_mono(
@@ -18,12 +21,24 @@ def matmul_mono(
@linalg_structured_op
def matmul_poly(
- A=TensorDef(TV.T1, S.M, S.K),
- B=TensorDef(TV.T2, S.K, S.N),
+ A=TensorDef(T1, S.M, S.K),
+ B=TensorDef(T2, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)):
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+ at linalg_structured_op
+def conv_poly(
+ I=TensorDef(T1, S.N, S.IH, S.IW, S.C),
+ K=TensorDef(T2, S.KH, S.KW, S.C),
+ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
+ strides=AttributeDef(S.SH, S.SW),
+ dilations=AttributeDef(S.DH, S.DW)):
+ O[D.n, D.oh, D.ow, D.c] += cast(
+ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+ D.c]) * cast(U, K[D.kh, D.kw, D.c])
+
+
@linalg_structured_op
def fill_rng(
min=ScalarDef(F64),
@@ -57,6 +72,10 @@ def fill_rng(
# CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
# CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+ # CHECK: #[[$MAPI:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 4 + d5 * 2, d3)>
+ # CHECK: #[[$MAPK:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
+ # CHECK: #[[$MAPO:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
+
# CHECK-LABEL: func @test_matmul_mono
# CHECK-SAME: %[[A:.+]]: tensor<4x16xf32>
# CHECK-SAME: %[[B:.+]]: tensor<16x8xf32>
@@ -161,17 +180,11 @@ def test_f64f64f32_matmul(lhs, rhs, init_result):
# CHECK-LABEL: @test_fill_rng
# CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %{{.*}}
# CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
- # CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
# CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
- # CHECK-DAG: %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32
# CHECK-DAG: %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
# CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i64
# CHECK-DAG: %[[CST0_CAST:.+]] = trunci %[[CST0]] : i64 to i32
- # CHECK-DAG: %[[CST1:.+]] = constant 12345 : i64
- # CHECK-DAG: %[[CST1_CAST:.+]] = trunci %[[CST1]] : i64 to i32
- # CHECK-DAG: %[[RND1:.+]] = muli %[[RND0]], %[[CST0_CAST]] : i32
- # CHECK-DAG: %[[RND2:.+]] = addi %[[RND1]], %[[CST1_CAST]] : i32
- # Skip random number computation for the second index.
+ # Skip the remaining random number computation and match the scaling logic.
# CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
# CHECK-DAG: %[[CST3:.+]] = constant 2.3283063999999999E-10 : f64
# CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST3]] : f64
@@ -183,5 +196,24 @@ def test_f64f64f32_matmul(lhs, rhs, init_result):
def test_fill_rng(min, max, seed, init_result):
return fill_rng(min, max, seed, outs=[init_result])
+ # CHECK-LABEL: @test_f32i32_conv
+ # CHECK: linalg.generic
+ # CHECK-SAME: indexing_maps = [#[[$MAPI]], #[[$MAPK]], #[[$MAPO]]]
+ # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]
+ # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[FILTER:.+]]: f32, %[[OUT:.+]]: i32)
+ # CHECK-NEXT: %[[IN_CAST:.+]] = fptosi %[[IN:.+]] : f32 to i32
+ # CHECK-NEXT: %[[FILTER_CAST:.+]] = fptosi %[[FILTER:.+]] : f32 to i32
+ # CHECK-NEXT: %[[PROD:.+]] = muli %[[IN_CAST]], %[[FILTER_CAST]] : i32
+ # CHECK-NEXT: %[[SUM:.+]] = addi %[[OUT]], %[[PROD]] : i32
+ # CHECK-NEXT: linalg.yield %[[SUM]] : i32
+ # CHECK-NEXT: -> tensor<2x4xi32>
+ @builtin.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2, 1),
+ f32),
+ RankedTensorType.get((2, 4), i32))
+ def test_f32i32_conv(input, filter, init_result):
+ return conv_poly(
+ input, filter, outs=[init_result], strides=[2, 4], dilations=[1, 2])
+
print(module)
diff --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
index 61453da13f49a..3132c90046df7 100644
--- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
+++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py
@@ -7,9 +7,9 @@
# dims auto discovered emits the right shape, indexing maps and iterator types.
# CHECK: ---
# CHECK-LABEL: matmul
-# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
-# CHECK: shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
-# CHECK: shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
+# CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
# CHECK: static_indexing_maps:
# CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
# CHECK-NEXT: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
@@ -19,9 +19,10 @@
# CHECK-NEXT: - parallel
# CHECK-NEXT: - reduction
@linalg_structured_op
-def matmul(A=TensorDef(T, S.M, S.K),
- B=TensorDef(T, S.K, S.N),
- C=TensorDef(U, S.M, S.N, output=True)):
+def matmul(
+ A=TensorDef(T, S.M, S.K),
+ B=TensorDef(T, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
@@ -29,9 +30,9 @@ def matmul(A=TensorDef(T, S.M, S.K),
# correctly.
# CHECK: ---
# CHECK-LABEL: dot
-# CHECK: shape: affine_map<()[s0] -> (s0)>
-# CHECK: shape: affine_map<()[s0] -> (s0)>
-# CHECK: shape: affine_map<()[s0] -> ()>
+# CHECK: shape_map: affine_map<()[s0] -> (s0)>
+# CHECK: shape_map: affine_map<()[s0] -> (s0)>
+# CHECK: shape_map: affine_map<()[s0] -> ()>
# CHECK: static_indexing_maps:
# CHECK-NEXT: - affine_map<(d0)[s0] -> (d0)>
# CHECK-NEXT: - affine_map<(d0)[s0] -> (d0)>
diff --git a/mlir/test/python/dialects/linalg/opsrun.py b/mlir/test/python/dialects/linalg/opsrun.py
index ab96c048c1311..14217014fcd98 100644
--- a/mlir/test/python/dialects/linalg/opsrun.py
+++ b/mlir/test/python/dialects/linalg/opsrun.py
@@ -58,6 +58,30 @@ def log(*args):
}
"""
+conv_boiler = """
+func @main() -> i32 attributes {llvm.emit_c_interface} {
+ %v0 = constant 0 : i32
+ %v1 = constant 1.0 : f64
+ %v2 = constant 2.0 : f64
+
+ %input = memref.alloc() : memref<1x4x16x1xf64>
+ %filter = memref.alloc() : memref<2x2x1xf64>
+ %output = memref.alloc() : memref<1x2x4x1xi32>
+ linalg.fill(%v1, %input) : f64, memref<1x4x16x1xf64>
+ linalg.fill(%v2, %filter) : f64, memref<2x2x1xf64>
+ linalg.fill(%v0, %output) : i32, memref<1x2x4x1xi32>
+
+ call @conv_on_buffers(%input, %filter, %output) :
+ (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> ()
+
+ %c0 = constant 0 : index
+ %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
+
+ // TODO: FFI-based solution to allow testing and printing with python code.
+ return %0 : i32
+}
+"""
+
def transform(module, boilerplate):
import mlir.conversions
@@ -69,8 +93,9 @@ def transform(module, boilerplate):
mod = Module.parse(
str(module.operation.regions[0].blocks[0].operations[0].operation) +
boilerplate)
- pm = PassManager.parse("func(convert-linalg-to-loops, convert-scf-to-std)," +
- "convert-vector-to-llvm," + "convert-std-to-llvm")
+ pm = PassManager.parse("func(convert-linalg-to-loops, lower-affine, " +
+ "convert-scf-to-std), convert-vector-to-llvm," +
+ "convert-std-to-llvm")
pm.run(mod)
return mod
@@ -183,3 +208,38 @@ def fill_on_buffers(min, max, seed, out):
test_fill_generic()
+
+
+def test_conv_generic():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f64 = F64Type.get()
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint(module.body):
+
+ @builtin.FuncOp.from_py_func(
+ MemRefType.get((1, 4, 16, 1), f64), MemRefType.get((2, 2, 1), f64),
+ MemRefType.get((1, 2, 4, 1), i32))
+ def conv_on_buffers(input, filter, output):
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly(
+ input,
+ filter,
+ outs=[output],
+ strides=[2, 4],
+ dilations=[1, 2],
+ emit_generic=True)
+
+ execution_engine = ExecutionEngine(transform(module, conv_boiler))
+
+ # TODO: FFI-based solution to allow testing and printing with python code.
+ # Prepare arguments: one result i32.
+ # Arguments must be passed as pointers.
+ c_int_p = ctypes.c_int * 1
+ res = c_int_p(-1)
+ execution_engine.invoke("main", res)
+
+ log("RESULT: ", res[0])
+ # CHECK: RESULT: 8
+
+
+test_conv_generic()
diff --git a/mlir/test/python/ir/affine_map.py b/mlir/test/python/ir/affine_map.py
index d7bc098ffdc5a..da5d230f42cde 100644
--- a/mlir/test/python/ir/affine_map.py
+++ b/mlir/test/python/ir/affine_map.py
@@ -3,6 +3,7 @@
import gc
from mlir.ir import *
+
def run(f):
print("\nTEST:", f.__name__)
f()
@@ -21,6 +22,7 @@ def testAffineMapCapsule():
assert am2 == am1
assert am2.context is ctx
+
run(testAffineMapCapsule)
@@ -97,6 +99,7 @@ def testAffineMapGet():
# CHECK: number of results out of bounds
print(e)
+
run(testAffineMapGet)
@@ -117,6 +120,7 @@ def testAffineMapDerive():
map34 = map5.get_minor_submap(2)
print(map34)
+
run(testAffineMapDerive)
@@ -142,6 +146,7 @@ def testAffineMapProperties():
# CHECK: False
print(map3.is_projected_permutation)
+
run(testAffineMapProperties)
@@ -175,23 +180,22 @@ def testAffineMapExprs():
print(expr)
assert list(map3.results) == [d2, d0, d1]
+
run(testAffineMapExprs)
+
# CHECK-LABEL: TEST: testCompressUnusedSymbols
def testCompressUnusedSymbols():
with Context() as ctx:
- d0, d1, d2 = (
- AffineDimExpr.get(0),
- AffineDimExpr.get(1),
- AffineDimExpr.get(2))
- s0, s1, s2 = (
- AffineSymbolExpr.get(0),
- AffineSymbolExpr.get(1),
- AffineSymbolExpr.get(2))
+ d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
+ AffineDimExpr.get(2))
+ s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
+ AffineSymbolExpr.get(2))
maps = [
AffineMap.get(3, 3, [d2, d0, d1]),
AffineMap.get(3, 3, [d2, d0 + s2, d1]),
- AffineMap.get(3, 3, [d1, d2, d0])]
+ AffineMap.get(3, 3, [d1, d2, d0])
+ ]
compressed_maps = AffineMap.compress_unused_symbols(maps, ctx)
@@ -207,3 +211,29 @@ def testCompressUnusedSymbols():
run(testCompressUnusedSymbols)
+
+
+# CHECK-LABEL: TEST: testReplace
+def testReplace():
+ with Context() as ctx:
+ d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1),
+ AffineDimExpr.get(2))
+ s0, s1, s2 = (AffineSymbolExpr.get(0), AffineSymbolExpr.get(1),
+ AffineSymbolExpr.get(2))
+ map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0])
+
+ replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3)
+ replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3)
+ replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2)
+
+ # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42)
+ print(replace0)
+
+ # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0)
+ print(replace1)
+
+ # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0)
+ print(replace3)
+
+
+run(testReplace)
More information about the Mlir-commits
mailing list