[Mlir-commits] [mlir] [mlir][VectorOps] Add fold `ExtractOp(CreateMask) -> CreateMask` (PR #69456)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Oct 19 03:34:36 PDT 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/69456
>From 2f9f46d8659054c59425419e2eac082dd38545df Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 18 Oct 2023 12:10:03 +0000
Subject: [PATCH 1/2] [mlir][VectorOps] Add fold `ExtractOp(CreateMask) ->
CreateMask`
This allows folding extracts from `vector.create_mask` ops, that have a
known value. Currently, there's no fold for this, but you get the same
effect from the unrolling in LowerVectorMask (part of
-convert-vector-to-llvm), then folds after that. However, for a future
patch, this simplification needs to be done before lowering to LLVM,
hence the need for this fold.
E.g.:
```
%0 = vector.create_mask %c1, %dimA, %dimB : vector<1x[4]x[4]xi1>
%1 = vector.extract %mask[0] : vector<[4]x[4]xi1>
```
->
```
%0 = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
```
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 62 +++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 106 +++++++++++++++++++++
2 files changed, 167 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 68a5cf209f2fb49..6670c1f98c7b45e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1983,6 +1983,66 @@ class ExtractOpNonSplatConstantFolder final
}
};
+// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
+class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto createMaskOp =
+ extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return failure();
+
+ ArrayRef<int64_t> position = extractOp.getStaticPosition();
+ auto maskOperands = createMaskOp.getOperands();
+ VectorType maskType = createMaskOp.getVectorType();
+ VectorType::Builder newMaskType(maskType);
+
+ bool allFalse = false;
+ bool containsUnknownDims = false;
+ for (auto [i, pos] : llvm::enumerate(position)) {
+ newMaskType.dropDim(0);
+ Value operand = maskOperands[i];
+ auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
+ if (!constantOp) {
+ // Bounds of this dim unknown.
+ containsUnknownDims = true;
+ continue;
+ }
+
+ int64_t createMaskBound =
+ llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
+ if (pos == ShapedType::kDynamic) {
+ // Extractions must be in-bounds. So if the corresponding `create_mask`
+ // size is 0 or the size of the dim, we know this dim is false or true.
+ if (createMaskBound == 0)
+ allFalse = true;
+ else if (createMaskBound < maskType.getDimSize(i))
+ // Unknown if this dim is within the true or false region.
+ containsUnknownDims = true;
+ } else {
+ // If any position is outside the range from the `create_mask`, then the
+ // extracted mask will be all false.
+ allFalse |= pos >= createMaskBound;
+ }
+ }
+
+ if (allFalse) {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ extractOp, DenseElementsAttr::get(VectorType(newMaskType), false));
+ } else if (!containsUnknownDims) {
+ rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+ extractOp, VectorType(newMaskType),
+ maskOperands.drop_front(position.size()));
+ } else {
+ return failure();
+ }
+ return success();
+ }
+};
+
// Folds extract(shape_cast(..)) into shape_cast when the total element count
// does not change.
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2009,7 +2069,7 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
- ExtractOpFromBroadcast>(context);
+ ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 924886c50030967..dd2c78eb44e9f9e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -67,6 +67,112 @@ func.func @create_mask_transpose_to_transposed_create_mask(
// -----
+// CHECK-LABEL: extract_from_create_mask
+// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
+ %c2 = arith.constant 2 : index
+ %mask = vector.create_mask %c2, %dim0, %dim1 : vector<4x[4]x[4]xi1>
+ // CHECK: vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[4]x[4]xi1>
+ // CHECK-NOT: vector.extract
+ %extract = vector.extract %mask[1] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+ return %extract : vector<[4]x[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_all_false
+func.func @extract_from_create_mask_all_false(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {
+ %c2 = arith.constant 2 : index
+ %mask = vector.create_mask %c2, %dim0, %dim1 : vector<4x[4]x[4]xi1>
+ // CHECK: arith.constant dense<false> : vector<[4]x[4]xi1>
+ // CHECK-NOT: vector.extract
+ %extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+ return %extract : vector<[4]x[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_leading_scalable
+// CHECK-SAME: %[[DIM0:.*]]: index
+func.func @extract_from_create_mask_leading_scalable(%dim0: index) -> vector<8xi1> {
+ %c3 = arith.constant 3 : index
+ %mask = vector.create_mask %c3, %dim0 : vector<[4]x8xi1>
+ // CHECK: vector.create_mask %[[DIM0]] : vector<8xi1>
+ // CHECK-NOT: vector.extract
+ %extract = vector.extract %mask[1] : vector<8xi1> from vector<[4]x8xi1>
+ return %extract : vector<8xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_dynamic_position
+// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
+func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index) -> vector<6xi1> {
+ %c4 = arith.constant 4 : index
+ %c3 = arith.constant 3 : index
+ %mask = vector.create_mask %c3, %c4, %dim0 : vector<4x4x6xi1>
+ // CHECK: vector.create_mask %[[DIM0]] : vector<6xi1>
+ // CHECK-NOT: vector.extract
+ %extract = vector.extract %mask[2, %index] : vector<6xi1> from vector<4x4x6xi1>
+ return %extract : vector<6xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
+// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
+func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %mask = vector.create_mask %c1, %c0, %dim0 : vector<1x4x6xi1>
+ // CHECK: arith.constant dense<false> : vector<6xi1>
+ // CHECK-NOT: vector.extract
+ %extract = vector.extract %mask[0, %index] : vector<6xi1> from vector<1x4x6xi1>
+ return %extract : vector<6xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_dynamic_position_unknown
+// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
+func.func @extract_from_create_mask_dynamic_position_unknown(%dim0: index, %index: index) -> vector<6xi1> {
+ %c2 = arith.constant 2 : index
+ %mask = vector.create_mask %c2, %dim0 : vector<4x6xi1>
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[DIM0]] : vector<4x6xi1>
+ // CHECK-NEXT: vector.extract %[[MASK]][%[[INDEX]]] : vector<6xi1> from vector<4x6xi1>
+ %extract = vector.extract %mask[%index] : vector<6xi1> from vector<4x6xi1>
+ return %extract : vector<6xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_create_mask_mixed_position_unknown
+// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
+func.func @extract_from_create_mask_mixed_position_unknown(%dim0: index, %index0: index) -> vector<4xi1> {
+ %c2 = arith.constant 2 : index
+ %mask = vector.create_mask %c2, %c2, %dim0 : vector<2x4x4xi1>
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[C2]], %[[DIM0]] : vector<2x4x4xi1>
+ // CHECK-NEXT: vector.extract %[[MASK]][1, %[[INDEX]]] : vector<4xi1> from vector<2x4x4xi1>
+ %extract = vector.extract %mask[1, %index0] : vector<4xi1> from vector<2x4x4xi1>
+ return %extract : vector<4xi1>
+}
+
+// -----
+
+// CHECK-LABEL: extract_from_non_constant_create_mask
+// CHECK-SAME: %[[DIM0:.*]]: index
+func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1> {
+ %mask = vector.create_mask %dim0, %dim0 : vector<[2]x[2]xi1>
+ // CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM0]] : vector<[2]x[2]xi1>
+ // CHECK-NEXT: vector.extract %[[MASK]][0] : vector<[2]xi1> from vector<[2]x[2]xi1>
+ %extract = vector.extract %mask[0] : vector<[2]xi1> from vector<[2]x[2]xi1>
+ return %extract : vector<[2]xi1>
+}
+
+// -----
+
// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
>From 326c4f2f3a6aab297416947bc3fd5b75bc2ae03a Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 19 Oct 2023 10:32:41 +0000
Subject: [PATCH 2/2] Shuffle things around
Move all-false mask detection to `getMaskFormat()`, cleanup rewrite a
little.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 57 ++++++++++++++++--------
1 file changed, 38 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6670c1f98c7b45e..77cc3f1f5544733 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -100,6 +100,20 @@ static MaskFormat getMaskFormat(Value mask) {
return MaskFormat::AllTrue;
if (allFalse)
return MaskFormat::AllFalse;
+ } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
+ // Finds all-false create_masks. An all-true create_mask requires all
+ // dims to be constants, so that'll be folded to a constant_mask, then
+ // detected in the constant_mask case.
+ auto maskOperands = m.getOperands();
+ for (Value operand : maskOperands) {
+ if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
+ int64_t dimSize =
+ llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
+ if (dimSize <= 0)
+ return MaskFormat::AllFalse;
+ }
+ }
+ return MaskFormat::Unknown;
}
return MaskFormat::Unknown;
}
@@ -1995,16 +2009,23 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
if (!createMaskOp)
return failure();
- ArrayRef<int64_t> position = extractOp.getStaticPosition();
+ VectorType extractedMaskType =
+ llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
+
+ if (!extractedMaskType)
+ return failure();
+
auto maskOperands = createMaskOp.getOperands();
+ ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
VectorType maskType = createMaskOp.getVectorType();
- VectorType::Builder newMaskType(maskType);
- bool allFalse = false;
bool containsUnknownDims = false;
- for (auto [i, pos] : llvm::enumerate(position)) {
- newMaskType.dropDim(0);
- Value operand = maskOperands[i];
+ bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
+
+ for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
+ dimIdx++) {
+ int64_t pos = extractOpPos[dimIdx];
+ Value operand = maskOperands[dimIdx];
auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
if (!constantOp) {
// Bounds of this dim unknown.
@@ -2014,28 +2035,26 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
int64_t createMaskBound =
llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
- if (pos == ShapedType::kDynamic) {
- // Extractions must be in-bounds. So if the corresponding `create_mask`
- // size is 0 or the size of the dim, we know this dim is false or true.
- if (createMaskBound == 0)
- allFalse = true;
- else if (createMaskBound < maskType.getDimSize(i))
- // Unknown if this dim is within the true or false region.
- containsUnknownDims = true;
- } else {
+
+ if (pos != ShapedType::kDynamic) {
// If any position is outside the range from the `create_mask`, then the
- // extracted mask will be all false.
+ // extracted mask will be all-false.
allFalse |= pos >= createMaskBound;
+ } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
+ // This dim is not all-true and since this is a dynamic index we don't
+ // know if the extraction is within the true or false region.
+ // Note: Zero dims have already handled via getMaskFormat().
+ containsUnknownDims = true;
}
}
if (allFalse) {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
- extractOp, DenseElementsAttr::get(VectorType(newMaskType), false));
+ extractOp, DenseElementsAttr::get(extractedMaskType, false));
} else if (!containsUnknownDims) {
rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
- extractOp, VectorType(newMaskType),
- maskOperands.drop_front(position.size()));
+ extractOp, extractedMaskType,
+ maskOperands.drop_front(extractOpPos.size()));
} else {
return failure();
}
More information about the Mlir-commits
mailing list