[Mlir-commits] [mlir] [MLIR][Linalg] Remove elemwise_unary and elemwise_binary (PR #147082)
Renato Golin
llvmlistbot at llvm.org
Fri Jul 4 09:12:46 PDT 2025
https://github.com/rengolin created https://github.com/llvm/llvm-project/pull/147082
RFC: https://discourse.llvm.org/t/rfc-deprecate-linalg-elemwise-unary-and-elemwise-binary/87144
Remove the two operations and fix the tests by:
* Cleaning simple operation tests of the old ops
* Changing `linalg.elemwise_{u|bi}nary` with `linalg.{exp|add}` on transform tests
* Surgically removing the `elemwise_*` part in the Python tests
Nothing else changed.
>From dbf79307c54816529b59fc037a61b0d63f212921 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Fri, 4 Jul 2025 16:30:22 +0100
Subject: [PATCH 1/4] Remove old element_wise operations
---
.../Linalg/IR/LinalgNamedStructuredOps.yaml | 114 ------------------
.../linalg/opdsl/ops/core_named_ops.py | 31 -----
2 files changed, 145 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 6344861c53ac5..3637147c5a90d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -44,56 +44,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: elemwise_unary
- cpp_class_name: ElemwiseUnaryOp
- doc: |-
- Applies the unary function fun elementwise.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: I
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<() -> ()>
- - !LinalgOperandDefConfig
- name: O
- kind: output_tensor
- type_var: U
- shape_map: affine_map<() -> ()>
- - !LinalgOperandDefConfig
- name: fun
- kind: unary_fn_attr
- default_fn: exp
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<() -> ()>
- - affine_map<() -> ()>
- iterator_types: []
- assignments:
- - !ScalarAssign
- arg: O
- value: !ScalarExpression
- scalar_fn:
- kind: unary
- attr_name: fun
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: I
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: exp
cpp_class_name: ExpOp
@@ -549,70 +499,6 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
-metadata: !LinalgOpMetadata
- name: elemwise_binary
- cpp_class_name: ElemwiseBinaryOp
- doc: |-
- Applies the binary function fun elementwise.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
-structured_op: !LinalgStructuredOpConfig
- args:
- - !LinalgOperandDefConfig
- name: lhs
- kind: input_tensor
- type_var: T1
- shape_map: affine_map<() -> ()>
- - !LinalgOperandDefConfig
- name: rhs
- kind: input_tensor
- type_var: T2
- shape_map: affine_map<() -> ()>
- - !LinalgOperandDefConfig
- name: O
- kind: output_tensor
- type_var: U
- shape_map: affine_map<() -> ()>
- - !LinalgOperandDefConfig
- name: fun
- kind: binary_fn_attr
- default_fn: add
- - !LinalgOperandDefConfig
- name: cast
- kind: type_fn_attr
- default_fn: cast_signed
- indexing_maps: !LinalgIndexingMapsConfig
- static_indexing_maps:
- - affine_map<() -> ()>
- - affine_map<() -> ()>
- - affine_map<() -> ()>
- iterator_types: []
- assignments:
- - !ScalarAssign
- arg: O
- value: !ScalarExpression
- scalar_fn:
- kind: binary
- attr_name: fun
- operands:
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: lhs
- - !ScalarExpression
- scalar_fn:
- kind: type
- attr_name: cast
- type_var: U
- operands:
- - !ScalarExpression
- scalar_arg: rhs
---- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: add
cpp_class_name: AddOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 48e724d80c926..1b359da40a291 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -21,21 +21,6 @@ def copy(
O[None] = cast(U, I[None])
- at linalg_structured_op
-def elemwise_unary(
- I=TensorDef(T1),
- O=TensorDef(U, output=True),
- fun=UnaryFnAttrDef(default=UnaryFn.exp),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
- """Applies the unary function fun elementwise.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- O[None] = fun(cast(U, I[None]))
-
-
@linalg_structured_op
def exp(
I=TensorDef(T1),
@@ -192,22 +177,6 @@ def erf(
O[None] = UnaryFn.erf(I[None])
- at linalg_structured_op
-def elemwise_binary(
- lhs=TensorDef(T1),
- rhs=TensorDef(T2),
- O=TensorDef(U, output=True),
- fun=BinaryFnAttrDef(default=BinaryFn.add),
- cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
- """Applies the binary function fun elementwise.
-
- Numeric casting is performed on the input operand, promoting it to the same
- data type as the accumulator/output.
- """
- O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
-
-
@linalg_structured_op
def add(
lhs=TensorDef(T1),
>From 33a7435b443ee9801ba9e7c7ba2f88ccfff29f72 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Fri, 4 Jul 2025 16:59:05 +0100
Subject: [PATCH 2/4] replacing elemwise_* with named ops
---
.../generalize-named-polymorphic-ops.mlir | 111 ------------------
mlir/test/Dialect/Linalg/invalid.mlir | 8 --
mlir/test/Dialect/Linalg/library-calls.mlir | 40 -------
.../Dialect/Linalg/match-ops-interpreter.mlir | 6 +-
.../Linalg/one-shot-bufferize-analysis.mlir | 38 ------
.../transform-op-fuse-into-containing.mlir | 42 +++----
.../Dialect/Linalg/transform-op-fuse.mlir | 72 ++++++------
.../Linalg/transform-op-generalize.mlir | 4 +-
8 files changed, 62 insertions(+), 259 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index bbd6e0fc8e2cc..290c6c7c36f76 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -434,117 +434,6 @@ func.func @generalize_const(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf
// -----
-// Verifies the default value of the fun attribute is an exp op.
-func.func @generalize_elemwise_exp(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
- %0 = linalg.elemwise_unary ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
- return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_exp
-// CHECK: = math.exp
-
-// -----
-
-// Verifies the fun attribute controls the unary function used.
-func.func @generalize_elemwise_log(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
- %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<log>}
- ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
- return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_log
-// CHECK: = math.log
-
-// -----
-
-// Verifies the fun attribute controls the unary function used.
-func.func @generalize_elemwise_abs(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
- %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<abs>}
- ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
- return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_abs
-// CHECK: = math.absf
-
-// -----
-
-// Verifies the fun attribute controls the unary function used.
-func.func @generalize_elemwise_ceil(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
- %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<ceil>}
- ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
- return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_ceil
-// CHECK: = math.ceil
-
-// -----
-
-// Verifies the fun attribute controls the unary function used.
-func.func @generalize_elemwise_floor(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
- %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<floor>}
- ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
- return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_floor
-// CHECK: = math.floor
-
-// -----
-
-// Verifies the fun attribute controls the unary function used.
-func.func @generalize_elemwise_negf(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
- %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<negf>}
- ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
- return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_negf
-// CHECK: = arith.negf
-
-// -----
-
-// Verifies the default value of the fun attribute is an add op.
-func.func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
- %0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
- outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
- return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_add
-// CHECK: = arith.addf
-
-// -----
-
-// Verifies the fun attribute controls the binary function used.
-func.func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
- %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<mul>}
- ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
- outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
- return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_mul
-// CHECK: = arith.mulf
-
-// -----
-
-// Verifies pointwise ops support rank zero input tensors
-func.func @generalize_elemwise_rank_zero(%lhs : tensor<f32>, %rhs : tensor<f32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
- %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<sub>}
- ins(%lhs, %rhs: tensor<f32>, tensor<f32>)
- outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
- return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_rank_zero
-// CHECK: linalg.generic
-// CHECK-SAME: iterator_types = ["parallel", "parallel"]
-// CHECK: = arith.subf
-
-// -----
-
// Verifies the fun attribute controls the binary function used.
func.func @generalize_copy(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = linalg.copy ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 964681d7dcd92..da1dfc7b6a624 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1909,14 +1909,6 @@ module {
// -----
-func.func @elemwise_unary_invalid_mixed_types(%arg0 : tensor<?xi32>) -> tensor<?xi32> {
- // expected-error @below {{unsupported non numeric type}}
- %0 = linalg.elemwise_unary ins(%arg0 : tensor<?xi32>) outs(%arg0 : tensor<?xi32>) -> tensor<?xi32>
- return %0 : tensor<?xi32>
-}
-
-// -----
-
func.func @matmul_invalid_mixed_types(%t: tensor<?xf16>, %f: vector<4xf16>)
-> (tensor<?xf16>, vector<4xf16>)
{
diff --git a/mlir/test/Dialect/Linalg/library-calls.mlir b/mlir/test/Dialect/Linalg/library-calls.mlir
index 1fa675d8b4b68..77c9d4a911447 100644
--- a/mlir/test/Dialect/Linalg/library-calls.mlir
+++ b/mlir/test/Dialect/Linalg/library-calls.mlir
@@ -59,43 +59,3 @@ module {
return
}
}
-
-
-// -----
-
-// CHECK: func.func private @linalg_elemwise_unary_negf_view16x8xf32_view16x8xf32(memref<16x8xf32, strided<[?, ?], offset: ?>>, memref<16x8xf32, strided<[?, ?], offset: ?>>) attributes {llvm.emit_c_interface}
-// CHECK: func.func private @linalg_elemwise_unary_negf_view16xf32_view16xf32(memref<16xf32, strided<[?], offset: ?>>, memref<16xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
-
-func.func @test_neg(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16xf32>, %D: memref<16xf32>) {
- linalg.elemwise_unary {fun = #linalg.unary_fn<negf>}
- ins(%A: memref<16x8xf32>) outs(%B: memref<16x8xf32>)
- linalg.elemwise_unary {fun = #linalg.unary_fn<negf>}
- ins(%C: memref<16xf32>) outs(%D: memref<16xf32>)
- return
-}
-
-// -----
-
-// CHECK: func.func private @linalg_elemwise_unary_exp_view16x8xf32_view16x8xf32(memref<16x8xf32, strided<[?, ?], offset: ?>>, memref<16x8xf32, strided<[?, ?], offset: ?>>) attributes {llvm.emit_c_interface}
-// CHECK: func.func private @linalg_elemwise_unary_exp_view16xf32_view16xf32(memref<16xf32, strided<[?], offset: ?>>, memref<16xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
-
-func.func @test_exp(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16xf32>, %D: memref<16xf32>) {
- linalg.elemwise_unary {fun = #linalg.unary_fn<exp>}
- ins(%A: memref<16x8xf32>) outs(%B: memref<16x8xf32>)
- linalg.elemwise_unary {fun = #linalg.unary_fn<exp>}
- ins(%C: memref<16xf32>) outs(%D: memref<16xf32>)
- return
-}
-
-// -----
-
-// CHECK: func.func private @linalg_elemwise_binary_add_view16x8xf32_view16x8xf32_view16x8xf32(memref<16x8xf32, strided<[?, ?], offset: ?>>, memref<16x8xf32, strided<[?, ?], offset: ?>>, memref<16x8xf32, strided<[?, ?], offset: ?>>) attributes {llvm.emit_c_interface}
-// CHECK: func.func private @linalg_elemwise_binary_add_view16xf32_view16xf32_view16xf32(memref<16xf32, strided<[?], offset: ?>>, memref<16xf32, strided<[?], offset: ?>>, memref<16xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
-
-func.func @test_add(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>, %D: memref<16xf32>, %E: memref<16xf32>, %F: memref<16xf32>) {
- linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
- ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
- linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
- ins(%D, %E: memref<16xf32>, memref<16xf32>) outs(%F: memref<16xf32>)
- return
-}
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index bfe7a07cb38a5..618ba3402ff52 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -842,15 +842,15 @@ module attributes { transform.with_named_sequence } {
// expected-remark @below {{op result}}
// expected-note @below {{value handle points to an op result #0}}
// expected-remark @below {{single user}}
- linalg.elemwise_unary {fun = #linalg.unary_fn<negf>} ins(%2 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+ linalg.negf ins(%2 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
// expected-remark @below {{matched result value}}
// expected-remark @below {{op result}}
// expected-note @below {{value handle points to an op result #0}}
- linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+ linalg.exp ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
// expected-remark @below {{matched result value}}
// expected-remark @below {{op result}}
// expected-note @below {{value handle points to an op result #0}}
- linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+ linalg.exp ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
return
}
}
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir
index 5b7c2baf9d84f..a0922bdfcfbe4 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir
@@ -1,43 +1,5 @@
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries test-analysis-only" -split-input-file | FileCheck %s
-// CHECK-LABEL: @elementwise_no_conflict
-func.func @elementwise_no_conflict(%a: tensor<5xf32>,
- %b: tensor<5xf32>) -> tensor<5xf32> {
- // CHECK: linalg.elemwise_binary
- // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"], fun = #linalg.binary_fn<add>}
- %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
- ins(%a, %b : tensor<5xf32>, tensor<5xf32>)
- outs(%a : tensor<5xf32>) -> tensor<5xf32>
- return %0 : tensor<5xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @elementwise_no_conflict_2
-func.func @elementwise_no_conflict_2(%a: tensor<5xf32>) -> tensor<5xf32> {
- // CHECK: linalg.elemwise_binary
- // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"], fun = #linalg.binary_fn<add>}
- %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
- ins(%a, %a : tensor<5xf32>, tensor<5xf32>)
- outs(%a : tensor<5xf32>) -> tensor<5xf32>
- return %0 : tensor<5xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @elementwise_no_conflict_3
-func.func @elementwise_no_conflict_3(%a: tensor<5xf32>) -> tensor<5xf32> {
- %c0f = arith.constant 1.0 : f32
- // CHECK: linalg.elemwise_binary
- // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none", "true"], fun = #linalg.binary_fn<add>}
- %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
- ins(%a, %c0f : tensor<5xf32>, f32)
- outs(%a : tensor<5xf32>) -> tensor<5xf32>
- return %0 : tensor<5xf32>
-}
-
-// -----
-
func.func @not_elementwise(%a: tensor<5x6xf32>) -> tensor<5x6xf32> {
%cst = arith.constant 5.0 : f32
// CHECK: tensor.extract_slice
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 5bdb5073ee865..312468970ae6d 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -26,8 +26,8 @@ module {
// CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
%6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
- // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]
- %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ // CHECK: %[[T2:.*]] = linalg.exp ins(%[[T1]]
+ %7 = linalg.exp ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
}
@@ -76,8 +76,8 @@ module {
%4 = affine.min #map2(%arg3)[%arg0]
%5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<64xf32> to tensor<?xf32>
- // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]]
- %7 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ // CHECK: %[[T2:.*]] = linalg.exp ins(%[[INIT_TENSOR]]
+ %7 = linalg.exp ins(%0 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<64xf32>
}
@@ -177,8 +177,8 @@ module {
// CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
%6 = tensor.extract_slice %arg1[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
- // CHECK: %[[T2:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T1]]
- %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ // CHECK: %[[T2:.*]] = linalg.exp {{.*}} outs(%[[T1]]
+ %7 = linalg.exp ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
}
@@ -228,8 +228,8 @@ module {
// CHECK: %[[T2:.*]] = linalg.fill {{.*}} outs(%[[T1]]
%6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
- // CHECK: %[[T3:.*]] = linalg.elemwise_unary ins(%[[T2]] : tensor<?xf32>) outs(%[[T0]] : tensor<?xf32>)
- %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ // CHECK: %[[T3:.*]] = linalg.exp ins(%[[T2]] : tensor<?xf32>) outs(%[[T0]] : tensor<?xf32>)
+ %7 = linalg.exp ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
}
@@ -261,7 +261,7 @@ module {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
- %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<abs>} ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %0 = linalg.exp {fun = #linalg.unary_fn<abs>} ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
%dim = tensor.dim %arg1, %c0 : tensor<?x?x?xf32>
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?x?xf32>
%dim_1 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
@@ -272,12 +272,12 @@ module {
%2 = scf.for %arg4 = %c0 to %dim_0 step %c1 iter_args(%arg5 = %arg3) -> (tensor<?x?x?xf32>) {
%3 = scf.for %arg6 = %c0 to %dim_1 step %c1 iter_args(%arg7 = %arg5) -> (tensor<?x?x?xf32>) {
// CHECK: %[[EX1:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}}: tensor<?x?x?xf32> to tensor<1x1x1xf32>
- // CHECK: linalg.elemwise_unary {fun = #linalg.unary_fn<abs>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX1]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+ // CHECK: linalg.exp {fun = #linalg.unary_fn<abs>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX1]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
// CHECK: %[[EX2:.*]] = tensor.extract_slice %[[BBARG2]]{{.*}} : tensor<?x?x?xf32> to tensor<1x1x1xf32>
- // CHECK: linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX2]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+ // CHECK: linalg.exp {fun = #linalg.unary_fn<exp>} ins({{.*}} : tensor<1x1x1xf32>) outs(%[[EX2]] : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
%extracted_slice = tensor.extract_slice %0[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<1x1x1xf32>
%extracted_slice_2 = tensor.extract_slice %arg7[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<?x?x?xf32> to tensor<1x1x1xf32>
- %4 = linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%extracted_slice : tensor<1x1x1xf32>) outs(%extracted_slice_2 : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
+ %4 = linalg.exp {fun = #linalg.unary_fn<exp>} ins(%extracted_slice : tensor<1x1x1xf32>) outs(%extracted_slice_2 : tensor<1x1x1xf32>) -> tensor<1x1x1xf32>
%inserted_slice = tensor.insert_slice %4 into %arg7[%arg2, %arg4, %arg6] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf32> into tensor<?x?x?xf32>
scf.yield %inserted_slice : tensor<?x?x?xf32>
}
@@ -290,7 +290,7 @@ module {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%2:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%3:3 = transform.split_handle %1 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@@ -340,8 +340,8 @@ module {
// CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}} ins(%[[T0]]
%6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
- // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]#0
- %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ // CHECK: %[[T2:.*]] = linalg.exp ins(%[[T1]]#0
+ %7 = linalg.exp ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
}
@@ -376,8 +376,8 @@ module {
%2 = tensor.extract_slice %0[%i][1][1] : tensor<2xf32> to tensor<1xf32>
%3 = tensor.extract_slice %arg1[%i][1][1] : tensor<2xf32> to tensor<1xf32>
// CHECK: %[[FUSED:.+]] = linalg.fill
- // CHECK: elemwise_unary ins(%[[FUSED]]
- %4 = linalg.elemwise_unary ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32>
+ // CHECK: exp ins(%[[FUSED]]
+ %4 = linalg.exp ins(%2 : tensor<1xf32>) outs(%3 : tensor<1xf32>) -> tensor<1xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %4 into %arg1[%i][1][1] : tensor<1xf32> into tensor<2xf32>
}
@@ -446,7 +446,7 @@ module {
// CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}}
%6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
- %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ %7 = linalg.exp ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[T1]]#0 into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
@@ -515,7 +515,7 @@ module {
// CHECK: %[[T1:.*]] = linalg.generic {{.*}}
%6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
- %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ %7 = linalg.exp ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
@@ -582,7 +582,7 @@ module {
// CHECK: %[[T1:.*]] = linalg.generic {{.*}}
%6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
- %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ %7 = linalg.exp ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
@@ -658,7 +658,7 @@ module {
// CHECK: %[[T2:.*]] = linalg.generic {{.*}}
%7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
- %8 = linalg.elemwise_unary ins(%7 : tensor<?xf32>) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
+ %8 = linalg.exp ins(%7 : tensor<?xf32>) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor<?xf32> into tensor<?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index 962858076db93..9a44f95afb586 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -5,19 +5,19 @@ func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<
// CHECK: %[[RES:.*]] = scf.for
// CHECK: scf.for
- // CHECK: linalg.elemwise_unary
- // CHECK: linalg.elemwise_binary
+ // CHECK: linalg.exp
+ // CHECK: linalg.add
// CHECK: return %[[RES]]
- %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+ %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+ %1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
@@ -31,23 +31,23 @@ func.func @fuse_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<
// CHECK: %[[PARTIAL_RES:.*]] = scf.for
// CHECK: scf.for
- // CHECK: linalg.elemwise_unary
- // CHECK: linalg.elemwise_binary
+ // CHECK: linalg.exp
+ // CHECK: linalg.add
// CHECK: %[[RES:.*]] = scf.for {{.*}}%[[PARTIAL_RES]]
// CHECK: scf.for
- // CHECK: linalg.elemwise_unary
- // CHECK: linalg.elemwise_binary
+ // CHECK: linalg.exp
+ // CHECK: linalg.add
// CHECK: return %[[RES]]
- %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+ %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
- %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
+ %1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.loop.peel %loops#0 : (!transform.op<"scf.for">) -> (!transform.any_op, !transform.any_op)
@@ -107,20 +107,20 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[RES:.*]] = scf.for
// CHECK: scf.for
// CHECK: linalg.unpack
-// CHECK: linalg.elemwise_unary
+// CHECK: linalg.exp
// CHECK: return %[[RES]]
func.func @unpack_elemwise(%arg0: tensor<16x48x8x8xf32>, %arg1: tensor<128x384xf32>) -> tensor<128x384xf32> {
%0 = tensor.empty() : tensor<128x384xf32>
%1 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0
: tensor<16x48x8x8xf32> -> tensor<128x384xf32>
- %2 = linalg.elemwise_unary ins(%1: tensor<128x384xf32>)
+ %2 = linalg.exp ins(%1: tensor<128x384xf32>)
outs(%arg1: tensor<128x384xf32>) -> tensor<128x384xf32>
return %2 : tensor<128x384xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.exp"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [16, 32], tile_interchange = [0, 1]}
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
@@ -133,20 +133,20 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[RES:.*]] = scf.for
// CHECK: scf.for
// CHECK: linalg.pack
-// CHECK: linalg.elemwise_unary
+// CHECK: linalg.exp
// CHECK: return %[[RES]]
func.func @pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> {
%0 = tensor.empty() : tensor<16x48x8x8xf32>
%1 = linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0
: tensor<128x384xf32> -> tensor<16x48x8x8xf32>
- %2 = linalg.elemwise_unary ins(%1: tensor<16x48x8x8xf32>)
+ %2 = linalg.exp ins(%1: tensor<16x48x8x8xf32>)
outs(%arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32>
return %2 : tensor<16x48x8x8xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.exp"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [3, 5, 0, 0]}
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
@@ -159,20 +159,20 @@ module attributes {transform.with_named_sequence} {
// CHECK: linalg.pack
// CHECK: %[[RES:.*]] = scf.for
// CHECK: scf.for
-// CHECK: linalg.elemwise_unary
+// CHECK: linalg.exp
// CHECK: return %[[RES]]
func.func @nofuse_pack_elemwise(%arg0: tensor<128x384xf32>, %arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32> {
%0 = tensor.empty() : tensor<16x48x8x8xf32>
%1 = linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %0
: tensor<128x384xf32> -> tensor<16x48x8x8xf32>
- %2 = linalg.elemwise_unary ins(%1: tensor<16x48x8x8xf32>)
+ %2 = linalg.exp ins(%1: tensor<16x48x8x8xf32>)
outs(%arg1: tensor<16x48x8x8xf32>) -> tensor<16x48x8x8xf32>
return %2 : tensor<16x48x8x8xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.elemwise_unary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.exp"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [3, 5, 2, 0]}
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
@@ -186,24 +186,24 @@ func.func @fuse_through_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
// CHECK: %[[RES:.*]] = scf.for
// CHECK: scf.for
- // CHECK: linalg.elemwise_unary
- // CHECK: linalg.elemwise_binary
+ // CHECK: linalg.exp
+ // CHECK: linalg.add
// CHECK: return %[[RES]]
- %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+ %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ %2 = linalg.add ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.yield
@@ -217,10 +217,10 @@ func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1:
// CHECK: %[[RES:.*]] = scf.for
// CHECK: scf.for
- // CHECK: linalg.elemwise_unary
- // CHECK: linalg.elemwise_binary
+ // CHECK: linalg.exp
+ // CHECK: linalg.add
// CHECK: return %[[RES]]
- %0 = linalg.elemwise_unary ins(%arg0 : tensor<100x100xf32>)
+ %0 = linalg.exp ins(%arg0 : tensor<100x100xf32>)
outs(%arg0: tensor<100x100xf32>) -> tensor<100x100xf32>
%1 = tensor.cast %0 : tensor<100x100xf32> to tensor<100x?xf32>
%2 = tensor.extract_slice %1 [1, 1] [98, 98] [1, 1] : tensor<100x?xf32> to tensor<98x98xf32>
@@ -230,14 +230,14 @@ func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1:
%dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%4 = tensor.extract_slice %3 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- %5 = linalg.elemwise_binary ins(%4, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ %5 = linalg.add ins(%4, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %5 : tensor<?x?xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.yield
@@ -253,8 +253,8 @@ func.func @fuse_unrelated_slices(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[SLICE1]]
// CHECK: %[[RES:.*]] = scf.for
// CHECK: scf.for
- // CHECK: linalg.elemwise_unary
- // CHECK: linalg.elemwise_binary
+ // CHECK: linalg.exp
+ // CHECK: linalg.add
// CHECK: return %[[RES]], %[[SLICE2]]
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -262,17 +262,17 @@ func.func @fuse_unrelated_slices(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>)
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%slice1 = tensor.extract_slice %arg0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%slice2 = tensor.extract_slice %slice1 [1, 1] [10, 10] [1, 1] : tensor<?x?xf32> to tensor<10x10xf32>
- %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+ %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+ %2 = linalg.add ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %2, %slice2 : tensor<?x?xf32>, tensor<10x10xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.yield
diff --git a/mlir/test/Dialect/Linalg/transform-op-generalize.mlir b/mlir/test/Dialect/Linalg/transform-op-generalize.mlir
index a0aa33c072dd4..331c9c0fbbfd5 100644
--- a/mlir/test/Dialect/Linalg/transform-op-generalize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-generalize.mlir
@@ -3,9 +3,9 @@
// CHECK-LABEL: func.func @generalize_unary
func.func @generalize_unary(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
- // CHECK-NOT: linalg.elemwise_unary
+ // CHECK-NOT: linalg.exp
// CHECK: linalg.generic
- %0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
+ %0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
>From cc10325d3b2114f4bad933673bd38d0dac3be2a7 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Fri, 4 Jul 2025 17:02:25 +0100
Subject: [PATCH 3/4] more IR tests
---
mlir/test/Dialect/SCF/canonicalize.mlir | 8 ++++----
mlir/test/Integration/Dialect/Transform/match_matmul.mlir | 2 +-
.../TilingInterface/tile-and-fuse-consumer.mlir | 8 ++++----
3 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 5e32a3a78c032..8ba8013d008a0 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1811,7 +1811,7 @@ module {
%4 = affine.min #map2(%arg3)[%dim, %arg0]
%extracted_slice0 = tensor.extract_slice %arg4[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
%extracted_slice1 = tensor.extract_slice %arg5[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
- %5 = linalg.elemwise_unary ins(%extracted_slice0 : tensor<?xf32>) outs(%extracted_slice1 : tensor<?xf32>) -> tensor<?xf32>
+ %5 = linalg.exp ins(%extracted_slice0 : tensor<?xf32>) outs(%extracted_slice1 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %5 into %arg5[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
}
@@ -1825,7 +1825,7 @@ module {
// CHECK-SAME: shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
-// CHECK: %[[ELEM:.*]] = linalg.elemwise_unary ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
+// CHECK: %[[ELEM:.*]] = linalg.exp ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
// CHECK: scf.forall.in_parallel {
// CHECK-NEXT: tensor.parallel_insert_slice %[[ELEM]] into %[[ITER_ARG_5]]
// CHECK-NEXT: }
@@ -1851,7 +1851,7 @@ module {
%extracted_slice_0 = tensor.extract_slice %arg6[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
%extracted_slice_1 = tensor.extract_slice %arg7[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
%extracted_slice_2 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
- %5 = linalg.elemwise_unary ins(%extracted_slice : tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) -> tensor<?xf32>
+ %5 = linalg.exp ins(%extracted_slice : tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) -> tensor<?xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %5 into %arg6[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
@@ -1868,7 +1868,7 @@ module {
// CHECK-SAME: shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
// CHECK: %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
-// CHECK: %[[ELEM:.*]] = linalg.elemwise_unary ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
+// CHECK: %[[ELEM:.*]] = linalg.exp ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
// CHECK: scf.forall.in_parallel {
// CHECK-NEXT: tensor.parallel_insert_slice %[[ELEM]] into %[[ITER_ARG_6]]
// CHECK-NEXT: }
diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
index 7b695cb027252..a374d9a611258 100644
--- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
+++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir
@@ -69,7 +69,7 @@ func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<2
// expected-remark @below {{fill}}
%fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32>
- %real_lhs = linalg.elemwise_binary { fun = #linalg.binary_fn<mul> }
+ %real_lhs = linalg.mul
ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32>
// expected-remark @below {{matmul}}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 0f69875d596f1..d09373bdb3f14 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -19,7 +19,7 @@ module {
}
%in_operand_2 = tensor.empty() : tensor<64xf32>
%out_operand_3 = tensor.empty() : tensor<64xf32>
- %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
+ %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
return %2 : tensor<64xf32>
}
}
@@ -50,7 +50,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
-// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK: %[[ELEM_OUT:.*]] = linalg.add
// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
// CHECK-SAME: outs(%[[SLICE_OUT]] :
// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
@@ -76,7 +76,7 @@ module {
}
%in_operand_2 = tensor.empty() : tensor<64x64xf32>
%out_operand_3 = tensor.empty() : tensor<64x64xf32>
- %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
+ %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
return %2 : tensor<64x64xf32>
}
}
@@ -109,7 +109,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
-// CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK: %[[ELEM_OUT:.*]] = linalg.add
// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
// CHECK-SAME: outs(%[[SLICE_OUT]] :
// CHECK: scf.forall.in_parallel {
>From 59ee2dd4956117c65183f818f2d907eb91e13f04 Mon Sep 17 00:00:00 2001
From: Renato Golin <rengolin at systemcall.eu>
Date: Fri, 4 Jul 2025 17:08:06 +0100
Subject: [PATCH 4/4] fixed python tests
---
mlir/test/python/dialects/linalg/ops.py | 36 ------
.../integration/dialects/linalg/opsrun.py | 121 ------------------
2 files changed, 157 deletions(-)
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index f1e2afa0f2408..709a1d2424f35 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -49,42 +49,6 @@ def fill_buffer(out):
print(module)
-# CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
- at run
-def testNamedStructuredOpCustomForm():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32)
- )
- def named_form(lhs, rhs):
- init_result = tensor.EmptyOp([4, 8], f32)
- # Check for the named form with custom format
- # CHECK: linalg.elemwise_unary
- # CHECK-SAME: cast = #linalg.type_fn<cast_signed>
- # CHECK-SAME: fun = #linalg.unary_fn<exp>
- # CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
- unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result])
- # CHECK: linalg.elemwise_binary
- # CHECK-SAME: cast = #linalg.type_fn<cast_unsigned>
- # CHECK-SAME: fun = #linalg.binary_fn<mul>
- # CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>)
- # CHECK: return
- binary_result = linalg.elemwise_binary(
- lhs,
- rhs,
- outs=[init_result.result],
- fun=BinaryFn.mul,
- cast=TypeFn.cast_unsigned,
- )
- return unary_result, binary_result
-
- print(module)
-
-
# CHECK-LABEL: TEST: testIdentityRegionOps
@run
def testIdentityRegionOps():
diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index 69f97e15e139d..8f202318146ee 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -19,37 +19,6 @@ def log(*args):
sys.stderr.flush()
-elemwise_boiler = """
-func.func @main() -> f32 attributes {llvm.emit_c_interface} {
- %v0 = arith.constant 0.0 : f32
- %v1 = arith.constant 1.0 : f32
- %v2 = arith.constant 2.0 : f32
-
- %lhs = memref.alloc() : memref<f32>
- %rhs = memref.alloc() : memref<4x8xf32>
- %O0 = memref.alloc() : memref<4x8xf32>
- %O1 = memref.alloc() : memref<4x8xf32>
- linalg.fill ins(%v1 : f32) outs(%lhs : memref<f32>)
- linalg.fill ins(%v2 : f32) outs(%rhs : memref<4x8xf32>)
- linalg.fill ins(%v0 : f32) outs(%O0 : memref<4x8xf32>)
- linalg.fill ins(%v0 : f32) outs(%O1 : memref<4x8xf32>)
-
- call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
- (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
- call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
- (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
-
- %c0 = arith.constant 0 : index
- %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
- %res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32>
-
- %0 = arith.addf %res0, %res1 : f32
-
- // TODO: FFI-based solution to allow testing and printing with python code.
- return %0 : f32
-}
-"""
-
fill_boiler = """
func.func @main() -> i32 attributes {llvm.emit_c_interface} {
%O0 = memref.alloc() : memref<i32>
@@ -177,96 +146,6 @@ def transform(module, boilerplate):
return mod
-def test_elemwise_builtin():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- i8 = IntegerType.get_signless(8)
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- MemRefType.get((), f32),
- MemRefType.get((4, 8), f32),
- MemRefType.get((4, 8), f32),
- )
- def elemwise_exp_add_on_buffers(lhs, rhs, out):
- linalg.elemwise_unary(lhs, outs=[out])
- linalg.elemwise_binary(out, rhs, outs=[out])
-
- @func.FuncOp.from_py_func(
- MemRefType.get((), f32),
- MemRefType.get((4, 8), f32),
- MemRefType.get((4, 8), f32),
- )
- def elemwise_log_mul_on_buffers(lhs, rhs, out):
- linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
- linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
-
- execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
-
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result f32.
- # Arguments must be passed as pointers.
- c_float_p = ctypes.c_float * 1
- res = c_float_p(-1.0)
- execution_engine.invoke("main", res)
-
- log("RESULT: ", res[0])
- # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
- # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
- # CHECK: RESULT: 4.71828
-
-
-test_elemwise_builtin()
-
-
-def test_elemwise_generic():
- with Context() as ctx, Location.unknown():
- module = Module.create()
- f32 = F32Type.get()
- i8 = IntegerType.get_signless(8)
- with InsertionPoint(module.body):
-
- @func.FuncOp.from_py_func(
- MemRefType.get((), f32),
- MemRefType.get((4, 8), f32),
- MemRefType.get((4, 8), f32),
- )
- def elemwise_exp_add_on_buffers(lhs, rhs, out):
- linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
- linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
-
- @func.FuncOp.from_py_func(
- MemRefType.get((), f32),
- MemRefType.get((4, 8), f32),
- MemRefType.get((4, 8), f32),
- )
- def elemwise_log_mul_on_buffers(lhs, rhs, out):
- linalg.elemwise_unary(
- lhs, outs=[out], fun=UnaryFn.log, emit_generic=True
- )
- linalg.elemwise_binary(
- out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True
- )
-
- execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
-
- # TODO: FFI-based solution to allow testing and printing with python code.
- # Prepare arguments: one result f32.
- # Arguments must be passed as pointers.
- c_float_p = ctypes.c_float * 1
- res = c_float_p(-1.0)
- execution_engine.invoke("main", res)
-
- log("RESULT: ", res[0])
- # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
- # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
- # CHECK: RESULT: 4.71828
-
-
-test_elemwise_generic()
-
-
def test_fill_builtin():
with Context() as ctx, Location.unknown():
module = Module.create()
More information about the Mlir-commits
mailing list