[Mlir-commits] [mlir] [mlir][vector] Flatten transfer - support multi-dim scalar element (PR #185417)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 9 06:30:46 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Adam Siemieniuk (adam-smnk)

<details>
<summary>Changes</summary>

Adds support for flattening multi-dimensional scalar vector transfers.

The addition prevents pattern crashes on such inputs and allows for cleaner lowering of scalar vectors.

---
Full diff: https://github.com/llvm/llvm-project/pull/185417.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+22-6) 
- (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+4) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+41) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 1121d9550f265..3666b904a00c6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -835,11 +835,19 @@ class FlattenContiguousRowMajorTransferReadPattern
       return failure();
     }
 
+    // Determine vector dimensions to collapse.
+    // Ignore a leading sequence of adjacent unit dimensions in the vector.
+    ArrayRef<int64_t> collapsedVectorShape =
+        vectorType.getShape().drop_while([](auto v) { return v == 1; });
+    size_t collapsedVecRank = collapsedVectorShape.size();
+    // In case of a multi-dimensional scalar vector, restrict the shape collapse
+    // to a 1D vector.
+    if (collapsedVecRank == 0)
+      collapsedVecRank = 1;
+
     // Determine the first memref dimension to collapse - just enough so we can
     // read a flattened vector.
-    int64_t firstDimToCollapse =
-        sourceType.getRank() -
-        vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
+    int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
     LDBG() << "  -> First dimension to collapse: " << firstDimToCollapse;
 
     // 1. Collapse the source memref
@@ -939,11 +947,19 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (transferWriteOp.getMask())
       return failure();
 
+    // Determine vector dimensions to collapse.
+    // Ignore a leading sequence of adjacent unit dimensions in the vector.
+    ArrayRef<int64_t> collapsedVectorShape =
+        vectorType.getShape().drop_while([](auto v) { return v == 1; });
+    size_t collapsedVecRank = collapsedVectorShape.size();
+    // In case of a multi-dimensional scalar vector, restrict the shape collapse
+    // to a 1D vector.
+    if (collapsedVecRank == 0)
+      collapsedVecRank = 1;
+
     // Determine the first memref dimension to collapse - just enough so we can
     // read a flattened vector.
-    int64_t firstDimToCollapse =
-        sourceType.getRank() -
-        vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
+    int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
 
     // 1. Collapse the source memref
     Value collapsedSource =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d1ce0fad2fb56..f7fa3053af248 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -249,6 +249,10 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
       vectorType.getShape().drop_while([](auto v) { return v == 1; });
   auto vecRank = vectorShape.size();
 
+  // A scalar is always contiguous.
+  if (vecRank == 0)
+    return true;
+
   if (!memrefType.areTrailingDimsContiguous(vecRank))
     return false;
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 1f7c1b7ff7ad5..839cef6157378 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -386,6 +386,28 @@ func.func @transfer_read_non_contiguous_src(
 
 // -----
 
+func.func @transfer_read_multi_dim_scalar_vector(
+    %mem: memref<5x4x3x2xi8>) -> vector<1x1x1xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<1x1x1xi8>
+  return %res : vector<1x1x1xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_multi_dim_scalar_vector
+// CHECK-SAME:    %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[MEM]]{{.*}}: memref<5x4x3x2xi8>, vector<1xi8>
+// CHECK:         %[[VEC3D:.+]] = vector.shape_cast %[[READ1D]] : vector<1xi8> to vector<1x1x1xi8>
+// CHECK:         return %[[VEC3D]]
+
+// CHECK-128B-LABEL: func @transfer_read_multi_dim_scalar_vector
+//       CHECK-128B:   vector.transfer_read {{.*}}: memref<5x4x3x2xi8>, vector<1xi8>
+//       CHECK-128B:   vector.shape_cast {{.*}}: vector<1xi8> to vector<1x1x1xi8>
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// vector.transfer_write
 /// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
@@ -785,3 +807,22 @@ func.func @negative_out_of_bound_transfer_write(
 // CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
 //   CHECK-128B-NOT:   memref.collapse_shape
 //   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
+func.func @transfer_write_multi_dim_scalar_vector(
+    %mem: memref<5x4x3x2xi8>, %vec: vector<1x1x1xi8>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
+    vector<1x1x1xi8>, memref<5x4x3x2xi8>
+  return
+}
+
+// CHECK-LABEL: func @transfer_write_multi_dim_scalar_vector
+// CHECK-SAME:    %[[MEM:.+]]: memref<5x4x3x2xi8>, %[[VEC:.+]]: vector<1x1x1xi8>
+// CHECK:         %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x1xi8> to vector<1xi8>
+// CHECK:         vector.transfer_write %[[VEC1D]], %[[MEM]]{{.*}}: vector<1xi8>, memref<5x4x3x2xi8>
+
+// CHECK-128B-LABEL: func @transfer_write_multi_dim_scalar_vector
+//       CHECK-128B:   vector.shape_cast {{.*}}: vector<1x1x1xi8> to vector<1xi8>
+//       CHECK-128B:   vector.transfer_write {{.*}}: vector<1xi8>, memref<5x4x3x2xi8>

``````````

</details>


https://github.com/llvm/llvm-project/pull/185417


More information about the Mlir-commits mailing list