[Mlir-commits] [mlir] c2be2cd - [mlir][Python][Linalg] Adding const, capture, and index support to the OpDSL.
Tobias Gysi
llvmlistbot at llvm.org
Thu Apr 29 00:26:00 PDT 2021
Author: Tobias Gysi
Date: 2021-04-29T07:24:47Z
New Revision: c2be2cda8d268d4a0adbede149a20e3fd284f1d7
URL: https://github.com/llvm/llvm-project/commit/c2be2cda8d268d4a0adbede149a20e3fd284f1d7
DIFF: https://github.com/llvm/llvm-project/commit/c2be2cda8d268d4a0adbede149a20e3fd284f1d7.diff
LOG: [mlir][Python][Linalg] Adding const, capture, and index support to the OpDSL.
The patch extends the OpDSL with support for:
- Constant values
- Capture scalar parameters
- Access the iteration indices using the index operation
- Provide predefined floating point and integer types.
Up to now the patch only supports emitting the new nodes. The C++/yaml path is not fully implemented. The fill_rng_2d operation defined in emit_structured_generic.py makes use of the new DSL constructs.
Differential Revision: https://reviews.llvm.org/D101364
Added:
Modified:
mlir/docs/Tools/LinalgOpDsl.md
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py
mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
Removed:
################################################################################
diff --git a/mlir/docs/Tools/LinalgOpDsl.md b/mlir/docs/Tools/LinalgOpDsl.md
index 140c2eab3ee2..4ef6fb1f7449 100644
--- a/mlir/docs/Tools/LinalgOpDsl.md
+++ b/mlir/docs/Tools/LinalgOpDsl.md
@@ -72,6 +72,34 @@ The docstring will be transferred to the op definition verbatim.
Special identifying op interfaces can be declared for the op via
`implements(interface1[, interface2...])`.
+## Parameters
+
+Structured operations can take two types of parameters namely input/output
+tensors and captures. Assignment expressions index the tensor parameters to
+access the individual elements, while captures are scalars that can be
+accessed directly.
+
+The following example demonstrates the use of the two parameter types:
+
+```python
+ at linalg_structured_op
+def copy_and_scale(I=TensorDef(T, S.M, S.K),
+ O=TensorDef(T, S.M, S.K, output=True),
+ val=CaptureDef(T)):
+ """Scale the input by the captured value and store the result"""
+ O[D.m, D.n] = I[D.m, D.n] * val
+```
+
+The operation scales the input tensor `I` scales its elements by the value
+`val` and writes the result to the output tensor `out`. The capture `val` is
+bound to a `CaptureDef`, which specifies the type of the captured value. The
+tensors are bound to a `TensorDef` as demonstrated by the matmul example. All
+parameters appear in the parameter list of the operation:
+
+```python
+fill(in_tensor, outs=[out_tensor], captures=[captured_val])
+```
+
## Assignments
The bulk of language consists of assignment expressions of the form above.
@@ -99,22 +127,30 @@ Reduction functions can appear as the outer-most function on the RHS:
There are also special forms:
-* `cast(TypeVar, operand)`
+* `cast(TypeVar, operand)` casts the `operand` to the target type `TypeVar`.
+* `const(TypeVar, value)` returns a constant value of type `TypeVar`.
+* `index(dim)` returns the iteration index in the given dimension `dim`.
## Types
All types in assignment expressions are late bound based on actual input
-and output types of constructed ops. Assignment expressions with no `cast`
-calls will generally require uniform types throughout and will fail to
-verify if violated. The presence of a `cast` allows for a limited form of
-numeric type conversion between element types that can be derived from inputs
-and outputs (and in the future, attributes). `cast` calls with a `TypeVar`
-first argument are emitted as `symbolic_cast` primitives in the YAML definition.
-
-Casting will perform `int<->float` type conversions and will perform any
-necessary extension or truncation within type family. Note that presently,
-any integer type is assumed to be signed for the purpose of determing how to
-extend or truncate. Supporting unsigned integer types is left for future work.
+and output types of constructed ops. An exception are predefined types such as
+`I32`, `I64`, `F32`, and `F64`. These hardwired types enable intermediate
+computations with a type that is independent of the input and output types.
+For example, parts of floating point computation may require double precision
+arithmetic despite all inputs and outputs being single precision values.
+Assignment expressions with no `cast` calls will generally require uniform
+types throughout and will fail to verify if violated. The presence of a
+`cast` allows for a limited form of numeric type conversion between element
+types that can be derived from inputs and outputs (and in the future,
+attributes). `cast` calls with a `TypeVar` first argument are emitted as
+`symbolic_cast` primitives in the YAML definition.
+
+Casting will perform `int<->float` and `index->int` type conversions and will
+perform any necessary extension or truncation within type family. Note that
+presently, any integer type is assumed to be signed for the purpose of
+determining how to extend or truncate. Supporting unsigned integer types is
+left for future work.
Not all functions are applicable for all numeric types, and on mismatch, op
verification will fail.
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py
index 34a8d6d307f8..6db3bcfcc5b2 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py
@@ -232,7 +232,6 @@ class DimDef(AffineExprDef):
"""
ALL_DIMS = dict() # type: Dict[str, "DimDef"]
- dimname: str
def __new__(cls, dimname: str):
existing = cls.ALL_DIMS.get(dimname)
@@ -276,7 +275,6 @@ class SymbolDef(AffineExprDef):
True
"""
ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"]
- symname: str
def __new__(cls, symname: str):
existing = cls.ALL_SYMBOLS.get(symname)
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 85da3323cac6..9b93d33b3246 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -8,7 +8,7 @@
represent actual op definitions (i.e. YAML).
"""
-from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
from mlir import ir as _ir
@@ -27,24 +27,49 @@ class TensorExpression:
def to_scalar_expression(self) -> ScalarExpression:
raise NotImplementedError()
- def visit_affine_exprs(self, callback):
- """Visits all affine expressions reachable by the expression."""
- pass
+ def visit_tensor_exprs(self, callback):
+ """Visits all tensor expression reachable by the expression."""
+ callback(self)
def _get_all_dim_defs(self) -> Set[DimDef]:
"""Recursively gets all DimDef affine expressions that are referenced."""
results = set()
- def visitor(affine_expr):
- if isinstance(affine_expr, DimDef):
- results.add(affine_expr)
+ def visit_dim_def(dim_def):
+ if isinstance(dim_def, DimDef):
+ results.add(dim_def)
- self.visit_affine_exprs(visitor)
+ def visit_affine_exprs(expr):
+ if isinstance(expr, TensorUse):
+ for ind in expr.indices:
+ ind.visit_affine_exprs(visit_dim_def)
+ if isinstance(expr, ReduceApply):
+ for ind in expr.reduce.reduce_dims:
+ ind.visit_affine_exprs(visit_dim_def)
+
+ self.visit_tensor_exprs(visit_affine_exprs)
return results
def collect_uses(self, uses: Set["TensorUse"]):
"""Collects all TensorUses reachable through this expression."""
- pass
+ def visit_tensor_use(expr):
+ if isinstance(expr, TensorUse):
+ uses.add(expr)
+ self.visit_tensor_exprs(visit_tensor_use)
+
+ def collect_indices(self, indices: Set["index"]):
+ """Collects all index accesses reachable through this expression."""
+ def visit_index(expr):
+ if isinstance(expr, index):
+ indices.add(expr)
+ self.visit_tensor_exprs(visit_index)
+
+ def collect_captures(self, captures: Set["CaptureDef"]):
+ """Collects all CaptureDefs reachable through this expression."""
+ def visit_capture_def(expr):
+ if isinstance(expr, CaptureDef):
+ captures.add(expr)
+ self.visit_tensor_exprs(visit_capture_def)
def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
return PrimFn.add(self, rhs)
@@ -84,13 +109,6 @@ def tensor_name(self) -> str:
assert n is not None, "TensorDef not attached"
return n
- def visit_affine_exprs(self, callback):
- for ind in self.indices:
- ind.visit_affine_exprs(callback)
-
- def collect_uses(self, uses: Set["TensorUse"]):
- uses.add(self)
-
def __iadd__(self, rhs: TensorExpression) -> TensorExpression:
return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs)
@@ -178,6 +196,35 @@ def __repr__(self):
return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, "
f"shape={self.shape})")
+class CaptureDef(TensorExpression):
+ """Defines an SSA value captured by the operation.
+
+ The captured SSA values are not indexed by the indexing_maps of the
+ structured op (as opposed to memrefs and tensors). A unique name
+ identifies the captures and an index determines their position the
+ operation's parameter list.
+ """
+
+ def __init__(self, type_var: TypeVar):
+ if not isinstance(type_var, TypeVar):
+ raise ValueError(f"CaptureDef requires a TypeVar. Got: {repr(type_var)}")
+ self.owner = None # type: Optional["LinalgOpDef"]
+ self.type_var = type_var
+ self.capture_name = None # type: Optional[str]
+ self.registered_index = -1 # type: int
+
+ def attach(self, index: int, capture_name: str, owner: "LinalgOpDef"):
+ if self.owner:
+ raise ValueError(f"CaptureDef already registered with op: {self}")
+ self.registered_index = index
+ self.capture_name = capture_name
+ self.owner = owner
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ return ScalarCapture(self.capture_name).expr()
+
+ def __repr__(self):
+ return (f"{self.capture_name}:CaptureDef({repr(self.type_var)})")
class Comprehension:
"""Represents a single comprehension."""
@@ -279,17 +326,52 @@ def to_scalar_expression(self) -> ScalarExpression:
*[arg.to_scalar_expression() for arg in self.args
]).expr()
- def visit_affine_exprs(self, callback):
- for arg in self.args:
- arg.visit_affine_exprs(callback)
-
- def collect_uses(self, uses: Set["TensorUse"]):
+ def visit_tensor_exprs(self, callback):
+ super().visit_tensor_exprs(callback)
for arg in self.args:
- arg.collect_uses(uses)
+ arg.visit_tensor_exprs(callback)
def __repr__(self):
return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})"
+class const(TensorExpression):
+ """Returns the given constant floating point or integer value."""
+
+ def __init__(self, type_var: TypeVar, value: Any):
+ if not isinstance(type_var, TypeVar):
+ raise ValueError(f"const requires a TypeVar. Got: {repr(type_var)}")
+ if not (isinstance(value, float) or isinstance(value, int)):
+ raise ValueError(f"const requires int or float. Got: {type(value)}")
+ self.type_var = type_var
+ self.value = value
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ return ScalarConst(self.type_var, self.value).expr()
+
+ def __repr__(self):
+ return f"const({self.type_var}, {self.value})"
+
+class index(TensorExpression):
+ """Returns the iteration index for a given dimension name.
+
+ Resolves the given dimension name to obtain its position in the iteration
+ domain of the operation.
+ """
+
+ def __init__(self, dim : DimDef):
+ self.dim_def = dim
+ self.dim = -1
+
+ def resolve_dimension_name(self, affine_state: AffineBuildState):
+ self.dim = affine_state.get_dim(self.dim_def.dimname)
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ assert self.dim != -1, "Dimension name not resolved"
+ return ScalarIndex(self.dim).expr()
+
+ def __repr__(self):
+ return f"index({repr(self.dim)})"
+
class cast(TensorExpression):
"""Casts the element type to a type (typically symbolic TypeVar)."""
@@ -302,11 +384,9 @@ def to_scalar_expression(self) -> ScalarExpression:
return ScalarSymbolicCast(self.to_type,
self.operand.to_scalar_expression()).expr()
- def visit_affine_exprs(self, callback):
- self.operand.visit_affine_exprs(callback)
-
- def collect_uses(self, uses: Set["TensorUse"]):
- self.operand.collect_uses(uses)
+ def visit_tensor_exprs(self, callback):
+ super().visit_tensor_exprs(callback)
+ self.operand.visit_tensor_exprs(callback)
def __repr__(self):
return f"cast({self.to_type}, {repr(self.operand)})"
@@ -331,15 +411,9 @@ def to_scalar_expression(self) -> ScalarExpression:
] + [arg.to_scalar_expression() for arg in self.args]
return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr()
- def visit_affine_exprs(self, callback):
- for ind in self.reduce.reduce_dims:
- ind.visit_affine_exprs(callback)
- for arg in self.args:
- arg.visit_affine_exprs(callback)
-
- def collect_uses(self, uses: Set["TensorUse"]):
+ def visit_tensor_exprs(self, callback):
for arg in self.args:
- arg.collect_uses(uses)
+ arg.visit_tensor_exprs(callback)
def __repr__(self):
return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})"
@@ -385,6 +459,7 @@ def __init__(self,
doc: Optional[str] = None):
self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc)
self.registered_tensors = dict() # type: Dict[str, TensorDef]
+ self.registered_captures = dict() # type: Dict[str, CaptureDef]
self.comprehensions = list() # type: List[Comprehension]
self._affine_state = AffineBuildState()
@@ -404,12 +479,13 @@ def add_tensor(self, tensor_name: str, tensor: TensorDef):
tensor.attach(len(self.registered_tensors), tensor_name, self)
self.registered_tensors[tensor_name] = tensor
- def tensor(self, name):
- """Gets a registered tensor by name."""
- try:
- return self.registered_tensors[name]
- except KeyError:
- raise KeyError(f"Tensor {name} is not registered")
+ def add_capture(self, capture_name: str, capture: CaptureDef):
+ """Registers a capture."""
+ if capture_name in self.registered_captures:
+ raise ValueError(f"Capture {capture_name} is already registered "
+ f"to {self.registered_captures['capture_name']}")
+ capture.attach(len(self.registered_captures), capture_name, self)
+ self.registered_captures[capture_name] = capture
def __repr__(self):
lines = [
@@ -417,6 +493,8 @@ def __repr__(self):
]
for name, tensor in self.registered_tensors.items():
lines.append(f" {tensor}")
+ for name, capture in self.registered_captures.items():
+ lines.append(f" {capture}")
if self.comprehensions:
lines[-1] += " {"
for comprehension in self.comprehensions:
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
index fdc6cfd9bab0..a67d18cc37ad 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -70,6 +70,22 @@ def to_yaml_custom_dict(self):
def __repr__(self):
return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})"
+class CaptureDefConfig(YAMLObject):
+ """Wrapper around a CaptureDef."""
+ yaml_tag = "LinalgCaptureDef"
+
+ def __init__(self, capture_def: CaptureDef):
+ self.capture_def = capture_def
+
+ def to_yaml_custom_dict(self):
+ return dict(
+ name=self.capture_def.capture_name,
+ type_var=self.capture_def.type_var.name,
+ )
+
+ def __repr__(self):
+ return f"Def({self.capture_def})"
+
class LinalgIndexingMapsConfig(YAMLObject):
"""Abstracts the style of indexing maps that the op exports.
@@ -109,10 +125,14 @@ def __init__(self,
self.affine_state = AffineBuildState()
self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]]
self.tensor_args = dict() # type: Dict[TensorDef, TensorDefConfig]
+ self.capture_args = dict() # type: Dict[CaptureDef, CaptureDefConfig]
self.uses = dict() # type: Dict[TensorUse, TensorUseConfig]
- # Compute the ordered set of writes.
+ # Compute the ordered set of writes and collect the tensor, capture, and
+ # index uses.
collected_uses = set()
+ collected_captures = set()
+ collected_indices = set()
for write_use, read_use in zip(comprehension.definitions,
comprehension.values):
self.writes.append((write_use, read_use))
@@ -120,10 +140,14 @@ def __init__(self,
for write_use, read_use in self.writes:
collected_uses.add(write_use)
read_use.collect_uses(collected_uses)
+ read_use.collect_captures(collected_captures)
+ read_use.collect_indices(collected_indices)
# Need to add all definitions before uses, so process twice.
for use in collected_uses:
self.add_tensor_arg(use.tensor_def)
+ for capture in collected_captures:
+ self.add_capture_arg(capture)
for use in collected_uses:
self.add_use(use)
@@ -170,6 +194,14 @@ def __init__(self,
f"dims. Got: {all_reduction_dims}")
self.reduction_dims = next(iter(all_reduction_dims))
+ # Check the index dimension exists and resolve
+ for index in collected_indices:
+ if index.dim_def.dimname not in self.affine_state.all_dims:
+ raise ValueError(
+ f"The dimension {index.dim.dimname} is not part of the iteration "
+ f"domain {self.affine_state.all_dims}")
+ index.resolve_dimension_name(self.affine_state)
+
# Generate the scalar assignments (used to build a body).
self.assignments = [
ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression())
@@ -186,6 +218,11 @@ def ordered_tensor_uses(self) -> Sequence[TensorUseConfig]:
return sorted(self.uses.values(),
key=lambda tuc: tuc.tensor_use.tensor_def.registered_index)
+ @property
+ def ordered_capture_args(self) -> Sequence[CaptureDefConfig]:
+ return sorted(self.capture_args.values(),
+ key=lambda cdc: cdc.capture_def.registered_index)
+
@property
def ordered_dims(self) -> Sequence[Tuple[str, int]]:
"""Gets the ordered list of dim bindings (symbolic name, position).
@@ -245,6 +282,12 @@ def add_use(self, tensor_use: TensorUse):
use_config = TensorUseConfig(tensor_use, indexing_map)
self.uses[tensor_use] = use_config
+ def add_capture_arg(self, capture_def: CaptureDef):
+ if capture_def in self.capture_args:
+ return
+ def_config = CaptureDefConfig(capture_def)
+ self.capture_args[capture_def] = def_config
+
def _normalize_affine_map(self,
affine_map: _ir.AffineMap,
with_dims: bool = True) -> _ir.AffineMap:
@@ -258,6 +301,7 @@ def _normalize_affine_map(self,
def to_yaml_custom_dict(self):
self_dict = dict(
args=self.ordered_tensor_args,
+ captures=self.ordered_capture_args,
# TODO: Refactor the hierarchy internally when supporting more
# than static (preserving this serialized form).
indexing_maps=LinalgIndexingMapsConfig(
@@ -272,6 +316,9 @@ def __repr__(self):
lines.append("tensor_args=[")
for def_config in self.ordered_tensor_args:
lines.append(f" {repr(def_config)}")
+ lines.append("], capture_args=[")
+ for def_config in self.ordered_capture_args:
+ lines.append(f" {repr(def_config)}")
lines.append("], indexing_maps=[")
for m in self.indexing_maps:
lines.append(f" {repr(m)}")
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
index 002ae51ba1b0..428eadfe0168 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -105,11 +105,15 @@ def linalg_structured_op(dsl_func=None,
sig = inspect.signature(dsl_func)
for param_name, param in sig.parameters.items():
param_default = param.default
- if not isinstance(param_default, TensorDef):
+ if isinstance(param_default, TensorDef):
+ tc_model.add_tensor(param_name, param_default)
+ elif isinstance(param_default, CaptureDef):
+ tc_model.add_capture(param_name, param_default)
+ else:
raise ValueError(f"@tc_def_op function parameters must be defaulted as "
- f"TensorDef(...): Found {param_name}: {param_default}")
+ f"TensorDef(...) or CaptureDef(...): Found {param_name}"
+ f": {param_default}")
dsl_func_args.append(param_default)
- tc_model.add_tensor(param_name, param_default)
# Invoke the DSL func to finish populating the model.
with bind_op_def(tc_model):
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 682f19138701..4a037025d46a 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -2,7 +2,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Dict, Sequence
+from typing import Any, Dict, Sequence
from mlir.ir import *
from mlir.dialects import linalg
@@ -28,10 +28,20 @@ def isa(cls : Type, ty : Type):
def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
*ins: Value,
- outs: Value):
+ outs: Sequence[Value],
+ captures: Sequence[Value]):
all_arg_defs = op_config.ordered_tensor_args
in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"]
out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"]
+ capture_arg_defs = op_config.ordered_capture_args
+
+ # Verify outs and captures are sequences.
+ if not isinstance(outs, Sequence):
+ raise ValueError(f"Expected named argument outs to have type Sequence "
+ f"but got {type(outs)}")
+ if not isinstance(captures, Sequence):
+ raise ValueError(f"Expected named argument captures to have type Sequence "
+ f"but got {type(outs)}")
# Arity validation.
if len(ins) != len(in_arg_defs):
@@ -40,19 +50,35 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
if outs and len(outs) != len(out_arg_defs):
raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
f"{len(outs)} for {op_config}")
+ if captures and len(captures) != len(capture_arg_defs):
+ raise ValueError(f"Expected {len(capture_arg_defs)} captures but got "
+ f"{len(captures)} for {op_config}")
outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
out_arg_defs, outs)
result_types = [t for t in out_types if isa(RankedTensorType, t)]
- # Extract type vars for input/output based types.
+ # Initialize the type dictionary with the predefined types.
type_mapping = dict() # type: Dict[str, Type]
+ type_mapping["F32"] = F32Type.get()
+ type_mapping["F64"] = F64Type.get()
+ type_mapping["I32"] = IntegerType.get_signless(32)
+ type_mapping["I64"] = IntegerType.get_signless(64)
+
+ # Extract type vars for input/output based types.
for arg_def, arg_element_type in zip(
in_arg_defs + out_arg_defs,
_get_shaped_element_types_from_values(*ins, *outs)):
- tv_name = arg_def.tensor_def.type_var.name
- type_mapping[tv_name] = arg_element_type
+ _add_type_mapping(arg_def.tensor_def.type_var.name, arg_element_type,
+ type_mapping)
+
+ # Extract type vars for captures and compute capture argument mapping.
+ capture_arg_mapping = dict() # type: Dict[str, Value]
+ for arg_def, capture_value in zip(capture_arg_defs, captures):
+ _add_type_mapping(arg_def.capture_def.type_var.name, capture_value.type,
+ type_mapping)
+ capture_arg_mapping[arg_def.capture_def.capture_name] = capture_value
# Emit the generic op.
# TODO: Support emission of pure memref form.
@@ -63,21 +89,22 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig,
for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)])
iterator_types_attr = ArrayAttr.get(
[StringAttr.get(s) for s in op_config.iterator_types])
- sparse_attr = ArrayAttr.get(
- [BoolAttr.get(False) for s in list(ins) + list(outs) if isa(RankedTensorType, s.type)])
- if len(sparse_attr) == 0:
- sparse_attr = None
+ # TODO: Add support for sparse operands once there is a stable interface.
+ sparse_attr = None
return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types,
- type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr)
+ type_mapping, capture_arg_mapping, indexing_maps_attr,
+ iterator_types_attr, sparse_attr)
def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
*ins: Value,
- outs: Value = ()):
- all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \
- type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \
- prepare_common_structured_op(op_config, *ins, outs = outs)
+ outs: Sequence[Value] = (),
+ captures: Sequence[Value] = ()):
+ all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
+ capture_arg_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \
+ prepare_common_structured_op(op_config, *ins, outs = outs,
+ captures=captures)
generic_op = linalg.GenericOp(
result_tensors=result_types,
@@ -95,7 +122,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
block = generic_op.regions[0].blocks.append(*block_arg_types)
block_arg_mapping = dict(zip(block_arg_names, block.arguments))
with InsertionPoint(block):
- body_builder = _BodyBuilder(type_mapping, block_arg_mapping)
+ body_builder = _BodyBuilder(type_mapping, block_arg_mapping,
+ capture_arg_mapping)
for assignment in op_config.assignments:
body_builder.assign(assignment)
body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs))
@@ -110,10 +138,12 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
op_name: str,
op_class_name: str,
*ins: Value,
- outs: Value = ()):
- all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \
- type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \
- prepare_common_structured_op(op_config, *ins, outs = outs)
+ outs: Sequence[Value] = (),
+ captures: Sequence[Value] = ()):
+ all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \
+ capture_arg_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \
+ prepare_common_structured_op(op_config, *ins, outs = outs,
+ captures = captures)
# If we get here, there must exist a builtin class `op_class_name`.
ctx = Context.current
@@ -127,7 +157,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
linalgDialect = ctx.get_dialect_descriptor("linalg")
fill_builtin_region(linalgDialect, named_op.operation)
# Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps
- # attribute that the non-yaml path does not. The non-yaml path hardcodes the
+ # attribute that the non-yaml path does not. The non-yaml path hardcodes the
# indexing_maps in C++ directly.
named_op.operation.attributes["linalg.memoized_indexing_maps"] = indexing_maps_attr
# iterator_types are hardcoded in C++ both in the yaml and non-yaml path.
@@ -141,10 +171,13 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
class _BodyBuilder:
"""Constructs a structured op body by evaluating assignments."""
- def __init__(self, type_mapping: Dict[str, Type],
- block_arg_mapping: Dict[str, Value]):
+ def __init__(self,
+ type_mapping: Dict[str, Type],
+ block_arg_mapping: Dict[str, Value],
+ capture_arg_mapping: Dict[str, Value]):
self.type_mapping = type_mapping
self.block_arg_mapping = block_arg_mapping
+ self.capture_arg_mapping = capture_arg_mapping
self.yield_mapping = dict() # type: Dict[str, Value]
def assign(self, assignment: ScalarAssign):
@@ -161,6 +194,16 @@ def expression(self, expr: ScalarExpression) -> Value:
except KeyError:
raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for "
f"this structured op.")
+ elif expr.scalar_capture:
+ try:
+ return self.capture_arg_mapping[expr.scalar_capture.capture]
+ except KeyError:
+ raise ValueError(f"Capture {expr.scalar_capture.capture} is not bound for "
+ f"this structured op.")
+ elif expr.scalar_const:
+ return self.constant(expr.scalar_const.type_var.name, expr.scalar_const.value)
+ elif expr.scalar_index:
+ return self.index(expr.scalar_index.dim)
elif expr.scalar_apply:
try:
fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}")
@@ -177,6 +220,25 @@ def expression(self, expr: ScalarExpression) -> Value:
return self.cast(expr.symbolic_cast.to_type.name, operand_value)
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
+ def constant(self, type_var_name: str, value: Any) -> Value:
+ try:
+ type = self.type_mapping[type_var_name]
+ except KeyError:
+ raise ValueError(f"Unbound type variable '{type_var_name}' ("
+ f"expected one of {self.type_mappings.keys()}")
+ try:
+ if(_is_floating_point_type(type)):
+ return std.ConstantOp(type, FloatAttr.get(type, float(value))).result
+ elif(_is_integer_type(type)):
+ return std.ConstantOp(type, IntegerAttr.get(type, int(value))).result
+ except ValueError:
+ raise ValueError(f"Unable to cast value {value} to type {type}")
+ raise NotImplementedError(f"Unimplemented constant type {type}")
+
+ def index(self, dim: int) -> Value:
+ dim_attr = IntegerAttr.get(IntegerType.get_signless(64), dim)
+ return linalg.IndexOp(IndexType.get(), dim_attr).result
+
def cast(self, type_var_name: str, operand: Value) -> Value:
try:
to_type = self.type_mapping[type_var_name]
@@ -189,15 +251,13 @@ def cast(self, type_var_name: str, operand: Value) -> Value:
return self._cast_to_integer(to_type, operand)
elif _is_floating_point_type(to_type):
return self._cast_to_floating_point(to_type, operand)
-
- raise ValueError(f"Unable to cast body expression from {operand.type} to "
- f"{to_type}")
-
def _cast_to_integer(self, to_type: Type, operand: Value) -> Value:
to_width = IntegerType(to_type).width
operand_type = operand.type
if _is_floating_point_type(operand_type):
return std.FPToSIOp(to_type, operand).result
+ if _is_index_type(operand_type):
+ return std.IndexCastOp(to_type, operand).result
# Assume integer.
from_width = IntegerType(operand_type).width
if to_width > from_width:
@@ -234,14 +294,21 @@ def yield_outputs(self, *output_names: str):
def _eval_add(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return std.AddFOp(lhs.type, lhs, rhs).result
- if _is_integer_type(lhs.type):
+ if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return std.AddIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'add' operand: {lhs}")
+ def _eval_sub(self, lhs: Value, rhs: Value) -> Value:
+ if _is_floating_point_type(lhs.type):
+ return std.SubFOp(lhs.type, lhs, rhs).result
+ if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
+ return std.SubIOp(lhs.type, lhs, rhs).result
+ raise NotImplementedError("Unsupported 'sub' operand: {lhs}")
+
def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return std.MulFOp(lhs.type, lhs, rhs).result
- if _is_integer_type(lhs.type):
+ if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return std.MulIOp(lhs.type, lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
@@ -281,6 +348,12 @@ def _get_tensor_def_names(
*tensor_def_configs: TensorDefConfig) -> Sequence[str]:
return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs]
+def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]):
+ if name in type_mapping:
+ if type_mapping[name] != type:
+ raise ValueError(f"Cannot overwrite type mapping {name} = "
+ f"{type_mapping[name]} by type {type}")
+ type_mapping[name] = type
def _is_floating_point_type(t: Type) -> bool:
# TODO: Create a FloatType in the Python API and implement the switch
@@ -288,10 +361,11 @@ def _is_floating_point_type(t: Type) -> bool:
return (F64Type.isinstance(t) or F32Type.isinstance(t) or
F16Type.isinstance(t) or BF16Type.isinstance(t))
-
def _is_integer_type(t: Type) -> bool:
return IntegerType.isinstance(t)
+def _is_index_type(t: Type) -> bool:
+ return IndexType.isinstance(t)
def _get_floating_point_width(t: Type) -> int:
# TODO: Create a FloatType in the Python API and implement the switch
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
index 9ebf7a9a0fb0..bb1938d71f07 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py
@@ -13,7 +13,7 @@
can be easily consumed from the C++ side, not necessarily for ergonomics.
"""
-from typing import Optional, Sequence
+from typing import Any, Optional, Sequence
from .yaml_helper import *
from .types import *
@@ -22,6 +22,9 @@
"ScalarAssign",
"ScalarApplyFn",
"ScalarArg",
+ "ScalarCapture",
+ "ScalarConst",
+ "ScalarIndex",
"ScalarExpression",
"ScalarSymbolicCast",
]
@@ -53,6 +56,42 @@ def expr(self) -> "ScalarExpression":
def __repr__(self):
return f"(ScalarArg({self.arg})"
+class ScalarCapture:
+ """A type of ScalarExpression that references a named capture."""
+
+ def __init__(self, capture: str):
+ self.capture = capture
+
+ def expr(self) -> "ScalarExpression":
+ return ScalarExpression(scalar_capture=self)
+
+ def __repr__(self):
+ return f"(ScalarCapture({self.capture})"
+
+class ScalarConst:
+ """A type of ScalarExpression representing a constant."""
+
+ def __init__(self, type_var: TypeVar, value: Any):
+ self.type_var = type_var
+ self.value = value
+
+ def expr(self) -> "ScalarExpression":
+ return ScalarExpression(scalar_const=self)
+
+ def __repr__(self):
+ return f"(ScalarConst({self.type_var}, {self.value})"
+
+class ScalarIndex:
+ """A type of ScalarExpression accessing an iteration index."""
+
+ def __init__(self, dim : int):
+ self.dim = dim
+
+ def expr(self) -> "ScalarExpression":
+ return ScalarExpression(scalar_index=self)
+
+ def __repr__(self):
+ return f"(ScalarIndex({self.dim})"
class ScalarSymbolicCast:
"""A type of ScalarExpression that symbolically casts an operand to a TypeVar.
@@ -75,6 +114,9 @@ class ScalarExpression(YAMLObject):
Can be one of:
- ScalarApplyFn
- ScalarArg
+ - ScalarCapture
+ - ScalarConst
+ - ScalarIndex
- ScalarSymbolicCast
"""
yaml_tag = "!ScalarExpression"
@@ -82,13 +124,20 @@ class ScalarExpression(YAMLObject):
def __init__(self,
scalar_apply: Optional[ScalarApplyFn] = None,
scalar_arg: Optional[ScalarArg] = None,
+ scalar_capture: Optional[ScalarCapture] = None,
+ scalar_const: Optional[ScalarConst] = None,
+ scalar_index: Optional[ScalarIndex] = None,
symbolic_cast: Optional[ScalarSymbolicCast] = None):
- if (bool(scalar_apply) + bool(scalar_arg) + bool(symbolic_cast)) != 1:
+ if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_capture) +
+ bool(scalar_const) + bool(scalar_index) + bool(symbolic_cast)) != 1:
raise ValueError(
- "One of 'scalar_apply', 'scalar_block_arg', 'symbolic_cast' must be "
- "specified")
+ "One of 'scalar_apply', 'scalar_arg', 'scalar_capture', 'scalar_const', "
+ "'scalar_index', 'symbolic_cast' must be specified")
self.scalar_apply = scalar_apply
self.scalar_arg = scalar_arg
+ self.scalar_capture = scalar_capture
+ self.scalar_const = scalar_const
+ self.scalar_index = scalar_index
self.symbolic_cast = symbolic_cast
def to_yaml_custom_dict(self):
@@ -99,6 +148,13 @@ def to_yaml_custom_dict(self):
))
elif self.scalar_arg:
return dict(scalar_arg=self.scalar_arg.arg)
+ elif self.scalar_capture:
+ return dict(scalar_capture=self.scalar_capture.capture)
+ elif self.scalar_const:
+ return dict(scalar_const=dict(type_var=self.scalar_const.type_var.name,
+ attributes=[self.scalar_const.value]))
+ elif self.scalar_index:
+ return dict(scalar_index=self.scalar_index.dim)
elif self.symbolic_cast:
# Note that even though operands must be arity 1, we write it the
# same way as for apply because it allows handling code to be more
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py
index 35bbfe712541..ddac87287e61 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py
@@ -22,6 +22,12 @@
"TypeVar",
"TV",
+ # Predefined types.
+ "I32",
+ "I64",
+ "F32",
+ "F64",
+
# TypeVar aliases.
"T",
"U",
@@ -63,6 +69,12 @@ def __getattr__(self, n):
# Expando access via TV.foo
TV = TypeVar.create_expando()
+# Predefined types.
+I32 = TV.I32
+I64 = TV.I64
+F32 = TV.F32
+F64 = TV.F64
+
# Some common type name aliases.
T = TV.T
U = TV.U
diff --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
index 5445daefa49f..91274dd79740 100644
--- a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -23,6 +23,18 @@ def matmul_poly(A=TensorDef(TV.T1, S.M, S.K),
C=TensorDef(U, S.M, S.N, output=True)):
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+ at linalg_structured_op
+def fill_rng_2d(A=TensorDef(T, S.M, S.N, output=True),
+ min=CaptureDef(F64),
+ max=CaptureDef(F64),
+ seed=CaptureDef(I32)):
+ multiplier = const(I32, 1103515245)
+ increment = const(I32, 12345)
+ temp1 = (cast(I32, index(D.m)) + seed) * multiplier + increment
+ temp2 = (cast(I32, index(D.n)) + temp1) * multiplier + increment
+ inv_randmax = const(F64, 2.3283064e-10)
+ scaling = (max - min) * inv_randmax
+ A[D.m, D.n] = cast(T, cast(F64, temp2) * scaling + min)
with Context() as ctx, Location.unknown():
module = Module.create()
@@ -142,5 +154,27 @@ def test_f16f16f32_matmul(lhs, rhs, init_result):
def test_f64f64f32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
+ # CHECK-LABEL: @test_fill_rng_2d
+ # CHECK-SAME: %{{.*}} tensor<4x16xi32>, %[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32
+ # CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index
+ # CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index
+ # CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32
+ # CHECK-DAG: %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32
+ # CHECK-DAG: %[[RND0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32
+ # CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i32
+ # CHECK-DAG: %[[CST1:.+]] = constant 12345 : i32
+ # CHECK-DAG: %[[RND1:.+]] = muli %[[RND0]], %[[CST0]] : i32
+ # CHECK-DAG: %[[RND2:.+]] = addi %[[RND1]], %[[CST1]] : i32
+ # CHECK: %[[RND3:.+]] = sitofp %{{.*}} : i32 to f64
+ # CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64
+ # CHECK-DAG: %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64
+ # CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64
+ # CHECK-DAG: %[[RND4:.+]] = mulf %[[RND3]], %[[FACT]] : f64
+ # CHECK-DAG: %[[RND5:.+]] = addf %[[RND4]], %[[MIN]] : f64
+ # CHECK-DAG: %{{.*}} = fptosi %[[RND5]] : f64 to i32
+ @builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32),
+ f64, f64, i32)
+ def test_fill_rng_2d(init_result, min, max, seed):
+ return fill_rng_2d(outs=[init_result], captures=[min, max, seed])
print(module)
More information about the Mlir-commits
mailing list