[Mlir-commits] [mlir] [mlir][vector] Convert vector.transfer_read to scalar load and broadcast (PR #159520)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 18 01:18:19 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Hsiangkai Wang (Hsiangkai)

<details>
<summary>Changes</summary>

If we use vector.transfer_read to read from a 0-d value, we can convert it to memref.load from the 0-d value then broadcast the value to the target vector type.

It can avoid generating vector operations breaking the requirements of convertVectorToMMAOps. The patterns in convertVectorToMMAOps expect all vector.transfer_read have 2-D vector types.

Instead of
  %s0 = vector.transfer_read %base[] : memref<dtype> to vector<dtype>
  %s1 = vector.broadcast %s0 : vector<dtype> to vector<d0...d1 x dtype>

Use
  %s0 = memref.load %base[] : memref<dtype>
  %s1 = vector.broadcast %s0 : dtype to vector<d0...d1 x dtype>

---
Full diff: https://github.com/llvm/llvm-project/pull/159520.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp (+29-11) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir (+18) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 2cf8f0beaa4de..4f62b6a7f2fde 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -360,17 +360,35 @@ struct TransferOpReduceRank
     SmallVector<bool> newScalableDims(
         originalVecType.getScalableDims().take_back(reducedShapeRank));
 
-    VectorType newReadType = VectorType::get(
-        newShape, originalVecType.getElementType(), newScalableDims);
-    ArrayAttr newInBoundsAttr =
-        op.getInBounds()
-            ? rewriter.getArrayAttr(
-                  op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
-            : ArrayAttr();
-    Value newRead = vector::TransferReadOp::create(
-        rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
-        AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
-        newInBoundsAttr);
+    Value newRead;
+    if (newShape.size() == 0 && newScalableDims.size() == 0) {
+      // Handle the scalar case.
+      // Convert
+      //   %val = vector.transfer_read %base[] : memref<dtype> to
+      //                                         vector<d0 x d1 x dtype>
+      // into
+      //   %scalar = memref.load %base[] : memref<dtype>
+      //   %val = vector.broadcast %scalar : dtype to vector<d0 x d1 x dtype>
+      Type baseType = op.getBase().getType();
+      if (isa<MemRefType>(baseType)) {
+        newRead = memref::LoadOp::create(rewriter, op.getLoc(), op.getBase(),
+                                         op.getIndices());
+      }
+    }
+
+    if (!newRead) {
+      VectorType newReadType = VectorType::get(
+          newShape, originalVecType.getElementType(), newScalableDims);
+      ArrayAttr newInBoundsAttr =
+          op.getInBounds()
+              ? rewriter.getArrayAttr(
+                    op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
+              : ArrayAttr();
+      newRead = vector::TransferReadOp::create(
+          rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
+          AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
+          newInBoundsAttr);
+    }
     return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType,
                                        newRead)
         .getVector();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
index 3ae18835c8367..16104aa76e692 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
@@ -388,6 +388,24 @@ func.func @xfer_read_minor_identitiy_bcast_dims(
   return %res : vector<8x4x2x3xf32>
 }
 
+// CHECK-LABEL:   func.func @xfer_read_minor_identitiy_bcast_scalar
+//  CHECK-SAME:     %[[MEM:.*]]: memref<f32>) -> vector<8x4x2x3xf32> {
+//       CHECK:     %[[LOAD:.*]] = memref.load %[[MEM]][] : memref<f32>
+//       CHECK:     %[[BC:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<8x4x2x3xf32>
+//       CHECK:     return %[[BC]] : vector<8x4x2x3xf32>
+func.func @xfer_read_minor_identitiy_bcast_scalar(
+    %mem: memref<f32>) -> vector<8x4x2x3xf32> {
+
+  %pad = arith.constant 0.000000e+00 : f32
+
+  %res = vector.transfer_read %mem[], %pad {
+    in_bounds = [true, true, true, true],
+    permutation_map = affine_map<() -> (0, 0, 0, 0)>
+  } : memref<f32>, vector<8x4x2x3xf32>
+
+  return %res : vector<8x4x2x3xf32>
+}
+
 // CHECK-LABEL:   func.func @xfer_read_minor_identitiy_bcast_dims_scalable
 //  CHECK-SAME:     %[[MEM:.*]]: memref<?x?x?x?xf32>, %[[IDX:.*]]: index) -> vector<8x[4]x2x3xf32> {
 //       CHECK:     %[[T_READ:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]]{{.*}} permutation_map = #[[$MAP]]} : memref<?x?x?x?xf32>, vector<[4]x2x3xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/159520


More information about the Mlir-commits mailing list