[Mlir-commits] [mlir] [mlir][vector] Add tests for `TransferWritePermutationLowering` (PR #95529)

Hugo Trachino llvmlistbot at llvm.org
Mon Jun 17 06:50:11 PDT 2024


================
@@ -1,14 +1,81 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
 ///----------------------------------------------------------------------------------------
-/// vector.transfer_write
+/// vector.transfer_write -> vector.transpose + vector.transfer_read
 ///----------------------------------------------------------------------------------------
-/// Input: 
-///   * vector.transfer_write op with a map which _is not_ the permutation of a
-///     minor identity
+/// Input:
+///   * vector.transfer_write op with a permutation that under a transpose
+///     _would be_ a permutation of a minor identity
 /// Output:
-///   * vector.broadcast + vector.transfer_write with a map which _is_ the permutation of a
+///   * vector.transpose + vector.transfer_write with a map which _is_ a
+///     permutation of a minor identity
+
+// CHECK-LABEL:   func.func @xfer_write_perm_minor_id_with_transpose(
+// CHECK-SAME:       %[[ARG_0:.*]]: vector<4x8xi16>,
+// CHECK-SAME:       %[[MEM:.*]]: memref<2x2x8x4xi16>) {
+// CHECK:           %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
+// CHECK:           vector.transfer_write %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
+func.func @xfer_write_perm_minor_id_with_transpose(
+    %arg0: vector<4x8xi16>,
+    %mem: memref<2x2x8x4xi16>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
+    in_bounds = [true, true],
+    permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+  } : vector<4x8xi16>, memref<2x2x8x4xi16>
+
+  return
+}
+
+// CHECK-LABEL:   func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable(
+// CHECK-SAME:      %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME:      %[[MEM:.*]]: memref<2x2x?x4xi16>,
+// CHECK-SAME:      %[[MASK:.*]]: vector<[8]x4xi1>) {
+// CHECK:           %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16>
+// CHECK:           vector.transfer_write %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16>
+func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable(
+    %arg0: vector<4x[8]xi16>,
+    %mem: memref<2x2x?x4xi16>,
+    %mask: vector<[8]x4xi1>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0], %mask {
+    in_bounds = [true, true],
+    permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+  } : vector<4x[8]xi16>, memref<2x2x?x4xi16>
+
+  return
+}
+
+// Masked version is not supported
+// CHECK-LABEL:   func.func @xfer_write_perm_minor_id_with_transpose_masked
+// CHECK-NOT: vector.transpose
+func.func @xfer_write_perm_minor_id_with_transpose_masked(
+    %arg0: vector<4x8xi16>,
+    %mem: memref<2x2x8x4xi16>,
+    %mask: vector<8x4xi1>) {
+
+  %c0 = arith.constant 0 : index
+  vector.mask %mask {
+    vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
+    in_bounds = [true, true],
+    permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+    } : vector<4x8xi16>, memref<2x2x8x4xi16>
+  } : vector<8x4xi1>
+
+  return
+}
+
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_read
+///----------------------------------------------------------------------------------------
----------------
nujaa wrote:

Agreed.

https://github.com/llvm/llvm-project/pull/95529


More information about the Mlir-commits mailing list