[Mlir-commits] [mlir] d4cba4a - [mlir][linalg] Add structured op builders from python opdsl.

Stella Laurenzo llvmlistbot at llvm.org
Fri Mar 19 11:21:08 PDT 2021


Author: Stella Laurenzo
Date: 2021-03-19T11:20:36-07:00
New Revision: d4cba4a188f419d1c2fc4b827c4a6a0310b0568e

URL: https://github.com/llvm/llvm-project/commit/d4cba4a188f419d1c2fc4b827c4a6a0310b0568e
DIFF: https://github.com/llvm/llvm-project/commit/d4cba4a188f419d1c2fc4b827c4a6a0310b0568e.diff

LOG: [mlir][linalg] Add structured op builders from python opdsl.

* Makes the wrapped functions of the `@linalg_structured_op` decorator callable such that they emit IR imperatively when invoked.
* There are numerous TODOs that I will keep working through to achieve generality.
* Will true up exception handling tests as the feature progresses (for things that are actually errors once everything is implemented).
* Includes the addition of an `isinstance` method on concrete types in the Python API.

Differential Revision: https://reviews.llvm.org/D98754

Added: 
    mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py

Modified: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
    mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index a544e52c2613..6b4e5434d1d7 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -2477,6 +2477,9 @@ class PyConcreteType : public BaseTy {
   static void bind(py::module &m) {
     auto cls = ClassTy(m, DerivedTy::pyClassName);
     cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
+    cls.def_static("isinstance", [](PyType &otherType) -> bool {
+      return DerivedTy::isaFunction(otherType);
+    });
     DerivedTy::bindDerived(cls);
   }
 

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
index 115ea40619b8..fdc6cfd9bab0 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -21,6 +21,7 @@
 __all__ = [
     "LinalgStructuredOpConfig",
     "LinalgOpConfig",
+    "TensorDefConfig",
 ]
 
 
@@ -51,17 +52,17 @@ def __init__(self, tensor_def: TensorDef, shape_map: _ir.AffineMap):
     self.shape_map = shape_map
     self.indexing_map = None  # type: Optional[_ir.AffineMap]
 
-  def to_yaml_custom_dict(self):
-
-    def get_usage():
-      if self.tensor_def.output:
-        return "output"
-      else:
-        return "input"
+  @property
+  def usage(self) -> str:
+    if self.tensor_def.output:
+      return "output"
+    else:
+      return "input"
 
+  def to_yaml_custom_dict(self):
     return dict(
         name=self.tensor_def.tensor_name,
-        usage=get_usage(),
+        usage=self.usage,
         shape=_serialize_affine_map(self.shape_map),
         element_type_var=self.tensor_def.type_var.name,
     )

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
index d367c5bdde07..cbff41db2d88 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -11,6 +11,8 @@
 
 from mlir import ir
 from .comprehension import *
+from .config import *
+from .emitter import *
 
 _CONTEXT = threading.local()
 
@@ -42,9 +44,34 @@ def __init__(self, op_name: str, model: LinalgOpDef):
     self.op_name = op_name
     self.model = model
 
-  def __call__(self, *args, **kwargs):
-    # TODO: Upstream the emitter and invoke here
-    raise NotImplementedError("Linalg generic emission not yet implemented")
+  def __call__(self, *args, emit_generic: bool = True, **kwargs):
+    """Emits the corresponding op definition as IR.
+
+    Most arguments are passed through to the underlying emitter. The following
+    are interpreted here:
+      emit_generic: Emits a generic form as appropriate (default True). If
+        False, a named form is emitted (which must have been built in to the
+        compiler).
+    """
+    op_configs = LinalgOpConfig.from_linalg_op_def(self.model,
+                                                   context=ir.Context.current)
+
+    if len(op_configs) != 1:
+      # TODO: Support composite ops.
+      raise NotImplementedError(
+          f"Emission of composite linalg ops not supported: {op_configs}")
+
+    op_config = op_configs[0]
+    if op_config.structured_op:
+      if emit_generic:
+        return emit_generic_structured_op(op_config.structured_op, *args,
+                                          **kwargs)
+      else:
+        return emit_named_structured_op(op_config.structured_op, *args,
+                                        **kwargs)
+
+    raise NotImplementedError(
+        f"Emission of linalg op type not supported: {op_config}")
 
 
 def linalg_structured_op(dsl_func=None,

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
new file mode 100644
index 000000000000..9a18993e9f62
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -0,0 +1,252 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Dict, Sequence
+
+from mlir.ir import *
+from mlir.dialects import linalg
+from mlir.dialects import std
+
+from .scalar_expr import *
+from .config import *
+
+__all__ = [
+    "emit_generic_structured_op",
+    "emit_named_structured_op",
+]
+
+
+def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
+                               *ins: Value,
+                               outs: Value = ()):
+  all_arg_defs = op_config.ordered_tensor_args
+  in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"]
+  out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"]
+
+  # Arity validation.
+  if len(ins) != len(in_arg_defs):
+    raise ValueError(f"Expected {len(in_arg_defs)} inputs but got "
+                     f"{len(ins)} for {op_config}")
+  if outs and len(outs) != len(out_arg_defs):
+    raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
+                     f"{len(outs)} for {op_config}")
+
+  outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
+                                           out_arg_defs, outs)
+
+  # Extract type vars for input/output based types.
+  type_mapping = dict()  # type: Dict[str, Type]
+  for arg_def, arg_element_type in zip(
+      in_arg_defs + out_arg_defs,
+      _get_shaped_element_types_from_values(*ins, *outs)):
+    tv_name = arg_def.tensor_def.type_var.name
+    type_mapping[tv_name] = arg_element_type
+
+  # Emit the generic op.
+  # TODO: Support emission of pure memref form.
+  indexing_maps_attr = ArrayAttr.get(
+      [AffineMapAttr.get(am) for am in op_config.indexing_maps])
+  iterator_types_attr = ArrayAttr.get(
+      [StringAttr.get(s) for s in op_config.iterator_types])
+  generic_op = linalg.GenericOp(
+      result_tensors=out_types,
+      inputs=ins,
+      outputs=outs,
+      indexing_maps=indexing_maps_attr,
+      iterator_types=iterator_types_attr,
+      doc=None,  # TODO: Make optional.
+      library_call=None,  # TODO: Make optional.
+      sparse=BoolAttr.get(False))  # TODO: Make optional.
+
+  # Construct the body.
+  block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs)
+  block_arg_types = _get_shaped_element_types_from_values(*ins, *outs)
+  block = generic_op.regions[0].blocks.append(*block_arg_types)
+  block_arg_mapping = dict(zip(block_arg_names, block.arguments))
+  with InsertionPoint(block):
+    body_builder = _BodyBuilder(type_mapping, block_arg_mapping)
+    for assignment in op_config.assignments:
+      body_builder.assign(assignment)
+    body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs))
+
+  if len(out_arg_defs) == 1:
+    return generic_op.result
+  else:
+    return generic_op.results
+
+
+def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
+                             *ins: Value,
+                             outs: Value = ()):
+  raise NotImplementedError(
+      f"Emission of named structured ops is not supported: {op_config}")
+
+
+class _BodyBuilder:
+  """Constructs a structured op body by evaluating assignments."""
+
+  def __init__(self, type_mapping: Dict[str, Type],
+               block_arg_mapping: Dict[str, Value]):
+    self.type_mapping = type_mapping
+    self.block_arg_mapping = block_arg_mapping
+    self.yield_mapping = dict()  # type: Dict[str, Value]
+
+  def assign(self, assignment: ScalarAssign):
+    if assignment.arg in self.yield_mapping:
+      raise ValueError(
+          f"Multiple assignments to the same argument are forbidden: "
+          f"{assignment}")
+    self.yield_mapping[assignment.arg] = self.expression(assignment.value)
+
+  def expression(self, expr: ScalarExpression) -> Value:
+    if expr.scalar_arg:
+      try:
+        return self.block_arg_mapping[expr.scalar_arg.arg]
+      except KeyError:
+        raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for "
+                         f"this structured op.")
+    elif expr.scalar_apply:
+      try:
+        fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}")
+      except AttributeError:
+        raise ValueError(
+            f"Function '{expr.scalar_apply.fn_name}' is not a known "
+            "scalar body function")
+      operand_values = [
+          self.expression(operand) for operand in expr.scalar_apply.operands
+      ]
+      return fn(*operand_values)
+    elif expr.symbolic_cast:
+      operand_value = self.expression(expr.symbolic_cast.operand)
+      return self.cast(expr.symbolic_cast.to_type.name, operand_value)
+    raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
+
+  def cast(self, type_var_name: str, operand: Value) -> Value:
+    try:
+      to_type = self.type_mapping[type_var_name]
+    except KeyError:
+      raise ValueError(f"Unbound type variable '{type_var_name}' ("
+                       f"expected one of {self.type_mappings.keys()}")
+    if operand.type == to_type:
+      return operand
+    if _is_integer_type(to_type):
+      return self._cast_to_integer(to_type, operand)
+    elif _is_floating_point_type(to_type):
+      return self._cast_to_floating_point(to_type, operand)
+
+    raise ValueError(f"Unable to cast body expression from {operand.type} to "
+                     f"{to_type}")
+
+  def _cast_to_integer(self, to_type: Type, operand: Value) -> Value:
+    to_width = IntegerType(to_type).width
+    operand_type = operand.type
+    if _is_floating_point_type(operand_type):
+      return std.FPToSIOp(to_type, operand).result
+    # Assume integer.
+    from_width = IntegerType(operand_type).width
+    if to_width > from_width:
+      return std.SignExtendIOp(to_type, operand).result
+    elif to_width < from_width:
+      return std.TruncateIOp(to_type, operand).result
+    raise ValueError(f"Unable to cast body expression from {operand_type} to "
+                     f"{to_type}")
+
+  def _cast_to_floating_point(self, to_type: Type, operand: Value) -> Value:
+    operand_type = operand.type
+    if _is_integer_type(operand_type):
+      return std.SIToFPOp(to_type, operand).result
+    # Assume FloatType.
+    to_width = _get_floating_point_width(to_type)
+    from_width = _get_floating_point_width(operand_type)
+    if to_width > from_width:
+      return std.FPExtOp(to_type, operand).result
+    elif to_width < from_width:
+      return std.FPTruncOp(to_type, operand).result
+    raise ValueError(f"Unable to cast body expression from {operand_type} to "
+                     f"{to_type}")
+
+  def yield_outputs(self, *output_names: str):
+    output_values = []
+    for n in output_names:
+      try:
+        output_values.append(self.yield_mapping[n])
+      except KeyError:
+        raise ValueError(f"Body assignments do not assign all outputs: "
+                         f"missing '{n}'")
+    linalg.YieldOp(output_values)
+
+  def _eval_add(self, lhs: Value, rhs: Value) -> Value:
+    if _is_floating_point_type(lhs.type):
+      return std.AddFOp(lhs.type, lhs, rhs).result
+    if _is_integer_type(lhs.type):
+      return std.AddIOp(lhs.type, lhs, rhs).result
+    raise NotImplementedError("Unsupported 'add' operand: {lhs}")
+
+  def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
+    if _is_floating_point_type(lhs.type):
+      return std.MulFOp(lhs.type, lhs, rhs).result
+    if _is_integer_type(lhs.type):
+      return std.MulIOp(lhs.type, lhs, rhs).result
+    raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
+
+
+def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
+                           in_arg_defs: Sequence[TensorDefConfig],
+                           ins: Sequence[Value],
+                           out_arg_defs: Sequence[TensorDefConfig],
+                           outs: Sequence[Value]):
+  """Infers implicit outs and output types.
+
+  Respects existing contents of outs if not empty.
+
+  Returns:
+    normalized outs, output types
+  """
+  # If outs were explicitly provided, we accept them verbatim.
+  if outs:
+    return outs, [out.type for out in outs]
+
+  raise NotImplementedError(f"Output tensor inference not yet supported for "
+                            "structured ops")
+
+
+def _get_shaped_element_types_from_values(*values: Value) -> Sequence[Type]:
+  types = []
+  for v in values:
+    try:
+      t = ShapedType(v.type)
+    except Exception as e:
+      raise ValueError(f"Expected ShapedType but got {v}") from e
+    types.append(t.element_type)
+  return types
+
+
+def _get_tensor_def_names(
+    *tensor_def_configs: TensorDefConfig) -> Sequence[str]:
+  return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs]
+
+
+def _is_floating_point_type(t: Type) -> bool:
+  # TODO: Create a FloatType in the Python API and implement the switch
+  # there.
+  return (F64Type.isinstance(t) or F32Type.isinstance(t) or
+          F16Type.isinstance(t) or BF16Type.isinstance(t))
+
+
+def _is_integer_type(t: Type) -> bool:
+  return IntegerType.isinstance(t)
+
+
+def _get_floating_point_width(t: Type) -> int:
+  # TODO: Create a FloatType in the Python API and implement the switch
+  # there.
+  if F64Type.isinstance(t):
+    return 64
+  if F32Type.isinstance(t):
+    return 32
+  if F16Type.isinstance(t):
+    return 16
+  if BF16Type.isinstance(t):
+    return 16
+  raise NotImplementedError(f"Unhandled floating point type switch {t}")

diff  --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
new file mode 100644
index 000000000000..7f8c11679457
--- /dev/null
+++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -0,0 +1,194 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from typing import Optional, Sequence
+
+from mlir.ir import *
+from mlir.dialects import builtin
+from mlir.dialects import linalg
+from mlir.dialects import std
+
+from mlir.dialects.linalg.opdsl.lang import *
+
+
+# TODO: Find a home for this quality of life helper.
+def build_function(*inputs: Type, results: Optional[Sequence[Type]] = None):
+  """Decorator that emits a function in a more pythonic way.
+
+  If result types are not specified, they are inferred from the function
+  returns. The `ReturnOp` is implicitly added upon the wrapped function return.
+  """
+
+  def decorator(f):
+    return_types = results
+    symbol_name = f.__name__
+    function_type = FunctionType.get(inputs=inputs, results=results or [])
+    func_op = builtin.FuncOp(name=symbol_name, type=function_type)
+    with InsertionPoint(func_op.add_entry_block()):
+      func_args = func_op.entry_block.arguments
+      return_values = f(*func_args)
+      if return_values is None:
+        return_values = []
+      elif isinstance(return_values, Value):
+        return_values = [return_values]
+      else:
+        return_values = list(return_values)
+      std.ReturnOp(return_values)
+      if return_types is None:
+        # Recompute the function type.
+        return_types = [v.type for v in return_values]
+        function_type = FunctionType.get(inputs=inputs, results=return_types)
+        # TODO: Have an API or a setter for this.
+        func_op.attributes["type"] = TypeAttr.get(function_type)
+
+    # TODO: When turning this into a real facility, return a function that emits
+    # a `call` to the function instead of doing nothing.
+    wrapped = lambda: None
+    wrapped.__name__ = symbol_name
+    wrapped.func_op = func_op
+    return wrapped
+
+  return decorator
+
+
+ at linalg_structured_op
+def matmul_mono(A=TensorDef(T, S.M, S.K),
+                B=TensorDef(T, S.K, S.N),
+                C=TensorDef(T, S.M, S.N, output=True)):
+  C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n]
+
+
+ at linalg_structured_op
+def matmul_poly(A=TensorDef(TV.T1, S.M, S.K),
+                B=TensorDef(TV.T2, S.K, S.N),
+                C=TensorDef(U, S.M, S.N, output=True)):
+  C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+
+
+with Context() as ctx, Location.unknown():
+  module = Module.create()
+  f16 = F16Type.get()
+  f32 = F32Type.get()
+  f64 = F64Type.get()
+  i8 = IntegerType.get_signless(8)
+  i16 = IntegerType.get_signless(16)
+  i32 = IntegerType.get_signless(32)
+  with InsertionPoint.at_block_terminator(module.body):
+
+    # Note that these all have the same indexing maps. We verify the first and
+    # then do more permutation tests on casting and body generation
+    # behavior.
+    # CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+    # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
+    # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+
+    # CHECK-LABEL: func @test_matmul_mono
+    # CHECK-SAME:  %[[A:.+]]: tensor<4x16xf32>
+    # CHECK-SAME: %[[B:.+]]: tensor<16x8xf32>
+
+    # CHECK: %[[INITC:.+]] = linalg.init_tensor [4, 8] : tensor<4x8xf32>
+    # CHECK: linalg.generic
+    # CHECK-SAME: indexing_maps = [#[[$MAPA]], #[[$MAPB]], #[[$MAPC]]]
+    # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+    # CHECK-SAME: ins(%[[A]], %[[B]]
+    # CHECK-SAME: outs(%[[INITC]]
+
+    @build_function(RankedTensorType.get((4, 16), f32),
+                    RankedTensorType.get((16, 8), f32))
+    def test_matmul_mono(lhs, rhs):
+      # TODO: Enable outs inference and add sugar for InitTensorOp
+      # construction.
+      init_result = linalg.InitTensorOp(result=RankedTensorType.get((4, 8),
+                                                                    f32),
+                                        static_sizes=ArrayAttr.get([
+                                            IntegerAttr.get(IndexType.get(), 4),
+                                            IntegerAttr.get(IndexType.get(), 8)
+                                        ]),
+                                        sizes=[])
+      return matmul_mono(lhs, rhs, outs=[init_result.result])
+
+    # CHECK-LABEL: @test_i8i8i32_matmul
+    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
+    # CHECK-NEXT:   %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
+    # CHECK-NEXT:   %[[B_CAST:.+]] = sexti %[[B_ARG]] : i8 to i32
+    # CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
+    # CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
+    # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
+    # CHECK-NEXT: -> tensor<4x8xi32>
+    @build_function(RankedTensorType.get((4, 16), i8),
+                    RankedTensorType.get((16, 8), i8),
+                    RankedTensorType.get((4, 8), i32))
+    def test_i8i8i32_matmul(lhs, rhs, init_result):
+      return matmul_poly(lhs, rhs, outs=[init_result])
+
+    # CHECK-LABEL: @test_i8i16i32_matmul
+    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
+    # CHECK-NEXT:   %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
+    # CHECK-NEXT:   %[[B_CAST:.+]] = sexti %[[B_ARG]] : i16 to i32
+    # CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
+    # CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
+    # CHECK-NEXT:   linalg.yield %[[ADD]] : i32
+    # CHECK-NEXT: -> tensor<4x8xi32>
+    @build_function(RankedTensorType.get((4, 16), i8),
+                    RankedTensorType.get((16, 8), i16),
+                    RankedTensorType.get((4, 8), i32))
+    def test_i8i16i32_matmul(lhs, rhs, init_result):
+      return matmul_poly(lhs, rhs, outs=[init_result])
+
+    # CHECK-LABEL: @test_i32i32i16_matmul
+    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
+    # CHECK-NEXT:   %[[A_CAST:.+]] = trunci %[[A_ARG]] : i32 to i16
+    # CHECK-NEXT:   %[[B_CAST:.+]] = trunci %[[B_ARG]] : i32 to i16
+    # CHECK-NEXT:   %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16
+    # CHECK-NEXT:   %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
+    # CHECK-NEXT:   linalg.yield %[[ADD]] : i16
+    # CHECK-NEXT: -> tensor<4x8xi16>
+    @build_function(RankedTensorType.get((4, 16), i32),
+                    RankedTensorType.get((16, 8), i32),
+                    RankedTensorType.get((4, 8), i16))
+    def test_i32i32i16_matmul(lhs, rhs, init_result):
+      return matmul_poly(lhs, rhs, outs=[init_result])
+
+    # CHECK-LABEL: @test_i8i8f32_matmul
+    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
+    # CHECK-NEXT:   %[[A_CAST:.+]] = sitofp %[[A_ARG]] : i8 to f32
+    # CHECK-NEXT:   %[[B_CAST:.+]] = sitofp %[[B_ARG]] : i8 to f32
+    # CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
+    # CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+    # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+    # CHECK-NEXT: -> tensor<4x8xf32>
+    @build_function(RankedTensorType.get((4, 16), i8),
+                    RankedTensorType.get((16, 8), i8),
+                    RankedTensorType.get((4, 8), f32))
+    def test_i8i8f32_matmul(lhs, rhs, init_result):
+      return matmul_poly(lhs, rhs, outs=[init_result])
+
+    # CHECK-LABEL: @test_f16f16f32_matmul
+    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+    # CHECK-NEXT:   %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32
+    # CHECK-NEXT:   %[[B_CAST:.+]] = fpext %[[B_ARG]] : f16 to f32
+    # CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
+    # CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+    # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+    # CHECK-NEXT: -> tensor<4x8xf32>
+    @build_function(RankedTensorType.get((4, 16), f16),
+                    RankedTensorType.get((16, 8), f16),
+                    RankedTensorType.get((4, 8), f32))
+    def test_f16f16f32_matmul(lhs, rhs, init_result):
+      return matmul_poly(lhs, rhs, outs=[init_result])
+
+    # CHECK-LABEL: @test_f64f64f32_matmul
+    # CHECK:      ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+    # CHECK-NEXT:   %[[A_CAST:.+]] = fptrunc %[[A_ARG]] : f64 to f32
+    # CHECK-NEXT:   %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32
+    # CHECK-NEXT:   %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
+    # CHECK-NEXT:   %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+    # CHECK-NEXT:   linalg.yield %[[ADD]] : f32
+    # CHECK-NEXT: -> tensor<4x8xf32>
+    @build_function(RankedTensorType.get((4, 16), f64),
+                    RankedTensorType.get((16, 8), f64),
+                    RankedTensorType.get((4, 8), f32))
+    def test_f64f64f32_matmul(lhs, rhs, init_result):
+      return matmul_poly(lhs, rhs, outs=[init_result])
+
+
+print(module)

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 59b4b50b533d..ea05c1561f74 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -59,6 +59,21 @@ def testTypeEq():
 run(testTypeEq)
 
 
+# CHECK-LABEL: TEST: testTypeIsInstance
+def testTypeIsInstance():
+  ctx = Context()
+  t1 = Type.parse("i32", ctx)
+  t2 = Type.parse("f32", ctx)
+  # CHECK: True
+  print(IntegerType.isinstance(t1))
+  # CHECK: False
+  print(F32Type.isinstance(t1))
+  # CHECK: True
+  print(F32Type.isinstance(t2))
+
+run(testTypeIsInstance)
+
+
 # CHECK-LABEL: TEST: testTypeEqDoesNotRaise
 def testTypeEqDoesNotRaise():
   ctx = Context()


        


More information about the Mlir-commits mailing list