[Mlir-commits] [mlir] [mlir][linalg] Extend `FuseElementwiseOps` pattern to work with named ops (PR #144922)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 6 09:53:26 PST 2025
================
@@ -1017,9 +1017,75 @@ module {
// -----
+func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
+ %fill = tensor.empty() : tensor<8xf32>
+ %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
+ %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
+ return %mapped_65 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
+// CHECK-NEXT: linalg.yield %[[SQRT]]
+// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @map_matmul(%in1: tensor<8x10xf32>, %in2: tensor<10x12xf32>) -> tensor<8x12xf32> {
----------------
srcarroll wrote:
> If the matmul has a transpose/broadcast/reduction map on %exp then it shouldn't be fused.
@rengolin Actually, why not? I think this is only conditionally true. One case I'm thinking of that involves transpose is valid for fusion, for example
```
func.func @map_matmul_transpose_a(%in1: tensor<10x8xf32>, %in2: tensor<10x12xf32>) -> tensor<8x12xf32> {
%fill0 = tensor.empty() : tensor<10x8xf32>
%exp = linalg.map {math.exp} ins(%in1 : tensor<10x8xf32>) outs(%fill0: tensor<10x8xf32>)
%fill1 = tensor.empty() : tensor<8x12xf32>
%matmul = linalg.matmul indexing_maps = [
affine_map<(d0, d1, d2) -> (d2, d0)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
] ins(%exp, %in2 : tensor<10x8xf32>, tensor<10x12xf32>) outs(%fill1 : tensor<8x12xf32>) -> tensor<8x12xf32>
return %matmul : tensor<8x12xf32>
}
```
would fuse to
```
#map = affine_map<(d0, d1, d2) -> (d2, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func.func @map_matmul_transpose_a(%arg0: tensor<10x8xf32>, %arg1: tensor<10x12xf32>) -> tensor<8x12xf32> {
%0 = tensor.empty() : tensor<8x12xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<10x8xf32>, tensor<10x12xf32>) outs(%0 : tensor<8x12xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = math.exp %in : f32
%3 = arith.mulf %2, %in_0 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> tensor<8x12xf32>
return %1 : tensor<8x12xf32>
}
}
```
broadcast cases can also be valid, for example
```
func.func @map_matmul_bcast(%in1: tensor<10xf32>, %in2: tensor<10x12xf32>) -> tensor<8x12xf32> {
%fill0 = tensor.empty() : tensor<10xf32>
%exp = linalg.map {math.exp} ins(%in1 : tensor<10xf32>) outs(%fill0: tensor<10xf32>)
%fill1 = tensor.empty() : tensor<8x12xf32>
%matmul = linalg.matmul indexing_maps = [
affine_map<(d0, d1, d2) -> (d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
] ins(%exp, %in2 : tensor<10xf32>, tensor<10x12xf32>) outs(%fill1 : tensor<8x12xf32>) -> tensor<8x12xf32>
return %matmul : tensor<8x12xf32>
}
```
fuses to
```
#map = affine_map<(d0, d1, d2) -> (d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func.func @map_matmul_bcast(%arg0: tensor<10xf32>, %arg1: tensor<10x12xf32>) -> tensor<8x12xf32> {
%0 = tensor.empty() : tensor<8x12xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<10xf32>, tensor<10x12xf32>) outs(%0 : tensor<8x12xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%2 = math.exp %in : f32
%3 = arith.mulf %2, %in_0 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> tensor<8x12xf32>
return %1 : tensor<8x12xf32>
}
}
```
I'll still need to think about cases, with valid input IR, that should NOT result in fusion for the elementwise + matmul case, but I think I will need a more complex case than this to show that. But just wanted to make sure it is agreed that the above cases are valid fusions cases. And again, if you have a specific test case in mind that I'm not thinking of, I will certainly investigate/add it.
https://github.com/llvm/llvm-project/pull/144922
More information about the Mlir-commits
mailing list