[Mlir-commits] [mlir] 48f4407 - [mlir][linalg] Extend opdsl to support operations on complex types.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 17 09:34:32 PDT 2022
Author: bixia1
Date: 2022-06-17T09:34:26-07:00
New Revision: 48f4407c1aafbf953b501381d725320fa867eaf5
URL: https://github.com/llvm/llvm-project/commit/48f4407c1aafbf953b501381d725320fa867eaf5
DIFF: https://github.com/llvm/llvm-project/commit/48f4407c1aafbf953b501381d725320fa867eaf5.diff
LOG: [mlir][linalg] Extend opdsl to support operations on complex types.
Linalg opdsl now supports negf/add/sub/mul on complex types.
Add a test.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D128010
Added:
Modified:
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/test/python/dialects/linalg/opdsl/emit_misc.py
Removed:
################################################################################
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 2e71e561a7f54..cc99081b440d0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -10,6 +10,7 @@
from .... import linalg
from .... import math
from .... import arith
+from .... import complex
from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
from .scalar_expr import *
@@ -408,6 +409,8 @@ def _unary_floor(self, x: Value) -> Value:
def _unary_negf(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return arith.NegFOp(x).result
+ if _is_complex_type(x.type):
+ return complex.NegOp(x).result
raise NotImplementedError("Unsupported 'negf' operand: {x}")
def _binary_add(self, lhs: Value, rhs: Value) -> Value:
@@ -415,6 +418,8 @@ def _binary_add(self, lhs: Value, rhs: Value) -> Value:
return arith.AddFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.AddIOp(lhs, rhs).result
+ if _is_complex_type(lhs.type):
+ return complex.AddOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
@@ -422,6 +427,8 @@ def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
return arith.SubFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.SubIOp(lhs, rhs).result
+ if _is_complex_type(lhs.type):
+ return complex.SubOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
@@ -429,6 +436,8 @@ def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
return arith.MulFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MulIOp(lhs, rhs).result
+ if _is_complex_type(lhs.type):
+ return complex.MulOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
@@ -512,6 +521,10 @@ def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type,
block_arg_types.append(element_or_self_type)
+def _is_complex_type(t: Type) -> bool:
+ return ComplexType.isinstance(t)
+
+
def _is_floating_point_type(t: Type) -> bool:
# TODO: Create a FloatType in the Python API and implement the switch
# there.
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index 2d045125f2858..ddb5cc8248024 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -44,6 +44,7 @@ def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
+ c32 = ComplexType.get(f32)
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
@@ -129,6 +130,16 @@ def test_f32_elemwise_floor(input, init_result):
def test_f32_elemwise_neg(input, init_result):
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+ # CHECK-LABEL: @test_c32_elemwise_neg
+ # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+ # CHECK-NEXT: %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[EXP]] : complex<f32>
+ # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32))
+ def test_c32_elemwise_neg(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+
# Just check that we don't assert out on name mismatch.
# CHECK-LABEL: @test_non_default_op_name
@func.FuncOp.from_py_func(
More information about the Mlir-commits
mailing list