[Mlir-commits] [mlir] [mlir][linalg] Extend `FuseElementwiseOps` pattern to work with named ops (PR #144922)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 1 09:48:11 PDT 2025


================
@@ -59,3 +59,154 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->
 //       CHECK:   %[[FUSED_OP:.+]] = linalg.generic
 //  CHECK-SAME:       outs(%[[EMPTY]] :
 //   CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
+    %fill = tensor.empty() : tensor<8xf32>
+    %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
+    %sqrt = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
+    return %sqrt : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+//       CHECK:   %[[FUSED_OP:.+]] = linalg.generic
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+//  CHECK-NEXT:   ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+//  CHECK-NEXT:     %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+//  CHECK-NEXT:     %[[SQRT:.*]] = math.sqrt %[[ADD]]
+//  CHECK-NEXT:     linalg.yield %[[SQRT]] 
+//   CHECK-NOT:   linalg.map
+
+// -----
+
+func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
+  %init = tensor.empty() : tensor<8xi1>
+  %initf = tensor.empty() : tensor<8xf32>
+  %0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
+  %1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
+  %2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>)
+    (%in0 : f32, %in1 : f32) {
+      %cmp = arith.cmpf olt, %in0, %in1 : f32
+      linalg.yield %cmp : i1
+  }
+  %3 = linalg.map { arith.select } ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>) 
+  return %3 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops_mixed_types
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+//       CHECK:   %[[FUSED_OP:.+]] = linalg.generic
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+//  CHECK-NEXT:   ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+//  CHECK-NEXT:     %[[EXP0:.*]] = math.exp %[[IN1]]
+//  CHECK-NEXT:     %[[SQRT0:.*]] = math.sqrt %[[IN0]]
+//  CHECK-NEXT:     %[[EXP1:.*]] = math.exp %[[IN1]]
+//  CHECK-NEXT:     %[[SQRT1:.*]] = math.sqrt %[[IN0]]
+//  CHECK-NEXT:     %[[CMP:.*]] = arith.cmpf olt, %[[SQRT1]], %[[EXP1]]
+//  CHECK-NEXT:     %[[RES:.*]] = arith.select %[[CMP]], %[[SQRT0]], %[[EXP0]]
+//  CHECK-NEXT:     linalg.yield %[[RES]] 
+//   CHECK-NOT:   linalg.map
+
+// -----
+
+#identity = affine_map<(d0, d1) -> (d0, d1)>
+#bcast = affine_map<(d0, d1) -> (d0)>
+func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tensor<8x10xf32> {
+    %fill = tensor.empty() : tensor<8x10xf32>
+    %add = linalg.elementwise
+      kind=#linalg.elementwise_kind<add>
+      indexing_maps = [#bcast, #identity, #identity]
+      ins(%in1, %in2: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) -> tensor<8x10xf32>
+    %sqrt = linalg.elementwise
+      kind=#linalg.elementwise_kind<sqrt>
+      indexing_maps = [#identity, #identity]
+      ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>) -> tensor<8x10xf32>
+    return %sqrt : tensor<8x10xf32>
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
+// CHECK-LABEL: func @elementwise_ops
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32>
+//       CHECK:   %[[FUSED_OP:.+]] = linalg.generic
+//  CHECK-SAME:       indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]]
----------------
srcarroll wrote:

fixed it by just changing how the variable is matched, but still dont understand why it was an issue to begin with

https://github.com/llvm/llvm-project/pull/144922


More information about the Mlir-commits mailing list