[Mlir-commits] [mlir] [mlir][vector] Convert vector.transfer_read to scalar load and broadcast (PR #159520)
    Andrzej WarzyĆski 
    llvmlistbot at llvm.org
       
    Fri Sep 26 05:46:11 PDT 2025
    
    
  
================
@@ -360,17 +360,35 @@ struct TransferOpReduceRank
     SmallVector<bool> newScalableDims(
         originalVecType.getScalableDims().take_back(reducedShapeRank));
 
-    VectorType newReadType = VectorType::get(
-        newShape, originalVecType.getElementType(), newScalableDims);
-    ArrayAttr newInBoundsAttr =
-        op.getInBounds()
-            ? rewriter.getArrayAttr(
-                  op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
-            : ArrayAttr();
-    Value newRead = vector::TransferReadOp::create(
-        rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(),
-        AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
-        newInBoundsAttr);
+    Value newRead;
+    if (newShape.size() == 0 && newScalableDims.size() == 0) {
+      // Handle the scalar case.
+      // Convert
+      //   %val = vector.transfer_read %base[] : memref<dtype> to
+      //                                         vector<d0 x d1 x dtype>
+      // into
+      //   %scalar = memref.load %base[] : memref<dtype>
+      //   %val = vector.broadcast %scalar : dtype to vector<d0 x d1 x dtype>
+      Type baseType = op.getBase().getType();
+      if (isa<MemRefType>(baseType)) {
+        newRead = memref::LoadOp::create(rewriter, op.getLoc(), op.getBase(),
+                                         op.getIndices());
+      }
+    }
+
----------------
banach-space wrote:
```suggestion
// Handle the non-scalar case.
```
(and then, above `if (newShape.size() == 0 && newScalableDims.size() == 0) {`, `// Handle the scalar case`).
https://github.com/llvm/llvm-project/pull/159520
    
    
More information about the Mlir-commits
mailing list