[Mlir-commits] [mlir] 4121090 - [mlir][OpDSL] Restructure comprehension.py (NFC).

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 14 05:00:34 PST 2022


Author: gysit
Date: 2022-02-14T12:56:01Z
New Revision: 4121090893d5444514284028e1185ccd3c3abf7a

URL: https://github.com/llvm/llvm-project/commit/4121090893d5444514284028e1185ccd3c3abf7a
DIFF: https://github.com/llvm/llvm-project/commit/4121090893d5444514284028e1185ccd3c3abf7a.diff

LOG: [mlir][OpDSL] Restructure comprehension.py (NFC).

Group and reorder the classed defined by comprehension.py and add type annotations.

Depends On D119126

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D119692

Added: 
    

Modified: 
    mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 4513236b8703f..300ea08332a7f 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -8,7 +8,7 @@
 represent actual op definitions (i.e. YAML).
 """
 
-from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
+from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
 from enum import Enum
 
 from ..... import ir as _ir
@@ -17,6 +17,10 @@
 from .types import *
 from .yaml_helper import *
 
+###############################################################################
+# Tensor expression nodes.
+###############################################################################
+
 
 class TensorExpression:
   """An expression that can appear on the RHS of a comprehension."""
@@ -24,19 +28,18 @@ class TensorExpression:
   def to_scalar_expression(self) -> ScalarExpression:
     raise NotImplementedError()
 
-  def visit_tensor_exprs(self, callback):
+  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
     """Visits all tensor expression reachable by the expression."""
     callback(self)
 
   def collect_dim_uses(self, uses: Set["DimDef"]):
     """Collects all DimDefs reachable through this expression."""
-    results = set()
 
-    def visit_dim_def(dim_def):
+    def visit_dim_def(dim_def: AffineExprDef):
       if isinstance(dim_def, DimDef):
         uses.add(dim_def)
 
-    def visit_affine_exprs(expr):
+    def visit_affine_exprs(expr: "TensorExpression"):
       if isinstance(expr, TensorUse):
         for ind in expr.indices:
           ind.visit_affine_exprs(visit_dim_def)
@@ -49,7 +52,7 @@ def visit_affine_exprs(expr):
   def collect_tensor_uses(self, uses: Set["TensorUse"]):
     """Collects all TensorUses reachable through this expression."""
 
-    def visit_tensor_use(expr):
+    def visit_tensor_use(expr: "TensorExpression"):
       if isinstance(expr, TensorUse):
         uses.add(expr)
 
@@ -58,7 +61,7 @@ def visit_tensor_use(expr):
   def collect_indices(self, indices: Set["index"]):
     """Collects all index accesses reachable through this expression."""
 
-    def visit_index(expr):
+    def visit_index(expr: "TensorExpression"):
       if isinstance(expr, index):
         indices.add(expr)
 
@@ -67,7 +70,7 @@ def visit_index(expr):
   def collect_scalar_uses(self, uses: Set["ScalarDef"]):
     """Collects all ScalarDefs reachable through this expression."""
 
-    def visit_scalar_def(expr):
+    def visit_scalar_def(expr: "TensorExpression"):
       if isinstance(expr, ScalarDef):
         uses.add(expr)
 
@@ -111,26 +114,261 @@ def tensor_name(self) -> str:
     assert name is not None, "TensorDef not attached"
     return name
 
-  def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
-    return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
-
   def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
-    """For implicit reductions, computes default reduction dims.
-
-    Assumes that the rhs is the expression being reduced and self is being
-    reduced into. Any indices referenced on the rhs and not in self are
-    considered reduction dims and will be ordered as encountered on the rhs.
-    """
+    # Computes the reduction dims for implicit reductions. Assumes that the rhs
+    # is the expression being reduced and self is being reduced into. Any
+    # indices referenced on the rhs and not in self are considered reduction
+    # dims and will be ordered as encountered on the rhs.
     rhs_dims = set()
     lhs_dims = set()
     rhs.collect_dim_uses(rhs_dims)
     self.collect_dim_uses(lhs_dims)
     return rhs_dims - lhs_dims
 
+  def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn":
+    return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs)
+
   def __repr__(self):
     return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]"
 
 
+class TensorArithFn(TensorExpression):
+  """Application of an arithmetic function."""
+
+  def __init__(self, arith_fn: "ArithFnType", args: Sequence[TensorExpression]):
+    self.arith_fn = arith_fn
+    self.args = tuple(args)
+
+  def to_scalar_expression(self) -> ScalarExpression:
+    return ScalarArithFn(self.arith_fn.fn_name,
+                         *[arg.to_scalar_expression() for arg in self.args
+                          ]).expr()
+
+  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+    super().visit_tensor_exprs(callback)
+    for arg in self.args:
+      arg.visit_tensor_exprs(callback)
+
+  def __repr__(self):
+    return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})"
+
+
+class TensorTypeFn(TensorExpression):
+  """Application of a type conversion function."""
+
+  def __init__(self, type_fn: "TypeFn", type_var: TypeVar,
+               arg: TensorExpression):
+    self.type_fn = type_fn
+    self.type_var = type_var
+    self.arg = arg
+
+  def to_scalar_expression(self) -> ScalarExpression:
+    return ScalarTypeFn(self.type_fn.fn_name, self.type_var,
+                        self.arg.to_scalar_expression()).expr()
+
+  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+    super().visit_tensor_exprs(callback)
+    self.arg.visit_tensor_exprs(callback)
+
+  def __repr__(self):
+    return f"{repr(self.type_fn)}({self.type_var}, {self.arg})"
+
+
+class TensorReduceFn(TensorExpression):
+  """Application of a reduction function.
+
+  This captures the lhs (initial value) separately from the rhs.
+  """
+
+  def __init__(self, reduce_use: "ReduceFnUse",
+               args: Sequence[TensorExpression]):
+    self.reduce_use = reduce_use
+    self.lhs = None  # type: Optional[TensorUse]
+    self.args = tuple(args)
+
+  def to_scalar_expression(self) -> ScalarExpression:
+    if self.lhs is None:
+      raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been "
+                       f"bound to its lhs: {self}")
+    full_args = [self.lhs.to_scalar_expression()
+                ] + [arg.to_scalar_expression() for arg in self.args]
+    return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr()
+
+  def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]):
+    for arg in self.args:
+      arg.visit_tensor_exprs(callback)
+
+  def __repr__(self):
+    return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
+
+
+class const(TensorExpression):
+  """Returns the given constant floating point or integer value."""
+
+  def __init__(self, value: Any):
+    with _ir.Context():
+      if isinstance(value, float):
+        self.value = str(_ir.FloatAttr.get_f64(float(value)))
+      elif isinstance(value, int):
+        self.value = str(
+            _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
+      else:
+        raise ValueError(f"const requires int or float but got {type(value)}")
+
+  def to_scalar_expression(self) -> ScalarExpression:
+    return ScalarConst(self.value).expr()
+
+  def __repr__(self):
+    return f"const({self.value})"
+
+
+class index(TensorExpression):
+  """Returns the iteration index for a given dimension name.
+
+  Resolves the given dimension name to obtain its position in the iteration
+  domain of the operation.
+  """
+
+  def __init__(self, dim: DimDef):
+    self.dim_def = dim
+    self.dim = -1
+
+  def resolve_dimension_name(self, affine_state: AffineBuildState):
+    self.dim = affine_state.get_dim(self.dim_def.dimname)
+
+  def to_scalar_expression(self) -> ScalarExpression:
+    assert self.dim != -1, "Dimension name not resolved"
+    return ScalarIndex(self.dim).expr()
+
+  def __repr__(self):
+    return f"index({repr(self.dim)})"
+
+
+###############################################################################
+# Function types and function definitions.
+###############################################################################
+
+
+class TypeFnType:
+  """Type conversion function.
+
+  A type conversion function takes a target type and a tensor expression and
+  returns the casted tensor expression.
+  """
+
+  def __init__(self, fn_name: str):
+    self.fn_name = fn_name
+
+  def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType":
+    return TensorTypeFn(self, type_var, arg)
+
+  def __repr__(self):
+    return f"{self.fn_name}"
+
+
+class TypeFn:
+  """Type conversion function namespace.
+
+  As the integer types are signless, signedness is implement by 
diff erent cast
+  functions that treat integers as signed (`cast`) or unsigned
+  (`cast_unsigned`) values.
+
+  Examples:
+  - cast(I32 -> I64) -> `arith.ExtSIOp`
+  - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
+  """
+  cast = TypeFnType("cast")
+  cast_unsigned = TypeFnType("cast_unsigned")
+
+
+class ArithFnType:
+  """Arithmetic function.
+
+  An arithmetic function takes one ore more tensor expressions and returns the
+  function evaluation result.
+  """
+
+  def __init__(self, fn_name: str):
+    self.fn_name = fn_name
+
+  def __call__(self, *args) -> "TensorArithFn":
+    return TensorArithFn(self, args)
+
+  def __repr__(self):
+    return f"{self.fn_name}"
+
+
+class ArithFn:
+  """Arithmetic function namespace.
+
+  As the integer types are signless, signedness is implement by 
diff erent
+  functions that treat integers as signed or unsigned values.
+
+  Examples:
+  - max -> `arith.MaxSIOp`
+  - max_unsinged -> `arith.MaxUIOp`
+  """
+  add = ArithFnType("add")
+  exp = ArithFnType("exp")
+  log = ArithFnType("log")
+  mul = ArithFnType("mul")
+  max = ArithFnType("max")
+  min = ArithFnType("min")
+  sub = ArithFnType("sub")
+  max_unsigned = ArithFnType("max_unsigned")
+  min_unsigned = ArithFnType("min_unsigned")
+
+
+class ReduceFnUse:
+  """Reduction function use.
+
+  A reduction use specifies the reduction function and dimensions.
+  """
+
+  def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef):
+    self.arith_fn = arith_fn
+    self.reduce_dims = reduce_dims
+
+  def __call__(self, *args: TensorExpression):
+    return TensorReduceFn(self, args)
+
+  def __repr__(self):
+    return (f"reduce_{self.arith_fn.fn_name}"
+            f"({', '.join(repr(d) for d in self.reduce_dims)})")
+
+
+class ReduceFnType:
+  """Reduction function.
+
+  An arithmetic function that reduces its RHS into its LHS.
+  """
+
+  def __init__(self, arith_fn: ArithFnType):
+    if not isinstance(arith_fn, ArithFnType):
+      raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}")
+    self.arith_fn = arith_fn
+
+  def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
+    return ReduceFnUse(self.arith_fn, *reduce_dims)
+
+  def __repr__(self):
+    return (f"reduce_{self.arith_fn.fn_name}")
+
+
+class ReduceFn:
+  add = ReduceFnType(ArithFn.add)
+  mul = ReduceFnType(ArithFn.mul)
+  max = ReduceFnType(ArithFn.max)
+  min = ReduceFnType(ArithFn.min)
+  max_unsigned = ReduceFnType(ArithFn.max_unsigned)
+  min_unsigned = ReduceFnType(ArithFn.min_unsigned)
+
+
+###############################################################################
+# Operand definitions.
+###############################################################################
+
+
 class OperandKind(Enum):
   InputTensor = 0
   Scalar = 1
@@ -150,7 +388,7 @@ def __init__(self,
                type_var: Optional[TypeVar] = None,
                size_exprs: Optional[Sequence[AffineExprDef]] = None,
                index_dims: Optional[Sequence[DimDef]] = None,
-               default_vals : Optional[Sequence[int]] = None):
+               default_vals: Optional[Sequence[int]] = None):
     if type_var and not isinstance(type_var, TypeVar):
       raise ValueError(
           f"OperandDef requires a TypeVar but got {repr(type_var)}")
@@ -206,7 +444,7 @@ def __init__(self,
     self.operand_def = OperandDef(
         kind, type_var=type_var, size_exprs=shape, index_dims=index_dims)
 
-  def __getitem__(self, dims) -> TensorUse:
+  def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse:
     assert self.operand_def.owner, "TensorDef is not attached to an op"
     state = AffineBuildState(
         global_state=self.operand_def.owner._affine_state,
@@ -225,7 +463,7 @@ def __getitem__(self, dims) -> TensorUse:
       exprs.append(expr_def)
     return TensorUse(self.operand_def, exprs)
 
-  def __setitem__(self, dims, value):
+  def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression):
     """Creates a new 1:1 comprehension by binding this tensor to an expression.
 
     Note that due to the way assignment works in Python, we have to capture
@@ -282,6 +520,11 @@ def __init__(self, *sizes: SymbolDef, default: Sequence[int]):
         OperandKind.IndexAttr, size_exprs=sizes, default_vals=default)
 
 
+###############################################################################
+# Operation definition.
+###############################################################################
+
+
 class Comprehension:
   """Represents a single comprehension."""
 
@@ -320,232 +563,6 @@ def __repr__(self):
     return f"{defs_repr} = {values_repr}"
 
 
-class TypeFnType:
-  """Type conversion function.
-
-  A type conversion function takes a target type and a tensor expression and
-  returns the casted tensor expression.
-  """
-
-  def __init__(self, fn_name: str):
-    self.fn_name = fn_name
-
-  def __call__(self, type_var: TypeVar,
-               arg: TensorExpression) -> "TensorTypeFn":
-    return TensorTypeFn(self, type_var, arg)
-
-  def __repr__(self):
-    return f"{self.fn_name}"
-
-
-class TypeFn:
-  """Type conversion function namespace.
-
-  As the integer types are signless, signedness is implement by 
diff erent cast
-  functions that treat integers as signed (`cast`) or unsigned
-  (`cast_unsigned`) values.
-
-  Examples:
-  - cast(I32 -> I64) -> `arith.ExtSIOp`
-  - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp`
-  """
-  cast = TypeFnType("cast")
-  cast_unsigned = TypeFnType("cast_unsigned")
-
-
-class ArithFnType:
-  """Arithmetic function.
-
-  An arithmetic function takes one ore more tensor expressions and returns the
-  function evaluation result.
-  """
-
-  def __init__(self, fn_name: str):
-    self.fn_name = fn_name
-
-  def __call__(self, *args) -> "TensorArithFn":
-    return TensorArithFn(self, args)
-
-  def __repr__(self):
-    return f"{self.fn_name}"
-
-
-class ArithFn:
-  """Arithmetic function namespace.
-
-  As the integer types are signless, signedness is implement by 
diff erent
-  functions that treat integers as signed or unsigned values.
-
-  Examples:
-  - max -> `arith.MaxSIOp`
-  - max_unsinged -> `arith.MaxUIOp`
-  """
-  add = ArithFnType("add")
-  exp = ArithFnType("exp")
-  log = ArithFnType("log")
-  mul = ArithFnType("mul")
-  max = ArithFnType("max")
-  min = ArithFnType("min")
-  sub = ArithFnType("sub")
-  max_unsigned = ArithFnType("max_unsigned")
-  min_unsigned = ArithFnType("min_unsigned")
-
-
-class ReduceFnUse:
-  """Reduction function use.
-
-  A reduction use specifies the reduction function and dimensions.
-  """
-
-  def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef):
-    self.arith_fn = arith_fn
-    self.reduce_dims = reduce_dims
-
-  def __call__(self, *args: TensorExpression):
-    return TensorReduceFn(self, args)
-
-  def __repr__(self):
-    return (f"reduce_{self.arith_fn.fn_name}"
-            f"({', '.join(repr(d) for d in self.reduce_dims)})")
-
-
-class ReduceFnType:
-  """Reduction function.
-
-  An arithmetic function that reduces its RHS into its LHS.
-  """
-
-  def __init__(self, arith_fn: ArithFnType):
-    if not isinstance(arith_fn, ArithFnType):
-      raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}")
-    self.arith_fn = arith_fn
-
-  def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
-    return ReduceFnUse(self.arith_fn, *reduce_dims)
-
-  def __repr__(self):
-    return (f"reduce_{self.arith_fn.fn_name}")
-
-
-class ReduceFn:
-  add = ReduceFnType(ArithFn.add)
-  mul = ReduceFnType(ArithFn.mul)
-  max = ReduceFnType(ArithFn.max)
-  min = ReduceFnType(ArithFn.min)
-  max_unsigned = ReduceFnType(ArithFn.max_unsigned)
-  min_unsigned = ReduceFnType(ArithFn.min_unsigned)
-
-
-class TensorArithFn(TensorExpression):
-  """Application of an arithmetic function."""
-
-  def __init__(self, arith_fn: ArithFnType, args: Sequence[TensorExpression]):
-    self.arith_fn = arith_fn
-    self.args = tuple(args)
-
-  def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarArithFn(self.arith_fn.fn_name,
-                         *[arg.to_scalar_expression() for arg in self.args
-                          ]).expr()
-
-  def visit_tensor_exprs(self, callback):
-    super().visit_tensor_exprs(callback)
-    for arg in self.args:
-      arg.visit_tensor_exprs(callback)
-
-  def __repr__(self):
-    return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})"
-
-
-class TensorTypeFn(TensorExpression):
-  """Application of a type conversion function."""
-
-  def __init__(self, type_fn: TypeFn, type_var: TypeVar, arg: TensorExpression):
-    self.type_fn = type_fn
-    self.type_var = type_var
-    self.arg = arg
-
-  def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarTypeFn(self.type_fn.fn_name, self.type_var,
-                        self.arg.to_scalar_expression()).expr()
-
-  def visit_tensor_exprs(self, callback):
-    super().visit_tensor_exprs(callback)
-    self.arg.visit_tensor_exprs(callback)
-
-  def __repr__(self):
-    return f"{repr(self.type_fn)}({self.type_var}, {self.arg})"
-
-
-class const(TensorExpression):
-  """Returns the given constant floating point or integer value."""
-
-  def __init__(self, value: Any):
-    with _ir.Context():
-      if isinstance(value, float):
-        self.value = str(_ir.FloatAttr.get_f64(float(value)))
-      elif isinstance(value, int):
-        self.value = str(
-            _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
-      else:
-        raise ValueError(f"const requires int or float but got {type(value)}")
-
-  def to_scalar_expression(self) -> ScalarExpression:
-    return ScalarConst(self.value).expr()
-
-  def __repr__(self):
-    return f"const({self.value})"
-
-
-class index(TensorExpression):
-  """Returns the iteration index for a given dimension name.
-
-  Resolves the given dimension name to obtain its position in the iteration
-  domain of the operation.
-  """
-
-  def __init__(self, dim: DimDef):
-    self.dim_def = dim
-    self.dim = -1
-
-  def resolve_dimension_name(self, affine_state: AffineBuildState):
-    self.dim = affine_state.get_dim(self.dim_def.dimname)
-
-  def to_scalar_expression(self) -> ScalarExpression:
-    assert self.dim != -1, "Dimension name not resolved"
-    return ScalarIndex(self.dim).expr()
-
-  def __repr__(self):
-    return f"index({repr(self.dim)})"
-
-
-class TensorReduceFn(TensorExpression):
-  """Application of a reduction function.
-
-  This captures the lhs (initial value) separately from the rhs.
-  """
-
-  def __init__(self, reduce_use: ReduceFnUse, args: Sequence[TensorExpression]):
-    self.reduce_use = reduce_use
-    self.lhs = None  # type: Optional[TensorUse]
-    self.args = tuple(args)
-
-  def to_scalar_expression(self) -> ScalarExpression:
-    if self.lhs is None:
-      raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been "
-                       f"bound to its lhs: {self}")
-    full_args = [self.lhs.to_scalar_expression()
-                ] + [arg.to_scalar_expression() for arg in self.args]
-    return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr()
-
-  def visit_tensor_exprs(self, callback):
-    for arg in self.args:
-      arg.visit_tensor_exprs(callback)
-
-  def __repr__(self):
-    return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})"
-
-
 class OpInterfaceDef:
   """An interface that an op implements."""
 


        


More information about the Mlir-commits mailing list