[Mlir-commits] [mlir] [mlir][vector] Add tests for `TransferWritePermutationLowering` (PR #95529)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Jun 18 02:10:01 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/95529

>From 384d72fe5e8a7a863d82c3c33f4f609ee56f6a5e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 14 Jun 2024 12:33:24 +0100
Subject: [PATCH 1/4] [mlir][vector] Add tests for
 `TransferWritePermutationLowering`

Adds more tests to "vector-transfer-permutation-lowering.mlir",
specifically for the `TransferWritePermutationLowering` pattern - such
tests seem to be missing ATM.

The following edge cases are covered:
  * plain fixed-width (supported)
  * scalable vectors with mask (supported)
  * plain fixed-width, masked (not supported)

This is a part of a larger effort to make sure that all key cases for
patterns under `populateVectorTransferPermutationMapLoweringPatterns`
(*) are tested. I also want to make sure that tests use consistent
function and variable names.

(*) `transform.apply_patterns.vector.transfer_permutation_patterns` in
TD parlance)
---
 .../vector-transfer-permutation-lowering.mlir | 83 +++++++++++++++++--
 1 file changed, 77 insertions(+), 6 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 0cd134717b1a0..ac5041d13f893 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -1,14 +1,81 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
 ///----------------------------------------------------------------------------------------
-/// vector.transfer_write
+/// vector.transfer_write -> vector.transpose + vector.transfer_read
 ///----------------------------------------------------------------------------------------
-/// Input: 
-///   * vector.transfer_write op with a map which _is not_ the permutation of a
-///     minor identity
+/// Input:
+///   * vector.transfer_write op with a permutation that under a transpose
+///     _would be_ a permutation of a minor identity
 /// Output:
-///   * vector.broadcast + vector.transfer_write with a map which _is_ the permutation of a
+///   * vector.transpose + vector.transfer_write with a map which _is_ a
+///     permutation of a minor identity
+
+// CHECK-LABEL:   func.func @xfer_write_perm_minor_id_with_transpose(
+// CHECK-SAME:       %[[ARG_0:.*]]: vector<4x8xi16>,
+// CHECK-SAME:       %[[MEM:.*]]: memref<2x2x8x4xi16>) {
+// CHECK:           %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
+// CHECK:           vector.transfer_write %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
+func.func @xfer_write_perm_minor_id_with_transpose(
+    %arg0: vector<4x8xi16>,
+    %mem: memref<2x2x8x4xi16>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
+    in_bounds = [true, true],
+    permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+  } : vector<4x8xi16>, memref<2x2x8x4xi16>
+
+  return
+}
+
+// CHECK-LABEL:   func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable(
+// CHECK-SAME:      %[[ARG_0:.*]]: vector<4x[8]xi16>,
+// CHECK-SAME:      %[[MEM:.*]]: memref<2x2x?x4xi16>,
+// CHECK-SAME:      %[[MASK:.*]]: vector<[8]x4xi1>) {
+// CHECK:           %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16>
+// CHECK:           vector.transfer_write %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16>
+func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable(
+    %arg0: vector<4x[8]xi16>,
+    %mem: memref<2x2x?x4xi16>,
+    %mask: vector<[8]x4xi1>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0], %mask {
+    in_bounds = [true, true],
+    permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+  } : vector<4x[8]xi16>, memref<2x2x?x4xi16>
+
+  return
+}
+
+// Masked version is not supported
+// CHECK-LABEL:   func.func @xfer_write_perm_minor_id_with_transpose_masked
+// CHECK-NOT: vector.transpose
+func.func @xfer_write_perm_minor_id_with_transpose_masked(
+    %arg0: vector<4x8xi16>,
+    %mem: memref<2x2x8x4xi16>,
+    %mask: vector<8x4xi1>) {
+
+  %c0 = arith.constant 0 : index
+  vector.mask %mask {
+    vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
+    in_bounds = [true, true],
+    permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+    } : vector<4x8xi16>, memref<2x2x8x4xi16>
+  } : vector<8x4xi1>
+
+  return
+}
+
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_read
+///----------------------------------------------------------------------------------------
+/// Input:
+///   * vector.transfer_write op with a map which _is not_ a permutation of a
 ///     minor identity
+/// Output:
+///   * vector.broadcast + vector.transpose + vector.transfer_write with a map
+///     which _is_ a permutation of a minor identity
 
 // CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width(
 //       CHECK:   %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
@@ -94,7 +161,7 @@ func.func @masked_non_permutation_xfer_write_fixed_width(
 ///----------------------------------------------------------------------------------------
 /// vector.transfer_read
 ///----------------------------------------------------------------------------------------
-/// Input: 
+/// Input:
 ///   * vector.transfer_read op with a permutation map
 /// Output:
 ///   * vector.transfer_read with a permutation map composed of leading zeros followed by a minor identiy +
@@ -190,6 +257,10 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+/// TODO: Review and categorize
 
 //       CHECK:   #[[MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
 //       CHECK:   func.func @transfer_read_reduce_rank_scalable(

>From 42b374591583254d39343b9604b0ceead9b2ee04 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 14 Jun 2024 16:12:27 +0100
Subject: [PATCH 2/4] !fixup [mlir][vector] Add tests for
 `TransferWritePermutationLowering`

* Update test names
* Added CHECK-NOT for permutation map that shouldn't be present
* Refine comments
---
 .../vector-transfer-permutation-lowering.mlir | 31 +++++++++++--------
 1 file changed, 18 insertions(+), 13 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index ac5041d13f893..c038baae72e78 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -2,20 +2,23 @@
 
 ///----------------------------------------------------------------------------------------
 /// vector.transfer_write -> vector.transpose + vector.transfer_read
+/// [Pattern: TransferWritePermutationLowering]
 ///----------------------------------------------------------------------------------------
 /// Input:
 ///   * vector.transfer_write op with a permutation that under a transpose
-///     _would be_ a permutation of a minor identity
+///     _would be_ a minor identity permutation map
 /// Output:
-///   * vector.transpose + vector.transfer_write with a map which _is_ a
-///     permutation of a minor identity
+///   * vector.transpose + vector.transfer_write with a permutation map which
+///     _is_ a minor identity
 
-// CHECK-LABEL:   func.func @xfer_write_perm_minor_id_with_transpose(
+// CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map
 // CHECK-SAME:       %[[ARG_0:.*]]: vector<4x8xi16>,
 // CHECK-SAME:       %[[MEM:.*]]: memref<2x2x8x4xi16>) {
 // CHECK:           %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
-// CHECK:           vector.transfer_write %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
-func.func @xfer_write_perm_minor_id_with_transpose(
+// CHECK:           vector.transfer_write
+// CHECK-NOT:       permutation_map
+// CHECK-SAME:      %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
+func.func @xfer_write_transposing_permutation_map
     %arg0: vector<4x8xi16>,
     %mem: memref<2x2x8x4xi16>) {
 
@@ -28,13 +31,15 @@ func.func @xfer_write_perm_minor_id_with_transpose(
   return
 }
 
-// CHECK-LABEL:   func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable(
+// CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map_with_mask_scalable
 // CHECK-SAME:      %[[ARG_0:.*]]: vector<4x[8]xi16>,
 // CHECK-SAME:      %[[MEM:.*]]: memref<2x2x?x4xi16>,
 // CHECK-SAME:      %[[MASK:.*]]: vector<[8]x4xi1>) {
 // CHECK:           %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x[8]xi16> to vector<[8]x4xi16>
-// CHECK:           vector.transfer_write %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16>
-func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable(
+// CHECK:           vector.transfer_write
+// CHECK-NOT:       permutation_map
+// CHECK-SAME:      %[[TR]], %[[MEM]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<[8]x4xi16>, memref<2x2x?x4xi16>
+func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
     %arg0: vector<4x[8]xi16>,
     %mem: memref<2x2x?x4xi16>,
     %mask: vector<[8]x4xi1>) {
@@ -49,9 +54,9 @@ func.func @xfer_write_perm_minor_id_with_transpose_with_mask_scalable(
 }
 
 // Masked version is not supported
-// CHECK-LABEL:   func.func @xfer_write_perm_minor_id_with_transpose_masked
+// CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map_with_transpose_masked
 // CHECK-NOT: vector.transpose
-func.func @xfer_write_perm_minor_id_with_transpose_masked(
+func.func @xfer_write_transposing_permutation_map_with_transpose_masked(
     %arg0: vector<4x8xi16>,
     %mem: memref<2x2x8x4xi16>,
     %mask: vector<8x4xi1>) {
@@ -59,8 +64,8 @@ func.func @xfer_write_perm_minor_id_with_transpose_masked(
   %c0 = arith.constant 0 : index
   vector.mask %mask {
     vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
-    in_bounds = [true, true],
-    permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+      in_bounds = [true, true],
+      permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
     } : vector<4x8xi16>, memref<2x2x8x4xi16>
   } : vector<8x4xi1>
 

>From c0752d00f6fb742ee3e338fd5c7efbf598235a15 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 17 Jun 2024 16:47:16 +0100
Subject: [PATCH 3/4] fixup! !fixup [mlir][vector] Add tests for
 `TransferWritePermutationLowering`

Add missing (
---
 .../Dialect/Vector/vector-transfer-permutation-lowering.mlir    | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index c038baae72e78..2682b08dee117 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -18,7 +18,7 @@
 // CHECK:           vector.transfer_write
 // CHECK-NOT:       permutation_map
 // CHECK-SAME:      %[[TR]], %[[MEM]]{{.*}} {in_bounds = [true, true]} : vector<8x4xi16>, memref<2x2x8x4xi16>
-func.func @xfer_write_transposing_permutation_map
+func.func @xfer_write_transposing_permutation_map(
     %arg0: vector<4x8xi16>,
     %mem: memref<2x2x8x4xi16>) {
 

>From a05c18787fee8d081e81e46df214e49b5ab93424 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 18 Jun 2024 10:09:31 +0100
Subject: [PATCH 4/4] fixup! fixup! !fixup [mlir][vector] Add tests for
 `TransferWritePermutationLowering`

Refine comments
---
 .../Vector/vector-transfer-permutation-lowering.mlir     | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 2682b08dee117..35418b38df9b2 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
 
 ///----------------------------------------------------------------------------------------
-/// vector.transfer_write -> vector.transpose + vector.transfer_read
+/// vector.transfer_write -> vector.transpose + vector.transfer_write
 /// [Pattern: TransferWritePermutationLowering]
 ///----------------------------------------------------------------------------------------
 /// Input:
@@ -54,9 +54,9 @@ func.func @xfer_write_transposing_permutation_map_with_mask_scalable(
 }
 
 // Masked version is not supported
-// CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map_with_transpose_masked
+// CHECK-LABEL:   func.func @xfer_write_transposing_permutation_map_masked
 // CHECK-NOT: vector.transpose
-func.func @xfer_write_transposing_permutation_map_with_transpose_masked(
+func.func @xfer_write_transposing_permutation_map_masked(
     %arg0: vector<4x8xi16>,
     %mem: memref<2x2x8x4xi16>,
     %mask: vector<8x4xi1>) {
@@ -73,7 +73,8 @@ func.func @xfer_write_transposing_permutation_map_with_transpose_masked(
 }
 
 ///----------------------------------------------------------------------------------------
-/// vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_read
+/// vector.transfer_write -> vector.broadcast + vector.transpose + vector.transfer_write
+/// [Patterns: TransferWriteNonPermutationLowering + TransferWritePermutationLowering]
 ///----------------------------------------------------------------------------------------
 /// Input:
 ///   * vector.transfer_write op with a map which _is not_ a permutation of a



More information about the Mlir-commits mailing list