[Mlir-commits] [mlir] [mlir][vector] Don't treat memrefs with empty stride as non-contiguous (PR #76848)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 3 11:20:06 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
As per the docs [1]:
```
In absence of an explicit layout, a memref is considered to have a
multi-dimensional identity affine map layout.
```
This patch makes sure that MemRefs with no strides (i.e. no explicit
layout) are treated as contiguous when checking whether a particular
vector is a contiguous slice of the given MemRef.
[1] https://mlir.llvm.org/docs/Dialects/Builtin/#layout
Follow-up for #<!-- -->76428.
---
Full diff: https://github.com/llvm/llvm-project/pull/76848.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+15-11)
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+33-15)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index c1c0f5483a6af5..e9eb65aef6a22e 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -270,20 +270,24 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
if (ShapedType::isDynamicShape(memrefShape))
return false;
- // Cond 1: A contiguous memref will always have a unit trailing stride.
- if (strides.empty() || strides.back() != 1)
- return false;
+ // Cond 1: Check whether `memrefType` is contiguous.
+ if (!strides.empty()) {
+ // Cond 1.1: A contiguous memref will always have a unit trailing stride.
+ if (strides.back() != 1)
+ return false;
- // Cond 2: Strides of a contiguous memref have to match the flattened dims.
- strides = strides.drop_back(1);
- SmallVector<int64_t> flattenedDims;
- for (size_t i = 1; i < memrefShape.size(); i++)
- flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
+ // Cond 1.2: Strides of a contiguous memref have to match the flattened
+ // dims.
+ strides = strides.drop_back(1);
+ SmallVector<int64_t> flattenedDims;
+ for (size_t i = 1; i < memrefShape.size(); i++)
+ flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
- if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
- return false;
+ if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
+ return false;
+ }
- // Cond 3: Compare the dims of `vectorType` against `memrefType` (in reverse).
+ // Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse).
// In the most basic case, all dims will match.
auto firstNonMatchingDim =
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ae457ea81ec5b1..79e2b97148f3f4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -18,6 +18,24 @@ func.func @transfer_read_dims_match_contiguous(
// -----
+func.func @transfer_read_dims_match_contiguous_empty_stride(
+ %arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
+}
+
+// CHECK-LABEL: func @tansfer_read_dims_match_contiguous_empty_stride
+// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
+// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
+// CHECK: return %[[VEC2D]]
+
+// -----
+
// The shape of the memref and the vector don't match, but the vector is a
// contiguous subset of the memref, so "flattenable".
@@ -114,6 +132,21 @@ func.func @transfer_read_dims_mismatch_non_contiguous(
// -----
+func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+ %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
+ return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
func.func @transfer_write_dims_match_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
%c0 = arith.constant 0 : index
@@ -356,18 +389,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
// CHECK: return %[[VAL_4]] : vector<8xi32>
-
-// -----
-
-// This test is to make sure there is no crash for empty stride.
-func.func @stride_empty_test(%1: memref<i16>) -> vector<32x256xi16> {
- %c0_i16 = arith.constant 0 : i16
- %3 = vector.transfer_read %1[], %c0_i16 {permutation_map = affine_map<() -> (0, 0)>} : memref<i16>, vector<32x256xi16>
- return %3 : vector<32x256xi16>
-
- // CHECK-LABEL: func.func @stride_empty_test
- // CHECK: %[[VAL:.*]] = arith.constant 0 : i16
- // CHECK: %[[RET:.*]] = vector.transfer_read {{.*}} vector<32x256xi16>
- // CHECK: return %[[RET]]
- // CHECK-NOT: empty()
-}
``````````
</details>
https://github.com/llvm/llvm-project/pull/76848
More information about the Mlir-commits
mailing list