[Mlir-commits] [mlir] [MLIR] Determine contiguousness of memrefs with dynamic dimensions (PR #142421)
Momchil Velikov
llvmlistbot at llvm.org
Mon Jun 2 09:12:35 PDT 2025
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/142421
This patch enhances `MemRefType::areTrailingDimsContiguous` to also handle memrefs with dynamic dimensions.
The implementation itself is based on a new member function `MemRefType::getMaxCollapsableTrailingDims` that return the maximum number of trailing dimensions that can be collapsed - trivially all dimensions for memrefs with identity layout, or by examining the memref strides stopping at discontiguous or statically unknown strides.
>From 1d025f3272712eeb0d07b7adb0722e1966410216 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 2 Jun 2025 13:20:14 +0000
Subject: [PATCH] [MLIR] Determine contiguousness of memrefs with a dynamic
dimensions
This patch enhances `MemRefType::areTrailingDimsContiguous`
to also handle memrefs with dynamic dimensions.
The implementation itself is based on a new member function
`MemRefType::getMaxCollapsableTrailingDims` that return
the maximum number of trailing dimensions that can be
collapsed - trivially all dimensions for memrefs with identity
layout, or by examining the memref strides stopping at discontguous
or statically unknown strides.
---
mlir/include/mlir/IR/BuiltinTypes.td | 14 ++
mlir/lib/IR/BuiltinTypes.cpp | 47 +++--
.../Vector/vector-transfer-flatten.mlir | 49 ++++-
mlir/unittests/Dialect/MemRef/CMakeLists.txt | 1 +
mlir/unittests/Dialect/MemRef/LayoutTest.cpp | 190 ++++++++++++++++++
5 files changed, 269 insertions(+), 32 deletions(-)
create mode 100644 mlir/unittests/Dialect/MemRef/LayoutTest.cpp
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 771de01fc8d5d..1d12f70882176 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -838,6 +838,20 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
///
bool areTrailingDimsContiguous(int64_t n);
+ /// Return the maximum number of trailing dimensions that can be
+ /// collapsed.
+ ///
+ /// Examples:
+ /// - memref<2x3x2xi8, strided<[24, 12, 2]>, the number of collapsable
+ /// trailing dimensions is 0
+ /// - memref<2x3x2xi8, strided<[12, 6, 1]>, the number of collapsable
+ /// trailing dimensions is 3
+ /// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]>, the number of
+ /// collapsable trailing dimensions is 2.
+ /// - memref<5x4x?x2xi8>, the number of collapsable trailing dimensions
+ /// is 4.
+ int64_t getMaxCollapsableTrailingDims();
+
/// Return a version of this type with identity layout if it can be
/// determined statically that the layout is the canonical contiguous
/// strided layout. Otherwise pass the layout into `simplifyAffineMap`
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d47e360e9dc13..cc23d08515ff3 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -646,35 +646,40 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
}
bool MemRefType::areTrailingDimsContiguous(int64_t n) {
- if (!isLastDimUnitStride())
- return false;
+ return getLayout().isIdentity() ||
+ getMaxCollapsableTrailingDims() >= std::min(n, getRank());
+}
- auto memrefShape = getShape().take_back(n);
- if (ShapedType::isDynamicShape(memrefShape))
- return false;
+int64_t MemRefType::getMaxCollapsableTrailingDims() {
+ const int64_t n = getRank();
+ // memrefs with identity layout are entirely contiguous.
if (getLayout().isIdentity())
- return true;
+ return n;
+ // Get the strides (if any). Failing to do that, conservatively assume a
+ // non-contiguous layout.
int64_t offset;
- SmallVector<int64_t> stridesFull;
- if (!succeeded(getStridesAndOffset(stridesFull, offset)))
- return false;
- auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
-
- if (strides.empty())
- return true;
+ SmallVector<int64_t> strides;
+ if (!succeeded(getStridesAndOffset(strides, offset)))
+ return 0;
- // Check whether strides match "flattened" dims.
- SmallVector<int64_t> flattenedDims;
- auto dimProduct = 1;
- for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
- dimProduct *= dim;
- flattenedDims.push_back(dimProduct);
+ auto shape = getShape();
+
+ // A memref with dimensions `d0, d1, ..., dn-1` and strides
+ // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
+ // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
+ // for `i` in `[k, n-1]`.
+ int64_t dimProduct = 1;
+ for (int64_t i = n - 1; i >= 0; --i) {
+ if (strides[i] != dimProduct)
+ return n - i - 1;
+ if (shape[i] == ShapedType::kDynamic)
+ return n - i;
+ dimProduct *= shape[i];
}
- strides = strides.drop_back(1);
- return llvm::equal(strides, llvm::reverse(flattenedDims));
+ return n;
}
MemRefType MemRefType::canonicalizeStridedLayout() {
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index e840dc6bbf224..5b2f2ab1f2cef 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -190,7 +190,7 @@ func.func @transfer_read_leading_dynamic_dims(
// One of the dims to be flattened is dynamic - not supported ATM.
-func.func @negative_transfer_read_dynamic_dim_to_flatten(
+func.func @transfer_read_dynamic_dim_to_flatten(
%idx_1: index,
%idx_2: index,
%mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -203,11 +203,25 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
return %res : vector<1x2x6xi32>
}
-// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
+// CHECK-SAME: %[[IDX_1:arg0]]
+// CHECK-SAME: %[[IDX_2:arg1]]
+// CHECK-SAME: %[[MEM:arg2]]
+// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
+// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
+// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
+// CHECK: return %[[RESULT]] : vector<1x2x6xi32>
+
+
+// CHECK-128B-LABEL: func @transfer_read_dynamic_dim_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
// -----
@@ -453,7 +467,7 @@ func.func @transfer_write_leading_dynamic_dims(
// One of the dims to be flattened is dynamic - not supported ATM.
-func.func @negative_transfer_write_dynamic_to_flatten(
+func.func @transfer_write_dynamic_to_flatten(
%idx_1: index,
%idx_2: index,
%vec : vector<1x2x6xi32>,
@@ -466,11 +480,24 @@ func.func @negative_transfer_write_dynamic_to_flatten(
return
}
-// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten
+// CHECK-SAME: %[[IDX_1:arg0]]: index
+// CHECK-SAME: %[[IDX_2:arg1]]: index
+// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
+// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
+
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
+// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
-// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
+// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
// -----
diff --git a/mlir/unittests/Dialect/MemRef/CMakeLists.txt b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
index dede3ba0a885c..1f6df1024f430 100644
--- a/mlir/unittests/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/unittests/Dialect/MemRef/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRMemRefTests
InferShapeTest.cpp
+ LayoutTest.cpp
)
mlir_target_link_libraries(MLIRMemRefTests
PRIVATE
diff --git a/mlir/unittests/Dialect/MemRef/LayoutTest.cpp b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
new file mode 100644
index 0000000000000..e01c0056d5cec
--- /dev/null
+++ b/mlir/unittests/Dialect/MemRef/LayoutTest.cpp
@@ -0,0 +1,190 @@
+//===- LayoutTest.cpp - unit tests related to memref layout --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+TEST(MemRefLayout, maxCollapseDim) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ const auto _ = ShapedType::kDynamic;
+ const auto f32 = b.getF32Type();
+ auto strided = [&ctx](ArrayRef<int64_t> s) {
+ return StridedLayoutAttr::get(&ctx, 0, s);
+ };
+
+ // memref<2x2x2xf32, strided<[4,2,1]>
+ auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+
+ // memref<2x2x2xf32, strided<[8,2,1]>
+ auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 2);
+
+ // memref<2x2x2xf32, strided<[8,4,1]>
+ auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_EQ(m3.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<2x2x2xf32, strided<[8,4,2]>
+ auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_EQ(m4.getMaxCollapsableTrailingDims(), 0);
+
+ // memref<2x2x?xf32, strided<[?,?,1]>
+ auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
+ EXPECT_EQ(m5.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<2x2x?xf32, strided<[?,?,2]>
+ auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
+ EXPECT_EQ(m6.getMaxCollapsableTrailingDims(), 0);
+
+ // memref<2x?x2xf32, strided<[?,2,1]>
+ auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+ EXPECT_EQ(m7.getMaxCollapsableTrailingDims(), 2);
+
+ // memref<2x?x2xf32, strided<[?,4,1]>
+ auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
+ EXPECT_EQ(m8.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<2x?x2xf32, strided<[?,4,2]>
+ auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
+ EXPECT_EQ(m9.getMaxCollapsableTrailingDims(), 0);
+
+ // memref<?x2x2xf32, strided<[4,2,1]>
+ auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_EQ(m10.getMaxCollapsableTrailingDims(), 3);
+
+ // memref<?x2x2xf32, strided<[8,2,1]>
+ auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_EQ(m11.getMaxCollapsableTrailingDims(), 2);
+
+ // memref<?x2x2xf32, strided<[8,4,1]>
+ auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_EQ(m12.getMaxCollapsableTrailingDims(), 1);
+
+ // memref<?x2x2xf32, strided<[8,4,2]>
+ auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_EQ(m13.getMaxCollapsableTrailingDims(), 0);
+}
+
+TEST(MemRefLayout, contigTrailingDim) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ const auto _ = ShapedType::kDynamic;
+ const auto f32 = b.getF32Type();
+ auto strided = [&ctx](ArrayRef<int64_t> s) {
+ return StridedLayoutAttr::get(&ctx, 0, s);
+ };
+
+ // memref<2x2x2xf32, strided<[4,2,1]>
+ auto m1 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
+
+ // memref<2x2x2xf32, strided<[8,2,1]>
+ auto m2 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m2.areTrailingDimsContiguous(3));
+
+ // memref<2x2x2xf32, strided<[8,4,1]>
+ auto m3 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_TRUE(m3.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m3.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m3.areTrailingDimsContiguous(3));
+
+ // memref<2x2x2xf32, strided<[8,4,2]>
+ auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_FALSE(m4.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m4.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m4.areTrailingDimsContiguous(3));
+
+ // memref<2x2x?xf32, strided<[?,?,1]>
+ auto m5 = MemRefType::get({2, 2, _}, f32, strided({_, _, 1}));
+ EXPECT_TRUE(m5.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m5.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m5.areTrailingDimsContiguous(3));
+
+ // memref<2x2x?xf32, strided<[?,?,2]>
+ auto m6 = MemRefType::get({2, 2, _}, f32, strided({_, _, 2}));
+ EXPECT_FALSE(m6.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m6.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m6.areTrailingDimsContiguous(3));
+
+ // memref<2x?x2xf32, strided<[?,2,1]>
+ auto m7 = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+ EXPECT_TRUE(m7.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m7.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m7.areTrailingDimsContiguous(3));
+
+ // memref<2x?x2xf32, strided<[?,4,1]>
+ auto m8 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 1}));
+ EXPECT_TRUE(m8.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m8.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m8.areTrailingDimsContiguous(3));
+
+ // memref<2x?x2xf32, strided<[?,4,2]>
+ auto m9 = MemRefType::get({2, _, 2}, f32, strided({_, 4, 2}));
+ EXPECT_FALSE(m9.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m9.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m9.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[4,2,1]>
+ auto m10 = MemRefType::get({_, 2, 2}, f32, strided({4, 2, 1}));
+ EXPECT_TRUE(m10.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m10.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m10.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[8,2,1]>
+ auto m11 = MemRefType::get({_, 2, 2}, f32, strided({8, 2, 1}));
+ EXPECT_TRUE(m11.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m11.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m11.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[8,4,1]>
+ auto m12 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 1}));
+ EXPECT_TRUE(m12.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m12.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m12.areTrailingDimsContiguous(3));
+
+ // memref<?x2x2xf32, strided<[8,4,2]>
+ auto m13 = MemRefType::get({_, 2, 2}, f32, strided({8, 4, 2}));
+ EXPECT_FALSE(m13.areTrailingDimsContiguous(1));
+ EXPECT_FALSE(m13.areTrailingDimsContiguous(2));
+ EXPECT_FALSE(m13.areTrailingDimsContiguous(3));
+}
+
+TEST(MemRefLayout, identityMaps) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+
+ const auto _ = ShapedType::kDynamic;
+ const auto f32 = b.getF32Type();
+
+ // memref<2x2x2xf32>
+ auto m1 = MemRefType::get({2, 2, 2}, f32);
+ EXPECT_EQ(m1.getMaxCollapsableTrailingDims(), 3);
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m1.areTrailingDimsContiguous(3));
+
+ // memref<?x?x?xf32>
+ auto m2 = MemRefType::get({_, _, _}, f32);
+ EXPECT_EQ(m2.getMaxCollapsableTrailingDims(), 3);
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(1));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(2));
+ EXPECT_TRUE(m2.areTrailingDimsContiguous(3));
+}
More information about the Mlir-commits
mailing list