[Mlir-commits] [mlir] 48f4407 - [mlir][linalg] Extend opdsl to support operations on complex types.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 17 09:34:32 PDT 2022


Author: bixia1
Date: 2022-06-17T09:34:26-07:00
New Revision: 48f4407c1aafbf953b501381d725320fa867eaf5

URL: https://github.com/llvm/llvm-project/commit/48f4407c1aafbf953b501381d725320fa867eaf5
DIFF: https://github.com/llvm/llvm-project/commit/48f4407c1aafbf953b501381d725320fa867eaf5.diff

LOG: [mlir][linalg] Extend opdsl to support operations on complex types.

Linalg opdsl now supports negf/add/sub/mul on complex types.

Add a test.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D128010

Added: 
    

Modified: 
    mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
    mlir/test/python/dialects/linalg/opdsl/emit_misc.py

Removed: 
    


################################################################################
diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 2e71e561a7f54..cc99081b440d0 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -10,6 +10,7 @@
 from .... import linalg
 from .... import math
 from .... import arith
+from .... import complex
 from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
 
 from .scalar_expr import *
@@ -408,6 +409,8 @@ def _unary_floor(self, x: Value) -> Value:
   def _unary_negf(self, x: Value) -> Value:
     if _is_floating_point_type(x.type):
       return arith.NegFOp(x).result
+    if _is_complex_type(x.type):
+      return complex.NegOp(x).result
     raise NotImplementedError("Unsupported 'negf' operand: {x}")
 
   def _binary_add(self, lhs: Value, rhs: Value) -> Value:
@@ -415,6 +418,8 @@ def _binary_add(self, lhs: Value, rhs: Value) -> Value:
       return arith.AddFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return arith.AddIOp(lhs, rhs).result
+    if _is_complex_type(lhs.type):
+      return complex.AddOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}")
 
   def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
@@ -422,6 +427,8 @@ def _binary_sub(self, lhs: Value, rhs: Value) -> Value:
       return arith.SubFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return arith.SubIOp(lhs, rhs).result
+    if _is_complex_type(lhs.type):
+      return complex.SubOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}")
 
   def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
@@ -429,6 +436,8 @@ def _binary_mul(self, lhs: Value, rhs: Value) -> Value:
       return arith.MulFOp(lhs, rhs).result
     if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
       return arith.MulIOp(lhs, rhs).result
+    if _is_complex_type(lhs.type):
+      return complex.MulOp(lhs, rhs).result
     raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}")
 
   def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value:
@@ -512,6 +521,10 @@ def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type,
   block_arg_types.append(element_or_self_type)
 
 
+def _is_complex_type(t: Type) -> bool:
+  return ComplexType.isinstance(t)
+
+
 def _is_floating_point_type(t: Type) -> bool:
   # TODO: Create a FloatType in the Python API and implement the switch
   # there.

diff  --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index 2d045125f2858..ddb5cc8248024 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -44,6 +44,7 @@ def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
 with Context() as ctx, Location.unknown():
   module = Module.create()
   f32 = F32Type.get()
+  c32 = ComplexType.get(f32)
   i32 = IntegerType.get_signless(32)
   with InsertionPoint(module.body):
 
@@ -129,6 +130,16 @@ def test_f32_elemwise_floor(input, init_result):
     def test_f32_elemwise_neg(input, init_result):
       return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
 
+    # CHECK-LABEL: @test_c32_elemwise_neg
+    # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+    # CHECK-NEXT:   %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
+    # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
+    # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+    @func.FuncOp.from_py_func(
+        RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32))
+    def test_c32_elemwise_neg(input, init_result):
+      return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+
     # Just check that we don't assert out on name mismatch.
     # CHECK-LABEL: @test_non_default_op_name
     @func.FuncOp.from_py_func(


        


More information about the Mlir-commits mailing list