[Mlir-commits] [mlir] [MLIR][XeGPU] Add lowering from transfer_read/transfer_write to load_gather/store_scatter (PR #152429)

Artem Kroviakov llvmlistbot at llvm.org
Thu Aug 7 02:01:19 PDT 2025


================
@@ -155,6 +152,340 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
   return ndDesc;
 }
 
+static LogicalResult
+extraCheckForScatteredLoadStore(VectorTransferOpInterface xferOp,
+                                PatternRewriter &rewriter) {
+  // 1. it must be inbound access by checking in_bounds attributes, like
+  // {in_bounds = [false, true]}
+  if (xferOp.hasOutOfBoundsDim())
+    return rewriter.notifyMatchFailure(xferOp,
+                                       "Out-of-bounds access is not supported "
+                                       "for scatter load/store lowering");
+  // 2. if the memref has static shape, its lower rank must exactly match with
+  // vector shape.
+  if (auto memrefType = dyn_cast<MemRefType>(xferOp.getShapedType())) {
+    if (memrefType.hasStaticShape()) {
+      ArrayRef<int64_t> memrefShape = memrefType.getShape();
+      ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
+      size_t memrefRank = memrefShape.size();
+      size_t vectorRank = vectorShape.size();
+      if (vectorRank > memrefRank)
+        return rewriter.notifyMatchFailure(
+            xferOp, "Vector rank cannot exceed memref rank");
+      // Compare the last vectorRank dimensions of memref with vector shape
+      for (size_t i = 0; i < vectorRank; ++i) {
+        if (memrefShape[memrefRank - vectorRank + i] <= vectorShape[i])
+          return rewriter.notifyMatchFailure(
+              xferOp, "Memref lower dimensions must match vector shape");
+      }
+    }
+  }
+  return success();
+}
+
+static LogicalResult adjustStridesForPermutation(
+    Operation *op, PatternRewriter &rewriter, MemRefType memrefType,
+    AffineMap permMap, VectorType vecType, SmallVectorImpl<Value> &strides) {
+  unsigned vecRank;
+  unsigned memrefRank = memrefType.getRank();
+
+  if (permMap.isMinorIdentity())
+    return success();
+  vecRank = vecType.getRank();
+  // Only adjust the last vecRank strides according to the permutation
+  ArrayRef<Value> relevantStrides = ArrayRef<Value>(strides).take_back(vecRank);
+  SmallVector<Value> adjustedStrides(vecRank);
+  // For each output dimension in the permutation map, find which input dim it
+  // refers to, and assign the corresponding stride.
+  for (unsigned outIdx = 0; outIdx < vecRank; ++outIdx) {
+    AffineExpr expr = permMap.getResult(outIdx);
+    auto dimExpr = dyn_cast<AffineDimExpr>(expr);
+    if (!dimExpr) {
+      return rewriter.notifyMatchFailure(op, "Unsupported permutation expr");
+    }
+    unsigned pos = dimExpr.getPosition();
+    // Map permutation to the relevant strides (innermost dims)
+    if (pos < memrefRank - vecRank) {
+      return rewriter.notifyMatchFailure(op, "Permutation out of bounds");
+    }
+    // The stride for output dimension outIdx is the stride of input dimension
+    // pos
+    adjustedStrides[outIdx] = relevantStrides[pos - (memrefRank - vecRank)];
+  }
+  // Replace the last vecRank strides with the adjusted ones
+  for (unsigned i = 0; i < vecRank; ++i)
+    strides[memrefRank - vecRank + i] = adjustedStrides[i];
+
+  return success();
+}
+
+SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
+                                  PatternRewriter &rewriter) {
+  SmallVector<Value> strides;
+  Value baseMemref = xferOp.getBase();
+  AffineMap permMap = xferOp.getPermutationMap();
+  VectorType vectorType = xferOp.getVectorType();
+  MemRefType memrefType = llvm::cast<MemRefType>(baseMemref.getType());
+
+  Location loc = xferOp.getLoc();
+  if (memrefType.hasStaticShape()) {
+    int64_t offset;
+    SmallVector<int64_t> intStrides;
+    if (failed(memrefType.getStridesAndOffset(intStrides, offset))) {
+      return {};
+    }
+    // Wrap static strides as MLIR values
+    for (int64_t s : intStrides)
+      strides.push_back(rewriter.create<arith::ConstantIndexOp>(loc, s));
----------------
akroviakov wrote:

Minor comment. The upstream has changed the op creation notation:
https://github.com/llvm/llvm-project/blob/b9e133d5b6e41b652ba579bcb8850c00f72d0f01/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp#L111-L112

https://github.com/llvm/llvm-project/pull/152429


More information about the Mlir-commits mailing list