[Mlir-commits] [mlir] [mlir][memref] `memref.subview`: Verify result strides with rank reductions (PR #80158)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 31 08:36:12 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This is a follow-up on #<!-- -->79865. Result strides are now also verified if the `memref.subview` op has rank reductions.

---
Full diff: https://github.com/llvm/llvm-project/pull/80158.diff


5 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+15-8) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (+13-4) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+3-3) 
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+2-2) 
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+9) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f43217f6f27ae..a6624e9c8482c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2756,17 +2756,26 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
 }
 
 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
-/// static value).
-static bool haveCompatibleStrides(MemRefType t1, MemRefType t2) {
+/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
+/// marked as dropped in `droppedDims`.
+static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
+                                  const llvm::SmallBitVector &droppedDims) {
+  assert(t1.getRank() == droppedDims.size() && "incorrect number of bits");
+  assert(t1.getRank() - t2.getRank() == droppedDims.count() &&
+         "incorrect number of dropped dims");
   int64_t t1Offset, t2Offset;
   SmallVector<int64_t> t1Strides, t2Strides;
   auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
   auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
   if (failed(res1) || failed(res2))
     return false;
-  for (auto [s1, s2] : llvm::zip_equal(t1Strides, t2Strides))
-    if (s1 != s2)
+  for (int64_t i = 0, j = 0; i < t1.getRank(); ++i) {
+    if (droppedDims[i])
+      continue;
+    if (t1Strides[i] != t2Strides[j])
       return false;
+    ++j;
+  }
   return true;
 }
 
@@ -2843,10 +2852,8 @@ LogicalResult SubViewOp::verify() {
     return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
                                   *this, expectedType);
 
-  // Strides must match if there are no rank reductions.
-  // TODO: Verify strides when there are rank reductions. Strides are partially
-  // checked in `computeMemRefRankReductionMask`.
-  if (unusedDims->none() && !haveCompatibleStrides(expectedType, subViewType))
+  // Strides must match.
+  if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
     return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
                                   *this, expectedType);
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index f6af0791ba756..96eb7cfd2db69 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -144,16 +144,25 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
   SmallVector<OpFoldResult> finalStrides;
   finalStrides.reserve(subRank);
 
+#ifndef NDEBUG
+  // Iteration variable for result dimensions of the subview op.
+  int64_t j = 0;
+#endif // NDEBUG
   for (unsigned i = 0; i < sourceRank; ++i) {
     if (droppedDims.test(i))
       continue;
 
     finalSizes.push_back(subSizes[i]);
     finalStrides.push_back(strides[i]);
-    // TODO: Assert that the computed stride matches the respective stride of
-    // the result type of the subview op (if both are static), once the verifier
-    // of memref.subview verfies result strides correctly for ops with rank
-    // reductions.
+#ifndef NDEBUG
+    // Assert that the computed stride matches the stride of the result type of
+    // the subview op (if both are static).
+    std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
+    if (computedStride && !ShapedType::isDynamic(resultStrides[j]))
+      assert(*computedStride == resultStrides[j] &&
+             "mismatch between computed stride and result type stride");
+    ++j;
+#endif // NDEBUG
   }
   assert(finalSizes.size() == subRank &&
          "Should have populated all the values at this point");
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 993ef32edc9d4..a772a25da5738 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -62,13 +62,13 @@ func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
 // -----
 
 func.func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
-  %arg2 : index) -> memref<?x?xf32, strided<[?, 1], offset: ?>>
+  %arg2 : index) -> memref<?x?xf32, strided<[?, ?], offset: ?>>
 {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
-  %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
-  return %0 : memref<?x?xf32, strided<[?, 1], offset: ?>>
+  %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+  return %0 : memref<?x?xf32, strided<[?, ?], offset: ?>>
 }
 // CHECK-LABEL: func @rank_reducing_subview_canonicalize
 //  CHECK-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3407bdbc7c8f9..5b853a6cc5a37 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -613,9 +613,9 @@ func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
 {
   %0 = memref.subview %m[3, 1, 8] [1, %sz, 1] [1, 1, 1]
       : memref<?x?x?xf32>
-        to memref<?xf32, strided<[1], offset: ?>>
+        to memref<?xf32, strided<[?], offset: ?>>
   %1 = memref.subview %0[6] [1] [1]
-      : memref<?xf32, strided<[1], offset: ?>>
+      : memref<?xf32, strided<[?], offset: ?>>
         to memref<f32, strided<[], offset: ?>>
   return %1 : memref<f32, strided<[], offset: ?>>
 }
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index be60a3dcb1b20..8f5ba5ea8fc78 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1082,3 +1082,12 @@ func.func @subview_invalid_strides(%m: memref<7x22x333x4444xi32>) {
       : memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32>
   return
 }
+
+// -----
+
+func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>) {
+  // expected-error @below{{expected result type to be 'memref<7x11x1x4444xi32, strided<[32556744, 2959704, 4444, 1]>>' or a rank-reduced version. (mismatch of result layout)}}
+  %subview = memref.subview %m[0, 0, 0, 0] [7, 11, 1, 4444] [1, 2, 1, 1]
+      : memref<7x22x333x4444xi32> to memref<7x11x4444xi32>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/80158


More information about the Mlir-commits mailing list