[Mlir-commits] [mlir] [mlir][linalg] convert arith ops to destination-passing-style. (PR #157854)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Sep 15 06:36:41 PDT 2025
================
@@ -252,3 +252,96 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_unary_op(
+// CHECK-SAME: %[[X:.+]]: tensor<64xi32>) -> tensor<64xf32> {
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) {
+// CHECK: ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32):
+// CHECK: %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32
+// CHECK: linalg.yield %[[z]] : f32
+// CHECK: return %[[GENERIC]] : tensor<64xf32>
+
+func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> {
+ %z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32>
+ return %z : tensor<64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.uitofp"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_binop(
+// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
+// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
+// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32
+// CHECK: linalg.yield %[[z]] : f32
+// CHECK: return %[[GENERIC]] : tensor<?xf32>
+
+func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>)
+ -> tensor<?xf32> {
+ %z = arith.addf %x, %y : tensor<?xf32>
+ return %z : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addf"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_binop_fastmath(
+// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
+// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
+// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] fastmath<fast> : f32
+// CHECK: linalg.yield %[[z]] : f32
+// CHECK: return %[[GENERIC]] : tensor<?xf32>
----------------
banach-space wrote:
You have already checked all the fine details when testing `@arith_binop`, here it is sufficient to make sure that `fastmath` is propagated correctly. I suggest simplifying:
```suggestion
// CHECK-LABEL: func @arith_binop_fastmath(
// CHECK: linalg.generic
// CHECK-SAME: ins({{.*}} : tensor<?xf32>, tensor<?xf32>) outs({{.*}} tensor<?xf32>) {
// CHECK: ^bb0({.*}}: f32, %{{.*}}: f32, %[[Out:.+]]: f32):
// CHECK: arith.addf {{.*}} fastmath<fast> : f32
```
https://github.com/llvm/llvm-project/pull/157854
More information about the Mlir-commits
mailing list