[Mlir-commits] [mlir] [MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (PR #144447)
Jianhui Li
llvmlistbot at llvm.org
Tue Jun 17 14:46:33 PDT 2025
================
@@ -519,30 +570,51 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- // check if the tensor descriptor type is a 1d vector type
- if (tdescTy.getRank() > 1)
+ if (!tdescTy.isScattered())
return failure();
- VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
-
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();
- SmallVector<Type> convertedValTypes =
- getUnrolledTypes(valueTy, *targetShape);
+ SmallVector<int64_t> targetIndiceShape(*targetShape);
+ int64_t originalChunkSize = tdescTy.getChunkSize();
+
+ VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
-
- SmallVector<Value> convertedValues =
- pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- SmallVector<Type> convertedMaskTypes =
- getUnrolledTypes(maskTy, *targetShape);
- SmallVector<Value> convertedMasks =
- pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ SmallVector<Type> convertedMaskTypes;
+ SmallVector<Value> convertedMasks;
+
+ if (originalChunkSize > 1) {
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+ convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
+ SmallVector<Value> convertedMasks1D = pack(
+ op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
+
+ for (auto mask : convertedMasks1D) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
+ convertedMasks.push_back(mask);
+ }
+ }
+ // This is to handle the transpose effect when chunkSize > 1.
+ std::swap((*targetShape)[0], (*targetShape)[1]);
+
+ } else {
+ convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
+ convertedMasks =
+ pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ }
----------------
Jianhui-Li wrote:
The similarity is at very small scope. Introducing additional abstraction may not worth it.
https://github.com/llvm/llvm-project/pull/144447
More information about the Mlir-commits
mailing list