[Mlir-commits] [mlir] [MLIR] Determine contiguousness of memrefs with dynamic dimensions (PR #142421)
Momchil Velikov
llvmlistbot at llvm.org
Thu Jun 19 06:46:35 PDT 2025
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142421
>From 3f5355acd26929528134890ddab1854af59e2df6 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 2 Jun 2025 13:20:14 +0000
Subject: [PATCH 01/11] [MLIR] Determine contiguousness of memrefs with a
dynamic dimensions
This patch enhances `MemRefType::areTrailingDimsContiguous`
to also handle memrefs with dynamic dimensions.
The implementation itself is based on a new member function
`MemRefType::getMaxCollapsableTrailingDims` that return
the maximum number of trailing dimensions that can be
collapsed - trivially all dimensions for memrefs with identity
layout, or by examining the memref strides stopping at discontguous
or statically unknown strides.
---
mlir/include/mlir/IR/BuiltinTypes.td | 14 ++
mlir/lib/IR/BuiltinTypes.cpp | 47 +++--
.../Vector/vector-transfer-flatten.mlir | 49 ++++-
mlir/unittests/Dialect/MemRef/CMakeLists.txt | 1 +
mlir/unittests/Dialect/MemRef/LayoutTest.cpp | 190 ++++++++++++++++++
5 files changed, 269 insertions(+), 32 deletions(-)
create mode 100644 mlir/unittests/Dialect/MemRef/LayoutTest.cpp
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..1d12f70882176 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -838,6 +838,20 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
///
bool areTrailingDimsContiguous(int64_t n);
+ /// Return the maximum number of trailing dimensions that can be
+ /// collapsed.
+ ///
+ /// Examples:
+ /// - memref<2x3x2xi8, strided<[24, 12, 2]>, the number of collapsable
+ /// trailing dimensions is 0
+ /// - memref<2x3x2xi8, strided<[12, 6, 1]>, the number of collapsable
+ /// trailing dimensions is 3
+ /// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]>, the number of
+ /// collapsable trailing dimensions is 2.
+ /// - memref<5x4x?x2xi8>, the number of collapsable trailing dimensions
+ /// is 4.
+ int64_t getMaxCollapsableTrailingDims();
+
/// Return a version of this type with identity layout if it can be
/// determined statically that the layout is the canonical contiguous
/// strided layout. Otherwise pass the layout into `simplifyAffineMap`
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d47e360e9dc13..cc23d08515ff3 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -646,35 +646,40 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
}
bool MemRefType::areTrailingDimsContiguous(int64_t n) {
- if (!isLastDimUnitStride())
- return false;
+ return getLayout().isIdentity() ||
+ getMaxCollapsableTrailingDims() >= std::min(n, getRank());
+}
- auto memrefShape = getShape().take_back(n);
- if (ShapedType::isDynamicShape(memrefShape))
- return false;
+int64_t MemRefType::getMaxCollapsableTrailingDims() {
+ const int64_t n = getRank();
+ // memrefs with identity layout are entirely contiguous.
if (getLayout().isIdentity())
- return true;
+ return n;
+ // Get the strides (if any). Failing to do that, conservatively assume a
+ // non-contiguous layout.
int64_t offset;
- SmallVector<int64_t> stridesFull;
- if (!succeeded(getStridesAndOffset(stridesFull, offset)))
- return false;
- auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
-
- if (strides.empty())
- return true;
+ SmallVector<int64_t> strides;
+ if (!succeeded(getStridesAndOffset(strides, offset)))
+ return 0;
- // Check whether strides match "flattened" dims.
- SmallVector<int64_t> flattenedDims;
- auto dimProduct = 1;
- for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
- dimProduct *= dim;
- flattenedDims.push_back(dimProduct);
+ auto shape = getShape();
+
+ // A memref with dimensions `d0, d1, ..., dn-1` and strides
+ // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
+ // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
+ // for `i` in `[k, n-1]`.
+ int64_t dimProduct = 1;
+ for (int64_t i = n - 1; i >= 0; --i) {
+ if (strides[i] != dimProduct)
+ return n - i - 1;
+ if (shape[i] == ShapedType::kDynamic)
+ return n - i;
+ dimProduct *= shape[i];
}
- strides = strides.drop_back(1);
- return llvm::equal(strides, llvm::reverse(flattenedDims));
+ return n;
}
MemRefType MemRefType::canonicalizeStridedLayout() {
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index e840dc6bbf224..5b2f2ab1f2cef 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -190,7 +190,7 @@ func.func @transfer_read_leading_dynamic_dims(
// One of the dims to be flattened is dynamic - not supported ATM.
-func.func @negative_transfer_read_dynamic_dim_to_flatten(
+func.func @transfer_read_dynamic_dim_to_flatten(
%idx_1: index,
%idx_2: index,
%mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -203,11 +203,25 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
return %res : vector<1x2x6xi32>
}
-// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
+// CHECK-SAME: %[[IDX_1:arg0]]
+// CHECK-SAME: %[[IDX_2:arg1]]
+// CHECK-SAME: %[[MEM:arg2]]
+// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
+// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
+// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
+// CHECK: return %[[RESULT]] : vector<1x2x6xi32>
+
+
+// CHECK-128B-LABEL: func @transfer_read_dynamic_dim_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
// -----
@@ -453,7 +467,7 @@ func.func @transfer_write_leading_dynamic_dims(
// One of the dims to be flattened is dynamic - not supported ATM.
-func.func @negative_transfer_write_dynamic_to_flatten(
+func.func @transfer_write_dynamic_to_flatten(
%idx_1: index,
%idx_2: index,
%vec : vector<1x2x6xi32>,
@@ -466,11 +480,24 @@ func.func @negative_transfer_write_dynamic_to_flatten(
return
}
-// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten
+// CHECK-SAME: %[[IDX_1:arg0]]: index
+// CHECK-SAME: %[[IDX_2:arg1]]: index
+// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
+// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
+// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
-// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
+// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
// -----
diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index dede3ba0a885c..1f6df1024f430 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRMemRefTests
InferShapeTest.cpp
+ LayoutTest.cpp
)
mlir_target_link_libraries(MLIRMemRefTests
PRIVATE
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
new file mode 100644
index 0000000000000..e01c0056d5cec
--- /dev/null
+++ b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
@@ -0,0 +1,190 @@
+//===- LayoutTest.cpp - unit tests related to memref layout --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+TEST(MemRefLayout, maxCollapseDim) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ const auto _ = ShapedType::kDynamic;
+ const auto f32 = b.getF32Type();
+ auto strided = [&ctx](ArrayRef<int64_t> s) {
+ return StridedLayoutAttr::get(&ctx, 0, s);
+ };
+
+ // memref<2x2x2xf32, strided<[4,2,1]>
+ auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+
+ // memref<2x2x2xf32, strided<[8,2,1]>
+ auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 2);
+
+ // memref<2x2x2xf32, strided<[8,4,1]>
+ auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_EQ(m3.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<2x2x2xf32, strided<[8,4,2]>
+ auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_EQ(m4.getMaxCollapsableTrailingDims(), 0);
+
+ // memref<2x2x?xf32, strided<[?,?,1]>
+ auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
+ EXPECT_EQ(m5.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<2x2x?xf32, strided<[?,?,2]>
+ auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
+ EXPECT_EQ(m6.getMaxCollapsableTrailingDims(), 0);
+
+ // memref<2x?x2xf32, strided<[?,2,1]>
+ auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+ EXPECT_EQ(m7.getMaxCollapsableTrailingDims(), 2);
+
+ // memref<2x?x2xf32, strided<[?,4,1]>
+ auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
+ EXPECT_EQ(m8.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<2x?x2xf32, strided<[?,4,2]>
+ auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
+ EXPECT_EQ(m9.getMaxCollapsableTrailingDims(), 0);
+
+ // memref<?x2x2xf32, strided<[4,2,1]>
+ auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_EQ(m10.getMaxCollapsableTrailingDims(), 3);
+
+ // memref<?x2x2xf32, strided<[8,2,1]>
+ auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_EQ(m11.getMaxCollapsableTrailingDims(), 2);
+
+ // memref<?x2x2xf32, strided<[8,4,1]>
+ auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_EQ(m12.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<?x2x2xf32, strided<[8,4,2]>
+ auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_EQ(m13.getMaxCollapsableTrailingDims(), 0);
+}
+
+TEST(MemRefLayout, contigTrailingDim) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ const auto _ = ShapedType::kDynamic;
+ const auto f32 = b.getF32Type();
+ auto strided = [&ctx](ArrayRef<int64_t> s) {
+ return StridedLayoutAttr::get(&ctx, 0, s);
+ };
+
+ // memref<2x2x2xf32, strided<[4,2,1]>
+ auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
+
+ // memref<2x2x2xf32, strided<[8,2,1]>
+ auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m2.areTrailingDimsContiguous(3));
+
+ // memref<2x2x2xf32, strided<[8,4,1]>
+ auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_TRUE(m3.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m3.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m3.areTrailingDimsContiguous(3));
+
+ // memref<2x2x2xf32, strided<[8,4,2]>
+ auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_FALSE(m4.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m4.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m4.areTrailingDimsContiguous(3));
+
+ // memref<2x2x?xf32, strided<[?,?,1]>
+ auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
+ EXPECT_TRUE(m5.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m5.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m5.areTrailingDimsContiguous(3));
+
+ // memref<2x2x?xf32, strided<[?,?,2]>
+ auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
+ EXPECT_FALSE(m6.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m6.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m6.areTrailingDimsContiguous(3));
+
+ // memref<2x?x2xf32, strided<[?,2,1]>
+ auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+ EXPECT_TRUE(m7.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m7.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m7.areTrailingDimsContiguous(3));
+
+ // memref<2x?x2xf32, strided<[?,4,1]>
+ auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
+ EXPECT_TRUE(m8.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m8.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m8.areTrailingDimsContiguous(3));
+
+ // memref<2x?x2xf32, strided<[?,4,2]>
+ auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
+ EXPECT_FALSE(m9.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m9.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m9.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[4,2,1]>
+ auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_TRUE(m10.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m10.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m10.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[8,2,1]>
+ auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_TRUE(m11.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m11.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m11.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[8,4,1]>
+ auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_TRUE(m12.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m12.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m12.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[8,4,2]>
+ auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_FALSE(m13.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m13.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m13.areTrailingDimsContiguous(3));
+}
+
+TEST(MemRefLayout, identityMaps) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ const auto _ = ShapedType::kDynamic;
+ const auto f32 = b.getF32Type();
+
+ // memref<2x2x2xf32>
+ auto m1 = MemRefType::get({2, 2, 2}, f32);
+ EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
+
+ // memref<?x?x?xf32>
+ auto m2 = MemRefType::get({_, _, _}, f32);
+ EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 3);
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(3));
+}
>From a43a9dc14d2c1adf6d92faa71ae85cc1c3806122 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 3 Jun 2025 16:09:05 +0000
Subject: [PATCH 02/11] [fixup] Fix an assertion
`computeStrides` does not acccess the first element of `sizes`
---
mlir/include/mlir/Dialect/Utils/IndexingUtils.h | 2 +-
mlir/lib/Dialect/Utils/IndexingUtils.cpp | 3 ++-
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 99218f491ddef..8524072929793 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -40,7 +40,7 @@ class ArrayAttr;
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
///
-/// `sizes` elements are asserted to be non-negative.
+/// `sizes` elements `s1` to `sn` are asserted to be non-negative.
///
/// Return an empty vector if `sizes` is empty.
SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 8de77e2c3cb08..3efe0edeaeb04 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -69,7 +69,8 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
//===----------------------------------------------------------------------===//
SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
- assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
+ assert((sizes.size() == 0 ||
+ llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
"sizes must be nonnegative");
int64_t unit = 1;
return ::computeSuffixProductImpl(sizes, unit);
>From 047b99f08a00263d4082229d661dae84c28aac7b Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 4 Jun 2025 10:37:37 +0000
Subject: [PATCH 03/11] [fixup] Misc NFC changes
- rename `getMaxCollapsabelTrailingDims` to `getMaxContiguousTrailingDims`
- new set of examples
- remove redundant call to `isIdentify()`
- make sure a variable type is visible on the declaration line
- some micro-optimisation
---
mlir/include/mlir/IR/BuiltinTypes.td | 26 ++++++++++-------
mlir/lib/IR/BuiltinTypes.cpp | 9 +++---
mlir/unittests/Dialect/MemRef/LayoutTest.cpp | 30 ++++++++++----------
3 files changed, 36 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1d12f70882176..dab6465e8bcf9 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -838,19 +838,25 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
///
bool areTrailingDimsContiguous(int64_t n);
- /// Return the maximum number of trailing dimensions that can be
- /// collapsed.
+ /// Return the maximum number of trailing dimensions that are
+ /// contiguous.
///
/// Examples:
- /// - memref<2x3x2xi8, strided<[24, 12, 2]>, the number of collapsable
- /// trailing dimensions is 0
- /// - memref<2x3x2xi8, strided<[12, 6, 1]>, the number of collapsable
+ /// - memref<5x3x2xi8, strided<[6,2,1]>>, the number of collapsable
/// trailing dimensions is 3
- /// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]>, the number of
- /// collapsable trailing dimensions is 2.
- /// - memref<5x4x?x2xi8>, the number of collapsable trailing dimensions
- /// is 4.
- int64_t getMaxCollapsableTrailingDims();
+ /// - memref<5x3x2xi8, strided<[12,2,1]>>, the number of collapsable
+ /// trailing dimensions is 2 (dimension 0 is non-contiguous)
+ /// - memref<5x3x2xi8, strided<[12,4,1]>>, the number of collapsable
+ /// trailing dimensions is 1 (dimension 1 is non-contiguous)
+ /// - memref<5x3x2xi8, strided<[12,4,2]>>, the number of collapsable
+ /// trailing dimensions is 0 (dimension 2 is non-contiguous)
+ /// - memref<?x3x2xi8, strided<[6,2,1]>>, the number of collapsable
+ /// trailing dimensions is 3
+ /// - memref<?x3x2xi8, strided<[12,2,1]>>, the number of collapsable
+ /// trailing dimensions is 2 (dimension 0 is non-contiguous)
+ /// - memref<5x?x2xi8, strided<[?,2,1]>>, the number of collapsable
+ /// trailing dimensions is 2 (stride 0 is dynamic)
+ int64_t getMaxContiguousTrailingDims();
/// Return a version of this type with identity layout if it can be
/// determined statically that the layout is the canonical contiguous
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index cc23d08515ff3..f839576f36969 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -646,11 +646,10 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
}
bool MemRefType::areTrailingDimsContiguous(int64_t n) {
- return getLayout().isIdentity() ||
- getMaxCollapsableTrailingDims() >= std::min(n, getRank());
+ return getMaxContiguousTrailingDims() >= std::min(n, getRank());
}
-int64_t MemRefType::getMaxCollapsableTrailingDims() {
+int64_t MemRefType::getMaxContiguousTrailingDims() {
const int64_t n = getRank();
// memrefs with identity layout are entirely contiguous.
@@ -664,7 +663,7 @@ int64_t MemRefType::getMaxCollapsableTrailingDims() {
if (!succeeded(getStridesAndOffset(strides, offset)))
return 0;
- auto shape = getShape();
+ ArrayRef<int64_t> shape = getShape();
// A memref with dimensions `d0, d1, ..., dn-1` and strides
// `s0, s1, ..., sn-1` is contiguous up to dimension `k`
@@ -674,6 +673,8 @@ int64_t MemRefType::getMaxCollapsableTrailingDims() {
for (int64_t i = n - 1; i >= 0; --i) {
if (strides[i] != dimProduct)
return n - i - 1;
+ if (shape[i] == 1)
+ continue;
if (shape[i] == ShapedType::kDynamic)
return n - i;
dimProduct *= shape[i];
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
index e01c0056d5cec..631cac79e6999 100644
--- a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
+++ b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
@@ -27,55 +27,55 @@ TEST(MemRefLayout, maxCollapseDim) {
// memref<2x2x2xf32, strided<[4,2,1]>
auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
- EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+ EXPECT_EQ(m1.getMaxContiguousTrailingDims(), 3);
// memref<2x2x2xf32, strided<[8,2,1]>
auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
- EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 2);
+ EXPECT_EQ(m2.getMaxContiguousTrailingDims(), 2);
// memref<2x2x2xf32, strided<[8,4,1]>
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
- EXPECT_EQ(m3.getMaxCollapsableTrailingDims(), 1);
+ EXPECT_EQ(m3.getMaxContiguousTrailingDims(), 1);
// memref<2x2x2xf32, strided<[8,4,2]>
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
- EXPECT_EQ(m4.getMaxCollapsableTrailingDims(), 0);
+ EXPECT_EQ(m4.getMaxContiguousTrailingDims(), 0);
// memref<2x2x?xf32, strided<[?,?,1]>
auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
- EXPECT_EQ(m5.getMaxCollapsableTrailingDims(), 1);
+ EXPECT_EQ(m5.getMaxContiguousTrailingDims(), 1);
// memref<2x2x?xf32, strided<[?,?,2]>
auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
- EXPECT_EQ(m6.getMaxCollapsableTrailingDims(), 0);
+ EXPECT_EQ(m6.getMaxContiguousTrailingDims(), 0);
// memref<2x?x2xf32, strided<[?,2,1]>
auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
- EXPECT_EQ(m7.getMaxCollapsableTrailingDims(), 2);
+ EXPECT_EQ(m7.getMaxContiguousTrailingDims(), 2);
// memref<2x?x2xf32, strided<[?,4,1]>
auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
- EXPECT_EQ(m8.getMaxCollapsableTrailingDims(), 1);
+ EXPECT_EQ(m8.getMaxContiguousTrailingDims(), 1);
// memref<2x?x2xf32, strided<[?,4,2]>
auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
- EXPECT_EQ(m9.getMaxCollapsableTrailingDims(), 0);
+ EXPECT_EQ(m9.getMaxContiguousTrailingDims(), 0);
// memref<?x2x2xf32, strided<[4,2,1]>
auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
- EXPECT_EQ(m10.getMaxCollapsableTrailingDims(), 3);
+ EXPECT_EQ(m10.getMaxContiguousTrailingDims(), 3);
// memref<?x2x2xf32, strided<[8,2,1]>
auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
- EXPECT_EQ(m11.getMaxCollapsableTrailingDims(), 2);
+ EXPECT_EQ(m11.getMaxContiguousTrailingDims(), 2);
// memref<?x2x2xf32, strided<[8,4,1]>
auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
- EXPECT_EQ(m12.getMaxCollapsableTrailingDims(), 1);
+ EXPECT_EQ(m12.getMaxContiguousTrailingDims(), 1);
// memref<?x2x2xf32, strided<[8,4,2]>
auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
- EXPECT_EQ(m13.getMaxCollapsableTrailingDims(), 0);
+ EXPECT_EQ(m13.getMaxContiguousTrailingDims(), 0);
}
TEST(MemRefLayout, contigTrailingDim) {
@@ -176,14 +176,14 @@ TEST(MemRefLayout, identityMaps) {
// memref<2x2x2xf32>
auto m1 = MemRefType::get({2, 2, 2}, f32);
- EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+ EXPECT_EQ(m1.getMaxContiguousTrailingDims(), 3);
EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
// memref<?x?x?xf32>
auto m2 = MemRefType::get({_, _, _}, f32);
- EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 3);
+ EXPECT_EQ(m2.getMaxContiguousTrailingDims(), 3);
EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
EXPECT_TRUE(m2.areTrailingDimsContiguous(3));
>From dfe4ac9ca82f44c7227a103fd87b3c77f764a306 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 5 Jun 2025 11:02:57 +0000
Subject: [PATCH 04/11] [fixup] Handle unit dimensions by ignoring the
corresponding stride
---
mlir/lib/IR/BuiltinTypes.cpp | 6 ++++--
mlir/unittests/Dialect/MemRef/LayoutTest.cpp | 22 +++++++++++++++++++-
2 files changed, 25 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index f839576f36969..9a8738a18c0d6 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -669,12 +669,14 @@ int64_t MemRefType::getMaxContiguousTrailingDims() {
// `s0, s1, ..., sn-1` is contiguous up to dimension `k`
// if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
// for `i` in `[k, n-1]`.
+ // Ignore stride elements if the corresponding dimension is 1, as they are
+ // of no consequence.
int64_t dimProduct = 1;
for (int64_t i = n - 1; i >= 0; --i) {
- if (strides[i] != dimProduct)
- return n - i - 1;
if (shape[i] == 1)
continue;
+ if (strides[i] != dimProduct)
+ return n - i - 1;
if (shape[i] == ShapedType::kDynamic)
return n - i;
dimProduct *= shape[i];
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
index 631cac79e6999..2aa8029ef01f7 100644
--- a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
+++ b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
@@ -15,7 +15,7 @@
using namespace mlir;
using namespace mlir::memref;
-TEST(MemRefLayout, maxCollapseDim) {
+TEST(MemRefLayout, maxContigDim) {
MLIRContext ctx;
OpBuilder b(&ctx);
@@ -76,6 +76,26 @@ TEST(MemRefLayout, maxCollapseDim) {
// memref<?x2x2xf32, strided<[8,4,2]>
auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
EXPECT_EQ(m13.getMaxContiguousTrailingDims(), 0);
+
+ // memref<2x2x1xf32, strided<[2,1,2]>
+ auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2}));
+ EXPECT_EQ(m14.getMaxContiguousTrailingDims(), 3);
+
+ // memref<2x2x1xf32, strided<[2,1,?]>
+ auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _}));
+ EXPECT_EQ(m15.getMaxContiguousTrailingDims(), 3);
+
+ // memref<2x2x1xf32, strided<[4,2,2]>
+ auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2}));
+ EXPECT_EQ(m16.getMaxContiguousTrailingDims(), 1);
+
+ // memref<2x1x2xf32, strided<[2,4,1]>
+ auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1}));
+ EXPECT_EQ(m17.getMaxContiguousTrailingDims(), 3);
+
+ // memref<2x1x2xf32, strided<[2,?,1]>
+ auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
+ EXPECT_EQ(m18.getMaxContiguousTrailingDims(), 3);
}
TEST(MemRefLayout, contigTrailingDim) {
>From 92a505154a852e67f736decc3e3dc30ab0cb1983 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 5 Jun 2025 16:25:15 +0000
Subject: [PATCH 05/11] [fixup] Address misc review comments
---
mlir/include/mlir/IR/BuiltinTypes.td | 5 +--
mlir/lib/Dialect/Utils/IndexingUtils.cpp | 2 +-
mlir/lib/IR/BuiltinTypes.cpp | 6 ++-
mlir/unittests/Dialect/MemRef/LayoutTest.cpp | 40 ++++++++++----------
4 files changed, 27 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index dab6465e8bcf9..3719f8e895967 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -838,8 +838,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
///
bool areTrailingDimsContiguous(int64_t n);
- /// Return the maximum number of trailing dimensions that are
- /// contiguous.
+ /// Return the number of trailing dimensions that are contiguous.
///
/// Examples:
/// - memref<5x3x2xi8, strided<[6,2,1]>>, the number of collapsable
@@ -856,7 +855,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
/// trailing dimensions is 2 (dimension 0 is non-contiguous)
/// - memref<5x?x2xi8, strided<[?,2,1]>>, the number of collapsable
/// trailing dimensions is 2 (stride 0 is dynamic)
- int64_t getMaxContiguousTrailingDims();
+ int64_t getNumContiguousTrailingDims();
/// Return a version of this type with identity layout if it can be
/// determined statically that the layout is the canonical contiguous
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 3efe0edeaeb04..e1648ab99ff25 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -69,7 +69,7 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
//===----------------------------------------------------------------------===//
SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
- assert((sizes.size() == 0 ||
+ assert((sizes.empty() ||
llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
"sizes must be nonnegative");
int64_t unit = 1;
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9a8738a18c0d6..08fac617824aa 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -646,10 +646,12 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
}
bool MemRefType::areTrailingDimsContiguous(int64_t n) {
- return getMaxContiguousTrailingDims() >= std::min(n, getRank());
+ assert(n <= getRank() &&
+ "number of dimensions to check must not exceed rank");
+ return n <= getNumContiguousTrailingDims();
}
-int64_t MemRefType::getMaxContiguousTrailingDims() {
+int64_t MemRefType::getNumContiguousTrailingDims() {
const int64_t n = getRank();
// memrefs with identity layout are entirely contiguous.
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
index 2aa8029ef01f7..0ba8927da69e9 100644
--- a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
+++ b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
@@ -27,75 +27,75 @@ TEST(MemRefLayout, maxContigDim) {
// memref<2x2x2xf32, strided<[4,2,1]>
auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
- EXPECT_EQ(m1.getMaxContiguousTrailingDims(), 3);
+ EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
// memref<2x2x2xf32, strided<[8,2,1]>
auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
- EXPECT_EQ(m2.getMaxContiguousTrailingDims(), 2);
+ EXPECT_EQ(m2.getNumContiguousTrailingDims(), 2);
// memref<2x2x2xf32, strided<[8,4,1]>
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
- EXPECT_EQ(m3.getMaxContiguousTrailingDims(), 1);
+ EXPECT_EQ(m3.getNumContiguousTrailingDims(), 1);
// memref<2x2x2xf32, strided<[8,4,2]>
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
- EXPECT_EQ(m4.getMaxContiguousTrailingDims(), 0);
+ EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
// memref<2x2x?xf32, strided<[?,?,1]>
auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
- EXPECT_EQ(m5.getMaxContiguousTrailingDims(), 1);
+ EXPECT_EQ(m5.getNumContiguousTrailingDims(), 1);
// memref<2x2x?xf32, strided<[?,?,2]>
auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
- EXPECT_EQ(m6.getMaxContiguousTrailingDims(), 0);
+ EXPECT_EQ(m6.getNumContiguousTrailingDims(), 0);
// memref<2x?x2xf32, strided<[?,2,1]>
auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
- EXPECT_EQ(m7.getMaxContiguousTrailingDims(), 2);
+ EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
// memref<2x?x2xf32, strided<[?,4,1]>
auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
- EXPECT_EQ(m8.getMaxContiguousTrailingDims(), 1);
+ EXPECT_EQ(m8.getNumContiguousTrailingDims(), 1);
// memref<2x?x2xf32, strided<[?,4,2]>
auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
- EXPECT_EQ(m9.getMaxContiguousTrailingDims(), 0);
+ EXPECT_EQ(m9.getNumContiguousTrailingDims(), 0);
// memref<?x2x2xf32, strided<[4,2,1]>
auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
- EXPECT_EQ(m10.getMaxContiguousTrailingDims(), 3);
+ EXPECT_EQ(m10.getNumContiguousTrailingDims(), 3);
// memref<?x2x2xf32, strided<[8,2,1]>
auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
- EXPECT_EQ(m11.getMaxContiguousTrailingDims(), 2);
+ EXPECT_EQ(m11.getNumContiguousTrailingDims(), 2);
// memref<?x2x2xf32, strided<[8,4,1]>
auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
- EXPECT_EQ(m12.getMaxContiguousTrailingDims(), 1);
+ EXPECT_EQ(m12.getNumContiguousTrailingDims(), 1);
// memref<?x2x2xf32, strided<[8,4,2]>
auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
- EXPECT_EQ(m13.getMaxContiguousTrailingDims(), 0);
+ EXPECT_EQ(m13.getNumContiguousTrailingDims(), 0);
// memref<2x2x1xf32, strided<[2,1,2]>
auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2}));
- EXPECT_EQ(m14.getMaxContiguousTrailingDims(), 3);
+ EXPECT_EQ(m14.getNumContiguousTrailingDims(), 3);
// memref<2x2x1xf32, strided<[2,1,?]>
auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _}));
- EXPECT_EQ(m15.getMaxContiguousTrailingDims(), 3);
+ EXPECT_EQ(m15.getNumContiguousTrailingDims(), 3);
// memref<2x2x1xf32, strided<[4,2,2]>
auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2}));
- EXPECT_EQ(m16.getMaxContiguousTrailingDims(), 1);
+ EXPECT_EQ(m16.getNumContiguousTrailingDims(), 1);
// memref<2x1x2xf32, strided<[2,4,1]>
auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1}));
- EXPECT_EQ(m17.getMaxContiguousTrailingDims(), 3);
+ EXPECT_EQ(m17.getNumContiguousTrailingDims(), 3);
// memref<2x1x2xf32, strided<[2,?,1]>
auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
- EXPECT_EQ(m18.getMaxContiguousTrailingDims(), 3);
+ EXPECT_EQ(m18.getNumContiguousTrailingDims(), 3);
}
TEST(MemRefLayout, contigTrailingDim) {
@@ -196,14 +196,14 @@ TEST(MemRefLayout, identityMaps) {
// memref<2x2x2xf32>
auto m1 = MemRefType::get({2, 2, 2}, f32);
- EXPECT_EQ(m1.getMaxContiguousTrailingDims(), 3);
+ EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
// memref<?x?x?xf32>
auto m2 = MemRefType::get({_, _, _}, f32);
- EXPECT_EQ(m2.getMaxContiguousTrailingDims(), 3);
+ EXPECT_EQ(m2.getNumContiguousTrailingDims(), 3);
EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
EXPECT_TRUE(m2.areTrailingDimsContiguous(3));
>From 3b88a09eb3e19548c4c979c11aa107cee52d712f Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 6 Jun 2025 14:51:30 +0000
Subject: [PATCH 06/11] [fixup] Fix a couple of FileCheck lines
---
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 5b2f2ab1f2cef..854800fb98fe4 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -212,7 +212,7 @@ func.func @transfer_read_dynamic_dim_to_flatten(
// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
@@ -490,7 +490,7 @@ func.func @transfer_write_dynamic_to_flatten(
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
>From 1179bf2a341c3197b6c69e0a5b2995b94c40a7f2 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 9 Jun 2025 14:43:54 +0000
Subject: [PATCH 07/11] [fixup] Add some negative tests with dynamic dimensions
---
.../Vector/vector-transfer-flatten.mlir | 48 ++++++++++++++++++-
1 file changed, 46 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 854800fb98fe4..2b012f1a97971 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -188,7 +188,29 @@ func.func @transfer_read_leading_dynamic_dims(
// -----
-// One of the dims to be flattened is dynamic - not supported ATM.
+// The vector could be a non-contiguous slice of the input
+// memref.
+
+func.func @negative_transfer_read_dynamic_dim_to_flatten(
+ %mem : memref<4x?x?x2xi8>) -> vector<2x2x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
+ memref<4x?x?x2xi8>, vector<2x2x2xi8>
+ return %res : vector<2x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten(
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten(
+// CHECK-128B-NOT: memref.collapse_shape
+
+// -----
+
+// Can flatten the righmost dynamic dimension
func.func @transfer_read_dynamic_dim_to_flatten(
%idx_1: index,
@@ -464,8 +486,28 @@ func.func @transfer_write_leading_dynamic_dims(
// CHECK-128B: memref.collapse_shape
// -----
+
+// The vector could be a non-contiguous slice of the input
+// memref.
+
+func.func @negative_transfer_write_dynamic_to_flatten(
+ %mem : memref<4x?x?x2xi8>,
+ %vec : vector<2x2x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0]
+ : vector<2x2x2xi8>, memref<4x?x?x2xi8>
+ return
+}
-// One of the dims to be flattened is dynamic - not supported ATM.
+// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten(
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten(
+// CHECK-128B-NOT: memref.collapse_shape
+
+// -----
func.func @transfer_write_dynamic_to_flatten(
%idx_1: index,
@@ -578,6 +620,8 @@ func.func @negative_out_of_bound_transfer_read(
// -----
+// Can flatten the righmost dynamic dimension
+
func.func @negative_out_of_bound_transfer_write(
%mem : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x3x2xi8>) {
%c0 = arith.constant 0 : index
>From 20b495df83e2e6b02958c8f3b4f7ecb5772b8c8a Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 9 Jun 2025 15:45:29 +0000
Subject: [PATCH 08/11] [fixip] Move a test and spell out some auto types
---
mlir/unittests/Dialect/MemRef/CMakeLists.txt | 1 -
mlir/unittests/IR/CMakeLists.txt | 1 +
.../LayoutTest.cpp => IR/MemrefLayoutTest.cpp} | 14 +++++++-------
3 files changed, 8 insertions(+), 8 deletions(-)
rename mlir/unittests/{Dialect/MemRef/LayoutTest.cpp => IR/MemrefLayoutTest.cpp} (96%)
diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index 1f6df1024f430..dede3ba0a885c 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,6 +1,5 @@
add_mlir_unittest(MLIRMemRefTests
InferShapeTest.cpp
- LayoutTest.cpp
)
mlir_target_link_libraries(MLIRMemRefTests
PRIVATE
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 9ab6029c3480d..7e4e57124309f 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_unittest(MLIRIRTests
IRMapping.cpp
InterfaceAttachmentTest.cpp
LocationTest.cpp
+ MemrefLayoutTest.cpp
OperationSupportTest.cpp
PatternMatchTest.cpp
ShapedTypeTest.cpp
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/IR/MemrefLayoutTest.cpp
similarity index 96%
rename from mlir/unittests/Dialect/MemRef/LayoutTest.cpp
rename to mlir/unittests/IR/MemrefLayoutTest.cpp
index 0ba8927da69e9..6c9bb40da52b3 100644
--- a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
+++ b/mlir/unittests/IR/MemrefLayoutTest.cpp
@@ -1,4 +1,4 @@
-//===- LayoutTest.cpp - unit tests related to memref layout --------------===//
+//===- LayoutTest.cpp - unit tests related to memref layout ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -19,8 +19,8 @@ TEST(MemRefLayout, maxContigDim) {
MLIRContext ctx;
OpBuilder b(&ctx);
- const auto _ = ShapedType::kDynamic;
- const auto f32 = b.getF32Type();
+ const int64_t _ = ShapedType::kDynamic;
+ const FloatType f32 = b.getF32Type();
auto strided = [&ctx](ArrayRef<int64_t> s) {
return StridedLayoutAttr::get(&ctx, 0, s);
};
@@ -102,8 +102,8 @@ TEST(MemRefLayout, contigTrailingDim) {
MLIRContext ctx;
OpBuilder b(&ctx);
- const auto _ = ShapedType::kDynamic;
- const auto f32 = b.getF32Type();
+ const int64_t _ = ShapedType::kDynamic;
+ const FloatType f32 = b.getF32Type();
auto strided = [&ctx](ArrayRef<int64_t> s) {
return StridedLayoutAttr::get(&ctx, 0, s);
};
@@ -191,8 +191,8 @@ TEST(MemRefLayout, identityMaps) {
MLIRContext ctx;
OpBuilder b(&ctx);
- const auto _ = ShapedType::kDynamic;
- const auto f32 = b.getF32Type();
+ const int64_t _ = ShapedType::kDynamic;
+ const FloatType f32 = b.getF32Type();
// memref<2x2x2xf32>
auto m1 = MemRefType::get({2, 2, 2}, f32);
>From af235d94411fc10b96df520eba654ecefa2e58ec Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 18 Jun 2025 11:21:26 +0000
Subject: [PATCH 09/11] [fixup] Misc unimportant changes
---
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir | 4 +---
mlir/unittests/IR/MemrefLayoutTest.cpp | 2 +-
2 files changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 2b012f1a97971..eae26106862e7 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -188,7 +188,7 @@ func.func @transfer_read_leading_dynamic_dims(
// -----
-// The vector could be a non-contiguous slice of the input
+// The vector is a non-contiguous slice of the input
// memref.
func.func @negative_transfer_read_dynamic_dim_to_flatten(
@@ -620,8 +620,6 @@ func.func @negative_out_of_bound_transfer_read(
// -----
-// Can flatten the righmost dynamic dimension
-
func.func @negative_out_of_bound_transfer_write(
%mem : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x3x2xi8>) {
%c0 = arith.constant 0 : index
diff --git a/mlir/unittests/IR/MemrefLayoutTest.cpp b/mlir/unittests/IR/MemrefLayoutTest.cpp
index 6c9bb40da52b3..d506b7b1a6687 100644
--- a/mlir/unittests/IR/MemrefLayoutTest.cpp
+++ b/mlir/unittests/IR/MemrefLayoutTest.cpp
@@ -15,7 +15,7 @@
using namespace mlir;
using namespace mlir::memref;
-TEST(MemRefLayout, maxContigDim) {
+TEST(MemRefLayout, numContigDim) {
MLIRContext ctx;
OpBuilder b(&ctx);
>From 145b055b8157443309f8e0879b6c7ba78a44c322 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 19 Jun 2025 10:20:53 +0000
Subject: [PATCH 10/11] [fixup] Address review comments
---
.../Vector/vector-transfer-flatten.mlir | 18 +-
mlir/unittests/IR/MemrefLayoutTest.cpp | 177 ++++++------------
2 files changed, 74 insertions(+), 121 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index eae26106862e7..45873aa93153d 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -210,7 +210,11 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
// -----
-// Can flatten the righmost dynamic dimension
+// When collapsing memref dimensions, we may include the rightmost dynamic
+// dimension (e.g., at position `k`) provided that the strides for dimensions
+// `k+1`, `k+2`, etc., ensure contiguity in memory. The stride at position `k`
+// itself does not factor into this. (Here "strides" mean both explicit and
+// implied by identity map)
func.func @transfer_read_dynamic_dim_to_flatten(
%idx_1: index,
@@ -486,8 +490,8 @@ func.func @transfer_write_leading_dynamic_dims(
// CHECK-128B: memref.collapse_shape
// -----
-
-// The vector could be a non-contiguous slice of the input
+
+// The vector is a non-contiguous slice of the input
// memref.
func.func @negative_transfer_write_dynamic_to_flatten(
@@ -509,7 +513,9 @@ func.func @negative_transfer_write_dynamic_to_flatten(
// -----
-func.func @transfer_write_dynamic_to_flatten(
+// See the comment in front of @transfer_read_dynamic_dim_to_flatten.
+
+func.func @transfer_write_dynamic_dim_to_flatten(
%idx_1: index,
%idx_2: index,
%vec : vector<1x2x6xi32>,
@@ -524,7 +530,7 @@ func.func @transfer_write_dynamic_to_flatten(
// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
-// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten
+// CHECK-LABEL: func.func @transfer_write_dynamic_dim_to_flatten
// CHECK-SAME: %[[IDX_1:arg0]]: index
// CHECK-SAME: %[[IDX_2:arg1]]: index
// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
@@ -539,7 +545,7 @@ func.func @transfer_write_dynamic_to_flatten(
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
-// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
+// CHECK-128B-LABEL: func @transfer_write_dynamic_dim_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
// -----
diff --git a/mlir/unittests/IR/MemrefLayoutTest.cpp b/mlir/unittests/IR/MemrefLayoutTest.cpp
index d506b7b1a6687..73005401b56fc 100644
--- a/mlir/unittests/IR/MemrefLayoutTest.cpp
+++ b/mlir/unittests/IR/MemrefLayoutTest.cpp
@@ -15,6 +15,9 @@
using namespace mlir;
using namespace mlir::memref;
+//
+// Test the correctness of `memref::getNumContiguousTrailingDims`
+//
TEST(MemRefLayout, numContigDim) {
MLIRContext ctx;
OpBuilder b(&ctx);
@@ -25,79 +28,108 @@ TEST(MemRefLayout, numContigDim) {
return StridedLayoutAttr::get(&ctx, 0, s);
};
- // memref<2x2x2xf32, strided<[4,2,1]>
+ // Create a sequence of test cases, starting with the base case of a
+ // contiguous 2x2x2 memref with fixed dimensions and then at each step
+ // introducing one dynamic dimension starting from the right.
+ // With thus obtained memref, start with maximally contiguous strides
+ // and then at each step gradually introduce discontinuity by increasing
+ // a fixed stride size from the left to right.
+
+ // In these and the following test cases the intent is to achieve code
+ // coverage of the main loop in `MemRefType::getNumContiguousTrailingDims()`.
+
+ // memref<2x2x2xf32, strided<[4,2,1]>>
auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
- // memref<2x2x2xf32, strided<[8,2,1]>
+ // memref<2x2x2xf32, strided<[8,2,1]>>
auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
EXPECT_EQ(m2.getNumContiguousTrailingDims(), 2);
- // memref<2x2x2xf32, strided<[8,4,1]>
+ // memref<2x2x2xf32, strided<[8,4,1]>>
auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
EXPECT_EQ(m3.getNumContiguousTrailingDims(), 1);
- // memref<2x2x2xf32, strided<[8,4,2]>
+ // memref<2x2x2xf32, strided<[8,4,2]>>
auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
- // memref<2x2x?xf32, strided<[?,?,1]>
+ // memref<2x2x?xf32, strided<[?,?,1]>>
auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
EXPECT_EQ(m5.getNumContiguousTrailingDims(), 1);
- // memref<2x2x?xf32, strided<[?,?,2]>
+ // memref<2x2x?xf32, strided<[?,?,2]>>
auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
EXPECT_EQ(m6.getNumContiguousTrailingDims(), 0);
- // memref<2x?x2xf32, strided<[?,2,1]>
+ // memref<2x?x2xf32, strided<[?,2,1]>>
auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
- // memref<2x?x2xf32, strided<[?,4,1]>
+ // memref<2x?x2xf32, strided<[?,4,1]>>
auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
EXPECT_EQ(m8.getNumContiguousTrailingDims(), 1);
- // memref<2x?x2xf32, strided<[?,4,2]>
+ // memref<2x?x2xf32, strided<[?,4,2]>>
auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
EXPECT_EQ(m9.getNumContiguousTrailingDims(), 0);
- // memref<?x2x2xf32, strided<[4,2,1]>
+ // memref<?x2x2xf32, strided<[4,2,1]>>
auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
EXPECT_EQ(m10.getNumContiguousTrailingDims(), 3);
- // memref<?x2x2xf32, strided<[8,2,1]>
+ // memref<?x2x2xf32, strided<[8,2,1]>>
auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
EXPECT_EQ(m11.getNumContiguousTrailingDims(), 2);
- // memref<?x2x2xf32, strided<[8,4,1]>
+ // memref<?x2x2xf32, strided<[8,4,1]>>
auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
EXPECT_EQ(m12.getNumContiguousTrailingDims(), 1);
- // memref<?x2x2xf32, strided<[8,4,2]>
+ // memref<?x2x2xf32, strided<[8,4,2]>>
auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
EXPECT_EQ(m13.getNumContiguousTrailingDims(), 0);
- // memref<2x2x1xf32, strided<[2,1,2]>
+ //
+ // Repeat a similar process, but this time introduce a unit memref dimension
+ // to test that strides corresponding to unit dimensions are immaterial, even
+ // if dynamic.
+ //
+
+ // memref<2x2x1xf32, strided<[2,1,2]>>
auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2}));
EXPECT_EQ(m14.getNumContiguousTrailingDims(), 3);
- // memref<2x2x1xf32, strided<[2,1,?]>
+ // memref<2x2x1xf32, strided<[2,1,?]>>
auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _}));
EXPECT_EQ(m15.getNumContiguousTrailingDims(), 3);
- // memref<2x2x1xf32, strided<[4,2,2]>
+ // memref<2x2x1xf32, strided<[4,2,2]>>
auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2}));
EXPECT_EQ(m16.getNumContiguousTrailingDims(), 1);
- // memref<2x1x2xf32, strided<[2,4,1]>
+ // memref<2x1x2xf32, strided<[2,4,1]>>
auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1}));
EXPECT_EQ(m17.getNumContiguousTrailingDims(), 3);
- // memref<2x1x2xf32, strided<[2,?,1]>
+ // memref<2x1x2xf32, strided<[2,?,1]>>
auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
EXPECT_EQ(m18.getNumContiguousTrailingDims(), 3);
+
+ //
+ // Special case for identity maps and no explicit `strided` attribute - the
+ // memref is entirely contiguous even if the strides cannot be determined
+ // statically.
+ //
+
+ // memref<?x?x?xf32>
+ auto m19 = MemRefType::get({_, _, _}, f32);
+ EXPECT_EQ(m19.getNumContiguousTrailingDims(), 3);
}
+//
+// Test the member function `memref::areTrailingDimsContiguous`
+//
TEST(MemRefLayout, contigTrailingDim) {
MLIRContext ctx;
OpBuilder b(&ctx);
@@ -108,103 +140,18 @@ TEST(MemRefLayout, contigTrailingDim) {
return StridedLayoutAttr::get(&ctx, 0, s);
};
- // memref<2x2x2xf32, strided<[4,2,1]>
- auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
- EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
- EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
- EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
-
- // memref<2x2x2xf32, strided<[8,2,1]>
- auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
- EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
- EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
- EXPECT_FALSE(m2.areTrailingDimsContiguous(3));
-
- // memref<2x2x2xf32, strided<[8,4,1]>
- auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
- EXPECT_TRUE(m3.areTrailingDimsContiguous(1));
- EXPECT_FALSE(m3.areTrailingDimsContiguous(2));
- EXPECT_FALSE(m3.areTrailingDimsContiguous(3));
-
- // memref<2x2x2xf32, strided<[8,4,2]>
- auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
- EXPECT_FALSE(m4.areTrailingDimsContiguous(1));
- EXPECT_FALSE(m4.areTrailingDimsContiguous(2));
- EXPECT_FALSE(m4.areTrailingDimsContiguous(3));
-
- // memref<2x2x?xf32, strided<[?,?,1]>
- auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
- EXPECT_TRUE(m5.areTrailingDimsContiguous(1));
- EXPECT_FALSE(m5.areTrailingDimsContiguous(2));
- EXPECT_FALSE(m5.areTrailingDimsContiguous(3));
-
- // memref<2x2x?xf32, strided<[?,?,2]>
- auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
- EXPECT_FALSE(m6.areTrailingDimsContiguous(1));
- EXPECT_FALSE(m6.areTrailingDimsContiguous(2));
- EXPECT_FALSE(m6.areTrailingDimsContiguous(3));
-
- // memref<2x?x2xf32, strided<[?,2,1]>
- auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
- EXPECT_TRUE(m7.areTrailingDimsContiguous(1));
- EXPECT_TRUE(m7.areTrailingDimsContiguous(2));
- EXPECT_FALSE(m7.areTrailingDimsContiguous(3));
-
- // memref<2x?x2xf32, strided<[?,4,1]>
- auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
- EXPECT_TRUE(m8.areTrailingDimsContiguous(1));
- EXPECT_FALSE(m8.areTrailingDimsContiguous(2));
- EXPECT_FALSE(m8.areTrailingDimsContiguous(3));
-
- // memref<2x?x2xf32, strided<[?,4,2]>
- auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
- EXPECT_FALSE(m9.areTrailingDimsContiguous(1));
- EXPECT_FALSE(m9.areTrailingDimsContiguous(2));
- EXPECT_FALSE(m9.areTrailingDimsContiguous(3));
-
- // memref<?x2x2xf32, strided<[4,2,1]>
- auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
- EXPECT_TRUE(m10.areTrailingDimsContiguous(1));
- EXPECT_TRUE(m10.areTrailingDimsContiguous(2));
- EXPECT_TRUE(m10.areTrailingDimsContiguous(3));
-
- // memref<?x2x2xf32, strided<[8,2,1]>
- auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
- EXPECT_TRUE(m11.areTrailingDimsContiguous(1));
- EXPECT_TRUE(m11.areTrailingDimsContiguous(2));
- EXPECT_FALSE(m11.areTrailingDimsContiguous(3));
-
- // memref<?x2x2xf32, strided<[8,4,1]>
- auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
- EXPECT_TRUE(m12.areTrailingDimsContiguous(1));
- EXPECT_FALSE(m12.areTrailingDimsContiguous(2));
- EXPECT_FALSE(m12.areTrailingDimsContiguous(3));
+ // Pick up a random test case among the ones already present in the file and
+ // ensure `areTrailingDimsContiguous(k)` returns `true` up to the value
+ // returned by `getNumContiguousTrailingDims` and `false` from that point on
+ // up to the memref rank.
- // memref<?x2x2xf32, strided<[8,4,2]>
- auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
- EXPECT_FALSE(m13.areTrailingDimsContiguous(1));
- EXPECT_FALSE(m13.areTrailingDimsContiguous(2));
- EXPECT_FALSE(m13.areTrailingDimsContiguous(3));
-}
-
-TEST(MemRefLayout, identityMaps) {
- MLIRContext ctx;
- OpBuilder b(&ctx);
+ // memref<2x?x2xf32, strided<[?,2,1]>>
+ auto m = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+ int64_t n = m.getNumContiguousTrailingDims();
+ for (int64_t i = 0; i <= n; ++i)
+ EXPECT_TRUE(m.areTrailingDimsContiguous(i));
- const int64_t _ = ShapedType::kDynamic;
- const FloatType f32 = b.getF32Type();
-
- // memref<2x2x2xf32>
- auto m1 = MemRefType::get({2, 2, 2}, f32);
- EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
- EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
- EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
- EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
-
- // memref<?x?x?xf32>
- auto m2 = MemRefType::get({_, _, _}, f32);
- EXPECT_EQ(m2.getNumContiguousTrailingDims(), 3);
- EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
- EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
- EXPECT_TRUE(m2.areTrailingDimsContiguous(3));
+ int64_t r = m.getRank();
+ for (int64_t i = n + 1; i <= r; ++i)
+ EXPECT_FALSE(m.areTrailingDimsContiguous(i));
}
>From b5425a0c6b6c8407525d88028d228932fc96796a Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 19 Jun 2025 13:43:11 +0000
Subject: [PATCH 11/11] [fixup] Reduce the number of test cases
---
mlir/unittests/IR/MemrefLayoutTest.cpp | 136 ++++++++-----------------
1 file changed, 45 insertions(+), 91 deletions(-)
diff --git a/mlir/unittests/IR/MemrefLayoutTest.cpp b/mlir/unittests/IR/MemrefLayoutTest.cpp
index 73005401b56fc..f243a76ee660c 100644
--- a/mlir/unittests/IR/MemrefLayoutTest.cpp
+++ b/mlir/unittests/IR/MemrefLayoutTest.cpp
@@ -28,103 +28,60 @@ TEST(MemRefLayout, numContigDim) {
return StridedLayoutAttr::get(&ctx, 0, s);
};
- // Create a sequence of test cases, starting with the base case of a
- // contiguous 2x2x2 memref with fixed dimensions and then at each step
- // introducing one dynamic dimension starting from the right.
- // With thus obtained memref, start with maximally contiguous strides
- // and then at each step gradually introduce discontinuity by increasing
- // a fixed stride size from the left to right.
-
- // In these and the following test cases the intent is to achieve code
- // coverage of the main loop in `MemRefType::getNumContiguousTrailingDims()`.
-
- // memref<2x2x2xf32, strided<[4,2,1]>>
- auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
- EXPECT_EQ(m1.getNumContiguousTrailingDims(), 3);
-
- // memref<2x2x2xf32, strided<[8,2,1]>>
- auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
- EXPECT_EQ(m2.getNumContiguousTrailingDims(), 2);
-
- // memref<2x2x2xf32, strided<[8,4,1]>>
- auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
- EXPECT_EQ(m3.getNumContiguousTrailingDims(), 1);
+ // Special case for identity maps and no explicit `strided` attribute - the
+ // memref is entirely contiguous even if the strides cannot be determined
+ // statically.
- // memref<2x2x2xf32, strided<[8,4,2]>>
- auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
- EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
+ // memref<?x?x?xf32>
+ auto m0 = MemRefType::get({_, _, _}, f32);
+ EXPECT_EQ(m0.getNumContiguousTrailingDims(), 3);
- // memref<2x2x?xf32, strided<[?,?,1]>>
- auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
- EXPECT_EQ(m5.getNumContiguousTrailingDims(), 1);
+ // Conservatively assume memref is sparse everywhere if cannot get the
+ // strides.
- // memref<2x2x?xf32, strided<[?,?,2]>>
- auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
- EXPECT_EQ(m6.getNumContiguousTrailingDims(), 0);
+ // memref<2x2x2xf32, (i,j,k)->(i,k,j)>
+ auto m1 = MemRefType::get(
+ {2, 2, 2}, f32,
+ AffineMap::getPermutationMap(ArrayRef<int64_t>{0, 2, 1}, &ctx));
+ EXPECT_EQ(m1.getNumContiguousTrailingDims(), 0);
- // memref<2x?x2xf32, strided<[?,2,1]>>
- auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
- EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
+ // A base cases of a fixed memref with the usual strides.
- // memref<2x?x2xf32, strided<[?,4,1]>>
- auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
- EXPECT_EQ(m8.getNumContiguousTrailingDims(), 1);
+ // memref<2x2x2xf32, strided<[4, 2, 1]>>
+ auto m3 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_EQ(m3.getNumContiguousTrailingDims(), 3);
- // memref<2x?x2xf32, strided<[?,4,2]>>
- auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
- EXPECT_EQ(m9.getNumContiguousTrailingDims(), 0);
+ // A fixed memref with a discontinuity in the rightmost dimension.
- // memref<?x2x2xf32, strided<[4,2,1]>>
- auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
- EXPECT_EQ(m10.getNumContiguousTrailingDims(), 3);
-
- // memref<?x2x2xf32, strided<[8,2,1]>>
- auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
- EXPECT_EQ(m11.getNumContiguousTrailingDims(), 2);
+ // memref<2x2x2xf32, strided<[8, 4, 2]>>
+ auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
- // memref<?x2x2xf32, strided<[8,4,1]>>
- auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
- EXPECT_EQ(m12.getNumContiguousTrailingDims(), 1);
+ // A fixed memref with a discontinuity in the "middle".
- // memref<?x2x2xf32, strided<[8,4,2]>>
- auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
- EXPECT_EQ(m13.getNumContiguousTrailingDims(), 0);
+ // memref<2x2x2xf32, strided<[8, 2, 1]>>
+ auto m5 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_EQ(m5.getNumContiguousTrailingDims(), 2);
- //
- // Repeat a similar process, but this time introduce a unit memref dimension
- // to test that strides corresponding to unit dimensions are immaterial, even
- // if dynamic.
- //
+ // A dynamic memref where the dynamic dimension breaks continuity.
- // memref<2x2x1xf32, strided<[2,1,2]>>
- auto m14 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, 2}));
- EXPECT_EQ(m14.getNumContiguousTrailingDims(), 3);
+ // memref<2x?x2xf32, strided<[4, 2, 1]>>
+ auto m6 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1}));
+ EXPECT_EQ(m6.getNumContiguousTrailingDims(), 2);
- // memref<2x2x1xf32, strided<[2,1,?]>>
- auto m15 = MemRefType::get({2, 2, 1}, f32, strided({2, 1, _}));
- EXPECT_EQ(m15.getNumContiguousTrailingDims(), 3);
+ // A edge case of a dynamic memref where the dynamic dimension is the first
+ // one.
- // memref<2x2x1xf32, strided<[4,2,2]>>
- auto m16 = MemRefType::get({2, 2, 1}, f32, strided({4, 2, 2}));
- EXPECT_EQ(m16.getNumContiguousTrailingDims(), 1);
+ // memref<?x2x2xf32, strided<[4, 2, 1]>>
+ auto m7 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1}));
+ EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
- // memref<2x1x2xf32, strided<[2,4,1]>>
- auto m17 = MemRefType::get({2, 1, 2}, f32, strided({2, 4, 1}));
- EXPECT_EQ(m17.getNumContiguousTrailingDims(), 3);
+ // A memref with a unit dimension. Unit dimensions do not affect continuity,
+ // even if the corresponding stride is dynamic.
// memref<2x1x2xf32, strided<[2,?,1]>>
- auto m18 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
- EXPECT_EQ(m18.getNumContiguousTrailingDims(), 3);
-
- //
- // Special case for identity maps and no explicit `strided` attribute - the
- // memref is entirely contiguous even if the strides cannot be determined
- // statically.
- //
-
- // memref<?x?x?xf32>
- auto m19 = MemRefType::get({_, _, _}, f32);
- EXPECT_EQ(m19.getNumContiguousTrailingDims(), 3);
+ auto m8 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
+ EXPECT_EQ(m8.getNumContiguousTrailingDims(), 3);
}
//
@@ -140,18 +97,15 @@ TEST(MemRefLayout, contigTrailingDim) {
return StridedLayoutAttr::get(&ctx, 0, s);
};
- // Pick up a random test case among the ones already present in the file and
- // ensure `areTrailingDimsContiguous(k)` returns `true` up to the value
- // returned by `getNumContiguousTrailingDims` and `false` from that point on
- // up to the memref rank.
+ // A not-entirely-continuous, not-entirely-discontinuous memref.
+ // ensure `areTrailingDimsContiguous` returns `true` for the value
+ // returned by `getNumContiguousTrailingDims` and `false` for the next bigger
+ // number.
// memref<2x?x2xf32, strided<[?,2,1]>>
auto m = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
int64_t n = m.getNumContiguousTrailingDims();
- for (int64_t i = 0; i <= n; ++i)
- EXPECT_TRUE(m.areTrailingDimsContiguous(i));
-
- int64_t r = m.getRank();
- for (int64_t i = n + 1; i <= r; ++i)
- EXPECT_FALSE(m.areTrailingDimsContiguous(i));
+ EXPECT_TRUE(m.areTrailingDimsContiguous(n));
+ ASSERT_TRUE(n + 1 <= m.getRank());
+ EXPECT_FALSE(m.areTrailingDimsContiguous(n + 1));
}
More information about the Mlir-commits
mailing list