[Mlir-commits] [mlir] [MLIR][Complex] Add complex ops support in OPDSL. (PR #162665)
Hugo Trachino
llvmlistbot at llvm.org
Thu Oct 9 08:28:27 PDT 2025
https://github.com/nujaa updated https://github.com/llvm/llvm-project/pull/162665
>From 05e6371fd97d1ac0044864bd1d4507abb0c00539 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Wed, 8 Oct 2025 21:49:49 +0800
Subject: [PATCH 1/2] [MLIR][Complex] Democratize complex ops in OPDSL.
---
.../linalg/opdsl/lang/comprehension.py | 1 +
.../dialects/linalg/opdsl/lang/emitter.py | 11 +++
.../python/dialects/linalg/opdsl/emit_misc.py | 73 ++++++++++++++++---
3 files changed, 74 insertions(+), 11 deletions(-)
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 4f81a3874650d..3f3ec7b59eb3d 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -299,6 +299,7 @@ class UnaryFn:
square = UnaryFnType("square")
tanh = UnaryFnType("tanh")
erf = UnaryFnType("erf")
+ conj = UnaryFnType("conj")
class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 254458a978828..10f1083b11758 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -468,16 +468,22 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
def _unary_exp(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return math.ExpOp(x).result
+ if _is_complex_type(x.type):
+ return complex.ExpOp(x).result
raise NotImplementedError("Unsupported 'exp' operand: {x}")
def _unary_log(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return math.LogOp(x).result
+ if _is_complex_type(x.type):
+ return complex.LogOp(x).result
raise NotImplementedError("Unsupported 'log' operand: {x}")
def _unary_abs(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return math.AbsFOp(x).result
+ if _is_complex_type(x.type):
+ return complex.AbsOp(x).result
raise NotImplementedError("Unsupported 'abs' operand: {x}")
def _unary_ceil(self, x: Value) -> Value:
@@ -497,6 +503,11 @@ def _unary_negf(self, x: Value) -> Value:
return complex.NegOp(x).result
raise NotImplementedError("Unsupported 'negf' operand: {x}")
+ def _unary_conj(self, x: Value) -> Value:
+ if _is_complex_type(x.type):
+ return complex.ConjOp(x).result
+ raise NotImplementedError("Unsupported 'conj' operand: {x}")
+
def _binary_add(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.AddFOp(lhs, rhs).result
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index f8e034fb0e48b..2afae8b055ed0 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -30,7 +30,7 @@ def test_index(O=TensorDef(I32, S.M, S.N, output=True)):
@linalg_structured_op
-def elemwise_unary_poly(
+def elemwise_unary_poly_cast(
I=TensorDef(T),
O=TensorDef(U, output=True),
fun=UnaryFnAttrDef(default=UnaryFn.exp),
@@ -38,6 +38,13 @@ def elemwise_unary_poly(
):
O[None] = fun(cast(U, I[None]))
+ at linalg_structured_op
+def elemwise_unary_poly(
+ I=TensorDef(T),
+ O=TensorDef(U, output=True),
+ fun=UnaryFnAttrDef(default=UnaryFn.exp),
+):
+ O[None] = fun(I[None])
@linalg_structured_op(op_name="custom_op_name")
def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
@@ -84,6 +91,17 @@ def test_i32_index(init_result):
def test_f32_elemwise_exp(input, init_result):
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
+ # CHECK-LABEL: @test_c32_elemwise_exp
+ # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+ # CHECK-NEXT: %[[EXP:.+]] = complex.exp %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[EXP]] : complex<f32>
+ # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+ )
+ def test_c32_elemwise_exp(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
+
# CHECK-LABEL: @test_f32_elemwise_log
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
# CHECK-NEXT: %[[LOG:.+]] = math.log %[[IN]] : f32
@@ -95,10 +113,21 @@ def test_f32_elemwise_exp(input, init_result):
def test_f32_elemwise_log(input, init_result):
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
+ # CHECK-LABEL: @test_c32_elemwise_log
+ # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+ # CHECK-NEXT: %[[LOG:.+]] = complex.log %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[LOG]] : complex<f32>
+ # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+ )
+ def test_c32_elemwise_log(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
+
# CHECK-LABEL: @test_f32_elemwise_abs
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = math.absf %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: %[[ABS:.+]] = math.absf %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[ABS]] : f32
# CHECK-NEXT: -> tensor<4x16xf32>
@func.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -106,10 +135,21 @@ def test_f32_elemwise_log(input, init_result):
def test_f32_elemwise_abs(input, init_result):
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+ # CHECK-LABEL: @test_c32_elemwise_abs
+ # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[ABS:.+]] = complex.abs %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[ABS]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), f32)
+ )
+ def test_c32_elemwise_abs(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+
# CHECK-LABEL: @test_f32_elemwise_ceil
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: %[[CEIL:.+]] = math.ceil %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[CEIL]] : f32
# CHECK-NEXT: -> tensor<4x16xf32>
@func.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -119,8 +159,8 @@ def test_f32_elemwise_ceil(input, init_result):
# CHECK-LABEL: @test_f32_elemwise_floor
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: %[[FLOOR:.+]] = math.floor %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[FLOOR]] : f32
# CHECK-NEXT: -> tensor<4x16xf32>
@func.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -130,8 +170,8 @@ def test_f32_elemwise_floor(input, init_result):
# CHECK-LABEL: @test_f32_elemwise_neg
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: %[[NEG:.+]] = arith.negf %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[NEG]] : f32
# CHECK-NEXT: -> tensor<4x16xf32>
@func.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -141,8 +181,8 @@ def test_f32_elemwise_neg(input, init_result):
# CHECK-LABEL: @test_c32_elemwise_neg
# CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
- # CHECK-NEXT: %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
- # CHECK-NEXT: linalg.yield %[[EXP]] : complex<f32>
+ # CHECK-NEXT: %[[NEG:.+]] = complex.neg %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[NEG]] : complex<f32>
# CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
@func.FuncOp.from_py_func(
RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
@@ -150,6 +190,17 @@ def test_f32_elemwise_neg(input, init_result):
def test_c32_elemwise_neg(input, init_result):
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+ # CHECK-LABEL: @test_c32_elemwise_conj
+ # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+ # CHECK-NEXT: %[[CONJ:.+]] = complex.conj %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[CONJ]] : complex<f32>
+ # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+ )
+ def test_c32_elemwise_conj(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.conj, cast=None)
+
# Just check that we don't assert out on name mismatch.
# CHECK-LABEL: @test_non_default_op_name
@func.FuncOp.from_py_func(
>From 6f3c02acb1dd496cd758cfb085d50ce510f41c27 Mon Sep 17 00:00:00 2001
From: Hugo <hugo.trachino at huawei.com>
Date: Thu, 9 Oct 2025 22:42:36 +0800
Subject: [PATCH 2/2] formating
---
mlir/test/python/dialects/linalg/opdsl/emit_misc.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index 2afae8b055ed0..d23c48daebad7 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -38,6 +38,7 @@ def elemwise_unary_poly_cast(
):
O[None] = fun(cast(U, I[None]))
+
@linalg_structured_op
def elemwise_unary_poly(
I=TensorDef(T),
@@ -199,7 +200,9 @@ def test_c32_elemwise_neg(input, init_result):
RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
)
def test_c32_elemwise_conj(input, init_result):
- return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.conj, cast=None)
+ return elemwise_unary_poly(
+ input, outs=[init_result], fun=UnaryFn.conj, cast=None
+ )
# Just check that we don't assert out on name mismatch.
# CHECK-LABEL: @test_non_default_op_name
More information about the Mlir-commits
mailing list