[Mlir-commits] [mlir] [MLIR][Vector] Implement TransferOpReduceRank as MaskableOpRewritePattern (PR #92426)

Han-Chung Wang llvmlistbot at llvm.org
Thu May 30 16:04:04 PDT 2024


================
@@ -187,3 +187,49 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+
+//       CHECK:   #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
+//       CHECK:   func.func @transfer_read_reduce_rank_scalable(
+//  CHECK-SAME:       %[[ARG_0:.*]]: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
+//       CHECK:     %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:     %[[TFR:.*]] = vector.transfer_read %arg0[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]{{.*}} permutation_map = #[[MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
+//       CHECK:     %[[BC:.*]] = vector.broadcast %[[TFR]] : vector<[4]x2x3xf32> to vector<8x[4]x2x3xf32>
+//       CHECK:     return %[[BC]] : vector<8x[4]x2x3xf32>
+func.func @transfer_read_reduce_rank_scalable(%mem: memref<?x?x?x?xf32>) -> vector<8x[4]x2x3xf32> {
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %1 = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0
+    {in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>}
+    : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32>
+  return %1 : vector<8x[4]x2x3xf32>
+}
+
+// Masked case not supported.
+// CHECK-LABEL:   func.func @masked_transfer_read_reduce_rank(
+//  CHECK-SAME:       %[[ARG_0:.*]]: memref<?x?x?x?xf32>,
+//  CHECK-SAME:       %[[DIM:.*]]: index) -> vector<8x[4]x2x3xf32> {
+//   CHECK-NOT:     vector.broadcast
+//       CHECK:     %[[MASK:.*]] = vector.mask %0 { vector.transfer_read %arg0{{.*}} : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
+func.func @masked_transfer_read_reduce_rank(%mem: memref<?x?x?x?xf32>, %dim: index) -> vector<8x[4]x2x3xf32> {
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %mask = vector.create_mask %dim, %dim: vector<[4]x3xi1>
+  %res = vector.mask %mask { vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst_0
+    {in_bounds = [true, true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d3)>}
+    : memref<?x?x?x?xf32>, vector<8x[4]x2x3xf32> } : vector<[4]x3xi1> -> vector<8x[4]x2x3xf32>
+  return %res : vector<8x[4]x2x3xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.transfer_permutation_patterns
+    } : !transform.any_op
+    transform.yield
+  }
+}
----------------
hanhanW wrote:

add a new line at the end.

https://github.com/llvm/llvm-project/pull/92426


More information about the Mlir-commits mailing list