[Mlir-commits] [mlir] [MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (PR #144447)
Charitha Saumya
llvmlistbot at llvm.org
Mon Jun 16 17:02:08 PDT 2025
================
@@ -462,10 +492,32 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
- SmallVector<Type> convertedMaskTypes =
- getUnrolledTypes(maskTy, *targetShape);
- SmallVector<Value> convertedMasks =
- pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+ SmallVector<Type> convertedMaskTypes;
+ SmallVector<Value> convertedMasks;
+
+ if (originalChunkSize > 1) {
+ targetMaskShape.pop_back();
+ convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+ SmallVector<Value> convertedMasks1D = pack(
+ op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
+ int64_t blockedChunkSize = targetShape->back();
+ int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+
+ for (auto mask : convertedMasks1D) {
+ for (int64_t i = 0; i < numNewChunks; ++i) {
+ convertedMasks.push_back(mask);
+ }
+ }
+ // This is to handle the transpose effect when chunkSize > 1.
+ if (targetShape && targetShape->size() > 1) {
----------------
charithaintc wrote:
targetShape is already valid. not need to check again.
https://github.com/llvm/llvm-project/pull/144447
More information about the Mlir-commits
mailing list