[Mlir-commits] [mlir] [MLIR] Determine contiguousness of memrefs with a dynamic dimension (PR #140872)
Momchil Velikov
llvmlistbot at llvm.org
Wed May 21 02:52:11 PDT 2025
https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/140872
Memrefs where only the leftmost dimension of the trailing ones to check for contiguity is dynamic can be reasoned about.
>From a55d4253f1f33aa84b80ac943dc487a0ad7ddda6 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 21 May 2025 09:35:55 +0000
Subject: [PATCH] [MLIR] Determine contiguousness of memrefs with a dynamic
dimension
Memrefs where only the leftmost dimension of the trailing ones to check
for contiguity is dynamic can be reasoned about.
---
mlir/lib/IR/BuiltinTypes.cpp | 7 +-
.../Vector/vector-transfer-flatten.mlir | 90 +++++++++++++++++--
2 files changed, 86 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index d47e360e9dc13..facf17551fa12 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -649,7 +649,10 @@ bool MemRefType::areTrailingDimsContiguous(int64_t n) {
if (!isLastDimUnitStride())
return false;
- auto memrefShape = getShape().take_back(n);
+ if (n == 1)
+ return true;
+
+ auto memrefShape = getShape().take_back(n-1);
if (ShapedType::isDynamicShape(memrefShape))
return false;
@@ -668,7 +671,7 @@ bool MemRefType::areTrailingDimsContiguous(int64_t n) {
// Check whether strides match "flattened" dims.
SmallVector<int64_t> flattenedDims;
auto dimProduct = 1;
- for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
+ for (auto dim : llvm::reverse(memrefShape)) {
dimProduct *= dim;
flattenedDims.push_back(dimProduct);
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index e840dc6bbf224..aa922415f2669 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -188,18 +188,20 @@ func.func @transfer_read_leading_dynamic_dims(
// -----
-// One of the dims to be flattened is dynamic - not supported ATM.
+// One of the dims to be flattened is dynamic and not the leftmost - not
+// possible to reason whether the memref is contiguous as the dynamic dimension
+// could be one and the corresponding stride could be arbitrary.
func.func @negative_transfer_read_dynamic_dim_to_flatten(
%idx_1: index,
%idx_2: index,
- %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
+ %mem: memref<1x4x?x6xi32>) -> vector<1x2x6xi32> {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 {
in_bounds = [true, true, true]
- } : memref<1x?x4x6xi32>, vector<1x2x6xi32>
+ } : memref<1x4x?x6xi32>, vector<1x2x6xi32>
return %res : vector<1x2x6xi32>
}
@@ -212,6 +214,41 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten(
// -----
+// One of the dims to be flattened is dynamic and leftmost.
+
+func.func @transfer_read_dynamic_leftmost_dim_to_flatten(
+ %idx_1: index,
+ %idx_2: index,
+ %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 {
+ in_bounds = [true, true, true]
+ } : memref<1x?x4x6xi32>, vector<1x2x6xi32>
+ return %res : vector<1x2x6xi32>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dynamic_leftmost_dim_to_flatten
+// CHECK-SAME: %[[IDX_1:arg0]]: index
+// CHECK-SAME: %[[IDX_2:arg1]]: index
+// CHECK-SAME: %[[MEM:arg2]]: memref<1x?x4x6xi32>
+// CHECK-NEXT: %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-NEXT: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-NEXT: %[[TMP:.+]] = affine.apply #map{{.*}}()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK-NEXT: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK-SAME: [%[[C0]], %[[TMP]]], %[[C0_I32]]
+// CHECK-SAME: {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK-NEXT: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<12xi32> to vector<1x2x6xi32>
+// CHECK-NEXT: return %[[RES]] : vector<1x2x6xi32>
+
+// CHECK-128B-LABEL: func @transfer_read_dynamic_leftmost_dim_to_flatten
+// CHECK-128B-NOT: memref.collapse_shape
+
+// -----
+
// The vector to be read represents a _non-contiguous_ slice of the input
// memref.
@@ -451,26 +488,61 @@ func.func @transfer_write_leading_dynamic_dims(
// -----
-// One of the dims to be flattened is dynamic - not supported ATM.
+// One of the dims to be flattened is dynamic and not leftmost.
-func.func @negative_transfer_write_dynamic_to_flatten(
+func.func @negative_transfer_write_dynamic_dim_to_flatten(
%idx_1: index,
%idx_2: index,
%vec : vector<1x2x6xi32>,
- %mem: memref<1x?x4x6xi32>) {
+ %mem: memref<1x4x?x6xi32>) {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
- vector<1x2x6xi32>, memref<1x?x4x6xi32>
+ vector<1x2x6xi32>, memref<1x4x?x6xi32>
return
}
-// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten
+// CHECK-LABEL: func.func @negative_transfer_write_dynamic_dim_to_flatten
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast
-// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten
+// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_dim_to_flatten
+// CHECK-128B-NOT: memref.collapse_shape
+
+// -----
+
+// One of the dims to be flattened is dynamic and leftmost.
+
+func.func @transfer_write_dynamic_leftmost_dim_to_flatten(
+ %idx_1: index,
+ %idx_2: index,
+ %vec : vector<1x2x6xi32>,
+ %mem: memref<1x?x4x6xi32>) {
+
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
+ vector<1x2x6xi32>, memref<1x?x4x6xi32>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_write_dynamic_leftmost_dim_to_flatten
+// CHECK-SAME: %[[IDX_1:arg0]]: index
+// CHECK-SAME: %[[IDX_2:arg1]]: index
+// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>,
+// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
+// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-NEXT: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-NEXT: %[[TMP:.+]] = affine.apply #map{{.*}}()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK-NEXT: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK-NEXT: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+// CHECK-SAME: [%[[C0]], %[[TMP]]]
+// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
+// CHECK-NEXT: return
+
+// CHECK-128B-LABEL: func @transfer_write_dynamic_leftmost_dim_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
// -----
More information about the Mlir-commits
mailing list