[Mlir-commits] [mlir] [MLIR][Complex] Add complex ops support in OPDSL. (PR #162665)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 9 07:24:42 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Hugo Trachino (nujaa)
<details>
<summary>Changes</summary>
This patch allows Opdsl to generate the complex version of existing OpDSL ops :
* ExpOp
* LogOp
* AbsOp
Adds support for 1 new op :
* ConjOp
I needed to refactor `elemwise_unary_poly` -> `elemwise_unary_poly_cast` since Complex AbsOp has inconsistent Input and output type (complex vs float). Additionally, turns out the cast in `elemwise_unary_poly` was not necessary for the tested use cases. Let me know if you prefer to see `elemwise_unary_poly_cast` completely gone or if it is maybe used downstream.
This patch includes nit-picking renaming of FileCheck names for better consistency.
---
Full diff: https://github.com/llvm/llvm-project/pull/162665.diff
3 Files Affected:
- (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py (+1)
- (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py (+11)
- (modified) mlir/test/python/dialects/linalg/opdsl/emit_misc.py (+62-11)
``````````diff
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 4f81a3874650d..3f3ec7b59eb3d 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -299,6 +299,7 @@ class UnaryFn:
square = UnaryFnType("square")
tanh = UnaryFnType("tanh")
erf = UnaryFnType("erf")
+ conj = UnaryFnType("conj")
class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 254458a978828..10f1083b11758 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -468,16 +468,22 @@ def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value:
def _unary_exp(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return math.ExpOp(x).result
+ if _is_complex_type(x.type):
+ return complex.ExpOp(x).result
raise NotImplementedError("Unsupported 'exp' operand: {x}")
def _unary_log(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return math.LogOp(x).result
+ if _is_complex_type(x.type):
+ return complex.LogOp(x).result
raise NotImplementedError("Unsupported 'log' operand: {x}")
def _unary_abs(self, x: Value) -> Value:
if _is_floating_point_type(x.type):
return math.AbsFOp(x).result
+ if _is_complex_type(x.type):
+ return complex.AbsOp(x).result
raise NotImplementedError("Unsupported 'abs' operand: {x}")
def _unary_ceil(self, x: Value) -> Value:
@@ -497,6 +503,11 @@ def _unary_negf(self, x: Value) -> Value:
return complex.NegOp(x).result
raise NotImplementedError("Unsupported 'negf' operand: {x}")
+ def _unary_conj(self, x: Value) -> Value:
+ if _is_complex_type(x.type):
+ return complex.ConjOp(x).result
+ raise NotImplementedError("Unsupported 'conj' operand: {x}")
+
def _binary_add(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.AddFOp(lhs, rhs).result
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
index f8e034fb0e48b..2afae8b055ed0 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py
@@ -30,7 +30,7 @@ def test_index(O=TensorDef(I32, S.M, S.N, output=True)):
@linalg_structured_op
-def elemwise_unary_poly(
+def elemwise_unary_poly_cast(
I=TensorDef(T),
O=TensorDef(U, output=True),
fun=UnaryFnAttrDef(default=UnaryFn.exp),
@@ -38,6 +38,13 @@ def elemwise_unary_poly(
):
O[None] = fun(cast(U, I[None]))
+ at linalg_structured_op
+def elemwise_unary_poly(
+ I=TensorDef(T),
+ O=TensorDef(U, output=True),
+ fun=UnaryFnAttrDef(default=UnaryFn.exp),
+):
+ O[None] = fun(I[None])
@linalg_structured_op(op_name="custom_op_name")
def non_default_op_name(I=TensorDef(T, S.N), O=TensorDef(T, S.N, output=True)):
@@ -84,6 +91,17 @@ def test_i32_index(init_result):
def test_f32_elemwise_exp(input, init_result):
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
+ # CHECK-LABEL: @test_c32_elemwise_exp
+ # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+ # CHECK-NEXT: %[[EXP:.+]] = complex.exp %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[EXP]] : complex<f32>
+ # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+ )
+ def test_c32_elemwise_exp(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.exp)
+
# CHECK-LABEL: @test_f32_elemwise_log
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
# CHECK-NEXT: %[[LOG:.+]] = math.log %[[IN]] : f32
@@ -95,10 +113,21 @@ def test_f32_elemwise_exp(input, init_result):
def test_f32_elemwise_log(input, init_result):
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
+ # CHECK-LABEL: @test_c32_elemwise_log
+ # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+ # CHECK-NEXT: %[[LOG:.+]] = complex.log %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[LOG]] : complex<f32>
+ # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+ )
+ def test_c32_elemwise_log(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.log)
+
# CHECK-LABEL: @test_f32_elemwise_abs
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = math.absf %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: %[[ABS:.+]] = math.absf %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[ABS]] : f32
# CHECK-NEXT: -> tensor<4x16xf32>
@func.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -106,10 +135,21 @@ def test_f32_elemwise_log(input, init_result):
def test_f32_elemwise_abs(input, init_result):
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+ # CHECK-LABEL: @test_c32_elemwise_abs
+ # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: f32)
+ # CHECK-NEXT: %[[ABS:.+]] = complex.abs %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[ABS]] : f32
+ # CHECK-NEXT: -> tensor<4x16xf32>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), f32)
+ )
+ def test_c32_elemwise_abs(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.abs)
+
# CHECK-LABEL: @test_f32_elemwise_ceil
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = math.ceil %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: %[[CEIL:.+]] = math.ceil %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[CEIL]] : f32
# CHECK-NEXT: -> tensor<4x16xf32>
@func.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -119,8 +159,8 @@ def test_f32_elemwise_ceil(input, init_result):
# CHECK-LABEL: @test_f32_elemwise_floor
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = math.floor %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: %[[FLOOR:.+]] = math.floor %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[FLOOR]] : f32
# CHECK-NEXT: -> tensor<4x16xf32>
@func.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -130,8 +170,8 @@ def test_f32_elemwise_floor(input, init_result):
# CHECK-LABEL: @test_f32_elemwise_neg
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[EXP:.+]] = arith.negf %[[IN]] : f32
- # CHECK-NEXT: linalg.yield %[[EXP]] : f32
+ # CHECK-NEXT: %[[NEG:.+]] = arith.negf %[[IN]] : f32
+ # CHECK-NEXT: linalg.yield %[[NEG]] : f32
# CHECK-NEXT: -> tensor<4x16xf32>
@func.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((4, 16), f32)
@@ -141,8 +181,8 @@ def test_f32_elemwise_neg(input, init_result):
# CHECK-LABEL: @test_c32_elemwise_neg
# CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
- # CHECK-NEXT: %[[EXP:.+]] = complex.neg %[[IN]] : complex<f32>
- # CHECK-NEXT: linalg.yield %[[EXP]] : complex<f32>
+ # CHECK-NEXT: %[[NEG:.+]] = complex.neg %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[NEG]] : complex<f32>
# CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
@func.FuncOp.from_py_func(
RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
@@ -150,6 +190,17 @@ def test_f32_elemwise_neg(input, init_result):
def test_c32_elemwise_neg(input, init_result):
return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.negf)
+ # CHECK-LABEL: @test_c32_elemwise_conj
+ # CHECK: ^{{.*}}(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
+ # CHECK-NEXT: %[[CONJ:.+]] = complex.conj %[[IN]] : complex<f32>
+ # CHECK-NEXT: linalg.yield %[[CONJ]] : complex<f32>
+ # CHECK-NEXT: -> tensor<4x16xcomplex<f32>>
+ @func.FuncOp.from_py_func(
+ RankedTensorType.get((4, 16), c32), RankedTensorType.get((4, 16), c32)
+ )
+ def test_c32_elemwise_conj(input, init_result):
+ return elemwise_unary_poly(input, outs=[init_result], fun=UnaryFn.conj, cast=None)
+
# Just check that we don't assert out on name mismatch.
# CHECK-LABEL: @test_non_default_op_name
@func.FuncOp.from_py_func(
``````````
</details>
https://github.com/llvm/llvm-project/pull/162665
More information about the Mlir-commits
mailing list