[Mlir-commits] [mlir] [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n) (PR #95744)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Jun 21 03:03:54 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/95744

>From 08c38a85510e6a7f826486ed4c759690f1ac2a22 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 15 Jun 2024 20:27:07 +0100
Subject: [PATCH] [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc)
 (2/n)

The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:
1. `@transfer_{read|write}_dims_mismatch_non_contiguous` and
   `@transfer_read_flattenable_negative` duplicated
   `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`. Both
   tests are deleted
   (`@transfer_{read|write}_dims_mismatch_non_contiguous_slice` is
   preserved).

2. `@transfer_read_flattenable_negative2` is replaced with
   `@transfer_read_non_contiguous_src` and
   `@transfer_write_non_contiguous_src` (i.e. a dedicated test for
   xfer_read and xfer_read with more descriptive func names)

Depends on https://github.com/llvm/llvm-project/pull/95743.

**Only review the top commit.**
---
 .../Vector/vector-transfer-flatten.mlir       | 116 +++++++-----------
 1 file changed, 44 insertions(+), 72 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 40a8b7e5e0737..3a5041fca53fc 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -131,25 +131,6 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
 
 // -----
 
-func.func @transfer_read_dims_mismatch_non_contiguous(
-    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
-
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0 : i8
-  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
-  return %v : vector<2x1x2x2xi8>
-}
-
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
-//   CHECK-128B-NOT:   memref.collapse_shape
-
-// -----
-
 // The input memref has a dynamic trailing shape and hence is not flattened.
 // TODO: This case could be supported via memref.dim
 
@@ -214,6 +195,28 @@ func.func @transfer_read_0d(
 
 // -----
 
+// Strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_read_non_contiguous_src(
+    %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_non_contiguous_src
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// vector.transfer_write
 /// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
@@ -342,25 +345,6 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
 
 // -----
 
-func.func @transfer_write_dims_mismatch_non_contiguous(
-    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
-    %vec : vector<2x1x2x2xi8>) {
-
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
-    vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
-  return
-}
-
-// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous(
-//   CHECK-128B-NOT:   memref.collapse_shape
-
-// -----
-
 // The input memref has a dynamic trailing shape and hence is not flattened.
 // TODO: This case could be supported via memref.dim
 
@@ -427,6 +411,28 @@ func.func @transfer_write_0d(
 
 // -----
 
+// The strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_write_non_contiguous_src(
+    %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>,
+    %vec : vector<5x4x3x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] :
+   vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>
+  return
+}
+
+// CHECK-LABEL: func.func @transfer_write_non_contiguous_src
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_write_non_contiguous_src
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// TODO: Categorize + re-format
 ///----------------------------------------------------------------------------------------
@@ -478,40 +484,6 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto
 
 // -----
 
-func.func @transfer_read_flattenable_negative(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x2x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x2x2x2xi8>
-    return %v : vector<2x2x2x2xi8>
-}
-
-// CHECK-LABEL: func @transfer_read_flattenable_negative
-//       CHECK:   vector.transfer_read {{.*}} vector<2x2x2x2xi8>
-
-// CHECK-128B-LABEL: func @transfer_read_flattenable_negative(
-//   CHECK-128B-NOT:   memref.collapse_shape
-
-// -----
-
-func.func @transfer_read_flattenable_negative2(
-      %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
-    return %v : vector<5x4x3x2xi8>
-}
-
-// CHECK-LABEL: func @transfer_read_flattenable_negative2
-//       CHECK:   vector.transfer_read {{.*}} vector<5x4x3x2xi8>
-
-// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2(
-//   CHECK-128B-NOT:   memref.collapse_shape
-
-// -----
-
 func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
    %add = arith.addi %arg0, %arg0 : vector<1x8xi32>
    return %add : vector<1x8xi32>



More information about the Mlir-commits mailing list