[Mlir-commits] [mlir] [MLIR][Linalg] Remove elemwise_unary and elemwise_binary (PR #147082)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 4 09:13:14 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir-linalg

Author: Renato Golin (rengolin)

<details>
<summary>Changes</summary>

RFC: https://discourse.llvm.org/t/rfc-deprecate-linalg-elemwise-unary-and-elemwise-binary/87144

Remove the two operations and fix the tests by:
* Cleaning simple operation tests of the old ops
* Changing `linalg.elemwise_{u|bi}nary` with `linalg.{exp|add}` on transform tests
* Surgically removing the `elemwise_*` part in the Python tests

Nothing else changed.


---

Patch is 54.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147082.diff


15 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (-114) 
- (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (-31) 
- (modified) mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir (-111) 
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (-8) 
- (modified) mlir/test/Dialect/Linalg/library-calls.mlir (-40) 
- (modified) mlir/test/Dialect/Linalg/match-ops-interpreter.mlir (+3-3) 
- (modified) mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir (-38) 
- (modified) mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir (+21-21) 
- (modified) mlir/test/Dialect/Linalg/transform-op-fuse.mlir (+36-36) 
- (modified) mlir/test/Dialect/Linalg/transform-op-generalize.mlir (+2-2) 
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+4-4) 
- (modified) mlir/test/Integration/Dialect/Transform/match_matmul.mlir (+1-1) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+4-4) 
- (modified) mlir/test/python/dialects/linalg/ops.py (-36) 
- (modified) mlir/test/python/integration/dialects/linalg/opsrun.py (-121) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 6344861c53ac5..3637147c5a90d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -44,56 +44,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: I
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: elemwise_unary
-  cpp_class_name: ElemwiseUnaryOp
-  doc: |-
-    Applies the unary function fun elementwise.
-
-    Numeric casting is performed on the input operand, promoting it to the same
-    data type as the accumulator/output.
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: I
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<() -> ()>
-  - !LinalgOperandDefConfig
-    name: O
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<() -> ()>
-  - !LinalgOperandDefConfig
-    name: fun
-    kind: unary_fn_attr
-    default_fn: exp
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<() -> ()>
-    - affine_map<() -> ()>
-  iterator_types: []
-  assignments:
-  - !ScalarAssign
-    arg: O
-    value: !ScalarExpression
-      scalar_fn:
-        kind: unary
-        attr_name: fun
-        operands:
-        - !ScalarExpression
-          scalar_fn:
-            kind: type
-            attr_name: cast
-            type_var: U
-            operands:
-            - !ScalarExpression
-              scalar_arg: I
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: exp
   cpp_class_name: ExpOp
@@ -549,70 +499,6 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: I
 --- !LinalgOpConfig
-metadata: !LinalgOpMetadata
-  name: elemwise_binary
-  cpp_class_name: ElemwiseBinaryOp
-  doc: |-
-    Applies the binary function fun elementwise.
-
-    Numeric casting is performed on the input operand, promoting it to the same
-    data type as the accumulator/output.
-structured_op: !LinalgStructuredOpConfig
-  args:
-  - !LinalgOperandDefConfig
-    name: lhs
-    kind: input_tensor
-    type_var: T1
-    shape_map: affine_map<() -> ()>
-  - !LinalgOperandDefConfig
-    name: rhs
-    kind: input_tensor
-    type_var: T2
-    shape_map: affine_map<() -> ()>
-  - !LinalgOperandDefConfig
-    name: O
-    kind: output_tensor
-    type_var: U
-    shape_map: affine_map<() -> ()>
-  - !LinalgOperandDefConfig
-    name: fun
-    kind: binary_fn_attr
-    default_fn: add
-  - !LinalgOperandDefConfig
-    name: cast
-    kind: type_fn_attr
-    default_fn: cast_signed
-  indexing_maps: !LinalgIndexingMapsConfig
-    static_indexing_maps:
-    - affine_map<() -> ()>
-    - affine_map<() -> ()>
-    - affine_map<() -> ()>
-  iterator_types: []
-  assignments:
-  - !ScalarAssign
-    arg: O
-    value: !ScalarExpression
-      scalar_fn:
-        kind: binary
-        attr_name: fun
-        operands:
-        - !ScalarExpression
-          scalar_fn:
-            kind: type
-            attr_name: cast
-            type_var: U
-            operands:
-            - !ScalarExpression
-              scalar_arg: lhs
-        - !ScalarExpression
-          scalar_fn:
-            kind: type
-            attr_name: cast
-            type_var: U
-            operands:
-            - !ScalarExpression
-              scalar_arg: rhs
---- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: add
   cpp_class_name: AddOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 48e724d80c926..1b359da40a291 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -21,21 +21,6 @@ def copy(
     O[None] = cast(U, I[None])
 
 
- at linalg_structured_op
-def elemwise_unary(
-    I=TensorDef(T1),
-    O=TensorDef(U, output=True),
-    fun=UnaryFnAttrDef(default=UnaryFn.exp),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
-    """Applies the unary function fun elementwise.
-
-    Numeric casting is performed on the input operand, promoting it to the same
-    data type as the accumulator/output.
-    """
-    O[None] = fun(cast(U, I[None]))
-
-
 @linalg_structured_op
 def exp(
     I=TensorDef(T1),
@@ -192,22 +177,6 @@ def erf(
     O[None] = UnaryFn.erf(I[None])
 
 
- at linalg_structured_op
-def elemwise_binary(
-    lhs=TensorDef(T1),
-    rhs=TensorDef(T2),
-    O=TensorDef(U, output=True),
-    fun=BinaryFnAttrDef(default=BinaryFn.add),
-    cast=TypeFnAttrDef(default=TypeFn.cast_signed),
-):
-    """Applies the binary function fun elementwise.
-
-    Numeric casting is performed on the input operand, promoting it to the same
-    data type as the accumulator/output.
-    """
-    O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None]))
-
-
 @linalg_structured_op
 def add(
     lhs=TensorDef(T1),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index bbd6e0fc8e2cc..290c6c7c36f76 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -434,117 +434,6 @@ func.func @generalize_const(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf
 
 // -----
 
-// Verifies the default value of the fun attribute is an exp op.
-func.func @generalize_elemwise_exp(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  %0 = linalg.elemwise_unary ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
-  return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_exp
-// CHECK:        = math.exp
-
-// -----
-
-// Verifies the fun attribute controls the unary function used.
-func.func @generalize_elemwise_log(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<log>}
-                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
-  return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_log
-// CHECK:        = math.log
-
-// -----
-
-// Verifies the fun attribute controls the unary function used.
-func.func @generalize_elemwise_abs(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<abs>}
-                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
-  return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_abs
-// CHECK:        = math.absf
-
-// -----
-
-// Verifies the fun attribute controls the unary function used.
-func.func @generalize_elemwise_ceil(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<ceil>}
-                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
-  return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_ceil
-// CHECK:        = math.ceil
-
-// -----
-
-// Verifies the fun attribute controls the unary function used.
-func.func @generalize_elemwise_floor(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<floor>}
-                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
-  return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_floor
-// CHECK:        = math.floor
-
-// -----
-
-// Verifies the fun attribute controls the unary function used.
-func.func @generalize_elemwise_negf(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  %0 = linalg.elemwise_unary {fun = #linalg.unary_fn<negf>}
-                              ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
-  return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_negf
-// CHECK:        = arith.negf
-
-// -----
-
-// Verifies the default value of the fun attribute is an add op.
-func.func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  %0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
-                              outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
-  return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_add
-// CHECK:        = arith.addf
-
-// -----
-
-// Verifies the fun attribute controls the binary function used.
-func.func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<mul>}
-                              ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>)
-                              outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
-  return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_mul
-// CHECK:        = arith.mulf
-
-// -----
-
-// Verifies pointwise ops support rank zero input tensors
-func.func @generalize_elemwise_rank_zero(%lhs : tensor<f32>, %rhs : tensor<f32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
-  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<sub>}
-                              ins(%lhs, %rhs: tensor<f32>, tensor<f32>)
-                              outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
-  return %0: tensor<4x8xf32>
-}
-
-// CHECK-LABEL: @generalize_elemwise_rank_zero
-// CHECK:       linalg.generic
-// CHECK-SAME:  iterator_types = ["parallel", "parallel"]
-// CHECK:        = arith.subf
-
-// -----
-
 // Verifies the fun attribute controls the binary function used.
 func.func @generalize_copy(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
   %0 = linalg.copy ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 964681d7dcd92..da1dfc7b6a624 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1909,14 +1909,6 @@ module {
 
 // -----
 
-func.func @elemwise_unary_invalid_mixed_types(%arg0 : tensor<?xi32>) -> tensor<?xi32> {
-  // expected-error @below {{unsupported non numeric type}}
-  %0 = linalg.elemwise_unary ins(%arg0 : tensor<?xi32>) outs(%arg0 : tensor<?xi32>) -> tensor<?xi32>
-  return %0 : tensor<?xi32>
-}
-
-// -----
-
 func.func @matmul_invalid_mixed_types(%t: tensor<?xf16>, %f: vector<4xf16>)
   -> (tensor<?xf16>, vector<4xf16>)
 {
diff --git a/mlir/test/Dialect/Linalg/library-calls.mlir b/mlir/test/Dialect/Linalg/library-calls.mlir
index 1fa675d8b4b68..77c9d4a911447 100644
--- a/mlir/test/Dialect/Linalg/library-calls.mlir
+++ b/mlir/test/Dialect/Linalg/library-calls.mlir
@@ -59,43 +59,3 @@ module {
     return
   }
 }
-
-
-// -----
-
-// CHECK: func.func private @linalg_elemwise_unary_negf_view16x8xf32_view16x8xf32(memref<16x8xf32, strided<[?, ?], offset: ?>>, memref<16x8xf32, strided<[?, ?], offset: ?>>) attributes {llvm.emit_c_interface}
-// CHECK: func.func private @linalg_elemwise_unary_negf_view16xf32_view16xf32(memref<16xf32, strided<[?], offset: ?>>, memref<16xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
-
-func.func @test_neg(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16xf32>, %D: memref<16xf32>) {
-  linalg.elemwise_unary {fun = #linalg.unary_fn<negf>}
-                              ins(%A: memref<16x8xf32>) outs(%B: memref<16x8xf32>)
-  linalg.elemwise_unary {fun = #linalg.unary_fn<negf>}
-                              ins(%C: memref<16xf32>) outs(%D: memref<16xf32>)
-  return
-}
-
-// -----
-
-// CHECK: func.func private @linalg_elemwise_unary_exp_view16x8xf32_view16x8xf32(memref<16x8xf32, strided<[?, ?], offset: ?>>, memref<16x8xf32, strided<[?, ?], offset: ?>>) attributes {llvm.emit_c_interface}
-// CHECK: func.func private @linalg_elemwise_unary_exp_view16xf32_view16xf32(memref<16xf32, strided<[?], offset: ?>>, memref<16xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
-
-func.func @test_exp(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16xf32>, %D: memref<16xf32>) {
-  linalg.elemwise_unary {fun = #linalg.unary_fn<exp>}
-                              ins(%A: memref<16x8xf32>) outs(%B: memref<16x8xf32>)
-  linalg.elemwise_unary {fun = #linalg.unary_fn<exp>}
-                              ins(%C: memref<16xf32>) outs(%D: memref<16xf32>)
-  return
-}
-
-// -----
-
-// CHECK: func.func private @linalg_elemwise_binary_add_view16x8xf32_view16x8xf32_view16x8xf32(memref<16x8xf32, strided<[?, ?], offset: ?>>, memref<16x8xf32, strided<[?, ?], offset: ?>>, memref<16x8xf32, strided<[?, ?], offset: ?>>) attributes {llvm.emit_c_interface}
-// CHECK: func.func private @linalg_elemwise_binary_add_view16xf32_view16xf32_view16xf32(memref<16xf32, strided<[?], offset: ?>>, memref<16xf32, strided<[?], offset: ?>>, memref<16xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
-
-func.func @test_add(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C: memref<16x8xf32>, %D: memref<16xf32>, %E: memref<16xf32>, %F: memref<16xf32>) {
-  linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
-                              ins(%A, %B: memref<16x8xf32>, memref<16x8xf32>) outs(%C: memref<16x8xf32>)
-  linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
-                              ins(%D, %E: memref<16xf32>, memref<16xf32>) outs(%F: memref<16xf32>)
-  return
-}
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index bfe7a07cb38a5..618ba3402ff52 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -842,15 +842,15 @@ module attributes { transform.with_named_sequence } {
     // expected-remark @below {{op result}}
     // expected-note @below {{value handle points to an op result #0}}
     // expected-remark @below {{single user}}
-    linalg.elemwise_unary {fun = #linalg.unary_fn<negf>} ins(%2 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+    linalg.negf ins(%2 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
     // expected-remark @below {{matched result value}}
     // expected-remark @below {{op result}}
     // expected-note @below {{value handle points to an op result #0}}
-    linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+    linalg.exp ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
     // expected-remark @below {{matched result value}}
     // expected-remark @below {{op result}}
     // expected-note @below {{value handle points to an op result #0}}
-    linalg.elemwise_unary {fun = #linalg.unary_fn<exp>} ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
+    linalg.exp ins(%3 : tensor<42x42xf32>) outs(%0 : tensor<42x42xf32>) -> tensor<42x42xf32>
     return
   }
 }
diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir
index 5b7c2baf9d84f..a0922bdfcfbe4 100644
--- a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir
@@ -1,43 +1,5 @@
 // RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries test-analysis-only" -split-input-file | FileCheck %s
 
-// CHECK-LABEL: @elementwise_no_conflict
-func.func @elementwise_no_conflict(%a: tensor<5xf32>,
-                                   %b: tensor<5xf32>) -> tensor<5xf32> {
-  // CHECK: linalg.elemwise_binary
-  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"], fun = #linalg.binary_fn<add>}
-  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
-      ins(%a, %b : tensor<5xf32>, tensor<5xf32>)
-      outs(%a : tensor<5xf32>) -> tensor<5xf32>
-  return %0 : tensor<5xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @elementwise_no_conflict_2
-func.func @elementwise_no_conflict_2(%a: tensor<5xf32>) -> tensor<5xf32> {
-  // CHECK: linalg.elemwise_binary
-  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"], fun = #linalg.binary_fn<add>}
-  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
-      ins(%a, %a : tensor<5xf32>, tensor<5xf32>)
-      outs(%a : tensor<5xf32>) -> tensor<5xf32>
-  return %0 : tensor<5xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @elementwise_no_conflict_3
-func.func @elementwise_no_conflict_3(%a: tensor<5xf32>) -> tensor<5xf32> {
-  %c0f = arith.constant 1.0 : f32
-  // CHECK: linalg.elemwise_binary
-  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none", "true"], fun = #linalg.binary_fn<add>}
-  %0 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
-      ins(%a, %c0f : tensor<5xf32>, f32)
-      outs(%a : tensor<5xf32>) -> tensor<5xf32>
-  return %0 : tensor<5xf32>
-}
-
-// -----
-
 func.func @not_elementwise(%a: tensor<5x6xf32>) -> tensor<5x6xf32> {
   %cst = arith.constant 5.0 : f32
   // CHECK: tensor.extract_slice
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 5bdb5073ee865..312468970ae6d 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -26,8 +26,8 @@ module {
       // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
       %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
 
-      // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]
-      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+      // CHECK: %[[T2:.*]] = linalg.exp ins(%[[T1]]
+      %7 = linalg.exp ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
       scf.forall.in_parallel {
         tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
       }
@@ -76,8 +76,8 @@ module {
       %4 = affine.min #map2(%arg3)[%arg0]
       %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<64xf32> to tensor<?xf32>
 
-      // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[INIT_TENSOR]]
-      %7 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+      // CHECK: %[[T2:.*]] = linalg.exp ins(%[[INIT_TENSOR]]
+      %7 = linalg.exp ins(%0 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
       scf.forall.in_parallel {
         tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<64xf32>
       }
@@ -177,8 +177,8 @@ module {
       // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]]
       %6 = tensor.extract_slice %arg1[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
 
-      // CHECK: %[[T2:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T1]]
-      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+      // CHECK: %[[T2:.*]] = linalg.exp {{.*}} outs(%[[T1]]
+      %7 = linalg.exp ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
       scf.forall.in_parallel {
         tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
       }
@@ -228,8 +228,8 @@ module {
       // CHECK: %[[T2:.*]] = linalg.fill {{.*}} outs(%[[T1]]
       %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
 
-      // CHECK: %[[T3:.*]] = linalg.elemwise_unary ins(%[[T2]] : tensor<?xf32>) outs(%[[T0]] : tensor<?xf32>)
-      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+      // CHECK: %[[T3:.*]] = linalg.exp ins(%[[T2]] : tensor<?xf32>) outs(%[[T0]] : tensor<?xf32>)
+      %7 = linalg.exp ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
       scf.forall.in_parallel {
         tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
       }
@@ -261,7 +261,7 @@ module {
     %c2 = arith.constant...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/147082


More information about the Mlir-commits mailing list