[Mlir-commits] [mlir] [MLIR] Determine contiguousness of memrefs with a dynamic dimension (PR #140872)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 21 02:52:44 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>

Memrefs where only the leftmost dimension of the trailing ones to check for contiguity is dynamic can be reasoned about.

---
Full diff: https://github.com/llvm/llvm-project/pull/140872.diff


2 Files Affected:

- (modified) mlir/lib/IR/BuiltinTypes.cpp (+5-2) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+81-9) 


``````````diff
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d47e360e9dc13..facf17551fa12 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -649,7 +649,10 @@ bool MemRefType::areTrailingDimsContiguous(int64_t n) {
   if (!isLastDimUnitStride())
     return false;
 
-  auto memrefShape = getShape().take_back(n);
+  if (n == 1)
+    return true;
+
+  auto memrefShape = getShape().take_back(n-1);
   if (ShapedType::isDynamicShape(memrefShape))
     return false;
 
@@ -668,7 +671,7 @@ bool MemRefType::areTrailingDimsContiguous(int64_t n) {
   // Check whether strides match "flattened" dims.
   SmallVector<int64_t> flattenedDims;
   auto dimProduct = 1;
-  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
+  for (auto dim : llvm::reverse(memrefShape)) {
     dimProduct *= dim;
     flattenedDims.push_back(dimProduct);
   }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index e840dc6bbf224..aa922415f2669 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -188,18 +188,20 @@ func.func @transfer_read_leading_dynamic_dims(
 
 // -----
 
-// One of the dims to be flattened is dynamic - not supported ATM.
+// One of the dims to be flattened is dynamic and not the leftmost - not
+// possible to reason whether the memref is contiguous as the dynamic dimension
+// could be one and the corresponding stride could be arbitrary.
 
 func.func @negative_transfer_read_dynamic_dim_to_flatten(
     %idx_1: index,
     %idx_2: index,
-    %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
+    %mem: memref<1x4x?x6xi32>) -> vector<1x2x6xi32> {
 
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
   %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 {
     in_bounds = [true, true, true]
-  } : memref<1x?x4x6xi32>, vector<1x2x6xi32>
+  } : memref<1x4x?x6xi32>, vector<1x2x6xi32>
   return %res : vector<1x2x6xi32>
 }
 
@@ -212,6 +214,41 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
 
 // -----
 
+// One of the dims to be flattened is dynamic and leftmost.
+
+func.func @transfer_read_dynamic_leftmost_dim_to_flatten(
+    %idx_1: index,
+    %idx_2: index,
+    %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 {
+    in_bounds = [true, true, true]
+  } : memref<1x?x4x6xi32>, vector<1x2x6xi32>
+  return %res : vector<1x2x6xi32>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dynamic_leftmost_dim_to_flatten
+// CHECK-SAME:    %[[IDX_1:arg0]]: index
+// CHECK-SAME:    %[[IDX_2:arg1]]: index
+// CHECK-SAME:    %[[MEM:arg2]]: memref<1x?x4x6xi32>
+// CHECK-NEXT:  %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK-NEXT:  %[[C0:.+]] = arith.constant 0 : index
+// CHECK-NEXT:   %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME:    : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-NEXT:  %[[TMP:.+]] = affine.apply #map{{.*}}()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK-NEXT:  %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK-SAME:    [%[[C0]], %[[TMP]]], %[[C0_I32]]
+// CHECK-SAME:    {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK-NEXT:  %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<12xi32> to vector<1x2x6xi32>
+// CHECK-NEXT:  return %[[RES]] : vector<1x2x6xi32>
+
+// CHECK-128B-LABEL: func @transfer_read_dynamic_leftmost_dim_to_flatten
+//   CHECK-128B-NOT:   memref.collapse_shape
+
+// -----
+
 // The vector to be read represents a _non-contiguous_ slice of the input
 // memref.
 
@@ -451,26 +488,61 @@ func.func @transfer_write_leading_dynamic_dims(
 
 // -----
 
-// One of the dims to be flattened is dynamic - not supported ATM.
+// One of the dims to be flattened is dynamic and not leftmost.
 
-func.func @negative_transfer_write_dynamic_to_flatten(
+func.func @negative_transfer_write_dynamic_dim_to_flatten(
     %idx_1: index,
     %idx_2: index,
     %vec : vector<1x2x6xi32>,
-    %mem: memref<1x?x4x6xi32>) {
+    %mem: memref<1x4x?x6xi32>) {
 
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
   vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
-    vector<1x2x6xi32>, memref<1x?x4x6xi32>
+    vector<1x2x6xi32>, memref<1x4x?x6xi32>
   return
 }
 
-// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
+// CHECK-LABEL: func.func @negative_transfer_write_dynamic_dim_to_flatten
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
+// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_dim_to_flatten
+//   CHECK-128B-NOT:   memref.collapse_shape
+
+// -----
+
+// One of the dims to be flattened is dynamic and leftmost.
+
+func.func @transfer_write_dynamic_leftmost_dim_to_flatten(
+    %idx_1: index,
+    %idx_2: index,
+    %vec : vector<1x2x6xi32>,
+    %mem: memref<1x?x4x6xi32>) {
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
+    vector<1x2x6xi32>, memref<1x?x4x6xi32>
+  return
+}
+
+// CHECK-LABEL: func.func @transfer_write_dynamic_leftmost_dim_to_flatten
+// CHECK-SAME:    %[[IDX_1:arg0]]: index
+// CHECK-SAME:    %[[IDX_2:arg1]]: index
+// CHECK-SAME:    %[[VEC:arg2]]: vector<1x2x6xi32>,
+// CHECK-SAME:    %[[MEM:arg3]]: memref<1x?x4x6xi32>
+// CHECK-NEXT:  %[[C0:.+]] = arith.constant 0 : index
+// CHECK-NEXT:   %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME:    : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-NEXT:  %[[TMP:.+]] = affine.apply #map{{.*}}()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK-NEXT:  %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK-NEXT:  vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+// CHECK-SAME:    [%[[C0]], %[[TMP]]]
+// CHECK-SAME:    {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
+// CHECK-NEXT:  return
+
+// CHECK-128B-LABEL: func @transfer_write_dynamic_leftmost_dim_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/140872


More information about the Mlir-commits mailing list