[Mlir-commits] [mlir] Revert "[mlir][memref] `memref.subview`: Verify result strides" (PR #80116)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 31 00:35:15 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Reverts llvm/llvm-project#<!-- -->79865

I think there is a bug in the stride computation in `SubViewOp::inferResultType`. (Was already there before this change.)

Reverting this commit for now and updating the original pull request with a fix and more test cases.


---
Full diff: https://github.com/llvm/llvm-project/pull/80116.diff


5 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+52-68) 
- (modified) mlir/test/Dialect/GPU/decompose-memrefs.mlir (+3-3) 
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+6-6) 
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (-9) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir (+4-4) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f43217f6f27ae..8b5765b7f8dba 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -917,7 +917,7 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
 /// This accounts for cases where there are multiple unit-dims, but only a
 /// subset of those are dropped. For MemRefTypes these can be disambiguated
 /// using the strides. If a dimension is dropped the stride must be dropped too.
-static FailureOr<llvm::SmallBitVector>
+static std::optional<llvm::SmallBitVector>
 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
                                ArrayRef<OpFoldResult> sizes) {
   llvm::SmallBitVector unusedDims(originalType.getRank());
@@ -941,7 +941,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
           getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
       failed(
           getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
-    return failure();
+    return std::nullopt;
 
   // For memrefs, a dimension is truly dropped if its corresponding stride is
   // also dropped. This is particularly important when more than one of the dims
@@ -976,22 +976,22 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
         candidateStridesNumOccurences[originalStride]) {
       // This should never happen. Cant have a stride in the reduced rank type
       // that wasnt in the original one.
-      return failure();
+      return std::nullopt;
     }
   }
 
   if ((int64_t)unusedDims.count() + reducedType.getRank() !=
       originalType.getRank())
-    return failure();
+    return std::nullopt;
   return unusedDims;
 }
 
 llvm::SmallBitVector SubViewOp::getDroppedDims() {
   MemRefType sourceType = getSourceType();
   MemRefType resultType = getType();
-  FailureOr<llvm::SmallBitVector> unusedDims =
+  std::optional<llvm::SmallBitVector> unusedDims =
       computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
-  assert(succeeded(unusedDims) && "unable to find unused dims of subview");
+  assert(unusedDims && "unable to find unused dims of subview");
   return *unusedDims;
 }
 
@@ -2745,7 +2745,7 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
 /// For ViewLikeOpInterface.
 Value SubViewOp::getViewSource() { return getSource(); }
 
-/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
+/// Return true if t1 and t2 have equal offsets (both dynamic or of same
 /// static value).
 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
   int64_t t1Offset, t2Offset;
@@ -2755,41 +2755,56 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
   return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
 }
 
-/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
-/// static value).
-static bool haveCompatibleStrides(MemRefType t1, MemRefType t2) {
-  int64_t t1Offset, t2Offset;
-  SmallVector<int64_t> t1Strides, t2Strides;
-  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
-  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
-  if (failed(res1) || failed(res2))
-    return false;
-  for (auto [s1, s2] : llvm::zip_equal(t1Strides, t2Strides))
-    if (s1 != s2)
-      return false;
-  return true;
+/// Checks if `original` Type type can be rank reduced to `reduced` type.
+/// This function is slight variant of `is subsequence` algorithm where
+/// not matching dimension must be 1.
+static SliceVerificationResult
+isRankReducedMemRefType(MemRefType originalType,
+                        MemRefType candidateRankReducedType,
+                        ArrayRef<OpFoldResult> sizes) {
+  auto partialRes = isRankReducedType(originalType, candidateRankReducedType);
+  if (partialRes != SliceVerificationResult::Success)
+    return partialRes;
+
+  auto optionalUnusedDimsMask = computeMemRefRankReductionMask(
+      originalType, candidateRankReducedType, sizes);
+
+  // Sizes cannot be matched in case empty vector is returned.
+  if (!optionalUnusedDimsMask)
+    return SliceVerificationResult::LayoutMismatch;
+
+  if (originalType.getMemorySpace() !=
+      candidateRankReducedType.getMemorySpace())
+    return SliceVerificationResult::MemSpaceMismatch;
+
+  // No amount of stride dropping can reconcile incompatible offsets.
+  if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
+    return SliceVerificationResult::LayoutMismatch;
+
+  return SliceVerificationResult::Success;
 }
 
+template <typename OpTy>
 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
-                                            Operation *op, Type expectedType) {
+                                            OpTy op, Type expectedType) {
   auto memrefType = llvm::cast<ShapedType>(expectedType);
   switch (result) {
   case SliceVerificationResult::Success:
     return success();
   case SliceVerificationResult::RankTooLarge:
-    return op->emitError("expected result rank to be smaller or equal to ")
+    return op.emitError("expected result rank to be smaller or equal to ")
            << "the source rank. ";
   case SliceVerificationResult::SizeMismatch:
-    return op->emitError("expected result type to be ")
+    return op.emitError("expected result type to be ")
            << expectedType
            << " or a rank-reduced version. (mismatch of result sizes) ";
   case SliceVerificationResult::ElemTypeMismatch:
-    return op->emitError("expected result element type to be ")
+    return op.emitError("expected result element type to be ")
            << memrefType.getElementType();
   case SliceVerificationResult::MemSpaceMismatch:
-    return op->emitError("expected result and source memory spaces to match.");
+    return op.emitError("expected result and source memory spaces to match.");
   case SliceVerificationResult::LayoutMismatch:
-    return op->emitError("expected result type to be ")
+    return op.emitError("expected result type to be ")
            << expectedType
            << " or a rank-reduced version. (mismatch of result layout) ";
   }
@@ -2811,46 +2826,13 @@ LogicalResult SubViewOp::verify() {
   if (!isStrided(baseType))
     return emitError("base type ") << baseType << " is not strided";
 
-  // Compute the expected result type, assuming that there are no rank
-  // reductions.
-  auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
-      baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
-
-  // Verify all properties of a shaped type: rank, element type and dimension
-  // sizes. This takes into account potential rank reductions.
-  auto shapedTypeVerification = isRankReducedType(
-      /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
-  if (shapedTypeVerification != SliceVerificationResult::Success)
-    return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
-
-  // Make sure that the memory space did not change.
-  if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
-    return produceSubViewErrorMsg(SliceVerificationResult::MemSpaceMismatch,
-                                  *this, expectedType);
-
-  // Verify the offset of the layout map.
-  if (!haveCompatibleOffsets(expectedType, subViewType))
-    return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
-                                  *this, expectedType);
-
-  // The only thing that's left to verify now are the strides. First, compute
-  // the unused dimensions due to rank reductions. We have to look at sizes and
-  // strides to decide which dimensions were dropped. This function also
-  // partially verifies strides in case of rank reductions.
-  auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
-                                                   getMixedSizes());
-  if (failed(unusedDims))
-    return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
-                                  *this, expectedType);
-
-  // Strides must match if there are no rank reductions.
-  // TODO: Verify strides when there are rank reductions. Strides are partially
-  // checked in `computeMemRefRankReductionMask`.
-  if (unusedDims->none() && !haveCompatibleStrides(expectedType, subViewType))
-    return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
-                                  *this, expectedType);
+  // Verify result type against inferred type.
+  auto expectedType = SubViewOp::inferResultType(
+      baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides());
 
-  return success();
+  auto result = isRankReducedMemRefType(llvm::cast<MemRefType>(expectedType),
+                                        subViewType, getMixedSizes());
+  return produceSubViewErrorMsg(result, *this, expectedType);
 }
 
 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
@@ -2900,9 +2882,11 @@ static MemRefType getCanonicalSubViewResultType(
     ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
   auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
       sourceType, mixedOffsets, mixedSizes, mixedStrides));
-  FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
-      currentSourceType, currentResultType, mixedSizes);
-  if (failed(unusedDims))
+  std::optional<llvm::SmallBitVector> unusedDims =
+      computeMemRefRankReductionMask(currentSourceType, currentResultType,
+                                     mixedSizes);
+  // Return nullptr as failure mode.
+  if (!unusedDims)
     return nullptr;
 
   auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
diff --git a/mlir/test/Dialect/GPU/decompose-memrefs.mlir b/mlir/test/Dialect/GPU/decompose-memrefs.mlir
index 56fc9a66b7ace..d714010d0f254 100644
--- a/mlir/test/Dialect/GPU/decompose-memrefs.mlir
+++ b/mlir/test/Dialect/GPU/decompose-memrefs.mlir
@@ -119,7 +119,7 @@ func.func @decompose_subview(%arg0 : memref<?x?x?xf32>) {
 //       CHECK:  %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#1]
 //       CHECK:  %[[IDX2:.*]] = affine.apply #[[MAP2]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
 //       CHECK:  %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX2]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[IDX]], %[[IDX1]], 4]
-//       CHECK:  "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>) -> ()
+//       CHECK:  "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
 func.func @decompose_subview_strided(%arg0 : memref<?x?x?xf32>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -129,8 +129,8 @@ func.func @decompose_subview_strided(%arg0 : memref<?x?x?xf32>) {
   %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32>
   gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
              threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) {
-    %res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [2, 3, 4] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>
-    "test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>) -> ()
+    %res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [2, 3, 4] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+    "test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> ()
     gpu.terminator
   }
   return
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3407bdbc7c8f9..96b72e042b9e0 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -595,9 +595,9 @@ func.func @subview_of_subview(%m: memref<1x1024xf32, 3>, %pos: index)
 {
   %0 = memref.subview %m[3, %pos] [1, 2] [1, 1]
       : memref<1x1024xf32, 3>
-        to memref<1x2xf32, strided<[1024, 1], offset: ?>, 3>
+        to memref<1x2xf32, strided<[1024, 2], offset: ?>, 3>
   %1 = memref.subview %0[1, 2] [1, 1] [1, 1]
-      : memref<1x2xf32, strided<[1024, 1], offset: ?>, 3>
+      : memref<1x2xf32, strided<[1024, 2], offset: ?>, 3>
         to memref<f32, strided<[], offset: ?>, 3>
   return %1 : memref<f32, strided<[], offset: ?>, 3>
 }
@@ -675,9 +675,9 @@ func.func @fold_gpu_subgroup_mma_store_matrix_1d(%dst: memref<?xvector<4xf32>>,
 // CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
 //  CHECK-SAME: %[[SRC:.+]]: memref<128x128xf32>
 func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
-  %subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[256, 1], offset: ?>>
+  %subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[64, 1], offset: ?>>
   // CHECK: gpu.subgroup_mma_load_matrix %[[SRC]][{{.+}}] {leadDimension = 32 : index} : memref<128x128xf32> -> !gpu.mma_matrix<16x16xf16, "COp">
-  %matrix = gpu.subgroup_mma_load_matrix %subview[%arg3, %arg4] {leadDimension = 32 : index} : memref<64x32xf32, strided<[256, 1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
+  %matrix = gpu.subgroup_mma_load_matrix %subview[%arg3, %arg4] {leadDimension = 32 : index} : memref<64x32xf32, strided<[64, 1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
   return %matrix : !gpu.mma_matrix<16x16xf16, "COp">
 }
 
@@ -686,9 +686,9 @@ func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %ar
 // CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d
 //  CHECK-SAME: %[[DST:.+]]: memref<128x128xf32>
 func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) {
-  %subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[256, 1], offset: ?>>
+  %subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[64, 1], offset: ?>>
   // CHECK: gpu.subgroup_mma_store_matrix %{{.+}}, %[[DST]][{{.+}}] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf32>
-  gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} :  !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[256, 1], offset: ?>>
+  gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} :  !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[64, 1], offset: ?>>
   return
 }
 
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index be60a3dcb1b20..7bb7a2affcbd1 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1073,12 +1073,3 @@ func.func @dim_0_ranked(%arg : memref<f32>, %arg1 : index) {
   memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any type values or non-0-ranked.memref of any type values, but got 'memref<f32>'}}
   return
 }
-
-// -----
-
-func.func @subview_invalid_strides(%m: memref<7x22x333x4444xi32>) {
-  // expected-error @below{{expected result type to be 'memref<7x11x333x4444xi32, strided<[32556744, 2959704, 4444, 1]>>' or a rank-reduced version. (mismatch of result layout)}}
-  %subview = memref.subview %m[0, 0, 0, 0] [7, 11, 333, 4444] [1, 2, 1, 1]
-      : memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32>
-  return
-}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index e7dd0ad32a243..3773cca9c8d69 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -88,10 +88,10 @@ module {
     // Prepare a buffer for x0, x1, x2, y0 and a buffer for y1.
     %xys = memref.alloc() : memref<20xi32>
     %xy = memref.cast %xys : memref<20xi32> to memref<?xi32>
-    %x0 = memref.subview %xy[%i0][%i5][4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
-    %x1 = memref.subview %xy[%i1][%i5][4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
-    %x2 = memref.subview %xy[%i2][%i5][4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
-    %y0 = memref.subview %xy[%i3][%i5][4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
+    %x0 = memref.subview %xy[%i0][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
+    %x1 = memref.subview %xy[%i1][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
+    %x2 = memref.subview %xy[%i2][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
+    %y0 = memref.subview %xy[%i3][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
     %y1s = memref.alloc() : memref<7xi32>
     %y1 = memref.cast %y1s : memref<7xi32> to memref<?xi32>
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/80116


More information about the Mlir-commits mailing list