[Mlir-commits] [mlir] [mlir][Vector] Improve vector.transferx store-to-load-forwarding (PR #171840)
Jakub Kuderski
llvmlistbot at llvm.org
Thu Dec 11 07:26:52 PST 2025
================
@@ -5161,60 +5185,51 @@ struct TransferReadAfterWriteToBroadcast
if (readOp.getTransferChunkAccessed() !=
defWrite.getTransferChunkAccessed())
return failure();
- // TODO: Support cases where a dim is explicitly written but implicitly
- // read (i.e., a unit dim that is rank reduced).
- if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
- getUnusedDimsBitVector({defWrite.getPermutationMap()}))
- return failure();
- // This pattern should only catch the broadcast case, the non-broadcast case
- // should be done separately to keep application conditions clean and
- // separate.
- AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
- AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
- bool bcast = !readMap.getBroadcastDims().empty() ||
- !writeMap.getBroadcastDims().empty();
- if (!bcast)
- return failure();
- // At this point, we know we have a bcast.
- // Bail in the masked case (too complex atm and needed to properly account
- // for padding).
- if (readOp.getMask() || defWrite.getMask())
- return failure();
- // If indices are not the same a shift may be required, bail.
- if (readOp.getIndices() != defWrite.getIndices())
+ // WriteMap: tensor -> w_vec
+ // ReadMap: tensor -> r_vec
+ //
+ // inv(WriteMap): w_vec -> tensor
+ // inv(WriteMap) o ReadMap: w_vec -> r_vec
+ AffineMap readMap = readOp.getPermutationMap();
+ AffineMap writeMap = defWrite.getPermutationMap();
+ AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
+ AffineMap composedMap = readMap.compose(invWriteMap);
+ // If there are any unused dims in the composedMap, we have to drop some
+ // unit dims from the written vector before we can do transpose(broadcast).
+ // TODO: Support this case.
+ if (getUnusedDimsBitVector(composedMap).any())
return failure();
-
+ // readVec = transpose(broadcast(writeVec))
+ //
+ // Build a transpose permutation for the above transpose operation.
+ //
+ // Treat the composed map as having extra leading dimensions which are
+ // the broadcasted dimensions, and treat the zeros as these new broadcasted
+ // dimensions.
+ SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
+ int64_t numBroadcastedDims = broadcastedDims.size();
+ SmallVector<int64_t> invPerm(broadcastedDims.begin(),
+ broadcastedDims.end());
----------------
kuhar wrote:
use `to_vector_of<int64_t>`
https://github.com/llvm/llvm-project/pull/171840
More information about the Mlir-commits
mailing list