[Mlir-commits] [mlir] [mlir][vector] Add a pattern to fuse extract(constant_mask) (PR #81057)

Jakub Kuderski llvmlistbot at llvm.org
Mon Feb 26 11:01:31 PST 2024


================
@@ -2040,6 +2040,77 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
   }
 };
 
+// Patterns to rewrite ExtractOp(ConstantMaskOp)
+//
+// When the result of ExtractOp is a subvector of input, we can rewrite it as
+// a ConstantMaskOp with subvector ranks.
+//
+// ExtractOp(ConstantMaskOp) -> ConstantMaskOp
+//
+// When the result of ExtractOp is a scalar, we can get the scalar value
+// directly.
+//
+// ExtractOp(ConstantMaskOp) -> ConstantOp
+class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto constantMaskOp =
+        extractOp.getVector().getDefiningOp<vector::ConstantMaskOp>();
+    if (!constantMaskOp)
+      return failure();
+
+    // All indices must be static.
+    ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
+    unsigned dynamicPosCount =
+        llvm::count_if(extractOpPos, ShapedType::isDynamic);
+    // If there is any dynamic position in ExtractOp, we cannot determine the
+    // scalar value.
+    if (dynamicPosCount)
+      return failure();
+
+    ArrayRef<Attribute> maskDimSizes =
+        constantMaskOp.getMaskDimSizes().getValue();
+    Type resultTy = extractOp.getResult().getType();
+    if (resultTy.isa<mlir::VectorType>()) {
+      auto resultVectorTy = resultTy.cast<mlir::VectorType>();
+      int64_t resultRank = resultVectorTy.getRank();
+      int64_t n = maskDimSizes.size();
+      std::vector<int64_t> indices;
+      for (auto i = n - resultRank; i < n; ++i)
+        indices.push_back(cast<IntegerAttr>(maskDimSizes[i]).getInt());
+
+      rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
+          extractOp, resultVectorTy,
+          vector::getVectorSubscriptAttr(rewriter, indices));
+
+      return success();
+    } else if (resultTy.isa<mlir::IntegerType>()) {
+      // ConstantMaskOp creates and returns a vector mask where elements of the
+      // result vector are set to ‘0’ or ‘1’, based on whether the element
+      // indices are contained within a hyper-rectangular region.
+      // We go through ExtractOp static positions to determine the position is
+      // within the hyper-rectangular region or not.
+      Type boolType = rewriter.getI1Type();
+      IntegerAttr setAttr = IntegerAttr::get(boolType, 1);
+      for (size_t i = 0, end = extractOpPos.size(); i < end; ++i) {
----------------
kuhar wrote:

```suggestion
      for (auto [idx, extractPos] : llvm::enumerate(extractOpPos)) {
```

https://github.com/llvm/llvm-project/pull/81057


More information about the Mlir-commits mailing list