[Mlir-commits] [mlir] [mlir][vector] Convert vector.transfer_read to scalar load and broadcast (PR #159520)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 18 01:18:19 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Hsiangkai Wang (Hsiangkai)
<details>
<summary>Changes</summary>
If we use vector.transfer_read to read from a 0-d value, we can convert it to memref.load from the 0-d value then broadcast the value to the target vector type.
It can avoid generating vector operations breaking the requirements of convertVectorToMMAOps. The patterns in convertVectorToMMAOps expect all vector.transfer_read have 2-D vector types.
Instead of
%s0 = vector.transfer_read %base[] : memref<dtype> to vector<dtype>
%s1 = vector.broadcast %s0 : vector<dtype> to vector<d0...d1 x dtype>
Use
%s0 = memref.load %base[] : memref<dtype>
%s1 = vector.broadcast %s0 : dtype to vector<d0...d1 x dtype>
---
Full diff: https://github.com/llvm/llvm-project/pull/159520.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+29-11)
- (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+18)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 2cf8f0beaa4de..4f62b6a7f2fde 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -360,17 +360,35 @@ struct TransferOpReduceRank
SmallVector<bool> newScalableDims(
originalVecType.getScalableDims().take_back(reducedShapeRank));
- VectorType newReadType = VectorType::get(
- newShape, originalVecType.getElementType(), newScalableDims);
- ArrayAttr newInBoundsAttr =
- op.getInBounds()
- ? rewriter.getArrayAttr(
- op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
- : ArrayAttr();
- Value newRead = vector::TransferReadOp::create(
- rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
- AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
- newInBoundsAttr);
+ Value newRead;
+ if (newShape.size() == 0 && newScalableDims.size() == 0) {
+ // Handle the scalar case.
+ // Convert
+ // %val = vector.transfer_read %base[] : memref<dtype> to
+ // vector<d0 x d1 x dtype>
+ // into
+ // %scalar = memref.load %base[] : memref<dtype>
+ // %val = vector.broadcast %scalar : dtype to vector<d0 x d1 x dtype>
+ Type baseType = op.getBase().getType();
+ if (isa<MemRefType>(baseType)) {
+ newRead = memref::LoadOp::create(rewriter, op.getLoc(), op.getBase(),
+ op.getIndices());
+ }
+ }
+
+ if (!newRead) {
+ VectorType newReadType = VectorType::get(
+ newShape, originalVecType.getElementType(), newScalableDims);
+ ArrayAttr newInBoundsAttr =
+ op.getInBounds()
+ ? rewriter.getArrayAttr(
+ op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
+ : ArrayAttr();
+ newRead = vector::TransferReadOp::create(
+ rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
+ AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
+ newInBoundsAttr);
+ }
return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
newRead)
.getVector();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 3ae18835c8367..16104aa76e692 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -388,6 +388,24 @@ func.func @xfer_read_minor_identitiy_bcast_dims(
return %res : vector<8x4x2x3xf32>
}
+// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_scalar
+// CHECK-SAME: %[[MEM:.*]]: memref<f32>) -> vector<8x4x2x3xf32> {
+// CHECK: %[[LOAD:.*]] = memref.load %[[MEM]][] : memref<f32>
+// CHECK: %[[BC:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<8x4x2x3xf32>
+// CHECK: return %[[BC]] : vector<8x4x2x3xf32>
+func.func @xfer_read_minor_identitiy_bcast_scalar(
+ %mem: memref<f32>) -> vector<8x4x2x3xf32> {
+
+ %pad = arith.constant 0.000000e+00 : f32
+
+ %res = vector.transfer_read %mem[], %pad {
+ in_bounds = [true, true, true, true],
+ permutation_map = affine_map<() -> (0, 0, 0, 0)>
+ } : memref<f32>, vector<8x4x2x3xf32>
+
+ return %res : vector<8x4x2x3xf32>
+}
+
// CHECK-LABEL: func.func @xfer_read_minor_identitiy_bcast_dims_scalable
// CHECK-SAME: %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
// CHECK: %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/159520
More information about the Mlir-commits
mailing list