[Mlir-commits] [mlir] [mlir][vector] Add mask elimination transform (PR #99314)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Jul 29 02:29:00 PDT 2024
================
@@ -0,0 +1,131 @@
+//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+namespace {
+
+/// If `value` is a constant multiple of `vector.vscale` return the multiplier.
+std::optional<int64_t> getConstantVscaleMultiplier(Value value) {
+ if (value.getDefiningOp<vector::VectorScaleOp>())
+ return 1;
+ auto mul = value.getDefiningOp<arith::MulIOp>();
+ if (!mul)
+ return {};
+ auto lhs = mul.getLhs();
+ auto rhs = mul.getRhs();
+ if (lhs.getDefiningOp<vector::VectorScaleOp>())
+ return getConstantIntValue(rhs);
+ if (rhs.getDefiningOp<vector::VectorScaleOp>())
+ return getConstantIntValue(lhs);
+ return {};
+}
+
+/// Attempts to resolve a (scalable) CreateMaskOp to an all-true constant mask.
+/// All-true masks can then be eliminated by simple folds.
+LogicalResult resolveAllTrueCreateMaskOp(IRRewriter &rewriter,
+ vector::CreateMaskOp createMaskOp,
+ VscaleRange vscaleRange) {
+ auto maskType = createMaskOp.getVectorType();
+ auto maskTypeDimScalableFlags = maskType.getScalableDims();
+ auto maskTypeDimSizes = maskType.getShape();
+
+ struct UnknownMaskDim {
+ size_t position;
+ Value dimSize;
+ };
+
+ // Check for any dims that could be (partially) false before doing the more
+ // expensive value bounds computations.
+ SmallVector<UnknownMaskDim> unknownDims;
+ for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
+ if (auto intSize = getConstantIntValue(dimSize)) {
+ // Mask not all-true for this dim.
+ if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
+ return failure();
+ } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
+ // Mask not all-true for this dim.
+ if (vscaleMultiplier < maskTypeDimSizes[i])
+ return failure();
+ } else {
+ // Unknown (without further analysis).
+ unknownDims.push_back(UnknownMaskDim{i, dimSize});
+ }
+ }
+
+ for (auto [i, dimSize] : unknownDims) {
+ // Compute the lower bound for the unknown dimension (i.e. the smallest
+ // value it could be).
+ FailureOr<ConstantOrScalableBound> dimLowerBound =
+ vector::ScalableValueBoundsConstraintSet::computeScalableBound(
+ dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
+ presburger::BoundType::LB);
+ if (failed(dimLowerBound))
+ return failure();
+ auto dimLowerBoundSize = dimLowerBound->getSize();
+ if (failed(dimLowerBoundSize))
+ return failure();
+ if (dimLowerBoundSize->scalable) {
+ // If the lower bound is scalable and < the mask dim size then this dim is
+ // not all-true.
+ if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
+ return failure();
+ } else {
+ // If the lower bound is a constant:
+ // - If the mask dim size is scalable then this dim is not all-true.
+ if (maskTypeDimScalableFlags[i])
+ return failure();
+ // - If the lower bound is < the _fixed-size_ mask dim size then this dim
+ // is not all-true.
----------------
banach-space wrote:
```suggestion
// 1. The lower bound, LB, is scalable. If LB < the mask dim size then this dim is
// not all-true.
if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
return failure();
} else {
// 2. The lower bound, LB, is a constant.
// 2.1 If the mask dim size is scalable then this dim is not all-true.
if (maskTypeDimScalableFlags[i])
return failure();
// 2.2 If LB < the _fixed-size_ mask dim size then this dim
// is not all-true.
```
[nit] This way I find easier to see what the distinct cases are.
https://github.com/llvm/llvm-project/pull/99314
More information about the Mlir-commits
mailing list