[Mlir-commits] [mlir] [mlir][vector] Add a pattern to fuse extract(constant_mask) (PR #81057)

Hsiangkai Wang llvmlistbot at llvm.org
Wed Feb 7 15:47:59 PST 2024


https://github.com/Hsiangkai created https://github.com/llvm/llvm-project/pull/81057

This pattern will rewrite
ExtractOp(ConstantMaskOp) -> ConstantMaskOp
or
ExtractOp(ConstantMaskOp) -> Constant

>From cafd3223dbbb75bd2fab88a73198ae016d72121d Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Fri, 15 Dec 2023 11:33:01 +0000
Subject: [PATCH] [mlir][vector] Add a pattern to fuse extract(constant_mask)

This pattern will rewrite
ExtractOp(ConstantMaskOp) -> ConstantMaskOp
or
ExtractOp(ConstantMaskOp) -> Constant
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 57 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 47 ++++++++++++++++++
 2 files changed, 103 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 452354413e8833..697989c4326236 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2039,6 +2039,60 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+// Pattern to rewrite
+// ExtractOp(ConstantMaskOp) -> ConstantMaskOp
+// or
+// ExtractOp(ConstantMaskOp) -> ConstantOp
+class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto constantMaskOp =
+        extractOp.getVector().getDefiningOp<vector::ConstantMaskOp>();
+    if (!constantMaskOp)
+      return failure();
+
+    auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+    auto resultTy = extractOp.getResult().getType();
+    if (resultTy.isa<mlir::VectorType>()) {
+      auto resultVectorTy = resultTy.cast<mlir::VectorType>();
+      auto resultRank = resultVectorTy.getRank();
+      auto n = maskDimSizes.size();
+      std::vector<int64_t> indices;
+      for (size_t i = n - resultRank; i < n; i++)
+        indices.push_back(cast<IntegerAttr>(maskDimSizes[i]).getInt());
+
+      rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+          extractOp, resultVectorTy,
+          vector::getVectorSubscriptAttr(rewriter, indices));
+
+      return success();
+    } else if (resultTy.isa<mlir::IntegerType>()) {
+      // Extract a scalar. All indices must be static.
+      ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
+      if (maskDimSizes.size() != extractOpPos.size())
+        return failure();
+
+      auto boolType = rewriter.getI1Type();
+      auto setAttr = IntegerAttr::get(boolType, 1);
+      for (size_t i = 0; i < extractOpPos.size(); i++) {
+        if (cast<IntegerAttr>(maskDimSizes[i]).getInt() <= extractOpPos[i]) {
+          setAttr = IntegerAttr::get(boolType, 0);
+          break;
+        }
+      }
+
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, boolType,
+                                                     setAttr);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 // Folds extract(shape_cast(..)) into shape_cast when the total element count
 // does not change.
 LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
@@ -2065,7 +2119,8 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
   results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
-              ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+              ExtractOpFromBroadcast, ExtractOpFromCreateMask,
+              ExtractOpFromConstantMask>(context);
   results.add(foldExtractFromShapeCastToShapeCast);
 }
 
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e6f045e12e5197..3e1b638f8f292a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2567,3 +2567,50 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te
       tensor<4x4x4xf32>, vector<1x100x4x5xf32>
   return %r : vector<1x100x4x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @extract_true_from_constant_mask() -> i1 {
+func.func @extract_true_from_constant_mask() -> i1 {
+// CHECK:      %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT: return %[[TRUE]] : i1
+  %mask = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
+  %extract = vector.extract %mask[1, 1, 2] : i1 from vector<4x4x4xi1>
+  return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_false_from_constant_mask() -> i1 {
+func.func @extract_false_from_constant_mask() -> i1 {
+// CHECK:      %[[FALSE:.*]] = arith.constant false
+// CHECK-NEXT: return %[[FALSE]] : i1
+  %mask = vector.constant_mask [2, 2, 3] : vector<4x4x4xi1>
+  %extract = vector.extract %mask[1, 2, 2] : i1 from vector<4x4x4xi1>
+  return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_from_create_mask() -> i1 {
+func.func @extract_from_create_mask() -> i1 {
+// CHECK:      %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT: return %[[TRUE]] : i1
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c2, %c2, %c3 : vector<4x4x4xi1>
+  %extract = vector.extract %mask[1, 1, 2] : i1 from vector<4x4x4xi1>
+  return %extract : i1
+}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_subvector_from_constant_mask() ->
+// CHECK-SAME:  vector<6xi1> {
+func.func @extract_subvector_from_constant_mask() -> vector<6xi1> {
+// CHECK:      %[[S0:.*]] = vector.constant_mask [4] : vector<6xi1>
+// CHECK-NEXT: return %[[S0]] : vector<6xi1>
+  %mask = vector.constant_mask [2, 3, 4] : vector<4x5x6xi1>
+  %extract = vector.extract %mask[1, 2] : vector<6xi1> from vector<4x5x6xi1>
+  return %extract : vector<6xi1>
+}



More information about the Mlir-commits mailing list