[Mlir-commits] [mlir] [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (3/n) (PR #95745)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Jun 21 05:04:15 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/95745

>From 9062edd6fcb7436c59af7fc3d49c6c7c88e21c79 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 16 Jun 2024 13:44:12 +0100
Subject: [PATCH 1/2] [mlir][vector] Refactor vector-transfer-flatten.mlir
 (nfc) (3/n)

The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`,
   i.e. move it near other tests for xfer_read, unify variable names to
   match other xfer_read tests, highlight what makes this a positive
   test to better contrast it with
   `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim`

2. Similar changes for
   `@transfer_write_flattenable_with_dynamic_dims_and_indices`.

Depends on https://github.com/llvm/llvm-project/pull/95743 and https://github.com/llvm/llvm-project/pull/95744

**Only review the top top commit**
---
 .../Vector/vector-transfer-flatten.mlir       | 124 ++++++++++--------
 1 file changed, 70 insertions(+), 54 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 3a5041fca53fc..477540fa71dc2 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -131,10 +131,42 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
 
 // -----
 
+/// The leading dynamic shapes don't affect whether this example is flattenable
+/// or not as those dynamic shapes are not candidates for flattening anyway.
+
+func.func @transfer_read_leading_dynamic_dims(
+    %arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
+    %idx_1 : index,
+    %idx_2 : index) -> vector<8x4xi8> {
+
+  %c0_i8 = arith.constant 0 : i8
+  %c0 = arith.constant 0 : index
+  %result = vector.transfer_read %arg[%idx_1, %idx_2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
+  return %result : vector<8x4xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
+// CHECK-SAME:    %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
+// CHECK:       %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:       %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK-SAME:    [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME:    {in_bounds = [true]}
+// CHECK-SAME:    : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK:       %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
+// CHECK:       return %[[VEC2D]] : vector<8x4xi8>
+
+// CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
 // The input memref has a dynamic trailing shape and hence is not flattened.
 // TODO: This case could be supported via memref.dim
 
-func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
     %idx_1: index,
     %idx_2: index,
     %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -146,11 +178,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
   return %v : vector<1x2x6xi32>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
@@ -345,10 +377,40 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
 
 // -----
 
+// The leading dynamic shapes don't affect whether this example is flattenable
+// or not as those dynamic shapes are not candidates for flattening anyway.
+
+func.func @transfer_write_leading_dynamic_dims(
+    %vec : vector<8x4xi8>,
+    %arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
+    %idx_1 : index,
+    %idx_2 : index) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg[%idx_1, %idx_2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>
+  return
+}
+
+// CHECK-LABEL: func @transfer_write_leading_dynamic_dims
+// CHECK-SAME:    %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:       %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
+// CHECK:       vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+// CHECK-SAME:    [%[[ARG2]], %[[ARG3]], %[[C0]]]
+// CHECK-SAME:    {in_bounds = [true]}
+// CHECK-SAME:    : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
+
+// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
 // The input memref has a dynamic trailing shape and hence is not flattened.
 // TODO: This case could be supported via memref.dim
 
-func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
+func.func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
     %idx_1: index,
     %idx_2: index,
     %vec : vector<1x2x6xi32>,
@@ -361,11 +423,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
   return
 }
 
-// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
@@ -434,56 +496,10 @@ func.func @transfer_write_non_contiguous_src(
 // -----
 
 ///----------------------------------------------------------------------------------------
-/// TODO: Categorize + re-format
+/// [Pattern: DropUnitDimFromElementwiseOps]
+/// TODO: Move to a dedicated file - there's no "flattening" in the following tests
 ///----------------------------------------------------------------------------------------
 
-func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
-    %c0_i8 = arith.constant 0 : i8
-    %c0 = arith.constant 0 : index
-    %result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
-    return %result : vector<8x4xi8>
-}
-
-// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
-// CHECK-SAME:    %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
-// CHECK:       %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK:       %[[C0:.+]] = arith.constant 0 : index
-// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
-// CHECK:       %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
-// CHECK-SAME:    [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
-// CHECK-SAME:    {in_bounds = [true]}
-// CHECK-SAME:    : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
-// CHECK:       %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
-// CHECK:       return %[[VEC2D]] : vector<8x4xi8>
-
-// CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
-//       CHECK-128B:   memref.collapse_shape
-
-// -----
-
-func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) {
-    %c0 = arith.constant 0 : index
-    vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>
-    return
-}
-
-// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
-// CHECK-SAME:    %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
-// CHECK:       %[[C0:.+]] = arith.constant 0 : index
-// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
-// CHECK:       %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
-// CHECK:       vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
-// CHECK-SAME:    [%[[ARG2]], %[[ARG3]], %[[C0]]]
-// CHECK-SAME:    {in_bounds = [true]}
-// CHECK-SAME:    : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
-
-// CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
-//       CHECK-128B:   memref.collapse_shape
-
-// -----
-
 func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
    %add = arith.addi %arg0, %arg0 : vector<1x8xi32>
    return %add : vector<1x8xi32>

>From ac89cdd3b0eb5c5ed158f478b908bdd198f9011d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 21 Jun 2024 13:04:00 +0100
Subject: [PATCH 2/2] fixup! [mlir][vector] Refactor
 vector-transfer-flatten.mlir (nfc) (3/n)

Minor updates - addressing PR comments
---
 .../Dialect/Vector/vector-transfer-flatten.mlir    | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 477540fa71dc2..8954c24a8b558 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -110,12 +110,12 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
 
 func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
     %arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
-    %idx0 : index,
-    %idx1 : index) -> vector<2x2xf32> {
+    %idx_1 : index,
+    %idx_2 : index) -> vector<2x2xf32> {
 
   %c0 = arith.constant 0 : index
   %cst_1 = arith.constant 0.000000e+00 : f32
-  %8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} :
+  %8 = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %cst_1 {in_bounds = [true, true]} :
     memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
   return %8 : vector<2x2xf32>
 }
@@ -358,11 +358,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
 func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
     %value : vector<2x2xf32>,
     %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
-    %idx0 : index,
-    %idx1 : index) {
+    %idx_1 : index,
+    %idx_2 : index) {
 
   %c0 = arith.constant 0 : index
-  vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
+  vector.transfer_write %value, %subview[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
   return
 }
 
@@ -392,7 +392,7 @@ func.func @transfer_write_leading_dynamic_dims(
 }
 
 // CHECK-LABEL: func @transfer_write_leading_dynamic_dims
-// CHECK-SAME:    %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
+// CHECK-SAME:  %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
 // CHECK:       %[[C0:.+]] = arith.constant 0 : index
 // CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
 // CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>



More information about the Mlir-commits mailing list