[Mlir-commits] [mlir] [MLIR][Complex] Add complex ops support in OPDSL. (PR #162665)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 9 07:24:42 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Hugo Trachino (nujaa)

<details>
<summary>Changes</summary>

This patch allows Opdsl to generate the complex version of existing OpDSL ops :
* ExpOp
* LogOp
* AbsOp

Adds support for 1 new op :
* ConjOp

I needed to refactor `elemwise_unary_poly` -> `elemwise_unary_poly_cast` since Complex AbsOp has inconsistent Input and output type (complex vs float). Additionally, turns out the cast in `elemwise_unary_poly` was not necessary for the tested use cases. Let me know if you prefer to see `elemwise_unary_poly_cast` completely gone or if it is maybe used downstream.

This patch includes nit-picking renaming of FileCheck names for better consistency.

---
Full diff: https://github.com/llvm/llvm-project/pull/162665.diff


3 Files Affected:

- (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py (+1) 
- (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py (+11) 
- (modified) mlir/test/python/dialects/linalg/opdsl/emit_misc.py (+62-11) 


``````````diff
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 4f81a3874650d..3f3ec7b59eb3d 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -299,6 +299,7 @@ class UnaryFn:
     square = UnaryFnType("square")
     tanh = UnaryFnType("tanh")
     erf = UnaryFnType("erf")
+    conj = UnaryFnType("conj")
 
 
 class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 254458a978828..10f1083b11758 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -468,16 +468,22 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
     def _unary_exp(self, x: Value) -> Value:
         if _is_floating_point_type(x.type):
             return math.ExpOp(x).result
+        if _is_complex_type(x.type):
+            return complex.ExpOp(x).result
         raise NotImplementedError("Unsupported 'exp' operand: {x}")
 
     def _unary_log(self, x: Value) -> Value:
         if _is_floating_point_type(x.type):
             return math.LogOp(x).result
+        if _is_complex_type(x.type):
+            return complex.LogOp(x).result
         raise NotImplementedError("Unsupported 'log' operand: {x}")
 
     def _unary_abs(self, x: Value) -> Value:
         if _is_floating_point_type(x.type):
             return math.AbsFOp(x).result
+        if _is_complex_type(x.type):
+            return complex.AbsOp(x).result
         raise NotImplementedError("Unsupported 'abs' operand: {x}")
 
     def _unary_ceil(self, x: Value) -> Value:
@@ -497,6 +503,11 @@ def _unary_negf(self, x: Value) -> Value:
             return complex.NegOp(x).result
         raise NotImplementedError("Unsupported 'negf' operand: {x}")
 
+    def _unary_conj(self, x: Value) -> Value:
+        if _is_complex_type(x.type):
+            return complex.ConjOp(x).result
+        raise NotImplementedError("Unsupported 'conj' operand: {x}")
+
     def _binary_add(self, lhs: Value, rhs: Value) -> Value:
         if _is_floating_point_type(lhs.type):
             return arith.AddFOp(lhs, rhs).result
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index f8e034fb0e48b..2afae8b055ed0 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -30,7 +30,7 @@ def test_index(O=TensorDef(I32, S.M, S.N, output=True)):
 
 
 @linalg_structured_op
-def elemwise_unary_poly(
+def elemwise_unary_poly_cast(
     I=TensorDef(T),
     O=TensorDef(U, output=True),
     fun=UnaryFnAttrDef(default=UnaryFn.exp),
@@ -38,6 +38,13 @@ def elemwise_unary_poly(
 ):
     O[None] = fun(cast(U, I[None]))
 
+ at linalg_structured_op
+def elemwise_unary_poly(
+    I=TensorDef(T),
+    O=TensorDef(U, output=True),
+    fun=UnaryFnAttrDef(default=UnaryFn.exp),
+):
+    O[None] = fun(I[None])
 
 @linalg_structured_op(op_name="custom_op_name")
 def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
@@ -84,6 +91,17 @@ def test_i32_index(init_result):
         def test_f32_elemwise_exp(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
 
+        # CHECK-LABEL: @test_c32_elemwise_exp
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[EXP:.+]] = complex.exp %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_exp(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
+
         # CHECK-LABEL: @test_f32_elemwise_log
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
         # CHECK-NEXT:   %[[LOG:.+]] = math.log %[[IN]] : f32
@@ -95,10 +113,21 @@ def test_f32_elemwise_exp(input, init_result):
         def test_f32_elemwise_log(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
 
+        # CHECK-LABEL: @test_c32_elemwise_log
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[LOG:.+]] = complex.log %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[LOG]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_log(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
+
         # CHECK-LABEL: @test_f32_elemwise_abs
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = math.absf %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[ABS:.+]] = math.absf %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[ABS]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -106,10 +135,21 @@ def test_f32_elemwise_log(input, init_result):
         def test_f32_elemwise_abs(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
 
+        # CHECK-LABEL: @test_c32_elemwise_abs
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: f32)
+        # CHECK-NEXT:   %[[ABS:.+]] = complex.abs %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[ABS]] : f32
+        # CHECK-NEXT: -> tensor<4x16xf32>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), f32)
+        )
+        def test_c32_elemwise_abs(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+
         # CHECK-LABEL: @test_f32_elemwise_ceil
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = math.ceil %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[CEIL:.+]] = math.ceil %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[CEIL]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -119,8 +159,8 @@ def test_f32_elemwise_ceil(input, init_result):
 
         # CHECK-LABEL: @test_f32_elemwise_floor
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = math.floor %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[FLOOR:.+]] = math.floor %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[FLOOR]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -130,8 +170,8 @@ def test_f32_elemwise_floor(input, init_result):
 
         # CHECK-LABEL: @test_f32_elemwise_neg
         # CHECK:      ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
-        # CHECK-NEXT:   %[[EXP:.+]] = arith.negf %[[IN]] : f32
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : f32
+        # CHECK-NEXT:   %[[NEG:.+]] = arith.negf %[[IN]] : f32
+        # CHECK-NEXT:   linalg.yield %[[NEG]] : f32
         # CHECK-NEXT: -> tensor<4x16xf32>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -141,8 +181,8 @@ def test_f32_elemwise_neg(input, init_result):
 
         # CHECK-LABEL: @test_c32_elemwise_neg
         # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
-        # CHECK-NEXT:   %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
-        # CHECK-NEXT:   linalg.yield %[[EXP]] : complex<f32>
+        # CHECK-NEXT:   %[[NEG:.+]] = complex.neg %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[NEG]] : complex<f32>
         # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
         @func.FuncOp.from_py_func(
             RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
@@ -150,6 +190,17 @@ def test_f32_elemwise_neg(input, init_result):
         def test_c32_elemwise_neg(input, init_result):
             return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
 
+        # CHECK-LABEL: @test_c32_elemwise_conj
+        # CHECK:      ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+        # CHECK-NEXT:   %[[CONJ:.+]] = complex.conj %[[IN]] : complex<f32>
+        # CHECK-NEXT:   linalg.yield %[[CONJ]] : complex<f32>
+        # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+        @func.FuncOp.from_py_func(
+            RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+        )
+        def test_c32_elemwise_conj(input, init_result):
+            return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.conj, cast=None)
+
         # Just check that we don't assert out on name mismatch.
         # CHECK-LABEL: @test_non_default_op_name
         @func.FuncOp.from_py_func(

``````````

</details>


https://github.com/llvm/llvm-project/pull/162665


More information about the Mlir-commits mailing list