[Mlir-commits] [mlir] [mlir][linalg] Extend `FuseElementwiseOps` pattern to work with named ops (PR #144922)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 20 08:59:31 PDT 2025
================
@@ -59,3 +59,57 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
+ %fill = tensor.empty() : tensor<8xf32>
+ %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
+ %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
+ return %mapped_65 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
+// CHECK-NEXT: linalg.yield %[[SQRT]]
+// CHECK-NOT: linalg.map
+
+// -----
+
+func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
+ %init = tensor.empty() : tensor<8xi1>
+ %initf = tensor.empty() : tensor<8xf32>
+ %0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
+ %1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
+ %2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>)
+ (%in0 : f32, %in1 : f32) {
+ %cmp = arith.cmpf olt, %in0, %in1 : f32
+ linalg.yield %cmp : i1
+ }
+ %3 = linalg.map { arith.select } ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>)
+ return %3 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops_mixed_types
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[EXP0:.*]] = math.exp %[[IN1]]
+// CHECK-NEXT: %[[SQRT0:.*]] = math.sqrt %[[IN0]]
+// CHECK-NEXT: %[[EXP1:.*]] = math.exp %[[IN1]]
+// CHECK-NEXT: %[[SQRT1:.*]] = math.sqrt %[[IN0]]
----------------
srcarroll wrote:
here's the generic version
```
#map = affine_map<(d0)->(d0)>
func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg3: tensor<8xf32>) -> tensor<8xf32> {
%init = tensor.empty() : tensor<8xi1>
%initf = tensor.empty() : tensor<8xf32>
%0 = linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel"]} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>) {
^bb0(%in0 : f32, %out : f32):
%sqrt = math.sqrt %in0 : f32
linalg.yield %sqrt : f32
} -> tensor<8xf32>
%1 = linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel"]} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>) {
^bb0(%in0 : f32, %out : f32):
%sqrt = math.exp %in0 : f32
linalg.yield %sqrt : f32
} -> tensor<8xf32>
%2 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel"]} ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs(%init : tensor<8xi1>) {
^bb0(%in0 : f32, %in1 : f32, %out: i1):
%cmp = arith.cmpf olt, %in0, %in1 : f32
linalg.yield %cmp : i1
} -> tensor<8xi1>
%3 = linalg.generic {
indexing_maps = [#map, #map, #map, #map],
iterator_types = ["parallel"]} ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>) {
^bb0(%in0 : i1, %in1 : f32, %in2 : f32, %out: f32):
%select = arith.select %in0, %in1, %in2 : f32
linalg.yield %select : f32
} -> tensor<8xf32>
return %3 : tensor<8xf32>
}
```
https://github.com/llvm/llvm-project/pull/144922
More information about the Mlir-commits
mailing list