[Mlir-commits] [mlir] [MLIR] Determine contiguousness of memrefs with dynamic dimensions (PR #142421)

Momchil Velikov llvmlistbot at llvm.org
Thu Jun 19 06:46:35 PDT 2025


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142421

>From 3f5355acd26929528134890ddab1854af59e2df6 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 2 Jun 2025 13:20:14 +0000
Subject: [PATCH 01/11] [MLIR] Determine contiguousness of memrefs with a
 dynamic dimensions

This patch enhances `MemRefType::areTrailingDimsContiguous`
to also handle memrefs with dynamic dimensions.

The implementation itself is based on a new member function
`MemRefType::getMaxCollapsableTrailingDims` that return
the maximum number of trailing dimensions that can be
collapsed - trivially all dimensions for memrefs with identity
layout, or by examining the memref strides stopping at discontguous
or statically unknown strides.
---
 mlir/include/mlir/IR/BuiltinTypes.td          |  14 ++
 mlir/lib/IR/BuiltinTypes.cpp                  |  47 +++--
 .../Vector/vector-transfer-flatten.mlir       |  49 ++++-
 mlir/unittests/Dialect/MemRef/CMakeLists.txt  |   1 +
 mlir/unittests/Dialect/MemRef/LayoutTest.cpp  | 190 ++++++++++++++++++
 5 files changed, 269 insertions(+), 32 deletions(-)
 create mode 100644 mlir/unittests/Dialect/MemRef/LayoutTest.cpp

diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..1d12f70882176 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -838,6 +838,20 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
     ///
     bool areTrailingDimsContiguous(int64_t n);
 
+    /// Return the maximum number of trailing dimensions that can be
+    /// collapsed.
+    ///
+    /// Examples:
+    ///   - memref<2x3x2xi8, strided<[24, 12, 2]>, the number of collapsable
+    ///     trailing dimensions is 0
+    ///   - memref<2x3x2xi8, strided<[12, 6, 1]>, the number of collapsable
+    ///     trailing dimensions is 3
+    ///   - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]>, the number of
+    ///     collapsable trailing dimensions is 2.
+    ///   - memref<5x4x?x2xi8>, the number of collapsable trailing dimensions
+    ///     is 4.
+    int64_t getMaxCollapsableTrailingDims();
+
     /// Return a version of this type with identity layout if it can be
     /// determined statically that the layout is the canonical contiguous
     /// strided layout. Otherwise pass the layout into `simplifyAffineMap`
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d47e360e9dc13..cc23d08515ff3 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -646,35 +646,40 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
 }
 
 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
-  if (!isLastDimUnitStride())
-    return false;
+  return getLayout().isIdentity() ||
+         getMaxCollapsableTrailingDims() >= std::min(n, getRank());
+}
 
-  auto memrefShape = getShape().take_back(n);
-  if (ShapedType::isDynamicShape(memrefShape))
-    return false;
+int64_t MemRefType::getMaxCollapsableTrailingDims() {
+  const int64_t n = getRank();
 
+  // memrefs with identity layout are entirely contiguous.
   if (getLayout().isIdentity())
-    return true;
+    return n;
 
+  // Get the strides (if any). Failing to do that, conservatively assume a
+  // non-contiguous layout.
   int64_t offset;
-  SmallVector<int64_t> stridesFull;
-  if (!succeeded(getStridesAndOffset(stridesFull, offset)))
-    return false;
-  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
-
-  if (strides.empty())
-    return true;
+  SmallVector<int64_t> strides;
+  if (!succeeded(getStridesAndOffset(strides, offset)))
+    return 0;
 
-  // Check whether strides match "flattened" dims.
-  SmallVector<int64_t> flattenedDims;
-  auto dimProduct = 1;
-  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
-    dimProduct *= dim;
-    flattenedDims.push_back(dimProduct);
+  auto shape = getShape();
+
+  // A memref with dimensions `d0, d1, ..., dn-1` and strides
+  // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
+  // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
+  // for `i` in `[k, n-1]`.
+  int64_t dimProduct = 1;
+  for (int64_t i = n - 1; i >= 0; --i) {
+    if (strides[i] != dimProduct)
+      return n - i - 1;
+    if (shape[i] == ShapedType::kDynamic)
+      return n - i;
+    dimProduct *= shape[i];
   }
 
-  strides = strides.drop_back(1);
-  return llvm::equal(strides, llvm::reverse(flattenedDims));
+  return n;
 }
 
 MemRefType MemRefType::canonicalizeStridedLayout() {
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index e840dc6bbf224..5b2f2ab1f2cef 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -190,7 +190,7 @@ func.func @transfer_read_leading_dynamic_dims(
 
 // One of the dims to be flattened is dynamic - not supported ATM.
 
-func.func @negative_transfer_read_dynamic_dim_to_flatten(
+func.func @transfer_read_dynamic_dim_to_flatten(
     %idx_1: index,
     %idx_2: index,
     %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -203,11 +203,25 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
   return %res : vector<1x2x6xi32>
 }
 
-// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
+// CHECK-SAME:    %[[IDX_1:arg0]]
+// CHECK-SAME:    %[[IDX_2:arg1]]
+// CHECK-SAME:    %[[MEM:arg2]]
+// CHECK:              %[[C0_I32:.*]] = arith.constant 0 : i32
+// CHECK:              %[[C0:.*]] = arith.constant 0 : index
+// CHECK:              %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL:   [[0], [1, 2, 3]]
+// CHECK-SAME:           memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK:              %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
+// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK:              %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
+// CHECK:              return %[[RESULT]] : vector<1x2x6xi32>
+
+
+// CHECK-128B-LABEL: func @transfer_read_dynamic_dim_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
@@ -453,7 +467,7 @@ func.func @transfer_write_leading_dynamic_dims(
 
 // One of the dims to be flattened is dynamic - not supported ATM.
 
-func.func @negative_transfer_write_dynamic_to_flatten(
+func.func @transfer_write_dynamic_to_flatten(
     %idx_1: index,
     %idx_2: index,
     %vec : vector<1x2x6xi32>,
@@ -466,11 +480,24 @@ func.func @negative_transfer_write_dynamic_to_flatten(
   return
 }
 
-// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten
+// CHECK-SAME:    %[[IDX_1:arg0]]: index
+// CHECK-SAME:    %[[IDX_2:arg1]]: index
+// CHECK-SAME:    %[[VEC:arg2]]: vector<1x2x6xi32>
+// CHECK-SAME:    %[[MEM:arg3]]: memref<1x?x4x6xi32>
+
+// CHECK:              %[[C0:.*]] = arith.constant 0 : index
+// CHECK:              %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL:   [[0], [1, 2, 3]]
+// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK:              %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK:              vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
+// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
 
-// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
+// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index dede3ba0a885c..1f6df1024f430 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRMemRefTests
   InferShapeTest.cpp
+  LayoutTest.cpp
 )
 mlir_target_link_libraries(MLIRMemRefTests
   PRIVATE
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
new file mode 100644
index 0000000000000..e01c0056d5cec
--- /dev/null
+++ b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
@@ -0,0 +1,190 @@
+//===- LayoutTest.cpp - unit tests related to memref layout --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+TEST(MemRefLayout, maxCollapseDim) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+
+  const auto _ = ShapedType::kDynamic;
+  const auto f32 = b.getF32Type();
+  auto strided = [&ctx](ArrayRef<int64_t> s) {
+    return StridedLayoutAttr::get(&ctx, 0, s);
+  };
+
+  // memref<2x2x2xf32, strided<[4,2,1]>
+  auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+  EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+
+  // memref<2x2x2xf32, strided<[8,2,1]>
+  auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+  EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 2);
+
+  // memref<2x2x2xf32, strided<[8,4,1]>
+  auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
+  EXPECT_EQ(m3.getMaxCollapsableTrailingDims(), 1);
+
+  // memref<2x2x2xf32, strided<[8,4,2]>
+  auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+  EXPECT_EQ(m4.getMaxCollapsableTrailingDims(), 0);
+
+  // memref<2x2x?xf32, strided<[?,?,1]>
+  auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
+  EXPECT_EQ(m5.getMaxCollapsableTrailingDims(), 1);
+
+  // memref<2x2x?xf32, strided<[?,?,2]>
+  auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
+  EXPECT_EQ(m6.getMaxCollapsableTrailingDims(), 0);
+
+  // memref<2x?x2xf32, strided<[?,2,1]>
+  auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+  EXPECT_EQ(m7.getMaxCollapsableTrailingDims(), 2);
+
+  // memref<2x?x2xf32, strided<[?,4,1]>
+  auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
+  EXPECT_EQ(m8.getMaxCollapsableTrailingDims(), 1);
+
+  // memref<2x?x2xf32, strided<[?,4,2]>
+  auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
+  EXPECT_EQ(m9.getMaxCollapsableTrailingDims(), 0);
+
+  // memref<?x2x2xf32, strided<[4,2,1]>
+  auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
+  EXPECT_EQ(m10.getMaxCollapsableTrailingDims(), 3);
+
+  // memref<?x2x2xf32, strided<[8,2,1]>
+  auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
+  EXPECT_EQ(m11.getMaxCollapsableTrailingDims(), 2);
+
+  // memref<?x2x2xf32, strided<[8,4,1]>
+  auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
+  EXPECT_EQ(m12.getMaxCollapsableTrailingDims(), 1);
+
+  // memref<?x2x2xf32, strided<[8,4,2]>
+  auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
+  EXPECT_EQ(m13.getMaxCollapsableTrailingDims(), 0);
+}
+
+TEST(MemRefLayout, contigTrailingDim) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+
+  const auto _ = ShapedType::kDynamic;
+  const auto f32 = b.getF32Type();
+  auto strided = [&ctx](ArrayRef<int64_t> s) {
+    return StridedLayoutAttr::get(&ctx, 0, s);
+  };
+
+  // memref<2x2x2xf32, strided<[4,2,1]>
+  auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+  EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
+  EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
+
+  // memref<2x2x2xf32, strided<[8,2,1]>
+  auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+  EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m2.areTrailingDimsContiguous(3));
+
+  // memref<2x2x2xf32, strided<[8,4,1]>
+  auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
+  EXPECT_TRUE(m3.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m3.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m3.areTrailingDimsContiguous(3));
+
+  // memref<2x2x2xf32, strided<[8,4,2]>
+  auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+  EXPECT_FALSE(m4.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m4.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m4.areTrailingDimsContiguous(3));
+
+  // memref<2x2x?xf32, strided<[?,?,1]>
+  auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
+  EXPECT_TRUE(m5.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m5.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m5.areTrailingDimsContiguous(3));
+
+  // memref<2x2x?xf32, strided<[?,?,2]>
+  auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
+  EXPECT_FALSE(m6.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m6.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m6.areTrailingDimsContiguous(3));
+
+  // memref<2x?x2xf32, strided<[?,2,1]>
+  auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+  EXPECT_TRUE(m7.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m7.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m7.areTrailingDimsContiguous(3));
+
+  // memref<2x?x2xf32, strided<[?,4,1]>
+  auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
+  EXPECT_TRUE(m8.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m8.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m8.areTrailingDimsContiguous(3));
+
+  // memref<2x?x2xf32, strided<[?,4,2]>
+  auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
+  EXPECT_FALSE(m9.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m9.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m9.areTrailingDimsContiguous(3));
+
+  // memref<?x2x2xf32, strided<[4,2,1]>
+  auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
+  EXPECT_TRUE(m10.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m10.areTrailingDimsContiguous(2));
+  EXPECT_TRUE(m10.areTrailingDimsContiguous(3));
+
+  // memref<?x2x2xf32, strided<[8,2,1]>
+  auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
+  EXPECT_TRUE(m11.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m11.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m11.areTrailingDimsContiguous(3));
+
+  // memref<?x2x2xf32, strided<[8,4,1]>
+  auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
+  EXPECT_TRUE(m12.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m12.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m12.areTrailingDimsContiguous(3));
+
+  // memref<?x2x2xf32, strided<[8,4,2]>
+  auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
+  EXPECT_FALSE(m13.areTrailingDimsContiguous(1));
+  EXPECT_FALSE(m13.areTrailingDimsContiguous(2));
+  EXPECT_FALSE(m13.areTrailingDimsContiguous(3));
+}
+
+TEST(MemRefLayout, identityMaps) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+
+  const auto _ = ShapedType::kDynamic;
+  const auto f32 = b.getF32Type();
+
+  // memref<2x2x2xf32>
+  auto m1 = MemRefType::get({2, 2, 2}, f32);
+  EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+  EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
+  EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
+
+  // memref<?x?x?xf32>
+  auto m2 = MemRefType::get({_, _, _}, f32);
+  EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 3);
+  EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
+  EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
+  EXPECT_TRUE(m2.areTrailingDimsContiguous(3));
+}

>From a43a9dc14d2c1adf6d92faa71ae85cc1c3806122 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 3 Jun 2025 16:09:05 +0000
Subject: [PATCH 02/11] [fixup] Fix an assertion

`computeStrides` does not acccess the first element of `sizes`
---
 mlir/include/mlir/Dialect/Utils/IndexingUtils.h | 2 +-
 mlir/lib/Dialect/Utils/IndexingUtils.cpp        | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 99218f491ddef..8524072929793 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -40,7 +40,7 @@ class ArrayAttr;
 /// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
 ///   `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
 ///
-/// `sizes` elements are asserted to be non-negative.
+/// `sizes` elements `s1` to `sn` are asserted to be non-negative.
 ///
 /// Return an empty vector if `sizes` is empty.
 SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 8de77e2c3cb08..3efe0edeaeb04 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -69,7 +69,8 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
 //===----------------------------------------------------------------------===//
 
 SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
-  assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
+  assert((sizes.size() == 0 ||
+          llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
          "sizes must be nonnegative");
   int64_t unit = 1;
   return ::computeSuffixProductImpl(sizes, unit);

>From 047b99f08a00263d4082229d661dae84c28aac7b Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 4 Jun 2025 10:37:37 +0000
Subject: [PATCH 03/11] [fixup] Misc NFC changes

 - rename `getMaxCollapsabelTrailingDims` to `getMaxContiguousTrailingDims`
 - new set of examples
 - remove redundant call to `isIdentify()`
 - make sure a variable type is visible on the declaration line
 - some micro-optimisation
---
 mlir/include/mlir/IR/BuiltinTypes.td         | 26 ++++++++++-------
 mlir/lib/IR/BuiltinTypes.cpp                 |  9 +++---
 mlir/unittests/Dialect/MemRef/LayoutTest.cpp | 30 ++++++++++----------
 3 files changed, 36 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1d12f70882176..dab6465e8bcf9 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -838,19 +838,25 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
     ///
     bool areTrailingDimsContiguous(int64_t n);
 
-    /// Return the maximum number of trailing dimensions that can be
-    /// collapsed.
+    /// Return the maximum number of trailing dimensions that are
+    /// contiguous.
     ///
     /// Examples:
-    ///   - memref<2x3x2xi8, strided<[24, 12, 2]>, the number of collapsable
-    ///     trailing dimensions is 0
-    ///   - memref<2x3x2xi8, strided<[12, 6, 1]>, the number of collapsable
+    ///   - memref<5x3x2xi8, strided<[6,2,1]>>, the number of collapsable
     ///     trailing dimensions is 3
-    ///   - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]>, the number of
-    ///     collapsable trailing dimensions is 2.
-    ///   - memref<5x4x?x2xi8>, the number of collapsable trailing dimensions
-    ///     is 4.
-    int64_t getMaxCollapsableTrailingDims();
+    ///   - memref<5x3x2xi8, strided<[12,2,1]>>, the number of collapsable
+    ///     trailing dimensions is 2 (dimension 0 is non-contiguous)
+    ///   - memref<5x3x2xi8, strided<[12,4,1]>>, the number of collapsable
+    ///     trailing dimensions is 1 (dimension 1 is non-contiguous)
+    ///   - memref<5x3x2xi8, strided<[12,4,2]>>, the number of collapsable
+    ///     trailing dimensions is 0 (dimension 2 is non-contiguous)
+    ///   - memref<?x3x2xi8, strided<[6,2,1]>>, the number of collapsable
+    ///     trailing dimensions is 3
+    ///   - memref<?x3x2xi8, strided<[12,2,1]>>, the number of collapsable
+    ///     trailing dimensions is 2 (dimension 0 is non-contiguous)
+    ///   - memref<5x?x2xi8, strided<[?,2,1]>>, the number of collapsable
+    ///     trailing dimensions is 2 (stride 0 is dynamic)
+    int64_t getMaxContiguousTrailingDims();
 
     /// Return a version of this type with identity layout if it can be
     /// determined statically that the layout is the canonical contiguous
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index cc23d08515ff3..f839576f36969 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -646,11 +646,10 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
 }
 
 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
-  return getLayout().isIdentity() ||
-         getMaxCollapsableTrailingDims() >= std::min(n, getRank());
+  return getMaxContiguousTrailingDims() >= std::min(n, getRank());
 }
 
-int64_t MemRefType::getMaxCollapsableTrailingDims() {
+int64_t MemRefType::getMaxContiguousTrailingDims() {
   const int64_t n = getRank();
 
   // memrefs with identity layout are entirely contiguous.
@@ -664,7 +663,7 @@ int64_t MemRefType::getMaxCollapsableTrailingDims() {
   if (!succeeded(getStridesAndOffset(strides, offset)))
     return 0;
 
-  auto shape = getShape();
+  ArrayRef<int64_t> shape = getShape();
 
   // A memref with dimensions `d0, d1, ..., dn-1` and strides
   // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
@@ -674,6 +673,8 @@ int64_t MemRefType::getMaxCollapsableTrailingDims() {
   for (int64_t i = n - 1; i >= 0; --i) {
     if (strides[i] != dimProduct)
       return n - i - 1;
+    if (shape[i] == 1)
+      continue;
     if (shape[i] == ShapedType::kDynamic)
       return n - i;
     dimProduct *= shape[i];
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
index e01c0056d5cec..631cac79e6999 100644
--- a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
+++ b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
@@ -27,55 +27,55 @@ TEST(MemRefLayout, maxCollapseDim) {
 
   // memref<2x2x2xf32, strided<[4,2,1]>
   auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
-  EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+  EXPECT_EQ(m1.getMaxContiguousTrailingDims(), 3);
 
   // memref<2x2x2xf32, strided<[8,2,1]>
   auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
-  EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 2);
+  EXPECT_EQ(m2.getMaxContiguousTrailingDims(), 2);
 
   // memref<2x2x2xf32, strided<[8,4,1]>
   auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
-  EXPECT_EQ(m3.getMaxCollapsableTrailingDims(), 1);
+  EXPECT_EQ(m3.getMaxContiguousTrailingDims(), 1);
 
   // memref<2x2x2xf32, strided<[8,4,2]>
   auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
-  EXPECT_EQ(m4.getMaxCollapsableTrailingDims(), 0);
+  EXPECT_EQ(m4.getMaxContiguousTrailingDims(), 0);
 
   // memref<2x2x?xf32, strided<[?,?,1]>
   auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
-  EXPECT_EQ(m5.getMaxCollapsableTrailingDims(), 1);
+  EXPECT_EQ(m5.getMaxContiguousTrailingDims(), 1);
 
   // memref<2x2x?xf32, strided<[?,?,2]>
   auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
-  EXPECT_EQ(m6.getMaxCollapsableTrailingDims(), 0);
+  EXPECT_EQ(m6.getMaxContiguousTrailingDims(), 0);
 
   // memref<2x?x2xf32, strided<[?,2,1]>
   auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
-  EXPECT_EQ(m7.getMaxCollapsableTrailingDims(), 2);
+  EXPECT_EQ(m7.getMaxContiguousTrailingDims(), 2);
 
   // memref<2x?x2xf32, strided<[?,4,1]>
   auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
-  EXPECT_EQ(m8.getMaxCollapsableTrailingDims(), 1);
+  EXPECT_EQ(m8.getMaxContiguousTrailingDims(), 1);
 
   // memref<2x?x2xf32, strided<[?,4,2]>
   auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
-  EXPECT_EQ(m9.getMaxCollapsableTrailingDims(), 0);
+  EXPECT_EQ(m9.getMaxContiguousTrailingDims(), 0);
 
   // memref<?x2x2xf32, strided<[4,2,1]>
   auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
-  EXPECT_EQ(m10.getMaxCollapsableTrailingDims(), 3);
+  EXPECT_EQ(m10.getMaxContiguousTrailingDims(), 3);
 
   // memref<?x2x2xf32, strided<[8,2,1]>
   auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
-  EXPECT_EQ(m11.getMaxCollapsableTrailingDims(), 2);
+  EXPECT_EQ(m11.getMaxContiguousTrailingDims(), 2);
 
   // memref<?x2x2xf32, strided<[8,4,1]>
   auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
-  EXPECT_EQ(m12.getMaxCollapsableTrailingDims(), 1);
+  EXPECT_EQ(m12.getMaxContiguousTrailingDims(), 1);
 
   // memref<?x2x2xf32, strided<[8,4,2]>
   auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
-  EXPECT_EQ(m13.getMaxCollapsableTrailingDims(), 0);
+  EXPECT_EQ(m13.getMaxContiguousTrailingDims(), 0);
 }
 
 TEST(MemRefLayout, contigTrailingDim) {
@@ -176,14 +176,14 @@ TEST(MemRefLayout, identityMaps) {
 
   // memref<2x2x2xf32>
   auto m1 = MemRefType::get({2, 2, 2}, f32);
-  EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+  EXPECT_EQ(m1.getMaxContiguousTrailingDims(), 3);
   EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
   EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
   EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
 
   // memref<?x?x?xf32>
   auto m2 = MemRefType::get({_, _, _}, f32);
-  EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 3);
+  EXPECT_EQ(m2.getMaxContiguousTrailingDims(), 3);
   EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
   EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
   EXPECT_TRUE(m2.areTrailingDimsContiguous(3));

>From dfe4ac9ca82f44c7227a103fd87b3c77f764a306 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 5 Jun 2025 11:02:57 +0000
Subject: [PATCH 04/11] [fixup] Handle unit dimensions by ignoring the
 corresponding stride

---
 mlir/lib/IR/BuiltinTypes.cpp                 |  6 ++++--
 mlir/unittests/Dialect/MemRef/LayoutTest.cpp | 22 +++++++++++++++++++-
 2 files changed, 25 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index f839576f36969..9a8738a18c0d6 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -669,12 +669,14 @@ int64_t MemRefType::getMaxContiguousTrailingDims() {
   // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
   // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
   // for `i` in `[k, n-1]`.
+  // Ignore stride elements if the corresponding dimension is 1, as they are
+  // of no consequence.
   int64_t dimProduct = 1;
   for (int64_t i = n - 1; i >= 0; --i) {
-    if (strides[i] != dimProduct)
-      return n - i - 1;
     if (shape[i] == 1)
       continue;
+    if (strides[i] != dimProduct)
+      return n - i - 1;
     if (shape[i] == ShapedType::kDynamic)
       return n - i;
     dimProduct *= shape[i];
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
index 631cac79e6999..2aa8029ef01f7 100644
--- a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
+++ b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
@@ -15,7 +15,7 @@
 using namespace mlir;
 using namespace mlir::memref;
 
-TEST(MemRefLayout, maxCollapseDim) {
+TEST(MemRefLayout, maxContigDim) {
   MLIRContext ctx;
   OpBuilder b(&ctx);
 
@@ -76,6 +76,26 @@ TEST(MemRefLayout, maxCollapseDim) {
   // memref<?x2x2xf32, strided<[8,4,2]>
   auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
   EXPECT_EQ(m13.getMaxContiguousTrailingDims(), 0);
+
+  // memref<2x2x1xf32, strided<[2,1,2]>
+  auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2}));
+  EXPECT_EQ(m14.getMaxContiguousTrailingDims(), 3);
+
+  // memref<2x2x1xf32, strided<[2,1,?]>
+  auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _}));
+  EXPECT_EQ(m15.getMaxContiguousTrailingDims(), 3);
+
+  // memref<2x2x1xf32, strided<[4,2,2]>
+  auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2}));
+  EXPECT_EQ(m16.getMaxContiguousTrailingDims(), 1);
+
+  // memref<2x1x2xf32, strided<[2,4,1]>
+  auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1}));
+  EXPECT_EQ(m17.getMaxContiguousTrailingDims(), 3);
+
+  // memref<2x1x2xf32, strided<[2,?,1]>
+  auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
+  EXPECT_EQ(m18.getMaxContiguousTrailingDims(), 3);
 }
 
 TEST(MemRefLayout, contigTrailingDim) {

>From 92a505154a852e67f736decc3e3dc30ab0cb1983 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 5 Jun 2025 16:25:15 +0000
Subject: [PATCH 05/11] [fixup] Address misc review comments

---
 mlir/include/mlir/IR/BuiltinTypes.td         |  5 +--
 mlir/lib/Dialect/Utils/IndexingUtils.cpp     |  2 +-
 mlir/lib/IR/BuiltinTypes.cpp                 |  6 ++-
 mlir/unittests/Dialect/MemRef/LayoutTest.cpp | 40 ++++++++++----------
 4 files changed, 27 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index dab6465e8bcf9..3719f8e895967 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -838,8 +838,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
     ///
     bool areTrailingDimsContiguous(int64_t n);
 
-    /// Return the maximum number of trailing dimensions that are
-    /// contiguous.
+    /// Return the number of trailing dimensions that are contiguous.
     ///
     /// Examples:
     ///   - memref<5x3x2xi8, strided<[6,2,1]>>, the number of collapsable
@@ -856,7 +855,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
     ///     trailing dimensions is 2 (dimension 0 is non-contiguous)
     ///   - memref<5x?x2xi8, strided<[?,2,1]>>, the number of collapsable
     ///     trailing dimensions is 2 (stride 0 is dynamic)
-    int64_t getMaxContiguousTrailingDims();
+    int64_t getNumContiguousTrailingDims();
 
     /// Return a version of this type with identity layout if it can be
     /// determined statically that the layout is the canonical contiguous
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 3efe0edeaeb04..e1648ab99ff25 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -69,7 +69,7 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
 //===----------------------------------------------------------------------===//
 
 SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
-  assert((sizes.size() == 0 ||
+  assert((sizes.empty() ||
           llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
          "sizes must be nonnegative");
   int64_t unit = 1;
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9a8738a18c0d6..08fac617824aa 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -646,10 +646,12 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
 }
 
 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
-  return getMaxContiguousTrailingDims() >= std::min(n, getRank());
+  assert(n <= getRank() &&
+         "number of dimensions to check must not exceed rank");
+  return n <= getNumContiguousTrailingDims();
 }
 
-int64_t MemRefType::getMaxContiguousTrailingDims() {
+int64_t MemRefType::getNumContiguousTrailingDims() {
   const int64_t n = getRank();
 
   // memrefs with identity layout are entirely contiguous.
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
index 2aa8029ef01f7..0ba8927da69e9 100644
--- a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
+++ b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
@@ -27,75 +27,75 @@ TEST(MemRefLayout, maxContigDim) {
 
   // memref<2x2x2xf32, strided<[4,2,1]>
   auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
-  EXPECT_EQ(m1.getMaxContiguousTrailingDims(), 3);
+  EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
 
   // memref<2x2x2xf32, strided<[8,2,1]>
   auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
-  EXPECT_EQ(m2.getMaxContiguousTrailingDims(), 2);
+  EXPECT_EQ(m2.getNumContiguousTrailingDims(), 2);
 
   // memref<2x2x2xf32, strided<[8,4,1]>
   auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
-  EXPECT_EQ(m3.getMaxContiguousTrailingDims(), 1);
+  EXPECT_EQ(m3.getNumContiguousTrailingDims(), 1);
 
   // memref<2x2x2xf32, strided<[8,4,2]>
   auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
-  EXPECT_EQ(m4.getMaxContiguousTrailingDims(), 0);
+  EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
 
   // memref<2x2x?xf32, strided<[?,?,1]>
   auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
-  EXPECT_EQ(m5.getMaxContiguousTrailingDims(), 1);
+  EXPECT_EQ(m5.getNumContiguousTrailingDims(), 1);
 
   // memref<2x2x?xf32, strided<[?,?,2]>
   auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
-  EXPECT_EQ(m6.getMaxContiguousTrailingDims(), 0);
+  EXPECT_EQ(m6.getNumContiguousTrailingDims(), 0);
 
   // memref<2x?x2xf32, strided<[?,2,1]>
   auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
-  EXPECT_EQ(m7.getMaxContiguousTrailingDims(), 2);
+  EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
 
   // memref<2x?x2xf32, strided<[?,4,1]>
   auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
-  EXPECT_EQ(m8.getMaxContiguousTrailingDims(), 1);
+  EXPECT_EQ(m8.getNumContiguousTrailingDims(), 1);
 
   // memref<2x?x2xf32, strided<[?,4,2]>
   auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
-  EXPECT_EQ(m9.getMaxContiguousTrailingDims(), 0);
+  EXPECT_EQ(m9.getNumContiguousTrailingDims(), 0);
 
   // memref<?x2x2xf32, strided<[4,2,1]>
   auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
-  EXPECT_EQ(m10.getMaxContiguousTrailingDims(), 3);
+  EXPECT_EQ(m10.getNumContiguousTrailingDims(), 3);
 
   // memref<?x2x2xf32, strided<[8,2,1]>
   auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
-  EXPECT_EQ(m11.getMaxContiguousTrailingDims(), 2);
+  EXPECT_EQ(m11.getNumContiguousTrailingDims(), 2);
 
   // memref<?x2x2xf32, strided<[8,4,1]>
   auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
-  EXPECT_EQ(m12.getMaxContiguousTrailingDims(), 1);
+  EXPECT_EQ(m12.getNumContiguousTrailingDims(), 1);
 
   // memref<?x2x2xf32, strided<[8,4,2]>
   auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
-  EXPECT_EQ(m13.getMaxContiguousTrailingDims(), 0);
+  EXPECT_EQ(m13.getNumContiguousTrailingDims(), 0);
 
   // memref<2x2x1xf32, strided<[2,1,2]>
   auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2}));
-  EXPECT_EQ(m14.getMaxContiguousTrailingDims(), 3);
+  EXPECT_EQ(m14.getNumContiguousTrailingDims(), 3);
 
   // memref<2x2x1xf32, strided<[2,1,?]>
   auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _}));
-  EXPECT_EQ(m15.getMaxContiguousTrailingDims(), 3);
+  EXPECT_EQ(m15.getNumContiguousTrailingDims(), 3);
 
   // memref<2x2x1xf32, strided<[4,2,2]>
   auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2}));
-  EXPECT_EQ(m16.getMaxContiguousTrailingDims(), 1);
+  EXPECT_EQ(m16.getNumContiguousTrailingDims(), 1);
 
   // memref<2x1x2xf32, strided<[2,4,1]>
   auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1}));
-  EXPECT_EQ(m17.getMaxContiguousTrailingDims(), 3);
+  EXPECT_EQ(m17.getNumContiguousTrailingDims(), 3);
 
   // memref<2x1x2xf32, strided<[2,?,1]>
   auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
-  EXPECT_EQ(m18.getMaxContiguousTrailingDims(), 3);
+  EXPECT_EQ(m18.getNumContiguousTrailingDims(), 3);
 }
 
 TEST(MemRefLayout, contigTrailingDim) {
@@ -196,14 +196,14 @@ TEST(MemRefLayout, identityMaps) {
 
   // memref<2x2x2xf32>
   auto m1 = MemRefType::get({2, 2, 2}, f32);
-  EXPECT_EQ(m1.getMaxContiguousTrailingDims(), 3);
+  EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
   EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
   EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
   EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
 
   // memref<?x?x?xf32>
   auto m2 = MemRefType::get({_, _, _}, f32);
-  EXPECT_EQ(m2.getMaxContiguousTrailingDims(), 3);
+  EXPECT_EQ(m2.getNumContiguousTrailingDims(), 3);
   EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
   EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
   EXPECT_TRUE(m2.areTrailingDimsContiguous(3));

>From 3b88a09eb3e19548c4c979c11aa107cee52d712f Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 6 Jun 2025 14:51:30 +0000
Subject: [PATCH 06/11] [fixup] Fix a couple of FileCheck lines

---
 mlir/test/Dialect/Vector/vector-transfer-flatten.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 5b2f2ab1f2cef..854800fb98fe4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -212,7 +212,7 @@ func.func @transfer_read_dynamic_dim_to_flatten(
 // CHECK:              %[[C0_I32:.*]] = arith.constant 0 : i32
 // CHECK:              %[[C0:.*]] = arith.constant 0 : index
 // CHECK:              %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME-LITERAL:   [[0], [1, 2, 3]]
+// CHECK-SAME{LITERAL}:  [[0], [1, 2, 3]]
 // CHECK-SAME:           memref<1x?x4x6xi32> into memref<1x?xi32>
 // CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
 // CHECK:              %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
@@ -490,7 +490,7 @@ func.func @transfer_write_dynamic_to_flatten(
 
 // CHECK:              %[[C0:.*]] = arith.constant 0 : index
 // CHECK:              %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME-LITERAL:   [[0], [1, 2, 3]]
+// CHECK-SAME{LITERAL}:  [[0], [1, 2, 3]]
 // CHECK-SAME:           : memref<1x?x4x6xi32> into memref<1x?xi32>
 // CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
 // CHECK:              %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>

>From 1179bf2a341c3197b6c69e0a5b2995b94c40a7f2 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 9 Jun 2025 14:43:54 +0000
Subject: [PATCH 07/11] [fixup] Add some negative tests with dynamic dimensions

---
 .../Vector/vector-transfer-flatten.mlir       | 48 ++++++++++++++++++-
 1 file changed, 46 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 854800fb98fe4..2b012f1a97971 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -188,7 +188,29 @@ func.func @transfer_read_leading_dynamic_dims(
 
 // -----
 
-// One of the dims to be flattened is dynamic - not supported ATM.
+// The vector could be a non-contiguous slice of the input
+// memref.
+
+func.func @negative_transfer_read_dynamic_dim_to_flatten(
+    %mem : memref<4x?x?x2xi8>) -> vector<2x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
+    memref<4x?x?x2xi8>, vector<2x2x2xi8>
+  return %res : vector<2x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten(
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
+// -----
+
+// Can flatten the righmost dynamic dimension
 
 func.func @transfer_read_dynamic_dim_to_flatten(
     %idx_1: index,
@@ -464,8 +486,28 @@ func.func @transfer_write_leading_dynamic_dims(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
+ 
+// The vector could be a non-contiguous slice of the input
+// memref.
+
+func.func @negative_transfer_write_dynamic_to_flatten(
+    %mem : memref<4x?x?x2xi8>,
+    %vec : vector<2x2x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write  %vec, %mem[%c0, %c0, %c0, %c0]
+    : vector<2x2x2xi8>, memref<4x?x?x2xi8>
+  return
+}
 
-// One of the dims to be flattened is dynamic - not supported ATM.
+// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten(
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
+// -----
 
 func.func @transfer_write_dynamic_to_flatten(
     %idx_1: index,
@@ -578,6 +620,8 @@ func.func @negative_out_of_bound_transfer_read(
 
 // -----
 
+// Can flatten the righmost dynamic dimension
+
 func.func @negative_out_of_bound_transfer_write(
     %mem : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x3x2xi8>) {
   %c0 = arith.constant 0 : index

>From 20b495df83e2e6b02958c8f3b4f7ecb5772b8c8a Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 9 Jun 2025 15:45:29 +0000
Subject: [PATCH 08/11] [fixip] Move a test and spell out some auto types

---
 mlir/unittests/Dialect/MemRef/CMakeLists.txt       |  1 -
 mlir/unittests/IR/CMakeLists.txt                   |  1 +
 .../LayoutTest.cpp => IR/MemrefLayoutTest.cpp}     | 14 +++++++-------
 3 files changed, 8 insertions(+), 8 deletions(-)
 rename mlir/unittests/{Dialect/MemRef/LayoutTest.cpp => IR/MemrefLayoutTest.cpp} (96%)

diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index 1f6df1024f430..dede3ba0a885c 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_unittest(MLIRMemRefTests
   InferShapeTest.cpp
-  LayoutTest.cpp
 )
 mlir_target_link_libraries(MLIRMemRefTests
   PRIVATE
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 9ab6029c3480d..7e4e57124309f 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_unittest(MLIRIRTests
   IRMapping.cpp
   InterfaceAttachmentTest.cpp
   LocationTest.cpp
+  MemrefLayoutTest.cpp
   OperationSupportTest.cpp
   PatternMatchTest.cpp
   ShapedTypeTest.cpp
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/IR/MemrefLayoutTest.cpp
similarity index 96%
rename from mlir/unittests/Dialect/MemRef/LayoutTest.cpp
rename to mlir/unittests/IR/MemrefLayoutTest.cpp
index 0ba8927da69e9..6c9bb40da52b3 100644
--- a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
+++ b/mlir/unittests/IR/MemrefLayoutTest.cpp
@@ -1,4 +1,4 @@
-//===- LayoutTest.cpp - unit tests related to memref layout --------------===//
+//===- LayoutTest.cpp - unit tests related to memref layout ---------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -19,8 +19,8 @@ TEST(MemRefLayout, maxContigDim) {
   MLIRContext ctx;
   OpBuilder b(&ctx);
 
-  const auto _ = ShapedType::kDynamic;
-  const auto f32 = b.getF32Type();
+  const int64_t _ = ShapedType::kDynamic;
+  const FloatType f32 = b.getF32Type();
   auto strided = [&ctx](ArrayRef<int64_t> s) {
     return StridedLayoutAttr::get(&ctx, 0, s);
   };
@@ -102,8 +102,8 @@ TEST(MemRefLayout, contigTrailingDim) {
   MLIRContext ctx;
   OpBuilder b(&ctx);
 
-  const auto _ = ShapedType::kDynamic;
-  const auto f32 = b.getF32Type();
+  const int64_t _ = ShapedType::kDynamic;
+  const FloatType f32 = b.getF32Type();
   auto strided = [&ctx](ArrayRef<int64_t> s) {
     return StridedLayoutAttr::get(&ctx, 0, s);
   };
@@ -191,8 +191,8 @@ TEST(MemRefLayout, identityMaps) {
   MLIRContext ctx;
   OpBuilder b(&ctx);
 
-  const auto _ = ShapedType::kDynamic;
-  const auto f32 = b.getF32Type();
+  const int64_t _ = ShapedType::kDynamic;
+  const FloatType f32 = b.getF32Type();
 
   // memref<2x2x2xf32>
   auto m1 = MemRefType::get({2, 2, 2}, f32);

>From af235d94411fc10b96df520eba654ecefa2e58ec Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 18 Jun 2025 11:21:26 +0000
Subject: [PATCH 09/11] [fixup] Misc unimportant changes

---
 mlir/test/Dialect/Vector/vector-transfer-flatten.mlir | 4 +---
 mlir/unittests/IR/MemrefLayoutTest.cpp                | 2 +-
 2 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 2b012f1a97971..eae26106862e7 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -188,7 +188,7 @@ func.func @transfer_read_leading_dynamic_dims(
 
 // -----
 
-// The vector could be a non-contiguous slice of the input
+// The vector is a non-contiguous slice of the input
 // memref.
 
 func.func @negative_transfer_read_dynamic_dim_to_flatten(
@@ -620,8 +620,6 @@ func.func @negative_out_of_bound_transfer_read(
 
 // -----
 
-// Can flatten the righmost dynamic dimension
-
 func.func @negative_out_of_bound_transfer_write(
     %mem : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x3x2xi8>) {
   %c0 = arith.constant 0 : index
diff --git a/mlir/unittests/IR/MemrefLayoutTest.cpp b/mlir/unittests/IR/MemrefLayoutTest.cpp
index 6c9bb40da52b3..d506b7b1a6687 100644
--- a/mlir/unittests/IR/MemrefLayoutTest.cpp
+++ b/mlir/unittests/IR/MemrefLayoutTest.cpp
@@ -15,7 +15,7 @@
 using namespace mlir;
 using namespace mlir::memref;
 
-TEST(MemRefLayout, maxContigDim) {
+TEST(MemRefLayout, numContigDim) {
   MLIRContext ctx;
   OpBuilder b(&ctx);
 

>From 145b055b8157443309f8e0879b6c7ba78a44c322 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 19 Jun 2025 10:20:53 +0000
Subject: [PATCH 10/11] [fixup] Address review comments

---
 .../Vector/vector-transfer-flatten.mlir       |  18 +-
 mlir/unittests/IR/MemrefLayoutTest.cpp        | 177 ++++++------------
 2 files changed, 74 insertions(+), 121 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index eae26106862e7..45873aa93153d 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -210,7 +210,11 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
 
 // -----
 
-// Can flatten the righmost dynamic dimension
+// When collapsing memref dimensions, we may include the rightmost dynamic
+// dimension (e.g., at position `k`) provided that the strides for dimensions
+// `k+1`, `k+2`, etc., ensure contiguity in memory. The stride at position `k`
+// itself does not factor into this. (Here "strides" mean both explicit and
+// implied by identity map)
 
 func.func @transfer_read_dynamic_dim_to_flatten(
     %idx_1: index,
@@ -486,8 +490,8 @@ func.func @transfer_write_leading_dynamic_dims(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
- 
-// The vector could be a non-contiguous slice of the input
+
+// The vector is a non-contiguous slice of the input
 // memref.
 
 func.func @negative_transfer_write_dynamic_to_flatten(
@@ -509,7 +513,9 @@ func.func @negative_transfer_write_dynamic_to_flatten(
 
 // -----
 
-func.func @transfer_write_dynamic_to_flatten(
+// See the comment in front of @transfer_read_dynamic_dim_to_flatten.
+
+func.func @transfer_write_dynamic_dim_to_flatten(
     %idx_1: index,
     %idx_2: index,
     %vec : vector<1x2x6xi32>,
@@ -524,7 +530,7 @@ func.func @transfer_write_dynamic_to_flatten(
 
 // CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
 
-// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten
+// CHECK-LABEL: func.func @transfer_write_dynamic_dim_to_flatten
 // CHECK-SAME:    %[[IDX_1:arg0]]: index
 // CHECK-SAME:    %[[IDX_2:arg1]]: index
 // CHECK-SAME:    %[[VEC:arg2]]: vector<1x2x6xi32>
@@ -539,7 +545,7 @@ func.func @transfer_write_dynamic_to_flatten(
 // CHECK:              vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
 // CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
 
-// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
+// CHECK-128B-LABEL: func @transfer_write_dynamic_dim_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
diff --git a/mlir/unittests/IR/MemrefLayoutTest.cpp b/mlir/unittests/IR/MemrefLayoutTest.cpp
index d506b7b1a6687..73005401b56fc 100644
--- a/mlir/unittests/IR/MemrefLayoutTest.cpp
+++ b/mlir/unittests/IR/MemrefLayoutTest.cpp
@@ -15,6 +15,9 @@
 using namespace mlir;
 using namespace mlir::memref;
 
+//
+// Test the correctness of `memref::getNumContiguousTrailingDims`
+//
 TEST(MemRefLayout, numContigDim) {
   MLIRContext ctx;
   OpBuilder b(&ctx);
@@ -25,79 +28,108 @@ TEST(MemRefLayout, numContigDim) {
     return StridedLayoutAttr::get(&ctx, 0, s);
   };
 
-  // memref<2x2x2xf32, strided<[4,2,1]>
+  // Create a sequence of test cases, starting with the base case of a
+  // contiguous 2x2x2 memref with fixed dimensions and then at each step
+  // introducing one dynamic dimension starting from the right.
+  // With thus obtained memref, start with maximally contiguous strides
+  // and then at each step gradually introduce discontinuity by increasing
+  // a fixed stride size from the left to right.
+
+  // In these and the following test cases the intent is to achieve code
+  // coverage of the main loop in `MemRefType::getNumContiguousTrailingDims()`.
+
+  // memref<2x2x2xf32, strided<[4,2,1]>>
   auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
   EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
 
-  // memref<2x2x2xf32, strided<[8,2,1]>
+  // memref<2x2x2xf32, strided<[8,2,1]>>
   auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
   EXPECT_EQ(m2.getNumContiguousTrailingDims(), 2);
 
-  // memref<2x2x2xf32, strided<[8,4,1]>
+  // memref<2x2x2xf32, strided<[8,4,1]>>
   auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
   EXPECT_EQ(m3.getNumContiguousTrailingDims(), 1);
 
-  // memref<2x2x2xf32, strided<[8,4,2]>
+  // memref<2x2x2xf32, strided<[8,4,2]>>
   auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
   EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
 
-  // memref<2x2x?xf32, strided<[?,?,1]>
+  // memref<2x2x?xf32, strided<[?,?,1]>>
   auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
   EXPECT_EQ(m5.getNumContiguousTrailingDims(), 1);
 
-  // memref<2x2x?xf32, strided<[?,?,2]>
+  // memref<2x2x?xf32, strided<[?,?,2]>>
   auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
   EXPECT_EQ(m6.getNumContiguousTrailingDims(), 0);
 
-  // memref<2x?x2xf32, strided<[?,2,1]>
+  // memref<2x?x2xf32, strided<[?,2,1]>>
   auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
   EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
 
-  // memref<2x?x2xf32, strided<[?,4,1]>
+  // memref<2x?x2xf32, strided<[?,4,1]>>
   auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
   EXPECT_EQ(m8.getNumContiguousTrailingDims(), 1);
 
-  // memref<2x?x2xf32, strided<[?,4,2]>
+  // memref<2x?x2xf32, strided<[?,4,2]>>
   auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
   EXPECT_EQ(m9.getNumContiguousTrailingDims(), 0);
 
-  // memref<?x2x2xf32, strided<[4,2,1]>
+  // memref<?x2x2xf32, strided<[4,2,1]>>
   auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
   EXPECT_EQ(m10.getNumContiguousTrailingDims(), 3);
 
-  // memref<?x2x2xf32, strided<[8,2,1]>
+  // memref<?x2x2xf32, strided<[8,2,1]>>
   auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
   EXPECT_EQ(m11.getNumContiguousTrailingDims(), 2);
 
-  // memref<?x2x2xf32, strided<[8,4,1]>
+  // memref<?x2x2xf32, strided<[8,4,1]>>
   auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
   EXPECT_EQ(m12.getNumContiguousTrailingDims(), 1);
 
-  // memref<?x2x2xf32, strided<[8,4,2]>
+  // memref<?x2x2xf32, strided<[8,4,2]>>
   auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
   EXPECT_EQ(m13.getNumContiguousTrailingDims(), 0);
 
-  // memref<2x2x1xf32, strided<[2,1,2]>
+  //
+  // Repeat a similar process, but this time introduce a unit memref dimension
+  // to test that strides corresponding to unit dimensions are immaterial, even
+  // if dynamic.
+  //
+
+  // memref<2x2x1xf32, strided<[2,1,2]>>
   auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2}));
   EXPECT_EQ(m14.getNumContiguousTrailingDims(), 3);
 
-  // memref<2x2x1xf32, strided<[2,1,?]>
+  // memref<2x2x1xf32, strided<[2,1,?]>>
   auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _}));
   EXPECT_EQ(m15.getNumContiguousTrailingDims(), 3);
 
-  // memref<2x2x1xf32, strided<[4,2,2]>
+  // memref<2x2x1xf32, strided<[4,2,2]>>
   auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2}));
   EXPECT_EQ(m16.getNumContiguousTrailingDims(), 1);
 
-  // memref<2x1x2xf32, strided<[2,4,1]>
+  // memref<2x1x2xf32, strided<[2,4,1]>>
   auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1}));
   EXPECT_EQ(m17.getNumContiguousTrailingDims(), 3);
 
-  // memref<2x1x2xf32, strided<[2,?,1]>
+  // memref<2x1x2xf32, strided<[2,?,1]>>
   auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
   EXPECT_EQ(m18.getNumContiguousTrailingDims(), 3);
+
+  //
+  // Special case for identity maps and no explicit `strided` attribute - the
+  // memref is entirely contiguous even if the strides cannot be determined
+  // statically.
+  //
+
+  // memref<?x?x?xf32>
+  auto m19 = MemRefType::get({_, _, _}, f32);
+  EXPECT_EQ(m19.getNumContiguousTrailingDims(), 3);
 }
 
+//
+// Test the member function `memref::areTrailingDimsContiguous`
+//
 TEST(MemRefLayout, contigTrailingDim) {
   MLIRContext ctx;
   OpBuilder b(&ctx);
@@ -108,103 +140,18 @@ TEST(MemRefLayout, contigTrailingDim) {
     return StridedLayoutAttr::get(&ctx, 0, s);
   };
 
-  // memref<2x2x2xf32, strided<[4,2,1]>
-  auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
-  EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
-  EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
-  EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
-
-  // memref<2x2x2xf32, strided<[8,2,1]>
-  auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
-  EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
-  EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
-  EXPECT_FALSE(m2.areTrailingDimsContiguous(3));
-
-  // memref<2x2x2xf32, strided<[8,4,1]>
-  auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
-  EXPECT_TRUE(m3.areTrailingDimsContiguous(1));
-  EXPECT_FALSE(m3.areTrailingDimsContiguous(2));
-  EXPECT_FALSE(m3.areTrailingDimsContiguous(3));
-
-  // memref<2x2x2xf32, strided<[8,4,2]>
-  auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
-  EXPECT_FALSE(m4.areTrailingDimsContiguous(1));
-  EXPECT_FALSE(m4.areTrailingDimsContiguous(2));
-  EXPECT_FALSE(m4.areTrailingDimsContiguous(3));
-
-  // memref<2x2x?xf32, strided<[?,?,1]>
-  auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
-  EXPECT_TRUE(m5.areTrailingDimsContiguous(1));
-  EXPECT_FALSE(m5.areTrailingDimsContiguous(2));
-  EXPECT_FALSE(m5.areTrailingDimsContiguous(3));
-
-  // memref<2x2x?xf32, strided<[?,?,2]>
-  auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
-  EXPECT_FALSE(m6.areTrailingDimsContiguous(1));
-  EXPECT_FALSE(m6.areTrailingDimsContiguous(2));
-  EXPECT_FALSE(m6.areTrailingDimsContiguous(3));
-
-  // memref<2x?x2xf32, strided<[?,2,1]>
-  auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
-  EXPECT_TRUE(m7.areTrailingDimsContiguous(1));
-  EXPECT_TRUE(m7.areTrailingDimsContiguous(2));
-  EXPECT_FALSE(m7.areTrailingDimsContiguous(3));
-
-  // memref<2x?x2xf32, strided<[?,4,1]>
-  auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
-  EXPECT_TRUE(m8.areTrailingDimsContiguous(1));
-  EXPECT_FALSE(m8.areTrailingDimsContiguous(2));
-  EXPECT_FALSE(m8.areTrailingDimsContiguous(3));
-
-  // memref<2x?x2xf32, strided<[?,4,2]>
-  auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
-  EXPECT_FALSE(m9.areTrailingDimsContiguous(1));
-  EXPECT_FALSE(m9.areTrailingDimsContiguous(2));
-  EXPECT_FALSE(m9.areTrailingDimsContiguous(3));
-
-  // memref<?x2x2xf32, strided<[4,2,1]>
-  auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
-  EXPECT_TRUE(m10.areTrailingDimsContiguous(1));
-  EXPECT_TRUE(m10.areTrailingDimsContiguous(2));
-  EXPECT_TRUE(m10.areTrailingDimsContiguous(3));
-
-  // memref<?x2x2xf32, strided<[8,2,1]>
-  auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
-  EXPECT_TRUE(m11.areTrailingDimsContiguous(1));
-  EXPECT_TRUE(m11.areTrailingDimsContiguous(2));
-  EXPECT_FALSE(m11.areTrailingDimsContiguous(3));
-
-  // memref<?x2x2xf32, strided<[8,4,1]>
-  auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
-  EXPECT_TRUE(m12.areTrailingDimsContiguous(1));
-  EXPECT_FALSE(m12.areTrailingDimsContiguous(2));
-  EXPECT_FALSE(m12.areTrailingDimsContiguous(3));
+  // Pick up a random test case among the ones already present in the file and
+  // ensure `areTrailingDimsContiguous(k)` returns `true` up to the value
+  // returned by `getNumContiguousTrailingDims` and `false` from that point on
+  // up to the memref rank.
 
-  // memref<?x2x2xf32, strided<[8,4,2]>
-  auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
-  EXPECT_FALSE(m13.areTrailingDimsContiguous(1));
-  EXPECT_FALSE(m13.areTrailingDimsContiguous(2));
-  EXPECT_FALSE(m13.areTrailingDimsContiguous(3));
-}
-
-TEST(MemRefLayout, identityMaps) {
-  MLIRContext ctx;
-  OpBuilder b(&ctx);
+  // memref<2x?x2xf32, strided<[?,2,1]>>
+  auto m = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+  int64_t n = m.getNumContiguousTrailingDims();
+  for (int64_t i = 0; i <= n; ++i)
+    EXPECT_TRUE(m.areTrailingDimsContiguous(i));
 
-  const int64_t _ = ShapedType::kDynamic;
-  const FloatType f32 = b.getF32Type();
-
-  // memref<2x2x2xf32>
-  auto m1 = MemRefType::get({2, 2, 2}, f32);
-  EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
-  EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
-  EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
-  EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
-
-  // memref<?x?x?xf32>
-  auto m2 = MemRefType::get({_, _, _}, f32);
-  EXPECT_EQ(m2.getNumContiguousTrailingDims(), 3);
-  EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
-  EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
-  EXPECT_TRUE(m2.areTrailingDimsContiguous(3));
+  int64_t r = m.getRank();
+  for (int64_t i = n + 1; i <= r; ++i)
+    EXPECT_FALSE(m.areTrailingDimsContiguous(i));
 }

>From b5425a0c6b6c8407525d88028d228932fc96796a Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 19 Jun 2025 13:43:11 +0000
Subject: [PATCH 11/11] [fixup] Reduce the number of test cases

---
 mlir/unittests/IR/MemrefLayoutTest.cpp | 136 ++++++++-----------------
 1 file changed, 45 insertions(+), 91 deletions(-)

diff --git a/mlir/unittests/IR/MemrefLayoutTest.cpp b/mlir/unittests/IR/MemrefLayoutTest.cpp
index 73005401b56fc..f243a76ee660c 100644
--- a/mlir/unittests/IR/MemrefLayoutTest.cpp
+++ b/mlir/unittests/IR/MemrefLayoutTest.cpp
@@ -28,103 +28,60 @@ TEST(MemRefLayout, numContigDim) {
     return StridedLayoutAttr::get(&ctx, 0, s);
   };
 
-  // Create a sequence of test cases, starting with the base case of a
-  // contiguous 2x2x2 memref with fixed dimensions and then at each step
-  // introducing one dynamic dimension starting from the right.
-  // With thus obtained memref, start with maximally contiguous strides
-  // and then at each step gradually introduce discontinuity by increasing
-  // a fixed stride size from the left to right.
-
-  // In these and the following test cases the intent is to achieve code
-  // coverage of the main loop in `MemRefType::getNumContiguousTrailingDims()`.
-
-  // memref<2x2x2xf32, strided<[4,2,1]>>
-  auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
-  EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
-
-  // memref<2x2x2xf32, strided<[8,2,1]>>
-  auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
-  EXPECT_EQ(m2.getNumContiguousTrailingDims(), 2);
-
-  // memref<2x2x2xf32, strided<[8,4,1]>>
-  auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
-  EXPECT_EQ(m3.getNumContiguousTrailingDims(), 1);
+  // Special case for identity maps and no explicit `strided` attribute - the
+  // memref is entirely contiguous even if the strides cannot be determined
+  // statically.
 
-  // memref<2x2x2xf32, strided<[8,4,2]>>
-  auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
-  EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
+  // memref<?x?x?xf32>
+  auto m0 = MemRefType::get({_, _, _}, f32);
+  EXPECT_EQ(m0.getNumContiguousTrailingDims(), 3);
 
-  // memref<2x2x?xf32, strided<[?,?,1]>>
-  auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
-  EXPECT_EQ(m5.getNumContiguousTrailingDims(), 1);
+  // Conservatively assume memref is sparse everywhere if cannot get the
+  // strides.
 
-  // memref<2x2x?xf32, strided<[?,?,2]>>
-  auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
-  EXPECT_EQ(m6.getNumContiguousTrailingDims(), 0);
+  // memref<2x2x2xf32, (i,j,k)->(i,k,j)>
+  auto m1 = MemRefType::get(
+      {2, 2, 2}, f32,
+      AffineMap::getPermutationMap(ArrayRef<int64_t>{0, 2, 1}, &ctx));
+  EXPECT_EQ(m1.getNumContiguousTrailingDims(), 0);
 
-  // memref<2x?x2xf32, strided<[?,2,1]>>
-  auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
-  EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
+  // A base cases of a fixed memref with the usual strides.
 
-  // memref<2x?x2xf32, strided<[?,4,1]>>
-  auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
-  EXPECT_EQ(m8.getNumContiguousTrailingDims(), 1);
+  // memref<2x2x2xf32, strided<[4, 2, 1]>>
+  auto m3 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+  EXPECT_EQ(m3.getNumContiguousTrailingDims(), 3);
 
-  // memref<2x?x2xf32, strided<[?,4,2]>>
-  auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
-  EXPECT_EQ(m9.getNumContiguousTrailingDims(), 0);
+  // A fixed memref with a discontinuity in the rightmost dimension.
 
-  // memref<?x2x2xf32, strided<[4,2,1]>>
-  auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
-  EXPECT_EQ(m10.getNumContiguousTrailingDims(), 3);
-
-  // memref<?x2x2xf32, strided<[8,2,1]>>
-  auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
-  EXPECT_EQ(m11.getNumContiguousTrailingDims(), 2);
+  // memref<2x2x2xf32, strided<[8, 4, 2]>>
+  auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+  EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
 
-  // memref<?x2x2xf32, strided<[8,4,1]>>
-  auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
-  EXPECT_EQ(m12.getNumContiguousTrailingDims(), 1);
+  // A fixed memref with a discontinuity in the "middle".
 
-  // memref<?x2x2xf32, strided<[8,4,2]>>
-  auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
-  EXPECT_EQ(m13.getNumContiguousTrailingDims(), 0);
+  // memref<2x2x2xf32, strided<[8, 2, 1]>>
+  auto m5 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+  EXPECT_EQ(m5.getNumContiguousTrailingDims(), 2);
 
-  //
-  // Repeat a similar process, but this time introduce a unit memref dimension
-  // to test that strides corresponding to unit dimensions are immaterial, even
-  // if dynamic.
-  //
+  // A dynamic memref where the dynamic dimension breaks continuity.
 
-  // memref<2x2x1xf32, strided<[2,1,2]>>
-  auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2}));
-  EXPECT_EQ(m14.getNumContiguousTrailingDims(), 3);
+  // memref<2x?x2xf32, strided<[4, 2, 1]>>
+  auto m6 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1}));
+  EXPECT_EQ(m6.getNumContiguousTrailingDims(), 2);
 
-  // memref<2x2x1xf32, strided<[2,1,?]>>
-  auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _}));
-  EXPECT_EQ(m15.getNumContiguousTrailingDims(), 3);
+  // A edge case of a dynamic memref where the dynamic dimension is the first
+  // one.
 
-  // memref<2x2x1xf32, strided<[4,2,2]>>
-  auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2}));
-  EXPECT_EQ(m16.getNumContiguousTrailingDims(), 1);
+  // memref<?x2x2xf32, strided<[4, 2, 1]>>
+  auto m7 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1}));
+  EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
 
-  // memref<2x1x2xf32, strided<[2,4,1]>>
-  auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1}));
-  EXPECT_EQ(m17.getNumContiguousTrailingDims(), 3);
+  // A memref with a unit dimension. Unit dimensions do not affect continuity,
+  // even if the corresponding stride is dynamic.
 
   // memref<2x1x2xf32, strided<[2,?,1]>>
-  auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
-  EXPECT_EQ(m18.getNumContiguousTrailingDims(), 3);
-
-  //
-  // Special case for identity maps and no explicit `strided` attribute - the
-  // memref is entirely contiguous even if the strides cannot be determined
-  // statically.
-  //
-
-  // memref<?x?x?xf32>
-  auto m19 = MemRefType::get({_, _, _}, f32);
-  EXPECT_EQ(m19.getNumContiguousTrailingDims(), 3);
+  auto m8 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
+  EXPECT_EQ(m8.getNumContiguousTrailingDims(), 3);
 }
 
 //
@@ -140,18 +97,15 @@ TEST(MemRefLayout, contigTrailingDim) {
     return StridedLayoutAttr::get(&ctx, 0, s);
   };
 
-  // Pick up a random test case among the ones already present in the file and
-  // ensure `areTrailingDimsContiguous(k)` returns `true` up to the value
-  // returned by `getNumContiguousTrailingDims` and `false` from that point on
-  // up to the memref rank.
+  // A not-entirely-continuous, not-entirely-discontinuous memref.
+  // ensure `areTrailingDimsContiguous` returns `true` for the value
+  // returned by `getNumContiguousTrailingDims` and `false` for the next bigger
+  // number.
 
   // memref<2x?x2xf32, strided<[?,2,1]>>
   auto m = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
   int64_t n = m.getNumContiguousTrailingDims();
-  for (int64_t i = 0; i <= n; ++i)
-    EXPECT_TRUE(m.areTrailingDimsContiguous(i));
-
-  int64_t r = m.getRank();
-  for (int64_t i = n + 1; i <= r; ++i)
-    EXPECT_FALSE(m.areTrailingDimsContiguous(i));
+  EXPECT_TRUE(m.areTrailingDimsContiguous(n));
+  ASSERT_TRUE(n + 1 <= m.getRank());
+  EXPECT_FALSE(m.areTrailingDimsContiguous(n + 1));
 }



More information about the Mlir-commits mailing list