[Mlir-commits] [mlir] [mlir]linalg][NFC]-Add lit test for tile and fuse transformation (PR #126216)

Amir Bishara llvmlistbot at llvm.org
Fri Feb 7 02:03:39 PST 2025


https://github.com/amirBish created https://github.com/llvm/llvm-project/pull/126216

Add coverage for the fuse consumer transform for
`linalg.generic` operation with projected permutation indexing maps.

>From bff3aeccb0c465b2a97f391f7ed536eb1b4932bb Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Fri, 7 Feb 2025 11:15:44 +0200
Subject: [PATCH] [mlir]linalg][NFC]-Add lit test for tile and fuse
 transformation

Add coverage for the fuse consumer transform for
`linalg.generic` operation with projected permutation
indexing maps.
---
 .../tile-and-fuse-consumer.mlir               | 62 +++++++++++++++++++
 1 file changed, 62 insertions(+)

diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index a2871b30698c527..2d35be403ef9937 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -676,3 +676,65 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:   }
 //      CHECK:   %[[RES_SLICE:.+]] = tensor.insert_slice
 //      CHECK:   return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+module {
+  func.func @fuse_with_tilable_consumer_with_projected_permutations(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<24xf32>) -> tensor<256x256x24xf32> {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c256 = arith.constant 256 : index
+    %0 = tensor.empty() : tensor<256x256xf32>
+    %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %0) -> (tensor<256x256xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+      %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
+      %4 = linalg.add ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice : tensor<64x256xf32>) -> tensor<64x256xf32>
+      %inserted_slice = tensor.insert_slice %4 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
+      scf.yield %inserted_slice : tensor<256x256xf32>
+    }
+    %2 = tensor.empty() : tensor<256x256x24xf32>
+    %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1, %arg2 : tensor<256x256xf32>, tensor<24xf32>) outs(%2 : tensor<256x256x24xf32>) {
+    ^bb0(%in: f32, %in_0: f32, %out: f32):
+      %4 = arith.addf %in, %in_0 : f32
+      linalg.yield %4 : f32
+    } -> tensor<256x256x24xf32>
+    return %3 : tensor<256x256x24xf32>
+  }
+}
+
+// CHECK: func.func @fuse_with_tilable_consumer_with_projected_permutations(%[[VAL_0:.*]]: tensor<256x256xf32>, %[[VAL_1:.*]]: tensor<256x256xf32>, %[[VAL_2:.*]]: tensor<24xf32>) -> tensor<256x256x24xf32> {
+// CHECK:             %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:             %[[VAL_4:.*]] = arith.constant 64 : index
+// CHECK:             %[[VAL_5:.*]] = arith.constant 256 : index
+// CHECK:             %[[VAL_6:.*]] = tensor.empty() : tensor<256x256xf32>
+// CHECK:             %[[VAL_7:.*]] = tensor.empty() : tensor<256x256x24xf32>
+// CHECK:             %[[VAL_8:.*]]:2 = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_6]], %[[VAL_11:.*]] = %[[VAL_7]]) -> (tensor<256x256xf32>, tensor<256x256x24xf32>) {
+// CHECK:               %[[VAL_12:.*]] = tensor.extract_slice %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK:               %[[VAL_13:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK:               %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_1]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK:               %[[VAL_15:.*]] = linalg.add ins(%[[VAL_13]], %[[VAL_14]] : tensor<64x256xf32>, tensor<64x256xf32>) outs(%[[VAL_12]] : tensor<64x256xf32>) -> tensor<64x256xf32>
+// CHECK:               %[[VAL_16:.*]] = tensor.insert_slice %[[VAL_15]] into %[[VAL_10]]{{\[}}%[[VAL_9]], 0] [64, 256] [1, 1]
+// CHECK:               %[[VAL_17:.*]] = tensor.extract_slice %[[VAL_2]][0] [24] [1] : tensor<24xf32> to tensor<24xf32>
+// CHECK:               %[[VAL_18:.*]] = tensor.extract_slice %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
+// CHECK:               %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_15]], %[[VAL_17]] : tensor<64x256xf32>, tensor<24xf32>) outs(%[[VAL_18]] : tensor<64x256x24xf32>) {
+// CHECK:               ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
+// CHECK:                 %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
+// CHECK:                 linalg.yield %[[VAL_23]] : f32
+// CHECK:               } -> tensor<64x256x24xf32>
+// CHECK:               %[[VAL_24:.*]] = tensor.insert_slice %[[VAL_25:.*]] into %[[VAL_11]]{{\[}}%[[VAL_9]], 0, 0] [64, 256, 24] [1, 1, 1]
+// CHECK:               scf.yield %[[VAL_16]], %[[VAL_24]] : tensor<256x256xf32>, tensor<256x256x24xf32>
+// CHECK:             }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}



More information about the Mlir-commits mailing list