[Mlir-commits] [mlir] [mlir][vector] Extend `CreateMaskFolder` (PR #75842)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue Dec 19 08:24:01 PST 2023
================
@@ -5657,49 +5657,104 @@ LogicalResult CreateMaskOp::verify() {
namespace {
-// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+///
+/// Ex 1:
+/// %c2 = arith.constant 2 : index
+/// %c3 = arith.constant 3 : index
+/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
+/// Becomes:
+/// vector.constant_mask [3, 2] : vector<4x3xi1>
+///
+/// Ex 2:
+/// %c_neg_1 = arith.constant -1 : index
+/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
+/// becomes:
+/// vector.constant_mask [0] : vector<[8]xi1>
+///
+/// Ex 3:
+/// %c8 = arith.constant 8 : index
+/// %c16 = arith.constant 16 : index
+/// %0 = vector.vscale
+/// %1 = arith.muli %0, %c16 : index
+/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
+/// becomes:
+/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
- // Return if any of 'createMaskOp' operands are not defined by a constant.
- auto isNotDefByConstant = [](Value operand) {
- return !getConstantIntValue(operand).has_value();
- };
- if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
- return failure();
+ VectorType retTy = createMaskOp.getResult().getType();
+ bool isScalable = retTy.isScalable();
+
+ // Check every mask operand
+ for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
+ // Most basic case - this operand is a constant value. Note that for
+ // scalable dimensions, CreateMaskOp can be folded only if the
+ // corresponding operand is negative or zero.
+ if (getConstantIntValue(operand)) {
+ APInt intVal;
+ if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) ||
+ intVal.isStrictlyPositive()))
----------------
MacDue wrote:
Likely the intended check:
```suggestion
if (isScalable && (!matchPattern(operand, m_ConstantInt(&intVal)) ||
intVal.isStrictlyPositive()))
```
https://github.com/llvm/llvm-project/pull/75842
More information about the Mlir-commits
mailing list