[Mlir-commits] [mlir] [mlir][vector] Add a pattern to fuse extract(constant_mask) (PR #81057)
Hsiangkai Wang
llvmlistbot at llvm.org
Thu Feb 8 02:39:44 PST 2024
https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/81057
>From 379487661bdc7967a1e12f1a62b9d870ac45b0a3 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Fri, 15 Dec 2023 11:33:01 +0000
Subject: [PATCH] [mlir][vector] Add patterns to fuse extract(constant_mask)
Add patterns to rewrite ExtractOp(ConstantMaskOp).
When the result of ExtractOp is a subvector of input, we can rewrite it as
a ConstantMaskOp with subvector ranks.
ExtractOp(ConstantMaskOp) -> ConstantMaskOp
When the result of ExtractOp is a scalar, we can get the scalar value
directly.
ExtractOp(ConstantMaskOp) -> ConstantOp
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 74 +++++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 71 +++++++++++++++++++++
2 files changed, 144 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 452354413e8833..56eca3ff8f7d08 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2039,6 +2039,77 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
}
};
+// Patterns to rewrite ExtractOp(ConstantMaskOp)
+//
+// When the result of ExtractOp is a subvector of input, we can rewrite it as
+// a ConstantMaskOp with subvector ranks.
+//
+// ExtractOp(ConstantMaskOp) -> ConstantMaskOp
+//
+// When the result of ExtractOp is a scalar, we can get the scalar value
+// directly.
+//
+// ExtractOp(ConstantMaskOp) -> ConstantOp
+class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto constantMaskOp =
+ extractOp.getVector().getDefiningOp<vector::ConstantMaskOp>();
+ if (!constantMaskOp)
+ return failure();
+
+ ArrayRef<Attribute> maskDimSizes =
+ constantMaskOp.getMaskDimSizes().getValue();
+ Type resultTy = extractOp.getResult().getType();
+ if (resultTy.isa<mlir::VectorType>()) {
+ auto resultVectorTy = resultTy.cast<mlir::VectorType>();
+ int64_t resultRank = resultVectorTy.getRank();
+ int64_t n = maskDimSizes.size();
+ std::vector<int64_t> indices;
+ for (auto i = n - resultRank; i < n; ++i)
+ indices.push_back(cast<IntegerAttr>(maskDimSizes[i]).getInt());
+
+ rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+ extractOp, resultVectorTy,
+ vector::getVectorSubscriptAttr(rewriter, indices));
+
+ return success();
+ } else if (resultTy.isa<mlir::IntegerType>()) {
+ // Extract a scalar. All indices must be static.
+ ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
+ unsigned dynamicPosCount =
+ llvm::count_if(extractOpPos, ShapedType::isDynamic);
+ // If there is any dynamic position in ExtractOp, we cannot determine the
+ // scalar value.
+ if (dynamicPosCount)
+ return failure();
+
+ // ConstantMaskOp creates and returns a vector mask where elements of the
+ // result vector are set to ‘0’ or ‘1’, based on whether the element
+ // indices are contained within a hyper-rectangular region.
+ // We go through ExtractOp static positions to determine the position is
+ // within the hyper-rectangular region or not.
+ Type boolType = rewriter.getI1Type();
+ IntegerAttr setAttr = IntegerAttr::get(boolType, 1);
+ for (size_t i = 0, end = extractOpPos.size(); i < end; ++i) {
+ if (cast<IntegerAttr>(maskDimSizes[i]).getInt() <= extractOpPos[i]) {
+ setAttr = IntegerAttr::get(boolType, 0);
+ break;
+ }
+ }
+
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, boolType,
+ setAttr);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
// Folds extract(shape_cast(..)) into shape_cast when the total element count
// does not change.
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2065,7 +2136,8 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
- ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ ExtractOpFromBroadcast, ExtractOpFromCreateMask,
+ ExtractOpFromConstantMask>(context);
results.add(foldExtractFromShapeCastToShapeCast);
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e6f045e12e5197..153894defff90f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2567,3 +2567,74 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
tensor<4x4x4xf32>, vector<1x100x4x5xf32>
return %r : vector<1x100x4x5xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_true_from_constant_mask() -> i1 {
+func.func @extract_true_from_constant_mask() -> i1 {
+// CHECK: %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT: return %[[TRUE]] : i1
+ %mask = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
+ %extract = vector.extract %mask[1, 1, 2] : i1 from vector<4x4x4xi1>
+ return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_false_from_constant_mask() -> i1 {
+func.func @extract_false_from_constant_mask() -> i1 {
+// CHECK: %[[FALSE:.*]] = arith.constant false
+// CHECK-NEXT: return %[[FALSE]] : i1
+ %mask = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
+ %extract = vector.extract %mask[1, 2, 2] : i1 from vector<4x4x4xi1>
+ return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_from_create_mask() -> i1 {
+func.func @extract_from_create_mask() -> i1 {
+// CHECK: %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT: return %[[TRUE]] : i1
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %mask = vector.create_mask %c2, %c2, %c3 : vector<4x4x4xi1>
+ %extract = vector.extract %mask[1, 1, 2] : i1 from vector<4x4x4xi1>
+ return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_subvector_from_constant_mask() ->
+// CHECK-SAME: vector<6xi1> {
+func.func @extract_subvector_from_constant_mask() -> vector<6xi1> {
+// CHECK: %[[S0:.*]] = vector.constant_mask [4] : vector<6xi1>
+// CHECK-NEXT: return %[[S0]] : vector<6xi1>
+ %mask = vector.constant_mask [2, 3, 4] : vector<4x5x6xi1>
+ %extract = vector.extract %mask[1, 2] : vector<6xi1> from vector<4x5x6xi1>
+ return %extract : vector<6xi1>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_scalar_with_dynamic_positions(
+// CHECK-SAME: %[[INDEX:.*]]: index) -> i1 {
+func.func @extract_scalar_with_dynamic_positions(%index: index) -> i1 {
+// CHECK: %[[S0:.*]] = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
+// CHECK-NEXT: %[[S1:.*]] = vector.extract %[[S0]][1, 1, %[[INDEX]]] : i1 from vector<4x4x4xi1>
+// CHECK-NEXT: return %[[S1]] : i1
+ %mask = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
+ %extract = vector.extract %mask[1, 1, %index] : i1 from vector<4x4x4xi1>
+ return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_subvector_with_dynamic_positions
+func.func @extract_subvector_with_dynamic_positions(%index: index) -> vector<6xi1> {
+// CHECK: %[[S0:.*]] = vector.constant_mask [4] : vector<6xi1>
+// CHECK-NEXT: return %[[S0]] : vector<6xi1>
+ %mask = vector.constant_mask [2, 3, 4] : vector<4x5x6xi1>
+ %extract = vector.extract %mask[1, %index] : vector<6xi1> from vector<4x5x6xi1>
+ return %extract : vector<6xi1>
+}
More information about the Mlir-commits
mailing list