[Mlir-commits] [mlir] [mlir][vector] Don't treat memrefs with empty stride as non-contiguous (PR #76848)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 3 11:20:06 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

As per the docs [1]:

```
In absence of an explicit layout, a memref is considered to have a
multi-dimensional identity affine map layout.
```

This patch makes sure that MemRefs with no strides (i.e. no explicit
layout) are treated as contiguous when checking whether a particular
vector is a contiguous slice of the given MemRef.

[1] https://mlir.llvm.org/docs/Dialects/Builtin/#layout

Follow-up for #<!-- -->76428.


---
Full diff: https://github.com/llvm/llvm-project/pull/76848.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+15-11) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+33-15) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index c1c0f5483a6af5..e9eb65aef6a22e 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -270,20 +270,24 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
   if (ShapedType::isDynamicShape(memrefShape))
     return false;
 
-  // Cond 1: A contiguous memref will always have a unit trailing stride.
-  if (strides.empty() || strides.back() != 1)
-    return false;
+  // Cond 1: Check whether `memrefType` is contiguous.
+  if (!strides.empty()) {
+    // Cond 1.1: A contiguous memref will always have a unit trailing stride.
+    if (strides.back() != 1)
+      return false;
 
-  // Cond 2: Strides of a contiguous memref have to match the flattened dims.
-  strides = strides.drop_back(1);
-  SmallVector<int64_t> flattenedDims;
-  for (size_t i = 1; i < memrefShape.size(); i++)
-    flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
+    // Cond 1.2: Strides of a contiguous memref have to match the flattened
+    // dims.
+    strides = strides.drop_back(1);
+    SmallVector<int64_t> flattenedDims;
+    for (size_t i = 1; i < memrefShape.size(); i++)
+      flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
 
-  if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
-    return false;
+    if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
+      return false;
+  }
 
-  // Cond 3: Compare the dims of `vectorType` against `memrefType` (in reverse).
+  // Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse).
   // In the most basic case, all dims will match.
   auto firstNonMatchingDim =
       std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ae457ea81ec5b1..79e2b97148f3f4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -18,6 +18,24 @@ func.func @transfer_read_dims_match_contiguous(
 
 // -----
 
+func.func @transfer_read_dims_match_contiguous_empty_stride(
+    %arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
+    return %v : vector<5x4x3x2xi8>
+}
+
+// CHECK-LABEL: func @tansfer_read_dims_match_contiguous_empty_stride
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
+// CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
+// CHECK:         return %[[VEC2D]]
+
+// -----
+
 // The shape of the memref and the vector don't match, but the vector is a
 // contiguous subset of the memref, so "flattenable".
 
@@ -114,6 +132,21 @@ func.func @transfer_read_dims_mismatch_non_contiguous(
 
 // -----
 
+func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+    %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
+    return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
 func.func @transfer_write_dims_match_contiguous(
       %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
     %c0 = arith.constant 0 : index
@@ -356,18 +389,3 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 // CHECK:           %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
 // CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
 // CHECK:           return %[[VAL_4]] : vector<8xi32>
-
-// -----
-
-// This test is to make sure there is no crash for empty stride.
-func.func @stride_empty_test(%1: memref<i16>) -> vector<32x256xi16> {
-  %c0_i16 = arith.constant 0 : i16
-  %3 = vector.transfer_read %1[], %c0_i16 {permutation_map = affine_map<() -> (0, 0)>} : memref<i16>, vector<32x256xi16>
-  return %3 : vector<32x256xi16>
-
-  // CHECK-LABEL: func.func @stride_empty_test
-  // CHECK: %[[VAL:.*]] = arith.constant 0 : i16
-  // CHECK: %[[RET:.*]] = vector.transfer_read {{.*}} vector<32x256xi16>
-  // CHECK: return %[[RET]]
-  // CHECK-NOT: empty()
-}

``````````

</details>


https://github.com/llvm/llvm-project/pull/76848


More information about the Mlir-commits mailing list