[Mlir-commits] [mlir] [MLIR][XeGPU] Add lowering from transfer_read/transfer_write to load_gather/store_scatter (PR #152429)
Chao Chen
llvmlistbot at llvm.org
Wed Aug 13 15:29:58 PDT 2025
================
@@ -155,6 +153,284 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
return ndDesc;
}
+// Adjusts the strides of a memref according to a given permutation map for
+// vector operations.
+//
+// This function updates the last `vecRank` elements of the `strides` array to
+// reflect the permutation specified by `permMap`. The permutation is applied
+// to the innermost dimensions of the memref, corresponding to the vector
+// shape. This is typically used when lowering vector transfer operations with
+// permutation maps to memory accesses, ensuring that the memory strides match
+// the logical permutation of vector dimensions.
+//
+// Example:
+// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]` and a
+// vector of rank 2. If the permutation map swaps the last two dimensions
+// (e.g., [0, 1] -> [1, 0]), then after calling this function, the last two
+// strides will be swapped:
+// Original strides: [s0, s1, s2, s3]
+// After permutation: [s0, s1, s3, s2]
+//
+static void adjustStridesForPermutation(Operation *op,
+ PatternRewriter &rewriter,
+ MemRefType memrefType,
+ AffineMap permMap, VectorType vecType,
+ SmallVectorImpl<Value> &strides) {
+
+ AffineMap invMap = inverseAndBroadcastProjectedPermutation(permMap);
+ SmallVector<unsigned> perms;
+ invMap.isPermutationOfMinorIdentityWithBroadcasting(perms);
+ SmallVector<int64_t> perms64(perms.begin(), perms.end());
+ strides = applyPermutation(strides, perms64);
+}
+
+// Computes memory strides for vector transfer operations, handling both
+// static and dynamic memrefs while applying permutation transformations
+// for XeGPU lowering.
+SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
+ SmallVector<Value> strides;
+ Value baseMemref = xferOp.getBase();
+ AffineMap permMap = xferOp.getPermutationMap();
+ VectorType vectorType = xferOp.getVectorType();
+ MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
+
+ Location loc = xferOp.getLoc();
+ if (memrefType.hasStaticShape()) {
+ int64_t offset;
+ SmallVector<int64_t> intStrides;
+ if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
+ return {};
+ // Wrap static strides as MLIR values
+ for (int64_t s : intStrides)
+ strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
+ } else {
+ // For dynamic shape memref, use memref.extract_strided_metadata to get
+ // stride values
+ unsigned rank = memrefType.getRank();
+ Type indexType = rewriter.getIndexType();
+
+ // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
+ // size0, size1, ..., sizeN-1]
+ SmallVector<Type> resultTypes;
+ resultTypes.push_back(MemRefType::get(
+ {}, memrefType.getElementType())); // base memref (unranked)
+ resultTypes.push_back(indexType); // offset
+
+ for (unsigned i = 0; i < rank; ++i)
+ resultTypes.push_back(indexType); // strides
+
+ for (unsigned i = 0; i < rank; ++i)
+ resultTypes.push_back(indexType); // sizes
+
+ auto meta = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, resultTypes, baseMemref);
+ strides.append(meta.getStrides().begin(), meta.getStrides().end());
+ }
+ // Adjust strides according to the permutation map (e.g., for transpose)
+ adjustStridesForPermutation(xferOp, rewriter, memrefType, permMap, vectorType,
+ strides);
+ return strides;
+}
+
+// This function compute the vectors of localOffsets for scattered load/stores.
+// It is used in the lowering of vector.transfer_read/write to
+// load_gather/store_scatter Example:
+// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
+// %cst {in_bounds = [true, true, true, true]}>} :
+// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
+//
+// %6 = vector.step: vector<4xindex>
+// %7 = vector.step: vector<2xindex>
+// %8 = vector.step: vector<6xindex>
+// %9 = vector.step: vector<32xindex>
+// %10 = arith.mul %6, 384
+// %11 = arith.mul %7, 192
+// %12 = arith.mul %8, 32
+// %13 = arith.mul %9, 1
+// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
+// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
+// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
+// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
+// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
+// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
+// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
+// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
+// %22 = arith.add %18, %19
+// %23 = arith.add %20, %21
+// %local_offsets = arith.add %22, %23
+// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
+// %offsets = orig_offset + local_offsets
+static Value computeOffsets(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter,
+ ArrayRef<Value> strides) {
+ Location loc = xferOp.getLoc();
+ VectorType vectorType = xferOp.getVectorType();
+ SmallVector<Value> indices(xferOp.getIndices().begin(),
+ xferOp.getIndices().end());
+ ArrayRef<int64_t> vectorShape = vectorType.getShape();
+
+ // Create vector.step operations for each dimension
+ SmallVector<Value> stepVectors;
+ llvm::map_to_vector(vectorShape, [&](int64_t dim) {
+ auto stepType = VectorType::get({dim}, rewriter.getIndexType());
+ auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
+ stepVectors.push_back(stepOp);
+ return stepOp;
+ });
+
+ // Multiply step vectors by corresponding strides
+ size_t memrefRank = strides.size();
+ size_t vectorRank = vectorShape.size();
+ SmallVector<Value> strideMultiplied;
+ for (size_t i = 0; i < vectorRank; ++i) {
+ size_t memrefDim = memrefRank - vectorRank + i;
+ Value strideValue = strides[memrefDim];
+ auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
+ auto bcastOp =
+ vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
+ auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
+ strideMultiplied.push_back(mulOp);
+ }
+
+ // Shape cast each multiplied vector to add singleton dimensions
+ SmallVector<Value> shapeCasted;
+ for (size_t i = 0; i < vectorRank; ++i) {
+ SmallVector<int64_t> newShape(vectorRank, 1);
+ newShape[i] = vectorShape[i];
+ auto newType = VectorType::get(newShape, rewriter.getIndexType());
+ auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
+ strideMultiplied[i]);
+ shapeCasted.push_back(castOp);
+ }
+
+ // Broadcast each shape-casted vector to full vector shape
+ SmallVector<Value> broadcasted;
+ auto fullIndexVectorType =
+ VectorType::get(vectorShape, rewriter.getIndexType());
+ for (Value shapeCastVal : shapeCasted) {
+ auto broadcastOp = vector::BroadcastOp::create(
+ rewriter, loc, fullIndexVectorType, shapeCastVal);
+ broadcasted.push_back(broadcastOp);
+ }
+
+ // Add all broadcasted vectors together to compute local offsets
+ Value localOffsets = broadcasted[0];
+ for (size_t i = 1; i < broadcasted.size(); ++i) {
+ localOffsets =
+ arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
----------------
chencha3 wrote:
nit: braces are not necessary.
https://github.com/llvm/llvm-project/pull/152429
More information about the Mlir-commits
mailing list