[Mlir-commits] [mlir] [MLIR] Determine contiguousness of memrefs with dynamic dimensions (PR #142421)

James Newling llvmlistbot at llvm.org
Wed Jun 4 08:54:41 PDT 2025


================
@@ -646,35 +646,40 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
 }
 
 bool MemRefType::areTrailingDimsContiguous(int64_t n) {
-  if (!isLastDimUnitStride())
-    return false;
+  return getLayout().isIdentity() ||
+         getMaxCollapsableTrailingDims() >= std::min(n, getRank());
+}
 
-  auto memrefShape = getShape().take_back(n);
-  if (ShapedType::isDynamicShape(memrefShape))
-    return false;
+int64_t MemRefType::getMaxCollapsableTrailingDims() {
+  const int64_t n = getRank();
 
+  // memrefs with identity layout are entirely contiguous.
   if (getLayout().isIdentity())
-    return true;
+    return n;
 
+  // Get the strides (if any). Failing to do that, conservatively assume a
+  // non-contiguous layout.
   int64_t offset;
-  SmallVector<int64_t> stridesFull;
-  if (!succeeded(getStridesAndOffset(stridesFull, offset)))
-    return false;
-  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
-
-  if (strides.empty())
-    return true;
+  SmallVector<int64_t> strides;
+  if (!succeeded(getStridesAndOffset(strides, offset)))
+    return 0;
 
-  // Check whether strides match "flattened" dims.
-  SmallVector<int64_t> flattenedDims;
-  auto dimProduct = 1;
-  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
-    dimProduct *= dim;
-    flattenedDims.push_back(dimProduct);
+  auto shape = getShape();
+
+  // A memref with dimensions `d0, d1, ..., dn-1` and strides
+  // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
+  // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
+  // for `i` in `[k, n-1]`.
+  int64_t dimProduct = 1;
+  for (int64_t i = n - 1; i >= 0; --i) {
----------------
newling wrote:

 `memref<1x4xi8, strided<[x,1]>>` is fully collapsible for any `x`. Basically, the stride in a unit dimension carries no information as there is no 'next element'. One way I think about it is to consider iterating through the elements. 
 
 ```
 for d0 in range(1):
   for d1 in range(4):
     element = vals[d0*x + d1] 
 ```

Unit dims correspond to loops with loop count of 1, which can be removed. Can you please add some tests for unit dims? Also fine if you prefer postponing the unit dim case for a later PR  

https://github.com/llvm/llvm-project/pull/142421


More information about the Mlir-commits mailing list