[Mlir-commits] [mlir] [MLIR] `vector.constant_mask` to support unaligned cases (PR #116520)
Han-Chung Wang
llvmlistbot at llvm.org
Mon Nov 18 11:06:17 PST 2024
================
@@ -125,34 +125,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
newMaskOperands);
})
- .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
- -> std::optional<Operation *> {
- ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
- size_t numMaskOperands = maskDimSizes.size();
- int64_t origIndex = maskDimSizes[numMaskOperands - 1];
- int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
- int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
- numSrcElemsPerDest);
-
- // TODO: we only want the mask between [startIndex, maskIndex]
- // to be true, the rest are false.
- if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
- return std::nullopt;
-
- SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
- newMaskDimSizes.push_back(maskIndex);
-
- if (numFrontPadElems == 0)
- return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
- newMaskDimSizes);
-
- SmallVector<bool> newMaskValues;
- for (int64_t i = 0; i < numDestElems; ++i)
- newMaskValues.push_back(i >= startIndex && i < maskIndex);
- auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
- return rewriter.create<arith::ConstantOp>(loc, newMaskType,
- newMask);
- })
+ .Case<vector::ConstantMaskOp>(
+ [&](auto constantMaskOp) -> std::optional<Operation *> {
+ ArrayRef<int64_t> maskDimSizes =
+ constantMaskOp.getMaskDimSizes();
+ size_t numMaskOperands = maskDimSizes.size();
+ int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+ int64_t maskIndex = llvm::divideCeil(
+ numFrontPadElems + origIndex, numSrcElemsPerDest);
+ SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+ newMaskDimSizes.push_back(maskIndex);
----------------
hanhanW wrote:
Instead of creating a new vector, can we declare it in the first place and update the last element? E.g.,
```cpp
SmallVector<int64_t> maskDimSizes =
constantMaskOp.getMaskDimSizes();
int64_t &maskIndex = maskDimSizes.back(); // Or use maskDimSizes.back() directly.
maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex, numSrcElemsPerDest);
```
https://github.com/llvm/llvm-project/pull/116520
More information about the Mlir-commits
mailing list