[Mlir-commits] [mlir] 85e7428 - [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (3/n) (#95745)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 22 03:11:55 PDT 2024
Author: Andrzej WarzyĆski
Date: 2024-07-22T11:11:52+01:00
New Revision: 85e74285626191108a78e797179374219b3a67d4
URL: https://github.com/llvm/llvm-project/commit/85e74285626191108a78e797179374219b3a67d4
DIFF: https://github.com/llvm/llvm-project/commit/85e74285626191108a78e797179374219b3a67d4.diff
LOG: [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (3/n) (#95745)
The main goal of this and subsequent PRs is to unify and categorize
tests in:
* vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.
The main contributions of this PR:
1. For consistency with other tests,
`@transfer_read_flattenable_with_dynamic_dims_and_indices` is renamed
as `@transfer_read_leading_dynamic_dims`. It is also moved near other
tests for `xfer_read`, variable names are updated to match other
`xfer_read` tests
2. `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim`
is renamed as `@negative_transfer_read_dynamic_dim_to_flatten` to
better highlight that it's a negative test and to contrast it with
`@transfer_read_leading_dynamic_dims` (and to emphasise the
difference between the two).
3. Similar changes for tests for `xfer_write`.
4. Make sure that we consistently use `%idx_N` (as opposed to `%idxN`).
Follow-up for #95743 and #95744
Added:
Modified:
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Removed:
################################################################################
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 303f841e8a828..621baef82319f 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -110,12 +110,12 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
%arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
- %idx0 : index,
- %idx1 : index) -> vector<2x2xf32> {
+ %idx_1 : index,
+ %idx_2 : index) -> vector<2x2xf32> {
%c0 = arith.constant 0 : index
%cst_1 = arith.constant 0.000000e+00 : f32
- %8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} :
+ %8 = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %cst_1 {in_bounds = [true, true]} :
memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
return %8 : vector<2x2xf32>
}
@@ -123,7 +123,8 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]]
+// CHECK-SAME: : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
@@ -131,10 +132,42 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
// -----
-// The input memref has a dynamic trailing shape and hence is not flattened.
-// TODO: This case could be supported via memref.dim
+// The leading dynamic shapes don't affect whether this example is flattenable
+// or not. Indeed, those dynamic shapes are not candidates for flattening anyway.
-func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+func.func @transfer_read_leading_dynamic_dims(
+ %arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
+ %idx_1 : index,
+ %idx_2 : index) -> vector<8x4xi8> {
+
+ %c0_i8 = arith.constant 0 : i8
+ %c0 = arith.constant 0 : index
+ %result = vector.transfer_read %arg[%idx_1, %idx_2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} :
+ memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
+ return %result : vector<8x4xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
+// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
+// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME: {in_bounds = [true]}
+// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
+// CHECK: return %[[VEC2D]] : vector<8x4xi8>
+
+// CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims
+// CHECK-128B: memref.collapse_shape
+
+// -----
+
+// One of the dims to be flattened is dynamic - not supported ATM.
+
+func.func @negative_transfer_read_dynamic_dim_to_flatten(
%idx_1: index,
%idx_2: index,
%m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -146,11 +179,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
return %v : vector<1x2x6xi32>
}
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
// -----
@@ -326,11 +359,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
%value : vector<2x2xf32>,
%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
- %idx0 : index,
- %idx1 : index) {
+ %idx_1 : index,
+ %idx_2 : index) {
%c0 = arith.constant 0 : index
- vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
+ vector.transfer_write %value, %subview[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
return
}
@@ -345,10 +378,40 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
// -----
-// The input memref has a dynamic trailing shape and hence is not flattened.
-// TODO: This case could be supported via memref.dim
+// The leading dynamic shapes don't affect whether this example is flattenable
+// or not. Indeed, those dynamic shapes are not candidates for flattening anyway.
+
+func.func @transfer_write_leading_dynamic_dims(
+ %vec : vector<8x4xi8>,
+ %arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
+ %idx_1 : index,
+ %idx_2 : index) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg[%idx_1, %idx_2, %c0, %c0] {in_bounds = [true, true]} :
+ vector<8x4xi8>, memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func @transfer_write_leading_dynamic_dims
+// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
+// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
+// CHECK-SAME: {in_bounds = [true]}
+// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
+
+// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
+// CHECK-128B: memref.collapse_shape
-func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
+// -----
+
+// One of the dims to be flattened is dynamic - not supported ATM.
+
+func.func @negative_transfer_write_dynamic_to_flatten(
%idx_1: index,
%idx_2: index,
%vec : vector<1x2x6xi32>,
@@ -361,11 +424,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
return
}
-// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast
-// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
// -----
@@ -434,56 +497,10 @@ func.func @transfer_write_non_contiguous_src(
// -----
///----------------------------------------------------------------------------------------
-/// TODO: Categorize + re-format
+/// [Pattern: DropUnitDimFromElementwiseOps]
+/// TODO: Move to a dedicated file - there's no "flattening" in the following tests
///----------------------------------------------------------------------------------------
-func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
- %c0_i8 = arith.constant 0 : i8
- %c0 = arith.constant 0 : index
- %result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
- return %result : vector<8x4xi8>
-}
-
-// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
-// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
-// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
-// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
-// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
-// CHECK-SAME: {in_bounds = [true]}
-// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
-// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
-// CHECK: return %[[VEC2D]] : vector<8x4xi8>
-
-// CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
-// CHECK-128B: memref.collapse_shape
-
-// -----
-
-func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>
- return
-}
-
-// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
-// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
-// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
-// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
-// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
-// CHECK-SAME: {in_bounds = [true]}
-// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
-
-// CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
-// CHECK-128B: memref.collapse_shape
-
-// -----
-
func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
%add = arith.addi %arg0, %arg0 : vector<1x8xi32>
return %add : vector<1x8xi32>
More information about the Mlir-commits
mailing list