[Mlir-commits] [mlir] [mlir][vector] Add mask elimination transform (PR #99314)
Andrzej Warzyński
llvmlistbot at llvm.org
Wed Aug 7 08:15:22 PDT 2024
================
@@ -5862,73 +5887,46 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
- VectorType retTy = createMaskOp.getResult().getType();
- bool isScalable = retTy.isScalable();
-
- // Check every mask operand
- for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
- if (auto cst = getConstantIntValue(operand)) {
- // Most basic case - this operand is a constant value. Note that for
- // scalable dimensions, CreateMaskOp can be folded only if the
- // corresponding operand is negative or zero.
- if (retTy.getScalableDims()[opIdx] && *cst > 0)
- return failure();
-
- continue;
- }
-
- // Non-constant operands are not allowed for non-scalable vectors.
- if (!isScalable)
- return failure();
-
- // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
- // true" mask, so can also be treated as constant.
- auto mul = operand.getDefiningOp<arith::MulIOp>();
- if (!mul)
- return failure();
- auto mulLHS = mul.getRhs();
- auto mulRHS = mul.getLhs();
- bool isOneOpVscale =
- (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
- isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
-
- auto isConstantValMatchingDim =
- [=, dim = retTy.getShape()[opIdx]](Value operand) {
- auto constantVal = getConstantIntValue(operand);
- return (constantVal.has_value() && constantVal.value() == dim);
- };
-
- bool isOneOpConstantMatchingDim =
- isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
-
- if (!isOneOpVscale || !isOneOpConstantMatchingDim)
- return failure();
+ VectorType maskType = createMaskOp.getVectorType();
+ ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
+ ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
+
+ // Special case: Rank zero shape.
+ constexpr std::array<int64_t, 1> rankZeroShape{1};
+ constexpr std::array<bool, 1> rankZeroScalableDims{false};
+ if (maskType.getRank() == 0) {
+ maskTypeDimSizes = rankZeroShape;
+ maskTypeDimScalableFlags = rankZeroScalableDims;
}
- // Gather constant mask dimension sizes.
- SmallVector<int64_t, 4> maskDimSizes;
- maskDimSizes.reserve(createMaskOp->getNumOperands());
- for (auto [operand, maxDimSize] : llvm::zip_equal(
- createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
- std::optional dimSize = getConstantIntValue(operand);
- if (!dimSize) {
- // Although not a constant, it is safe to assume that `operand` is
- // "vscale * maxDimSize".
- maskDimSizes.push_back(maxDimSize);
- continue;
- }
- int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
- // If one of dim sizes is zero, set all dims to zero.
- if (dimSize <= 0) {
- maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
- break;
+ SmallVector<int64_t, 4> constantDims;
+ for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
+ if (auto intSize = getConstantIntValue(dimSize)) {
+ // Non scalable dims can have any value. Scalable dims can only be zero.
+ if (intSize >= 0 && maskTypeDimScalableFlags[i])
+ return failure();
+ constantDims.push_back(*intSize);
+ } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
+ // Scalable dims must be all-true.
+ if (vscaleMultiplier < maskTypeDimSizes[i])
+ return failure();
+ constantDims.push_back(*vscaleMultiplier);
+ } else {
+ return failure();
----------------
banach-space wrote:
Please document what the purpose of this loop is (IIUC, identify whether the folding would make sense at all + collect constant dims to use for `vector.constant_mask`).
In particular, let's make it clear what the difference between this loop and what's in [VectorMaskElimination.cpp](https://github.com/llvm/llvm-project/pull/99314/files#diff-bd4772744dd105a11599527d87340a3b6e5b354d5fbe6af65d86437cca59c493) actually is 🙏🏻
https://github.com/llvm/llvm-project/pull/99314
More information about the Mlir-commits
mailing list