[Mlir-commits] [mlir] [MLIR][XeGPU] Add lowering from transfer_read/transfer_write to load_gather/store_scatter (PR #152429)
Adam Siemieniuk
llvmlistbot at llvm.org
Mon Aug 11 03:14:01 PDT 2025
================
@@ -155,6 +152,340 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
return ndDesc;
}
+static LogicalResult
+extraCheckForScatteredLoadStore(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
+ // 1. it must be inbound access by checking in_bounds attributes, like
+ // {in_bounds = [false, true]}
+ if (xferOp.hasOutOfBoundsDim())
+ return rewriter.notifyMatchFailure(xferOp,
+ "Out-of-bounds access is not supported "
+ "for scatter load/store lowering");
+ // 2. if the memref has static shape, its lower rank must exactly match with
+ // vector shape.
+ if (auto memrefType = dyn_cast<MemRefType>(xferOp.getShapedType())) {
+ if (memrefType.hasStaticShape()) {
+ ArrayRef<int64_t> memrefShape = memrefType.getShape();
+ ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
+ size_t memrefRank = memrefShape.size();
+ size_t vectorRank = vectorShape.size();
+ if (vectorRank > memrefRank)
+ return rewriter.notifyMatchFailure(
+ xferOp, "Vector rank cannot exceed memref rank");
+ // Compare the last vectorRank dimensions of memref with vector shape
+ for (size_t i = 0; i < vectorRank; ++i) {
+ if (memrefShape[memrefRank - vectorRank + i] <= vectorShape[i])
+ return rewriter.notifyMatchFailure(
+ xferOp, "Memref lower dimensions must match vector shape");
+ }
+ }
+ }
+ return success();
+}
+
+static LogicalResult adjustStridesForPermutation(
+ Operation *op, PatternRewriter &rewriter, MemRefType memrefType,
+ AffineMap permMap, VectorType vecType, SmallVectorImpl<Value> &strides) {
+ unsigned vecRank;
+ unsigned memrefRank = memrefType.getRank();
+
+ if (permMap.isMinorIdentity())
+ return success();
+ vecRank = vecType.getRank();
+ // Only adjust the last vecRank strides according to the permutation
+ ArrayRef<Value> relevantStrides = ArrayRef<Value>(strides).take_back(vecRank);
+ SmallVector<Value> adjustedStrides(vecRank);
+ // For each output dimension in the permutation map, find which input dim it
+ // refers to, and assign the corresponding stride.
+ for (unsigned outIdx = 0; outIdx < vecRank; ++outIdx) {
+ AffineExpr expr = permMap.getResult(outIdx);
+ auto dimExpr = dyn_cast<AffineDimExpr>(expr);
+ if (!dimExpr) {
+ return rewriter.notifyMatchFailure(op, "Unsupported permutation expr");
+ }
+ unsigned pos = dimExpr.getPosition();
+ // Map permutation to the relevant strides (innermost dims)
+ if (pos < memrefRank - vecRank) {
+ return rewriter.notifyMatchFailure(op, "Permutation out of bounds");
+ }
+ // The stride for output dimension outIdx is the stride of input dimension
+ // pos
+ adjustedStrides[outIdx] = relevantStrides[pos - (memrefRank - vecRank)];
+ }
+ // Replace the last vecRank strides with the adjusted ones
+ for (unsigned i = 0; i < vecRank; ++i)
+ strides[memrefRank - vecRank + i] = adjustedStrides[i];
+
+ return success();
+}
+
+SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
+ SmallVector<Value> strides;
+ Value baseMemref = xferOp.getBase();
+ AffineMap permMap = xferOp.getPermutationMap();
+ VectorType vectorType = xferOp.getVectorType();
+ MemRefType memrefType = llvm::cast<MemRefType>(baseMemref.getType());
+
+ Location loc = xferOp.getLoc();
+ if (memrefType.hasStaticShape()) {
+ int64_t offset;
+ SmallVector<int64_t> intStrides;
+ if (failed(memrefType.getStridesAndOffset(intStrides, offset))) {
+ return {};
+ }
+ // Wrap static strides as MLIR values
+ for (int64_t s : intStrides)
+ strides.push_back(rewriter.create<arith::ConstantIndexOp>(loc, s));
+ } else {
+ // For dynamic shape memref, use memref.extract_strided_metadata to get
+ // stride values
+ unsigned rank = memrefType.getRank();
+ Type indexType = rewriter.getIndexType();
+
+ // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
+ // size0, size1, ..., sizeN-1]
+ SmallVector<Type> resultTypes;
+ resultTypes.push_back(MemRefType::get(
+ {}, memrefType.getElementType())); // base memref (unranked)
+ resultTypes.push_back(indexType); // offset
+ for (unsigned i = 0; i < rank; ++i) {
+ resultTypes.push_back(indexType); // strides
+ }
+ for (unsigned i = 0; i < rank; ++i) {
+ resultTypes.push_back(indexType); // sizes
+ }
+
+ auto meta = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, resultTypes, baseMemref);
+ strides.append(meta.getStrides().begin(), meta.getStrides().end());
+ }
+ // Adjust strides according to the permutation map (e.g., for transpose)
+ if (failed(adjustStridesForPermutation(xferOp, rewriter, memrefType, permMap,
+ vectorType, strides))) {
+ return {};
+ }
+ return strides;
+}
+
+// This function compute the vectors of localOffsets for scattered load/stores.
+// It is used in the lowering of vector.transfer_read/write to
+// load_gather/store_scatter Example:
+// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
+// %cst {in_bounds = [true, true, true, true]}>} :
+// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
+//
+// %6 = vector.step: vector<4xindex>
+// %7 = vector.step: vector<2xindex>
+// %8 = vector.step: vector<6xindex>
+// %9 = vector.step: vector<32xindex>
+// %10 = arith.mul %6, 384
+// %11 = arith.mul %7, 192
+// %12 = arith.mul %8, 32
+// %13 = arith.mul %9, 1
+// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
+// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
+// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
+// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
+// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
+// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
+// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
+// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
+// %22 = arith.add %18, %19
+// %23 = arith.add %20, %21
+// %local_offsets = arith.add %22, %23
+// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
+// %offsets = orig_offset + local_offsets
+static Value computeOffsets(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter,
+ ArrayRef<Value> strides) {
+ Location loc = xferOp.getLoc();
+ VectorType vectorType = xferOp.getVectorType();
+ SmallVector<Value> indices(xferOp.getIndices().begin(),
+ xferOp.getIndices().end());
+ ArrayRef<int64_t> vectorShape = vectorType.getShape();
+
+ // Step 1: Create vector.step operations for each dimension
----------------
adam-smnk wrote:
nit: I'd avoid explicit `1, 2, 3...` enumeration; it's always annoying to maintain in case of future refactors and changes
https://github.com/llvm/llvm-project/pull/152429
More information about the Mlir-commits
mailing list