[Mlir-commits] [mlir] [mlir][vector] Flatten transfer - support multi-dim scalar element (PR #185417)

Adam Siemieniuk llvmlistbot at llvm.org
Mon Mar 9 09:06:57 PDT 2026


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/185417

>From c954224069d8a4bcfc8a938d8d9bf6e5b312b6c3 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 9 Mar 2026 13:37:14 +0100
Subject: [PATCH 1/2] [mlir][vector] Flatten transfer - support multi-dim
 scalar element

Adds support for flattening multi-dimensional scalar vector transfers.

The addition prevents pattern crashes on such inputs and allows for
cleaner lowering of scalar vectors.
---
 .../Transforms/VectorTransferOpTransforms.cpp | 28 ++++++++++---
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |  4 ++
 .../Vector/vector-transfer-flatten.mlir       | 41 +++++++++++++++++++
 3 files changed, 67 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 1121d9550f265..3666b904a00c6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -835,11 +835,19 @@ class FlattenContiguousRowMajorTransferReadPattern
       return failure();
     }
 
+    // Determine vector dimensions to collapse.
+    // Ignore a leading sequence of adjacent unit dimensions in the vector.
+    ArrayRef<int64_t> collapsedVectorShape =
+        vectorType.getShape().drop_while([](auto v) { return v == 1; });
+    size_t collapsedVecRank = collapsedVectorShape.size();
+    // In case of a multi-dimensional scalar vector, restrict the shape collapse
+    // to a 1D vector.
+    if (collapsedVecRank == 0)
+      collapsedVecRank = 1;
+
     // Determine the first memref dimension to collapse - just enough so we can
     // read a flattened vector.
-    int64_t firstDimToCollapse =
-        sourceType.getRank() -
-        vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
+    int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
     LDBG() << "  -> First dimension to collapse: " << firstDimToCollapse;
 
     // 1. Collapse the source memref
@@ -939,11 +947,19 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (transferWriteOp.getMask())
       return failure();
 
+    // Determine vector dimensions to collapse.
+    // Ignore a leading sequence of adjacent unit dimensions in the vector.
+    ArrayRef<int64_t> collapsedVectorShape =
+        vectorType.getShape().drop_while([](auto v) { return v == 1; });
+    size_t collapsedVecRank = collapsedVectorShape.size();
+    // In case of a multi-dimensional scalar vector, restrict the shape collapse
+    // to a 1D vector.
+    if (collapsedVecRank == 0)
+      collapsedVecRank = 1;
+
     // Determine the first memref dimension to collapse - just enough so we can
     // read a flattened vector.
-    int64_t firstDimToCollapse =
-        sourceType.getRank() -
-        vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
+    int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
 
     // 1. Collapse the source memref
     Value collapsedSource =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d1ce0fad2fb56..f7fa3053af248 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -249,6 +249,10 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
       vectorType.getShape().drop_while([](auto v) { return v == 1; });
   auto vecRank = vectorShape.size();
 
+  // A scalar is always contiguous.
+  if (vecRank == 0)
+    return true;
+
   if (!memrefType.areTrailingDimsContiguous(vecRank))
     return false;
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 1f7c1b7ff7ad5..839cef6157378 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -386,6 +386,28 @@ func.func @transfer_read_non_contiguous_src(
 
 // -----
 
+func.func @transfer_read_multi_dim_scalar_vector(
+    %mem: memref<5x4x3x2xi8>) -> vector<1x1x1xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<1x1x1xi8>
+  return %res : vector<1x1x1xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_multi_dim_scalar_vector
+// CHECK-SAME:    %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[MEM]]{{.*}}: memref<5x4x3x2xi8>, vector<1xi8>
+// CHECK:         %[[VEC3D:.+]] = vector.shape_cast %[[READ1D]] : vector<1xi8> to vector<1x1x1xi8>
+// CHECK:         return %[[VEC3D]]
+
+// CHECK-128B-LABEL: func @transfer_read_multi_dim_scalar_vector
+//       CHECK-128B:   vector.transfer_read {{.*}}: memref<5x4x3x2xi8>, vector<1xi8>
+//       CHECK-128B:   vector.shape_cast {{.*}}: vector<1xi8> to vector<1x1x1xi8>
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// vector.transfer_write
 /// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
@@ -785,3 +807,22 @@ func.func @negative_out_of_bound_transfer_write(
 // CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
 //   CHECK-128B-NOT:   memref.collapse_shape
 //   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
+func.func @transfer_write_multi_dim_scalar_vector(
+    %mem: memref<5x4x3x2xi8>, %vec: vector<1x1x1xi8>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
+    vector<1x1x1xi8>, memref<5x4x3x2xi8>
+  return
+}
+
+// CHECK-LABEL: func @transfer_write_multi_dim_scalar_vector
+// CHECK-SAME:    %[[MEM:.+]]: memref<5x4x3x2xi8>, %[[VEC:.+]]: vector<1x1x1xi8>
+// CHECK:         %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x1xi8> to vector<1xi8>
+// CHECK:         vector.transfer_write %[[VEC1D]], %[[MEM]]{{.*}}: vector<1xi8>, memref<5x4x3x2xi8>
+
+// CHECK-128B-LABEL: func @transfer_write_multi_dim_scalar_vector
+//       CHECK-128B:   vector.shape_cast {{.*}}: vector<1x1x1xi8> to vector<1xi8>
+//       CHECK-128B:   vector.transfer_write {{.*}}: vector<1xi8>, memref<5x4x3x2xi8>

>From a6383a584ff205f0c51ff8187561da75966d1ad0 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 9 Mar 2026 17:06:46 +0100
Subject: [PATCH 2/2] Repharse: scalar -> unit

---
 .../Vector/Transforms/VectorTransferOpTransforms.cpp |  8 ++++----
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp        |  2 +-
 .../test/Dialect/Vector/vector-transfer-flatten.mlir | 12 ++++++------
 3 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 3666b904a00c6..19db8b3b48a25 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -840,8 +840,8 @@ class FlattenContiguousRowMajorTransferReadPattern
     ArrayRef<int64_t> collapsedVectorShape =
         vectorType.getShape().drop_while([](auto v) { return v == 1; });
     size_t collapsedVecRank = collapsedVectorShape.size();
-    // In case of a multi-dimensional scalar vector, restrict the shape collapse
-    // to a 1D vector.
+    // Limit the collapse of multi-dimensional unit vectors (e.g. <1x1x1xf32>)
+    // to a 1D single-element vector.
     if (collapsedVecRank == 0)
       collapsedVecRank = 1;
 
@@ -952,8 +952,8 @@ class FlattenContiguousRowMajorTransferWritePattern
     ArrayRef<int64_t> collapsedVectorShape =
         vectorType.getShape().drop_while([](auto v) { return v == 1; });
     size_t collapsedVecRank = collapsedVectorShape.size();
-    // In case of a multi-dimensional scalar vector, restrict the shape collapse
-    // to a 1D vector.
+    // Limit the collapse of multi-dimensional unit vectors (e.g. <1x1x1xf32>)
+    // to a 1D single-element vector.
     if (collapsedVecRank == 0)
       collapsedVecRank = 1;
 
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index f7fa3053af248..df8e6cf167348 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -249,7 +249,7 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
       vectorType.getShape().drop_while([](auto v) { return v == 1; });
   auto vecRank = vectorShape.size();
 
-  // A scalar is always contiguous.
+  // A single element is always contiguous.
   if (vecRank == 0)
     return true;
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 839cef6157378..b048af24acfcd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -386,7 +386,7 @@ func.func @transfer_read_non_contiguous_src(
 
 // -----
 
-func.func @transfer_read_multi_dim_scalar_vector(
+func.func @transfer_read_multi_dim_unit_vector(
     %mem: memref<5x4x3x2xi8>) -> vector<1x1x1xi8> {
 
   %c0 = arith.constant 0 : index
@@ -396,13 +396,13 @@ func.func @transfer_read_multi_dim_scalar_vector(
   return %res : vector<1x1x1xi8>
 }
 
-// CHECK-LABEL: func @transfer_read_multi_dim_scalar_vector
+// CHECK-LABEL: func @transfer_read_multi_dim_unit_vector
 // CHECK-SAME:    %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[MEM]]{{.*}}: memref<5x4x3x2xi8>, vector<1xi8>
 // CHECK:         %[[VEC3D:.+]] = vector.shape_cast %[[READ1D]] : vector<1xi8> to vector<1x1x1xi8>
 // CHECK:         return %[[VEC3D]]
 
-// CHECK-128B-LABEL: func @transfer_read_multi_dim_scalar_vector
+// CHECK-128B-LABEL: func @transfer_read_multi_dim_unit_vector
 //       CHECK-128B:   vector.transfer_read {{.*}}: memref<5x4x3x2xi8>, vector<1xi8>
 //       CHECK-128B:   vector.shape_cast {{.*}}: vector<1xi8> to vector<1x1x1xi8>
 
@@ -810,7 +810,7 @@ func.func @negative_out_of_bound_transfer_write(
 
 // -----
 
-func.func @transfer_write_multi_dim_scalar_vector(
+func.func @transfer_write_multi_dim_unit_vector(
     %mem: memref<5x4x3x2xi8>, %vec: vector<1x1x1xi8>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
@@ -818,11 +818,11 @@ func.func @transfer_write_multi_dim_scalar_vector(
   return
 }
 
-// CHECK-LABEL: func @transfer_write_multi_dim_scalar_vector
+// CHECK-LABEL: func @transfer_write_multi_dim_unit_vector
 // CHECK-SAME:    %[[MEM:.+]]: memref<5x4x3x2xi8>, %[[VEC:.+]]: vector<1x1x1xi8>
 // CHECK:         %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x1xi8> to vector<1xi8>
 // CHECK:         vector.transfer_write %[[VEC1D]], %[[MEM]]{{.*}}: vector<1xi8>, memref<5x4x3x2xi8>
 
-// CHECK-128B-LABEL: func @transfer_write_multi_dim_scalar_vector
+// CHECK-128B-LABEL: func @transfer_write_multi_dim_unit_vector
 //       CHECK-128B:   vector.shape_cast {{.*}}: vector<1x1x1xi8> to vector<1xi8>
 //       CHECK-128B:   vector.transfer_write {{.*}}: vector<1xi8>, memref<5x4x3x2xi8>



More information about the Mlir-commits mailing list