[Mlir-commits] [mlir] [MLIR][XeGPU] Add lowering from transfer_read/transfer_write to load_gather/store_scatter (PR #152429)
Jianhui Li
llvmlistbot at llvm.org
Wed Aug 13 14:40:57 PDT 2025
================
@@ -155,6 +152,300 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
return ndDesc;
}
+// Adjusts the strides of a memref according to a given permutation map for
+// vector operations.
+//
+// This function updates the last `vecRank` elements of the `strides` array to
+// reflect the permutation specified by `permMap`. The permutation is applied
+// to the innermost dimensions of the memref, corresponding to the vector
+// shape. This is typically used when lowering vector transfer operations with
+// permutation maps to memory accesses, ensuring that the memory strides match
+// the logical permutation of vector dimensions.
+//
+// Example:
+// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]` and a
+// vector of rank 2. If the permutation map swaps the last two dimensions
+// (e.g., [0, 1] -> [1, 0]), then after calling this function, the last two
+// strides will be swapped:
+// Original strides: [s0, s1, s2, s3]
+// After permutation: [s0, s1, s3, s2]
+//
+void adjustStridesForPermutation(Operation *op, PatternRewriter &rewriter,
+ MemRefType memrefType, AffineMap permMap,
+ VectorType vecType,
+ SmallVectorImpl<Value> &strides) {
+ unsigned vecRank;
+ unsigned memrefRank = memrefType.getRank();
+
+ if (permMap.isMinorIdentity())
+ return;
+ vecRank = vecType.getRank();
+ // Only adjust the last vecRank strides according to the permutation
+ ArrayRef<Value> relevantStrides = ArrayRef<Value>(strides).take_back(vecRank);
+ SmallVector<Value> adjustedStrides(vecRank);
+ // For each output dimension in the permutation map, find which input dim it
+ // refers to, and assign the corresponding stride.
+ for (unsigned outIdx = 0; outIdx < vecRank; ++outIdx) {
+ AffineExpr expr = permMap.getResult(outIdx);
+ auto dimExpr = dyn_cast<AffineDimExpr>(expr);
+ assert(dimExpr && "The permutation expr must be affine expression");
+ unsigned pos = dimExpr.getPosition();
+ // Map permutation to the relevant strides (innermost dims)
+ assert((pos >= (memrefRank - vecRank)) &&
+ "Permuted index must be in the inner dimensions");
+
+ // The stride for output dimension outIdx is the stride of input dimension
+ // pos
+ adjustedStrides[outIdx] = relevantStrides[pos - (memrefRank - vecRank)];
+ }
+ // Replace the last vecRank strides with the adjusted ones
+ for (unsigned i = 0; i < vecRank; ++i)
+ strides[memrefRank - vecRank + i] = adjustedStrides[i];
+}
+
+SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
+ SmallVector<Value> strides;
+ Value baseMemref = xferOp.getBase();
+ AffineMap permMap = xferOp.getPermutationMap();
+ VectorType vectorType = xferOp.getVectorType();
+ MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
+
+ Location loc = xferOp.getLoc();
+ if (memrefType.hasStaticShape()) {
+ int64_t offset;
+ SmallVector<int64_t> intStrides;
+ if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
+ return {};
+ // Wrap static strides as MLIR values
+ for (int64_t s : intStrides)
+ strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
+ } else {
+ // For dynamic shape memref, use memref.extract_strided_metadata to get
+ // stride values
+ unsigned rank = memrefType.getRank();
+ Type indexType = rewriter.getIndexType();
+
+ // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
+ // size0, size1, ..., sizeN-1]
+ SmallVector<Type> resultTypes;
+ resultTypes.push_back(MemRefType::get(
+ {}, memrefType.getElementType())); // base memref (unranked)
+ resultTypes.push_back(indexType); // offset
+ for (unsigned i = 0; i < rank; ++i) {
+ resultTypes.push_back(indexType); // strides
+ }
+ for (unsigned i = 0; i < rank; ++i) {
+ resultTypes.push_back(indexType); // sizes
+ }
----------------
Jianhui-Li wrote:
removed
https://github.com/llvm/llvm-project/pull/152429
More information about the Mlir-commits
mailing list