[Mlir-commits] [mlir] [MLIR] Refactor mask compression logic when emulating `vector.maskedload` ops (PR #116520)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Nov 22 06:57:46 PST 2024
================
@@ -128,34 +128,16 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
newMaskOperands);
})
- .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
- -> std::optional<Operation *> {
- ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
- size_t numMaskOperands = maskDimSizes.size();
- int64_t origIndex = maskDimSizes[numMaskOperands - 1];
- int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
- int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
- numSrcElemsPerDest);
-
- // TODO: we only want the mask between [startIndex, maskIndex]
- // to be true, the rest are false.
- if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
- return std::nullopt;
-
- SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
- newMaskDimSizes.push_back(maskIndex);
-
- if (numFrontPadElems == 0)
- return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
- newMaskDimSizes);
-
- SmallVector<bool> newMaskValues;
- for (int64_t i = 0; i < numDestElems; ++i)
- newMaskValues.push_back(i >= startIndex && i < maskIndex);
- auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
- return rewriter.create<arith::ConstantOp>(loc, newMaskType,
- newMask);
- })
+ .Case<vector::ConstantMaskOp>(
+ [&](auto constantMaskOp) -> std::optional<Operation *> {
+ SmallVector<int64_t> maskDimSizes(
+ constantMaskOp.getMaskDimSizes());
+ int64_t &maskIndex = maskDimSizes.back();
+ maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex,
+ numSrcElemsPerDest);
----------------
banach-space wrote:
Perhaps add a note that you are updating the _trailing_ dim of the mask _in place_? That wasn't immediately obvious to me.
https://github.com/llvm/llvm-project/pull/116520
More information about the Mlir-commits
mailing list