[Mlir-commits] [mlir] 4af96a9 - [MLIR] Determine contiguousness of memrefs with dynamic dimensions (#142421)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 23 01:28:37 PDT 2025


Author: Momchil Velikov
Date: 2025-06-23T09:28:33+01:00
New Revision: 4af96a9d83335b3b59f3441af47c879c7a9eb183

URL: https://github.com/llvm/llvm-project/commit/4af96a9d83335b3b59f3441af47c879c7a9eb183
DIFF: https://github.com/llvm/llvm-project/commit/4af96a9d83335b3b59f3441af47c879c7a9eb183.diff

LOG: [MLIR] Determine contiguousness of memrefs with dynamic dimensions (#142421)

This patch enhances `MemRefType::areTrailingDimsContiguous` to also
handle memrefs with dynamic dimensions.

The implementation itself is based on a new member function
`MemRefType::getMaxCollapsableTrailingDims` that return the maximum
number of trailing dimensions that can be collapsed - trivially all
dimensions for memrefs with identity layout, or by examining the memref
strides stopping at discontiguous or statically unknown strides.

Added: 
    mlir/unittests/IR/MemrefLayoutTest.cpp

Modified: 
    mlir/include/mlir/Dialect/Utils/IndexingUtils.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/lib/Dialect/Utils/IndexingUtils.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
    mlir/unittests/IR/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 99218f491ddef..8524072929793 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -40,7 +40,7 @@ class ArrayAttr;
 /// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
 ///   `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
 ///
-/// `sizes` elements are asserted to be non-negative.
+/// `sizes` elements `s1` to `sn` are asserted to be non-negative.
 ///
 /// Return an empty vector if `sizes` is empty.
 SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 89ade79a3ac02..a0c8acea91dc5 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -839,6 +839,25 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
     ///
     bool areTrailingDimsContiguous(int64_t n);
 
+    /// Return the number of trailing dimensions that are contiguous.
+    ///
+    /// Examples:
+    ///   - memref<5x3x2xi8, strided<[6,2,1]>>, the number of collapsable
+    ///     trailing dimensions is 3
+    ///   - memref<5x3x2xi8, strided<[12,2,1]>>, the number of collapsable
+    ///     trailing dimensions is 2 (dimension 0 is non-contiguous)
+    ///   - memref<5x3x2xi8, strided<[12,4,1]>>, the number of collapsable
+    ///     trailing dimensions is 1 (dimension 1 is non-contiguous)
+    ///   - memref<5x3x2xi8, strided<[12,4,2]>>, the number of collapsable
+    ///     trailing dimensions is 0 (dimension 2 is non-contiguous)
+    ///   - memref<?x3x2xi8, strided<[6,2,1]>>, the number of collapsable
+    ///     trailing dimensions is 3
+    ///   - memref<?x3x2xi8, strided<[12,2,1]>>, the number of collapsable
+    ///     trailing dimensions is 2 (dimension 0 is non-contiguous)
+    ///   - memref<5x?x2xi8, strided<[?,2,1]>>, the number of collapsable
+    ///     trailing dimensions is 2 (stride 0 is dynamic)
+    int64_t getNumContiguousTrailingDims();
+
     /// Return a version of this type with identity layout if it can be
     /// determined statically that the layout is the canonical contiguous
     /// strided layout. Otherwise pass the layout into `simplifyAffineMap`

diff  --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 8de77e2c3cb08..e1648ab99ff25 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -69,7 +69,8 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
 //===----------------------------------------------------------------------===//
 
 SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
-  assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
+  assert((sizes.empty() ||
+          llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
          "sizes must be nonnegative");
   int64_t unit = 1;
   return ::computeSuffixProductImpl(sizes, unit);

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index e3a00ac5a14b1..6661efa8907b7 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -660,35 +660,45 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
 }
 
 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
-  if (!isLastDimUnitStride())
-    return false;
+  assert(n <= getRank() &&
+         "number of dimensions to check must not exceed rank");
+  return n <= getNumContiguousTrailingDims();
+}
 
-  auto memrefShape = getShape().take_back(n);
-  if (ShapedType::isDynamicShape(memrefShape))
-    return false;
+int64_t MemRefType::getNumContiguousTrailingDims() {
+  const int64_t n = getRank();
 
+  // memrefs with identity layout are entirely contiguous.
   if (getLayout().isIdentity())
-    return true;
+    return n;
 
+  // Get the strides (if any). Failing to do that, conservatively assume a
+  // non-contiguous layout.
   int64_t offset;
-  SmallVector<int64_t> stridesFull;
-  if (!succeeded(getStridesAndOffset(stridesFull, offset)))
-    return false;
-  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
-
-  if (strides.empty())
-    return true;
+  SmallVector<int64_t> strides;
+  if (!succeeded(getStridesAndOffset(strides, offset)))
+    return 0;
 
-  // Check whether strides match "flattened" dims.
-  SmallVector<int64_t> flattenedDims;
-  auto dimProduct = 1;
-  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
-    dimProduct *= dim;
-    flattenedDims.push_back(dimProduct);
+  ArrayRef<int64_t> shape = getShape();
+
+  // A memref with dimensions `d0, d1, ..., dn-1` and strides
+  // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
+  // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
+  // for `i` in `[k, n-1]`.
+  // Ignore stride elements if the corresponding dimension is 1, as they are
+  // of no consequence.
+  int64_t dimProduct = 1;
+  for (int64_t i = n - 1; i >= 0; --i) {
+    if (shape[i] == 1)
+      continue;
+    if (strides[i] != dimProduct)
+      return n - i - 1;
+    if (shape[i] == ShapedType::kDynamic)
+      return n - i;
+    dimProduct *= shape[i];
   }
 
-  strides = strides.drop_back(1);
-  return llvm::equal(strides, llvm::reverse(flattenedDims));
+  return n;
 }
 
 MemRefType MemRefType::canonicalizeStridedLayout() {

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index e840dc6bbf224..45873aa93153d 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -188,9 +188,35 @@ func.func @transfer_read_leading_dynamic_dims(
 
 // -----
 
-// One of the dims to be flattened is dynamic - not supported ATM.
+// The vector is a non-contiguous slice of the input
+// memref.
 
 func.func @negative_transfer_read_dynamic_dim_to_flatten(
+    %mem : memref<4x?x?x2xi8>) -> vector<2x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
+    memref<4x?x?x2xi8>, vector<2x2x2xi8>
+  return %res : vector<2x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten(
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
+// -----
+
+// When collapsing memref dimensions, we may include the rightmost dynamic
+// dimension (e.g., at position `k`) provided that the strides for dimensions
+// `k+1`, `k+2`, etc., ensure contiguity in memory. The stride at position `k`
+// itself does not factor into this. (Here "strides" mean both explicit and
+// implied by identity map)
+
+func.func @transfer_read_dynamic_dim_to_flatten(
     %idx_1: index,
     %idx_2: index,
     %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -203,11 +229,25 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
   return %res : vector<1x2x6xi32>
 }
 
-// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
+// CHECK-SAME:    %[[IDX_1:arg0]]
+// CHECK-SAME:    %[[IDX_2:arg1]]
+// CHECK-SAME:    %[[MEM:arg2]]
+// CHECK:              %[[C0_I32:.*]] = arith.constant 0 : i32
+// CHECK:              %[[C0:.*]] = arith.constant 0 : index
+// CHECK:              %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}:  [[0], [1, 2, 3]]
+// CHECK-SAME:           memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK:              %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
+// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK:              %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
+// CHECK:              return %[[RESULT]] : vector<1x2x6xi32>
+
+
+// CHECK-128B-LABEL: func @transfer_read_dynamic_dim_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
@@ -451,9 +491,31 @@ func.func @transfer_write_leading_dynamic_dims(
 
 // -----
 
-// One of the dims to be flattened is dynamic - not supported ATM.
+// The vector is a non-contiguous slice of the input
+// memref.
 
 func.func @negative_transfer_write_dynamic_to_flatten(
+    %mem : memref<4x?x?x2xi8>,
+    %vec : vector<2x2x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write  %vec, %mem[%c0, %c0, %c0, %c0]
+    : vector<2x2x2xi8>, memref<4x?x?x2xi8>
+  return
+}
+
+// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten(
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
+// -----
+
+// See the comment in front of @transfer_read_dynamic_dim_to_flatten.
+
+func.func @transfer_write_dynamic_dim_to_flatten(
     %idx_1: index,
     %idx_2: index,
     %vec : vector<1x2x6xi32>,
@@ -466,11 +528,24 @@ func.func @negative_transfer_write_dynamic_to_flatten(
   return
 }
 
-// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
+// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL: func.func @transfer_write_dynamic_dim_to_flatten
+// CHECK-SAME:    %[[IDX_1:arg0]]: index
+// CHECK-SAME:    %[[IDX_2:arg1]]: index
+// CHECK-SAME:    %[[VEC:arg2]]: vector<1x2x6xi32>
+// CHECK-SAME:    %[[MEM:arg3]]: memref<1x?x4x6xi32>
+
+// CHECK:              %[[C0:.*]] = arith.constant 0 : index
+// CHECK:              %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}:  [[0], [1, 2, 3]]
+// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK:              %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK:              vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
+// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
 
-// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
+// CHECK-128B-LABEL: func @transfer_write_dynamic_dim_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----

diff  --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index 7700644864570..d22afb3003e76 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_unittest(MLIRIRTests
   IRMapping.cpp
   InterfaceAttachmentTest.cpp
   LocationTest.cpp
+  MemrefLayoutTest.cpp
   OperationSupportTest.cpp
   PatternMatchTest.cpp
   ShapedTypeTest.cpp

diff  --git a/mlir/unittests/IR/MemrefLayoutTest.cpp b/mlir/unittests/IR/MemrefLayoutTest.cpp
new file mode 100644
index 0000000000000..f243a76ee660c
--- /dev/null
+++ b/mlir/unittests/IR/MemrefLayoutTest.cpp
@@ -0,0 +1,111 @@
+//===- LayoutTest.cpp - unit tests related to memref layout ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+//
+// Test the correctness of `memref::getNumContiguousTrailingDims`
+//
+TEST(MemRefLayout, numContigDim) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+
+  const int64_t _ = ShapedType::kDynamic;
+  const FloatType f32 = b.getF32Type();
+  auto strided = [&ctx](ArrayRef<int64_t> s) {
+    return StridedLayoutAttr::get(&ctx, 0, s);
+  };
+
+  // Special case for identity maps and no explicit `strided` attribute - the
+  // memref is entirely contiguous even if the strides cannot be determined
+  // statically.
+
+  // memref<?x?x?xf32>
+  auto m0 = MemRefType::get({_, _, _}, f32);
+  EXPECT_EQ(m0.getNumContiguousTrailingDims(), 3);
+
+  // Conservatively assume memref is sparse everywhere if cannot get the
+  // strides.
+
+  // memref<2x2x2xf32, (i,j,k)->(i,k,j)>
+  auto m1 = MemRefType::get(
+      {2, 2, 2}, f32,
+      AffineMap::getPermutationMap(ArrayRef<int64_t>{0, 2, 1}, &ctx));
+  EXPECT_EQ(m1.getNumContiguousTrailingDims(), 0);
+
+  // A base cases of a fixed memref with the usual strides.
+
+  // memref<2x2x2xf32, strided<[4, 2, 1]>>
+  auto m3 = MemRefType::get({2, 2, 2}, f32, strided({4, 2, 1}));
+  EXPECT_EQ(m3.getNumContiguousTrailingDims(), 3);
+
+  // A fixed memref with a discontinuity in the rightmost dimension.
+
+  // memref<2x2x2xf32, strided<[8, 4, 2]>>
+  auto m4 = MemRefType::get({2, 2, 2}, f32, strided({8, 4, 2}));
+  EXPECT_EQ(m4.getNumContiguousTrailingDims(), 0);
+
+  // A fixed memref with a discontinuity in the "middle".
+
+  // memref<2x2x2xf32, strided<[8, 2, 1]>>
+  auto m5 = MemRefType::get({2, 2, 2}, f32, strided({8, 2, 1}));
+  EXPECT_EQ(m5.getNumContiguousTrailingDims(), 2);
+
+  // A dynamic memref where the dynamic dimension breaks continuity.
+
+  // memref<2x?x2xf32, strided<[4, 2, 1]>>
+  auto m6 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1}));
+  EXPECT_EQ(m6.getNumContiguousTrailingDims(), 2);
+
+  // A edge case of a dynamic memref where the dynamic dimension is the first
+  // one.
+
+  // memref<?x2x2xf32, strided<[4, 2, 1]>>
+  auto m7 = MemRefType::get({2, _, 2}, f32, strided({4, 2, 1}));
+  EXPECT_EQ(m7.getNumContiguousTrailingDims(), 2);
+
+  // A memref with a unit dimension. Unit dimensions do not affect continuity,
+  // even if the corresponding stride is dynamic.
+
+  // memref<2x1x2xf32, strided<[2,?,1]>>
+  auto m8 = MemRefType::get({2, 1, 2}, f32, strided({2, _, 1}));
+  EXPECT_EQ(m8.getNumContiguousTrailingDims(), 3);
+}
+
+//
+// Test the member function `memref::areTrailingDimsContiguous`
+//
+TEST(MemRefLayout, contigTrailingDim) {
+  MLIRContext ctx;
+  OpBuilder b(&ctx);
+
+  const int64_t _ = ShapedType::kDynamic;
+  const FloatType f32 = b.getF32Type();
+  auto strided = [&ctx](ArrayRef<int64_t> s) {
+    return StridedLayoutAttr::get(&ctx, 0, s);
+  };
+
+  // A not-entirely-continuous, not-entirely-discontinuous memref.
+  // ensure `areTrailingDimsContiguous` returns `true` for the value
+  // returned by `getNumContiguousTrailingDims` and `false` for the next bigger
+  // number.
+
+  // memref<2x?x2xf32, strided<[?,2,1]>>
+  auto m = MemRefType::get({2, _, 2}, f32, strided({_, 2, 1}));
+  int64_t n = m.getNumContiguousTrailingDims();
+  EXPECT_TRUE(m.areTrailingDimsContiguous(n));
+  ASSERT_TRUE(n + 1 <= m.getRank());
+  EXPECT_FALSE(m.areTrailingDimsContiguous(n + 1));
+}


        


More information about the Mlir-commits mailing list