[Mlir-commits] [mlir] [MLIR][XeGPU] Add lowering from transfer_read/transfer_write to load_gather/store_scatter (PR #152429)

Jianhui Li llvmlistbot at llvm.org
Wed Aug 13 14:41:07 PDT 2025


================
@@ -155,6 +152,300 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
   return ndDesc;
 }
 
+// Adjusts the strides of a memref according to a given permutation map for
+// vector operations.
+//
+// This function updates the last `vecRank` elements of the `strides` array to
+// reflect the permutation specified by `permMap`. The permutation is applied
+// to the innermost dimensions of the memref, corresponding to the vector
+// shape. This is typically used when lowering vector transfer operations with
+// permutation maps to memory accesses, ensuring that the memory strides match
+// the logical permutation of vector dimensions.
+//
+// Example:
+//   Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]` and a
+//   vector of rank 2. If the permutation map swaps the last two dimensions
+//   (e.g., [0, 1] -> [1, 0]), then after calling this function, the last two
+//   strides will be swapped:
+//     Original strides: [s0, s1, s2, s3]
+//     After permutation: [s0, s1, s3, s2]
+//
+void adjustStridesForPermutation(Operation *op, PatternRewriter &rewriter,
+                                 MemRefType memrefType, AffineMap permMap,
+                                 VectorType vecType,
+                                 SmallVectorImpl<Value> &strides) {
+  unsigned vecRank;
+  unsigned memrefRank = memrefType.getRank();
+
+  if (permMap.isMinorIdentity())
+    return;
+  vecRank = vecType.getRank();
+  // Only adjust the last vecRank strides according to the permutation
+  ArrayRef<Value> relevantStrides = ArrayRef<Value>(strides).take_back(vecRank);
+  SmallVector<Value> adjustedStrides(vecRank);
+  // For each output dimension in the permutation map, find which input dim it
+  // refers to, and assign the corresponding stride.
+  for (unsigned outIdx = 0; outIdx < vecRank; ++outIdx) {
+    AffineExpr expr = permMap.getResult(outIdx);
+    auto dimExpr = dyn_cast<AffineDimExpr>(expr);
+    assert(dimExpr && "The permutation expr must be affine expression");
+    unsigned pos = dimExpr.getPosition();
+    // Map permutation to the relevant strides (innermost dims)
+    assert((pos >= (memrefRank - vecRank)) &&
+           "Permuted index must be in the inner dimensions");
+
+    // The stride for output dimension outIdx is the stride of input dimension
+    // pos
+    adjustedStrides[outIdx] = relevantStrides[pos - (memrefRank - vecRank)];
+  }
+  // Replace the last vecRank strides with the adjusted ones
+  for (unsigned i = 0; i < vecRank; ++i)
+    strides[memrefRank - vecRank + i] = adjustedStrides[i];
+}
+
+SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
----------------
Jianhui-Li wrote:

added

https://github.com/llvm/llvm-project/pull/152429


More information about the Mlir-commits mailing list