[Mlir-commits] [mlir] [mlir][vector] Restrict vector.insert/vector.extract (PR #121458)

Kunwar Grover llvmlistbot at llvm.org
Thu Jun 19 09:39:19 PDT 2025


================
@@ -1338,12 +1344,33 @@ struct UnrollTransferReadConversion
             insertionIndices.push_back(rewriter.getIndexAttr(i));
 
             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
-            auto newXferOp = b.create<vector::TransferReadOp>(
-                loc, newXferVecType, xferOp.getBase(), xferIndices,
-                AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
-                xferOp.getPadding(), Value(), inBoundsAttr);
-            maybeAssignMask(b, xferOp, newXferOp, i);
-            return b.create<vector::InsertOp>(loc, newXferOp, vec,
+
+            // A value that's read after rank-reducing the original
+            // vector.transfer_read Op.
+            Value unpackedReadRes;
+            if (newXferVecType.getRank() != 0) {
+              // Unpacking Vector that's rank > 2
+              // (use vector.transfer_read to load a rank-reduced vector)
+              unpackedReadRes = b.create<vector::TransferReadOp>(
+                  loc, newXferVecType, xferOp.getBase(), xferIndices,
+                  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
+                  xferOp.getPadding(), Value(), inBoundsAttr);
+              maybeAssignMask(b, xferOp,
+                              dyn_cast<vector::TransferReadOp>(
+                                  unpackedReadRes.getDefiningOp()),
+                              i);
+            } else {
+              // Unpacking Vector that's rank == 1
+              // (use memref.load/tensor.extract to load a scalar)
+              unpackedReadRes = dyn_cast<MemRefType>(xferOp.getBase().getType())
+                                    ? b.create<memref::LoadOp>(
+                                           loc, xferOp.getBase(), xferIndices)
+                                          .getResult()
+                                    : b.create<tensor::ExtractOp>(
+                                           loc, xferOp.getBase(), xferIndices)
+                                          .getResult();
+            }
+            return b.create<vector::InsertOp>(loc, unpackedReadRes, vec,
----------------
Groverkss wrote:

This is unrelated to the patch and changing behavior of other transformations. For now, if the transfer_read returns a 0-D vector, we should extract a scalar and then insert it.

https://github.com/llvm/llvm-project/pull/121458


More information about the Mlir-commits mailing list