[Mlir-commits] [mlir] [mlir]linalg][NFC]-Add lit test for tile and fuse transformation (PR #126216)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 7 02:04:17 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Amir Bishara (amirBish)

<details>
<summary>Changes</summary>

Add coverage for the fuse consumer transform for
`linalg.generic` operation with projected permutation indexing maps.

---
Full diff: https://github.com/llvm/llvm-project/pull/126216.diff


1 Files Affected:

- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+62) 


``````````diff
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index a2871b30698c527..2d35be403ef9937 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -676,3 +676,65 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:   }
 //      CHECK:   %[[RES_SLICE:.+]] = tensor.insert_slice
 //      CHECK:   return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+module {
+  func.func @fuse_with_tilable_consumer_with_projected_permutations(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<24xf32>) -> tensor<256x256x24xf32> {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c256 = arith.constant 256 : index
+    %0 = tensor.empty() : tensor<256x256xf32>
+    %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %0) -> (tensor<256x256xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+      %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+      %4 = linalg.add ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice : tensor<64x256xf32>) -> tensor<64x256xf32>
+      %inserted_slice = tensor.insert_slice %4 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+      scf.yield %inserted_slice : tensor<256x256xf32>
+    }
+    %2 = tensor.empty() : tensor<256x256x24xf32>
+    %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %arg2 : tensor<256x256xf32>, tensor<24xf32>) outs(%2 : tensor<256x256x24xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %4 = arith.addf %in, %in_0 : f32
+      linalg.yield %4 : f32
+    } -> tensor<256x256x24xf32>
+    return %3 : tensor<256x256x24xf32>
+  }
+}
+
+// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> {
+// CHECK:             %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:             %[[VAL_4:.*]] = arith.constant 64 : index
+// CHECK:             %[[VAL_5:.*]] = arith.constant 256 : index
+// CHECK:             %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32>
+// CHECK:             %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32>
+// CHECK:             %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) {
+// CHECK:               %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK:               %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK:               %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK:               %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32>
+// CHECK:               %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK:               %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32>
+// CHECK:               %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
+// CHECK:               %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) {
+// CHECK:               ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
+// CHECK:                 %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
+// CHECK:                 linalg.yield %[[VAL_23]] : f32
+// CHECK:               } -> tensor<64x256x24xf32>
+// CHECK:               %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
+// CHECK:               scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32>
+// CHECK:             }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/126216


More information about the Mlir-commits mailing list