[Mlir-commits] [mlir] 1c85c71 - [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n) (#95744)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 21 04:48:29 PDT 2024
Author: Andrzej WarzyĆski
Date: 2024-06-21T12:48:26+01:00
New Revision: 1c85c711aadb65943f5187524274fc96d1151b02
URL: https://github.com/llvm/llvm-project/commit/1c85c711aadb65943f5187524274fc96d1151b02
DIFF: https://github.com/llvm/llvm-project/commit/1c85c711aadb65943f5187524274fc96d1151b02.diff
LOG: [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n) (#95744)
The main goal of this and subsequent PRs is to unify and categorize
tests in:
* vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.
Below are the main contributions of this PR
1. Two tests duplicated
`@transfer_{read|write}_dims_mismatch_non_contiguous_slice`:
* `@transfer_{read|write}_dims_mismatch_non_contiguous` and
* `@transfer_read_flattenable_negative` duplicated
`@transfer_{read|write}_dims_mismatch_non_contiguous_slice`.
These tests are removed (the original test is preserved).
2. `@transfer_read_flattenable_negative2` is replaced with
two tests with more descriptive names:
* `@transfer_read_non_contiguous_src` (for `xfer_read`) and
* `@transfer_write_non_contiguous_src` (for `xfer_write`)
Added:
Modified:
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Removed:
################################################################################
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 40a8b7e5e0737..3a5041fca53fc 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -131,25 +131,6 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
// -----
-func.func @transfer_read_dims_mismatch_non_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
-
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
- return %v : vector<2x1x2x2xi8>
-}
-
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
-// CHECK-128B-NOT: memref.collapse_shape
-
-// -----
-
// The input memref has a dynamic trailing shape and hence is not flattened.
// TODO: This case could be supported via memref.dim
@@ -214,6 +195,28 @@ func.func @transfer_read_0d(
// -----
+// Strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_read_non_contiguous_src(
+ %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_non_contiguous_src
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src
+// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
+
+// -----
+
///----------------------------------------------------------------------------------------
/// vector.transfer_write
/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
@@ -342,25 +345,6 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
// -----
-func.func @transfer_write_dims_mismatch_non_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
- %vec : vector<2x1x2x2xi8>) {
-
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
- vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
- return
-}
-
-// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous(
-// CHECK-128B-NOT: memref.collapse_shape
-
-// -----
-
// The input memref has a dynamic trailing shape and hence is not flattened.
// TODO: This case could be supported via memref.dim
@@ -427,6 +411,28 @@ func.func @transfer_write_0d(
// -----
+// The strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_write_non_contiguous_src(
+ %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>,
+ %vec : vector<5x4x3x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] :
+ vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_write_non_contiguous_src
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_write_non_contiguous_src
+// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
+
+// -----
+
///----------------------------------------------------------------------------------------
/// TODO: Categorize + re-format
///----------------------------------------------------------------------------------------
@@ -478,40 +484,6 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto
// -----
-func.func @transfer_read_flattenable_negative(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x2x2x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x2x2x2xi8>
- return %v : vector<2x2x2x2xi8>
-}
-
-// CHECK-LABEL: func @transfer_read_flattenable_negative
-// CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8>
-
-// CHECK-128B-LABEL: func @transfer_read_flattenable_negative(
-// CHECK-128B-NOT: memref.collapse_shape
-
-// -----
-
-func.func @transfer_read_flattenable_negative2(
- %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
- return %v : vector<5x4x3x2xi8>
-}
-
-// CHECK-LABEL: func @transfer_read_flattenable_negative2
-// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8>
-
-// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2(
-// CHECK-128B-NOT: memref.collapse_shape
-
-// -----
-
func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
%add = arith.addi %arg0, %arg0 : vector<1x8xi32>
return %add : vector<1x8xi32>
More information about the Mlir-commits
mailing list