[Mlir-commits] [mlir] d4cba4a - [mlir][linalg] Add structured op builders from python opdsl.
Stella Laurenzo
llvmlistbot at llvm.org
Fri Mar 19 11:21:08 PDT 2021
Author: Stella Laurenzo
Date: 2021-03-19T11:20:36-07:00
New Revision: d4cba4a188f419d1c2fc4b827c4a6a0310b0568e
URL: https://github.com/llvm/llvm-project/commit/d4cba4a188f419d1c2fc4b827c4a6a0310b0568e
DIFF: https://github.com/llvm/llvm-project/commit/d4cba4a188f419d1c2fc4b827c4a6a0310b0568e.diff
LOG: [mlir][linalg] Add structured op builders from python opdsl.
* Makes the wrapped functions of the `@linalg_structured_op` decorator callable such that they emit IR imperatively when invoked.
* There are numerous TODOs that I will keep working through to achieve generality.
* Will true up exception handling tests as the feature progresses (for things that are actually errors once everything is implemented).
* Includes the addition of an `isinstance` method on concrete types in the Python API.
Differential Revision: https://reviews.llvm.org/D98754
Added:
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
Modified:
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
mlir/test/Bindings/Python/ir_types.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index a544e52c2613..6b4e5434d1d7 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -2477,6 +2477,9 @@ class PyConcreteType : public BaseTy {
static void bind(py::module &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName);
cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
+ cls.def_static("isinstance", [](PyType &otherType) -> bool {
+ return DerivedTy::isaFunction(otherType);
+ });
DerivedTy::bindDerived(cls);
}
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
index 115ea40619b8..fdc6cfd9bab0 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py
@@ -21,6 +21,7 @@
__all__ = [
"LinalgStructuredOpConfig",
"LinalgOpConfig",
+ "TensorDefConfig",
]
@@ -51,17 +52,17 @@ def __init__(self, tensor_def: TensorDef, shape_map: _ir.AffineMap):
self.shape_map = shape_map
self.indexing_map = None # type: Optional[_ir.AffineMap]
- def to_yaml_custom_dict(self):
-
- def get_usage():
- if self.tensor_def.output:
- return "output"
- else:
- return "input"
+ @property
+ def usage(self) -> str:
+ if self.tensor_def.output:
+ return "output"
+ else:
+ return "input"
+ def to_yaml_custom_dict(self):
return dict(
name=self.tensor_def.tensor_name,
- usage=get_usage(),
+ usage=self.usage,
shape=_serialize_affine_map(self.shape_map),
element_type_var=self.tensor_def.type_var.name,
)
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
index d367c5bdde07..cbff41db2d88 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py
@@ -11,6 +11,8 @@
from mlir import ir
from .comprehension import *
+from .config import *
+from .emitter import *
_CONTEXT = threading.local()
@@ -42,9 +44,34 @@ def __init__(self, op_name: str, model: LinalgOpDef):
self.op_name = op_name
self.model = model
- def __call__(self, *args, **kwargs):
- # TODO: Upstream the emitter and invoke here
- raise NotImplementedError("Linalg generic emission not yet implemented")
+ def __call__(self, *args, emit_generic: bool = True, **kwargs):
+ """Emits the corresponding op definition as IR.
+
+ Most arguments are passed through to the underlying emitter. The following
+ are interpreted here:
+ emit_generic: Emits a generic form as appropriate (default True). If
+ False, a named form is emitted (which must have been built in to the
+ compiler).
+ """
+ op_configs = LinalgOpConfig.from_linalg_op_def(self.model,
+ context=ir.Context.current)
+
+ if len(op_configs) != 1:
+ # TODO: Support composite ops.
+ raise NotImplementedError(
+ f"Emission of composite linalg ops not supported: {op_configs}")
+
+ op_config = op_configs[0]
+ if op_config.structured_op:
+ if emit_generic:
+ return emit_generic_structured_op(op_config.structured_op, *args,
+ **kwargs)
+ else:
+ return emit_named_structured_op(op_config.structured_op, *args,
+ **kwargs)
+
+ raise NotImplementedError(
+ f"Emission of linalg op type not supported: {op_config}")
def linalg_structured_op(dsl_func=None,
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
new file mode 100644
index 000000000000..9a18993e9f62
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -0,0 +1,252 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Dict, Sequence
+
+from mlir.ir import *
+from mlir.dialects import linalg
+from mlir.dialects import std
+
+from .scalar_expr import *
+from .config import *
+
+__all__ = [
+ "emit_generic_structured_op",
+ "emit_named_structured_op",
+]
+
+
+def emit_generic_structured_op(op_config: LinalgStructuredOpConfig,
+ *ins: Value,
+ outs: Value = ()):
+ all_arg_defs = op_config.ordered_tensor_args
+ in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"]
+ out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"]
+
+ # Arity validation.
+ if len(ins) != len(in_arg_defs):
+ raise ValueError(f"Expected {len(in_arg_defs)} inputs but got "
+ f"{len(ins)} for {op_config}")
+ if outs and len(outs) != len(out_arg_defs):
+ raise ValueError(f"Expected {len(out_arg_defs)} outputs but got "
+ f"{len(outs)} for {op_config}")
+
+ outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins,
+ out_arg_defs, outs)
+
+ # Extract type vars for input/output based types.
+ type_mapping = dict() # type: Dict[str, Type]
+ for arg_def, arg_element_type in zip(
+ in_arg_defs + out_arg_defs,
+ _get_shaped_element_types_from_values(*ins, *outs)):
+ tv_name = arg_def.tensor_def.type_var.name
+ type_mapping[tv_name] = arg_element_type
+
+ # Emit the generic op.
+ # TODO: Support emission of pure memref form.
+ indexing_maps_attr = ArrayAttr.get(
+ [AffineMapAttr.get(am) for am in op_config.indexing_maps])
+ iterator_types_attr = ArrayAttr.get(
+ [StringAttr.get(s) for s in op_config.iterator_types])
+ generic_op = linalg.GenericOp(
+ result_tensors=out_types,
+ inputs=ins,
+ outputs=outs,
+ indexing_maps=indexing_maps_attr,
+ iterator_types=iterator_types_attr,
+ doc=None, # TODO: Make optional.
+ library_call=None, # TODO: Make optional.
+ sparse=BoolAttr.get(False)) # TODO: Make optional.
+
+ # Construct the body.
+ block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs)
+ block_arg_types = _get_shaped_element_types_from_values(*ins, *outs)
+ block = generic_op.regions[0].blocks.append(*block_arg_types)
+ block_arg_mapping = dict(zip(block_arg_names, block.arguments))
+ with InsertionPoint(block):
+ body_builder = _BodyBuilder(type_mapping, block_arg_mapping)
+ for assignment in op_config.assignments:
+ body_builder.assign(assignment)
+ body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs))
+
+ if len(out_arg_defs) == 1:
+ return generic_op.result
+ else:
+ return generic_op.results
+
+
+def emit_named_structured_op(op_config: LinalgStructuredOpConfig,
+ *ins: Value,
+ outs: Value = ()):
+ raise NotImplementedError(
+ f"Emission of named structured ops is not supported: {op_config}")
+
+
+class _BodyBuilder:
+ """Constructs a structured op body by evaluating assignments."""
+
+ def __init__(self, type_mapping: Dict[str, Type],
+ block_arg_mapping: Dict[str, Value]):
+ self.type_mapping = type_mapping
+ self.block_arg_mapping = block_arg_mapping
+ self.yield_mapping = dict() # type: Dict[str, Value]
+
+ def assign(self, assignment: ScalarAssign):
+ if assignment.arg in self.yield_mapping:
+ raise ValueError(
+ f"Multiple assignments to the same argument are forbidden: "
+ f"{assignment}")
+ self.yield_mapping[assignment.arg] = self.expression(assignment.value)
+
+ def expression(self, expr: ScalarExpression) -> Value:
+ if expr.scalar_arg:
+ try:
+ return self.block_arg_mapping[expr.scalar_arg.arg]
+ except KeyError:
+ raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for "
+ f"this structured op.")
+ elif expr.scalar_apply:
+ try:
+ fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}")
+ except AttributeError:
+ raise ValueError(
+ f"Function '{expr.scalar_apply.fn_name}' is not a known "
+ "scalar body function")
+ operand_values = [
+ self.expression(operand) for operand in expr.scalar_apply.operands
+ ]
+ return fn(*operand_values)
+ elif expr.symbolic_cast:
+ operand_value = self.expression(expr.symbolic_cast.operand)
+ return self.cast(expr.symbolic_cast.to_type.name, operand_value)
+ raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
+
+ def cast(self, type_var_name: str, operand: Value) -> Value:
+ try:
+ to_type = self.type_mapping[type_var_name]
+ except KeyError:
+ raise ValueError(f"Unbound type variable '{type_var_name}' ("
+ f"expected one of {self.type_mappings.keys()}")
+ if operand.type == to_type:
+ return operand
+ if _is_integer_type(to_type):
+ return self._cast_to_integer(to_type, operand)
+ elif _is_floating_point_type(to_type):
+ return self._cast_to_floating_point(to_type, operand)
+
+ raise ValueError(f"Unable to cast body expression from {operand.type} to "
+ f"{to_type}")
+
+ def _cast_to_integer(self, to_type: Type, operand: Value) -> Value:
+ to_width = IntegerType(to_type).width
+ operand_type = operand.type
+ if _is_floating_point_type(operand_type):
+ return std.FPToSIOp(to_type, operand).result
+ # Assume integer.
+ from_width = IntegerType(operand_type).width
+ if to_width > from_width:
+ return std.SignExtendIOp(to_type, operand).result
+ elif to_width < from_width:
+ return std.TruncateIOp(to_type, operand).result
+ raise ValueError(f"Unable to cast body expression from {operand_type} to "
+ f"{to_type}")
+
+ def _cast_to_floating_point(self, to_type: Type, operand: Value) -> Value:
+ operand_type = operand.type
+ if _is_integer_type(operand_type):
+ return std.SIToFPOp(to_type, operand).result
+ # Assume FloatType.
+ to_width = _get_floating_point_width(to_type)
+ from_width = _get_floating_point_width(operand_type)
+ if to_width > from_width:
+ return std.FPExtOp(to_type, operand).result
+ elif to_width < from_width:
+ return std.FPTruncOp(to_type, operand).result
+ raise ValueError(f"Unable to cast body expression from {operand_type} to "
+ f"{to_type}")
+
+ def yield_outputs(self, *output_names: str):
+ output_values = []
+ for n in output_names:
+ try:
+ output_values.append(self.yield_mapping[n])
+ except KeyError:
+ raise ValueError(f"Body assignments do not assign all outputs: "
+ f"missing '{n}'")
+ linalg.YieldOp(output_values)
+
+ def _eval_add(self, lhs: Value, rhs: Value) -> Value:
+ if _is_floating_point_type(lhs.type):
+ return std.AddFOp(lhs.type, lhs, rhs).result
+ if _is_integer_type(lhs.type):
+ return std.AddIOp(lhs.type, lhs, rhs).result
+ raise NotImplementedError("Unsupported 'add' operand: {lhs}")
+
+ def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
+ if _is_floating_point_type(lhs.type):
+ return std.MulFOp(lhs.type, lhs, rhs).result
+ if _is_integer_type(lhs.type):
+ return std.MulIOp(lhs.type, lhs, rhs).result
+ raise NotImplementedError("Unsupported 'mul' operand: {lhs}")
+
+
+def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
+ in_arg_defs: Sequence[TensorDefConfig],
+ ins: Sequence[Value],
+ out_arg_defs: Sequence[TensorDefConfig],
+ outs: Sequence[Value]):
+ """Infers implicit outs and output types.
+
+ Respects existing contents of outs if not empty.
+
+ Returns:
+ normalized outs, output types
+ """
+ # If outs were explicitly provided, we accept them verbatim.
+ if outs:
+ return outs, [out.type for out in outs]
+
+ raise NotImplementedError(f"Output tensor inference not yet supported for "
+ "structured ops")
+
+
+def _get_shaped_element_types_from_values(*values: Value) -> Sequence[Type]:
+ types = []
+ for v in values:
+ try:
+ t = ShapedType(v.type)
+ except Exception as e:
+ raise ValueError(f"Expected ShapedType but got {v}") from e
+ types.append(t.element_type)
+ return types
+
+
+def _get_tensor_def_names(
+ *tensor_def_configs: TensorDefConfig) -> Sequence[str]:
+ return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs]
+
+
+def _is_floating_point_type(t: Type) -> bool:
+ # TODO: Create a FloatType in the Python API and implement the switch
+ # there.
+ return (F64Type.isinstance(t) or F32Type.isinstance(t) or
+ F16Type.isinstance(t) or BF16Type.isinstance(t))
+
+
+def _is_integer_type(t: Type) -> bool:
+ return IntegerType.isinstance(t)
+
+
+def _get_floating_point_width(t: Type) -> int:
+ # TODO: Create a FloatType in the Python API and implement the switch
+ # there.
+ if F64Type.isinstance(t):
+ return 64
+ if F32Type.isinstance(t):
+ return 32
+ if F16Type.isinstance(t):
+ return 16
+ if BF16Type.isinstance(t):
+ return 16
+ raise NotImplementedError(f"Unhandled floating point type switch {t}")
diff --git a/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
new file mode 100644
index 000000000000..7f8c11679457
--- /dev/null
+++ b/mlir/test/Bindings/Python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -0,0 +1,194 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from typing import Optional, Sequence
+
+from mlir.ir import *
+from mlir.dialects import builtin
+from mlir.dialects import linalg
+from mlir.dialects import std
+
+from mlir.dialects.linalg.opdsl.lang import *
+
+
+# TODO: Find a home for this quality of life helper.
+def build_function(*inputs: Type, results: Optional[Sequence[Type]] = None):
+ """Decorator that emits a function in a more pythonic way.
+
+ If result types are not specified, they are inferred from the function
+ returns. The `ReturnOp` is implicitly added upon the wrapped function return.
+ """
+
+ def decorator(f):
+ return_types = results
+ symbol_name = f.__name__
+ function_type = FunctionType.get(inputs=inputs, results=results or [])
+ func_op = builtin.FuncOp(name=symbol_name, type=function_type)
+ with InsertionPoint(func_op.add_entry_block()):
+ func_args = func_op.entry_block.arguments
+ return_values = f(*func_args)
+ if return_values is None:
+ return_values = []
+ elif isinstance(return_values, Value):
+ return_values = [return_values]
+ else:
+ return_values = list(return_values)
+ std.ReturnOp(return_values)
+ if return_types is None:
+ # Recompute the function type.
+ return_types = [v.type for v in return_values]
+ function_type = FunctionType.get(inputs=inputs, results=return_types)
+ # TODO: Have an API or a setter for this.
+ func_op.attributes["type"] = TypeAttr.get(function_type)
+
+ # TODO: When turning this into a real facility, return a function that emits
+ # a `call` to the function instead of doing nothing.
+ wrapped = lambda: None
+ wrapped.__name__ = symbol_name
+ wrapped.func_op = func_op
+ return wrapped
+
+ return decorator
+
+
+ at linalg_structured_op
+def matmul_mono(A=TensorDef(T, S.M, S.K),
+ B=TensorDef(T, S.K, S.N),
+ C=TensorDef(T, S.M, S.N, output=True)):
+ C[D.m, D.n] += A[D.m, D.k] * B[D.k, D.n]
+
+
+ at linalg_structured_op
+def matmul_poly(A=TensorDef(TV.T1, S.M, S.K),
+ B=TensorDef(TV.T2, S.K, S.N),
+ C=TensorDef(U, S.M, S.N, output=True)):
+ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
+
+
+with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f16 = F16Type.get()
+ f32 = F32Type.get()
+ f64 = F64Type.get()
+ i8 = IntegerType.get_signless(8)
+ i16 = IntegerType.get_signless(16)
+ i32 = IntegerType.get_signless(32)
+ with InsertionPoint.at_block_terminator(module.body):
+
+ # Note that these all have the same indexing maps. We verify the first and
+ # then do more permutation tests on casting and body generation
+ # behavior.
+ # CHECK: #[[$MAPA:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+ # CHECK: #[[$MAPB:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
+ # CHECK: #[[$MAPC:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+
+ # CHECK-LABEL: func @test_matmul_mono
+ # CHECK-SAME: %[[A:.+]]: tensor<4x16xf32>
+ # CHECK-SAME: %[[B:.+]]: tensor<16x8xf32>
+
+ # CHECK: %[[INITC:.+]] = linalg.init_tensor [4, 8] : tensor<4x8xf32>
+ # CHECK: linalg.generic
+ # CHECK-SAME: indexing_maps = [#[[$MAPA]], #[[$MAPB]], #[[$MAPC]]]
+ # CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
+ # CHECK-SAME: ins(%[[A]], %[[B]]
+ # CHECK-SAME: outs(%[[INITC]]
+
+ @build_function(RankedTensorType.get((4, 16), f32),
+ RankedTensorType.get((16, 8), f32))
+ def test_matmul_mono(lhs, rhs):
+ # TODO: Enable outs inference and add sugar for InitTensorOp
+ # construction.
+ init_result = linalg.InitTensorOp(result=RankedTensorType.get((4, 8),
+ f32),
+ static_sizes=ArrayAttr.get([
+ IntegerAttr.get(IndexType.get(), 4),
+ IntegerAttr.get(IndexType.get(), 8)
+ ]),
+ sizes=[])
+ return matmul_mono(lhs, rhs, outs=[init_result.result])
+
+ # CHECK-LABEL: @test_i8i8i32_matmul
+ # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32)
+ # CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
+ # CHECK-NEXT: %[[B_CAST:.+]] = sexti %[[B_ARG]] : i8 to i32
+ # CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
+ # CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
+ # CHECK-NEXT: linalg.yield %[[ADD]] : i32
+ # CHECK-NEXT: -> tensor<4x8xi32>
+ @build_function(RankedTensorType.get((4, 16), i8),
+ RankedTensorType.get((16, 8), i8),
+ RankedTensorType.get((4, 8), i32))
+ def test_i8i8i32_matmul(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result])
+
+ # CHECK-LABEL: @test_i8i16i32_matmul
+ # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32)
+ # CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32
+ # CHECK-NEXT: %[[B_CAST:.+]] = sexti %[[B_ARG]] : i16 to i32
+ # CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32
+ # CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
+ # CHECK-NEXT: linalg.yield %[[ADD]] : i32
+ # CHECK-NEXT: -> tensor<4x8xi32>
+ @build_function(RankedTensorType.get((4, 16), i8),
+ RankedTensorType.get((16, 8), i16),
+ RankedTensorType.get((4, 8), i32))
+ def test_i8i16i32_matmul(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result])
+
+ # CHECK-LABEL: @test_i32i32i16_matmul
+ # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16)
+ # CHECK-NEXT: %[[A_CAST:.+]] = trunci %[[A_ARG]] : i32 to i16
+ # CHECK-NEXT: %[[B_CAST:.+]] = trunci %[[B_ARG]] : i32 to i16
+ # CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16
+ # CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
+ # CHECK-NEXT: linalg.yield %[[ADD]] : i16
+ # CHECK-NEXT: -> tensor<4x8xi16>
+ @build_function(RankedTensorType.get((4, 16), i32),
+ RankedTensorType.get((16, 8), i32),
+ RankedTensorType.get((4, 8), i16))
+ def test_i32i32i16_matmul(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result])
+
+ # CHECK-LABEL: @test_i8i8f32_matmul
+ # CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32)
+ # CHECK-NEXT: %[[A_CAST:.+]] = sitofp %[[A_ARG]] : i8 to f32
+ # CHECK-NEXT: %[[B_CAST:.+]] = sitofp %[[B_ARG]] : i8 to f32
+ # CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
+ # CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+ # CHECK-NEXT: linalg.yield %[[ADD]] : f32
+ # CHECK-NEXT: -> tensor<4x8xf32>
+ @build_function(RankedTensorType.get((4, 16), i8),
+ RankedTensorType.get((16, 8), i8),
+ RankedTensorType.get((4, 8), f32))
+ def test_i8i8f32_matmul(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result])
+
+ # CHECK-LABEL: @test_f16f16f32_matmul
+ # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32)
+ # CHECK-NEXT: %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32
+ # CHECK-NEXT: %[[B_CAST:.+]] = fpext %[[B_ARG]] : f16 to f32
+ # CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
+ # CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+ # CHECK-NEXT: linalg.yield %[[ADD]] : f32
+ # CHECK-NEXT: -> tensor<4x8xf32>
+ @build_function(RankedTensorType.get((4, 16), f16),
+ RankedTensorType.get((16, 8), f16),
+ RankedTensorType.get((4, 8), f32))
+ def test_f16f16f32_matmul(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result])
+
+ # CHECK-LABEL: @test_f64f64f32_matmul
+ # CHECK: ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32)
+ # CHECK-NEXT: %[[A_CAST:.+]] = fptrunc %[[A_ARG]] : f64 to f32
+ # CHECK-NEXT: %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32
+ # CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32
+ # CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+ # CHECK-NEXT: linalg.yield %[[ADD]] : f32
+ # CHECK-NEXT: -> tensor<4x8xf32>
+ @build_function(RankedTensorType.get((4, 16), f64),
+ RankedTensorType.get((16, 8), f64),
+ RankedTensorType.get((4, 8), f32))
+ def test_f64f64f32_matmul(lhs, rhs, init_result):
+ return matmul_poly(lhs, rhs, outs=[init_result])
+
+
+print(module)
diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 59b4b50b533d..ea05c1561f74 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -59,6 +59,21 @@ def testTypeEq():
run(testTypeEq)
+# CHECK-LABEL: TEST: testTypeIsInstance
+def testTypeIsInstance():
+ ctx = Context()
+ t1 = Type.parse("i32", ctx)
+ t2 = Type.parse("f32", ctx)
+ # CHECK: True
+ print(IntegerType.isinstance(t1))
+ # CHECK: False
+ print(F32Type.isinstance(t1))
+ # CHECK: True
+ print(F32Type.isinstance(t2))
+
+run(testTypeIsInstance)
+
+
# CHECK-LABEL: TEST: testTypeEqDoesNotRaise
def testTypeEqDoesNotRaise():
ctx = Context()
More information about the Mlir-commits
mailing list