[Mlir-commits] [mlir] [MLIR] `vector.constant_mask` to support unaligned cases (PR #116520)

Han-Chung Wang llvmlistbot at llvm.org
Mon Nov 18 11:06:17 PST 2024


================
@@ -125,34 +125,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                 return rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
                                                              newMaskOperands);
               })
-          .Case<vector::ConstantMaskOp>([&](auto constantMaskOp)
-                                            -> std::optional<Operation *> {
-            ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
-            size_t numMaskOperands = maskDimSizes.size();
-            int64_t origIndex = maskDimSizes[numMaskOperands - 1];
-            int64_t startIndex = numFrontPadElems / numSrcElemsPerDest;
-            int64_t maskIndex = llvm::divideCeil(numFrontPadElems + origIndex,
-                                                 numSrcElemsPerDest);
-
-            // TODO: we only want the mask between [startIndex, maskIndex]
-            // to be true, the rest are false.
-            if (numFrontPadElems != 0 && maskDimSizes.size() > 1)
-              return std::nullopt;
-
-            SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
-            newMaskDimSizes.push_back(maskIndex);
-
-            if (numFrontPadElems == 0)
-              return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
-                                                             newMaskDimSizes);
-
-            SmallVector<bool> newMaskValues;
-            for (int64_t i = 0; i < numDestElems; ++i)
-              newMaskValues.push_back(i >= startIndex && i < maskIndex);
-            auto newMask = DenseElementsAttr::get(newMaskType, newMaskValues);
-            return rewriter.create<arith::ConstantOp>(loc, newMaskType,
-                                                      newMask);
-          })
+          .Case<vector::ConstantMaskOp>(
+              [&](auto constantMaskOp) -> std::optional<Operation *> {
+                ArrayRef<int64_t> maskDimSizes =
+                    constantMaskOp.getMaskDimSizes();
+                size_t numMaskOperands = maskDimSizes.size();
+                int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+                int64_t maskIndex = llvm::divideCeil(
+                    numFrontPadElems + origIndex, numSrcElemsPerDest);
+                SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+                newMaskDimSizes.push_back(maskIndex);
----------------
hanhanW wrote:

Instead of creating a new vector, can we declare it in the first place and update the last element? E.g.,

```cpp
SmallVector<int64_t> maskDimSizes =
                    constantMaskOp.getMaskDimSizes();
int64_t &maskIndex = maskDimSizes.back(); // Or use maskDimSizes.back() directly.
maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex, numSrcElemsPerDest);
```


https://github.com/llvm/llvm-project/pull/116520


More information about the Mlir-commits mailing list