[Mlir-commits] [mlir] [mlir][memref] `memref.subview`: Verify result strides with rank reductions (PR #80158)
Matthias Springer
llvmlistbot at llvm.org
Fri Feb 2 00:53:52 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/80158
>From 17c0745a45af86b521739b111b453c16c1301810 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 31 Jan 2024 16:33:43 +0000
Subject: [PATCH] [mlir][memref] `memref.subview`: Verify result strides with
rank reductions
This is a follow-up on #79865. Result strides are now also verified if the `memref.subview` op has rank reductions.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 23 ++++++++++++-------
.../Transforms/ExpandStridedMetadata.cpp | 17 ++++++++++----
mlir/test/Dialect/MemRef/canonicalize.mlir | 6 ++---
.../Dialect/MemRef/fold-memref-alias-ops.mlir | 4 ++--
mlir/test/Dialect/MemRef/invalid.mlir | 9 ++++++++
5 files changed, 42 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f43217f6f27ae..841c5d1686b44 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2756,17 +2756,26 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
}
/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
-/// static value).
-static bool haveCompatibleStrides(MemRefType t1, MemRefType t2) {
+/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
+/// marked as dropped in `droppedDims`.
+static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
+ const llvm::SmallBitVector &droppedDims) {
+ assert(t1.getRank() == droppedDims.size() && "incorrect number of bits");
+ assert(t1.getRank() - t2.getRank() == droppedDims.count() &&
+ "incorrect number of dropped dims");
int64_t t1Offset, t2Offset;
SmallVector<int64_t> t1Strides, t2Strides;
auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
if (failed(res1) || failed(res2))
return false;
- for (auto [s1, s2] : llvm::zip_equal(t1Strides, t2Strides))
- if (s1 != s2)
+ for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
+ if (droppedDims[i])
+ continue;
+ if (t1Strides[i] != t2Strides[j])
return false;
+ ++j;
+ }
return true;
}
@@ -2843,10 +2852,8 @@ LogicalResult SubViewOp::verify() {
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
*this, expectedType);
- // Strides must match if there are no rank reductions.
- // TODO: Verify strides when there are rank reductions. Strides are partially
- // checked in `computeMemRefRankReductionMask`.
- if (unusedDims->none() && !haveCompatibleStrides(expectedType, subViewType))
+ // Strides must match.
+ if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
*this, expectedType);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index f6af0791ba756..96eb7cfd2db69 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -144,16 +144,25 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
SmallVector<OpFoldResult> finalStrides;
finalStrides.reserve(subRank);
+#ifndef NDEBUG
+ // Iteration variable for result dimensions of the subview op.
+ int64_t j = 0;
+#endif // NDEBUG
for (unsigned i = 0; i < sourceRank; ++i) {
if (droppedDims.test(i))
continue;
finalSizes.push_back(subSizes[i]);
finalStrides.push_back(strides[i]);
- // TODO: Assert that the computed stride matches the respective stride of
- // the result type of the subview op (if both are static), once the verifier
- // of memref.subview verfies result strides correctly for ops with rank
- // reductions.
+#ifndef NDEBUG
+ // Assert that the computed stride matches the stride of the result type of
+ // the subview op (if both are static).
+ std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
+ if (computedStride && !ShapedType::isDynamic(resultStrides[j]))
+ assert(*computedStride == resultStrides[j] &&
+ "mismatch between computed stride and result type stride");
+ ++j;
+#endif // NDEBUG
}
assert(finalSizes.size() == subRank &&
"Should have populated all the values at this point");
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 993ef32edc9d4..a772a25da5738 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -62,13 +62,13 @@ func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
// -----
func.func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
- %arg2 : index) -> memref<?x?xf32, strided<[?, 1], offset: ?>>
+ %arg2 : index) -> memref<?x?xf32, strided<[?, ?], offset: ?>>
{
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
- %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
- return %0 : memref<?x?xf32, strided<[?, 1], offset: ?>>
+ %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
+ return %0 : memref<?x?xf32, strided<[?, ?], offset: ?>>
}
// CHECK-LABEL: func @rank_reducing_subview_canonicalize
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x?xf32>
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3407bdbc7c8f9..5b853a6cc5a37 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -613,9 +613,9 @@ func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
{
%0 = memref.subview %m[3, 1, 8] [1, %sz, 1] [1, 1, 1]
: memref<?x?x?xf32>
- to memref<?xf32, strided<[1], offset: ?>>
+ to memref<?xf32, strided<[?], offset: ?>>
%1 = memref.subview %0[6] [1] [1]
- : memref<?xf32, strided<[1], offset: ?>>
+ : memref<?xf32, strided<[?], offset: ?>>
to memref<f32, strided<[], offset: ?>>
return %1 : memref<f32, strided<[], offset: ?>>
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index be60a3dcb1b20..8f5ba5ea8fc78 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1082,3 +1082,12 @@ func.func @subview_invalid_strides(%m: memref<7x22x333x4444xi32>) {
: memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32>
return
}
+
+// -----
+
+func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>) {
+ // expected-error @below{{expected result type to be 'memref<7x11x1x4444xi32, strided<[32556744, 2959704, 4444, 1]>>' or a rank-reduced version. (mismatch of result layout)}}
+ %subview = memref.subview %m[0, 0, 0, 0] [7, 11, 1, 4444] [1, 2, 1, 1]
+ : memref<7x22x333x4444xi32> to memref<7x11x4444xi32>
+ return
+}
More information about the Mlir-commits
mailing list