[Mlir-commits] [mlir] [MLIR][Complex] Add complex ops support in OPDSL. (PR #162665)

Hugo Trachino llvmlistbot at llvm.org
Thu Oct 9 08:28:27 PDT 2025


https://github.com/nujaa updated https://github.com/llvm/llvm-project/pull/162665

>From 05e6371fd97d1ac0044864bd1d4507abb0c00539 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Wed, 8 Oct 2025 21:49:49 +0800
Subject: [PATCH 1/2] [MLIR][Complex] Democratize complex ops in OPDSL.

---
 .../linalg/opdsl/lang/comprehension.py        |  1 +
 .../dialects/linalg/opdsl/lang/emitter.py     | 11 +++
 .../python/dialects/linalg/opdsl/emit_misc.py | 73 ++++++++++++++++---
 3 files changed, 74 insertions(+), 11 deletions(-)

diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 4f81a3874650d..3f3ec7b59eb3d 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -299,6 +299,7 @@ class UnaryFn:
     square = UnaryFnType("square")
     tanh = UnaryFnType("tanh")
     erf = UnaryFnType("erf")
+    conj = UnaryFnType("conj")
 
 
 class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 254458a978828..10f1083b11758 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -468,16 +468,22 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
     def _unary_exp(self, x: Value) -> Value:
         if _is_floating_point_type(x.type):
             return math.ExpOp(x).result
+        if _is_complex_type(x.type):
+            return complex.ExpOp(x).result
         raise NotImplementedError("Unsupported 'exp' operand: {x}")
 
     def _unary_log(self, x: Value) -> Value:
         if _is_floating_point_type(x.type):
             return math.LogOp(x).result
+        if _is_complex_type(x.type):
+            return complex.LogOp(x).result
         raise NotImplementedError("Unsupported 'log' operand: {x}")
 
     def _unary_abs(self, x: Value) -> Value:
         if _is_floating_point_type(x.type):
             return math.AbsFOp(x).result
+        if _is_complex_type(x.type):
+            return complex.AbsOp(x).result
         raise NotImplementedError("Unsupported 'abs' operand: {x}")
 
     def _unary_ceil(self, x: Value) -> Value:
@@ -497,6 +503,11 @@ def _unary_negf(self, x: Value) -> Value:
             return complex.NegOp(x).result
         raise NotImplementedError("Unsupported 'negf' operand: {x}")
 
+    def _unary_conj(self, x: Value) -> Value:
+        if _is_complex_type(x.type):
+            return complex.ConjOp(x).result
+        raise NotImplementedError("Unsupported 'conj' operand: {x}")
+
     def _binary_add(self, lhs: Value, rhs: Value) -> Value:
         if _is_floating_point_type(lhs.type):
             return arith.AddFOp(lhs, rhs).result
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index f8e034fb0e48b..2afae8b055ed0 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -30,7 +30,7 @@ def test_index(O=TensorDef(I32, S.M, S.N, output=True)):
 
 
 @linalg_structured_op
-def elemwise_unary_poly(
+def elemwise_unary_poly_cast(
     I=TensorDef(T),
     O=TensorDef(U, output=True),
     fun=UnaryFnAttrDef(default=UnaryFn.exp),
@@ -38,6 +38,13 @@ def elemwise_unary_poly(
 ):
     O[None] = fun(cast(U, I[None]))
 
+ at linalg_structured_op
+def elemwise_unary_poly(
+    I=TensorDef(T),
+    O=TensorDef(U, output=True),
+    fun=UnaryFnAttrDef(default=UnaryFn.exp),
+):
+    O[None] = fun(I[None])
 
 @linalg_structured_op(op_name="custom_op_name")
 def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
@@ -84,6 +91,17 @@ def test_i32_index(init_result):
         def test_f32_elemwise_exp(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
 
+        # CHECK-LABEL: @test_c32_elemwise_exp
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[EXP:.+]] = complex.exp %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_exp(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
+
         # CHECK-LABEL: @test_f32_elemwise_log
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
         # CHECK-NEXT:   %[[LOG:.+]] = math.log %[[IN]] : f32
@@ -95,10 +113,21 @@ def test_f32_elemwise_exp(input, init_result):
         def test_f32_elemwise_log(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
 
+        # CHECK-LABEL: @test_c32_elemwise_log
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[LOG:.+]] = complex.log %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[LOG]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_log(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
+
         # CHECK-LABEL: @test_f32_elemwise_abs
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = math.absf %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[ABS:.+]] = math.absf %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[ABS]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -106,10 +135,21 @@ def test_f32_elemwise_log(input, init_result):
         def test_f32_elemwise_abs(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
 
+        # CHECK-LABEL: @test_c32_elemwise_abs
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[ABS:.+]] = complex.abs %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[ABS]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_c32_elemwise_abs(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+
         # CHECK-LABEL: @test_f32_elemwise_ceil
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = math.ceil %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[CEIL:.+]] = math.ceil %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[CEIL]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -119,8 +159,8 @@ def test_f32_elemwise_ceil(input, init_result):
 
         # CHECK-LABEL: @test_f32_elemwise_floor
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = math.floor %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[FLOOR:.+]] = math.floor %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[FLOOR]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -130,8 +170,8 @@ def test_f32_elemwise_floor(input, init_result):
 
         # CHECK-LABEL: @test_f32_elemwise_neg
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = arith.negf %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[NEG:.+]] = arith.negf %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[NEG]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -141,8 +181,8 @@ def test_f32_elemwise_neg(input, init_result):
 
         # CHECK-LABEL: @test_c32_elemwise_neg
         # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
-        # CHECK-NEXT:   %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
+        # CHECK-NEXT:   %[[NEG:.+]] = complex.neg %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[NEG]] : complex<f32>
         # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
@@ -150,6 +190,17 @@ def test_f32_elemwise_neg(input, init_result):
         def test_c32_elemwise_neg(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
 
+        # CHECK-LABEL: @test_c32_elemwise_conj
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[CONJ:.+]] = complex.conj %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[CONJ]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_conj(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.conj, cast=None)
+
         # Just check that we don't assert out on name mismatch.
         # CHECK-LABEL: @test_non_default_op_name
         @func.FuncOp.from_py_func(

>From 6f3c02acb1dd496cd758cfb085d50ce510f41c27 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 9 Oct 2025 22:42:36 +0800
Subject: [PATCH 2/2] formating

---
 mlir/test/python/dialects/linalg/opdsl/emit_misc.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index 2afae8b055ed0..d23c48daebad7 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -38,6 +38,7 @@ def elemwise_unary_poly_cast(
 ):
     O[None] = fun(cast(U, I[None]))
 
+
 @linalg_structured_op
 def elemwise_unary_poly(
     I=TensorDef(T),
@@ -199,7 +200,9 @@ def test_c32_elemwise_neg(input, init_result):
             RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
         )
         def test_c32_elemwise_conj(input, init_result):
-            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.conj, cast=None)
+            return elemwise_unary_poly(
+                input, outs=[init_result], fun=UnaryFn.conj, cast=None
+            )
 
         # Just check that we don't assert out on name mismatch.
         # CHECK-LABEL: @test_non_default_op_name



More information about the Mlir-commits mailing list