[Mlir-commits] [mlir] 34de7fd - [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (1/n) (#95743)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 21 02:55:49 PDT 2024
Author: Andrzej WarzyĆski
Date: 2024-06-21T10:55:45+01:00
New Revision: 34de7fd4284ce9f02c2ea902f8a8ce5fd256db3d
URL: https://github.com/llvm/llvm-project/commit/34de7fd4284ce9f02c2ea902f8a8ce5fd256db3d
DIFF: https://github.com/llvm/llvm-project/commit/34de7fd4284ce9f02c2ea902f8a8ce5fd256db3d.diff
LOG: [mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (1/n) (#95743)
The main goal of this and subsequent PRs is to unify and categorize
tests in:
* vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.
The main contributions of this PR:
* split tests that covered `xfer_read` + `xfer_write` into separate
tests (majority of the existing tests check _one_ xfer Op at a time),
* organise tests for `xfer_read` and `xfer_write` into separate
groups (separate with a big bold comment).
Note, all tests (i.e. test cases) are preserved and some new tests are
added. Deletions that you will see in `git diff` correspond to
`xfer_write` and `xfer_read` Ops being extracted to separate functions
(so that there's one xfer Op per function). In particular, the number of
test functions has grown from 26 to 30.
In addition, this PR unifies the tests so that:
* input variable names are consistent (e.g. make sure that the input
memref is always `arg`)
* CHECK lines use similar indentations
* 2 x tabs are always used for function arguments, 1 x tab for
function body
Finally, changes in "VectorTransferOpTransforms.cpp" are merely meant to
unify comments and logic between
* `FlattenContiguousRowMajorTransferWritePattern` and
* `FlattenContiguousRowMajorTransferReadPattern`.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c131fde517f80..4c93d3841bf87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -568,6 +568,7 @@ namespace {
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
+///
/// If `targetVectorBitwidth` is provided, the flattening will only happen if
/// the trailing dimension of the vector read is smaller than the provided
/// bitwidth.
@@ -617,7 +618,7 @@ class FlattenContiguousRowMajorTransferReadPattern
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
- dyn_cast<MemRefType>(collapsedSource.getType());
+ cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstDimToCollapse + 1);
@@ -658,6 +659,10 @@ class FlattenContiguousRowMajorTransferReadPattern
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_write has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
+///
+/// If `targetVectorBitwidth` is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
class FlattenContiguousRowMajorTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
public:
@@ -674,9 +679,12 @@ class FlattenContiguousRowMajorTransferWritePattern
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferWriteOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+ // 0. Check pre-conditions
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
+ // If this is already 0D/1D, there's nothing to do.
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
@@ -688,7 +696,6 @@ class FlattenContiguousRowMajorTransferWritePattern
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
- int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
@@ -697,10 +704,9 @@ class FlattenContiguousRowMajorTransferWritePattern
if (transferWriteOp.getMask())
return failure();
- SmallVector<Value> collapsedIndices =
- getCollapsedIndices(rewriter, loc, sourceType.getShape(),
- transferWriteOp.getIndices(), firstDimToCollapse);
+ int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+ // 1. Collapse the source memref
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
@@ -708,11 +714,20 @@ class FlattenContiguousRowMajorTransferWritePattern
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstDimToCollapse + 1);
+ // 2. Generate input args for a new vector.transfer_read that will read
+ // from the collapsed memref.
+ // 2.1. New dim exprs + affine map
SmallVector<AffineExpr, 1> dimExprs{
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+ // 2.2 New indices
+ SmallVector<Value> collapsedIndices =
+ getCollapsedIndices(rewriter, loc, sourceType.getShape(),
+ transferWriteOp.getIndices(), firstDimToCollapse);
+
+ // 3. Create new vector.transfer_write that writes to the collapsed memref
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
Value flatVector =
@@ -721,6 +736,9 @@ class FlattenContiguousRowMajorTransferWritePattern
rewriter.create<vector::TransferWriteOp>(
loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+ // 4. Replace the old transfer_write with the new one writing the
+ // collapsed shape
rewriter.eraseOp(transferWriteOp);
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 42bf7201daaa7..40a8b7e5e0737 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,17 +1,23 @@
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+/// [Pattern: FlattenContiguousRowMajorTransferReadPattern]
+///----------------------------------------------------------------------------------------
+
func.func @transfer_read_dims_match_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
- return %v : vector<5x4x3x2xi8>
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
}
// CHECK-LABEL: func @transfer_read_dims_match_contiguous
-// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
@@ -24,11 +30,12 @@ func.func @transfer_read_dims_match_contiguous(
func.func @transfer_read_dims_match_contiguous_empty_stride(
%arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
- return %v : vector<5x4x3x2xi8>
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
}
// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
@@ -47,16 +54,17 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
// contiguous subset of the memref, so "flattenable".
func.func @transfer_read_dims_mismatch_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
- return %v : vector<1x1x2x2xi8>
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+ return %v : vector<1x1x2x2xi8>
}
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
@@ -70,135 +78,160 @@ func.func @transfer_read_dims_mismatch_contiguous(
// -----
func.func @transfer_read_dims_mismatch_non_zero_indices(
- %idx_1: index,
- %idx_2: index,
- %m_in: memref<1x43x4x6xi32>,
- %m_out: memref<1x2x6xi32>) {
+ %idx_1: index,
+ %idx_2: index,
+ %arg: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{
+
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+ %v = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
- vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
- vector<1x2x6xi32>, memref<1x2x6xi32>
- return
+ return %v : vector<1x2x6xi32>
}
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32>,
-// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) {
+// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32>
// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
-// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
-// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-128B-NOT: memref.collapse_shape
// -----
+// Overall, the source memref is non-contiguous. However, the slice from which
+// the output vector is to be read _is_ contiguous. Hence the flattening works fine.
+
func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
- %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
- %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+ %arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+ %idx0 : index,
+ %idx1 : index) -> vector<2x2xf32> {
+
%c0 = arith.constant 0 : index
%cst_1 = arith.constant 0.000000e+00 : f32
- %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+ %8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} :
+ memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
return %8 : vector<2x2xf32>
}
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
-// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
// CHECK-128B: memref.collapse_shape
// -----
+func.func @transfer_read_dims_mismatch_non_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
+ return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+// CHECK-128B-NOT: memref.collapse_shape
+
+// -----
+
// The input memref has a dynamic trailing shape and hence is not flattened.
// TODO: This case could be supported via memref.dim
func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
- %idx_1: index,
- %idx_2: index,
- %m_in: memref<1x?x4x6xi32>,
- %m_out: memref<1x2x6xi32>) {
+ %idx_1: index,
+ %idx_2: index,
+ %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
+
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+ %v = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x?x4x6xi32>, vector<1x2x6xi32>
- vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
- vector<1x2x6xi32>, memref<1x2x6xi32>
- return
+ return %v : vector<1x2x6xi32>
}
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME: %[[M_IN:.*]]: memref<1x?x4x6xi32>,
-// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) {
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
-// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32>
-// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK: vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
// CHECK-128B-NOT: memref.collapse_shape
// -----
-func.func @transfer_read_dims_mismatch_non_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
- return %v : vector<2x1x2x2xi8>
+// The vector to be read represents a _non-contiguous_ slice of the input
+// memref.
+
+func.func @transfer_read_dims_mismatch_non_contiguous_slice(
+ %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
+ return %v : vector<2x1x2x2xi8>
}
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_slice(
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_slice(
// CHECK-128B-NOT: memref.collapse_shape
// -----
-func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
- %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
- return %v : vector<2x1x2x2xi8>
+func.func @transfer_read_0d(
+ %arg : memref<i8>) -> vector<i8> {
+
+ %cst = arith.constant 0 : i8
+ %0 = vector.transfer_read %arg[], %cst : memref<i8>, vector<i8>
+ return %0 : vector<i8>
}
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-LABEL: func.func @transfer_read_0d
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-128B-LABEL: func @transfer_read_0d(
// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
// -----
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write
+/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
+///----------------------------------------------------------------------------------------
+
func.func @transfer_write_dims_match_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
- vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
- return
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+ %vec : vector<5x4x3x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
}
// CHECK-LABEL: func @transfer_write_dims_match_contiguous(
-// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
-// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
@@ -208,42 +241,101 @@ func.func @transfer_write_dims_match_contiguous(
// -----
+func.func @transfer_write_dims_match_contiguous_empty_stride(
+ %arg : memref<5x4x3x2xi8>,
+ %vec : vector<5x4x3x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<5x4x3x2xi8>, memref<5x4x3x2xi8>
+ return
+}
+
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous_empty_stride(
+// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8>
+// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
+// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+
+// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous_empty_stride(
+// CHECK-128B: memref.collapse_shape
+
+// -----
+
func.func @transfer_write_dims_mismatch_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
- vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
- return
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+ %vec : vector<1x1x2x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
}
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
-// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
-// CHECK: return
-// CHECK: }
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
// CHECK-128B: memref.collapse_shape
// -----
+func.func @transfer_write_dims_mismatch_non_zero_indices(
+ %idx_1: index,
+ %idx_2: index,
+ %arg: memref<1x43x4x6xi32>,
+ %vec: vector<1x2x6xi32>) {
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ vector.transfer_write %vec, %arg[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
+ vector<1x2x6xi32>, memref<1x43x4x6xi32>
+ return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices(
+// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
+// CHECK-SAME: %[[ARG:.*]]: memref<1x43x4x6xi32>,
+// CHECK-SAME: %[[VEC:.*]]: vector<1x2x6xi32>) {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[ARG]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK: vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32>
+
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices(
+// CHECK-128B-NOT: memref.collapse_shape
+
+// -----
+
+// Overall, the destination memref is non-contiguous. However, the slice to
+// which the input vector is to be written _is_ contiguous. Hence the
+// flattening works fine.
+
func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
%value : vector<2x2xf32>,
%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
- %idx0 : index, %idx1 : index) {
+ %idx0 : index,
+ %idx1 : index) {
+
%c0 = arith.constant 0 : index
vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
return
}
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
-// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK-DAG: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
// CHECK-128B: memref.collapse_shape
@@ -251,11 +343,13 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
// -----
func.func @transfer_write_dims_mismatch_non_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
- vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
- return
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+ %vec : vector<2x1x2x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
}
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
@@ -267,37 +361,76 @@ func.func @transfer_write_dims_mismatch_non_contiguous(
// -----
-func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
- vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
- return
+// The input memref has a dynamic trailing shape and hence is not flattened.
+// TODO: This case could be supported via memref.dim
+
+func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
+ %idx_1: index,
+ %idx_2: index,
+ %vec : vector<1x2x6xi32>,
+ %arg: memref<1x?x4x6xi32>) {
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ vector.transfer_write %vec, %arg[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
+ vector<1x2x6xi32>, memref<1x?x4x6xi32>
+ return
}
-// CHECK-LABEL: func.func @transfer_write_0d
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast
-// CHECK-128B-LABEL: func @transfer_write_0d(
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
// CHECK-128B-NOT: memref.collapse_shape
-// CHECK-128B-NOT: vector.shape_cast
// -----
-func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
- %cst = arith.constant 0 : i8
- %0 = vector.transfer_read %arg[], %cst : memref<i8>, vector<i8>
- return %0 : vector<i8>
+// The vector to be written represents a _non-contiguous_ slice of the output
+// memref.
+
+func.func @transfer_write_dims_mismatch_non_contiguous_slice(
+ %arg : memref<5x4x3x2xi8>,
+ %vec : vector<2x1x2x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] :
+ vector<2x1x2x2xi8>, memref<5x4x3x2xi8>
+ return
}
-// CHECK-LABEL: func.func @transfer_read_0d
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous_slice(
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast
-// CHECK-128B-LABEL: func @transfer_read_0d(
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous_slice(
+// CHECK-128B-NOT: memref.collapse_shape
+
+// -----
+
+func.func @transfer_write_0d(
+ %arg : memref<i8>,
+ %vec : vector<i8>) {
+
+ vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_write_0d
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_write_0d(
// CHECK-128B-NOT: memref.collapse_shape
// CHECK-128B-NOT: vector.shape_cast
// -----
+///----------------------------------------------------------------------------------------
+/// TODO: Categorize + re-format
+///----------------------------------------------------------------------------------------
+
func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
%c0_i8 = arith.constant 0 : i8
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list