[Mlir-commits] [mlir] [MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (PR #144447)
Jianhui Li
llvmlistbot at llvm.org
Tue Jun 17 14:43:19 PDT 2025
================
@@ -402,30 +402,58 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getType();
+ TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+ VectorType indiceVecTy = indiceVec.getType();
- // check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1)
+ if (!tdescTy.isScattered())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
- auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-
- TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
- VectorType indiceVecTy = indiceVec.getType();
+ SmallVector<int64_t> targetIndiceShape(*targetShape);
+ int64_t originalChunkSize = tdescTy.getChunkSize();
+ // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
+ if (originalChunkSize > 1)
+ targetIndiceShape.pop_back();
+ auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
SmallVector<Type> convertedIndiceTypes =
- getUnrolledTypes(indiceVecTy, *targetShape);
+ getUnrolledTypes(indiceVecTy, targetIndiceShape);
SmallVector<Value> convertedIndiceVec =
- pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+ pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
SmallVector<Value> newOps;
- for (auto indice : convertedIndiceVec) {
- auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
- op.getSource(), indice);
- newOps.push_back(newOp);
+
+ // more indices is need when chunkSize > 1. Since a big load from one
+ // address could be break into multiple small loads.
+ if (originalChunkSize > 1) {
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+
+ for (auto [indice, indiceType] :
+ llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
+ // Compute the offset
+ Value inc = rewriter.create<arith::ConstantIndexOp>(
+ loc, i * blockedChunkSize);
+ Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
----------------
Jianhui-Li wrote:
added
https://github.com/llvm/llvm-project/pull/144447
More information about the Mlir-commits
mailing list