[Mlir-commits] [mlir] e3b442b - [mlir][OpDSL] Separate `ReduceFn` and `ReduceFnUse`.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 7 04:57:37 PST 2022
Author: gysit
Date: 2022-01-07T12:51:06Z
New Revision: e3b442b62f44491514b10b0dd3949b3259ce80f3
URL: https://github.com/llvm/llvm-project/commit/e3b442b62f44491514b10b0dd3949b3259ce80f3
DIFF: https://github.com/llvm/llvm-project/commit/e3b442b62f44491514b10b0dd3949b3259ce80f3.diff
LOG: [mlir][OpDSL] Separate `ReduceFn` and `ReduceFnUse`.
The revision distinguishes `ReduceFn` and `ReduceFnUse`. The latter has the reduction dimensions attached while the former specifies the arithmetic function only. This separation allows us to adapt the reduction syntax a little bit and specify the reduction dimensions using square brackets (in contrast to the round brackets used for the values to reduce). It als is a preparation to add reduction function attributes to OpDSL. A reduction function attribute shall only specify the arithmetic function and not the reduction dimensions.
Example:
```
ReduceFn.max_unsigned(D.kh, D.kw)(...)
```
changes to:
```
ReduceFn.max_unsigned[D.kh, D.kw](...)
```
Depends On D115240
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D115241
Added:
Modified:
mlir/docs/Dialects/Linalg/OpDSL.md
mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index 4a08f0bc4d389..79f22a247bb27 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -192,11 +192,18 @@ A number of arithmetic functions are supported:
As the integer types are signless, signedness is implement by
diff erent
functions that treat integers as signed or unsigned values.
-Reduction functions can appear as the outer-most function on the RHS:
+A subset of the arithmetic functions are supported in reductions. These
+reduction functions can appear as the outermost function on the RHS:
* `ReduceFn.add` (also overloading the inplace `+=` on a LHS)
* `ReduceFn.mul`
* `ReduceFn.max`
+* `ReduceFn.min`
+* `ReduceFn.max_unsigned`
+* `ReduceFn.min_unsigned`
+
+As the integer types are signless, signedness is implement by
diff erent
+functions that treat integers as signed or unsigned values.
Additionally, type conversion functions cast an operand to a target type:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index fd0fa72661909..ddbebb29fd6f0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -43,8 +43,8 @@ def visit_affine_exprs(expr):
if isinstance(expr, TensorUse):
for ind in expr.indices:
ind.visit_affine_exprs(visit_dim_def)
- if isinstance(expr, ReduceApply):
- for ind in expr.reduce.reduce_dims:
+ if isinstance(expr, TensorReduceFn):
+ for ind in expr.reduce_fn.reduce_dims:
ind.visit_affine_exprs(visit_dim_def)
self.visit_tensor_exprs(visit_affine_exprs)
@@ -114,8 +114,8 @@ def tensor_name(self) -> str:
assert name is not None, "TensorDef not attached"
return name
- def __iadd__(self, rhs: TensorExpression) -> TensorExpression:
- return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs)
+ def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
+ return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
"""For implicit reductions, computes default reduction dims.
@@ -285,7 +285,7 @@ def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
# Find the lhs to reduction rhs.
for assign, value in bindings:
- if isinstance(value, ReduceApply):
+ if isinstance(value, TensorReduceFn):
if value.lhs:
raise ValueError(f"Reduction expression already assigns: {value}")
value.lhs = assign
@@ -297,8 +297,8 @@ def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
"""Gets the reduction dims for the comprehension or None."""
result = set()
for use in self.values:
- if isinstance(use, ReduceApply):
- result.add(use.reduce.reduce_dims)
+ if isinstance(use, TensorReduceFn):
+ result.add(use.reduce_use.reduce_dims)
else:
result.add(tuple())
return result
@@ -360,10 +360,6 @@ def __init__(self, fn_name: str):
def __call__(self, *args) -> "TensorArithFn":
return TensorArithFn(self, args)
- def reduce(self, *reduce_dims: DimDef):
- """Shortcut to create a Reduce operation from this function."""
- return ReduceFnType(self, *reduce_dims)
-
def __repr__(self):
return f"{self.fn_name}"
@@ -389,31 +385,49 @@ class ArithFn:
min_unsigned = ArithFnType("min_unsigned")
-class ReduceFnType:
- """A reduction operator that reduces into its LHS from its RHS."""
+class ReduceFnUse:
+ """Reduction function use.
+
+ A reduction use specifies the reduction function and dimensions.
+ """
- def __init__(self, operator: ArithFnType, *reduce_dims: DimDef):
- """Initializes the ReduceFn with an airthmetic function and dims."""
- if not isinstance(operator, ArithFnType):
- raise ValueError(f"Reduce expected a ArithFnType but got {operator}")
- self.operator = operator
- self.reduce_dims = tuple(reduce_dims)
+ def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef):
+ self.arith_fn = arith_fn
+ self.reduce_dims = reduce_dims
def __call__(self, *args: TensorExpression):
- return ReduceApply(self, args)
+ return TensorReduceFn(self, args)
def __repr__(self):
- return (f"reduce_{self.operator.fn_name}"
+ return (f"reduce_{self.arith_fn.fn_name}"
f"({', '.join(repr(d) for d in self.reduce_dims)})")
+class ReduceFnType:
+ """Reduction function.
+
+ An arithmetic function that reduces its RHS into its LHS.
+ """
+
+ def __init__(self, arith_fn: ArithFnType):
+ if not isinstance(arith_fn, ArithFnType):
+ raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}")
+ self.arith_fn = arith_fn
+
+ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+ return ReduceFnUse(self.arith_fn, *reduce_dims)
+
+ def __repr__(self):
+ return (f"reduce_{self.arith_fn.fn_name}")
+
+
class ReduceFn:
- add = ArithFn.add.reduce
- mul = ArithFn.mul.reduce
- max = ArithFn.max.reduce
- min = ArithFn.min.reduce
- max_unsigned = ArithFn.max_unsigned.reduce
- min_unsigned = ArithFn.min_unsigned.reduce
+ add = ReduceFnType(ArithFn.add)
+ mul = ReduceFnType(ArithFn.mul)
+ max = ReduceFnType(ArithFn.max)
+ min = ReduceFnType(ArithFn.min)
+ max_unsigned = ReduceFnType(ArithFn.max_unsigned)
+ min_unsigned = ReduceFnType(ArithFn.min_unsigned)
class TensorArithFn(TensorExpression):
@@ -499,31 +513,31 @@ def __repr__(self):
return f"index({repr(self.dim)})"
-class ReduceApply(TensorExpression):
- """Application of a reduction.
+class TensorReduceFn(TensorExpression):
+ """Application of a reduction function.
- This captures the lhs separately (initial value) separately from the rhs.
+ This captures the lhs (initial value) separately from the rhs.
"""
- def __init__(self, reduce: ReduceFnType, args: Sequence[TensorExpression]):
- self.reduce = reduce
+ def __init__(self, reduce_use: ReduceFnUse, args: Sequence[TensorExpression]):
+ self.reduce_use = reduce_use
self.lhs = None # type: Optional[TensorUse]
self.args = tuple(args)
def to_scalar_expression(self) -> ScalarExpression:
if self.lhs is None:
- raise ValueError(f"Cannot scalarize a ReduceApply that has not been "
+ raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been "
f"bound to its lhs: {self}")
full_args = [self.lhs.to_scalar_expression()
] + [arg.to_scalar_expression() for arg in self.args]
- return ScalarArithFn(self.reduce.operator.fn_name, *full_args).expr()
+ return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr()
def visit_tensor_exprs(self, callback):
for arg in self.args:
arg.visit_tensor_exprs(callback)
def __repr__(self):
- return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})"
+ return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
class OpInterfaceDef:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index afc078d509c54..9fe370ffdabb5 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -479,7 +479,7 @@ def pooling_nhwc_max(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
TypeFn.cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@@ -499,7 +499,7 @@ def pooling_nhwc_max_unsigned(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
TypeFn.cast_unsigned(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@@ -519,7 +519,7 @@ def pooling_nchw_max(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
- O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)(
+ O[D.n, D.c, D.oh, D.ow] = ReduceFn.max[D.kh, D.kw](
TypeFn.cast(
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW,]))
@@ -540,7 +540,7 @@ def pooling_nhwc_min(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
TypeFn.cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@@ -560,7 +560,7 @@ def pooling_nhwc_min_unsigned(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
TypeFn.cast_unsigned(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@@ -600,7 +600,7 @@ def pooling_ndhwc_max(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
- O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)(
+ O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max[D.kd, D.kh, D.kw](
TypeFn.cast(
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW, D.c]))
@@ -621,7 +621,7 @@ def pooling_ndhwc_min(
"""
implements(ConvolutionOpInterface)
domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
- O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)(
+ O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min[D.kd, D.kh, D.kw](
TypeFn.cast(
U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
D.ow * S.SW + D.kw * S.DW, D.c]))
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
index ec4e9dfda9591..cf10c9d3f1a2a 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
@@ -19,7 +19,7 @@ def pooling_max_poly(
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
TypeFn.cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@@ -32,7 +32,7 @@ def pooling_max_unsigned_poly(
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
TypeFn.cast_unsigned(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@@ -45,7 +45,7 @@ def pooling_min_poly(
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
TypeFn.cast(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
@@ -58,7 +58,7 @@ def pooling_min_unsigned_poly(
strides=IndexAttrDef(S.SH, S.SW),
dilations=IndexAttrDef(S.DH, S.DW)):
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
- O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
+ O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
TypeFn.cast_unsigned(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
More information about the Mlir-commits
mailing list