[Mlir-commits] [mlir] e3b442b - [mlir][OpDSL] Separate `ReduceFn` and `ReduceFnUse`.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 7 04:57:37 PST 2022


Author: gysit
Date: 2022-01-07T12:51:06Z
New Revision: e3b442b62f44491514b10b0dd3949b3259ce80f3

URL: https://github.com/llvm/llvm-project/commit/e3b442b62f44491514b10b0dd3949b3259ce80f3
DIFF: https://github.com/llvm/llvm-project/commit/e3b442b62f44491514b10b0dd3949b3259ce80f3.diff

LOG: [mlir][OpDSL] Separate `ReduceFn` and `ReduceFnUse`.

The revision distinguishes `ReduceFn` and `ReduceFnUse`. The latter has the reduction dimensions attached while the former specifies the arithmetic function only. This separation allows us to adapt the reduction syntax a little bit and specify the reduction dimensions using square brackets (in contrast to the round brackets used for the values to reduce). It als is a preparation to add reduction function attributes to OpDSL. A reduction function attribute shall only specify the arithmetic function and not the reduction dimensions.

Example:
```
ReduceFn.max_unsigned(D.kh, D.kw)(...)
```
changes to:
```
ReduceFn.max_unsigned[D.kh, D.kw](...)
```

Depends On D115240

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D115241

Added: 
    

Modified: 
    mlir/docs/Dialects/Linalg/OpDSL.md
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/python/dialects/linalg/opdsl/emit_pooling.py

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md
index 4a08f0bc4d389..79f22a247bb27 100644
--- a/mlir/docs/Dialects/Linalg/OpDSL.md
+++ b/mlir/docs/Dialects/Linalg/OpDSL.md
@@ -192,11 +192,18 @@ A number of arithmetic functions are supported:
 As the integer types are signless, signedness is implement by 
diff erent
 functions that treat integers as signed or unsigned values.
 
-Reduction functions can appear as the outer-most function on the RHS:
+A subset of the arithmetic functions are supported in reductions. These
+reduction functions can appear as the outermost function on the RHS:
 
 *   `ReduceFn.add` (also overloading the inplace `+=` on a LHS)
 *   `ReduceFn.mul`
 *   `ReduceFn.max`
+*   `ReduceFn.min`
+*   `ReduceFn.max_unsigned`
+*   `ReduceFn.min_unsigned`
+
+As the integer types are signless, signedness is implement by 
diff erent
+functions that treat integers as signed or unsigned values.
 
 Additionally, type conversion functions cast an operand to a target type:
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index fd0fa72661909..ddbebb29fd6f0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -43,8 +43,8 @@ def visit_affine_exprs(expr):
       if isinstance(expr, TensorUse):
         for ind in expr.indices:
           ind.visit_affine_exprs(visit_dim_def)
-      if isinstance(expr, ReduceApply):
-        for ind in expr.reduce.reduce_dims:
+      if isinstance(expr, TensorReduceFn):
+        for ind in expr.reduce_fn.reduce_dims:
           ind.visit_affine_exprs(visit_dim_def)
 
     self.visit_tensor_exprs(visit_affine_exprs)
@@ -114,8 +114,8 @@ def tensor_name(self) -> str:
     assert name is not None, "TensorDef not attached"
     return name
 
-  def __iadd__(self, rhs: TensorExpression) -> TensorExpression:
-    return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs)
+  def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
+    return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
 
   def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
     """For implicit reductions, computes default reduction dims.
@@ -285,7 +285,7 @@ def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
 
     # Find the lhs to reduction rhs.
     for assign, value in bindings:
-      if isinstance(value, ReduceApply):
+      if isinstance(value, TensorReduceFn):
         if value.lhs:
           raise ValueError(f"Reduction expression already assigns: {value}")
         value.lhs = assign
@@ -297,8 +297,8 @@ def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
     """Gets the reduction dims for the comprehension or None."""
     result = set()
     for use in self.values:
-      if isinstance(use, ReduceApply):
-        result.add(use.reduce.reduce_dims)
+      if isinstance(use, TensorReduceFn):
+        result.add(use.reduce_use.reduce_dims)
       else:
         result.add(tuple())
     return result
@@ -360,10 +360,6 @@ def __init__(self, fn_name: str):
   def __call__(self, *args) -> "TensorArithFn":
     return TensorArithFn(self, args)
 
-  def reduce(self, *reduce_dims: DimDef):
-    """Shortcut to create a Reduce operation from this function."""
-    return ReduceFnType(self, *reduce_dims)
-
   def __repr__(self):
     return f"{self.fn_name}"
 
@@ -389,31 +385,49 @@ class ArithFn:
   min_unsigned = ArithFnType("min_unsigned")
 
 
-class ReduceFnType:
-  """A reduction operator that reduces into its LHS from its RHS."""
+class ReduceFnUse:
+  """Reduction function use.
+
+  A reduction use specifies the reduction function and dimensions.
+  """
 
-  def __init__(self, operator: ArithFnType, *reduce_dims: DimDef):
-    """Initializes the ReduceFn with an airthmetic function and dims."""
-    if not isinstance(operator, ArithFnType):
-      raise ValueError(f"Reduce expected a ArithFnType but got {operator}")
-    self.operator = operator
-    self.reduce_dims = tuple(reduce_dims)
+  def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef):
+    self.arith_fn = arith_fn
+    self.reduce_dims = reduce_dims
 
   def __call__(self, *args: TensorExpression):
-    return ReduceApply(self, args)
+    return TensorReduceFn(self, args)
 
   def __repr__(self):
-    return (f"reduce_{self.operator.fn_name}"
+    return (f"reduce_{self.arith_fn.fn_name}"
             f"({', '.join(repr(d) for d in self.reduce_dims)})")
 
 
+class ReduceFnType:
+  """Reduction function.
+
+  An arithmetic function that reduces its RHS into its LHS.
+  """
+
+  def __init__(self, arith_fn: ArithFnType):
+    if not isinstance(arith_fn, ArithFnType):
+      raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}")
+    self.arith_fn = arith_fn
+
+  def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+    return ReduceFnUse(self.arith_fn, *reduce_dims)
+
+  def __repr__(self):
+    return (f"reduce_{self.arith_fn.fn_name}")
+
+
 class ReduceFn:
-  add = ArithFn.add.reduce
-  mul = ArithFn.mul.reduce
-  max = ArithFn.max.reduce
-  min = ArithFn.min.reduce
-  max_unsigned = ArithFn.max_unsigned.reduce
-  min_unsigned = ArithFn.min_unsigned.reduce
+  add = ReduceFnType(ArithFn.add)
+  mul = ReduceFnType(ArithFn.mul)
+  max = ReduceFnType(ArithFn.max)
+  min = ReduceFnType(ArithFn.min)
+  max_unsigned = ReduceFnType(ArithFn.max_unsigned)
+  min_unsigned = ReduceFnType(ArithFn.min_unsigned)
 
 
 class TensorArithFn(TensorExpression):
@@ -499,31 +513,31 @@ def __repr__(self):
     return f"index({repr(self.dim)})"
 
 
-class ReduceApply(TensorExpression):
-  """Application of a reduction.
+class TensorReduceFn(TensorExpression):
+  """Application of a reduction function.
 
-  This captures the lhs separately (initial value) separately from the rhs.
+  This captures the lhs (initial value) separately from the rhs.
   """
 
-  def __init__(self, reduce: ReduceFnType, args: Sequence[TensorExpression]):
-    self.reduce = reduce
+  def __init__(self, reduce_use: ReduceFnUse, args: Sequence[TensorExpression]):
+    self.reduce_use = reduce_use
     self.lhs = None  # type: Optional[TensorUse]
     self.args = tuple(args)
 
   def to_scalar_expression(self) -> ScalarExpression:
     if self.lhs is None:
-      raise ValueError(f"Cannot scalarize a ReduceApply that has not been "
+      raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been "
                        f"bound to its lhs: {self}")
     full_args = [self.lhs.to_scalar_expression()
                 ] + [arg.to_scalar_expression() for arg in self.args]
-    return ScalarArithFn(self.reduce.operator.fn_name, *full_args).expr()
+    return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr()
 
   def visit_tensor_exprs(self, callback):
     for arg in self.args:
       arg.visit_tensor_exprs(callback)
 
   def __repr__(self):
-    return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})"
+    return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
 
 
 class OpInterfaceDef:

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index afc078d509c54..9fe370ffdabb5 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -479,7 +479,7 @@ def pooling_nhwc_max(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
       TypeFn.cast(
           U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
@@ -499,7 +499,7 @@ def pooling_nhwc_max_unsigned(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
       TypeFn.cast_unsigned(
           U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
@@ -519,7 +519,7 @@ def pooling_nchw_max(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
-  O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)(
+  O[D.n, D.c, D.oh, D.ow] = ReduceFn.max[D.kh, D.kw](
       TypeFn.cast(
           U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH,
                D.ow * S.SW + D.kw * S.DW,]))
@@ -540,7 +540,7 @@ def pooling_nhwc_min(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
       TypeFn.cast(
           U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
@@ -560,7 +560,7 @@ def pooling_nhwc_min_unsigned(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
       TypeFn.cast_unsigned(
           U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
@@ -600,7 +600,7 @@ def pooling_ndhwc_max(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
-  O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)(
+  O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max[D.kd, D.kh, D.kw](
       TypeFn.cast(
           U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
                D.ow * S.SW + D.kw * S.DW, D.c]))
@@ -621,7 +621,7 @@ def pooling_ndhwc_min(
   """
   implements(ConvolutionOpInterface)
   domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c)
-  O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)(
+  O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min[D.kd, D.kh, D.kw](
       TypeFn.cast(
           U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH,
                D.ow * S.SW + D.kw * S.DW, D.c]))

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
index ec4e9dfda9591..cf10c9d3f1a2a 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py
@@ -19,7 +19,7 @@ def pooling_max_poly(
     strides=IndexAttrDef(S.SH, S.SW),
     dilations=IndexAttrDef(S.DH, S.DW)):
   domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)(
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw](
       TypeFn.cast(
           U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
@@ -32,7 +32,7 @@ def pooling_max_unsigned_poly(
     strides=IndexAttrDef(S.SH, S.SW),
     dilations=IndexAttrDef(S.DH, S.DW)):
   domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw](
       TypeFn.cast_unsigned(
           U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
@@ -45,7 +45,7 @@ def pooling_min_poly(
     strides=IndexAttrDef(S.SH, S.SW),
     dilations=IndexAttrDef(S.DH, S.DW)):
   domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)(
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw](
       TypeFn.cast(
           U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 
@@ -58,7 +58,7 @@ def pooling_min_unsigned_poly(
     strides=IndexAttrDef(S.SH, S.SW),
     dilations=IndexAttrDef(S.DH, S.DW)):
   domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
-  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
+  O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](
       TypeFn.cast_unsigned(
           U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
 


        


More information about the Mlir-commits mailing list