[Mlir-commits] [mlir] c21a5ac - [mlir][vector] Flatten transfer - support multi-dim scalar element (#185417)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 9 09:45:21 PDT 2026


Author: Adam Siemieniuk
Date: 2026-03-09T17:45:15+01:00
New Revision: c21a5ac736d173e70b1ba1b5c447b382fa1995b0

URL: https://github.com/llvm/llvm-project/commit/c21a5ac736d173e70b1ba1b5c447b382fa1995b0
DIFF: https://github.com/llvm/llvm-project/commit/c21a5ac736d173e70b1ba1b5c447b382fa1995b0.diff

LOG: [mlir][vector] Flatten transfer - support multi-dim scalar element (#185417)

Adds support for flattening multi-dimensional scalar vector transfers.

The addition prevents pattern crashes on such inputs and allows for
cleaner lowering of scalar vectors.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
    mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 1121d9550f265..19db8b3b48a25 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -835,11 +835,19 @@ class FlattenContiguousRowMajorTransferReadPattern
       return failure();
     }
 
+    // Determine vector dimensions to collapse.
+    // Ignore a leading sequence of adjacent unit dimensions in the vector.
+    ArrayRef<int64_t> collapsedVectorShape =
+        vectorType.getShape().drop_while([](auto v) { return v == 1; });
+    size_t collapsedVecRank = collapsedVectorShape.size();
+    // Limit the collapse of multi-dimensional unit vectors (e.g. <1x1x1xf32>)
+    // to a 1D single-element vector.
+    if (collapsedVecRank == 0)
+      collapsedVecRank = 1;
+
     // Determine the first memref dimension to collapse - just enough so we can
     // read a flattened vector.
-    int64_t firstDimToCollapse =
-        sourceType.getRank() -
-        vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
+    int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
     LDBG() << "  -> First dimension to collapse: " << firstDimToCollapse;
 
     // 1. Collapse the source memref
@@ -939,11 +947,19 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (transferWriteOp.getMask())
       return failure();
 
+    // Determine vector dimensions to collapse.
+    // Ignore a leading sequence of adjacent unit dimensions in the vector.
+    ArrayRef<int64_t> collapsedVectorShape =
+        vectorType.getShape().drop_while([](auto v) { return v == 1; });
+    size_t collapsedVecRank = collapsedVectorShape.size();
+    // Limit the collapse of multi-dimensional unit vectors (e.g. <1x1x1xf32>)
+    // to a 1D single-element vector.
+    if (collapsedVecRank == 0)
+      collapsedVecRank = 1;
+
     // Determine the first memref dimension to collapse - just enough so we can
     // read a flattened vector.
-    int64_t firstDimToCollapse =
-        sourceType.getRank() -
-        vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
+    int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
 
     // 1. Collapse the source memref
     Value collapsedSource =

diff  --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d1ce0fad2fb56..df8e6cf167348 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -249,6 +249,10 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
       vectorType.getShape().drop_while([](auto v) { return v == 1; });
   auto vecRank = vectorShape.size();
 
+  // A single element is always contiguous.
+  if (vecRank == 0)
+    return true;
+
   if (!memrefType.areTrailingDimsContiguous(vecRank))
     return false;
 

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 1f7c1b7ff7ad5..b048af24acfcd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -386,6 +386,28 @@ func.func @transfer_read_non_contiguous_src(
 
 // -----
 
+func.func @transfer_read_multi_dim_unit_vector(
+    %mem: memref<5x4x3x2xi8>) -> vector<1x1x1xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<1x1x1xi8>
+  return %res : vector<1x1x1xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_multi_dim_unit_vector
+// CHECK-SAME:    %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[MEM]]{{.*}}: memref<5x4x3x2xi8>, vector<1xi8>
+// CHECK:         %[[VEC3D:.+]] = vector.shape_cast %[[READ1D]] : vector<1xi8> to vector<1x1x1xi8>
+// CHECK:         return %[[VEC3D]]
+
+// CHECK-128B-LABEL: func @transfer_read_multi_dim_unit_vector
+//       CHECK-128B:   vector.transfer_read {{.*}}: memref<5x4x3x2xi8>, vector<1xi8>
+//       CHECK-128B:   vector.shape_cast {{.*}}: vector<1xi8> to vector<1x1x1xi8>
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// vector.transfer_write
 /// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
@@ -785,3 +807,22 @@ func.func @negative_out_of_bound_transfer_write(
 // CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
 //   CHECK-128B-NOT:   memref.collapse_shape
 //   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
+func.func @transfer_write_multi_dim_unit_vector(
+    %mem: memref<5x4x3x2xi8>, %vec: vector<1x1x1xi8>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
+    vector<1x1x1xi8>, memref<5x4x3x2xi8>
+  return
+}
+
+// CHECK-LABEL: func @transfer_write_multi_dim_unit_vector
+// CHECK-SAME:    %[[MEM:.+]]: memref<5x4x3x2xi8>, %[[VEC:.+]]: vector<1x1x1xi8>
+// CHECK:         %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x1xi8> to vector<1xi8>
+// CHECK:         vector.transfer_write %[[VEC1D]], %[[MEM]]{{.*}}: vector<1xi8>, memref<5x4x3x2xi8>
+
+// CHECK-128B-LABEL: func @transfer_write_multi_dim_unit_vector
+//       CHECK-128B:   vector.shape_cast {{.*}}: vector<1x1x1xi8> to vector<1xi8>
+//       CHECK-128B:   vector.transfer_write {{.*}}: vector<1xi8>, memref<5x4x3x2xi8>


        


More information about the Mlir-commits mailing list