[Mlir-commits] [mlir] [mlir][Vector] Improve vector.transferx store-to-load-forwarding (PR #171840)

Jakub Kuderski llvmlistbot at llvm.org
Thu Dec 11 07:26:52 PST 2025


================
@@ -5161,60 +5185,51 @@ struct TransferReadAfterWriteToBroadcast
     if (readOp.getTransferChunkAccessed() !=
         defWrite.getTransferChunkAccessed())
       return failure();
-    // TODO: Support cases where a dim is explicitly written but implicitly
-    // read (i.e., a unit dim that is rank reduced).
-    if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
-        getUnusedDimsBitVector({defWrite.getPermutationMap()}))
-      return failure();
-    // This pattern should only catch the broadcast case, the non-broadcast case
-    // should be done separately to keep application conditions clean and
-    // separate.
-    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
-    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
-    bool bcast = !readMap.getBroadcastDims().empty() ||
-                 !writeMap.getBroadcastDims().empty();
-    if (!bcast)
-      return failure();
-    // At this point, we know we have a bcast.
-    // Bail in the masked case (too complex atm and needed to properly account
-    // for padding).
-    if (readOp.getMask() || defWrite.getMask())
-      return failure();
-    // If indices are not the same a shift may be required, bail.
-    if (readOp.getIndices() != defWrite.getIndices())
+    // WriteMap: tensor -> w_vec
+    // ReadMap: tensor -> r_vec
+    //
+    // inv(WriteMap): w_vec -> tensor
+    // inv(WriteMap) o ReadMap: w_vec -> r_vec
+    AffineMap readMap = readOp.getPermutationMap();
+    AffineMap writeMap = defWrite.getPermutationMap();
+    AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
+    AffineMap composedMap = readMap.compose(invWriteMap);
+    // If there are any unused dims in the composedMap, we have to drop some
+    // unit dims from the written vector before we can do transpose(broadcast).
+    // TODO: Support this case.
+    if (getUnusedDimsBitVector(composedMap).any())
       return failure();
-
+    // readVec = transpose(broadcast(writeVec))
+    //
+    // Build a transpose permutation for the above transpose operation.
+    //
+    // Treat the composed map as having extra leading dimensions which are
+    // the broadcasted dimensions, and treat the zeros as these new broadcasted
+    // dimensions.
+    SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
+    int64_t numBroadcastedDims = broadcastedDims.size();
+    SmallVector<int64_t> invPerm(broadcastedDims.begin(),
+                                 broadcastedDims.end());
----------------
kuhar wrote:

use `to_vector_of<int64_t>`

https://github.com/llvm/llvm-project/pull/171840


More information about the Mlir-commits mailing list