[Mlir-commits] [mlir] [mlir][memref][WIP] `memref.subview`: Verify result strides (PR #79865)

Matthias Springer llvmlistbot at llvm.org
Mon Jan 29 09:01:03 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/79865

The strides of the result types are currently not verified for `memref.subview` ops that have no rank reductions.

WIP: This is still failing some test cases. It also looks like the verification of result strides is incomplete (maybe also incorrect) for ops with rank reduction.

>From dc736637144703d11e780b2b198f24f769ff6205 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 29 Jan 2024 16:58:35 +0000
Subject: [PATCH] [mlir][memref][WIP] `memref.subview`: Verify result strides

The strides of the result types are currently not verified for `memref.subview` ops that have no rank reductions.

WIP: This is still failing some test cases. It also looks like the verification of result strides is incomplete (maybe also incorrect) for ops with rank reduction.
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 24 +++++++++++++++++++++++-
 mlir/test/Dialect/MemRef/invalid.mlir    |  9 +++++++++
 2 files changed, 32 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b79ab8f3d671e0..2b8899ce1634f3 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -931,6 +931,7 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
 
   // Early exit for the case where the number of unused dims matches the number
   // of ranks reduced.
+  // TODO: Verify strides.
   if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
       originalType.getRank())
     return unusedDims;
@@ -2745,7 +2746,7 @@ void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
 /// For ViewLikeOpInterface.
 Value SubViewOp::getViewSource() { return getSource(); }
 
-/// Return true if t1 and t2 have equal offsets (both dynamic or of same
+/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
 /// static value).
 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
   int64_t t1Offset, t2Offset;
@@ -2755,6 +2756,21 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
   return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
 }
 
+/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
+/// static value).
+static bool haveCompatibleStrides(MemRefType t1, MemRefType t2) {
+  int64_t t1Offset, t2Offset;
+  SmallVector<int64_t> t1Strides, t2Strides;
+  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
+  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
+  if (failed(res1) || failed(res2))
+    return false;
+  for (auto [s1, s2] : llvm::zip_equal(t1Strides, t2Strides))
+    if (s1 != s2)
+      return false;
+  return true;
+}
+
 /// Checks if `original` Type type can be rank reduced to `reduced` type.
 /// This function is slight variant of `is subsequence` algorithm where
 /// not matching dimension must be 1.
@@ -2781,6 +2797,12 @@ isRankReducedMemRefType(MemRefType originalType,
   if (!haveCompatibleOffsets(originalType, candidateRankReducedType))
     return SliceVerificationResult::LayoutMismatch;
 
+  // Strides must match if there are no rank reductions. In case of rank
+  // reductions, the strides are checked by `computeMemRefRankReductionMask`.
+  if (optionalUnusedDimsMask->none() &&
+      !haveCompatibleStrides(originalType, candidateRankReducedType))
+    return SliceVerificationResult::LayoutMismatch;
+
   return SliceVerificationResult::Success;
 }
 
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 7bb7a2affcbd19..be60a3dcb1b201 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -1073,3 +1073,12 @@ func.func @dim_0_ranked(%arg : memref<f32>, %arg1 : index) {
   memref.dim %arg, %arg1 : memref<f32> // expected-error {{'memref.dim' op operand #0 must be unranked.memref of any type values or non-0-ranked.memref of any type values, but got 'memref<f32>'}}
   return
 }
+
+// -----
+
+func.func @subview_invalid_strides(%m: memref<7x22x333x4444xi32>) {
+  // expected-error @below{{expected result type to be 'memref<7x11x333x4444xi32, strided<[32556744, 2959704, 4444, 1]>>' or a rank-reduced version. (mismatch of result layout)}}
+  %subview = memref.subview %m[0, 0, 0, 0] [7, 11, 333, 4444] [1, 2, 1, 1]
+      : memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32>
+  return
+}



More information about the Mlir-commits mailing list