[Mlir-commits] [mlir] c21a5ac - [mlir][vector] Flatten transfer - support multi-dim scalar element (#185417)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 9 09:45:21 PDT 2026
Author: Adam Siemieniuk
Date: 2026-03-09T17:45:15+01:00
New Revision: c21a5ac736d173e70b1ba1b5c447b382fa1995b0
URL: https://github.com/llvm/llvm-project/commit/c21a5ac736d173e70b1ba1b5c447b382fa1995b0
DIFF: https://github.com/llvm/llvm-project/commit/c21a5ac736d173e70b1ba1b5c447b382fa1995b0.diff
LOG: [mlir][vector] Flatten transfer - support multi-dim scalar element (#185417)
Adds support for flattening multi-dimensional scalar vector transfers.
The addition prevents pattern crashes on such inputs and allows for
cleaner lowering of scalar vectors.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 1121d9550f265..19db8b3b48a25 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -835,11 +835,19 @@ class FlattenContiguousRowMajorTransferReadPattern
return failure();
}
+ // Determine vector dimensions to collapse.
+ // Ignore a leading sequence of adjacent unit dimensions in the vector.
+ ArrayRef<int64_t> collapsedVectorShape =
+ vectorType.getShape().drop_while([](auto v) { return v == 1; });
+ size_t collapsedVecRank = collapsedVectorShape.size();
+ // Limit the collapse of multi-dimensional unit vectors (e.g. <1x1x1xf32>)
+ // to a 1D single-element vector.
+ if (collapsedVecRank == 0)
+ collapsedVecRank = 1;
+
// Determine the first memref dimension to collapse - just enough so we can
// read a flattened vector.
- int64_t firstDimToCollapse =
- sourceType.getRank() -
- vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
+ int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
LDBG() << " -> First dimension to collapse: " << firstDimToCollapse;
// 1. Collapse the source memref
@@ -939,11 +947,19 @@ class FlattenContiguousRowMajorTransferWritePattern
if (transferWriteOp.getMask())
return failure();
+ // Determine vector dimensions to collapse.
+ // Ignore a leading sequence of adjacent unit dimensions in the vector.
+ ArrayRef<int64_t> collapsedVectorShape =
+ vectorType.getShape().drop_while([](auto v) { return v == 1; });
+ size_t collapsedVecRank = collapsedVectorShape.size();
+ // Limit the collapse of multi-dimensional unit vectors (e.g. <1x1x1xf32>)
+ // to a 1D single-element vector.
+ if (collapsedVecRank == 0)
+ collapsedVecRank = 1;
+
// Determine the first memref dimension to collapse - just enough so we can
// read a flattened vector.
- int64_t firstDimToCollapse =
- sourceType.getRank() -
- vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
+ int64_t firstDimToCollapse = sourceType.getRank() - collapsedVecRank;
// 1. Collapse the source memref
Value collapsedSource =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d1ce0fad2fb56..df8e6cf167348 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -249,6 +249,10 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
vectorType.getShape().drop_while([](auto v) { return v == 1; });
auto vecRank = vectorShape.size();
+ // A single element is always contiguous.
+ if (vecRank == 0)
+ return true;
+
if (!memrefType.areTrailingDimsContiguous(vecRank))
return false;
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 1f7c1b7ff7ad5..b048af24acfcd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -386,6 +386,28 @@ func.func @transfer_read_non_contiguous_src(
// -----
+func.func @transfer_read_multi_dim_unit_vector(
+ %mem: memref<5x4x3x2xi8>) -> vector<1x1x1xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8>, vector<1x1x1xi8>
+ return %res : vector<1x1x1xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_multi_dim_unit_vector
+// CHECK-SAME: %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[MEM]]{{.*}}: memref<5x4x3x2xi8>, vector<1xi8>
+// CHECK: %[[VEC3D:.+]] = vector.shape_cast %[[READ1D]] : vector<1xi8> to vector<1x1x1xi8>
+// CHECK: return %[[VEC3D]]
+
+// CHECK-128B-LABEL: func @transfer_read_multi_dim_unit_vector
+// CHECK-128B: vector.transfer_read {{.*}}: memref<5x4x3x2xi8>, vector<1xi8>
+// CHECK-128B: vector.shape_cast {{.*}}: vector<1xi8> to vector<1x1x1xi8>
+
+// -----
+
///----------------------------------------------------------------------------------------
/// vector.transfer_write
/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
@@ -785,3 +807,22 @@ func.func @negative_out_of_bound_transfer_write(
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
// CHECK-128B-NOT: memref.collapse_shape
// CHECK-128B-NOT: vector.shape_cast
+
+// -----
+
+func.func @transfer_write_multi_dim_unit_vector(
+ %mem: memref<5x4x3x2xi8>, %vec: vector<1x1x1xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
+ vector<1x1x1xi8>, memref<5x4x3x2xi8>
+ return
+}
+
+// CHECK-LABEL: func @transfer_write_multi_dim_unit_vector
+// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8>, %[[VEC:.+]]: vector<1x1x1xi8>
+// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x1xi8> to vector<1xi8>
+// CHECK: vector.transfer_write %[[VEC1D]], %[[MEM]]{{.*}}: vector<1xi8>, memref<5x4x3x2xi8>
+
+// CHECK-128B-LABEL: func @transfer_write_multi_dim_unit_vector
+// CHECK-128B: vector.shape_cast {{.*}}: vector<1x1x1xi8> to vector<1xi8>
+// CHECK-128B: vector.transfer_write {{.*}}: vector<1xi8>, memref<5x4x3x2xi8>
More information about the Mlir-commits
mailing list