[Mlir-commits] [mlir] 4121090 - [mlir][OpDSL] Restructure comprehension.py (NFC).
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 14 05:00:34 PST 2022
Author: gysit
Date: 2022-02-14T12:56:01Z
New Revision: 4121090893d5444514284028e1185ccd3c3abf7a
URL: https://github.com/llvm/llvm-project/commit/4121090893d5444514284028e1185ccd3c3abf7a
DIFF: https://github.com/llvm/llvm-project/commit/4121090893d5444514284028e1185ccd3c3abf7a.diff
LOG: [mlir][OpDSL] Restructure comprehension.py (NFC).
Group and reorder the classed defined by comprehension.py and add type annotations.
Depends On D119126
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D119692
Added:
Modified:
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 4513236b8703f..300ea08332a7f 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -8,7 +8,7 @@
represent actual op definitions (i.e. YAML).
"""
-from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
+from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from enum import Enum
from ..... import ir as _ir
@@ -17,6 +17,10 @@
from .types import *
from .yaml_helper import *
+###############################################################################
+# Tensor expression nodes.
+###############################################################################
+
class TensorExpression:
"""An expression that can appear on the RHS of a comprehension."""
@@ -24,19 +28,18 @@ class TensorExpression:
def to_scalar_expression(self) -> ScalarExpression:
raise NotImplementedError()
- def visit_tensor_exprs(self, callback):
+ def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
"""Visits all tensor expression reachable by the expression."""
callback(self)
def collect_dim_uses(self, uses: Set["DimDef"]):
"""Collects all DimDefs reachable through this expression."""
- results = set()
- def visit_dim_def(dim_def):
+ def visit_dim_def(dim_def: AffineExprDef):
if isinstance(dim_def, DimDef):
uses.add(dim_def)
- def visit_affine_exprs(expr):
+ def visit_affine_exprs(expr: "TensorExpression"):
if isinstance(expr, TensorUse):
for ind in expr.indices:
ind.visit_affine_exprs(visit_dim_def)
@@ -49,7 +52,7 @@ def visit_affine_exprs(expr):
def collect_tensor_uses(self, uses: Set["TensorUse"]):
"""Collects all TensorUses reachable through this expression."""
- def visit_tensor_use(expr):
+ def visit_tensor_use(expr: "TensorExpression"):
if isinstance(expr, TensorUse):
uses.add(expr)
@@ -58,7 +61,7 @@ def visit_tensor_use(expr):
def collect_indices(self, indices: Set["index"]):
"""Collects all index accesses reachable through this expression."""
- def visit_index(expr):
+ def visit_index(expr: "TensorExpression"):
if isinstance(expr, index):
indices.add(expr)
@@ -67,7 +70,7 @@ def visit_index(expr):
def collect_scalar_uses(self, uses: Set["ScalarDef"]):
"""Collects all ScalarDefs reachable through this expression."""
- def visit_scalar_def(expr):
+ def visit_scalar_def(expr: "TensorExpression"):
if isinstance(expr, ScalarDef):
uses.add(expr)
@@ -111,26 +114,261 @@ def tensor_name(self) -> str:
assert name is not None, "TensorDef not attached"
return name
- def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
- return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
-
def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
- """For implicit reductions, computes default reduction dims.
-
- Assumes that the rhs is the expression being reduced and self is being
- reduced into. Any indices referenced on the rhs and not in self are
- considered reduction dims and will be ordered as encountered on the rhs.
- """
+ # Computes the reduction dims for implicit reductions. Assumes that the rhs
+ # is the expression being reduced and self is being reduced into. Any
+ # indices referenced on the rhs and not in self are considered reduction
+ # dims and will be ordered as encountered on the rhs.
rhs_dims = set()
lhs_dims = set()
rhs.collect_dim_uses(rhs_dims)
self.collect_dim_uses(lhs_dims)
return rhs_dims - lhs_dims
+ def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
+ return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
+
def __repr__(self):
return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]"
+class TensorArithFn(TensorExpression):
+ """Application of an arithmetic function."""
+
+ def __init__(self, arith_fn: "ArithFnType", args: Sequence[TensorExpression]):
+ self.arith_fn = arith_fn
+ self.args = tuple(args)
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ return ScalarArithFn(self.arith_fn.fn_name,
+ *[arg.to_scalar_expression() for arg in self.args
+ ]).expr()
+
+ def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+ super().visit_tensor_exprs(callback)
+ for arg in self.args:
+ arg.visit_tensor_exprs(callback)
+
+ def __repr__(self):
+ return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})"
+
+
+class TensorTypeFn(TensorExpression):
+ """Application of a type conversion function."""
+
+ def __init__(self, type_fn: "TypeFn", type_var: TypeVar,
+ arg: TensorExpression):
+ self.type_fn = type_fn
+ self.type_var = type_var
+ self.arg = arg
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ return ScalarTypeFn(self.type_fn.fn_name, self.type_var,
+ self.arg.to_scalar_expression()).expr()
+
+ def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+ super().visit_tensor_exprs(callback)
+ self.arg.visit_tensor_exprs(callback)
+
+ def __repr__(self):
+ return f"{repr(self.type_fn)}({self.type_var}, {self.arg})"
+
+
+class TensorReduceFn(TensorExpression):
+ """Application of a reduction function.
+
+ This captures the lhs (initial value) separately from the rhs.
+ """
+
+ def __init__(self, reduce_use: "ReduceFnUse",
+ args: Sequence[TensorExpression]):
+ self.reduce_use = reduce_use
+ self.lhs = None # type: Optional[TensorUse]
+ self.args = tuple(args)
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ if self.lhs is None:
+ raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been "
+ f"bound to its lhs: {self}")
+ full_args = [self.lhs.to_scalar_expression()
+ ] + [arg.to_scalar_expression() for arg in self.args]
+ return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr()
+
+ def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+ for arg in self.args:
+ arg.visit_tensor_exprs(callback)
+
+ def __repr__(self):
+ return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
+
+
+class const(TensorExpression):
+ """Returns the given constant floating point or integer value."""
+
+ def __init__(self, value: Any):
+ with _ir.Context():
+ if isinstance(value, float):
+ self.value = str(_ir.FloatAttr.get_f64(float(value)))
+ elif isinstance(value, int):
+ self.value = str(
+ _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
+ else:
+ raise ValueError(f"const requires int or float but got {type(value)}")
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ return ScalarConst(self.value).expr()
+
+ def __repr__(self):
+ return f"const({self.value})"
+
+
+class index(TensorExpression):
+ """Returns the iteration index for a given dimension name.
+
+ Resolves the given dimension name to obtain its position in the iteration
+ domain of the operation.
+ """
+
+ def __init__(self, dim: DimDef):
+ self.dim_def = dim
+ self.dim = -1
+
+ def resolve_dimension_name(self, affine_state: AffineBuildState):
+ self.dim = affine_state.get_dim(self.dim_def.dimname)
+
+ def to_scalar_expression(self) -> ScalarExpression:
+ assert self.dim != -1, "Dimension name not resolved"
+ return ScalarIndex(self.dim).expr()
+
+ def __repr__(self):
+ return f"index({repr(self.dim)})"
+
+
+###############################################################################
+# Function types and function definitions.
+###############################################################################
+
+
+class TypeFnType:
+ """Type conversion function.
+
+ A type conversion function takes a target type and a tensor expression and
+ returns the casted tensor expression.
+ """
+
+ def __init__(self, fn_name: str):
+ self.fn_name = fn_name
+
+ def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType":
+ return TensorTypeFn(self, type_var, arg)
+
+ def __repr__(self):
+ return f"{self.fn_name}"
+
+
+class TypeFn:
+ """Type conversion function namespace.
+
+ As the integer types are signless, signedness is implement by
diff erent cast
+ functions that treat integers as signed (`cast`) or unsigned
+ (`cast_unsigned`) values.
+
+ Examples:
+ - cast(I32 -> I64) -> `arith.ExtSIOp`
+ - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
+ """
+ cast = TypeFnType("cast")
+ cast_unsigned = TypeFnType("cast_unsigned")
+
+
+class ArithFnType:
+ """Arithmetic function.
+
+ An arithmetic function takes one ore more tensor expressions and returns the
+ function evaluation result.
+ """
+
+ def __init__(self, fn_name: str):
+ self.fn_name = fn_name
+
+ def __call__(self, *args) -> "TensorArithFn":
+ return TensorArithFn(self, args)
+
+ def __repr__(self):
+ return f"{self.fn_name}"
+
+
+class ArithFn:
+ """Arithmetic function namespace.
+
+ As the integer types are signless, signedness is implement by
diff erent
+ functions that treat integers as signed or unsigned values.
+
+ Examples:
+ - max -> `arith.MaxSIOp`
+ - max_unsinged -> `arith.MaxUIOp`
+ """
+ add = ArithFnType("add")
+ exp = ArithFnType("exp")
+ log = ArithFnType("log")
+ mul = ArithFnType("mul")
+ max = ArithFnType("max")
+ min = ArithFnType("min")
+ sub = ArithFnType("sub")
+ max_unsigned = ArithFnType("max_unsigned")
+ min_unsigned = ArithFnType("min_unsigned")
+
+
+class ReduceFnUse:
+ """Reduction function use.
+
+ A reduction use specifies the reduction function and dimensions.
+ """
+
+ def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef):
+ self.arith_fn = arith_fn
+ self.reduce_dims = reduce_dims
+
+ def __call__(self, *args: TensorExpression):
+ return TensorReduceFn(self, args)
+
+ def __repr__(self):
+ return (f"reduce_{self.arith_fn.fn_name}"
+ f"({', '.join(repr(d) for d in self.reduce_dims)})")
+
+
+class ReduceFnType:
+ """Reduction function.
+
+ An arithmetic function that reduces its RHS into its LHS.
+ """
+
+ def __init__(self, arith_fn: ArithFnType):
+ if not isinstance(arith_fn, ArithFnType):
+ raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}")
+ self.arith_fn = arith_fn
+
+ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+ return ReduceFnUse(self.arith_fn, *reduce_dims)
+
+ def __repr__(self):
+ return (f"reduce_{self.arith_fn.fn_name}")
+
+
+class ReduceFn:
+ add = ReduceFnType(ArithFn.add)
+ mul = ReduceFnType(ArithFn.mul)
+ max = ReduceFnType(ArithFn.max)
+ min = ReduceFnType(ArithFn.min)
+ max_unsigned = ReduceFnType(ArithFn.max_unsigned)
+ min_unsigned = ReduceFnType(ArithFn.min_unsigned)
+
+
+###############################################################################
+# Operand definitions.
+###############################################################################
+
+
class OperandKind(Enum):
InputTensor = 0
Scalar = 1
@@ -150,7 +388,7 @@ def __init__(self,
type_var: Optional[TypeVar] = None,
size_exprs: Optional[Sequence[AffineExprDef]] = None,
index_dims: Optional[Sequence[DimDef]] = None,
- default_vals : Optional[Sequence[int]] = None):
+ default_vals: Optional[Sequence[int]] = None):
if type_var and not isinstance(type_var, TypeVar):
raise ValueError(
f"OperandDef requires a TypeVar but got {repr(type_var)}")
@@ -206,7 +444,7 @@ def __init__(self,
self.operand_def = OperandDef(
kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
- def __getitem__(self, dims) -> TensorUse:
+ def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
assert self.operand_def.owner, "TensorDef is not attached to an op"
state = AffineBuildState(
global_state=self.operand_def.owner._affine_state,
@@ -225,7 +463,7 @@ def __getitem__(self, dims) -> TensorUse:
exprs.append(expr_def)
return TensorUse(self.operand_def, exprs)
- def __setitem__(self, dims, value):
+ def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression):
"""Creates a new 1:1 comprehension by binding this tensor to an expression.
Note that due to the way assignment works in Python, we have to capture
@@ -282,6 +520,11 @@ def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
OperandKind.IndexAttr, size_exprs=sizes, default_vals=default)
+###############################################################################
+# Operation definition.
+###############################################################################
+
+
class Comprehension:
"""Represents a single comprehension."""
@@ -320,232 +563,6 @@ def __repr__(self):
return f"{defs_repr} = {values_repr}"
-class TypeFnType:
- """Type conversion function.
-
- A type conversion function takes a target type and a tensor expression and
- returns the casted tensor expression.
- """
-
- def __init__(self, fn_name: str):
- self.fn_name = fn_name
-
- def __call__(self, type_var: TypeVar,
- arg: TensorExpression) -> "TensorTypeFn":
- return TensorTypeFn(self, type_var, arg)
-
- def __repr__(self):
- return f"{self.fn_name}"
-
-
-class TypeFn:
- """Type conversion function namespace.
-
- As the integer types are signless, signedness is implement by
diff erent cast
- functions that treat integers as signed (`cast`) or unsigned
- (`cast_unsigned`) values.
-
- Examples:
- - cast(I32 -> I64) -> `arith.ExtSIOp`
- - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
- """
- cast = TypeFnType("cast")
- cast_unsigned = TypeFnType("cast_unsigned")
-
-
-class ArithFnType:
- """Arithmetic function.
-
- An arithmetic function takes one ore more tensor expressions and returns the
- function evaluation result.
- """
-
- def __init__(self, fn_name: str):
- self.fn_name = fn_name
-
- def __call__(self, *args) -> "TensorArithFn":
- return TensorArithFn(self, args)
-
- def __repr__(self):
- return f"{self.fn_name}"
-
-
-class ArithFn:
- """Arithmetic function namespace.
-
- As the integer types are signless, signedness is implement by
diff erent
- functions that treat integers as signed or unsigned values.
-
- Examples:
- - max -> `arith.MaxSIOp`
- - max_unsinged -> `arith.MaxUIOp`
- """
- add = ArithFnType("add")
- exp = ArithFnType("exp")
- log = ArithFnType("log")
- mul = ArithFnType("mul")
- max = ArithFnType("max")
- min = ArithFnType("min")
- sub = ArithFnType("sub")
- max_unsigned = ArithFnType("max_unsigned")
- min_unsigned = ArithFnType("min_unsigned")
-
-
-class ReduceFnUse:
- """Reduction function use.
-
- A reduction use specifies the reduction function and dimensions.
- """
-
- def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef):
- self.arith_fn = arith_fn
- self.reduce_dims = reduce_dims
-
- def __call__(self, *args: TensorExpression):
- return TensorReduceFn(self, args)
-
- def __repr__(self):
- return (f"reduce_{self.arith_fn.fn_name}"
- f"({', '.join(repr(d) for d in self.reduce_dims)})")
-
-
-class ReduceFnType:
- """Reduction function.
-
- An arithmetic function that reduces its RHS into its LHS.
- """
-
- def __init__(self, arith_fn: ArithFnType):
- if not isinstance(arith_fn, ArithFnType):
- raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}")
- self.arith_fn = arith_fn
-
- def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
- return ReduceFnUse(self.arith_fn, *reduce_dims)
-
- def __repr__(self):
- return (f"reduce_{self.arith_fn.fn_name}")
-
-
-class ReduceFn:
- add = ReduceFnType(ArithFn.add)
- mul = ReduceFnType(ArithFn.mul)
- max = ReduceFnType(ArithFn.max)
- min = ReduceFnType(ArithFn.min)
- max_unsigned = ReduceFnType(ArithFn.max_unsigned)
- min_unsigned = ReduceFnType(ArithFn.min_unsigned)
-
-
-class TensorArithFn(TensorExpression):
- """Application of an arithmetic function."""
-
- def __init__(self, arith_fn: ArithFnType, args: Sequence[TensorExpression]):
- self.arith_fn = arith_fn
- self.args = tuple(args)
-
- def to_scalar_expression(self) -> ScalarExpression:
- return ScalarArithFn(self.arith_fn.fn_name,
- *[arg.to_scalar_expression() for arg in self.args
- ]).expr()
-
- def visit_tensor_exprs(self, callback):
- super().visit_tensor_exprs(callback)
- for arg in self.args:
- arg.visit_tensor_exprs(callback)
-
- def __repr__(self):
- return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})"
-
-
-class TensorTypeFn(TensorExpression):
- """Application of a type conversion function."""
-
- def __init__(self, type_fn: TypeFn, type_var: TypeVar, arg: TensorExpression):
- self.type_fn = type_fn
- self.type_var = type_var
- self.arg = arg
-
- def to_scalar_expression(self) -> ScalarExpression:
- return ScalarTypeFn(self.type_fn.fn_name, self.type_var,
- self.arg.to_scalar_expression()).expr()
-
- def visit_tensor_exprs(self, callback):
- super().visit_tensor_exprs(callback)
- self.arg.visit_tensor_exprs(callback)
-
- def __repr__(self):
- return f"{repr(self.type_fn)}({self.type_var}, {self.arg})"
-
-
-class const(TensorExpression):
- """Returns the given constant floating point or integer value."""
-
- def __init__(self, value: Any):
- with _ir.Context():
- if isinstance(value, float):
- self.value = str(_ir.FloatAttr.get_f64(float(value)))
- elif isinstance(value, int):
- self.value = str(
- _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
- else:
- raise ValueError(f"const requires int or float but got {type(value)}")
-
- def to_scalar_expression(self) -> ScalarExpression:
- return ScalarConst(self.value).expr()
-
- def __repr__(self):
- return f"const({self.value})"
-
-
-class index(TensorExpression):
- """Returns the iteration index for a given dimension name.
-
- Resolves the given dimension name to obtain its position in the iteration
- domain of the operation.
- """
-
- def __init__(self, dim: DimDef):
- self.dim_def = dim
- self.dim = -1
-
- def resolve_dimension_name(self, affine_state: AffineBuildState):
- self.dim = affine_state.get_dim(self.dim_def.dimname)
-
- def to_scalar_expression(self) -> ScalarExpression:
- assert self.dim != -1, "Dimension name not resolved"
- return ScalarIndex(self.dim).expr()
-
- def __repr__(self):
- return f"index({repr(self.dim)})"
-
-
-class TensorReduceFn(TensorExpression):
- """Application of a reduction function.
-
- This captures the lhs (initial value) separately from the rhs.
- """
-
- def __init__(self, reduce_use: ReduceFnUse, args: Sequence[TensorExpression]):
- self.reduce_use = reduce_use
- self.lhs = None # type: Optional[TensorUse]
- self.args = tuple(args)
-
- def to_scalar_expression(self) -> ScalarExpression:
- if self.lhs is None:
- raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been "
- f"bound to its lhs: {self}")
- full_args = [self.lhs.to_scalar_expression()
- ] + [arg.to_scalar_expression() for arg in self.args]
- return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr()
-
- def visit_tensor_exprs(self, callback):
- for arg in self.args:
- arg.visit_tensor_exprs(callback)
-
- def __repr__(self):
- return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
-
-
class OpInterfaceDef:
"""An interface that an op implements."""
More information about the Mlir-commits
mailing list