[Mlir-commits] [mlir] [mlir][vector] Flatten transfer - support multi-dim scalar element (PR #185417)
Adam Siemieniuk
llvmlistbot at llvm.org
Mon Mar 9 06:30:11 PDT 2026
https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/185417
Adds support for flattening multi-dimensional scalar vector transfers.
The addition prevents pattern crashes on such inputs and allows for cleaner lowering of scalar vectors.
>From c954224069d8a4bcfc8a938d8d9bf6e5b312b6c3 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 9 Mar 2026 13:37:14 +0100
Subject: [PATCH] [mlir][vector] Flatten transfer - support multi-dim scalar
element
Adds support for flattening multi-dimensional scalar vector transfers.
The addition prevents pattern crashes on such inputs and allows for
cleaner lowering of scalar vectors.
---
.../Transforms/VectorTransferOpTransforms.cpp | 28 ++++++++++---
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 4 ++
.../Vector/vector-transfer-flatten.mlir | 41 +++++++++++++++++++
3 files changed, 67 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 1121d9550f265..3666b904a00c6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -835,11 +835,19 @@ class FlattenContiguousRowMajorTransferReadPattern
return failure();
}
+ // Determine vector dimensions to collapse.
+ // Ignore a leading sequence of adjacent unit dimensions in the vector.
+ ArrayRef<int64_t> collapsedVectorShape =
+ vectorType.getShape().drop_while([](auto v) { return v == 1; });
+ size_t collapsedVecRank = collapsedVectorShape.size();
+ // In case of a multi-dimensional scalar vector, restrict the shape collapse
+ // to a 1D vector.
+ if (collapsedVecRank == 0)
+ collapsedVecRank = 1;
+
// Determine the first memref dimension to collapse - just enough so we can
// read a flattened vector.
- int64_t firstDimToCollapse =
- sourceType.getRank() -
- vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
+ int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
LDBG() << " -> First dimension to collapse: " << firstDimToCollapse;
// 1. Collapse the source memref
@@ -939,11 +947,19 @@ class FlattenContiguousRowMajorTransferWritePattern
if (transferWriteOp.getMask())
return failure();
+ // Determine vector dimensions to collapse.
+ // Ignore a leading sequence of adjacent unit dimensions in the vector.
+ ArrayRef<int64_t> collapsedVectorShape =
+ vectorType.getShape().drop_while([](auto v) { return v == 1; });
+ size_t collapsedVecRank = collapsedVectorShape.size();
+ // In case of a multi-dimensional scalar vector, restrict the shape collapse
+ // to a 1D vector.
+ if (collapsedVecRank == 0)
+ collapsedVecRank = 1;
+
// Determine the first memref dimension to collapse - just enough so we can
// read a flattened vector.
- int64_t firstDimToCollapse =
- sourceType.getRank() -
- vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
+ int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
// 1. Collapse the source memref
Value collapsedSource =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d1ce0fad2fb56..f7fa3053af248 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -249,6 +249,10 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
vectorType.getShape().drop_while([](auto v) { return v == 1; });
auto vecRank = vectorShape.size();
+ // A scalar is always contiguous.
+ if (vecRank == 0)
+ return true;
+
if (!memrefType.areTrailingDimsContiguous(vecRank))
return false;
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 1f7c1b7ff7ad5..839cef6157378 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -386,6 +386,28 @@ func.func @transfer_read_non_contiguous_src(
// -----
+func.func @transfer_read_multi_dim_scalar_vector(
+ %mem: memref<5x4x3x2xi8>) -> vector<1x1x1xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8>, vector<1x1x1xi8>
+ return %res : vector<1x1x1xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_multi_dim_scalar_vector
+// CHECK-SAME: %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[MEM]]{{.*}}: memref<5x4x3x2xi8>, vector<1xi8>
+// CHECK: %[[VEC3D:.+]] = vector.shape_cast %[[READ1D]] : vector<1xi8> to vector<1x1x1xi8>
+// CHECK: return %[[VEC3D]]
+
+// CHECK-128B-LABEL: func @transfer_read_multi_dim_scalar_vector
+// CHECK-128B: vector.transfer_read {{.*}}: memref<5x4x3x2xi8>, vector<1xi8>
+// CHECK-128B: vector.shape_cast {{.*}}: vector<1xi8> to vector<1x1x1xi8>
+
+// -----
+
///----------------------------------------------------------------------------------------
/// vector.transfer_write
/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
@@ -785,3 +807,22 @@ func.func @negative_out_of_bound_transfer_write(
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
// CHECK-128B-NOT: memref.collapse_shape
// CHECK-128B-NOT: vector.shape_cast
+
+// -----
+
+func.func @transfer_write_multi_dim_scalar_vector(
+ %mem: memref<5x4x3x2xi8>, %vec: vector<1x1x1xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
+ vector<1x1x1xi8>, memref<5x4x3x2xi8>
+ return
+}
+
+// CHECK-LABEL: func @transfer_write_multi_dim_scalar_vector
+// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8>, %[[VEC:.+]]: vector<1x1x1xi8>
+// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x1xi8> to vector<1xi8>
+// CHECK: vector.transfer_write %[[VEC1D]], %[[MEM]]{{.*}}: vector<1xi8>, memref<5x4x3x2xi8>
+
+// CHECK-128B-LABEL: func @transfer_write_multi_dim_scalar_vector
+// CHECK-128B: vector.shape_cast {{.*}}: vector<1x1x1xi8> to vector<1xi8>
+// CHECK-128B: vector.transfer_write {{.*}}: vector<1xi8>, memref<5x4x3x2xi8>
More information about the Mlir-commits
mailing list